File size: 3,361 Bytes
cc7fbfd
 
 
 
 
 
 
21232f6
cc7fbfd
 
 
 
 
 
 
390940a
fd219d5
21232f6
cc7fbfd
 
 
 
 
 
 
 
 
 
 
 
 
fd219d5
21232f6
390940a
cc7fbfd
 
390940a
fd219d5
390940a
 
fd219d5
 
390940a
cc7fbfd
 
 
21232f6
cc7fbfd
21232f6
 
 
 
 
 
 
 
 
 
 
 
cc7fbfd
 
21232f6
 
cc7fbfd
 
 
390940a
 
 
cc7fbfd
 
fa46c8a
 
cc7fbfd
 
 
 
 
32e2c7e
cc7fbfd
 
 
 
 
390940a
c97f39d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
import requests
from transformers import SamModel, SamProcessor
import cv2
from typing import List

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and processor
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

cache_data = None

def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
    gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
    _, thresh = cv2.threshold(gray, 127, 255, 0)
    kernel = np.ones((5,5),np.uint8)
    closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
    contours, _ = cv2.findContours(closed, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    points = []
    for contour in contours:
        moments = cv2.moments(contour)
        cx = int(moments['m10']/moments['m00'])
        cy = int(moments['m01']/moments['m00'])
        points.append([cx, cy])
    return [points]

@torch.no_grad()
def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
    global cache_data
    image_input = Image.fromarray(image_input)
    inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
    if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
        embedding = model.get_image_embeddings(inputs["pixel_values"])
        pixels = inputs["pixel_values"]
        cache_data = [pixels, embedding]
    del inputs["pixel_values"]

    outputs = model.forward(image_embeddings=cache_data[1], **inputs)
    masks = processor.image_processor.post_process_masks(
        outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
    )
    masks = masks[0].squeeze(0).numpy().transpose(1, 2, 0)

    return masks

def main_func(inputs) -> List[Image.Image]:
    dots = inputs['mask']
    points = mask_2_dots(dots)
    image_input = inputs['image']
    masks = foward_pass(image_input, points)

    image_input = Image.fromarray(image_input)
    draw = ImageDraw.Draw(image_input)
    for point in points[0]:
        draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red")

    pred_masks = [image_input]
    for i in range(masks.shape[2]):
        pred_masks.append(Image.fromarray((masks[:,:,i] * 255).astype(np.uint8)))

    return pred_masks

def reset_data():
    global cache_data
    cache_data = None

with gr.Blocks() as demo:
    gr.Markdown("# How to use")
    gr.Markdown("To start, input an image, then use the brush to create dots on the object which you want to segment, don't worry if your dots aren't perfect as the code will find the middle of each drawn item. Then press the segment button to create masks for the object that the dots are on.")
    gr.Markdown("# Demo to run Segment Anything base model")
    gr.Markdown("""This app uses the [Segment Anything](https://huggingface.co./facebook/sam-vit-base) model from Meta to get a mask from a points in an image.
    """)
    with gr.Tab("Flip Image"):
        with gr.Row():
            image_input = gr.ImageEditor()
            image_output = gr.Gallery()
        
        image_button = gr.Button("Segment Image")

    image_button.click(main_func, inputs=image_input, outputs=image_output)
    image_input.upload(reset_data)

demo.launch()