import argparse import os from PIL import Image from transformers import AutoModelForCausalLM, AutoTokenizer import cv2 try: from mmengine.visualization import Visualizer except ImportError: Visualizer = None print("Warning: mmengine is not installed, visualization is disabled.") def parse_args(): parser = argparse.ArgumentParser(description='Video Reasoning Segmentation') parser.add_argument('image_folder', help='Path to image file') parser.add_argument('--model_path', default="ByteDance/Sa2VA-8B") parser.add_argument('--work-dir', default=None, help='The dir to save results.') parser.add_argument('--text', type=str, default="Please describe the video content.") parser.add_argument('--select', type=int, default=-1) args = parser.parse_args() return args def visualize(pred_mask, image_path, work_dir): visualizer = Visualizer() img = cv2.imread(image_path) visualizer.set_image(img) visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4) visual_result = visualizer.get_image() output_path = os.path.join(work_dir, os.path.basename(image_path)) cv2.imwrite(output_path, visual_result) if __name__ == "__main__": cfg = parse_args() model_path = cfg.model_path model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype="auto", device_map="auto", trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True ) image_files = [] image_paths = [] image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"} for filename in sorted(list(os.listdir(cfg.image_folder))): if os.path.splitext(filename)[1].lower() in image_extensions: image_files.append(filename) image_paths.append(os.path.join(cfg.image_folder, filename)) vid_frames = [] for img_path in image_paths: img = Image.open(img_path).convert('RGB') vid_frames.append(img) if cfg.select > 0: img_frame = vid_frames[cfg.select - 1] print(f"Selected frame {cfg.select}") print(f"The input is:\n{cfg.text}") result = model.predict_forward( image=img_frame, text=cfg.text, tokenizer=tokenizer, ) else: print(f"The input is:\n{cfg.text}") result = model.predict_forward( video=vid_frames, text=cfg.text, tokenizer=tokenizer, ) prediction = result['prediction'] print(f"The output is:\n{prediction}") if '[SEG]' in prediction and Visualizer is not None: _seg_idx = 0 pred_masks = result['prediction_masks'][_seg_idx] for frame_idx in range(len(vid_frames)): pred_mask = pred_masks[frame_idx] if cfg.work_dir: os.makedirs(cfg.work_dir, exist_ok=True) visualize(pred_mask, image_paths[frame_idx], cfg.work_dir) else: os.makedirs('./temp_visualize_results', exist_ok=True) visualize(pred_mask, image_paths[frame_idx], './temp_visualize_results') else: pass