From f7c7bdf5fae5e946e67812c5791d91cec83fccd5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Aug 2023 15:41:54 +0100 Subject: [PATCH 1/2] Stricter SanitizeBoundingBoxes labels_getter heuristic (#7880) --- gallery/v2_transforms/plot_transforms_v2.py | 7 +++---- test/test_transforms_v2.py | 14 ++++++++++++++ torchvision/transforms/v2/_utils.py | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/gallery/v2_transforms/plot_transforms_v2.py b/gallery/v2_transforms/plot_transforms_v2.py index 3058df23444..0f97431da59 100644 --- a/gallery/v2_transforms/plot_transforms_v2.py +++ b/gallery/v2_transforms/plot_transforms_v2.py @@ -99,10 +99,9 @@ format="XYXY", canvas_size=img.shape[-2:]) transforms = v2.Compose([ - v2.RandomPhotometricDistort(), - v2.RandomIoUCrop(), - v2.RandomHorizontalFlip(p=0.5), - v2.SanitizeBoundingBoxes(), + v2.RandomResizedCrop(size=(224, 224), antialias=True), + v2.RandomPhotometricDistort(p=1), + v2.RandomHorizontalFlip(p=1), ]) out_img, out_bboxes = transforms(img, bboxes) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 325d864dc05..982c86d0426 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1256,6 +1256,20 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): assert out_labels.tolist() == valid_indices +def test_sanitize_bounding_boxes_no_label(): + # Non-regression test for https://github.com/pytorch/vision/issues/7878 + + img = make_image() + boxes = make_bounding_boxes() + + with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"): + transforms.SanitizeBoundingBoxes()(img, boxes) + + out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes) + assert isinstance(out_img, datapoints.Image) + assert isinstance(out_boxes, datapoints.BoundingBoxes) + + def test_sanitize_bounding_boxes_errors(): good_bbox = datapoints.BoundingBoxes( diff --git a/torchvision/transforms/v2/_utils.py b/torchvision/transforms/v2/_utils.py index 3c6977fae91..6b327d45c0e 100644 --- a/torchvision/transforms/v2/_utils.py +++ b/torchvision/transforms/v2/_utils.py @@ -112,7 +112,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: inputs = inputs[1] # MixUp, CutMix - if isinstance(inputs, torch.Tensor): + if is_pure_tensor(inputs): return inputs if not isinstance(inputs, collections.abc.Mapping): From b82d8833c2a872acf1d73541032c57762dc5f0cc Mon Sep 17 00:00:00 2001 From: David Chiu Date: Thu, 24 Aug 2023 23:21:53 +0800 Subject: [PATCH 2/2] Fix typos in _augment.py (#7877) Co-authored-by: Nicolas Hug --- torchvision/transforms/v2/_augment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 130950fee34..a9bad8f9bf7 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -229,7 +229,7 @@ class MixUp(_BaseMixUpCutMix): alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. num_classes (int): number of classes in the batch. Used for one-hot-encoding. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. - By default, this will pick the second parameter a the labels if it's a tensor. This covers the most + By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``. It can also be a callable that takes the same input as the transform, and returns the labels. """ @@ -279,7 +279,7 @@ class CutMix(_BaseMixUpCutMix): alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. num_classes (int): number of classes in the batch. Used for one-hot-encoding. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. - By default, this will pick the second parameter a the labels if it's a tensor. This covers the most + By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``. It can also be a callable that takes the same input as the transform, and returns the labels. """