Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor transforms v2 tests #7562

Merged
merged 53 commits into from
Jun 21, 2023
Merged
Changes from 1 commit
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
c12991b
[PoC] refactor transforms v2 tests
pmeier May 8, 2023
597bde0
Merge branch 'main' into refactor-v2-tests
pmeier May 10, 2023
d5167c2
Merge branch 'main' into refactor-v2-tests
pmeier May 10, 2023
95672b4
Merge branch 'main' into refactor-v2-tests
pmeier May 18, 2023
9dfe0fb
complete kernel checks
pmeier May 18, 2023
aa52e9d
align parameter names
pmeier May 18, 2023
d35c381
add tolerance handling
pmeier May 19, 2023
3b3edd7
add tolerance handling and dispatcher tests
pmeier May 22, 2023
8082ed4
Merge branch 'main' into refactor-v2-tests
pmeier May 22, 2023
fde36a2
don't check device in CUDA vs CPU test
pmeier May 23, 2023
785d4d9
address small comments
pmeier May 23, 2023
c178d26
refactor tolerances
pmeier May 23, 2023
3e75473
Merge branch 'main' into refactor-v2-tests
pmeier May 24, 2023
3ee85ee
add batch dim parametrization
pmeier May 24, 2023
c4c2987
Merge branch 'main' into refactor-v2-tests
pmeier May 31, 2023
045b237
simplify batch checks
pmeier May 31, 2023
0f12486
reduce dispatcher parametrization
pmeier May 31, 2023
1c46b24
Merge branch 'main' into refactor-v2-tests
pmeier Jun 1, 2023
f6157f3
polish kernel and dispatcher tests
pmeier Jun 1, 2023
ff1fb6a
Merge branch 'main' into refactor-v2-tests
pmeier Jun 2, 2023
43877de
add transforms tests
pmeier Jun 19, 2023
8ab5374
add tests for extra warnings and errors
pmeier Jun 19, 2023
5f1a68b
fix antialias test
pmeier Jun 19, 2023
b94d8bc
add output size checks
pmeier Jun 19, 2023
959200a
Merge branch 'main' into refactor-v2-tests
pmeier Jun 19, 2023
87ff4ae
fix cuda tolerances
pmeier Jun 19, 2023
2547acd
address small comments
pmeier Jun 19, 2023
a83f9fb
improve bicubic cuda check
pmeier Jun 19, 2023
3b10248
properly parametrize over simple tensors and PIL images
pmeier Jun 19, 2023
50da561
add bicubic cuda comment
pmeier Jun 19, 2023
658c307
Merge branch 'main' into refactor-v2-tests
pmeier Jun 19, 2023
b25cef8
reorder tests
pmeier Jun 20, 2023
026e295
fix int interpolation test
pmeier Jun 20, 2023
6770596
fix warnings
pmeier Jun 20, 2023
d77baec
add noop test
pmeier Jun 20, 2023
ad7437d
add regression test
pmeier Jun 20, 2023
1ad491d
fix warnings
pmeier Jun 20, 2023
6ffeeb4
improve pil compat interpolation test
pmeier Jun 20, 2023
60c3023
fix format not being used in bbox test
pmeier Jun 20, 2023
dc5b4f1
use pre-defined PIL interpolation mode mapping
pmeier Jun 20, 2023
23b1599
improve noop test
pmeier Jun 20, 2023
e66bf4f
unify dispatcher checks
pmeier Jun 20, 2023
624b925
dont use MAE for correctness checks
pmeier Jun 20, 2023
ad06894
unify image correctness tests
pmeier Jun 20, 2023
9239eda
Merge branch 'main' into refactor-v2-tests
NicolasHug Jun 21, 2023
03667b0
Update test/test_transforms_v2_refactored.py
pmeier Jun 21, 2023
4deb952
reinstate high bicubic tolerance
pmeier Jun 21, 2023
33d3207
Merge branch 'refactor-v2-tests' of github.com:pmeier/vision into ref…
pmeier Jun 21, 2023
c496bfd
Merge branch 'main' into refactor-v2-tests
pmeier Jun 21, 2023
ff4c0ea
fix
pmeier Jun 21, 2023
d5358b9
remove most old v2 resize tests
pmeier Jun 21, 2023
5957d42
port bounding box correctness tests
pmeier Jun 21, 2023
db6b137
Merge branch 'main' into refactor-v2-tests
pmeier Jun 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import pytest
import torch

from common_utils import cache, cpu_and_gpu
from torch.testing import assert_close
from torch.utils._pytree import tree_map
from torchvision import datapoints

