diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 826ba8b57e1..315993c750e 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -1013,43 +1013,3 @@ def test_correctness_uniform_temporal_subsample(device): out_video = F.uniform_temporal_subsample(video, 8) assert out_video.unique().tolist() == [0, 1, 2, 3, 5, 6, 7, 9] - - -# TODO: We can remove this test and related torchvision workaround -# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 -@make_info_args_kwargs_parametrization( - [info for info in KERNEL_INFOS if info.kernel is F.resize_image], - args_kwargs_fn=lambda info: info.reference_inputs_fn(), -) -def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs): - (input, *other_args), kwargs = args_kwargs.load("cpu") - - output = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs) - - error_msg_fn = parametrized_error_message(input, *other_args, **kwargs) - assert input.ndim == 3, error_msg_fn - input_stride = input.stride() - output_stride = output.stride() - # Here we check output memory format according to the input: - # if input_stride is (..., 1) then input is most likely channels first and thus - # output strides should match channels first strides (H * W, H, 1) - # if input_stride is (1, ...) then input is most likely channels last and thus - # output strides should match channels last strides (1, W * C, C) - if input_stride[-1] == 1: - expected_stride = (output.shape[-2] * output.shape[-1], output.shape[-1], 1) - assert expected_stride == output_stride, error_msg_fn("") - elif input_stride[0] == 1: - expected_stride = (1, output.shape[0] * output.shape[-1], output.shape[0]) - assert expected_stride == output_stride, error_msg_fn("") - else: - assert False, error_msg_fn("") - - -def test_resize_float16_no_rounding(): - # Make sure Resize() doesn't round float16 images - # Non-regression test for https://github.com/pytorch/vision/issues/7667 - - img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16) - out = F.resize(img, size=(10, 10)) - assert out.dtype == torch.float16 - assert (out.round() - out).sum() > 0 diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 14842c85c4b..0b9024c946b 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -735,6 +735,66 @@ def test_no_regression_5405(self, make_input): assert max(F.get_size(output)) == max_size + def _make_image(self, *args, batch_dims=(), memory_format=torch.contiguous_format, **kwargs): + # torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming + # from PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the + # layout of the data in memory is channels last. To emulate this when a 3D input is requested here, we create + # the image as 4D and create a view with the right shape afterwards. With this the layout in memory is channels + # last although PyTorch doesn't recognizes it as such. + emulate_channels_last = memory_format is torch.channels_last and len(batch_dims) != 1 + + image = make_image( + *args, + batch_dims=(math.prod(batch_dims),) if emulate_channels_last else batch_dims, + memory_format=memory_format, + **kwargs, + ) + + if emulate_channels_last: + image = datapoints.wrap(image.view(*batch_dims, *image.shape[-3:]), like=image) + + return image + + def _check_stride(self, image, *, memory_format): + C, H, W = F.get_dimensions(image) + if memory_format is torch.contiguous_format: + expected_stride = (H * W, W, 1) + elif memory_format is torch.channels_last: + expected_stride = (1, W * C, C) + else: + raise ValueError(f"Unknown memory_format: {memory_format}") + + assert image.stride() == expected_stride + + # TODO: We can remove this test and related torchvision workaround + # once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 + @pytest.mark.parametrize("interpolation", INTERPOLATION_MODES) + @pytest.mark.parametrize("antialias", [True, False]) + @pytest.mark.parametrize("memory_format", [torch.contiguous_format, torch.channels_last]) + @pytest.mark.parametrize("dtype", [torch.uint8, torch.float32]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_kernel_image_memory_format_consistency(self, interpolation, antialias, memory_format, dtype, device): + size = self.OUTPUT_SIZES[0] + + input = self._make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format) + + # Smoke test to make sure we aren't starting with wrong assumptions + self._check_stride(input, memory_format=memory_format) + + output = F.resize_image(input, size=size, interpolation=interpolation, antialias=antialias) + + self._check_stride(output, memory_format=memory_format) + + def test_float16_no_rounding(self): + # Make sure Resize() doesn't round float16 images + # Non-regression test for https://github.com/pytorch/vision/issues/7667 + + input = make_image_tensor(self.INPUT_SIZE, dtype=torch.float16) + output = F.resize_image(input, size=self.OUTPUT_SIZES[0]) + + assert output.dtype is torch.float16 + assert (output.round() - output).abs().sum() > 0 + class TestHorizontalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])