SynthPose (Transformers ๐Ÿค— VitPose Base variant)

The SynthPose model was proposed in OpenCapBench: A Benchmark to Bridge Pose Estimation and Biomechanics by Yoni Gozlan, Antoine Falisse, Scott Uhlrich, Anthony Gatti, Michael Black, Akshay Chaudhari.

Intended use cases

This model uses a VitPose Base backbone. SynthPose is a new approach that enables finetuning of pre-trained 2D human pose models to predict an arbitrarily denser set of keypoints for accurate kinematic analysis through the use of synthetic data.
More details are available in OpenCapBench: A Benchmark to Bridge Pose Estimation and Biomechanics.
This particular variant was finetuned on a set of keypoints usually found on motion capture setups, and include coco keypoints as well.

The model predicts the following 52 markers:

{
    0: "Nose",
    1: "L_Eye",
    2: "R_Eye",
    3: "L_Ear",
    4: "R_Ear",
    5: "L_Shoulder",
    6: "R_Shoulder",
    7: "L_Elbow",
    8: "R_Elbow",
    9: "L_Wrist",
    10: "R_Wrist",
    11: "L_Hip",
    12: "R_Hip",
    13: "L_Knee",
    14: "R_Knee",
    15: "L_Ankle",
    16: "R_Ankle",
    17: "sternum",
    18: "rshoulder",
    19: "lshoulder",
    20: "r_lelbow",
    21: "l_lelbow",
    22: "r_melbow",
    23: "l_melbow",
    24: "r_lwrist",
    25: "l_lwrist",
    26: "r_mwrist",
    27: "l_mwrist",
    28: "r_ASIS",
    29: "l_ASIS",
    30: "r_PSIS",
    31: "l_PSIS",
    32: "r_knee",
    33: "l_knee",
    34: "r_mknee",
    35: "l_mknee",
    36: "r_ankle",
    37: "l_ankle",
    38: "r_mankle",
    39: "l_mankle",
    40: "r_5meta",
    41: "l_5meta",
    42: "r_toe",
    43: "l_toe",
    44: "r_big_toe",
    45: "l_big_toe",
    46: "l_calc",
    47: "r_calc",
    48: "C7",
    49: "L2",
    50: "T11",
    51: "T6",
}

Where the first 17 keypoints are the COCO keypoints, and the next 35 are anatomical markers.

Usage

Image inference

Here's how to load the model and run inference on an image:

import torch
import requests
import numpy as np

from PIL import Image

from transformers import (
    AutoProcessor,
    RTDetrForObjectDetection,
    VitPoseForPoseEstimation,
)

device = "cuda" if torch.cuda.is_available() else "cpu"

url = "http://farm4.staticflickr.com/3300/3416216247_f9c6dfc939_z.jpg"
image = Image.open(requests.get(url, stream=True).raw)

# ------------------------------------------------------------------------
# Stage 1. Detect humans on the image
# ------------------------------------------------------------------------

# You can choose detector by your choice
person_image_processor = AutoProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
person_model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365", device_map=device)

