Skip to content

Commit

Permalink
[fbsync] port vertical flip (#7712)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D47186566

fbshipit-source-id: 92dd32411629e98d4b82c69cd9a000bd92eeb5fb
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Jul 3, 2023
1 parent 541dbe2 commit d5276bf
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 151 deletions.
55 changes: 1 addition & 54 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import datapoints
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image
from torchvision.transforms.functional import InterpolationMode, to_pil_image
from torchvision.transforms.v2 import functional as F
from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw

Expand Down Expand Up @@ -406,59 +406,6 @@ def was_applied(output, inpt):
assert transform.was_applied(output, input)


@pytest.mark.parametrize("p", [0.0, 1.0])
class TestRandomVerticalFlip:
def input_expected_image_tensor(self, p, dtype=torch.float32):
input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype)
expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype)

return input, expected if p == 1 else input

def test_simple_tensor(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(input)

assert_equal(expected, actual)

def test_pil_image(self, p):
input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(to_pil_image(input))

assert_equal(expected, pil_to_tensor(actual))

def test_datapoints_image(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(datapoints.Image(input))

assert_equal(datapoints.Image(expected), actual)

def test_datapoints_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(datapoints.Mask(input))

assert_equal(datapoints.Mask(expected), actual)

def test_datapoints_bounding_box(self, p):
input = datapoints.BoundingBox([0, 0, 5, 5], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10))
transform = transforms.RandomVerticalFlip(p=p)

actual = transform(input)

expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input
expected = datapoints.BoundingBox.wrap_like(input, expected_image_tensor)
assert_equal(expected, actual)
assert actual.format == expected.format
assert actual.spatial_size == expected.spatial_size


class TestPad:
def test_assertions(self):
with pytest.raises(TypeError, match="Got inappropriate padding arg"):
Expand Down
148 changes: 143 additions & 5 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def _reference_horizontal_flip_bounding_box(self, bounding_box):
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
)
def test_bounding_box_correctness(self, format, fn):
bounding_box = self._make_input(datapoints.BoundingBox)
bounding_box = self._make_input(datapoints.BoundingBox, format=format)

actual = fn(bounding_box)
expected = self._reference_horizontal_flip_bounding_box(bounding_box)
Expand Down Expand Up @@ -1025,12 +1025,10 @@ def test_kernel_bounding_box(self, param, value, format, dtype, device):

@pytest.mark.parametrize("mask_type", ["segmentation", "detection"])
def test_kernel_mask(self, mask_type):
check_kernel(
F.affine_mask, self._make_input(datapoints.Mask, mask_type=mask_type), **self._MINIMAL_AFFINE_KWARGS
)
self._check_kernel(F.affine_mask, self._make_input(datapoints.Mask, mask_type=mask_type))

def test_kernel_video(self):
check_kernel(F.affine_video, self._make_input(datapoints.Video), **self._MINIMAL_AFFINE_KWARGS)
self._check_kernel(F.affine_video, self._make_input(datapoints.Video))

