|
import torch |
|
import gradio as gr |
|
import re |
|
import cv2 |
|
from PIL import ImageDraw, Image |
|
|
|
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration |
|
|
|
mix_model_id = "google/paligemma-3b-mix-224" |
|
mix_model = PaliGemmaForConditionalGeneration.from_pretrained(mix_model_id) |
|
mix_processor = AutoProcessor.from_pretrained(mix_model_id) |
|
|
|
|
|
def parse_multiple_locations(decoded_output): |
|
|
|
loc_pattern = r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s+(\w+)" |
|
|
|
matches = re.findall(loc_pattern, decoded_output) |
|
coords_and_labels = [] |
|
|
|
for match in matches: |
|
|
|
y1 = int(match[0]) / 1000 |
|
x1 = int(match[1]) / 1000 |
|
y2 = int(match[2]) / 1000 |
|
x2 = int(match[3]) / 1000 |
|
label = match[4] |
|
|
|
coords_and_labels.append({ |
|
'label': label, |
|
'bbox': [y1, x1, y2, x2] |
|
}) |
|
|
|
return coords_and_labels |
|
|
|
|
|
def draw_multiple_bounding_boxes(image, coords_and_labels): |
|
draw = ImageDraw.Draw(image) |
|
width, height = image.size |
|
|
|
for obj in coords_and_labels: |
|
|
|
y1, x1, y2, x2 = obj['bbox'][0] * height, obj['bbox'][1] * width, obj['bbox'][2] * height, obj['bbox'][3] * width |
|
|
|
|
|
draw.rectangle([x1, y1, x2, y2], outline="red", width=3) |
|
draw.text((x1, y1), obj['label'], fill="red") |
|
|
|
return image |
|
|
|
|
|
def process_image(image, prompt): |
|
|
|
inputs = mix_processor(image.convert("RGB"), prompt, return_tensors="pt") |
|
|
|
try: |
|
|
|
output = mix_model.generate(**inputs, max_new_tokens=100) |
|
|
|
|
|
decoded_output = mix_processor.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
coords_and_labels = parse_multiple_locations(decoded_output) |
|
|
|
if coords_and_labels: |
|
|
|
image_with_boxes = draw_multiple_bounding_boxes(image, coords_and_labels) |
|
|
|
|
|
labels_and_coords = "\n".join([f"Label: {obj['label']}, Coordinates: {obj['bbox']}" for obj in coords_and_labels]) |
|
|
|
|
|
return image_with_boxes, labels_and_coords |
|
else: |
|
return "No bounding boxes detected." |
|
|
|
except IndexError as e: |
|
print(f"IndexError: {e}") |
|
return "An error occurred during processing." |
|
|
|
|
|
inputs = [ |
|
gr.Image(type="pil"), |
|
gr.Textbox(label="Prompt", placeholder="Enter your question") |
|
] |
|
outputs = [ |
|
gr.Image(label="Output Image with Bounding Boxes"), |
|
gr.Textbox(label="Bounding Box Coordinates and Labels") |
|
] |
|
|
|
|
|
demo = gr.Interface(fn=process_image, inputs=inputs, outputs=outputs, title="Object Detection with Mix PaliGemma Model", |
|
description="Upload an image and get object detections with bounding boxes and labels.") |
|
|
|
|
|
demo.launch(debug=True) |