rag-augment / app.py
davidberenstein1957's picture
Update app.py
2b37664 verified
import gradio as gr
from sentence_transformers import CrossEncoder
import pandas as pd
reranker = CrossEncoder("sentence-transformers/all-MiniLM-L12-v2")
def rerank(query: str, documents: pd.DataFrame) -> pd.DataFrame:
documents = documents.copy()
documents = documents.drop_duplicates("text")
documents["rank"] = reranker.predict([[query, hit] for hit in documents["text"]])
documents = documents.sort_values(by="rank", ascending=False)
return documents
with gr.Blocks() as demo:
gr.Markdown("""# RAG - Augment
Applies reranking to the retrieved documents using [sentence-transformers/all-MiniLM-L12-v2](https://huggingface.co./sentence-transformers/all-MiniLM-L12-v2).
Part of [AI blueprint](https://github.com/huggingface/ai-blueprint) - a blueprint for AI development, focusing on practical examples of RAG, information extraction, analysis and fine-tuning in the age of LLMs.""")
query_input = gr.Textbox(
label="Query", placeholder="Enter your question here...", lines=3
)
documents_input = gr.Dataframe(
label="Documents", headers=["text"], wrap=True, interactive=True
)
submit_btn = gr.Button("Submit")
documents_output = gr.Dataframe(
label="Documents", headers=["text", "rank"], wrap=True
)
submit_btn.click(
fn=rerank,
inputs=[query_input, documents_input],
outputs=[documents_output],
)
demo.launch()