diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index bed6180a4fb..e16c0677c9f 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -6171,50 +6171,48 @@ def test_transform_invalid_quality_error(self, quality): transforms.JPEG(quality=quality) -class TestQuerySize: +class TestUtils: + # TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes @pytest.mark.parametrize( - "make_input, input_name", - [ - (lambda: torch.rand(3, 32, 64), "pure_tensor"), - (lambda: tv_tensors.Image(torch.rand(3, 32, 64)), "tv_tensor_image"), - (lambda: PIL.Image.new("RGB", (64, 32)), "pil_image"), - (lambda: tv_tensors.Video(torch.rand(1, 3, 32, 64)), "tv_tensor_video"), - (lambda: tv_tensors.Mask(torch.randint(0, 2, (32, 64))), "tv_tensor_mask"), - ], - ids=["pure_tensor", "tv_tensor_image", "pil_image", "tv_tensor_video", "tv_tensor_mask"], + "make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize( + "make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] ) - def test_functional(self, make_input, input_name): - input1 = make_input() - input2 = make_input() - # Both inputs should have the same size (32, 64) - assert transforms.query_size([input1, input2]) == (32, 64) + @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) + def test_query_size_and_query_chw(self, make_input1, make_input2, query): + size = (32, 64) + input1 = make_input1(size) + input2 = make_input2(size) + + if query is transforms.query_chw and not any( + transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + for inpt in (input1, input2) + ): + return + + expected = size if query is transforms.query_size else ((3,) + size) + assert query([input1, input2]) == expected @pytest.mark.parametrize( - "make_input, input_name", - [ - (lambda: torch.rand(3, 32, 64), "pure_tensor"), - (lambda: tv_tensors.Image(torch.rand(3, 32, 64)), "tv_tensor_image"), - (lambda: PIL.Image.new("RGB", (64, 32)), "pil_image"), - (lambda: tv_tensors.Video(torch.rand(1, 3, 32, 64)), "tv_tensor_video"), - (lambda: tv_tensors.Mask(torch.randint(0, 2, (32, 64))), "tv_tensor_mask"), - ], - ids=["pure_tensor", "tv_tensor_image", "pil_image", "tv_tensor_video", "tv_tensor_mask"], - ) - def test_functional_mixed_types(self, make_input, input_name): - input1 = make_input() - input2 = make_input() - # Both inputs should have the same size (32, 64) - assert transforms.query_size([input1, input2]) == (32, 64) - - def test_different_sizes(self): - img_tensor = torch.rand(3, 32, 64) # (C, H, W) - img_tensor_different_size = torch.rand(3, 48, 96) # (C, H, W) - # Should raise ValueError for different sizes - with pytest.raises(ValueError, match="Found multiple HxW dimensions"): - transforms.query_size([img_tensor, img_tensor_different_size]) - - def test_no_valid_image(self): - invalid_input = torch.rand(1, 10) # Non-image tensor - # Should raise TypeError for invalid input - with pytest.raises(TypeError, match="No image, video, mask or bounding box was found"): - transforms.query_size([invalid_input]) + "make_input1", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize( + "make_input2", [make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask] + ) + @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) + def test_different_sizes(self, make_input1, make_input2, query): + input1 = make_input1((10, 10)) + input2 = make_input2((20, 20)) + if query is transforms.query_chw and not all( + transforms.check_type(inpt, (is_pure_tensor, tv_tensors.Image, PIL.Image.Image, tv_tensors.Video)) + for inpt in (input1, input2) + ): + return + with pytest.raises(ValueError, match="Found multiple"): + query([input1, input2]) + + @pytest.mark.parametrize("query", [transforms.query_size, transforms.query_chw]) + def test_no_valid_input(self, query): + with pytest.raises(TypeError, match="No image"): + query(["blah"]) diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 33d83f1fe3f..2d66917b6ea 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -55,5 +55,6 @@ ) from ._temporal import UniformTemporalSubsample from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor +from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 507a029fd67..d5705d55c4b 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -151,5 +151,3 @@ from ._type_conversion import pil_to_tensor, to_image, to_pil_image from ._deprecated import get_image_size, to_tensor # usort: skip - -from ._utils import query_size