From 9d11700fa9d4846cb79ac3a04aa640b820ab6e8e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Aug 2023 15:14:00 +0100 Subject: [PATCH] fix_bbox_sanitize_tensor --- gallery/v2_transforms/plot_transforms_v2.py | 7 +++---- test/test_transforms_v2.py | 14 ++++++++++++++ torchvision/transforms/v2/_utils.py | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/gallery/v2_transforms/plot_transforms_v2.py b/gallery/v2_transforms/plot_transforms_v2.py index 3058df23444..0f97431da59 100644 --- a/gallery/v2_transforms/plot_transforms_v2.py +++ b/gallery/v2_transforms/plot_transforms_v2.py @@ -99,10 +99,9 @@ format="XYXY", canvas_size=img.shape[-2:]) transforms = v2.Compose([ - v2.RandomPhotometricDistort(), - v2.RandomIoUCrop(), - v2.RandomHorizontalFlip(p=0.5), - v2.SanitizeBoundingBoxes(), + v2.RandomResizedCrop(size=(224, 224), antialias=True), + v2.RandomPhotometricDistort(p=1), + v2.RandomHorizontalFlip(p=1), ]) out_img, out_bboxes = transforms(img, bboxes) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 325d864dc05..982c86d0426 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1256,6 +1256,20 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): assert out_labels.tolist() == valid_indices +def test_sanitize_bounding_boxes_no_label(): + # Non-regression test for https://github.com/pytorch/vision/issues/7878 + + img = make_image() + boxes = make_bounding_boxes() + + with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"): + transforms.SanitizeBoundingBoxes()(img, boxes) + + out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes) + assert isinstance(out_img, datapoints.Image) + assert isinstance(out_boxes, datapoints.BoundingBoxes) + + def test_sanitize_bounding_boxes_errors(): good_bbox = datapoints.BoundingBoxes( diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 3c6977fae91..6b327d45c0e 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -112,7 +112,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: inputs = inputs[1] # MixUp, CutMix - if isinstance(inputs, torch.Tensor): + if is_pure_tensor(inputs): return inputs if not isinstance(inputs, collections.abc.Mapping):