ariG23498's picture
ariG23498 HF staff
fix: typo
ff9b8ad
import gradio as gr
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import semantic_search
title_dataset = load_dataset("pyimagesearch/blog-title", data_files="bp-title.csv")
title_embeddings = load_dataset("pyimagesearch/blog-title", data_files="embeddings.csv")
title_embeddings = torch.from_numpy(title_embeddings["train"].to_pandas().to_numpy()).to(torch.float)
model = SentenceTransformer("paraphrase-MiniLM-L6-v2")
title="Title Semantic Search"
description="Provide a blog post title, and we'll find the most similar titles from our already written blog posts."
examples=[
"Introduction to Keras",
"Conditional GANs with Keras",
"A Gentle Introduction to PyTorch with Deep Learning",
]
def get_titles(query):
query_embed = model.encode(query)
hits = semantic_search(query_embed, title_embeddings, top_k=5)[0]
titles = dict()
for hit in hits:
index = hit["corpus_id"]
selected_title = title_dataset["train"]["title"][index]
score = hit["score"]
titles[selected_title] = score
return titles
space = gr.Interface(
fn=get_titles,
inputs=gr.Textbox(label="Input Title"),
# outputs=gr.Textbox(label="Similar Titles"),
outputs=gr.Label(num_top_classes=5),
title=title,
description=description,
examples=examples,
)
space.launch()