From 1120aa9e89d86bc466bba8f10ba382a5c63d6056 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 8 Feb 2023 20:02:02 +0100 Subject: [PATCH] introduce heuristic for simple tensor handling of transforms v2 (#7170) --- test/test_prototype_transforms.py | 297 ++++++++++++------ torchvision/prototype/transforms/_misc.py | 19 ++ .../prototype/transforms/_transform.py | 35 ++- 3 files changed, 250 insertions(+), 101 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 335fbfd4fe3..29c2bc1358a 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,15 +1,16 @@ import itertools +import re import numpy as np import PIL.Image - import pytest import torch import torchvision.prototype.transforms.utils -from common_utils import assert_equal, cpu_and_gpu +from common_utils import cpu_and_gpu from prototype_common_utils import ( + assert_equal, DEFAULT_EXTRA_DIMS, make_bounding_box, make_bounding_boxes, @@ -25,7 +26,7 @@ ) from torchvision.ops.boxes import box_iou from torchvision.prototype import datapoints, transforms -from torchvision.prototype.transforms.utils import check_type +from torchvision.prototype.transforms.utils import check_type, is_simple_tensor from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -222,6 +223,67 @@ def test_random_resized_crop(self, transform, input): transform(input) +@pytest.mark.parametrize( + "flat_inputs", + itertools.permutations( + [ + next(make_vanilla_tensor_images()), + next(make_vanilla_tensor_images()), + next(make_pil_images()), + make_image(), + next(make_videos()), + ], + 3, + ), +) +def test_simple_tensor_heuristic(flat_inputs): + def split_on_simple_tensor(to_split): + # This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts: + # 1. The first simple tensor. If none is present, this will be `None` + # 2. A list of the remaining simple tensors + # 3. A list of all other items + simple_tensors = [] + others = [] + # Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to + # affect the splitting. + for item, inpt in zip(to_split, flat_inputs): + (simple_tensors if is_simple_tensor(inpt) else others).append(item) + return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others + + class CopyCloneTransform(transforms.Transform): + def _transform(self, inpt, params): + return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy() + + @staticmethod + def was_applied(output, inpt): + identity = output is inpt + if identity: + return False + + # Make sure nothing fishy is going on + assert_equal(output, inpt) + return True + + first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs) + + transform = CopyCloneTransform() + transformed_sample = transform(flat_inputs) + + first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample) + + if first_simple_tensor_input is not None: + if other_inputs: + assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) + else: + assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) + + for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs): + assert not transform.was_applied(output, inpt) + + for input, output in zip(other_inputs, other_outputs): + assert transform.was_applied(output, input) + + @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomHorizontalFlip: def input_expected_image_tensor(self, p, dtype=torch.float32): @@ -1755,117 +1817,158 @@ def test__transform(self, mocker): ) -@pytest.mark.parametrize( - ("dtype", "expected_dtypes"), - [ - ( - torch.float64, - {torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64}, - ), - ( - {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, - {torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, - ), - ], -) -def test_to_dtype(dtype, expected_dtypes): - sample = dict( - plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"), - image=make_image(dtype=torch.uint8), - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), - str="str", - int=0, +class TestToDtype: + @pytest.mark.parametrize( + ("dtype", "expected_dtypes"), + [ + ( + torch.float64, + { + datapoints.Video: torch.float64, + datapoints.Image: torch.float64, + datapoints.BoundingBox: torch.float64, + }, + ), + ( + {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, + {datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64}, + ), + ], ) + def test_call(self, dtype, expected_dtypes): + sample = dict( + video=make_video(dtype=torch.int64), + image=make_image(dtype=torch.uint8), + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32), + str="str", + int=0, + ) - transform = transforms.ToDtype(dtype) - transformed_sample = transform(sample) + transform = transforms.ToDtype(dtype) + transformed_sample = transform(sample) - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] - # make sure the transformation retains the type - assert isinstance(transformed_value, value_type) + # make sure the transformation retains the type + assert isinstance(transformed_value, value_type) - if isinstance(value, torch.Tensor): - assert transformed_value.dtype is expected_dtypes[value_type] - else: - assert transformed_value is value + if isinstance(value, torch.Tensor): + assert transformed_value.dtype is expected_dtypes[value_type] + else: + assert transformed_value is value + @pytest.mark.filterwarnings("error") + def test_plain_tensor_call(self): + tensor = torch.empty((), dtype=torch.float32) + transform = transforms.ToDtype({torch.Tensor: torch.float64}) -@pytest.mark.parametrize( - ("dims", "inverse_dims"), - [ - ( - {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None}, - {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None}, - ), - ( - {torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, - {torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, - ), - ], -) -def test_permute_dimensions(dims, inverse_dims): - sample = dict( - plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), - image=make_image(), - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), - video=make_video(), - str="str", - int=0, + assert transform(tensor).dtype is torch.float64 + + @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) + def test_plain_tensor_warning(self, other_type): + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): + transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64}) + + +class TestPermuteDimensions: + @pytest.mark.parametrize( + ("dims", "inverse_dims"), + [ + ( + {datapoints.Image: (2, 1, 0), datapoints.Video: None}, + {datapoints.Image: (2, 1, 0), datapoints.Video: None}, + ), + ( + {datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)}, + {datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)}, + ), + ], ) + def test_call(self, dims, inverse_dims): + sample = dict( + image=make_image(), + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), + video=make_video(), + str="str", + int=0, + ) - transform = transforms.PermuteDimensions(dims) - transformed_sample = transform(sample) + transform = transforms.PermuteDimensions(dims) + transformed_sample = transform(sample) - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] - if check_type( - value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) - ): - if transform.dims.get(value_type) is not None: - assert transformed_value.permute(inverse_dims[value_type]).equal(value) - assert type(transformed_value) == torch.Tensor - else: - assert transformed_value is value + if check_type( + value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) + ): + if transform.dims.get(value_type) is not None: + assert transformed_value.permute(inverse_dims[value_type]).equal(value) + assert type(transformed_value) == torch.Tensor + else: + assert transformed_value is value + @pytest.mark.filterwarnings("error") + def test_plain_tensor_call(self): + tensor = torch.empty((2, 3, 4)) + transform = transforms.PermuteDimensions(dims=(1, 2, 0)) -@pytest.mark.parametrize( - "dims", - [ - (-1, -2), - {torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None}, - ], -) -def test_transpose_dimensions(dims): - sample = dict( - plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"), - image=make_image(), - bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), - video=make_video(), - str="str", - int=0, + assert transform(tensor).shape == (3, 4, 2) + + @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) + def test_plain_tensor_warning(self, other_type): + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): + transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) + + +class TestTransposeDimensions: + @pytest.mark.parametrize( + "dims", + [ + (-1, -2), + {datapoints.Image: (1, 2), datapoints.Video: None}, + ], ) + def test_call(self, dims): + sample = dict( + image=make_image(), + bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY), + video=make_video(), + str="str", + int=0, + ) - transform = transforms.TransposeDimensions(dims) - transformed_sample = transform(sample) + transform = transforms.TransposeDimensions(dims) + transformed_sample = transform(sample) - for key, value in sample.items(): - value_type = type(value) - transformed_value = transformed_sample[key] + for key, value in sample.items(): + value_type = type(value) + transformed_value = transformed_sample[key] - transposed_dims = transform.dims.get(value_type) - if check_type( - value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) - ): - if transposed_dims is not None: - assert transformed_value.transpose(*transposed_dims).equal(value) - assert type(transformed_value) == torch.Tensor - else: - assert transformed_value is value + transposed_dims = transform.dims.get(value_type) + if check_type( + value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video) + ): + if transposed_dims is not None: + assert transformed_value.transpose(*transposed_dims).equal(value) + assert type(transformed_value) == torch.Tensor + else: + assert transformed_value is value + + @pytest.mark.filterwarnings("error") + def test_plain_tensor_call(self): + tensor = torch.empty((2, 3, 4)) + transform = transforms.TransposeDimensions(dims=(0, 2)) + + assert transform(tensor).shape == (4, 3, 2) + + @pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video]) + def test_plain_tensor_warning(self, other_type): + with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")): + transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)}) class TestUniformTemporalSubsample: diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 07ab53aff82..e7bb62da18e 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -1,3 +1,4 @@ +import warnings from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union import PIL.Image @@ -155,6 +156,12 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]]) super().__init__() if not isinstance(dtype, dict): dtype = _get_defaultdict(dtype) + if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: @@ -171,6 +178,12 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]] super().__init__() if not isinstance(dims, dict): dims = _get_defaultdict(dims) + if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) self.dims = dims def _transform( @@ -189,6 +202,12 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i super().__init__() if not isinstance(dims, dict): dims = _get_defaultdict(dims) + if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]): + warnings.warn( + "Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. " + "Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) " + "in case a `datapoints.Image` or `datapoints.Video` is present in the input." + ) self.dims = dims def _transform( diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 206889ace72..675b0787e83 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -7,7 +7,8 @@ import torch from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten -from torchvision.prototype.transforms.utils import check_type +from torchvision.prototype import datapoints +from torchvision.prototype.transforms.utils import check_type, has_any, is_simple_tensor from torchvision.utils import _log_api_usage_once @@ -37,9 +38,35 @@ def forward(self, *inputs: Any) -> Any: params = self._get_params(flat_inputs) - flat_outputs = [ - self._transform(inpt, params) if check_type(inpt, self._transformed_types) else inpt for inpt in flat_inputs - ] + # Below is a heuristic on how to deal with simple tensor inputs: + # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image + # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. + # 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is + # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` + # of `tree_flatten`, which recurses depth-first through the input. + # + # This heuristic stems from two requirements: + # 1. We need to keep BC for single input simple tensors and treat them as images. + # 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface` + # return supplemental numerical data as tensors that cannot be transformed as images. + # + # The heuristic should work well for most people in practice. The only case where it doesn't is if someone + # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. + # However, this case wasn't supported by transforms v1 either, so there is no BC concern. + flat_outputs = [] + transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) + for inpt in flat_inputs: + needs_transform = True + + if not check_type(inpt, self._transformed_types): + needs_transform = False + elif is_simple_tensor(inpt): + if transform_simple_tensor: + transform_simple_tensor = False + else: + needs_transform = False + + flat_outputs.append(self._transform(inpt, params) if needs_transform else inpt) return tree_unflatten(flat_outputs, spec)