Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[6D Pose Estimation] Support YOLO6D #575

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
9 changes: 7 additions & 2 deletions mmyolo/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .transforms import * # noqa: F401,F403
from .utils import BatchShapePolicy, yolov5_collate
from .yolo6d_linemod import YOLO6DDataset
from .yolov5_coco import YOLOv5CocoDataset
from .yolov5_crowdhuman import YOLOv5CrowdHumanDataset
from .yolov5_voc import YOLOv5VOCDataset

__all__ = [
'YOLOv5CocoDataset', 'YOLOv5VOCDataset', 'BatchShapePolicy',
'yolov5_collate', 'YOLOv5CrowdHumanDataset'
'YOLOv5CocoDataset',
'YOLOv5VOCDataset',
'BatchShapePolicy',
'yolov5_collate',
'YOLOv5CrowdHumanDataset',
'YOLO6DDataset',
]
5 changes: 4 additions & 1 deletion mmyolo/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .formatting import Pack6DInputs
from .loading import Load6DAnnotations
from .mix_img_transforms import Mosaic, Mosaic9, YOLOv5MixUp, YOLOXMixUp
from .transforms import (LetterResize, LoadAnnotations, PPYOLOERandomCrop,
PPYOLOERandomDistort, RemoveDataElement,
Expand All @@ -9,5 +11,6 @@
'YOLOv5KeepRatioResize', 'LetterResize', 'Mosaic', 'YOLOXMixUp',
'YOLOv5MixUp', 'YOLOv5HSVRandomAug', 'LoadAnnotations',
'YOLOv5RandomAffine', 'PPYOLOERandomDistort', 'PPYOLOERandomCrop',
'Mosaic9', 'YOLOv5CopyPaste', 'RemoveDataElement'
'Mosaic9', 'YOLOv5CopyPaste', 'RemoveDataElement', 'Load6DAnnotations',
'Pack6DInputs'
]
99 changes: 99 additions & 0 deletions mmyolo/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import numpy as np

from mmengine.structures import InstanceData
from mmyolo.registry import TRANSFORMS
from mmcv.transforms import BaseTransform, to_tensor
from mmyolo.structures import DataSample6D

@TRANSFORMS.register_module()
class Pack6DInputs(BaseTransform):
# todo: 预测值是哪个?需要更改
mapping_table = {
'gt_bboxes': 'bboxes',
'gt_bboxes_labels': 'labels',
'gt_rotations': 'rotation',
'gt_translations': 'translation',
'gt_center': 'center',
'gt_corners': 'corners'
}