from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import is_simple_tensor


def _check_cuda_vs_cpu(kernel, input_cuda, *other_args, **kwargs):
input_cuda = input_cuda.as_subclass(torch.Tensor)
input_cpu = input_cuda.to("cpu")

actual = kernel(input_cuda, *other_args, **kwargs)
expected = kernel(input_cpu, *other_args, **kwargs)

assert_close(actual, expected)


@cache
def _script(fn):
try:
return torch.jit.script(fn)
except Exception as error:
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
NicolasHug marked this conversation as resolved.
Show resolved Hide resolved


def _check_scripted_vs_eager(kernel_eager, input, *other_args, **kwargs):
kernel_scripted = _script(kernel_eager)

input = input.as_subclass(torch.Tensor)
actual = kernel_scripted(input, *other_args, **kwargs)
expected = kernel_eager(input, *other_args, **kwargs)

assert_close(actual, expected)


def _unbatch(batch, *, data_dims):
if isinstance(batch, torch.Tensor):
batched_tensor = batch
metadata = ()
else:
batched_tensor, *metadata = batch

if batched_tensor.ndim == data_dims:
return batch

return [
_unbatch(unbatched, data_dims=data_dims)
for unbatched in (
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
)
]
pmeier marked this conversation as resolved.
Show resolved Hide resolved


def _check_batched_vs_single(kernel, batched_input, *other_args, **kwargs):
input_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input)
# This dictionary contains the number of rightmost dimensions that contain the actual data.
# Everything to the left is considered a batch dimension.
data_dims = {
datapoints.Image: 3,
datapoints.BoundingBox: 1,
# `Mask`'s are special in the sense that the data dimensions depend on the type of mask. For detection masks
# it is 3 `(*, N, H, W)`, but for segmentation masks it is 2 `(*, H, W)`. Since both a grouped under one
# type all kernels should also work without differentiating between the two. Thus, we go with 2 here as
# common ground.
pmeier marked this conversation as resolved.
Show resolved Hide resolved
datapoints.Mask: 2,
datapoints.Video: 4,
}.get(input_type)
if data_dims is None:
raise pytest.UsageError(
f"The number of data dimensions cannot be determined for input of type {input_type.__name__}."
) from None
pmeier marked this conversation as resolved.
Show resolved Hide resolved
elif batched_input.ndim <= data_dims or not all(batched_input.shape[:-data_dims]):
# input is not batched or has a degenerate batch shape
return
pmeier marked this conversation as resolved.
Show resolved Hide resolved

batched_input = batched_input.as_subclass(torch.Tensor)
batched_output = kernel(batched_input, *other_args, **kwargs)
actual = _unbatch(batched_output, data_dims=data_dims)

single_inputs = _unbatch(batched_input, data_dims=data_dims)
expected = tree_map(lambda single_input: kernel(single_input, *other_args, **kwargs), single_inputs)

assert_close(actual, expected)


def check_kernel(
kernel,
input,
*other_kernel_args,
pmeier marked this conversation as resolved.
Show resolved Hide resolved
check_cuda_vs_cpu=True,
check_scripted_vs_eager=True,
check_batched_vs_single=True,
**kernel_kwargs,
# TODO: tolerances!
):
initial_input_version = input._version
output = kernel(input.as_subclass(torch.Tensor), *other_kernel_args, **kernel_kwargs)

# check that no inplace operation happened
assert input._version == initial_input_version

assert output.dtype == input.dtype
assert output.device == input.device

# TODO: we can do better here, by passing the regular output of the kernel instead of multiple times in
# each auxiliary helper

if check_cuda_vs_cpu and input.device.type == "cuda":
_check_cuda_vs_cpu(kernel, input, *other_kernel_args, **kernel_kwargs)

if check_scripted_vs_eager:
_check_scripted_vs_eager(kernel, input, *other_kernel_args, **kernel_kwargs)

if check_batched_vs_single:
_check_batched_vs_single(kernel, input, *other_kernel_args, **kernel_kwargs)


def check_dispatcher():
pass


def check_transform():
pass


class TestResize:
@pytest.mark.parametrize("size", [(11, 17), (15, 13)])
@pytest.mark.parametrize("antialias", [True, False])
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_resize_image_tensor(self, size, antialias, device):
image = torch.rand((3, 14, 16), dtype=torch.float32, device=device)
check_kernel(F.resize_image_tensor, image, size=size, antialias=antialias)

def test_resize_bounding_box(self):
pass

def test_resize(self):
pass

def test_Resize(self):
pass