Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding shrec dataset #375

Merged
merged 13 commits into from
Jun 18, 2021
1 change: 1 addition & 0 deletions docs/modules/kaolin.io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ and :ref:`materials module<kaolin.io.materials>` contains Materials definition t
kaolin.io.shapenet
kaolin.io.usd
kaolin.io.modelnet
kaolin.io.shrec
1 change: 1 addition & 0 deletions kaolin/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from . import shapenet
from . import usd
from . import modelnet
from . import shrec
6 changes: 0 additions & 6 deletions kaolin/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,18 @@ def _get_cache_key(dataset, index):

KaolinDatasetItem = namedtuple('KaolinDatasetItem', ['data', 'attributes'])


class KaolinDataset(Dataset):
"""A dataset supporting the separation of data and attributes, and combines
them in its `__getitem__`.

The return value of `__getitem__` will be a named tuple containing the
return value of both `get_data` and `get_attributes`.

The difference between `get_data` and `get_attributes` is that data are able
to be transformed or preprocessed (such as using `ProcessedDataset`), while
attributes are generally not.
"""

def __getitem__(self, index):
"""Returns the item at the given index.

Will contain a named tuple of both data and attributes.
"""
attributes = self.get_attributes(index)
Expand All @@ -156,7 +152,6 @@ def get_data(self, index):
@abstractmethod
def get_attributes(self, index):
"""Returns the attributes at the given index.

Attributes are usually not transformed by wrappers such as
`ProcessedDataset`.
"""
Expand All @@ -167,7 +162,6 @@ def __len__(self):
"""Returns the number of entries."""
pass


class ProcessedDataset(KaolinDataset):
def __init__(self, dataset, preprocessing_transform=None,
cache_dir=None, num_workers=None, transform=None,
Expand Down
179 changes: 179 additions & 0 deletions kaolin/io/shrec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) 2019-2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from pathlib import Path

from kaolin.io.dataset import KaolinDataset
from kaolin.io.obj import import_mesh, ignore_error_handler

synset_to_labels = {
'03790512': ['motorcycle', 'bike'],
'02808440': ['bathtub', 'bathing tub', 'bath', 'tub'],
'02871439': ['bookshelf'],
'03761084': ['microwave', 'microwave oven'],
'04530566': ['vessel', 'watercraft'],
'02691156': ['airplane', 'aeroplane', 'plane'],
'04379243': ['table'],
'03337140': ['file', 'file cabinet', 'filing cabinet'],
'04256520': ['sofa', 'couch', 'lounge'],
'03636649': ['lamp'],
'03928116': ['piano', 'pianoforte', 'forte-piano'],
'04004475': ['printer', 'printing machine'],
'03593526': ['jar'],
'04330267': ['stove'],
'04554684': ['washer', 'automatic washer', 'washing machine'],
'03948459': ['pistol', 'handgun', 'side arm', 'shooting iron'],
'03001627': ['chair'],
'03797390': ['mug'],
'02801938': ['basket', 'handbasket'],
'03710193': ['mailbox', 'letter box'],
'03938244': ['pillow'],
'03624134': ['knife'],
'02954340': ['cap'],
'02773838': ['bag', 'traveling bag', 'travelling bag', 'grip', 'suitcase'],
'02747177': ['ashcan', 'trash can', 'garbage can', 'wastebin',
'ash bin', 'ash-bin', 'ashbin', 'dustbin', 'trash barrel', 'trash bin'],
'04460130': ['tower'],
'02933112': ['cabinet'],
'02876657': ['bottle'],
'03991062': ['pot', 'flowerpot'],
'02843684': ['birdhouse'],
'02818832': ['bed'],
'02958343': ['car', 'auto', 'automobile', 'machine', 'motorcar'],
'03642806': ['laptop', 'laptop computer'],
'03085013': ['computer keyboard', 'keypad'],
'04074963': ['remote control', 'remote'],
'02924116': ['bus', 'autobus', 'coach', 'charabanc', 'double-decker',
'jitney', 'motorbus', 'motorcoach', 'omnibus', 'passenger vehi'],
'04225987': ['skateboard'],
'03261776': ['earphone', 'earpiece', 'headphone', 'phone'],
'02880940': ['bowl'],
'03325088': ['faucet', 'spigot'],
'03211117': ['display', 'video display'],
'04468005': ['train', 'railroad train'],
'03691459': ['loudspeaker', 'speaker', 'speaker unit', 'loudspeaker system', 'speaker system'],
'04090263': ['rifle'],
'02946921': ['can', 'tin', 'tin can'],
'04099429': ['rocket', 'projectile'],
'03467517': ['guitar'],
'04401088': ['telephone', 'phone', 'telephone set'],
'03046257': ['clock'],
'03759954': ['microphone', 'mike'],
'03513137': ['helmet'],
'02834778': ['bicycle', 'bike', 'wheel', 'cycle'],
'03207941': ['dishwasher', 'dish washer', 'dishwashing machine'],
'02828884': ['bench'],
'02942699': ['camera', 'photographic camera']}

# Label to Synset mapping (for ShapeNet core classes)
label_to_synset = {label: synset for synset, labels in synset_to_labels.items() for label in labels}

def _convert_categories(categories):
if not (c in synset_to_label.keys() + label_to_synset.keys()
for c in categories):
warnings.warn('Some or all of the categories requested are not part of \
Shrec16. Data loading may fail if these categories are not avaliable.')
synsets = [label_to_synset[c] if c in label_to_synset.keys()
else c for c in categories]
return synsets

