Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import semantic_search | |
title_dataset = load_dataset("pyimagesearch/blog-title", data_files="bp-title.csv") | |
title_embeddings = load_dataset("pyimagesearch/blog-title", data_files="embeddings.csv") | |
title_embeddings = torch.from_numpy(title_embeddings["train"].to_pandas().to_numpy()).to(torch.float) | |
model = SentenceTransformer("paraphrase-MiniLM-L6-v2") | |
title="Title Semantic Search" | |
description="Provide a blog post title, and we'll find the most similar titles from our already written blog posts." | |
examples=[ | |
"Introduction to Keras", | |
"Conditional GANs with Keras", | |
"A Gentle Introduction to PyTorch with Deep Learning", | |
] | |
def get_titles(query): | |
query_embed = model.encode(query) | |
hits = semantic_search(query_embed, title_embeddings, top_k=5)[0] | |
titles = dict() | |
for hit in hits: | |
index = hit["corpus_id"] | |
selected_title = title_dataset["train"]["title"][index] | |
score = hit["score"] | |
titles[selected_title] = score | |
return titles | |
space = gr.Interface( | |
fn=get_titles, | |
inputs=gr.Textbox(label="Input Title"), | |
# outputs=gr.Textbox(label="Similar Titles"), | |
outputs=gr.Label(num_top_classes=5), | |
title=title, | |
description=description, | |
examples=examples, | |
) | |
space.launch() |