import torch from transformers import AutoModelForCausalLM, AutoTokenizer from PIL import Image import numpy as np import os import tempfile import gradio as gr import cv2 try: from mmengine.visualization import Visualizer except ImportError: Visualizer = None print("Warning: mmengine is not installed, visualization is disabled.") # Load the model and tokenizer model_path = "ByteDance/Sa2VA-4B" model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True, ).eval().cuda() tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code = True, ) from third_parts import VideoReader def read_video(video_path, video_interval): vid_frames = VideoReader(video_path)[::video_interval] temp_dir = tempfile.mkdtemp() os.makedirs(temp_dir, exist_ok=True) image_paths = [] # List to store paths of saved images for frame_idx in range(len(vid_frames)): frame_image = vid_frames[frame_idx] frame_image = frame_image[..., ::-1] # BGR (opencv system) to RGB (numpy system) frame_image = Image.fromarray(frame_image) vid_frames[frame_idx] = frame_image # Save the frame as a .jpg file in the temporary folder image_path = os.path.join(temp_dir, f"frame_{frame_idx:04d}.jpg") frame_image.save(image_path, format="JPEG") # Append the image path to the list image_paths.append(image_path) return vid_frames, image_paths def visualize(pred_mask, image_path, work_dir): visualizer = Visualizer() img = cv2.imread(image_path) visualizer.set_image(img) visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4) visual_result = visualizer.get_image() output_path = os.path.join(work_dir, os.path.basename(image_path)) cv2.imwrite(output_path, visual_result) return output_path def image_vision(image_input_path, prompt): image_path = image_input_path text_prompts = f"{prompt}" image = Image.open(image_path).convert('RGB') input_dict = { 'image': image, 'text': text_prompts, 'past_text': '', 'mask_prompts': None, 'tokenizer': tokenizer, } return_dict = model.predict_forward(**input_dict) print(return_dict) answer = return_dict["prediction"] # the text format answer seg_image = return_dict["prediction_masks"] if '[SEG]' in answer and Visualizer is not None: pred_masks = seg_image[0] temp_dir = tempfile.mkdtemp() pred_mask = pred_masks os.makedirs(temp_dir, exist_ok=True) seg_result = visualize(pred_mask, image_input_path, temp_dir) return answer, seg_result else: return answer, None def video_vision(video_input_path, prompt): vid_frames, image_paths = read_video(video_input_path, video_interval=6) # create a question ( is a placeholder for the video frames) question = f"{prompt}" result = model.predict_forward( video=vid_frames, text=question, tokenizer=tokenizer, ) prediction = result['prediction'] print(prediction) if '[SEG]' in prediction and Visualizer is not None: _seg_idx = 0 pred_masks = result['prediction_masks'][_seg_idx] seg_frames = [] for frame_idx in range(len(vid_frames)): pred_mask = pred_masks[frame_idx] temp_dir = tempfile.mkdtemp() os.makedirs(temp_dir, exist_ok=True) seg_frame = visualize(pred_mask, image_paths[frame_idx], temp_dir) seg_frames.append(seg_frame) return result['prediction'], seg_frames else: return result['prediction'], None # Gradio UI with gr.Blocks() as demo: with gr.Column(): gr.Markdown("# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos") with gr.Tab("Single Image"): with gr.Row(): with gr.Column(): image_input = gr.Image(label="Image IN", type="filepath") with gr.Row(): instruction = gr.Textbox(label="Instruction", scale=4) submit_image_btn = gr.Button("Submit", scale=1) with gr.Column(): output_res = gr.Textbox(label="Response") output_image = gr.Image(label="Segmentation", type="numpy") submit_image_btn.click( fn = image_vision, inputs = [image_input, instruction], outputs = [output_res, output_image] ) with gr.Tab("Video"): with gr.Row(): with gr.Column(): video_input = gr.Video(label="Video IN") with gr.Row(): vid_instruction = gr.Textbox(label="Instruction", scale=4) submit_video_btn = gr.Button("Submit", scale=1) with gr.Column(): vid_output_res = gr.Textbox(label="Response") output_video = gr.Video(label="Segmentation") submit_video_btn.click( fn = video_vision, inputs = [video_input, vid_instruction], outputs = [vid_output_res, output_video] ) demo.queue().launch(show_api=False, show_error=True)