@pytest.mark.parametrize(
("input_type", "kernel"),
Expand Down Expand Up @@ -1301,3 +1299,143 @@ def test_transform_negative_shear_error(self):
def test_transform_unknown_fill_error(self):
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.RandomAffine(degrees=0, fill="fill")


class TestVerticalFlip:
def _make_input(self, input_type, *, dtype=None, device="cpu", spatial_size=(17, 11), **kwargs):
if input_type in {torch.Tensor, PIL.Image.Image, datapoints.Image}:
input = make_image(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
if input_type is torch.Tensor:
input = input.as_subclass(torch.Tensor)
elif input_type is PIL.Image.Image:
input = F.to_image_pil(input)
elif input_type is datapoints.BoundingBox:
kwargs.setdefault("format", datapoints.BoundingBoxFormat.XYXY)
input = make_bounding_box(
dtype=dtype or torch.float32,
device=device,
spatial_size=spatial_size,
**kwargs,
)
elif input_type is datapoints.Mask:
input = make_segmentation_mask(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)
elif input_type is datapoints.Video:
input = make_video(size=spatial_size, dtype=dtype or torch.uint8, device=device, **kwargs)

return input

@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image_tensor(self, dtype, device):
check_kernel(F.vertical_flip_image_tensor, self._make_input(torch.Tensor, dtype=dtype, device=device))

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_bounding_box(self, format, dtype, device):
bounding_box = self._make_input(datapoints.BoundingBox, dtype=dtype, device=device, format=format)
check_kernel(
F.vertical_flip_bounding_box,
bounding_box,
format=format,
spatial_size=bounding_box.spatial_size,
)

@pytest.mark.parametrize(
"dtype_and_make_mask", [(torch.uint8, make_segmentation_mask), (torch.bool, make_detection_mask)]
)
def test_kernel_mask(self, dtype_and_make_mask):
dtype, make_mask = dtype_and_make_mask
check_kernel(F.vertical_flip_mask, make_mask(dtype=dtype))

def test_kernel_video(self):
check_kernel(F.vertical_flip_video, self._make_input(datapoints.Video))

@pytest.mark.parametrize(
("input_type", "kernel"),
[
(torch.Tensor, F.vertical_flip_image_tensor),
(PIL.Image.Image, F.vertical_flip_image_pil),
(datapoints.Image, F.vertical_flip_image_tensor),
(datapoints.BoundingBox, F.vertical_flip_bounding_box),
(datapoints.Mask, F.vertical_flip_mask),
(datapoints.Video, F.vertical_flip_video),
],
)
def test_dispatcher(self, kernel, input_type):
check_dispatcher(F.vertical_flip, kernel, self._make_input(input_type))

@pytest.mark.parametrize(
("input_type", "kernel"),
[
(torch.Tensor, F.vertical_flip_image_tensor),
(PIL.Image.Image, F.vertical_flip_image_pil),
(datapoints.Image, F.vertical_flip_image_tensor),
(datapoints.BoundingBox, F.vertical_flip_bounding_box),
(datapoints.Mask, F.vertical_flip_mask),
(datapoints.Video, F.vertical_flip_video),
],
)
def test_dispatcher_signature(self, kernel, input_type):
check_dispatcher_signatures_match(F.vertical_flip, kernel=kernel, input_type=input_type)

@pytest.mark.parametrize(
"input_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform(self, input_type, device):
input = self._make_input(input_type, device=device)

check_transform(transforms.RandomVerticalFlip, input, p=1)

@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_image_correctness(self, fn):
image = self._make_input(torch.Tensor, dtype=torch.uint8, device="cpu")

actual = fn(image)
expected = F.to_image_tensor(F.vertical_flip(F.to_image_pil(image)))

torch.testing.assert_close(actual, expected)

def _reference_vertical_flip_bounding_box(self, bounding_box):
affine_matrix = np.array(
[
[1, 0, 0],
[0, -1, bounding_box.spatial_size[0]],
],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_box_helper(
bounding_box,
format=bounding_box.format,
spatial_size=bounding_box.spatial_size,
affine_matrix=affine_matrix,
)

return datapoints.BoundingBox.wrap_like(bounding_box, expected_bboxes)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_bounding_box_correctness(self, format, fn):
bounding_box = self._make_input(datapoints.BoundingBox, format=format)

actual = fn(bounding_box)
expected = self._reference_vertical_flip_bounding_box(bounding_box)

torch.testing.assert_close(actual, expected)

@pytest.mark.parametrize(
"input_type",
[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.BoundingBox, datapoints.Mask, datapoints.Video],
)
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, input_type, device):
input = self._make_input(input_type, device=device)

transform = transforms.RandomVerticalFlip(p=0)

output = transform(input)

assert_equal(output, input)
10 changes: 0 additions & 10 deletions test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,6 @@ def fill_sequence_needs_broadcast(args_kwargs):


DISPATCHER_INFOS = [
DispatcherInfo(
F.vertical_flip,
kernels={
datapoints.Image: F.vertical_flip_image_tensor,
datapoints.Video: F.vertical_flip_video,
datapoints.BoundingBox: F.vertical_flip_bounding_box,
datapoints.Mask: F.vertical_flip_mask,
},
pil_kernel_info=PILKernelInfo(F.vertical_flip_image_pil, kernel_name="vertical_flip_image_pil"),
),
DispatcherInfo(
F.rotate,
kernels={
Expand Down
81 changes: 0 additions & 81 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,87 +264,6 @@ def reference_inputs_convert_format_bounding_box():
)


def sample_inputs_vertical_flip_image_tensor():
for image_loader in make_image_loaders(sizes=["random"], dtypes=[torch.float32]):
yield ArgsKwargs(image_loader)


def reference_inputs_vertical_flip_image_tensor():
for image_loader in make_image_loaders(extra_dims=[()], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)


def sample_inputs_vertical_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(
formats=[datapoints.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
):
yield ArgsKwargs(
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
)


def sample_inputs_vertical_flip_mask():
for image_loader in make_mask_loaders(sizes=["random"], dtypes=[torch.uint8]):
yield ArgsKwargs(image_loader)


def sample_inputs_vertical_flip_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader)


def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
affine_matrix = np.array(
[
[1, 0, 0],
[0, -1, spatial_size[0]],
],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)

expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)

return expected_bboxes


def reference_inputs_vertical_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)


KERNEL_INFOS.extend(
[
KernelInfo(
F.vertical_flip_image_tensor,
kernel_name="vertical_flip_image_tensor",
sample_inputs_fn=sample_inputs_vertical_flip_image_tensor,
reference_fn=pil_reference_wrapper(F.vertical_flip_image_pil),
reference_inputs_fn=reference_inputs_vertical_flip_image_tensor,
float32_vs_uint8=True,
),
KernelInfo(
F.vertical_flip_bounding_box,
sample_inputs_fn=sample_inputs_vertical_flip_bounding_box,
reference_fn=reference_vertical_flip_bounding_box,
reference_inputs_fn=reference_inputs_vertical_flip_bounding_box,
),
KernelInfo(
F.vertical_flip_mask,
sample_inputs_fn=sample_inputs_vertical_flip_mask,
),
KernelInfo(
F.vertical_flip_video,
sample_inputs_fn=sample_inputs_vertical_flip_video,
),
]
)

_ROTATE_ANGLES = [-87, 15, 90]


Expand Down
3 changes: 2 additions & 1 deletion torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
return image.flip(-2)


vertical_flip_image_pil = _FP.vflip
def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image:
return _FP.vflip(image)


def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit d5276bf

Please sign in to comment.