From 0e4961551d3b9cd6e766381cb7539531de20450b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 30 Jun 2023 10:30:33 +0200 Subject: [PATCH] port affine tests (#7708) --- test/test_transforms_v2.py | 124 ------- test/test_transforms_v2_functional.py | 83 ----- test/test_transforms_v2_refactored.py | 465 ++++++++++++++++++++++++- test/transforms_v2_dispatcher_infos.py | 15 - test/transforms_v2_kernel_infos.py | 182 ---------- 5 files changed, 453 insertions(+), 416 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 311a442ffed..e9d1dfc0517 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -668,130 +668,6 @@ def test_boundingbox_spatial_size(self, angle, expand): assert out_img.spatial_size == out_bbox.spatial_size -class TestRandomAffine: - def test_assertions(self): - with pytest.raises(ValueError, match="is a single number, it must be positive"): - transforms.RandomAffine(-0.7) - - for d in [[-0.7], [-0.7, 0, 0.7]]: - with pytest.raises(ValueError, match="degrees should be a sequence of length 2"): - transforms.RandomAffine(d) - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomAffine(12, fill="abc") - - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomAffine(12, fill="abc") - - for kwargs in [ - {"center": 12}, - {"translate": 12}, - {"scale": 12}, - ]: - with pytest.raises(TypeError, match="should be a sequence of length"): - transforms.RandomAffine(12, **kwargs) - - for kwargs in [{"center": [1, 2, 3]}, {"translate": [1, 2, 3]}, {"scale": [1, 2, 3]}]: - with pytest.raises(ValueError, match="should be a sequence of length"): - transforms.RandomAffine(12, **kwargs) - - with pytest.raises(ValueError, match="translation values should be between 0 and 1"): - transforms.RandomAffine(12, translate=[-1.0, 2.0]) - - with pytest.raises(ValueError, match="scale values should be positive"): - transforms.RandomAffine(12, scale=[-1.0, 2.0]) - - with pytest.raises(ValueError, match="is a single number, it must be positive"): - transforms.RandomAffine(12, shear=-10) - - for s in [[-0.7], [-0.7, 0, 0.7]]: - with pytest.raises(ValueError, match="shear should be a sequence of length 2"): - transforms.RandomAffine(12, shear=s) - - @pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)]) - @pytest.mark.parametrize("translate", [None, [0.1, 0.2]]) - @pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) - @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) - def test__get_params(self, degrees, translate, scale, shear, mocker): - image = mocker.MagicMock(spec=datapoints.Image) - image.num_channels = 3 - image.spatial_size = (24, 32) - h, w = image.spatial_size - - transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear) - params = transform._get_params([image]) - - if not isinstance(degrees, (list, tuple)): - assert -degrees <= params["angle"] <= degrees - else: - assert degrees[0] <= params["angle"] <= degrees[1] - - if translate is not None: - w_max = int(round(translate[0] * w)) - h_max = int(round(translate[1] * h)) - assert -w_max <= params["translate"][0] <= w_max - assert -h_max <= params["translate"][1] <= h_max - else: - assert params["translate"] == (0, 0) - - if scale is not None: - assert scale[0] <= params["scale"] <= scale[1] - else: - assert params["scale"] == 1.0 - - if shear is not None: - if isinstance(shear, float): - assert -shear <= params["shear"][0] <= shear - assert params["shear"][1] == 0.0 - elif len(shear) == 2: - assert shear[0] <= params["shear"][0] <= shear[1] - assert params["shear"][1] == 0.0 - else: - assert shear[0] <= params["shear"][0] <= shear[1] - assert shear[2] <= params["shear"][1] <= shear[3] - else: - assert params["shear"] == (0, 0) - - @pytest.mark.parametrize("degrees", [23, [0, 45], (0, 45)]) - @pytest.mark.parametrize("translate", [None, [0.1, 0.2]]) - @pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) - @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("center", [None, [2.0, 3.0]]) - def test__transform(self, degrees, translate, scale, shear, fill, center, mocker): - interpolation = InterpolationMode.BILINEAR - transform = transforms.RandomAffine( - degrees, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - center=center, - ) - - if isinstance(degrees, (tuple, list)): - assert transform.degrees == [float(degrees[0]), float(degrees[1])] - else: - assert transform.degrees == [float(-degrees), float(degrees)] - - fn = mocker.patch("torchvision.transforms.v2.functional.affine") - inpt = mocker.MagicMock(spec=datapoints.Image) - inpt.num_channels = 3 - inpt.spatial_size = (24, 32) - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - params = transform._get_params([inpt]) - - fill = transforms._utils._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) - - class TestRandomCrop: def test_assertions(self): with pytest.raises(ValueError, match="Please provide only two dimensions"): diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 93996432aa5..79ea20d854e 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -665,77 +665,6 @@ def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_): return true_matrix -@pytest.mark.parametrize("device", cpu_and_cuda()) -def test_correctness_affine_bounding_box_on_fixed_input(device): - # Check transformation against known expected output - format = datapoints.BoundingBoxFormat.XYXY - spatial_size = (64, 64) - in_boxes = [ - [20, 25, 35, 45], - [50, 5, 70, 22], - [spatial_size[1] // 2 - 10, spatial_size[0] // 2 - 10, spatial_size[1] // 2 + 10, spatial_size[0] // 2 + 10], - [1, 1, 5, 5], - ] - in_boxes = torch.tensor(in_boxes, dtype=torch.float64, device=device) - # Tested parameters - angle = 63 - scale = 0.89 - dx = 0.12 - dy = 0.23 - - # Expected bboxes computed using albumentations: - # from albumentations.augmentations.geometric.functional import bbox_shift_scale_rotate - # from albumentations.augmentations.geometric.functional import normalize_bbox, denormalize_bbox - # expected_bboxes = [] - # for in_box in in_boxes: - # n_in_box = normalize_bbox(in_box, *spatial_size) - # n_out_box = bbox_shift_scale_rotate(n_in_box, -angle, scale, dx, dy, *spatial_size) - # out_box = denormalize_bbox(n_out_box, *spatial_size) - # expected_bboxes.append(out_box) - expected_bboxes = [ - (24.522435977922218, 34.375689508290854, 46.443125279998114, 54.3516575015695), - (54.88288587110401, 50.08453280875634, 76.44484547743795, 72.81332520036864), - (27.709526487041554, 34.74952648704156, 51.650473512958435, 58.69047351295844), - (48.56528888843238, 9.611532109828834, 53.35347829361575, 14.39972151501221), - ] - - expected_bboxes = clamp_bounding_box( - datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size) - ).tolist() - - output_boxes = F.affine_bounding_box( - in_boxes, - format=format, - spatial_size=spatial_size, - angle=angle, - translate=(dx * spatial_size[1], dy * spatial_size[0]), - scale=scale, - shear=(0, 0), - ) - - torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) - - -@pytest.mark.parametrize("device", cpu_and_cuda()) -def test_correctness_affine_segmentation_mask_on_fixed_input(device): - # Check transformation against known expected output and CPU/CUDA devices - - # Create a fixed input segmentation mask with 2 square masks - # in top-left, bottom-left corners - mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device) - mask[0, 2:10, 2:10] = 1 - mask[0, 32 - 9 : 32 - 3, 3:9] = 2 - - # Rotate 90 degrees and scale - expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1)) - expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest") - expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long() - - out_mask = F.affine_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0]) - - torch.testing.assert_close(out_mask, expected_mask) - - @pytest.mark.parametrize("angle", range(-90, 90, 56)) @pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) def test_correctness_rotate_bounding_box(angle, expand, center): @@ -950,18 +879,6 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, torch.testing.assert_close(output_spatial_size, spatial_size) -@pytest.mark.parametrize("device", cpu_and_cuda()) -def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device): - mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) - mask[:, :, 0] = 1 - - out_mask = F.horizontal_flip_mask(mask) - - expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) - expected_mask[:, :, -1] = 1 - torch.testing.assert_close(out_mask, expected_mask) - - @pytest.mark.parametrize("device", cpu_and_cuda()) def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 05eb47ab69e..0db4824d584 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -1,5 +1,6 @@ import contextlib import inspect +import math import re from typing import get_type_hints from unittest import mock @@ -25,6 +26,8 @@ ) from torch.testing import assert_close from torchvision import datapoints + +from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F @@ -162,7 +165,7 @@ def _check_dispatcher_dispatch(dispatcher, kernel, input, *args, **kwargs): if isinstance(input, datapoints._datapoint.Datapoint): # Due to our complex dispatch architecture for datapoints, we cannot spy on the kernel directly, # but rather have to patch the `Datapoint.__F` attribute to contain the spied on kernel. - spy = mock.MagicMock(wraps=kernel) + spy = mock.MagicMock(wraps=kernel, name=kernel.__name__) with mock.patch.object(F, kernel.__name__, spy): # Due to Python's name mangling, the `Datapoint.__F` attribute is only accessible from inside the class. # Since that is not the case here, we need to prefix f"_{cls.__name__}" @@ -473,10 +476,9 @@ def test_kernel_bounding_box(self, format, size, use_max_size, dtype, device): ) @pytest.mark.parametrize( - "dtype_and_make_mask", [(torch.uint8, make_segmentation_mask), (torch.bool, make_detection_mask)] + ("dtype", "make_mask"), [(torch.uint8, make_segmentation_mask), (torch.bool, make_detection_mask)] ) - def test_kernel_mask(self, dtype_and_make_mask): - dtype, make_mask = dtype_and_make_mask + def test_kernel_mask(self, dtype, make_mask): check_kernel(F.resize_mask, make_mask(dtype=dtype), size=self.OUTPUT_SIZES[-1]) def test_kernel_video(self): @@ -744,7 +746,7 @@ def _make_input(self, input_type, *, dtype=None, device="cpu", spatial_size=(17, @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_image_tensor(self, dtype, device): - check_kernel(F.horizontal_flip_image_tensor, self._make_input(torch.Tensor)) + check_kernel(F.horizontal_flip_image_tensor, self._make_input(torch.Tensor, dtype=dtype, device=device)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @@ -785,16 +787,16 @@ def test_dispatcher(self, kernel, input_type): @pytest.mark.parametrize( ("input_type", "kernel"), [ - (torch.Tensor, F.resize_image_tensor), - (PIL.Image.Image, F.resize_image_pil), - (datapoints.Image, F.resize_image_tensor), - (datapoints.BoundingBox, F.resize_bounding_box), - (datapoints.Mask, F.resize_mask), - (datapoints.Video, F.resize_video), + (torch.Tensor, F.horizontal_flip_image_tensor), + (PIL.Image.Image, F.horizontal_flip_image_pil), + (datapoints.Image, F.horizontal_flip_image_tensor), + (datapoints.BoundingBox, F.horizontal_flip_bounding_box), + (datapoints.Mask, F.horizontal_flip_mask), + (datapoints.Video, F.horizontal_flip_video), ], ) def test_dispatcher_signature(self, kernel, input_type): - check_dispatcher_signatures_match(F.resize, kernel=kernel, input_type=input_type) + check_dispatcher_signatures_match(F.horizontal_flip, kernel=kernel, input_type=input_type) @pytest.mark.parametrize( "input_type", @@ -860,3 +862,442 @@ def test_transform_noop(self, input_type, device): output = transform(input) assert_equal(output, input) + + +class TestAffine: + def _make_input( + self, input_type, *, dtype=None, device="cpu", spatial_size=(17, 11), mask_type="segmentation", **kwargs + ): + if input_type in {torch.Tensor, PIL.Image.Image, datapoints.Image}: + input = make_image(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs) + if input_type is torch.Tensor: + input = input.as_subclass(torch.Tensor) + elif input_type is PIL.Image.Image: + input = F.to_image_pil(input) + elif input_type is datapoints.BoundingBox: + kwargs.setdefault("format", datapoints.BoundingBoxFormat.XYXY) + input = make_bounding_box( + dtype=dtype or torch.float32, + device=device, + spatial_size=spatial_size, + **kwargs, + ) + elif input_type is datapoints.Mask: + if mask_type == "segmentation": + make_mask = make_segmentation_mask + default_dtype = torch.uint8 + elif mask_type == "detection": + make_mask = make_detection_mask + default_dtype = torch.bool + input = make_mask(size=spatial_size, dtype=dtype or default_dtype, device=device, **kwargs) + elif input_type is datapoints.Video: + input = make_video(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs) + + return input + + def _adapt_fill(self, value, *, dtype): + """Adapt fill values in the range [0.0, 1.0] to the value range of the dtype""" + if value is None: + return value + + max_value = get_max_value(dtype) + + if isinstance(value, (int, float)): + return type(value)(value * max_value) + elif isinstance(value, (list, tuple)): + return type(value)(type(v)(v * max_value) for v in value) + else: + raise ValueError(f"fill should be an int or float, or a list or tuple of the former, but got '{value}'") + + _EXHAUSTIVE_TYPE_AFFINE_KWARGS = dict( + # float, int + angle=[-10.9, 18], + # two-list of float, two-list of int, two-tuple of float, two-tuple of int + translate=[[6.3, -0.6], [1, -3], (16.6, -6.6), (-2, 4)], + # float + scale=[0.5], + # float, int, + # one-list of float, one-list of int, one-tuple of float, one-tuple of int + # two-list of float, two-list of int, two-tuple of float, two-tuple of int + shear=[35.6, 38, [-37.7], [-23], (5.3,), (-52,), [5.4, 21.8], [-47, 51], (-11.2, 36.7), (8, -53)], + # None + # two-list of float, two-list of int, two-tuple of float, two-tuple of int + center=[None, [1.2, 4.9], [-3, 1], (2.5, -4.7), (3, 2)], + ) + # The special case for shear makes sure we pick a value that is supported while JIT scripting + _MINIMAL_AFFINE_KWARGS = { + k: vs[0] if k != "shear" else next(v for v in vs if isinstance(v, list)) + for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items() + } + _CORRECTNESS_AFFINE_KWARGS = { + k: [v for v in vs if v is None or isinstance(v, float) or (isinstance(v, list) and len(v) > 1)] + for k, vs in _EXHAUSTIVE_TYPE_AFFINE_KWARGS.items() + } + + _EXHAUSTIVE_TYPE_FILLS = [ + None, + 1, + 0.5, + [1], + [0.2], + (0,), + (0.7,), + [1, 0, 1], + [0.1, 0.2, 0.3], + (0, 1, 0), + (0.9, 0.234, 0.314), + ] + _CORRECTNESS_FILL = [ + v for v in _EXHAUSTIVE_TYPE_FILLS if v is None or isinstance(v, float) or (isinstance(v, list) and len(v) > 1) + ] + + _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES = dict( + degrees=[30, (-15, 20)], + translate=[None, (0.5, 0.5)], + scale=[None, (0.75, 1.25)], + shear=[None, (12, 30, -17, 5), 10, (-5, 12)], + ) + _CORRECTNESS_TRANSFORM_AFFINE_RANGES = { + k: next(v for v in vs if v is not None) for k, vs in _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES.items() + } + + def _check_kernel(self, kernel, input, *args, **kwargs): + kwargs_ = self._MINIMAL_AFFINE_KWARGS.copy() + kwargs_.update(kwargs) + check_kernel(kernel, input, *args, **kwargs_) + + @pytest.mark.parametrize( + ("param", "value"), + [ + (param, value) + for param, values in [ + ("angle", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"]), + ("translate", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"]), + ("shear", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"]), + ("center", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"]), + ("interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR]), + ("fill", _EXHAUSTIVE_TYPE_FILLS), + ] + for value in values + ], + ) + @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image_tensor(self, param, value, dtype, device): + if param == "fill": + value = self._adapt_fill(value, dtype=dtype) + self._check_kernel( + F.affine_image_tensor, + self._make_input(torch.Tensor, dtype=dtype, device=device), + **{param: value}, + check_scripted_vs_eager=not (param in {"shear", "fill"} and isinstance(value, (int, float))), + check_cuda_vs_cpu=dict(atol=1, rtol=0) + if dtype is torch.uint8 and param == "interpolation" and value is transforms.InterpolationMode.BILINEAR + else True, + ) + + @pytest.mark.parametrize( + ("param", "value"), + [ + (param, value) + for param, values in [ + ("angle", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["angle"]), + ("translate", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["translate"]), + ("shear", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["shear"]), + ("center", _EXHAUSTIVE_TYPE_AFFINE_KWARGS["center"]), + ] + for value in values + ], + ) + @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) + @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_bounding_box(self, param, value, format, dtype, device): + bounding_box = self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device) + self._check_kernel( + F.affine_bounding_box, + self._make_input(datapoints.BoundingBox, format=format, dtype=dtype, device=device), + format=format, + spatial_size=bounding_box.spatial_size, + **{param: value}, + check_scripted_vs_eager=not (param == "shear" and isinstance(value, (int, float))), + ) + + @pytest.mark.parametrize("mask_type", ["segmentation", "detection"]) + def test_kernel_mask(self, mask_type): + check_kernel( + F.affine_mask, self._make_input(datapoints.Mask, mask_type=mask_type), **self._MINIMAL_AFFINE_KWARGS + ) + + def test_kernel_video(self): + check_kernel(F.affine_video, self._make_input(datapoints.Video), **self._MINIMAL_AFFINE_KWARGS) + + @pytest.mark.parametrize( + ("input_type", "kernel"), + [ + (torch.Tensor, F.affine_image_tensor), + (PIL.Image.Image, F.affine_image_pil), + (datapoints.Image, F.affine_image_tensor), + (datapoints.BoundingBox, F.affine_bounding_box), + (datapoints.Mask, F.affine_mask), + (datapoints.Video, F.affine_video), + ], + ) + def test_dispatcher(self, kernel, input_type): + check_dispatcher(F.affine, kernel, self._make_input(input_type), **self._MINIMAL_AFFINE_KWARGS) + + @pytest.mark.parametrize( + ("input_type", "kernel"), + [ + (torch.Tensor, F.affine_image_tensor), + (PIL.Image.Image, F.affine_image_pil), + (datapoints.Image, F.affine_image_tensor), + (datapoints.BoundingBox, F.affine_bounding_box), + (datapoints.Mask, F.affine_mask), + (datapoints.Video, F.affine_video), + ], + ) + def test_dispatcher_signature(self, kernel, input_type): + check_dispatcher_signatures_match(F.affine, kernel=kernel, input_type=input_type) + + @pytest.mark.parametrize( + "input_type", + [torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video], + ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform(self, input_type, device): + input = self._make_input(input_type, device=device) + + check_transform(transforms.RandomAffine, input, **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES) + + @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) + @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"]) + @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"]) + @pytest.mark.parametrize("shear", _CORRECTNESS_AFFINE_KWARGS["shear"]) + @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) + @pytest.mark.parametrize( + "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR] + ) + @pytest.mark.parametrize("fill", _CORRECTNESS_FILL) + def test_functional_image_correctness(self, angle, translate, scale, shear, center, interpolation, fill): + image = self._make_input(torch.Tensor, dtype=torch.uint8, device="cpu") + + fill = self._adapt_fill(fill, dtype=torch.uint8) + + actual = F.affine( + image, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + interpolation=interpolation, + fill=fill, + ) + expected = F.to_image_tensor( + F.affine( + F.to_image_pil(image), + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + interpolation=interpolation, + fill=fill, + ) + ) + + mae = (actual.float() - expected.float()).abs().mean() + assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8 + + @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) + @pytest.mark.parametrize( + "interpolation", [transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR] + ) + @pytest.mark.parametrize("fill", _CORRECTNESS_FILL) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_image_correctness(self, center, interpolation, fill, seed): + image = self._make_input(torch.Tensor, dtype=torch.uint8, device="cpu") + + fill = self._adapt_fill(fill, dtype=torch.uint8) + + transform = transforms.RandomAffine( + **self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center, interpolation=interpolation, fill=fill + ) + + torch.manual_seed(seed) + actual = transform(image) + + torch.manual_seed(seed) + expected = F.to_image_tensor(transform(F.to_image_pil(image))) + + mae = (actual.float() - expected.float()).abs().mean() + assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8 + + def _compute_affine_matrix(self, *, angle, translate, scale, shear, center): + rot = math.radians(angle) + cx, cy = center + tx, ty = translate + sx, sy = [math.radians(s) for s in ([shear, 0.0] if isinstance(shear, (int, float)) else shear)] + + c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) + t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) + c_matrix_inv = np.linalg.inv(c_matrix) + rs_matrix = np.array( + [ + [scale * math.cos(rot), -scale * math.sin(rot), 0], + [scale * math.sin(rot), scale * math.cos(rot), 0], + [0, 0, 1], + ] + ) + shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) + shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) + rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) + true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) + return true_matrix + + def _reference_affine_bounding_box(self, bounding_box, *, angle, translate, scale, shear, center): + if center is None: + center = [s * 0.5 for s in bounding_box.spatial_size[::-1]] + + affine_matrix = self._compute_affine_matrix( + angle=angle, translate=translate, scale=scale, shear=shear, center=center + ) + affine_matrix = affine_matrix[:2, :] + + expected_bboxes = reference_affine_bounding_box_helper( + bounding_box, + format=bounding_box.format, + spatial_size=bounding_box.spatial_size, + affine_matrix=affine_matrix, + ) + + return expected_bboxes + + @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) + @pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"]) + @pytest.mark.parametrize("translate", _CORRECTNESS_AFFINE_KWARGS["translate"]) + @pytest.mark.parametrize("scale", _CORRECTNESS_AFFINE_KWARGS["scale"]) + @pytest.mark.parametrize("shear", _CORRECTNESS_AFFINE_KWARGS["shear"]) + @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) + def test_functional_bounding_box_correctness(self, format, angle, translate, scale, shear, center): + bounding_box = self._make_input(datapoints.BoundingBox, format=format) + + actual = F.affine( + bounding_box, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + expected = self._reference_affine_bounding_box( + bounding_box, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + + torch.testing.assert_close(actual, expected) + + @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) + @pytest.mark.parametrize("center", _CORRECTNESS_AFFINE_KWARGS["center"]) + @pytest.mark.parametrize("seed", list(range(5))) + def test_transform_bounding_box_correctness(self, format, center, seed): + bounding_box = self._make_input(datapoints.BoundingBox, format=format) + + transform = transforms.RandomAffine(**self._CORRECTNESS_TRANSFORM_AFFINE_RANGES, center=center) + + torch.manual_seed(seed) + params = transform._get_params([bounding_box]) + + torch.manual_seed(seed) + actual = transform(bounding_box) + + expected = self._reference_affine_bounding_box(bounding_box, **params, center=center) + + torch.testing.assert_close(actual, expected) + + @pytest.mark.parametrize("degrees", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["degrees"]) + @pytest.mark.parametrize("translate", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["translate"]) + @pytest.mark.parametrize("scale", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["scale"]) + @pytest.mark.parametrize("shear", _EXHAUSTIVE_TYPE_TRANSFORM_AFFINE_RANGES["shear"]) + @pytest.mark.parametrize("seed", list(range(10))) + def test_transform_get_params_bounds(self, degrees, translate, scale, shear, seed): + image = self._make_input(torch.Tensor) + height, width = F.get_spatial_size(image) + + transform = transforms.RandomAffine(degrees=degrees, translate=translate, scale=scale, shear=shear) + + torch.manual_seed(seed) + params = transform._get_params([image]) + + if isinstance(degrees, (int, float)): + assert -degrees <= params["angle"] <= degrees + else: + assert degrees[0] <= params["angle"] <= degrees[1] + + if translate is not None: + width_max = int(round(translate[0] * width)) + height_max = int(round(translate[1] * height)) + assert -width_max <= params["translate"][0] <= width_max + assert -height_max <= params["translate"][1] <= height_max + else: + assert params["translate"] == (0, 0) + + if scale is not None: + assert scale[0] <= params["scale"] <= scale[1] + else: + assert params["scale"] == 1.0 + + if shear is not None: + if isinstance(shear, (int, float)): + assert -shear <= params["shear"][0] <= shear + assert params["shear"][1] == 0.0 + elif len(shear) == 2: + assert shear[0] <= params["shear"][0] <= shear[1] + assert params["shear"][1] == 0.0 + elif len(shear) == 4: + assert shear[0] <= params["shear"][0] <= shear[1] + assert shear[2] <= params["shear"][1] <= shear[3] + else: + assert params["shear"] == (0, 0) + + @pytest.mark.parametrize("param", ["degrees", "translate", "scale", "shear", "center"]) + @pytest.mark.parametrize("value", [0, [0], [0, 0, 0]]) + def test_transform_sequence_len_errors(self, param, value): + if param in {"degrees", "shear"} and not isinstance(value, list): + return + + kwargs = {param: value} + if param != "degrees": + kwargs["degrees"] = 0 + + with pytest.raises( + ValueError if isinstance(value, list) else TypeError, match=f"{param} should be a sequence of length 2" + ): + transforms.RandomAffine(**kwargs) + + def test_transform_negative_degrees_error(self): + with pytest.raises(ValueError, match="If degrees is a single number, it must be positive"): + transforms.RandomAffine(degrees=-1) + + @pytest.mark.parametrize("translate", [[-1, 0], [2, 0], [-1, 2]]) + def test_transform_translate_range_error(self, translate): + with pytest.raises(ValueError, match="translation values should be between 0 and 1"): + transforms.RandomAffine(degrees=0, translate=translate) + + @pytest.mark.parametrize("scale", [[-1, 0], [0, -1], [-1, -1]]) + def test_transform_scale_range_error(self, scale): + with pytest.raises(ValueError, match="scale values should be positive"): + transforms.RandomAffine(degrees=0, scale=scale) + + def test_transform_negative_shear_error(self): + with pytest.raises(ValueError, match="If shear is a single number, it must be positive"): + transforms.RandomAffine(degrees=0, shear=-1) + + def test_transform_unknown_fill_error(self): + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomAffine(degrees=0, fill="fill") diff --git a/test/transforms_v2_dispatcher_infos.py b/test/transforms_v2_dispatcher_infos.py index e0f7edd7129..b217e1638c7 100644 --- a/test/transforms_v2_dispatcher_infos.py +++ b/test/transforms_v2_dispatcher_infos.py @@ -138,21 +138,6 @@ def fill_sequence_needs_broadcast(args_kwargs): DISPATCHER_INFOS = [ - DispatcherInfo( - F.affine, - kernels={ - datapoints.Image: F.affine_image_tensor, - datapoints.Video: F.affine_video, - datapoints.BoundingBox: F.affine_bounding_box, - datapoints.Mask: F.affine_mask, - }, - pil_kernel_info=PILKernelInfo(F.affine_image_pil), - test_marks=[ - *xfails_pil_if_fill_sequence_needs_broadcast, - xfail_jit_python_scalar_arg("shear"), - xfail_jit_python_scalar_arg("fill"), - ], - ), DispatcherInfo( F.vertical_flip, kernels={ diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index 54fd3a679a5..0daae8aeec8 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -1,7 +1,6 @@ import decimal import functools import itertools -import math import numpy as np import PIL.Image @@ -156,46 +155,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): KERNEL_INFOS = [] -_AFFINE_KWARGS = combinations_grid( - angle=[-87, 15, 90], - translate=[(5, 5), (-5, -5)], - scale=[0.77, 1.27], - shear=[(12, 12), (0, 0)], -) - - -def _diversify_affine_kwargs_types(affine_kwargs): - angle = affine_kwargs["angle"] - for diverse_angle in [int(angle), float(angle)]: - yield dict(affine_kwargs, angle=diverse_angle) - - shear = affine_kwargs["shear"] - for diverse_shear in [tuple(shear), list(shear), int(shear[0]), float(shear[0])]: - yield dict(affine_kwargs, shear=diverse_shear) - - -def _full_affine_params(**partial_params): - partial_params.setdefault("angle", 0.0) - partial_params.setdefault("translate", [0.0, 0.0]) - partial_params.setdefault("scale", 1.0) - partial_params.setdefault("shear", [0.0, 0.0]) - partial_params.setdefault("center", None) - return partial_params - - -_DIVERSE_AFFINE_PARAMS = [ - _full_affine_params(**{name: arg}) - for name, args in [ - ("angle", [1.0, 2]), - ("translate", [[1.0, 0.5], [1, 2], (1.0, 0.5), (1, 2)]), - ("scale", [0.5]), - ("shear", [1.0, 2, [1.0], [2], (1.0,), (2,), [1.0, 0.5], [1, 2], (1.0, 0.5), (1, 2)]), - ("center", [None, [1.0, 0.5], [1, 2], (1.0, 0.5), (1, 2)]), - ] - for arg in args -] - - def get_fills(*, num_channels, dtype): yield None @@ -226,72 +185,6 @@ def float32_vs_uint8_fill_adapter(other_args, kwargs): return other_args, dict(kwargs, fill=fill) -def sample_inputs_affine_image_tensor(): - make_affine_image_loaders = functools.partial( - make_image_loaders, sizes=["random"], color_spaces=["RGB"], dtypes=[torch.float32] - ) - - for image_loader, affine_params in itertools.product(make_affine_image_loaders(), _DIVERSE_AFFINE_PARAMS): - yield ArgsKwargs(image_loader, **affine_params) - - for image_loader in make_affine_image_loaders(): - for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype): - yield ArgsKwargs(image_loader, **_full_affine_params(), fill=fill) - - for image_loader, interpolation in itertools.product( - make_affine_image_loaders(), - [ - F.InterpolationMode.NEAREST, - F.InterpolationMode.BILINEAR, - ], - ): - yield ArgsKwargs(image_loader, **_full_affine_params(), fill=0) - - -def reference_inputs_affine_image_tensor(): - for image_loader, affine_kwargs in itertools.product(make_image_loaders_for_interpolation(), _AFFINE_KWARGS): - yield ArgsKwargs( - image_loader, - interpolation=F.InterpolationMode.NEAREST, - **affine_kwargs, - ) - - -def sample_inputs_affine_bounding_box(): - for bounding_box_loader, affine_params in itertools.product( - make_bounding_box_loaders(formats=[datapoints.BoundingBoxFormat.XYXY]), _DIVERSE_AFFINE_PARAMS - ): - yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - spatial_size=bounding_box_loader.spatial_size, - **affine_params, - ) - - -def _compute_affine_matrix(angle, translate, scale, shear, center): - rot = math.radians(angle) - cx, cy = center - tx, ty = translate - sx, sy = [math.radians(sh_) for sh_ in shear] - - c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]]) - t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]]) - c_matrix_inv = np.linalg.inv(c_matrix) - rs_matrix = np.array( - [ - [scale * math.cos(rot), -scale * math.sin(rot), 0], - [scale * math.sin(rot), scale * math.cos(rot), 0], - [0, 0, 1], - ] - ) - shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]]) - shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]]) - rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix)) - true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv))) - return true_matrix - - def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix): def transform(bbox, affine_matrix_, format_, spatial_size_): # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 @@ -342,81 +235,6 @@ def transform(bbox, affine_matrix_, format_, spatial_size_): return expected_bboxes -def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None): - if center is None: - center = [s * 0.5 for s in spatial_size[::-1]] - - affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center) - affine_matrix = affine_matrix[:2, :] - - expected_bboxes = reference_affine_bounding_box_helper( - bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix - ) - - return expected_bboxes - - -def reference_inputs_affine_bounding_box(): - for bounding_box_loader, affine_kwargs in itertools.product( - make_bounding_box_loaders(extra_dims=[()]), - _AFFINE_KWARGS, - ): - yield ArgsKwargs( - bounding_box_loader, - format=bounding_box_loader.format, - spatial_size=bounding_box_loader.spatial_size, - **affine_kwargs, - ) - - -def sample_inputs_affine_mask(): - for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]): - yield ArgsKwargs(mask_loader, **_full_affine_params()) - - -def sample_inputs_affine_video(): - for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]): - yield ArgsKwargs(video_loader, **_full_affine_params()) - - -KERNEL_INFOS.extend( - [ - KernelInfo( - F.affine_image_tensor, - sample_inputs_fn=sample_inputs_affine_image_tensor, - reference_fn=pil_reference_wrapper(F.affine_image_pil), - reference_inputs_fn=reference_inputs_affine_image_tensor, - float32_vs_uint8=True, - closeness_kwargs=pil_reference_pixel_difference(10, mae=True), - test_marks=[ - xfail_jit_python_scalar_arg("shear"), - xfail_jit_python_scalar_arg("fill"), - ], - ), - KernelInfo( - F.affine_bounding_box, - sample_inputs_fn=sample_inputs_affine_bounding_box, - reference_fn=reference_affine_bounding_box, - reference_inputs_fn=reference_inputs_affine_bounding_box, - test_marks=[ - xfail_jit_python_scalar_arg("shear"), - ], - ), - KernelInfo( - F.affine_mask, - sample_inputs_fn=sample_inputs_affine_mask, - test_marks=[ - xfail_jit_python_scalar_arg("shear"), - ], - ), - KernelInfo( - F.affine_video, - sample_inputs_fn=sample_inputs_affine_video, - ), - ] -) - - def sample_inputs_convert_format_bounding_box(): formats = list(datapoints.BoundingBoxFormat) for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):