inputs = person_image_processor(images=image, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = person_model(**inputs)

results = person_image_processor.post_process_object_detection(
    outputs, target_sizes=torch.tensor([(image.height, image.width)]), threshold=0.3
)
result = results[0]  # take first image results

# Human label refers 0 index in COCO dataset
person_boxes = result["boxes"][result["labels"] == 0]
person_boxes = person_boxes.cpu().numpy()

# Convert boxes from VOC (x1, y1, x2, y2) to COCO (x1, y1, w, h) format
person_boxes[:, 2] = person_boxes[:, 2] - person_boxes[:, 0]
person_boxes[:, 3] = person_boxes[:, 3] - person_boxes[:, 1]

# ------------------------------------------------------------------------
# Stage 2. Detect keypoints for each person found
# ------------------------------------------------------------------------

image_processor = AutoProcessor.from_pretrained("yonigozlan/synthpose-vit-base-hf")
model = VitPoseForPoseEstimation.from_pretrained("yonigozlan/synthpose-vit-base-hf", device_map=device)

inputs = image_processor(image, boxes=[person_boxes], return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model(**inputs)

pose_results = image_processor.post_process_pose_estimation(outputs, boxes=[person_boxes])
image_pose_result = pose_results[0]  # results for first image

Visualization for supervision user

import supervision as sv

xy = torch.stack([pose_result['keypoints'] for pose_result in image_pose_result]).cpu().numpy()
scores = torch.stack([pose_result['scores'] for pose_result in image_pose_result]).cpu().numpy()

key_points = sv.KeyPoints(
    xy=xy, confidence=scores
)

vertex_annotator = sv.VertexAnnotator(
    color=sv.Color.PINK,
    radius=2
)

annotated_frame = vertex_annotator.annotate(
    scene=image.copy(),
    key_points=key_points
)
annotated_frame

Advanced manual visualization

import math
import cv2

def draw_points(image, keypoints, scores, pose_keypoint_color, keypoint_score_threshold, radius, show_keypoint_weight):
    if pose_keypoint_color is not None:
        assert len(pose_keypoint_color) == len(keypoints)
    for kid, (kpt, kpt_score) in enumerate(zip(keypoints, scores)):
        x_coord, y_coord = int(kpt[0]), int(kpt[1])
        if kpt_score > keypoint_score_threshold:
            color = tuple(int(c) for c in pose_keypoint_color[kid])
            if show_keypoint_weight:
                cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)
                transparency = max(0, min(1, kpt_score))
                cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
            else:
                cv2.circle(image, (int(x_coord), int(y_coord)), radius, color, -1)

def draw_links(image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
    height, width, _ = image.shape
    if keypoint_edges is not None and link_colors is not None:
        assert len(link_colors) == len(keypoint_edges)
        for sk_id, sk in enumerate(keypoint_edges):
            x1, y1, score1 = (int(keypoints[sk[0], 0]), int(keypoints[sk[0], 1]), scores[sk[0]])
            x2, y2, score2 = (int(keypoints[sk[1], 0]), int(keypoints[sk[1], 1]), scores[sk[1]])
            if (
                x1 > 0
                and x1 < width
                and y1 > 0
                and y1 < height
                and x2 > 0
                and x2 < width
                and y2 > 0
                and y2 < height
                and score1 > keypoint_score_threshold
                and score2 > keypoint_score_threshold
            ):
                color = tuple(int(c) for c in link_colors[sk_id])
                if show_keypoint_weight:
                    X = (x1, x2)
                    Y = (y1, y2)
                    mean_x = np.mean(X)
                    mean_y = np.mean(Y)
                    length = ((Y[0] - Y[1]) ** 2 + (X[0] - X[1]) ** 2) ** 0.5
                    angle = math.degrees(math.atan2(Y[0] - Y[1], X[0] - X[1]))
                    polygon = cv2.ellipse2Poly(
                        (int(mean_x), int(mean_y)), (int(length / 2), int(stick_width)), int(angle), 0, 360, 1
                    )
                    cv2.fillConvexPoly(image, polygon, color)
                    transparency = max(0, min(1, 0.5 * (keypoints[sk[0], 2] + keypoints[sk[1], 2])))
                    cv2.addWeighted(image, transparency, image, 1 - transparency, 0, dst=image)
                else:
                    cv2.line(image, (x1, y1), (x2, y2), color, thickness=thickness)


# Note: keypoint_edges and color palette are dataset-specific
keypoint_edges = model.config.edges

palette = np.array(
    [
        [255, 128, 0],
        [255, 153, 51],
        [255, 178, 102],
        [230, 230, 0],
        [255, 153, 255],
        [153, 204, 255],
        [255, 102, 255],
        [255, 51, 255],
        [102, 178, 255],
        [51, 153, 255],
        [255, 153, 153],
        [255, 102, 102],
        [255, 51, 51],
        [153, 255, 153],
        [102, 255, 102],
        [51, 255, 51],
        [0, 255, 0],
        [0, 0, 255],
        [255, 0, 0],
        [255, 255, 255],
    ]
)

link_colors = palette[[0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16]]
keypoint_colors = palette[[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]+[4]*(52-17)]

numpy_image = np.array(image)

for pose_result in image_pose_result:
    scores = np.array(pose_result["scores"])
    keypoints = np.array(pose_result["keypoints"])

    # draw each point on image
    draw_points(numpy_image, keypoints, scores, keypoint_colors, keypoint_score_threshold=0.3, radius=2, show_keypoint_weight=False)

    # draw links
    draw_links(numpy_image, keypoints, scores, keypoint_edges, link_colors, keypoint_score_threshold=0.3, thickness=1, show_keypoint_weight=False)

pose_image = Image.fromarray(numpy_image)
pose_image

Downloads last month
122
Safetensors
Model size
90M params
Tensor type
F32
ยท
Inference API
Unable to determine this modelโ€™s pipeline type. Check the docs .