|
import gradio as gr |
|
import gc |
|
import cv2 |
|
import torch |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
from transformers import DistilBertTokenizer |
|
import matplotlib.pyplot as plt |
|
from implement import * |
|
import config as CFG |
|
from main import build_loaders |
|
from CLIP import CLIPModel |
|
import os |
|
with gr.Blocks(css="style.css") as demo: |
|
def get_image_embeddings(valid_df, model_path): |
|
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) |
|
valid_loader = build_loaders(valid_df, tokenizer, mode="valid") |
|
|
|
model = CLIPModel().to(CFG.device) |
|
model.load_state_dict(torch.load(model_path, map_location=CFG.device)) |
|
model.eval() |
|
|
|
valid_image_embeddings = [] |
|
with torch.no_grad(): |
|
for batch in tqdm(valid_loader): |
|
image_features = model.image_encoder(batch["image"].to(CFG.device)) |
|
image_embeddings = model.image_projection(image_features) |
|
valid_image_embeddings.append(image_embeddings) |
|
return model, torch.cat(valid_image_embeddings) |
|
|
|
_, valid_df = make_train_valid_dfs() |
|
model, image_embeddings = get_image_embeddings(valid_df, "best.pt") |
|
|
|
def find_matches(query, n=9): |
|
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer) |
|
encoded_query = tokenizer([query]) |
|
batch = { |
|
key: torch.tensor(values).to(CFG.device) |
|
for key, values in encoded_query.items() |
|
} |
|
with torch.no_grad(): |
|
text_features = model.text_encoder( |
|
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"] |
|
) |
|
text_embeddings = model.text_projection(text_features) |
|
|
|
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1) |
|
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1) |
|
dot_similarity = text_embeddings_n @ image_embeddings_n.T |
|
|
|
_, indices = torch.topk(dot_similarity.squeeze(0), n * 5) |
|
matches = [valid_df['image'].values[idx] for idx in indices[::5]] |
|
|
|
images = [] |
|
for match in matches: |
|
image = cv2.imread(f"{CFG.image_path}/{match}") |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
return image |
|
with gr.Row(): |
|
textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.") |
|
image = gr.Image(type="numpy") |
|
|
|
button = gr.Button("Press") |
|
button.click( |
|
fn = find_matches, |
|
inputs=textbox, |
|
outputs=image |
|
) |
|
|
|
|
|
demo.launch(share=True) |
|
|