Skip to content

Commit

Permalink
fix_bbox_sanitize_tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Aug 24, 2023
1 parent 054432d commit 9d11700
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
7 changes: 3 additions & 4 deletions gallery/v2_transforms/plot_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,9 @@
format="XYXY", canvas_size=img.shape[-2:])

transforms = v2.Compose([
v2.RandomPhotometricDistort(),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=0.5),
v2.SanitizeBoundingBoxes(),
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
])
out_img, out_bboxes = transforms(img, bboxes)

Expand Down
14 changes: 14 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,20 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_labels.tolist() == valid_indices


def test_sanitize_bounding_boxes_no_label():
# Non-regression test for https://github.com/pytorch/vision/issues/7878

img = make_image()
boxes = make_bounding_boxes()

with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
transforms.SanitizeBoundingBoxes()(img, boxes)

out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
assert isinstance(out_img, datapoints.Image)
assert isinstance(out_boxes, datapoints.BoundingBoxes)


def test_sanitize_bounding_boxes_errors():

good_bbox = datapoints.BoundingBoxes(
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
inputs = inputs[1]

# MixUp, CutMix
if isinstance(inputs, torch.Tensor):
if is_pure_tensor(inputs):
return inputs

if not isinstance(inputs, collections.abc.Mapping):
Expand Down

0 comments on commit 9d11700

Please sign in to comment.