Skip to content

Commit

Permalink
move channels_last handling to resize test
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Aug 24, 2023
1 parent 194a758 commit b53cd8e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
14 changes: 2 additions & 12 deletions test/common_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import contextlib
import functools
import itertools
import math
import os
import pathlib
import random
Expand Down Expand Up @@ -382,23 +381,14 @@ def make_image(
dtype = dtype or torch.uint8
max_value = get_max_value(dtype)

shape = make_tensor_shape = (*batch_dims, num_channels, *size)
# torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming from
# PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the layout of
# the data in memory is channels last. To emulate this when a 3D input is requested here, we create the image as 4D
# and create a view with the right shape afterwards. With this the layout in memory is channels last although
# PyTorch doesn't recognizes it as such.
if memory_format is torch.channels_last and len(batch_dims) != 1:
make_tensor_shape = (math.prod(shape[:-3]), *shape[-3:])

data = torch.testing.make_tensor(
make_tensor_shape,
(*batch_dims, num_channels, *size),
low=0,
high=max_value,
dtype=dtype,
device=device,
memory_format=memory_format,
).view(shape)
)
if color_space in {"GRAY_ALPHA", "RGBA"}:
data[..., -1, :, :] = max_value

Expand Down
22 changes: 21 additions & 1 deletion test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,26 @@ def _check_stride(self, image, *, memory_format):

assert image.stride() == expected_stride

def _make_image(self, *args, batch_dims=(), memory_format=torch.contiguous_format, **kwargs):
# torch.channels_last memory_format is only available for 4D tensors, i.e. (B, C, H, W). However, images coming
# from PIL or our own I/O functions do not have a batch dimensions and are thus 3D, i.e. (C, H, W). Still, the
# layout of the data in memory is channels last. To emulate this when a 3D input is requested here, we create
# the image as 4D and create a view with the right shape afterwards. With this the layout in memory is channels
# last although PyTorch doesn't recognizes it as such.
emulate_channels_last = memory_format is torch.channels_last or len(batch_dims) != 1

image = make_image(
*args,
batch_dims=(math.prod(batch_dims),) if emulate_channels_last else batch_dims,
memory_format=memory_format,
**kwargs,
)

if emulate_channels_last:
image = datapoints.wrap(image.view(*batch_dims, *image.shape[-3:]), like=image)

return image

# TODO: We can remove this test and related torchvision workaround
# once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430
@pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
Expand All @@ -761,7 +781,7 @@ def test_kernel_image_memory_format_consistency(
if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
return

input = make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format)
input = self._make_image(self.INPUT_SIZE, dtype=dtype, device=device, memory_format=memory_format)

# Smoke test to make sure we aren't starting with wrong assumptions
self._check_stride(input, memory_format=memory_format)
Expand Down

0 comments on commit b53cd8e

Please sign in to comment.