fffiloni commited on
Commit
5672cc2
·
verified ·
1 Parent(s): c0784bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -22
app.py CHANGED
@@ -28,6 +28,16 @@ tokenizer = AutoTokenizer.from_pretrained(
28
  trust_remote_code = True,
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
31
  def visualize(pred_mask, image_path, work_dir):
32
  visualizer = Visualizer()
33
  img = cv2.imread(image_path)
@@ -56,13 +66,6 @@ def image_vision(image_input_path, prompt):
56
 
57
  seg_image = return_dict["prediction_masks"]
58
 
59
- return answer, seg_image
60
-
61
- def main_infer(image_input_path, prompt):
62
-
63
- answer, seg_image = image_vision(image_input_path, prompt)
64
-
65
-
66
  if '[SEG]' in answer and Visualizer is not None:
67
  pred_masks = seg_image[0]
68
  temp_dir = tempfile.mkdtemp()
@@ -73,26 +76,58 @@ def main_infer(image_input_path, prompt):
73
  else:
74
  return answer, None
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Gradio UI
78
 
79
  with gr.Blocks() as demo:
80
  with gr.Column():
81
  gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
82
- with gr.Row():
83
- with gr.Column():
84
- image_input = gr.Image(label="Image IN", type="filepath")
85
- with gr.Row():
86
- instruction = gr.Textbox(label="Instruction", scale=4)
87
- submit_btn = gr.Button("Submit", scale=1)
88
- with gr.Column():
89
- output_res = gr.Textbox(label="Response")
90
- output_image = gr.Image(label="Segmentation", type="numpy")
91
-
92
- submit_btn.click(
93
- fn = main_infer,
94
- inputs = [image_input, instruction],
95
- outputs = [output_res, output_image]
96
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  demo.queue().launch(show_api=False, show_error=True)
 
28
  trust_remote_code = True,
29
  )
30
 
31
+ from third_parts import VideoReader
32
+ def read_video(video_path, video_interval):
33
+ vid_frames = VideoReader(video_path)[::video_interval]
34
+ for frame_idx in range(len(vid_frames)):
35
+ frame_image = vid_frames[frame_idx]
36
+ frame_image = frame_image[..., ::-1] # BGR (opencv system) to RGB (numpy system)
37
+ frame_image = Image.fromarray(frame_image)
38
+ vid_frames[frame_idx] = frame_image
39
+ return vid_frames
40
+
41
  def visualize(pred_mask, image_path, work_dir):
42
  visualizer = Visualizer()
43
  img = cv2.imread(image_path)
 
66
 
67
  seg_image = return_dict["prediction_masks"]
68
 
 
 
 
 
 
 
 
69
  if '[SEG]' in answer and Visualizer is not None:
70
  pred_masks = seg_image[0]
71
  temp_dir = tempfile.mkdtemp()
 
76
  else:
77
  return answer, None
78
 
79
+ def video_vision(video_input_path, prompt):
80
+ vid_frames = read_video(video_input_path, video_interval=6)
81
+ # create a question (<image> is a placeholder for the video frames)
82
+ question = f"<image>{prompt}"
83
+ result = model.predict_forward(
84
+ video=vid_frames,
85
+ text=question,
86
+ tokenizer=tokenizer,
87
+ )
88
+ prediction = result['prediction']
89
+ print(prediction)
90
+
91
+ return result['prediction'], None
92
+
93
+
94
 
95
  # Gradio UI
96
 
97
  with gr.Blocks() as demo:
98
  with gr.Column():
99
  gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos")
100
+ with gr.Tab("Single Image"):
101
+ with gr.Row():
102
+ with gr.Column():
103
+ image_input = gr.Image(label="Image IN", type="filepath")
104
+ with gr.Row():
105
+ instruction = gr.Textbox(label="Instruction", scale=4)
106
+ submit_image_btn = gr.Button("Submit", scale=1)
107
+ with gr.Column():
108
+ output_res = gr.Textbox(label="Response")
109
+ output_image = gr.Image(label="Segmentation", type="numpy")
110
+
111
+ submit_image_btn.click(
112
+ fn = image_vision,
113
+ inputs = [image_input, instruction],
114
+ outputs = [output_res, output_image]
115
+ )
116
+ with gr.Tab("Video"):
117
+ with gr.Row():
118
+ with gr.Column():
119
+ video_input = gr.Image(label="Video IN")
120
+ with gr.Row():
121
+ vid_instruction = gr.Textbox(label="Instruction", scale=4)
122
+ submit_video_btn = gr.Button("Submit", scale=1)
123
+ with gr.Column():
124
+ vid_output_res = gr.Textbox(label="Response")
125
+ output_video = gr.Video(label="Segmentation")
126
+
127
+ submit_video_btn.click(
128
+ fn = video_vision,
129
+ inputs = [video_input, vid_instruction],
130
+ outputs = [vid_output_res, output_video]
131
+ )
132
 
133
  demo.queue().launch(show_api=False, show_error=True)