fffiloni commited on
Commit
2a274cc
·
verified ·
1 Parent(s): 4140fc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -2
app.py CHANGED
@@ -3,8 +3,16 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from PIL import Image
4
  import numpy as np
5
  import os
 
6
  import gradio as gr
7
 
 
 
 
 
 
 
 
8
  # Load the model and tokenizer
9
  model_path = "ByteDance/Sa2VA-4B"
10
 
@@ -20,6 +28,17 @@ tokenizer = AutoTokenizer.from_pretrained(
20
  trust_remote_code = True,
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  def image_vision(image_input_path, prompt):
24
  image_path = image_input_path
25
  text_prompts = f"<image>{prompt}"
@@ -34,6 +53,7 @@ def image_vision(image_input_path, prompt):
34
  return_dict = model.predict_forward(**input_dict)
35
  print(return_dict)
36
  answer = return_dict["prediction"] # the text format answer
 
37
  seg_image = return_dict["prediction_masks"]
38
 
39
  return answer, seg_image
@@ -41,7 +61,15 @@ def image_vision(image_input_path, prompt):
41
  def main_infer(image_input_path, prompt):
42
 
43
  answer, seg_image = image_vision(image_input_path, prompt)
44
- return answer, seg_image[0]
 
 
 
 
 
 
 
 
45
 
46
  # Gradio UI
47
 
@@ -56,7 +84,7 @@ with gr.Blocks() as demo:
56
  submit_btn = gr.Button("Submit", scale=1)
57
  with gr.Column():
58
  output_res = gr.Textbox(label="Response")
59
- output_image = gr.Image(label="Segmentation")
60
 
61
  submit_btn.click(
62
  fn = main_infer,
 
3
  from PIL import Image
4
  import numpy as np
5
  import os
6
+ import tempfile
7
  import gradio as gr
8
 
9
+ import cv2
10
+ try:
11
+ from mmengine.visualization import Visualizer
12
+ except ImportError:
13
+ Visualizer = None
14
+ print("Warning: mmengine is not installed, visualization is disabled.")
15
+
16
  # Load the model and tokenizer
17
  model_path = "ByteDance/Sa2VA-4B"
18
 
 
28
  trust_remote_code = True,
29
  )
30
 
31
+ def visualize(pred_mask, image_path, work_dir):
32
+ visualizer = Visualizer()
33
+ img = cv2.imread(image_path)
34
+ visualizer.set_image(img)
35
+ visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
36
+ visual_result = visualizer.get_image()
37
+
38
+ output_path = os.path.join(work_dir, os.path.basename(image_path))
39
+ cv2.imwrite(output_path, visual_result)
40
+ return output_path
41
+
42
  def image_vision(image_input_path, prompt):
43
  image_path = image_input_path
44
  text_prompts = f"<image>{prompt}"
 
53
  return_dict = model.predict_forward(**input_dict)
54
  print(return_dict)
55
  answer = return_dict["prediction"] # the text format answer
56
+
57
  seg_image = return_dict["prediction_masks"]
58
 
59
  return answer, seg_image
 
61
  def main_infer(image_input_path, prompt):
62
 
63
  answer, seg_image = image_vision(image_input_path, prompt)
64
+ pred_masks = seg_image[0]
65
+
66
+ if '[SEG]' in answer and Visualizer is not None:
67
+ temp_dir = tempfile.mkdtemp()
68
+ pred_mask = pred_masks[0]
69
+ os.makedirs(temp_dir, exist_ok=True)
70
+ seg_result = visualize(pred_mask, image_input_path, temp_dir)
71
+
72
+ return answer, seg_result
73
 
74
  # Gradio UI
75
 
 
84
  submit_btn = gr.Button("Submit", scale=1)
85
  with gr.Column():
86
  output_res = gr.Textbox(label="Response")
87
+ output_image = gr.Image(label="Segmentation", type="numpy")
88
 
89
  submit_btn.click(
90
  fn = main_infer,