Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add SRFacicalLandmarkDataset. #329

Merged
merged 3 commits into from
May 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmedit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .img_inpainting_dataset import ImgInpaintingDataset
from .registry import DATASETS, PIPELINES
from .sr_annotation_dataset import SRAnnotationDataset
from .sr_facical_landmark_dataset import SRFacicalLandmarkDataset
from .sr_folder_dataset import SRFolderDataset
from .sr_folder_gt_dataset import SRFolderGTDataset
from .sr_folder_ref_dataset import SRFolderRefDataset
Expand All @@ -29,5 +30,5 @@
'SRVimeo90KDataset', 'BaseGenerationDataset', 'GenerationPairedDataset',
'GenerationUnpairedDataset', 'SRVid4Dataset', 'SRFolderGTDataset',
'SRREDSMultipleGTDataset', 'SRVimeo90KMultipleGTDataset',
'SRTestMultipleGTDataset', 'SRFolderRefDataset'
'SRTestMultipleGTDataset', 'SRFolderRefDataset', 'SRFacicalLandmarkDataset'
]
63 changes: 63 additions & 0 deletions mmedit/datasets/sr_facical_landmark_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os.path as osp

import numpy as np

from .base_sr_dataset import BaseSRDataset
from .registry import DATASETS


@DATASETS.register_module()
class SRFacicalLandmarkDataset(BaseSRDataset):
"""Facical image and landmark dataset with an annotation file for image
restoration.

The dataset loads gt (Ground-Truth) image, shape of image, face box, and
landmark. Applies specified transforms and finally returns a dict
containing paired data and other information.

This is the "annotation file mode":
Each dict in the annotation list contains the image names, image shape,
face box, and landmark.

Annotation file is a `npy` file, which contains a list of dict.
Example of an annotation file:

::

dict1(file=*, bbox=*, shape=*, landmark=*)
dict2(file=*, bbox=*, shape=*, landmark=*)

Args:
gt_folder (str | :obj:`Path`): Path to a gt folder.
ann_file (str | :obj:`Path`): Path to the annotation file.
pipeline (list[dict | callable]): A sequence of data transformations.
scale (int): Upsampling scale ratio.
test_mode (bool): Store `True` when building test dataset.
Default: `False`.
"""

def __init__(self, gt_folder, ann_file, pipeline, scale, test_mode=False):
super().__init__(pipeline, scale, test_mode)
self.gt_folder = str(gt_folder)
self.ann_file = str(ann_file)
self.data_infos = self.load_annotations()

def load_annotations(self):
"""Load annoations for SR dataset.

Annotation file is a `npy` file, which contains a list of dict.

It loads the GT image path and landmark from the annotation file.
Each dict in the annotation file contains the image names, image
shape (usually for gt), bbox and landmark.

Returns:
dict: Returned dict for GT and landmark.
Contains: gt_path, bbox, shape, landmark.
"""
data_infos = np.load(self.ann_file, allow_pickle=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use mmcv.load. Benefit: it will support json automatically.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it doesn't support "npy" file, and here we don't need JSON data.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

for data_info in data_infos:
data_info['gt_path'] = osp.join(self.gt_folder,
data_info['gt_path'])

return data_infos
Binary file added tests/data/face/000001.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/data/facemark_ann.npy
Binary file not shown.
45 changes: 42 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from mmedit.datasets import (AdobeComp1kDataset, BaseGenerationDataset,
BaseSRDataset, GenerationPairedDataset,
GenerationUnpairedDataset, RepeatDataset,
SRAnnotationDataset, SRFolderDataset,
SRFolderGTDataset, SRFolderRefDataset,
SRLmdbDataset, SRREDSDataset,
SRAnnotationDataset, SRFacicalLandmarkDataset,
SRFolderDataset, SRFolderGTDataset,
SRFolderRefDataset, SRLmdbDataset, SRREDSDataset,
SRREDSMultipleGTDataset, SRTestMultipleGTDataset,
SRVid4Dataset, SRVimeo90KDataset,
SRVimeo90KMultipleGTDataset)
Expand Down Expand Up @@ -376,6 +376,45 @@ def test_sr_folder_ref_dataset(self):
scale=4,
filename_tmpl_lq=filename_tmpl)

def test_sr_landmark_dataset(self):
# setup
sr_pipeline = [
dict(
type='LoadImageFromFile',
io_backend='disk',
key='gt',
flag='color',
channel_order='rgb',
backend='cv2')
]

target_keys = ['gt_path', 'bbox', 'shape', 'landmark']
gt_folder = self.data_prefix / 'face'
ann_file = self.data_prefix / 'facemark_ann.npy'

# input path is Path object
sr_landmark_dataset = SRFacicalLandmarkDataset(
gt_folder=gt_folder,
ann_file=ann_file,
pipeline=sr_pipeline,
scale=4)
data_infos = sr_landmark_dataset.data_infos
assert len(data_infos) == 1
result = sr_landmark_dataset[0]
assert len(sr_landmark_dataset) == 1
assert check_keys_contain(result.keys(), target_keys)
# input path is str
sr_landmark_dataset = SRFacicalLandmarkDataset(
gt_folder=str(gt_folder),
ann_file=str(ann_file),
pipeline=sr_pipeline,
scale=4)
data_infos = sr_landmark_dataset.data_infos
assert len(data_infos) == 1
result = sr_landmark_dataset[0]
assert len(sr_landmark_dataset) == 1
assert check_keys_contain(result.keys(), target_keys)

def test_sr_lmdb_dataset(self):
# setup
lq_lmdb_folder = self.data_prefix / 'lq.lmdb'
Expand Down