Skip to content

Commit

Permalink
TaskVisionDataset: add a way to specify a custom label-name-to-index …
Browse files Browse the repository at this point in the history
…mapping

The default numbering is generally unpredictable, so it's not especially
useful outside of basic use cases. Given that there's no way to specify
custom indexes for labels on the server side, make it possible on the client
side.

The input dictionary maps label _names_ to indexes, because names is what a
CVAT user would be familiar with. The output dictionary in the `Target`
class still maps label _IDs_, because that's what CVAT uses in the
annotation data models.
  • Loading branch information
SpecLad committed Dec 13, 2022
1 parent 74a8cbf commit 014d039
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
35 changes: 27 additions & 8 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,7 @@ class Target:
label_id_to_index: Mapping[int, int]
"""
A mapping from label_id values in `LabeledImage` and `LabeledShape` objects
to an index in the range [0, num_labels), where num_labels is the number of labels
defined in the task. This mapping is consistent across all samples for a given task.
to an integer index. This mapping is consistent across all samples for a given task.
"""


Expand Down Expand Up @@ -100,6 +99,7 @@ def __init__(
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
label_name_to_index: Mapping[str, int] = None,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
Expand All @@ -108,6 +108,17 @@ def __init__(
`transforms`, `transform` and `target_transforms` are optional transformation
functions; see the documentation for `torchvision.datasets.VisionDataset` for
more information.
`label_name_to_index` affects the `label_id_to_index` member in `Target` objects
returned by the dataset. If it is specified, then it must contain an entry for
each label name in the task. The `label_id_to_index` mapping will be constructed
so that each label will be mapped to the index corresponding to the label's name
in `label_name_to_index`.
If `label_name_to_index` is unspecified or set to `None`, then `label_id_to_index`
will map each label ID to a distinct integer in the range [0, `num_labels`), where
`num_labels` is the number of labels defined in the task. This mapping will be
generally unpredictable, but consistent for a given task.
"""

self._logger = client.logger
Expand Down Expand Up @@ -163,12 +174,20 @@ def __init__(

self._logger.info("All chunks downloaded")

self._label_id_to_index = types.MappingProxyType(
{
label["id"]: label_index
for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id))
}
)
if label_name_to_index is None:
self._label_id_to_index = types.MappingProxyType(
{
label.id: label_index
for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id))
}
)
else:
self._label_id_to_index = types.MappingProxyType(
{
label.id: label_name_to_index[label.name]
for label in self._task.labels
}
)

annotations = self._ensure_model(
"annotations.json", LabeledData, self._task.get_annotations, "annotations"
Expand Down
15 changes: 15 additions & 0 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,18 @@ def test_transforms(self):

assert isinstance(dataset[0][0], cvatpt.Target)
assert isinstance(dataset[0][1], PIL.Image.Image)

def test_custom_label_mapping(self):
label_name_to_id = {
label.name: label.id for label in self.task.labels
}

dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
label_name_to_index={"person": 123, "car": 456},
)

_, target = dataset[5]
assert target.label_id_to_index[label_name_to_id["person"]] == 123
assert target.label_id_to_index[label_name_to_id["car"]] == 456

0 comments on commit 014d039

Please sign in to comment.