Skip to content

Commit

Permalink
introduce heuristic for simple tensor handling of transforms v2 (#7170)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Feb 8, 2023
1 parent 1222b49 commit 1120aa9
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 101 deletions.
297 changes: 200 additions & 97 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import itertools
import re

import numpy as np

import PIL.Image

import pytest
import torch

import torchvision.prototype.transforms.utils
from common_utils import assert_equal, cpu_and_gpu
from common_utils import cpu_and_gpu
from prototype_common_utils import (
assert_equal,
DEFAULT_EXTRA_DIMS,
make_bounding_box,
make_bounding_boxes,
Expand All @@ -25,7 +26,7 @@
)
from torchvision.ops.boxes import box_iou
from torchvision.prototype import datapoints, transforms
from torchvision.prototype.transforms.utils import check_type
from torchvision.prototype.transforms.utils import check_type, is_simple_tensor
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims]
Expand Down Expand Up @@ -222,6 +223,67 @@ def test_random_resized_crop(self, transform, input):
transform(input)


@pytest.mark.parametrize(
"flat_inputs",
itertools.permutations(
[
next(make_vanilla_tensor_images()),
next(make_vanilla_tensor_images()),
next(make_pil_images()),
make_image(),
next(make_videos()),
],
3,
),
)
def test_simple_tensor_heuristic(flat_inputs):
def split_on_simple_tensor(to_split):
# This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts:
# 1. The first simple tensor. If none is present, this will be `None`
# 2. A list of the remaining simple tensors
# 3. A list of all other items
simple_tensors = []
others = []
# Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to
# affect the splitting.
for item, inpt in zip(to_split, flat_inputs):
(simple_tensors if is_simple_tensor(inpt) else others).append(item)
return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others

class CopyCloneTransform(transforms.Transform):
def _transform(self, inpt, params):
return inpt.clone() if isinstance(inpt, torch.Tensor) else inpt.copy()

@staticmethod
def was_applied(output, inpt):
identity = output is inpt
if identity:
return False

# Make sure nothing fishy is going on
assert_equal(output, inpt)
return True

first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs)

transform = CopyCloneTransform()
transformed_sample = transform(flat_inputs)

first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample)

if first_simple_tensor_input is not None:
if other_inputs:
assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)
else:
assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input)

for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs):
assert not transform.was_applied(output, inpt)

for input, output in zip(other_inputs, other_outputs):
assert transform.was_applied(output, input)


@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomHorizontalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
Expand Down Expand Up @@ -1755,117 +1817,158 @@ def test__transform(self, mocker):
)


@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{torch.Tensor: torch.float64, datapoints.Image: torch.float64, datapoints.BoundingBox: torch.float64},
),
(
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{torch.Tensor: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
),
],
)
def test_to_dtype(dtype, expected_dtypes):
sample = dict(
plain_tensor=torch.testing.make_tensor(5, dtype=torch.int64, device="cpu"),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
class TestToDtype:
@pytest.mark.parametrize(
("dtype", "expected_dtypes"),
[
(
torch.float64,
{
datapoints.Video: torch.float64,
datapoints.Image: torch.float64,
datapoints.BoundingBox: torch.float64,
},
),
(
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
{datapoints.Video: torch.int32, datapoints.Image: torch.float32, datapoints.BoundingBox: torch.float64},
),
],
)
def test_call(self, dtype, expected_dtypes):
sample = dict(
video=make_video(dtype=torch.int64),
image=make_image(dtype=torch.uint8),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY, dtype=torch.float32),
str="str",
int=0,
)

transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)
transform = transforms.ToDtype(dtype)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)
# make sure the transformation retains the type
assert isinstance(transformed_value, value_type)

if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value
if isinstance(value, torch.Tensor):
assert transformed_value.dtype is expected_dtypes[value_type]
else:
assert transformed_value is value

@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((), dtype=torch.float32)
transform = transforms.ToDtype({torch.Tensor: torch.float64})

