Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
import spaces | |
import numpy as np | |
from PIL import Image | |
import io | |
import random | |
from transformers import SamModel, SamProcessor | |
def apply_colored_masks_on_image(image, masks): | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
image_rgba = image.convert("RGBA") | |
for i in range(masks.shape[0]): | |
mask = masks[i].squeeze().cpu().numpy() | |
mask_image = Image.fromarray((mask * 255).astype(np.uint8), 'L') | |
color = tuple([random.randint(0, 255) for _ in range(3)] + [128]) | |
colored_mask = Image.new("RGBA", image.size, color) | |
colored_mask.putalpha(mask_image) | |
image_rgba = Image.alpha_composite(image_rgba, colored_mask) | |
return image_rgba | |
# Use GPU if available | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to(device) | |
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") | |
model_sam = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) | |
processor_sam = SamProcessor.from_pretrained("facebook/sam-vit-huge") | |
def query_image(img, text_queries, score_threshold=0.5): | |
text_queries = text_queries.split(",") | |
size = max(img.shape[:2]) | |
target_sizes = torch.Tensor([[size, size]]) | |
inputs = processor(text=text_queries, images=img, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
model_outputs = model(**inputs) | |
model_outputs.logits = model_outputs.logits.cpu() | |
model_outputs.pred_boxes = model_outputs.pred_boxes.cpu() | |
results = processor.post_process_object_detection(outputs=model_outputs, target_sizes=target_sizes) | |
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"] | |
img_pil = Image.fromarray(img.astype('uint8'), 'RGB') | |
result_labels = [] | |
result_boxes = [] | |
for box, score, label in zip(boxes, scores, labels): | |
if score >= score_threshold: | |
box = [int(i) for i in box.tolist()] | |
label_text = text_queries[label.item()] | |
result_labels.append((box, label_text)) | |
result_boxes.append(box) | |
sam_image = generate_image_with_sam(np.array(img_pil), result_boxes) | |
return sam_image,result_labels | |
def generate_image_with_sam(img, input_boxes): | |
img_pil = Image.fromarray(img.astype('uint8'), 'RGB') | |
inputs = processor_sam(img_pil, return_tensors="pt").to(device) | |
image_embeddings = model_sam.get_image_embeddings(inputs["pixel_values"]) | |
inputs = processor_sam(img_pil, input_boxes=[input_boxes], return_tensors="pt").to(device) | |
inputs["input_boxes"].shape | |
inputs.pop("pixel_values", None) | |
inputs.update({"image_embeddings": image_embeddings}) | |
with torch.no_grad(): | |
outputs = model_sam(**inputs, multimask_output=False) | |
masks = processor_sam.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) | |
# scores = outputs.iou_scores | |
SAM_image = apply_colored_masks_on_image(img_pil, masks[0]) | |
return SAM_image | |
description = """ | |
Split anythings | |
""" | |
demo = gr.Interface( | |
fn=query_image, | |
inputs=[gr.Image(), gr.Textbox(label="Query Text"), gr.Slider(0, 1, value=0.1, label="Score Threshold")], | |
outputs=gr.AnnotatedImage(), | |
title="Zero-Shot Object Detection SV3", | |
description="This interface demonstrates object detection using zero-shot object detection and SAM for image segmentation.", | |
examples=[ | |
["images/dark_cell.png", "gray cells", 0.1], | |
["images/animals.png", "Rabbit,Squirrel,Parrot,Hedgehog,Turtle,Ladybug,Chick,Frog,Butterfly,Snail,Mouse", 0.35], | |
], | |
) | |
demo.launch() | |