From 014d039cb38904e74abe1859e3df10201f6b5a98 Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Tue, 13 Dec 2022 12:01:24 +0300 Subject: [PATCH] TaskVisionDataset: add a way to specify a custom label-name-to-index mapping The default numbering is generally unpredictable, so it's not especially useful outside of basic use cases. Given that there's no way to specify custom indexes for labels on the server side, make it possible on the client side. The input dictionary maps label _names_ to indexes, because names is what a CVAT user would be familiar with. The output dictionary in the `Target` class still maps label _IDs_, because that's what CVAT uses in the annotation data models. --- cvat-sdk/cvat_sdk/pytorch/__init__.py | 35 +++++++++++++++++++++------ tests/python/sdk/test_pytorch.py | 15 ++++++++++++ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/cvat-sdk/cvat_sdk/pytorch/__init__.py b/cvat-sdk/cvat_sdk/pytorch/__init__.py index 7812954548d9..d8082e85744a 100644 --- a/cvat-sdk/cvat_sdk/pytorch/__init__.py +++ b/cvat-sdk/cvat_sdk/pytorch/__init__.py @@ -66,8 +66,7 @@ class Target: label_id_to_index: Mapping[int, int] """ A mapping from label_id values in `LabeledImage` and `LabeledShape` objects - to an index in the range [0, num_labels), where num_labels is the number of labels - defined in the task. This mapping is consistent across all samples for a given task. + to an integer index. This mapping is consistent across all samples for a given task. """ @@ -100,6 +99,7 @@ def __init__( transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + label_name_to_index: Mapping[str, int] = None, ) -> None: """ Creates a dataset corresponding to the task with ID `task_id` on the @@ -108,6 +108,17 @@ def __init__( `transforms`, `transform` and `target_transforms` are optional transformation functions; see the documentation for `torchvision.datasets.VisionDataset` for more information. + + `label_name_to_index` affects the `label_id_to_index` member in `Target` objects + returned by the dataset. If it is specified, then it must contain an entry for + each label name in the task. The `label_id_to_index` mapping will be constructed + so that each label will be mapped to the index corresponding to the label's name + in `label_name_to_index`. + + If `label_name_to_index` is unspecified or set to `None`, then `label_id_to_index` + will map each label ID to a distinct integer in the range [0, `num_labels`), where + `num_labels` is the number of labels defined in the task. This mapping will be + generally unpredictable, but consistent for a given task. """ self._logger = client.logger @@ -163,12 +174,20 @@ def __init__( self._logger.info("All chunks downloaded") - self._label_id_to_index = types.MappingProxyType( - { - label["id"]: label_index - for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id)) - } - ) + if label_name_to_index is None: + self._label_id_to_index = types.MappingProxyType( + { + label.id: label_index + for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id)) + } + ) + else: + self._label_id_to_index = types.MappingProxyType( + { + label.id: label_name_to_index[label.name] + for label in self._task.labels + } + ) annotations = self._ensure_model( "annotations.json", LabeledData, self._task.get_annotations, "annotations" diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 5be990c60fb8..b4f0238beab4 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -211,3 +211,18 @@ def test_transforms(self): assert isinstance(dataset[0][0], cvatpt.Target) assert isinstance(dataset[0][1], PIL.Image.Image) + + def test_custom_label_mapping(self): + label_name_to_id = { + label.name: label.id for label in self.task.labels + } + + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + label_name_to_index={"person": 123, "car": 456}, + ) + + _, target = dataset[5] + assert target.label_id_to_index[label_name_to_id["person"]] == 123 + assert target.label_id_to_index[label_name_to_id["car"]] == 456