Spaces:
Runtime error
Runtime error
''' | |
Save SAM mask predictions | |
''' | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
import torch.multiprocessing as mp | |
import pickle | |
from tqdm import tqdm | |
import torch | |
import cv2 | |
import os | |
import json | |
import argparse | |
import numpy as np | |
img_anno = { | |
'ade20k_val':['ADEChallengeData2016/images/validation', 'ADEChallengeData2016/ade20k_panoptic_val.json'], | |
'pc_val': ['pascal_ctx_d2/images/validation','' ], | |
'pas_val':['pascal_voc_d2/images/validation',''], | |
} | |
sam_checkpoint_dict = { | |
'vit_b': 'pretrained_checkpoint/sam_vit_b_01ec64.pth', | |
'vit_h': 'pretrained_checkpoint/sam_vit_h_4b8939.pth', | |
'vit_l': 'pretrained_checkpoint/sam_vit_l_0b3195.pth', | |
'vit_t': 'pretrained_checkpoint/mobile_sam.pt' | |
} | |
def process_images(args, gpu, data_chunk, save_path, if_parallel): | |
def to_parallel(if_parallel): | |
sam_checkpoint = sam_checkpoint_dict[args.sam_model] | |
sam = sam_model_registry[args.sam_model](checkpoint=sam_checkpoint) | |
if not if_parallel: | |
torch.cuda.set_device(gpu) | |
sam = sam.cuda() | |
else: | |
sam = sam.cuda() | |
sam = torch.nn.DataParallel(sam) | |
sam = sam.module | |
return sam | |
sam = to_parallel(if_parallel) | |
mask_generator = SamAutomaticMaskGenerator( | |
model=sam, | |
pred_iou_thresh=0.8, | |
stability_score_thresh=0.7, | |
crop_n_layers=0, | |
crop_n_points_downscale_factor=2, | |
min_mask_region_area=100, | |
output_mode='coco_rle' | |
) | |
# Process each image | |
for image_info in tqdm(data_chunk): | |
if isinstance(image_info, dict): | |
if 'coco_url' in image_info: | |
coco_url = image_info['coco_url'] | |
file_name = coco_url.split('/')[-1].split('.')[0] + '.jpg' | |
elif 'file_name' in image_info: | |
file_name = image_info['file_name'].split('.')[0] + '.jpg' | |
file_path = os.path.join(dataset_path,img_anno[args.data_name][0]) | |
else: | |
assert isinstance(image_info, str) | |
file_name = image_info.split('.')[0] + '.jpg' | |
file_path = os.path.join(dataset_path,img_anno[args.data_name][0]) | |
image_path = f'{file_path}/{file_name}' | |
try: | |
id =file_name.split('.')[0] | |
id = id.replace('/','_') | |
savepath = f'{save_path}/{id}.pkl' | |
if not os.path.exists(savepath): | |
image = cv2.imread(image_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
everything_mask = mask_generator.generate(image, train=False) | |
everything_mask = sorted(everything_mask, key=lambda x: x['area'], reverse=True) | |
if len(everything_mask) >50: | |
everything_mask = everything_mask[:50] | |
with open(savepath, 'wb') as f: | |
pickle.dump(everything_mask, f) | |
except Exception as e: | |
print(f"Failed to load or convert image at {image_path}. Error: {e}") | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--data_name', type=str, default='pas_val') | |
parser.add_argument('--sam_model', type=str, default='vit_h') | |
argss = parser.parse_args() | |
gpus = os.getenv("CUDA_VISIBLE_DEVICES", "") | |
dataset_path = os.getenv("DETECTRON2_DATASETS", "/users/cx_xchen/DATASETS/") | |
num_gpus = len([x.strip() for x in gpus.split(",") if x.strip().isdigit()]) | |
print(f"Using {num_gpus} GPUs") | |
# File paths | |
if img_anno[argss.data_name][1] != '': | |
json_file_path = os.path.join(dataset_path, img_anno[argss.data_name][1]) | |
# Load data | |
with open(json_file_path, 'r') as file: | |
data = json.load(file) | |
# Split data into chunks for each GPU | |
data_chunks = np.array_split(data['images'], num_gpus) | |
else: | |
image_dir = os.path.join(dataset_path, img_anno[argss.data_name][0]) | |
image_files = os.listdir(image_dir) | |
data_chunks = np.array_split(image_files, num_gpus) | |
# Create processes | |
save_path = f'output/SAM_masks_pred/{argss.sam_model}_{argss.data_name}' | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
processes = [] | |
parallel = False | |
# if parallel: | |
# assert num_gpus>1 | |
for gpu in range(num_gpus): | |
p = mp.Process(target=process_images, args=(argss, gpu, data_chunks[gpu],save_path, False)) | |
p.start() | |
processes.append(p) | |
for p in processes: | |
p.join() | |
# elif num_gpus<=1: | |
# process_images(argss, None, np.concatenate(data_chunks), save_path, if_parallel=True) | |
# else: | |
# assert NotImplemented | |