#!/usr/bin/env python """A demo of the VitPose model. This code is based on the implementation from the Colab notebook: https://colab.research.google.com/drive/1e8fcby5rhKZWcr9LSN8mNbQ0TU4Dxxpo """ import pathlib import gradio as gr import PIL.Image import spaces import supervision as sv import torch from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation DESCRIPTION = "# ViTPose" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") person_detector_name = "PekingU/rtdetr_r50vd_coco_o365" person_image_processor = AutoProcessor.from_pretrained(person_detector_name) person_model = RTDetrForObjectDetection.from_pretrained(person_detector_name, device_map=device) pose_model_name = "usyd-community/vitpose-base-simple" pose_image_processor = AutoProcessor.from_pretrained(pose_model_name) pose_model = VitPoseForPoseEstimation.from_pretrained(pose_model_name, device_map=device) @spaces.GPU @torch.inference_mode() def run(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dict]]: inputs = person_image_processor(images=image, return_tensors="pt").to(device) outputs = person_model(**inputs) results = person_image_processor.post_process_object_detection( outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3 ) result = results[0] # take first image results # Human label refers 0 index in COCO dataset person_boxes_xyxy = result["boxes"][result["labels"] == 0] person_boxes_xyxy = person_boxes_xyxy.cpu().numpy() # Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format person_boxes = person_boxes_xyxy.copy() person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0] person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1] inputs = pose_image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device) # for vitpose-plus-base checkpoint we should additionaly provide dataset_index # to sepcify which MOE experts to use for inference if pose_model.config.backbone_config.num_experts > 1: dataset_index = torch.tensor([0] * len(inputs["pixel_values"])) dataset_index = dataset_index.to(inputs["pixel_values"].device) inputs["dataset_index"] = dataset_index outputs = pose_model(**inputs) pose_results = pose_image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes]) image_pose_result = pose_results[0] # results for first image # make results more human-readable human_readable_results = [] for i, person_pose in enumerate(image_pose_result): data = { "person_id": i, "bbox": person_pose["bbox"].numpy().tolist(), "keypoints": [], } for keypoint, label, score in zip( person_pose["keypoints"], person_pose["labels"], person_pose["scores"], strict=True ): keypoint_name = pose_model.config.id2label[label.item()] x, y = keypoint data["keypoints"].append({"name": keypoint_name, "x": x.item(), "y": y.item(), "score": score.item()}) human_readable_results.append(data) # preprocess to torch tensor of shape (n_objects, n_keypoints, 2) xy = [pose_result["keypoints"] for pose_result in image_pose_result] xy = torch.stack(xy).cpu().numpy() scores = [pose_result["scores"] for pose_result in image_pose_result] scores = torch.stack(scores).cpu().numpy() keypoints = sv.KeyPoints(xy=xy, confidence=scores) detections = sv.Detections(xyxy=person_boxes_xyxy) edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=1) vertex_annotator = sv.VertexAnnotator(color=sv.Color.RED, radius=2) bounding_box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=1) annotated_frame = image.copy() # annotate boundg boxes annotated_frame = bounding_box_annotator.annotate(scene=image.copy(), detections=detections) # annotate edges and verticies annotated_frame = edge_annotator.annotate(scene=annotated_frame, key_points=keypoints) return vertex_annotator.annotate(scene=annotated_frame, key_points=keypoints), human_readable_results paths = sorted(pathlib.Path("images").glob("*.jpg")) with gr.Blocks(css_paths="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") run_button = gr.Button() with gr.Column(): output_image = gr.Image(label="Output Image") output_json = gr.JSON(label="Output JSON") gr.Examples(examples=paths, inputs=input_image, outputs=[output_image, output_json], fn=run) run_button.click( fn=run, inputs=input_image, outputs=[output_image, output_json], ) if __name__ == "__main__": demo.queue(max_size=20).launch()