Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -28,6 +28,16 @@ tokenizer = AutoTokenizer.from_pretrained(
|
|
28 |
trust_remote_code = True,
|
29 |
)
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def visualize(pred_mask, image_path, work_dir):
|
32 |
visualizer = Visualizer()
|
33 |
img = cv2.imread(image_path)
|
@@ -56,13 +66,6 @@ def image_vision(image_input_path, prompt):
|
|
56 |
|
57 |
seg_image = return_dict["prediction_masks"]
|
58 |
|
59 |
-
return answer, seg_image
|
60 |
-
|
61 |
-
def main_infer(image_input_path, prompt):
|
62 |
-
|
63 |
-
answer, seg_image = image_vision(image_input_path, prompt)
|
64 |
-
|
65 |
-
|
66 |
if '[SEG]' in answer and Visualizer is not None:
|
67 |
pred_masks = seg_image[0]
|
68 |
temp_dir = tempfile.mkdtemp()
|
@@ -73,26 +76,58 @@ def main_infer(image_input_path, prompt):
|
|
73 |
else:
|
74 |
return answer, None
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# Gradio UI
|
78 |
|
79 |
with gr.Blocks() as demo:
|
80 |
with gr.Column():
|
81 |
gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
|
82 |
-
with gr.
|
83 |
-
with gr.
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
demo.queue().launch(show_api=False, show_error=True)
|
|
|
28 |
trust_remote_code = True,
|
29 |
)
|
30 |
|
31 |
+
from third_parts import VideoReader
|
32 |
+
def read_video(video_path, video_interval):
|
33 |
+
vid_frames = VideoReader(video_path)[::video_interval]
|
34 |
+
for frame_idx in range(len(vid_frames)):
|
35 |
+
frame_image = vid_frames[frame_idx]
|
36 |
+
frame_image = frame_image[..., ::-1] # BGR (opencv system) to RGB (numpy system)
|
37 |
+
frame_image = Image.fromarray(frame_image)
|
38 |
+
vid_frames[frame_idx] = frame_image
|
39 |
+
return vid_frames
|
40 |
+
|
41 |
def visualize(pred_mask, image_path, work_dir):
|
42 |
visualizer = Visualizer()
|
43 |
img = cv2.imread(image_path)
|
|
|
66 |
|
67 |
seg_image = return_dict["prediction_masks"]
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
if '[SEG]' in answer and Visualizer is not None:
|
70 |
pred_masks = seg_image[0]
|
71 |
temp_dir = tempfile.mkdtemp()
|
|
|
76 |
else:
|
77 |
return answer, None
|
78 |
|
79 |
+
def video_vision(video_input_path, prompt):
|
80 |
+
vid_frames = read_video(video_input_path, video_interval=6)
|
81 |
+
# create a question (<image> is a placeholder for the video frames)
|
82 |
+
question = f"<image>{prompt}"
|
83 |
+
result = model.predict_forward(
|
84 |
+
video=vid_frames,
|
85 |
+
text=question,
|
86 |
+
tokenizer=tokenizer,
|
87 |
+
)
|
88 |
+
prediction = result['prediction']
|
89 |
+
print(prediction)
|
90 |
+
|
91 |
+
return result['prediction'], None
|
92 |
+
|
93 |
+
|
94 |
|
95 |
# Gradio UI
|
96 |
|
97 |
with gr.Blocks() as demo:
|
98 |
with gr.Column():
|
99 |
gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
|
100 |
+
with gr.Tab("Single Image"):
|
101 |
+
with gr.Row():
|
102 |
+
with gr.Column():
|
103 |
+
image_input = gr.Image(label="Image IN", type="filepath")
|
104 |
+
with gr.Row():
|
105 |
+
instruction = gr.Textbox(label="Instruction", scale=4)
|
106 |
+
submit_image_btn = gr.Button("Submit", scale=1)
|
107 |
+
with gr.Column():
|
108 |
+
output_res = gr.Textbox(label="Response")
|
109 |
+
output_image = gr.Image(label="Segmentation", type="numpy")
|
110 |
+
|
111 |
+
submit_image_btn.click(
|
112 |
+
fn = image_vision,
|
113 |
+
inputs = [image_input, instruction],
|
114 |
+
outputs = [output_res, output_image]
|
115 |
+
)
|
116 |
+
with gr.Tab("Video"):
|
117 |
+
with gr.Row():
|
118 |
+
with gr.Column():
|
119 |
+
video_input = gr.Image(label="Video IN")
|
120 |
+
with gr.Row():
|
121 |
+
vid_instruction = gr.Textbox(label="Instruction", scale=4)
|
122 |
+
submit_video_btn = gr.Button("Submit", scale=1)
|
123 |
+
with gr.Column():
|
124 |
+
vid_output_res = gr.Textbox(label="Response")
|
125 |
+
output_video = gr.Video(label="Segmentation")
|
126 |
+
|
127 |
+
submit_video_btn.click(
|
128 |
+
fn = video_vision,
|
129 |
+
inputs = [video_input, vid_instruction],
|
130 |
+
outputs = [vid_output_res, output_video]
|
131 |
+
)
|
132 |
|
133 |
demo.queue().launch(show_api=False, show_error=True)
|