fffiloni's picture
Migrated from GitHub
d59f323 verified
raw
history blame
3.21 kB
import argparse
import os
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
import cv2
try:
from mmengine.visualization import Visualizer
except ImportError:
Visualizer = None
print("Warning: mmengine is not installed, visualization is disabled.")
def parse_args():
parser = argparse.ArgumentParser(description='Video Reasoning Segmentation')
parser.add_argument('image_folder', help='Path to image file')
parser.add_argument('--model_path', default="ByteDance/Sa2VA-8B")
parser.add_argument('--work-dir', default=None, help='The dir to save results.')
parser.add_argument('--text', type=str, default="<image>Please describe the video content.")
parser.add_argument('--select', type=int, default=-1)
args = parser.parse_args()
return args
def visualize(pred_mask, image_path, work_dir):
visualizer = Visualizer()
img = cv2.imread(image_path)
visualizer.set_image(img)
visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
visual_result = visualizer.get_image()
output_path = os.path.join(work_dir, os.path.basename(image_path))
cv2.imwrite(output_path, visual_result)
if __name__ == "__main__":
cfg = parse_args()
model_path = cfg.model_path
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True
)
image_files = []
image_paths = []
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
for filename in sorted(list(os.listdir(cfg.image_folder))):
if os.path.splitext(filename)[1].lower() in image_extensions:
image_files.append(filename)
image_paths.append(os.path.join(cfg.image_folder, filename))
vid_frames = []
for img_path in image_paths:
img = Image.open(img_path).convert('RGB')
vid_frames.append(img)
if cfg.select > 0:
img_frame = vid_frames[cfg.select - 1]
print(f"Selected frame {cfg.select}")
print(f"The input is:\n{cfg.text}")
result = model.predict_forward(
image=img_frame,
text=cfg.text,
tokenizer=tokenizer,
)
else:
print(f"The input is:\n{cfg.text}")
result = model.predict_forward(
video=vid_frames,
text=cfg.text,
tokenizer=tokenizer,
)
prediction = result['prediction']
print(f"The output is:\n{prediction}")
if '[SEG]' in prediction and Visualizer is not None:
_seg_idx = 0
pred_masks = result['prediction_masks'][_seg_idx]
for frame_idx in range(len(vid_frames)):
pred_mask = pred_masks[frame_idx]
if cfg.work_dir:
os.makedirs(cfg.work_dir, exist_ok=True)
visualize(pred_mask, image_paths[frame_idx], cfg.work_dir)
else:
os.makedirs('./temp_visualize_results', exist_ok=True)
visualize(pred_mask, image_paths[frame_idx], './temp_visualize_results')
else:
pass