From 8e8a208117a654c0945fa08208d2e7e6f3206441 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Oct 2024 12:35:11 +0100 Subject: [PATCH] [Cherry-pick for 0.20] Expose transforms.v2 utils for writing custom transforms (#8673) Co-authored-by: venkatram-dev <45727389+venkatram-dev@users.noreply.github.com> --- test/test_transforms_v2.py | 47 +++++++++++++++++++++++++++ torchvision/transforms/v2/__init__.py | 1 + 2 files changed, 48 insertions(+) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index f9218c3e840..e16c0677c9f 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -6169,3 +6169,50 @@ def test_transform_sequence_len_error(self, quality): def test_transform_invalid_quality_error(self, quality): with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"): transforms.JPEG(quality=quality) + + +class TestUtils: + # TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes + @pytest.mark.parametrize( + "make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize( + "make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) + def test_query_size_and_query_chw(self, make_input1, make_input2, query): + size = (32, 64) + input1 = make_input1(size) + input2 = make_input2(size) + + if query is transforms.query_chw and not any( + transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + for inpt in (input1, input2) + ): + return + + expected = size if query is transforms.query_size else ((3,) + size) + assert query([input1, input2]) == expected + + @pytest.mark.parametrize( + "make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize( + "make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) + def test_different_sizes(self, make_input1, make_input2, query): + input1 = make_input1((10, 10)) + input2 = make_input2((20, 20)) + if query is transforms.query_chw and not all( + transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + for inpt in (input1, input2) + ): + return + with pytest.raises(ValueError, match="Found multiple"): + query([input1, input2]) + + @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) + def test_no_valid_input(self, query): + with pytest.raises(TypeError, match="No image"): + query(["blah"]) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 33d83f1fe3f..2d66917b6ea 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -55,5 +55,6 @@ ) from ._temporal import UniformTemporalSubsample from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor +from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size from ._deprecated import ToTensor # usort: skip