import os import sys os.chdir('GroundingDINO/') os.system('pip install -e .') os.chdir('../SAM') os.system('pip install -e .') os.system('pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel gradio loguru transformers timm addict yapf loguru tqdm scikit-image scikit-learn pandas tensorboard seaborn open_clip_torch einops') os.system('pip install torch==1.10.0 torchvision==0.11.1 -f https://download.pytorch.org/whl/cu113/torch_stable.html') os.chdir('..') os.mkdir('weights') os.chdir('./weights') os.system('wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth') os.system('wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth') os.chdir('..') import sys sys.path.append('./GroundingDINO') sys.path.append('./SAM') sys.path.append('.') import matplotlib.pyplot as plt import SAA as SegmentAnyAnomaly from utils.training_utils import * import os dino_config_file = 'GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py' dino_checkpoint = 'weights/groundingdino_swint_ogc.pth' sam_checkpoint = 'weights/sam_vit_h_4b8939.pth' box_threshold = 0.1 text_threshold = 0.1 eval_resolution = 256 device = f"cpu" root_dir = 'result' # get the model model = SegmentAnyAnomaly.Model( dino_config_file=dino_config_file, dino_checkpoint=dino_checkpoint, sam_checkpoint=sam_checkpoint, box_threshold=box_threshold, text_threshold=text_threshold, out_size=eval_resolution, device=device, ) model = model.to(device) import cv2 import numpy as np import gradio as gr def process_image(heatmap, image): heatmap = heatmap.astype(float) heatmap = (heatmap - heatmap.min()) / heatmap.max() * 255 heatmap = heatmap.astype(np.uint8) heat_map = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) visz_map = cv2.addWeighted(heat_map, 0.5, image, 0.5, 0) visz_map = cv2.cvtColor(visz_map, cv2.COLOR_BGR2RGB) visz_map = visz_map.astype(float) visz_map = visz_map / visz_map.max() return visz_map def func(image, anomaly_description, object_name, object_number, mask_number, area_threashold): textual_prompts = [ [anomaly_description, object_name] ] # detect prompts, filtered phrase property_text_prompts = f'the image of {object_name} have {object_number} dissimilar {object_name}, with a maximum of {mask_number} anomaly. The anomaly would not exceed {area_threashold} object area. ' model.set_ensemble_text_prompts(textual_prompts, verbose=True) model.set_property_text_prompts(property_text_prompts, verbose=True) image = cv2.resize(image, (eval_resolution, eval_resolution)) score, appendix = model(image) similarity_map = appendix['similarity_map'] image_show = cv2.resize(image, (eval_resolution, eval_resolution)) similarity_map = cv2.resize(similarity_map, (eval_resolution, eval_resolution)) score = cv2.resize(score, (eval_resolution, eval_resolution)) viz_score = process_image(score, image_show) viz_sim = process_image(similarity_map, image_show) return viz_score, viz_sim with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image(label="Image") anomaly_description = gr.Textbox(label="Anomaly Description (e.g. color defect. hole. black defect. wick hole. spot. )") object_name = gr.Textbox(label="Object Name (e.g. candle)") object_number = gr.Textbox(label="Object Number (e.g. 4)") mask_number = gr.Textbox(label="Mask Number (e.g. 1)") area_threashold = gr.Textbox(label="Area Threshold (e.g. 0.3)") with gr.Column(): anomaly_score = gr.Image(label="Anomaly Score") saliency_map = gr.Image(label="Saliency Map") greet_btn = gr.Button("Inference") greet_btn.click(fn=func, inputs=[image, anomaly_description, object_name, object_number, mask_number, area_threashold], outputs=[anomaly_score, saliency_map], api_name="Segment-Any-Anomaly") demo.launch()