CLIPModel / app.py
ibrim's picture
Update app.py
ddf3fb2 verified
import gradio as gr
import gc
import cv2
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import DistilBertTokenizer
import matplotlib.pyplot as plt
from implement import *
import config as CFG
from main import build_loaders
from CLIP import CLIPModel
import os
with gr.Blocks(css="style.css") as demo:
def get_image_embeddings(valid_df, model_path):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
model = CLIPModel().to(CFG.device)
model.load_state_dict(torch.load(model_path, map_location=CFG.device))
model.eval()
valid_image_embeddings = []
with torch.no_grad():
for batch in tqdm(valid_loader):
image_features = model.image_encoder(batch["image"].to(CFG.device))
image_embeddings = model.image_projection(image_features)
valid_image_embeddings.append(image_embeddings)
return model, torch.cat(valid_image_embeddings)
_, valid_df = make_train_valid_dfs()
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")
def find_matches(query, n=9):
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
encoded_query = tokenizer([query])
batch = {
key: torch.tensor(values).to(CFG.device)
for key, values in encoded_query.items()
}
with torch.no_grad():
text_features = model.text_encoder(
input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
)
text_embeddings = model.text_projection(text_features)
image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
dot_similarity = text_embeddings_n @ image_embeddings_n.T
_, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
matches = [valid_df['image'].values[idx] for idx in indices[::5]]
images = []
for match in matches:
image = cv2.imread(f"{CFG.image_path}/{match}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# images.append(image)
return image
with gr.Row():
textbox = gr.Textbox(label = "Enter a query to find matching images using a CLIP model.")
image = gr.Image(type="numpy")
button = gr.Button("Press")
button.click(
fn = find_matches,
inputs=textbox,
outputs=image
)
# Create Gradio interface
demo.launch(share=True)