Blip_QA / app.py
ParisNeo
updated
1c157b6
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
import urllib.request
import io
from pathlib import Path
from blip_vqa import blip_vqa
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
class App():
def __init__(self):
self.selected_model=0
# Load blip for question answer
print("Loading Blip for question answering")
model_url = str(Path(__file__).parent/'blip_vqa.pth')
self.qa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
self.qa_model.eval()
self.qa_model = self.qa_model.to(device)
with gr.Blocks() as demo:
gr.Markdown("# BLIP Image question and answer\nThis model allows you to ask questions about an image and get solid answers.\nIt can be used to caption images for stable diffusion fine tuning purposes or many other applications.\nBrought to gradio by @ParisNeo from the original github Blip code [https://github.com/salesforce/BLIP](https://github.com/salesforce/BLIP)\nThis model is described in this paper :[https://arxiv.org/abs/2201.12086](https://arxiv.org/abs/2201.12086)")
with gr.Row():
self.image_source = gr.inputs.Image(shape=(448, 448))
with gr.Tabs():
with gr.Tab("Question/Answer"):
self.question = gr.inputs.Textbox(label="Custom question (if applicable)", default="Describe this image")
self.answer = gr.Button("Ask")
self.lbl_caption = gr.outputs.Label(label="Caption")
self.answer.click(self.answer_question_image, [self.image_source, self.question], self.lbl_caption)
# Launch the interface
demo.launch()
def answer_question_image(self, img, custom_question="Describe this image"):
# Load the selected PyTorch model
# Preprocess the image
preprocess = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
img = preprocess(Image.fromarray(img.astype('uint8'), 'RGB'))
# Make a prediction with the model
with torch.no_grad():
output = self.qa_model(img.unsqueeze(0).to(device), custom_question, train=False, inference='generate')
answer = output
# Return the predicted label as a string
return answer[0]
app = App()