@pytest.mark.parametrize(
("dims", "inverse_dims"),
[
(
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: None},
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: None},
),
(
{torch.Tensor: (1, 2, 0), datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
{torch.Tensor: (2, 0, 1), datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
),
],
)
def test_permute_dimensions(dims, inverse_dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
assert transform(tensor).dtype is torch.float64

@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.ToDtype(dtype={torch.Tensor: torch.float32, other_type: torch.float64})


class TestPermuteDimensions:
@pytest.mark.parametrize(
("dims", "inverse_dims"),
[
(
{datapoints.Image: (2, 1, 0), datapoints.Video: None},
{datapoints.Image: (2, 1, 0), datapoints.Video: None},
),
(
{datapoints.Image: (2, 1, 0), datapoints.Video: (1, 2, 3, 0)},
{datapoints.Image: (2, 1, 0), datapoints.Video: (3, 0, 1, 2)},
),
],
)
def test_call(self, dims, inverse_dims):
sample = dict(
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)

transform = transforms.PermuteDimensions(dims)
transformed_sample = transform(sample)
transform = transforms.PermuteDimensions(dims)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transform.dims.get(value_type) is not None:
assert transformed_value.permute(inverse_dims[value_type]).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value

@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.PermuteDimensions(dims=(1, 2, 0))

@pytest.mark.parametrize(
"dims",
[
(-1, -2),
{torch.Tensor: (-1, -2), datapoints.Image: (1, 2), datapoints.Video: None},
],
)
def test_transpose_dimensions(dims):
sample = dict(
plain_tensor=torch.testing.make_tensor((3, 28, 28), dtype=torch.uint8, device="cpu"),
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
assert transform(tensor).shape == (3, 4, 2)

@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.PermuteDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})


class TestTransposeDimensions:
@pytest.mark.parametrize(
"dims",
[
(-1, -2),
{datapoints.Image: (1, 2), datapoints.Video: None},
],
)
def test_call(self, dims):
sample = dict(
image=make_image(),
bounding_box=make_bounding_box(format=datapoints.BoundingBoxFormat.XYXY),
video=make_video(),
str="str",
int=0,
)

transform = transforms.TransposeDimensions(dims)
transformed_sample = transform(sample)
transform = transforms.TransposeDimensions(dims)
transformed_sample = transform(sample)

for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]
for key, value in sample.items():
value_type = type(value)
transformed_value = transformed_sample[key]

transposed_dims = transform.dims.get(value_type)
if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value
transposed_dims = transform.dims.get(value_type)
if check_type(
value, (datapoints.Image, torchvision.prototype.transforms.utils.is_simple_tensor, datapoints.Video)
):
if transposed_dims is not None:
assert transformed_value.transpose(*transposed_dims).equal(value)
assert type(transformed_value) == torch.Tensor
else:
assert transformed_value is value

@pytest.mark.filterwarnings("error")
def test_plain_tensor_call(self):
tensor = torch.empty((2, 3, 4))
transform = transforms.TransposeDimensions(dims=(0, 2))

assert transform(tensor).shape == (4, 3, 2)

@pytest.mark.parametrize("other_type", [datapoints.Image, datapoints.Video])
def test_plain_tensor_warning(self, other_type):
with pytest.warns(UserWarning, match=re.escape("`torch.Tensor` will *not* be transformed")):
transforms.TransposeDimensions(dims={torch.Tensor: (0, 1), other_type: (1, 0)})


class TestUniformTemporalSubsample:
Expand Down
19 changes: 19 additions & 0 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import PIL.Image
Expand Down Expand Up @@ -155,6 +156,12 @@ def __init__(self, dtype: Union[torch.dtype, Dict[Type, Optional[torch.dtype]]])
super().__init__()
if not isinstance(dtype, dict):
dtype = _get_defaultdict(dtype)
if torch.Tensor in dtype and any(cls in dtype for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dtype` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dtype = dtype

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
Expand All @@ -171,6 +178,12 @@ def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dims = dims

def _transform(
Expand All @@ -189,6 +202,12 @@ def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, i
super().__init__()
if not isinstance(dims, dict):
dims = _get_defaultdict(dims)
if torch.Tensor in dims and any(cls in dims for cls in [datapoints.Image, datapoints.Video]):
warnings.warn(
"Got `dims` values for `torch.Tensor` and either `datapoints.Image` or `datapoints.Video`. "
"Note that a plain `torch.Tensor` will *not* be transformed by this (or any other transformation) "
"in case a `datapoints.Image` or `datapoints.Video` is present in the input."
)
self.dims = dims

def _transform(
Expand Down
Loading

0 comments on commit 1120aa9

Please sign in to comment.