class SHREC16(KaolinDataset):
r"""Dataset class for SHREC16, used for the "Large-scale 3D shape retrieval
from ShapeNet Core55" contest at Eurographics 2016.
More details about the challenge and the dataset are available
`here <https://shapenet.cs.stanford.edu/shrec16/>`_.

The `__getitem__` method will return a `KaolinDatasetItem`, with its `data`
field containing a `kaolin.io.obj.return_type`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer something like "with its data field containing a namedtuple returned by :func:kaolin.io.obj.import_mesh" since kaolin.io.obj.return_type is not documented.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detailed description added.


Args:
root (str): Path to the root directory of the dataset.
categories (list): List of categories to load (each class is
specified as a string, and must be a valid `SHREC16`
category). If this argument is not specified, all categories
are loaded by default.
split (str): String to indicate whether to load train, test or val set.
"""

def __init__(self, root: str, categories: list = None, split: str = "train"):
Caenorst marked this conversation as resolved.
Show resolved Hide resolved

self.root = Path(root)
self.paths = []
self.synset_idxs = []

if split == "test":
# Setting synsets and labels to None if in test split
self.synsets = [None]
self.labels = [None]
else:
if categories is None:
self.synsets = list(synset_to_labels.keys())
else:
self.synsets = _convert_categories(categories)
self.labels = [synset_to_labels[s] for s in self.synsets]

# loops through desired classes
if split == "test_allinone":
class_target = self.root / "test"
# find all objects in the class
models = sorted(class_target.glob('*'))

self.paths += models
self.synset_idxs += [0] * len(models)

else:
for i in range(len(self.synsets)):
syn = self.synsets[i]

if split == "train":
class_target = self.root / "train" / syn
elif split == "val":
class_target = self.root / "val" / syn
else:
raise ValueError(f'Split must be either train, test or val, got {split} instead.')

if not class_target.exists():
raise ValueError(
'Class {0} ({1}) was not found at location {2}.'.format(
syn, self.labels[i], str(class_target)))

# find all objects in the class
models = sorted(class_target.glob('*'))

self.paths += models
self.synset_idxs += [i] * len(models)

self.names = [p.name for p in self.paths]

def __len__(self):
return len(self.paths)

def get_data(self, index):
obj_location = self.paths[index]
mesh = import_mesh(str(obj_location), error_handler=ignore_error_handler)
return mesh

def get_attributes(self, index):
synset_idx = self.synset_idxs[index]
attributes = {
'name': self.names[index],
'path': self.paths[index],
'synset': self.synsets[synset_idx],
'label': self.labels[synset_idx]
}
return attributes

def get_cache_key(self, index):
return self.names[index]
103 changes: 103 additions & 0 deletions tests/python/kaolin/io/test_shrec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import os

import pytest
import torch

from kaolin.io.obj import return_type
from kaolin.io.shrec import SHREC16

SHREC16_PATH = '/data/shrec16/'
SHREC16_TEST_CATEGORY_SYNSETS = ['02691156']
SHREC16_TEST_CATEGORY_LABELS = ['airplane']
SHREC16_TEST_CATEGORY_SYNSETS_2 = ['02958343']
SHREC16_TEST_CATEGORY_LABELS_2 = ['car']
SHREC16_TEST_CATEGORY_SYNSETS_MULTI = ['02691156', '02958343']
SHREC16_TEST_CATEGORY_LABELS_MULTI = ['airplane', 'car']

ALL_CATEGORIES = [
None,
SHREC16_TEST_CATEGORY_SYNSETS,
SHREC16_TEST_CATEGORY_LABELS,
SHREC16_TEST_CATEGORY_SYNSETS_2,
SHREC16_TEST_CATEGORY_LABELS_2,
SHREC16_TEST_CATEGORY_SYNSETS_MULTI,
SHREC16_TEST_CATEGORY_LABELS_MULTI,
]


# Skip test in a CI environment
@pytest.mark.skipif(os.getenv('CI') == 'true', reason="CI does not have dataset")
@pytest.mark.parametrize('categories', ALL_CATEGORIES)
@pytest.mark.parametrize('split', ['train', 'val', 'test'])
class TestSHREC16(object):

@pytest.fixture(autouse=True)
def shrec16_dataset(self, categories, split):
return SHREC16(root=SHREC16_PATH,
categories=categories,
split=split)

@pytest.mark.parametrize('index', [0, -1])
def test_basic_getitem(self, shrec16_dataset, index, split):
assert len(shrec16_dataset) > 0

if index == -1:
index = len(shrec16_dataset) - 1

item = shrec16_dataset[index]
data = item.data
attributes = item.attributes
assert isinstance(data, return_type)
assert isinstance(attributes, dict)

assert isinstance(data.vertices, torch.Tensor)
assert len(data.vertices.shape) == 2
assert data.vertices.shape[1] == 3
assert isinstance(data.faces, torch.Tensor)
assert len(data.faces.shape) == 2

assert isinstance(attributes['name'], str)
assert isinstance(attributes['path'], Path)

if split == "test":
assert attributes['synset'] is None
assert attributes['label'] is None
else:
assert isinstance(attributes['synset'], str)
assert isinstance(attributes['label'], list)

@pytest.mark.parametrize('index', [-1, -2])
def test_neg_index(self, shrec16_dataset, index):

assert len(shrec16_dataset) > 0

gt_item = shrec16_dataset[len(shrec16_dataset) + index]
gt_data = gt_item.data
gt_attributes = gt_item.attributes

item = shrec16_dataset[index]
data = item.data
attributes = item.attributes

assert torch.equal(data.vertices, gt_data.vertices)
assert torch.equal(data.faces, gt_data.faces)

assert attributes['name'] == gt_attributes['name']
assert attributes['path'] == gt_attributes['path']
assert attributes['synset'] == gt_attributes['synset']