|
import gradio as gr |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import urllib.request |
|
import io |
|
from pathlib import Path |
|
|
|
from blip_vqa import blip_vqa |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
image_size = 384 |
|
|
|
class App(): |
|
def __init__(self): |
|
self.selected_model=0 |
|
|
|
|
|
print("Loading Blip for question answering") |
|
model_url = str(Path(__file__).parent/'blip_vqa.pth') |
|
self.qa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base') |
|
self.qa_model.eval() |
|
self.qa_model = self.qa_model.to(device) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# BLIP Image question and answer\nThis model allows you to ask questions about an image and get solid answers.\nIt can be used to caption images for stable diffusion fine tuning purposes or many other applications.\nBrought to gradio by @ParisNeo from the original github Blip code [https://github.com/salesforce/BLIP](https://github.com/salesforce/BLIP)\nThis model is described in this paper :[https://arxiv.org/abs/2201.12086](https://arxiv.org/abs/2201.12086)") |
|
with gr.Row(): |
|
self.image_source = gr.inputs.Image(shape=(448, 448)) |
|
with gr.Tabs(): |
|
with gr.Tab("Question/Answer"): |
|
self.question = gr.inputs.Textbox(label="Custom question (if applicable)", default="Describe this image") |
|
self.answer = gr.Button("Ask") |
|
self.lbl_caption = gr.outputs.Label(label="Caption") |
|
self.answer.click(self.answer_question_image, [self.image_source, self.question], self.lbl_caption) |
|
|
|
demo.launch() |
|
|
|
|
|
|
|
def answer_question_image(self, img, custom_question="Describe this image"): |
|
|
|
|
|
|
|
preprocess = transforms.Compose([ |
|
transforms.Resize((image_size,image_size),interpolation=transforms.InterpolationMode.BICUBIC), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
]) |
|
img = preprocess(Image.fromarray(img.astype('uint8'), 'RGB')) |
|
|
|
|
|
with torch.no_grad(): |
|
output = self.qa_model(img.unsqueeze(0).to(device), custom_question, train=False, inference='generate') |
|
answer = output |
|
|
|
|
|
return answer[0] |
|
|
|
app = App() |
|
|
|
|
|
|