diff --git a/cvat-sdk/cvat_sdk/pytorch/__init__.py b/cvat-sdk/cvat_sdk/pytorch/__init__.py index 7812954548d9..d8082e85744a 100644 --- a/cvat-sdk/cvat_sdk/pytorch/__init__.py +++ b/cvat-sdk/cvat_sdk/pytorch/__init__.py @@ -66,8 +66,7 @@ class Target: label_id_to_index: Mapping[int, int] """ A mapping from label_id values in `LabeledImage` and `LabeledShape` objects - to an index in the range [0, num_labels), where num_labels is the number of labels - defined in the task. This mapping is consistent across all samples for a given task. + to an integer index. This mapping is consistent across all samples for a given task. """ @@ -100,6 +99,7 @@ def __init__( transforms: Optional[Callable] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + label_name_to_index: Mapping[str, int] = None, ) -> None: """ Creates a dataset corresponding to the task with ID `task_id` on the @@ -108,6 +108,17 @@ def __init__( `transforms`, `transform` and `target_transforms` are optional transformation functions; see the documentation for `torchvision.datasets.VisionDataset` for more information. + + `label_name_to_index` affects the `label_id_to_index` member in `Target` objects + returned by the dataset. If it is specified, then it must contain an entry for + each label name in the task. The `label_id_to_index` mapping will be constructed + so that each label will be mapped to the index corresponding to the label's name + in `label_name_to_index`. + + If `label_name_to_index` is unspecified or set to `None`, then `label_id_to_index` + will map each label ID to a distinct integer in the range [0, `num_labels`), where + `num_labels` is the number of labels defined in the task. This mapping will be + generally unpredictable, but consistent for a given task. """ self._logger = client.logger @@ -163,12 +174,20 @@ def __init__( self._logger.info("All chunks downloaded") - self._label_id_to_index = types.MappingProxyType( - { - label["id"]: label_index - for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id)) - } - ) + if label_name_to_index is None: + self._label_id_to_index = types.MappingProxyType( + { + label.id: label_index + for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id)) + } + ) + else: + self._label_id_to_index = types.MappingProxyType( + { + label.id: label_name_to_index[label.name] + for label in self._task.labels + } + ) annotations = self._ensure_model( "annotations.json", LabeledData, self._task.get_annotations, "annotations" diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index 5be990c60fb8..b4f0238beab4 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -211,3 +211,18 @@ def test_transforms(self): assert isinstance(dataset[0][0], cvatpt.Target) assert isinstance(dataset[0][1], PIL.Image.Image) + + def test_custom_label_mapping(self): + label_name_to_id = { + label.name: label.id for label in self.task.labels + } + + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + label_name_to_index={"person": 123, "car": 456}, + ) + + _, target = dataset[5] + assert target.label_id_to_index[label_name_to_id["person"]] == 123 + assert target.label_id_to_index[label_name_to_id["car"]] == 456