File size: 4,947 Bytes
82b20ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python

"""A demo of the VitPose model.

This code is based on the implementation from the Colab notebook:
https://colab.research.google.com/drive/1e8fcby5rhKZWcr9LSN8mNbQ0TU4Dxxpo
"""

import pathlib

import gradio as gr
import PIL.Image
import spaces
import supervision as sv
import torch
from transformers import AutoProcessor, RTDetrForObjectDetection, VitPoseForPoseEstimation

DESCRIPTION = "# ViTPose"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

person_detector_name = "PekingU/rtdetr_r50vd_coco_o365"
person_image_processor = AutoProcessor.from_pretrained(person_detector_name)
person_model = RTDetrForObjectDetection.from_pretrained(person_detector_name, device_map=device)

pose_model_name = "usyd-community/vitpose-base-simple"
pose_image_processor = AutoProcessor.from_pretrained(pose_model_name)
pose_model = VitPoseForPoseEstimation.from_pretrained(pose_model_name, device_map=device)


@spaces.GPU
@torch.inference_mode()
def run(image: PIL.Image.Image) -> tuple[PIL.Image.Image, list[dict]]:
    inputs = person_image_processor(images=image, return_tensors="pt").to(device)
    outputs = person_model(**inputs)
    results = person_image_processor.post_process_object_detection(
        outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
    )
    result = results[0]  # take first image results

    # Human label refers 0 index in COCO dataset
    person_boxes_xyxy = result["boxes"][result["labels"] == 0]
    person_boxes_xyxy = person_boxes_xyxy.cpu().numpy()

    # Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
    person_boxes = person_boxes_xyxy.copy()
    person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
    person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]

    inputs = pose_image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)

    # for vitpose-plus-base checkpoint we should additionaly provide dataset_index
    # to sepcify which MOE experts to use for inference
    if pose_model.config.backbone_config.num_experts > 1:
        dataset_index = torch.tensor([0] * len(inputs["pixel_values"]))
        dataset_index = dataset_index.to(inputs["pixel_values"].device)
        inputs["dataset_index"] = dataset_index

    outputs = pose_model(**inputs)

    pose_results = pose_image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
    image_pose_result = pose_results[0]  # results for first image

    # make results more human-readable
    human_readable_results = []
    for i, person_pose in enumerate(image_pose_result):
        data = {
            "person_id": i,
            "bbox": person_pose["bbox"].numpy().tolist(),
            "keypoints": [],
        }
        for keypoint, label, score in zip(
            person_pose["keypoints"], person_pose["labels"], person_pose["scores"], strict=True
        ):
            keypoint_name = pose_model.config.id2label[label.item()]
            x, y = keypoint
            data["keypoints"].append({"name": keypoint_name, "x": x.item(), "y": y.item(), "score": score.item()})
        human_readable_results.append(data)

    # preprocess to torch tensor of shape (n_objects, n_keypoints, 2)
    xy = [pose_result["keypoints"] for pose_result in image_pose_result]
    xy = torch.stack(xy).cpu().numpy()

    scores = [pose_result["scores"] for pose_result in image_pose_result]
    scores = torch.stack(scores).cpu().numpy()

    keypoints = sv.KeyPoints(xy=xy, confidence=scores)
    detections = sv.Detections(xyxy=person_boxes_xyxy)

    edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=1)
    vertex_annotator = sv.VertexAnnotator(color=sv.Color.RED, radius=2)
    bounding_box_annotator = sv.BoxAnnotator(color=sv.Color.WHITE, color_lookup=sv.ColorLookup.INDEX, thickness=1)

    annotated_frame = image.copy()

    # annotate boundg boxes
    annotated_frame = bounding_box_annotator.annotate(scene=image.copy(), detections=detections)

    # annotate edges and verticies
    annotated_frame = edge_annotator.annotate(scene=annotated_frame, key_points=keypoints)
    return vertex_annotator.annotate(scene=annotated_frame, key_points=keypoints), human_readable_results


paths = sorted(pathlib.Path("images").glob("*.jpg"))


with gr.Blocks(css_paths="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image", type="pil")
            run_button = gr.Button()
        with gr.Column():
            output_image = gr.Image(label="Output Image")
            output_json = gr.JSON(label="Output JSON")

    gr.Examples(examples=paths, inputs=input_image, outputs=[output_image, output_json], fn=run)

    run_button.click(
        fn=run,
        inputs=input_image,
        outputs=[output_image, output_json],
    )


if __name__ == "__main__":
    demo.queue(max_size=20).launch()