Skip to content

Commit

Permalink
[Feature] Support H3WB dataset (#2736)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch authored Dec 26, 2023
1 parent 8c4a6e0 commit fdab6f7
Show file tree
Hide file tree
Showing 12 changed files with 1,609 additions and 53 deletions.
1,151 changes: 1,151 additions & 0 deletions configs/_base_/datasets/h3wb.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion demo/body3d_pose_lifter_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def parse_args():
parser.add_argument(
'--bbox-thr',
type=float,
default=0.9,
default=0.3,
help='Bounding box score threshold')
parser.add_argument('--kpt-thr', type=float, default=0.3)
parser.add_argument(
Expand Down
77 changes: 77 additions & 0 deletions docs/en/dataset_zoo/3d_wholebody_keypoint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# 3D Body Keypoint Datasets

It is recommended to symlink the dataset root to `$MMPOSE/data`.
If your folder structure is different, you may need to change the corresponding paths in config files.

MMPose supported datasets:

- [H3WB](#h3wb) \[ [Homepage](https://github.com/wholebody3d/wholebody3d) \]

## H3WB

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2211.15692">H3WB (ICCV'2023)</a></summary>

```bibtex
@InProceedings{Zhu_2023_ICCV,
author = {Zhu, Yue and Samet, Nermin and Picard, David},
title = {H3WB: Human3.6M 3D WholeBody Dataset and Benchmark},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2023},
pages = {20166-20177}
}
```

</details>

<div align="center">
<img src="https://user-images.githubusercontent.com/100993824/227770977-c8f00355-c43a-467e-8444-d307789cf4b2.png" height="300px">
</div>

For [H3WB](https://github.com/wholebody3d/wholebody3d), please follow the [document](3d_body_keypoint.md#human36m) to download [Human3.6M](http://vision.imar.ro/human3.6m/description.php) dataset, then download the H3WB annotations from the official [webpage](https://github.com/wholebody3d/wholebody3d). NOTES: please follow their [updates](https://github.com/wholebody3d/wholebody3d?tab=readme-ov-file#updates) to download the annotations.

The data should have the following structure:

```text
mmpose
├── mmpose
├── docs
├── tests
├── tools
├── configs
`── data
├── h36m
├── annotation_body3d
| ├── cameras.pkl
| ├── h3wb_train.npz
| ├── fps50
| | ├── h36m_test.npz
| | ├── h36m_train.npz
| | ├── joint2d_rel_stats.pkl
| | ├── joint2d_stats.pkl
| | ├── joint3d_rel_stats.pkl
| | `── joint3d_stats.pkl
| `── fps10
| ├── h36m_test.npz
| ├── h36m_train.npz
| ├── joint2d_rel_stats.pkl
| ├── joint2d_stats.pkl
| ├── joint3d_rel_stats.pkl
| `── joint3d_stats.pkl
`── images
├── S1
| ├── S1_Directions_1.54138969
| | ├── S1_Directions_1.54138969_00001.jpg
| | ├── S1_Directions_1.54138969_00002.jpg
| | ├── ...
| ├── ...
├── S5
├── S6
├── S7
├── S8
├── S9
`── S11
```
28 changes: 23 additions & 5 deletions mmpose/apis/inference_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ def convert_keypoint_definition(keypoints, pose_det_dataset,
ndarray[K, 2 or 3]: the transformed 2D keypoints.
"""
assert pose_lift_dataset in [
'h36m'], '`pose_lift_dataset` should be ' \
'h36m', 'h3wb'], '`pose_lift_dataset` should be ' \
f'`h36m`, but got {pose_lift_dataset}.'

keypoints_new = np.zeros((keypoints.shape[0], 17, keypoints.shape[2]),
dtype=keypoints.dtype)
if pose_lift_dataset == 'h36m':
if pose_det_dataset in ['h36m']:
if pose_lift_dataset in ['h36m', 'h3wb']:
if pose_det_dataset in ['h36m', 'coco_wholebody']:
keypoints_new = keypoints
elif pose_det_dataset in ['coco', 'posetrack18']:
# pelvis (root) is in the middle of l_hip and r_hip
Expand Down Expand Up @@ -265,8 +265,26 @@ def inference_pose_lifter_model(model,
bbox_center = dataset_info['stats_info']['bbox_center']
bbox_scale = dataset_info['stats_info']['bbox_scale']
else:
bbox_center = None
bbox_scale = None
if norm_pose_2d:
# compute the average bbox center and scale from the
# datasamples in pose_results_2d
bbox_center = np.zeros((1, 2), dtype=np.float32)
bbox_scale = 0
num_bbox = 0
for pose_res in pose_results_2d:
for data_sample in pose_res:
for bbox in data_sample.pred_instances.bboxes:
bbox_center += np.array([[(bbox[0] + bbox[2]) / 2,
(bbox[1] + bbox[3]) / 2]
])
bbox_scale += max(bbox[2] - bbox[0],
bbox[3] - bbox[1])
num_bbox += 1
bbox_center /= num_bbox
bbox_scale /= num_bbox
else:
bbox_center = None
bbox_scale = None

pose_results_2d_copy = []
for i, pose_res in enumerate(pose_results_2d):
Expand Down
39 changes: 24 additions & 15 deletions mmpose/codecs/image_pose_lifting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np

Expand All @@ -20,7 +20,7 @@ class ImagePoseLifting(BaseKeypointCodec):
Args:
num_keypoints (int): The number of keypoints in the dataset.
root_index (int): Root keypoint index in the pose.
root_index (Union[int, List]): Root keypoint index in the pose.
remove_root (bool): If true, remove the root keypoint from the pose.
Default: ``False``.
save_index (bool): If true, store the root position separated from the
Expand Down Expand Up @@ -52,18 +52,21 @@ class ImagePoseLifting(BaseKeypointCodec):

def __init__(self,
num_keypoints: int,
root_index: int,
root_index: Union[int, List] = 0,
remove_root: bool = False,
save_index: bool = False,
reshape_keypoints: bool = True,
concat_vis: bool = False,
keypoints_mean: Optional[np.ndarray] = None,
keypoints_std: Optional[np.ndarray] = None,
target_mean: Optional[np.ndarray] = None,
target_std: Optional[np.ndarray] = None):
target_std: Optional[np.ndarray] = None,
additional_encode_keys: Optional[List[str]] = None):
super().__init__()

self.num_keypoints = num_keypoints
if isinstance(root_index, int):
root_index = [root_index]
self.root_index = root_index
self.remove_root = remove_root
self.save_index = save_index
Expand Down Expand Up @@ -96,6 +99,9 @@ def __init__(self,
self.target_mean = target_mean
self.target_std = target_std

if additional_encode_keys is not None:
self.auxiliary_encode_keys.update(additional_encode_keys)

def encode(self,
keypoints: np.ndarray,
keypoints_visible: Optional[np.ndarray] = None,
Expand Down Expand Up @@ -161,36 +167,38 @@ def encode(self,

# Zero-center the target pose around a given root keypoint
assert (lifting_target.ndim >= 2 and
lifting_target.shape[-2] > self.root_index), \
lifting_target.shape[-2] > max(self.root_index)), \
f'Got invalid joint shape {lifting_target.shape}'

root = lifting_target[..., self.root_index, :]
lifting_target_label = lifting_target - lifting_target[
..., self.root_index:self.root_index + 1, :]
root = np.mean(
lifting_target[..., self.root_index, :], axis=-2, dtype=np.float32)
lifting_target_label = lifting_target - root[np.newaxis, ...]

if self.remove_root:
if self.remove_root and len(self.root_index) == 1:
root_index = self.root_index[0]
lifting_target_label = np.delete(
lifting_target_label, self.root_index, axis=-2)
lifting_target_label, root_index, axis=-2)
lifting_target_visible = np.delete(
lifting_target_visible, self.root_index, axis=-2)
lifting_target_visible, root_index, axis=-2)
assert lifting_target_weight.ndim in {
2, 3
}, (f'lifting_target_weight.ndim {lifting_target_weight.ndim} '
'is not in {2, 3}')

axis_to_remove = -2 if lifting_target_weight.ndim == 3 else -1
lifting_target_weight = np.delete(
lifting_target_weight, self.root_index, axis=axis_to_remove)
lifting_target_weight, root_index, axis=axis_to_remove)
# Add a flag to avoid latter transforms that rely on the root
# joint or the original joint index
encoded['target_root_removed'] = True

# Save the root index which is necessary to restore the global pose
if self.save_index:
encoded['target_root_index'] = self.root_index
encoded['target_root_index'] = root_index

# Normalize the 2D keypoint coordinate with mean and std
keypoint_labels = keypoints.copy()

if self.keypoints_mean is not None:
assert self.keypoints_mean.shape[1:] == keypoints.shape[1:], (
f'self.keypoints_mean.shape[1:] {self.keypoints_mean.shape[1:]} ' # noqa
Expand All @@ -203,7 +211,8 @@ def encode(self,
if self.target_mean is not None:
assert self.target_mean.shape == lifting_target_label.shape, (
f'self.target_mean.shape {self.target_mean.shape} '
f'!= lifting_target_label.shape {lifting_target_label.shape}')
f'!= lifting_target_label.shape {lifting_target_label.shape}' # noqa
)
encoded['target_mean'] = self.target_mean.copy()
encoded['target_std'] = self.target_std.copy()

Expand Down Expand Up @@ -263,7 +272,7 @@ def decode(self,

if target_root is not None and target_root.size > 0:
keypoints = keypoints + target_root
if self.remove_root:
if self.remove_root and len(self.root_index) == 1:
keypoints = np.insert(
keypoints, self.root_index, target_root, axis=1)
scores = np.ones(keypoints.shape[:-1], dtype=np.float32)
Expand Down
31 changes: 16 additions & 15 deletions mmpose/codecs/video_pose_lifting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.

from copy import deepcopy
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import numpy as np

Expand All @@ -24,7 +24,8 @@ class VideoPoseLifting(BaseKeypointCodec):
num_keypoints (int): The number of keypoints in the dataset.
zero_center: Whether to zero-center the target around root. Default:
``True``.
root_index (int): Root keypoint index in the pose. Default: 0.
root_index (Union[int, List]): Root keypoint index in the pose.
Default: 0.
remove_root (bool): If true, remove the root keypoint from the pose.
Default: ``False``.
save_index (bool): If true, store the root position separated from the
Expand Down Expand Up @@ -54,7 +55,7 @@ class VideoPoseLifting(BaseKeypointCodec):
def __init__(self,
num_keypoints: int,
zero_center: bool = True,
root_index: int = 0,
root_index: Union[int, List] = 0,
remove_root: bool = False,
save_index: bool = False,
reshape_keypoints: bool = True,
Expand All @@ -64,6 +65,8 @@ def __init__(self,

self.num_keypoints = num_keypoints
self.zero_center = zero_center
if isinstance(root_index, int):
root_index = [root_index]
self.root_index = root_index
self.remove_root = remove_root
self.save_index = save_index
Expand Down Expand Up @@ -143,36 +146,34 @@ def encode(self,
# Zero-center the target pose around a given root keypoint
if self.zero_center:
assert (lifting_target.ndim >= 2 and
lifting_target.shape[-2] > self.root_index), \
lifting_target.shape[-2] > max(self.root_index)), \
f'Got invalid joint shape {lifting_target.shape}'

root = lifting_target[..., self.root_index, :]
lifting_target_label -= lifting_target_label[
..., self.root_index:self.root_index + 1, :]
root = np.mean(lifting_target[..., self.root_index, :], axis=-2)
lifting_target_label -= root[..., np.newaxis, :]
encoded['target_root'] = root

if self.remove_root:
if self.remove_root and len(self.root_index) == 1:
root_index = self.root_index[0]
lifting_target_label = np.delete(
lifting_target_label, self.root_index, axis=-2)
lifting_target_label, root_index, axis=-2)
lifting_target_visible = np.delete(
lifting_target_visible, self.root_index, axis=-2)
lifting_target_visible, root_index, axis=-2)
assert lifting_target_weight.ndim in {
2, 3
}, (f'Got invalid lifting target weights shape '
f'{lifting_target_weight.shape}')

axis_to_remove = -2 if lifting_target_weight.ndim == 3 else -1
lifting_target_weight = np.delete(
lifting_target_weight,
self.root_index,
axis=axis_to_remove)
lifting_target_weight, root_index, axis=axis_to_remove)
# Add a flag to avoid latter transforms that rely on the root
# joint or the original joint index
encoded['target_root_removed'] = True

# Save the root index for restoring the global pose
if self.save_index:
encoded['target_root_index'] = self.root_index
encoded['target_root_index'] = root_index

# Normalize the 2D keypoint coordinate with image width and height
_camera_param = deepcopy(camera_param)
Expand Down Expand Up @@ -237,7 +238,7 @@ def decode(self,

if target_root is not None and target_root.size > 0:
keypoints = keypoints + target_root
if self.remove_root:
if self.remove_root and len(self.root_index) == 1:
keypoints = np.insert(
keypoints, self.root_index, target_root, axis=1)
scores = np.ones(keypoints.shape[:-1], dtype=np.float32)
Expand Down
3 changes: 2 additions & 1 deletion mmpose/datasets/datasets/wholebody3d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .h3wb_dataset import H36MWholeBodyDataset
from .ubody3d_dataset import UBody3dDataset

__all__ = ['UBody3dDataset']
__all__ = ['UBody3dDataset', 'H36MWholeBodyDataset']
Loading

0 comments on commit fdab6f7

Please sign in to comment.