def __init__(self,
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape')):
self.meta_keys = meta_keys

def transform(self, results: dict) -> dict:
"""Method to pack the input data.

Args:
results (dict): Result dict from the data pipeline.

Returns:
dict:

- 'inputs' (obj:`torch.Tensor`): The forward data of models.
- 'data_sample' (obj:`DetDataSample`): The annotation info of the
sample.
"""
packed_results = dict()
if 'img' in results:
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['input'] = to_tensor(img)

data_sample = DataSample6D()
instance_data = InstanceData()

for key in self.mapping_table.keys():
if key not in results:
continue
instance_data[self.mapping_table[key]] = to_tensor(
results[key])

data_sample.gt_instances = instance_data

img_meta = {}
for key in self.meta_keys:
assert key in results, f'`{key}` is not found in `results`, ' \
f'the valid keys are {list(results)}.'
img_meta[key] = results[key]

data_sample.set_metainfo(img_meta)
packed_results['data_samples'] = data_sample

return packed_results

def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(meta_keys={self.meta_keys})'
return repr_str


@TRANSFORMS.register_module()
class ToTensor:
"""Convert some results to :obj:`torch.Tensor` by given keys.

Args:
keys (Sequence[str]): Keys that need to be converted to Tensor.
"""

def __init__(self, keys):
self.keys = keys

def __call__(self, results):
"""Call function to convert data in results to :obj:`torch.Tensor`.

Args:
results (dict): Result dict contains the data to convert.

Returns:
dict: The result dict contains the data converted
to :obj:`torch.Tensor`.
"""
for key in self.keys:
results[key] = to_tensor(results[key])
return results

def __repr__(self):
return self.__class__.__name__ + f'(keys={self.keys})'
80 changes: 80 additions & 0 deletions mmyolo/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Optional, Tuple, Union, Dict

import numpy as np
from plyfile import PlyData

from mmyolo.registry import TRANSFORMS
from mmcv.transforms import BaseTransform
from mmcv.transforms import LoadAnnotations as MMCV_LoadAnnotations

@TRANSFORMS.register_module()
class Load6DAnnotations(MMCV_LoadAnnotations):
def __init__(self,
with_label: bool = True,
with_2d_bbox: bool = True,
with_corners: bool = True,
with_center: bool = True,
with_translation: bool = True,
with_rotation: bool = True,
file_client_args: dict = dict(backend='disk')
) -> None:
super(Load6DAnnotations, self).__init__(
with_bbox=with_2d_bbox,
with_label=with_label,
file_client_args=file_client_args
)
self.with_corners = with_corners
self.with_center = with_center
self.with_translation = with_translation
self.with_rotation = with_rotation


def _load_rotation(self, results:dict) -> None:
gt_rotations = []
for instance in results.get('instances', []):
gt_rotations.append(instance['rotation'])
results['gt_rotations'] = np.array(gt_rotations, dtype=np.float32).reshape(-1, 10)

def _load_translation(self, results:dict) -> None:
gt_translations = []
for instance in results.get('instances', []):
gt_translations.append(instance['translation'])
results['gt_translations'] = np.array(gt_translations, dtype=np.float32).reshape(-1,3)

def _load_center(self, results:dict) -> None:
gt_center = []
for instance in results['instances']:
gt_center.append(instance['center'])
results['gt_center'] = np.array(gt_center, dtype=np.float32).reshape(-1,2)

def _load_cornors(self, results: dict) -> None:
gt_cornors = []
for instance in results['instances']:
gt_cornors.append(instance['corners'])
results['gt_corners'] = np.array(gt_cornors, dtype=np.int64).reshape(1,-1)

def transform(self, results: dict) -> dict:
if self.with_label:
self._load_labels(results)
if self.with_bbox:
self._load_bboxes(results)
if self.with_corners:
self._load_cornors(results)
if self.with_center:
self._load_center(results)
if self.with_translation:
self._load_translation(results)
if self.with_rotation:
self._load_rotation(results)

return results

def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'with_label={self.with_label}, '
repr_str += f'(with_bbox={self.with_bbox}, '
repr_str += f'with_corners={self.with_corners}, '
repr_str += f'with_center={self.with_center}, '
repr_str += f'with_translation={self.with_translation}'
repr_str += f'file_client_args={self.file_client})'
return repr_str
73 changes: 73 additions & 0 deletions mmyolo/datasets/transforms/yolo6d_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import random
import numpy as np
import mmcv
import cv2
from PIL import Image, ImageMath
from mmyolo.registry import TRANSFORMS
from mmcv.transforms import BaseTransform

def change_background(img, mask, bg):
ow, oh = img.size
bg = bg.resize((ow, oh)).convert('RGB')

imcs = list(img.split())
bgcs = list(bg.split())
maskcs = list(mask.split())
fics = list(Image.new(img.mode, img.size).split())

for c in range(len(imcs)):
negmask = maskcs[c].point(lambda i: 1 - i / 255)
posmask = maskcs[c].point(lambda i: i / 255)
fics[c] = ImageMath.eval("a * c + b * d", a=imcs[c], b=bgcs[c], c=posmask, d=negmask).convert('L')
out = Image.merge(img.mode, tuple(fics))

return out


@TRANSFORMS.register_module()
class CopyPaste6D(BaseTransform):
"""change the background"""
def __init__(self,
shape,
num_keypoints,
max_num_gt
):
self.shape = shape
self.num_keypoints = num_keypoints
self.max_num_gt = max_num_gt

def transform(self, results:dict) -> dict:
## data augmentation
img = results['img']
maskpath = results['mask_path']
bgpath = results['bg_path']

# opencv -> PIL
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
mask = Image.open(maskpath)
bg = Image.open(bgpath)

img = change_background(img, mask, bg)
img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)

# match shape
if self.shape is not None:
oh, ow = img.shape[:2]
sx = self.shape[0]/ow
sy = self.shape[1]/oh
scale_ratio = max(sx, sy)

img = mmcv.imresize(img, (int(ow*scale_ratio),
int(oh*scale_ratio)))
img = img[0:self.shape[1], 0:self.shape[0]]

results['gt_cpts_norm'][..., 0] *= scale_ratio
results['gt_cpts_norm'][..., 1] *= scale_ratio

results['gt_bboxes'] *= scale_ratio

# PIL -> opencv
results['img'] = img

return results
Loading