Skip to content

Commit

Permalink
SDK: change extractor transforms to return tensors instead of primiti…
Browse files Browse the repository at this point in the history
…ve types

This is more consistent with torchvision models, and lets the
`ExtractBoundingBoxes` outputs to be used directly with the `torchmetrics`
package (or at least the `MeanAveragePrecision` class).
  • Loading branch information
SpecLad committed Dec 13, 2022
1 parent a3b4f97 commit 74a8cbf
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
20 changes: 12 additions & 8 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import attrs
import attrs.validators
import PIL.Image
import torch
import torchvision.datasets
from typing_extensions import TypedDict

Expand Down Expand Up @@ -283,7 +284,7 @@ def __len__(self) -> int:
class ExtractSingleLabelIndex:
"""
A target transform that takes a `Target` object and produces a single label index
based on the tag in that object.
based on the tag in that object, as a 0-dimensional tensor.
This makes the dataset samples compatible with the image classification networks
in torchvision.
Expand All @@ -299,12 +300,12 @@ def __call__(self, target: Target) -> int:
if len(tags) > 1:
raise ValueError("sample has multiple tags")

return target.label_id_to_index[tags[0].label_id]
return torch.tensor(target.label_id_to_index[tags[0].label_id], dtype=torch.int)


class LabeledBoxes(TypedDict):
boxes: Sequence[Tuple[float, float, float, float]]
labels: Sequence[int]
boxes: torch.Tensor
labels: torch.Tensor


_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"])
Expand All @@ -318,9 +319,9 @@ class ExtractBoundingBoxes:
The dictionary contains the following entries:
"boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape
in the annotations.
"labels": a sequence of corresponding label indices.
"boxes": a tensor with shape [N, 4], where each row represents a bounding box of a shape
in the annotations in the (xmin, ymin, xmax, ymax) format.
"labels": a tensor with shape [N] containing corresponding label indices.
Limitations:
Expand Down Expand Up @@ -356,4 +357,7 @@ def __call__(self, target: Target) -> LabeledBoxes:
boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords)))
labels.append(target.label_id_to_index[shape.label_id])

return LabeledBoxes(boxes=boxes, labels=labels)
return LabeledBoxes(
boxes=torch.tensor(boxes, dtype=torch.float),
labels=torch.tensor(labels, dtype=torch.int),
)
16 changes: 11 additions & 5 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def test_extract_single_label_index(self):
target_transform=cvatpt.ExtractSingleLabelIndex(),
)

assert dataset[5][1] == 0
assert dataset[6][1] == 1
assert torch.equal(dataset[5][1], torch.tensor(0))
assert torch.equal(dataset[6][1], torch.tensor(1))

with pytest.raises(ValueError):
# no tags
Expand All @@ -192,9 +192,15 @@ def test_extract_bounding_boxes(self):
target_transform=cvatpt.ExtractBoundingBoxes(include_shape_types={"rectangle"}),
)

assert dataset[0][1] == {"boxes": [], "labels": []}
assert dataset[6][1] == {"boxes": [(1.0, 2.0, 3.0, 4.0)], "labels": [1]}
assert dataset[7][1] == {"boxes": [], "labels": []} # points are filtered out
assert torch.equal(dataset[0][1]["boxes"], torch.tensor([]))
assert torch.equal(dataset[0][1]["labels"], torch.tensor([]))

assert torch.equal(dataset[6][1]["boxes"], torch.tensor([(1.0, 2.0, 3.0, 4.0)]))
assert torch.equal(dataset[6][1]["labels"], torch.tensor([1]))

# points are filtered out
assert torch.equal(dataset[7][1]["boxes"], torch.tensor([]))
assert torch.equal(dataset[7][1]["labels"], torch.tensor([]))

def test_transforms(self):
dataset = cvatpt.TaskVisionDataset(
Expand Down

0 comments on commit 74a8cbf

Please sign in to comment.