From ca012d39c6ba265091d9373c8ca00157b933d3e9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Aug 2023 16:09:13 +0200 Subject: [PATCH 1/3] make PIL kernels private (#7831) --- docs/source/transforms.rst | 3 +- gallery/plot_transforms_v2_e2e.py | 4 +- references/detection/presets.py | 8 +- references/segmentation/presets.py | 6 +- test/common_utils.py | 10 +- test/test_prototype_transforms.py | 4 +- test/test_transforms_v2.py | 29 +-- test/test_transforms_v2_consistency.py | 16 +- test/test_transforms_v2_functional.py | 28 +-- test/test_transforms_v2_refactored.py | 110 +++++----- test/test_transforms_v2_utils.py | 4 +- test/transforms_v2_dispatcher_infos.py | 82 ++++---- test/transforms_v2_kernel_infos.py | 91 ++++---- torchvision/prototype/transforms/_augment.py | 4 +- torchvision/transforms/v2/__init__.py | 2 +- torchvision/transforms/v2/_auto_augment.py | 2 +- torchvision/transforms/v2/_type_conversion.py | 13 +- .../transforms/v2/functional/__init__.py | 128 ++++++------ .../transforms/v2/functional/_augment.py | 10 +- .../transforms/v2/functional/_color.py | 130 ++++++------ .../transforms/v2/functional/_deprecated.py | 2 +- .../transforms/v2/functional/_geometry.py | 196 +++++++++--------- torchvision/transforms/v2/functional/_meta.py | 26 +-- torchvision/transforms/v2/functional/_misc.py | 26 ++- .../v2/functional/_type_conversion.py | 8 +- 25 files changed, 454 insertions(+), 488 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 0df46c92530..6700395717f 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -228,12 +228,11 @@ Conversion ToPILImage v2.ToPILImage - v2.ToImagePIL ToTensor v2.ToTensor PILToTensor v2.PILToTensor - v2.ToImageTensor + v2.ToImage ConvertImageDtype v2.ConvertImageDtype v2.ToDtype diff --git a/gallery/plot_transforms_v2_e2e.py b/gallery/plot_transforms_v2_e2e.py index ccffea766c8..b837b9ba972 100644 --- a/gallery/plot_transforms_v2_e2e.py +++ b/gallery/plot_transforms_v2_e2e.py @@ -27,7 +27,7 @@ def show(sample): image, target = sample if isinstance(image, PIL.Image.Image): - image = F.to_image_tensor(image) + image = F.to_image(image) image = F.to_dtype(image, torch.uint8, scale=True) annotated_image = draw_bounding_boxes(image, target["boxes"], colors="yellow", width=3) @@ -101,7 +101,7 @@ def load_example_coco_detection_dataset(**kwargs): transforms.RandomZoomOut(fill={PIL.Image.Image: (123, 117, 104), "others": 0}), transforms.RandomIoUCrop(), transforms.RandomHorizontalFlip(), - transforms.ToImageTensor(), + transforms.ToImage(), transforms.ConvertImageDtype(torch.float32), transforms.SanitizeBoundingBoxes(), ] diff --git a/references/detection/presets.py b/references/detection/presets.py index 098ec85e690..09ca148a263 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -33,7 +33,7 @@ def __init__( transforms = [] backend = backend.lower() if backend == "datapoint": - transforms.append(T.ToImageTensor()) + transforms.append(T.ToImage()) elif backend == "tensor": transforms.append(T.PILToTensor()) elif backend != "pil": @@ -71,7 +71,7 @@ def __init__( if backend == "pil": # Note: we could just convert to pure tensors even in v2. - transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + transforms += [T.ToImage() if use_v2 else T.PILToTensor()] transforms += [T.ConvertImageDtype(torch.float)] @@ -94,11 +94,11 @@ def __init__(self, backend="pil", use_v2=False): backend = backend.lower() if backend == "pil": # Note: we could just convert to pure tensors even in v2? - transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + transforms += [T.ToImage() if use_v2 else T.PILToTensor()] elif backend == "tensor": transforms += [T.PILToTensor()] elif backend == "datapoint": - transforms += [T.ToImageTensor()] + transforms += [T.ToImage()] else: raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index e62fd5ae301..755cb236dcb 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -32,7 +32,7 @@ def __init__( transforms = [] backend = backend.lower() if backend == "datapoint": - transforms.append(T.ToImageTensor()) + transforms.append(T.ToImage()) elif backend == "tensor": transforms.append(T.PILToTensor()) elif backend != "pil": @@ -81,7 +81,7 @@ def __init__( if backend == "tensor": transforms += [T.PILToTensor()] elif backend == "datapoint": - transforms += [T.ToImageTensor()] + transforms += [T.ToImage()] elif backend != "pil": raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}") @@ -92,7 +92,7 @@ def __init__( if backend == "pil": # Note: we could just convert to pure tensors even in v2? - transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()] + transforms += [T.ToImage() if use_v2 else T.PILToTensor()] transforms += [ T.ConvertImageDtype(torch.float), diff --git a/test/common_utils.py b/test/common_utils.py index 8d5eb047534..9713901bdcf 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -27,7 +27,7 @@ from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair from torchvision import datapoints, io from torchvision.transforms._functional_tensor import _max_value as get_max_value -from torchvision.transforms.v2.functional import to_dtype_image_tensor, to_image_pil, to_image_tensor +from torchvision.transforms.v2.functional import to_dtype_image, to_image, to_pil_image IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"]) @@ -293,7 +293,7 @@ def __init__( **other_parameters, ): if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): - actual, expected = [to_image_tensor(input) for input in [actual, expected]] + actual, expected = [to_image(input) for input in [actual, expected]] super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -536,7 +536,7 @@ def make_image_tensor(*args, **kwargs): def make_image_pil(*args, **kwargs): - return to_image_pil(make_image(*args, **kwargs)) + return to_pil_image(make_image(*args, **kwargs)) def make_image_loader( @@ -609,12 +609,12 @@ def fn(shape, dtype, device, memory_format): ) ) - image_tensor = to_image_tensor(image_pil) + image_tensor = to_image(image_pil) if memory_format == torch.contiguous_format: image_tensor = image_tensor.to(device=device, memory_format=memory_format, copy=True) else: image_tensor = image_tensor.to(device=device) - image_tensor = to_dtype_image_tensor(image_tensor, dtype=dtype, scale=True) + image_tensor = to_dtype_image(image_tensor, dtype=dtype, scale=True) return datapoints.Image(image_tensor) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 43a7df4f3a2..32a68e14017 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -17,7 +17,7 @@ from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.prototype import datapoints, transforms -from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil +from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image from torchvision.transforms.v2.utils import check_type, is_simple_tensor BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -387,7 +387,7 @@ def make_datapoints(): size = (600, 800) num_objects = 22 - pil_image = to_image_pil(make_image(size=size, color_space="RGB")) + pil_image = to_pil_image(make_image(size=size, color_space="RGB")) target = { "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 4db2abe7fc4..ade3bdf0b51 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -666,19 +666,19 @@ def test_check_transformed_types(self, inpt_type, mocker): t(inpt) -class TestToImageTensor: +class TestToImage: @pytest.mark.parametrize( "inpt_type", [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], ) def test__transform(self, inpt_type, mocker): fn = mocker.patch( - "torchvision.transforms.v2.functional.to_image_tensor", + "torchvision.transforms.v2.functional.to_image", return_value=torch.rand(1, 3, 8, 8), ) inpt = mocker.MagicMock(spec=inpt_type) - transform = transforms.ToImageTensor() + transform = transforms.ToImage() transform(inpt) if inpt_type in (datapoints.BoundingBoxes, datapoints.Image, str, int): assert fn.call_count == 0 @@ -686,30 +686,13 @@ def test__transform(self, inpt_type, mocker): fn.assert_called_once_with(inpt) -class TestToImagePIL: - @pytest.mark.parametrize( - "inpt_type", - [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], - ) - def test__transform(self, inpt_type, mocker): - fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") - - inpt = mocker.MagicMock(spec=inpt_type) - transform = transforms.ToImagePIL() - transform(inpt) - if inpt_type in (datapoints.BoundingBoxes, PIL.Image.Image, str, int): - assert fn.call_count == 0 - else: - fn.assert_called_once_with(inpt, mode=transform.mode) - - class TestToPILImage: @pytest.mark.parametrize( "inpt_type", [torch.Tensor, PIL.Image.Image, datapoints.Image, np.ndarray, datapoints.BoundingBoxes, str, int], ) def test__transform(self, inpt_type, mocker): - fn = mocker.patch("torchvision.transforms.v2.functional.to_image_pil") + fn = mocker.patch("torchvision.transforms.v2.functional.to_pil_image") inpt = mocker.MagicMock(spec=inpt_type) transform = transforms.ToPILImage() @@ -1013,7 +996,7 @@ def test_antialias_warning(): @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("label_type", (torch.Tensor, int)) @pytest.mark.parametrize("dataset_return_type", (dict, tuple)) -@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) +@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage)) def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): image = datapoints.Image(torch.randint(0, 256, size=(1, 3, 250, 250), dtype=torch.uint8)) @@ -1074,7 +1057,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) -@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) +@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImage)) @pytest.mark.parametrize("sanitize", (True, False)) def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): torch.manual_seed(0) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index bcab4355c54..5855fbe447f 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -30,7 +30,7 @@ from torchvision.transforms import functional as legacy_F from torchvision.transforms.v2 import functional as prototype_F from torchvision.transforms.v2._utils import _get_fill -from torchvision.transforms.v2.functional import to_image_pil +from torchvision.transforms.v2.functional import to_pil_image from torchvision.transforms.v2.utils import query_size DEFAULT_MAKE_IMAGES_KWARGS = dict(color_spaces=["RGB"], extra_dims=[(4,)]) @@ -630,7 +630,7 @@ def check_call_consistency( ) if image.ndim == 3 and supports_pil: - image_pil = to_image_pil(image) + image_pil = to_pil_image(image) try: torch.manual_seed(0) @@ -869,7 +869,7 @@ def test_pil_to_tensor(self): legacy_transform = legacy_transforms.PILToTensor() for image in make_images(extra_dims=[()]): - image_pil = to_image_pil(image) + image_pil = to_pil_image(image) assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) @@ -879,7 +879,7 @@ def test_to_tensor(self): legacy_transform = legacy_transforms.ToTensor() for image in make_images(extra_dims=[()]): - image_pil = to_image_pil(image) + image_pil = to_pil_image(image) image_numpy = np.array(image_pil) assert_equal(prototype_transform(image_pil), legacy_transform(image_pil)) @@ -1088,7 +1088,7 @@ def make_datapoints(self, with_mask=True): def make_label(extra_dims, categories): return torch.randint(categories, extra_dims, dtype=torch.int64) - pil_image = to_image_pil(make_image(size=size, color_space="RGB")) + pil_image = to_pil_image(make_image(size=size, color_space="RGB")) target = { "boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float), "labels": make_label(extra_dims=(num_objects,), categories=80), @@ -1192,7 +1192,7 @@ def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): conv_fns = [] if supports_pil: - conv_fns.append(to_image_pil) + conv_fns.append(to_pil_image) conv_fns.extend([torch.Tensor, lambda x: x]) for conv_fn in conv_fns: @@ -1201,8 +1201,8 @@ def make_datapoints(self, supports_pil=True, image_dtype=torch.uint8): dp = (conv_fn(datapoint_image), datapoint_mask) dp_ref = ( - to_image_pil(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor), - to_image_pil(datapoint_mask), + to_pil_image(datapoint_image) if supports_pil else datapoint_image.as_subclass(torch.Tensor), + to_pil_image(datapoint_mask), ) yield dp, dp_ref diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index bf447c8ce71..14a1f82b2cf 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -280,12 +280,12 @@ def test_float32_vs_uint8(self, test_id, info, args_kwargs): adapted_other_args, adapted_kwargs = info.float32_vs_uint8(other_args, kwargs) actual = info.kernel( - F.to_dtype_image_tensor(input, dtype=torch.float32, scale=True), + F.to_dtype_image(input, dtype=torch.float32, scale=True), *adapted_other_args, **adapted_kwargs, ) - expected = F.to_dtype_image_tensor(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True) + expected = F.to_dtype_image(info.kernel(input, *other_args, **kwargs), dtype=torch.float32, scale=True) assert_close( actual, @@ -377,7 +377,7 @@ def test_pil_output_type(self, info, args_kwargs): if image_datapoint.ndim > 3: pytest.skip("Input is batched") - image_pil = F.to_image_pil(image_datapoint) + image_pil = F.to_pil_image(image_datapoint) output = info.dispatcher(image_pil, *other_args, **kwargs) @@ -470,7 +470,7 @@ def test_bounding_boxes_format_consistency(self, info, args_kwargs): (F.hflip, F.horizontal_flip), (F.vflip, F.vertical_flip), (F.get_image_num_channels, F.get_num_channels), - (F.to_pil_image, F.to_image_pil), + (F.to_pil_image, F.to_pil_image), (F.elastic_transform, F.elastic), (F.to_grayscale, F.rgb_to_grayscale), ] @@ -493,7 +493,7 @@ def assert_samples_from_standard_normal(t): mean = image.mean(dim=(1, 2)).tolist() std = image.std(dim=(1, 2)).tolist() - assert_samples_from_standard_normal(F.normalize_image_tensor(image, mean, std)) + assert_samples_from_standard_normal(F.normalize_image(image, mean, std)) class TestClampBoundingBoxes: @@ -899,7 +899,7 @@ def _compute_expected_mask(mask, output_size): _, image_height, image_width = mask.shape if crop_width > image_height or crop_height > image_width: padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) - mask = F.pad_image_tensor(mask, padding, fill=0) + mask = F.pad_image(mask, padding, fill=0) left = round((image_width - crop_width) * 0.5) top = round((image_height - crop_height) * 0.5) @@ -920,7 +920,7 @@ def _compute_expected_mask(mask, output_size): @pytest.mark.parametrize("ksize", [(3, 3), [3, 5], (23, 23)]) @pytest.mark.parametrize("sigma", [[0.5, 0.5], (0.5, 0.5), (0.8, 0.8), (1.7, 1.7)]) def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, sigma): - fn = F.gaussian_blur_image_tensor + fn = F.gaussian_blur_image # true_cv2_results = { # # np_img = np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) @@ -977,8 +977,8 @@ def test_correctness_gaussian_blur_image_tensor(device, canvas_size, dt, ksize, PIL.Image.new("RGB", (32, 32), 122), ], ) -def test_to_image_tensor(inpt): - output = F.to_image_tensor(inpt) +def test_to_image(inpt): + output = F.to_image(inpt) assert isinstance(output, torch.Tensor) assert output.shape == (3, 32, 32) @@ -993,8 +993,8 @@ def test_to_image_tensor(inpt): ], ) @pytest.mark.parametrize("mode", [None, "RGB"]) -def test_to_image_pil(inpt, mode): - output = F.to_image_pil(inpt, mode=mode) +def test_to_pil_image(inpt, mode): + output = F.to_pil_image(inpt, mode=mode) assert isinstance(output, PIL.Image.Image) assert np.asarray(inpt).sum() == np.asarray(output).sum() @@ -1002,12 +1002,12 @@ def test_to_image_pil(inpt, mode): def test_equalize_image_tensor_edge_cases(): inpt = torch.zeros(3, 200, 200, dtype=torch.uint8) - output = F.equalize_image_tensor(inpt) + output = F.equalize_image(inpt) torch.testing.assert_close(inpt, output) inpt = torch.zeros(5, 3, 200, 200, dtype=torch.uint8) inpt[..., 100:, 100:] = 1 - output = F.equalize_image_tensor(inpt) + output = F.equalize_image(inpt) assert output.unique().tolist() == [0, 255] @@ -1024,7 +1024,7 @@ def test_correctness_uniform_temporal_subsample(device): # TODO: We can remove this test and related torchvision workaround # once we fixed related pytorch issue: https://github.com/pytorch/pytorch/issues/68430 @make_info_args_kwargs_parametrization( - [info for info in KERNEL_INFOS if info.kernel is F.resize_image_tensor], + [info for info in KERNEL_INFOS if info.kernel is F.resize_image], args_kwargs_fn=lambda info: info.reference_inputs_fn(), ) def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwargs): diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c83327a069e..9d359e59559 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -437,7 +437,7 @@ def test_kernel_image_tensor(self, size, interpolation, use_max_size, antialias, check_cuda_vs_cpu_tolerances = dict(rtol=0, atol=atol / 255 if dtype.is_floating_point else atol) check_kernel( - F.resize_image_tensor, + F.resize_image, make_image(self.INPUT_SIZE, dtype=dtype, device=device), size=size, interpolation=interpolation, @@ -495,9 +495,9 @@ def test_functional(self, size, make_input): @pytest.mark.parametrize( ("kernel", "input_type"), [ - (F.resize_image_tensor, torch.Tensor), - (F.resize_image_pil, PIL.Image.Image), - (F.resize_image_tensor, datapoints.Image), + (F.resize_image, torch.Tensor), + (F._resize_image_pil, PIL.Image.Image), + (F.resize_image, datapoints.Image), (F.resize_bounding_boxes, datapoints.BoundingBoxes), (F.resize_mask, datapoints.Mask), (F.resize_video, datapoints.Video), @@ -541,9 +541,7 @@ def test_image_correctness(self, size, interpolation, use_max_size, fn): image = make_image(self.INPUT_SIZE, dtype=torch.uint8) actual = fn(image, size=size, interpolation=interpolation, **max_size_kwarg, antialias=True) - expected = F.to_image_tensor( - F.resize(F.to_image_pil(image), size=size, interpolation=interpolation, **max_size_kwarg) - ) + expected = F.to_image(F.resize(F.to_pil_image(image), size=size, interpolation=interpolation, **max_size_kwarg)) self._check_output_size(image, actual, size=size, **max_size_kwarg) torch.testing.assert_close(actual, expected, atol=1, rtol=0) @@ -739,7 +737,7 @@ class TestHorizontalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_image_tensor(self, dtype, device): - check_kernel(F.horizontal_flip_image_tensor, make_image(dtype=dtype, device=device)) + check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @@ -770,9 +768,9 @@ def test_functional(self, make_input): @pytest.mark.parametrize( ("kernel", "input_type"), [ - (F.horizontal_flip_image_tensor, torch.Tensor), - (F.horizontal_flip_image_pil, PIL.Image.Image), - (F.horizontal_flip_image_tensor, datapoints.Image), + (F.horizontal_flip_image, torch.Tensor), + (F._horizontal_flip_image_pil, PIL.Image.Image), + (F.horizontal_flip_image, datapoints.Image), (F.horizontal_flip_bounding_boxes, datapoints.BoundingBoxes), (F.horizontal_flip_mask, datapoints.Mask), (F.horizontal_flip_video, datapoints.Video), @@ -796,7 +794,7 @@ def test_image_correctness(self, fn): image = make_image(dtype=torch.uint8, device="cpu") actual = fn(image) - expected = F.to_image_tensor(F.horizontal_flip(F.to_image_pil(image))) + expected = F.to_image(F.horizontal_flip(F.to_pil_image(image))) torch.testing.assert_close(actual, expected) @@ -900,7 +898,7 @@ def test_kernel_image_tensor(self, param, value, dtype, device): if param == "fill": value = adapt_fill(value, dtype=dtype) self._check_kernel( - F.affine_image_tensor, + F.affine_image, make_image(dtype=dtype, device=device), **{param: value}, check_scripted_vs_eager=not (param in {"shear", "fill"} and isinstance(value, (int, float))), @@ -946,9 +944,9 @@ def test_functional(self, make_input): @pytest.mark.parametrize( ("kernel", "input_type"), [ - (F.affine_image_tensor, torch.Tensor), - (F.affine_image_pil, PIL.Image.Image), - (F.affine_image_tensor, datapoints.Image), + (F.affine_image, torch.Tensor), + (F._affine_image_pil, PIL.Image.Image), + (F.affine_image, datapoints.Image), (F.affine_bounding_boxes, datapoints.BoundingBoxes), (F.affine_mask, datapoints.Mask), (F.affine_video, datapoints.Video), @@ -991,9 +989,9 @@ def test_functional_image_correctness(self, angle, translate, scale, shear, cent interpolation=interpolation, fill=fill, ) - expected = F.to_image_tensor( + expected = F.to_image( F.affine( - F.to_image_pil(image), + F.to_pil_image(image), angle=angle, translate=translate, scale=scale, @@ -1026,7 +1024,7 @@ def test_transform_image_correctness(self, center, interpolation, fill, seed): actual = transform(image) torch.manual_seed(seed) - expected = F.to_image_tensor(transform(F.to_image_pil(image))) + expected = F.to_image(transform(F.to_pil_image(image))) mae = (actual.float() - expected.float()).abs().mean() assert mae < 2 if interpolation is transforms.InterpolationMode.NEAREST else 8 @@ -1204,7 +1202,7 @@ class TestVerticalFlip: @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_kernel_image_tensor(self, dtype, device): - check_kernel(F.vertical_flip_image_tensor, make_image(dtype=dtype, device=device)) + check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("dtype", [torch.float32, torch.int64]) @@ -1235,9 +1233,9 @@ def test_functional(self, make_input): @pytest.mark.parametrize( ("kernel", "input_type"), [ - (F.vertical_flip_image_tensor, torch.Tensor), - (F.vertical_flip_image_pil, PIL.Image.Image), - (F.vertical_flip_image_tensor, datapoints.Image), + (F.vertical_flip_image, torch.Tensor), + (F._vertical_flip_image_pil, PIL.Image.Image), + (F.vertical_flip_image, datapoints.Image), (F.vertical_flip_bounding_boxes, datapoints.BoundingBoxes), (F.vertical_flip_mask, datapoints.Mask), (F.vertical_flip_video, datapoints.Video), @@ -1259,7 +1257,7 @@ def test_image_correctness(self, fn): image = make_image(dtype=torch.uint8, device="cpu") actual = fn(image) - expected = F.to_image_tensor(F.vertical_flip(F.to_image_pil(image))) + expected = F.to_image(F.vertical_flip(F.to_pil_image(image))) torch.testing.assert_close(actual, expected) @@ -1339,7 +1337,7 @@ def test_kernel_image_tensor(self, param, value, dtype, device): if param != "angle": kwargs["angle"] = self._MINIMAL_AFFINE_KWARGS["angle"] check_kernel( - F.rotate_image_tensor, + F.rotate_image, make_image(dtype=dtype, device=device), **kwargs, check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))), @@ -1385,9 +1383,9 @@ def test_functional(self, make_input): @pytest.mark.parametrize( ("kernel", "input_type"), [ - (F.rotate_image_tensor, torch.Tensor), - (F.rotate_image_pil, PIL.Image.Image), - (F.rotate_image_tensor, datapoints.Image), + (F.rotate_image, torch.Tensor), + (F._rotate_image_pil, PIL.Image.Image), + (F.rotate_image, datapoints.Image), (F.rotate_bounding_boxes, datapoints.BoundingBoxes), (F.rotate_mask, datapoints.Mask), (F.rotate_video, datapoints.Video), @@ -1419,9 +1417,9 @@ def test_functional_image_correctness(self, angle, center, interpolation, expand fill = adapt_fill(fill, dtype=torch.uint8) actual = F.rotate(image, angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill) - expected = F.to_image_tensor( + expected = F.to_image( F.rotate( - F.to_image_pil(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill + F.to_pil_image(image), angle=angle, center=center, interpolation=interpolation, expand=expand, fill=fill ) ) @@ -1452,7 +1450,7 @@ def test_transform_image_correctness(self, center, interpolation, expand, fill, actual = transform(image) torch.manual_seed(seed) - expected = F.to_image_tensor(transform(F.to_image_pil(image))) + expected = F.to_image(transform(F.to_pil_image(image))) mae = (actual.float() - expected.float()).abs().mean() assert mae < 1 if interpolation is transforms.InterpolationMode.NEAREST else 6 @@ -1621,8 +1619,8 @@ class TestToDtype: @pytest.mark.parametrize( ("kernel", "make_input"), [ - (F.to_dtype_image_tensor, make_image_tensor), - (F.to_dtype_image_tensor, make_image), + (F.to_dtype_image, make_image_tensor), + (F.to_dtype_image, make_image), (F.to_dtype_video, make_video), ], ) @@ -1801,7 +1799,7 @@ class TestAdjustBrightness: @pytest.mark.parametrize( ("kernel", "make_input"), [ - (F.adjust_brightness_image_tensor, make_image), + (F.adjust_brightness_image, make_image), (F.adjust_brightness_video, make_video), ], ) @@ -1817,9 +1815,9 @@ def test_functional(self, make_input): @pytest.mark.parametrize( ("kernel", "input_type"), [ - (F.adjust_brightness_image_tensor, torch.Tensor), - (F.adjust_brightness_image_pil, PIL.Image.Image), - (F.adjust_brightness_image_tensor, datapoints.Image), + (F.adjust_brightness_image, torch.Tensor), + (F._adjust_brightness_image_pil, PIL.Image.Image), + (F.adjust_brightness_image, datapoints.Image), (F.adjust_brightness_video, datapoints.Video), ], ) @@ -1831,7 +1829,7 @@ def test_image_correctness(self, brightness_factor): image = make_image(dtype=torch.uint8, device="cpu") actual = F.adjust_brightness(image, brightness_factor=brightness_factor) - expected = F.to_image_tensor(F.adjust_brightness(F.to_image_pil(image), brightness_factor=brightness_factor)) + expected = F.to_image(F.adjust_brightness(F.to_pil_image(image), brightness_factor=brightness_factor)) torch.testing.assert_close(actual, expected) @@ -1979,9 +1977,9 @@ class TestShapeGetters: @pytest.mark.parametrize( ("kernel", "make_input"), [ - (F.get_dimensions_image_tensor, make_image_tensor), - (F.get_dimensions_image_pil, make_image_pil), - (F.get_dimensions_image_tensor, make_image), + (F.get_dimensions_image, make_image_tensor), + (F._get_dimensions_image_pil, make_image_pil), + (F.get_dimensions_image, make_image), (F.get_dimensions_video, make_video), ], ) @@ -1996,9 +1994,9 @@ def test_get_dimensions(self, kernel, make_input): @pytest.mark.parametrize( ("kernel", "make_input"), [ - (F.get_num_channels_image_tensor, make_image_tensor), - (F.get_num_channels_image_pil, make_image_pil), - (F.get_num_channels_image_tensor, make_image), + (F.get_num_channels_image, make_image_tensor), + (F._get_num_channels_image_pil, make_image_pil), + (F.get_num_channels_image, make_image), (F.get_num_channels_video, make_video), ], ) @@ -2012,9 +2010,9 @@ def test_get_num_channels(self, kernel, make_input): @pytest.mark.parametrize( ("kernel", "make_input"), [ - (F.get_size_image_tensor, make_image_tensor), - (F.get_size_image_pil, make_image_pil), - (F.get_size_image_tensor, make_image), + (F.get_size_image, make_image_tensor), + (F._get_size_image_pil, make_image_pil), + (F.get_size_image, make_image), (F.get_size_bounding_boxes, make_bounding_box), (F.get_size_mask, make_detection_mask), (F.get_size_mask, make_segmentation_mask), @@ -2101,7 +2099,7 @@ def test_errors(self): F.register_kernel(F.resize, object) with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"): - F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) + F.register_kernel(F.resize, datapoints.Image)(F.resize_image) class CustomDatapoint(datapoints.Datapoint): pass @@ -2119,9 +2117,9 @@ class TestGetKernel: # We are using F.resize as functional and the kernels below as proxy. Any other functional / kernels combination # would also be fine KERNELS = { - torch.Tensor: F.resize_image_tensor, - PIL.Image.Image: F.resize_image_pil, - datapoints.Image: F.resize_image_tensor, + torch.Tensor: F.resize_image, + PIL.Image.Image: F._resize_image_pil, + datapoints.Image: F.resize_image, datapoints.BoundingBoxes: F.resize_bounding_boxes, datapoints.Mask: F.resize_mask, datapoints.Video: F.resize_video, @@ -2217,10 +2215,10 @@ class TestPermuteChannels: @pytest.mark.parametrize( ("kernel", "make_input"), [ - (F.permute_channels_image_tensor, make_image_tensor), + (F.permute_channels_image, make_image_tensor), # FIXME # check_kernel does not support PIL kernel, but it should - (F.permute_channels_image_tensor, make_image), + (F.permute_channels_image, make_image), (F.permute_channels_video, make_video), ], ) @@ -2236,9 +2234,9 @@ def test_functional(self, make_input): @pytest.mark.parametrize( ("kernel", "input_type"), [ - (F.permute_channels_image_tensor, torch.Tensor), - (F.permute_channels_image_pil, PIL.Image.Image), - (F.permute_channels_image_tensor, datapoints.Image), + (F.permute_channels_image, torch.Tensor), + (F._permute_channels_image_pil, PIL.Image.Image), + (F.permute_channels_image, datapoints.Image), (F.permute_channels_video, datapoints.Video), ], ) diff --git a/test/test_transforms_v2_utils.py b/test/test_transforms_v2_utils.py index f880dac6c67..0cf7a77ac0d 100644 --- a/test/test_transforms_v2_utils.py +++ b/test/test_transforms_v2_utils.py @@ -7,7 +7,7 @@ from common_utils import DEFAULT_SIZE, make_bounding_box, make_detection_mask, make_image from torchvision import datapoints -from torchvision.transforms.v2.functional import to_image_pil +from torchvision.transforms.v2.functional import to_pil_image from torchvision.transforms.v2.utils import has_all, has_any @@ -44,7 +44,7 @@ True, ), ( - (to_image_pil(IMAGE),), + (to_pil_image(IMAGE),), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True, ), diff --git a/test/transforms_v2_dispatcher_infos.py b/test/transforms_v2_dispatcher_infos.py index cef5c360430..8f212c850cb 100644 --- a/test/transforms_v2_dispatcher_infos.py +++ b/test/transforms_v2_dispatcher_infos.py @@ -142,32 +142,32 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.crop, kernels={ - datapoints.Image: F.crop_image_tensor, + datapoints.Image: F.crop_image, datapoints.Video: F.crop_video, datapoints.BoundingBoxes: F.crop_bounding_boxes, datapoints.Mask: F.crop_mask, }, - pil_kernel_info=PILKernelInfo(F.crop_image_pil, kernel_name="crop_image_pil"), + pil_kernel_info=PILKernelInfo(F._crop_image_pil, kernel_name="crop_image_pil"), ), DispatcherInfo( F.resized_crop, kernels={ - datapoints.Image: F.resized_crop_image_tensor, + datapoints.Image: F.resized_crop_image, datapoints.Video: F.resized_crop_video, datapoints.BoundingBoxes: F.resized_crop_bounding_boxes, datapoints.Mask: F.resized_crop_mask, }, - pil_kernel_info=PILKernelInfo(F.resized_crop_image_pil), + pil_kernel_info=PILKernelInfo(F._resized_crop_image_pil), ), DispatcherInfo( F.pad, kernels={ - datapoints.Image: F.pad_image_tensor, + datapoints.Image: F.pad_image, datapoints.Video: F.pad_video, datapoints.BoundingBoxes: F.pad_bounding_boxes, datapoints.Mask: F.pad_mask, }, - pil_kernel_info=PILKernelInfo(F.pad_image_pil, kernel_name="pad_image_pil"), + pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"), test_marks=[ *xfails_pil( reason=( @@ -184,12 +184,12 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.perspective, kernels={ - datapoints.Image: F.perspective_image_tensor, + datapoints.Image: F.perspective_image, datapoints.Video: F.perspective_video, datapoints.BoundingBoxes: F.perspective_bounding_boxes, datapoints.Mask: F.perspective_mask, }, - pil_kernel_info=PILKernelInfo(F.perspective_image_pil), + pil_kernel_info=PILKernelInfo(F._perspective_image_pil), test_marks=[ *xfails_pil_if_fill_sequence_needs_broadcast, xfail_jit_python_scalar_arg("fill"), @@ -198,23 +198,23 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.elastic, kernels={ - datapoints.Image: F.elastic_image_tensor, + datapoints.Image: F.elastic_image, datapoints.Video: F.elastic_video, datapoints.BoundingBoxes: F.elastic_bounding_boxes, datapoints.Mask: F.elastic_mask, }, - pil_kernel_info=PILKernelInfo(F.elastic_image_pil), + pil_kernel_info=PILKernelInfo(F._elastic_image_pil), test_marks=[xfail_jit_python_scalar_arg("fill")], ), DispatcherInfo( F.center_crop, kernels={ - datapoints.Image: F.center_crop_image_tensor, + datapoints.Image: F.center_crop_image, datapoints.Video: F.center_crop_video, datapoints.BoundingBoxes: F.center_crop_bounding_boxes, datapoints.Mask: F.center_crop_mask, }, - pil_kernel_info=PILKernelInfo(F.center_crop_image_pil), + pil_kernel_info=PILKernelInfo(F._center_crop_image_pil), test_marks=[ xfail_jit_python_scalar_arg("output_size"), ], @@ -222,10 +222,10 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.gaussian_blur, kernels={ - datapoints.Image: F.gaussian_blur_image_tensor, + datapoints.Image: F.gaussian_blur_image, datapoints.Video: F.gaussian_blur_video, }, - pil_kernel_info=PILKernelInfo(F.gaussian_blur_image_pil), + pil_kernel_info=PILKernelInfo(F._gaussian_blur_image_pil), test_marks=[ xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("sigma"), @@ -234,58 +234,58 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.equalize, kernels={ - datapoints.Image: F.equalize_image_tensor, + datapoints.Image: F.equalize_image, datapoints.Video: F.equalize_video, }, - pil_kernel_info=PILKernelInfo(F.equalize_image_pil, kernel_name="equalize_image_pil"), + pil_kernel_info=PILKernelInfo(F._equalize_image_pil, kernel_name="equalize_image_pil"), ), DispatcherInfo( F.invert, kernels={ - datapoints.Image: F.invert_image_tensor, + datapoints.Image: F.invert_image, datapoints.Video: F.invert_video, }, - pil_kernel_info=PILKernelInfo(F.invert_image_pil, kernel_name="invert_image_pil"), + pil_kernel_info=PILKernelInfo(F._invert_image_pil, kernel_name="invert_image_pil"), ), DispatcherInfo( F.posterize, kernels={ - datapoints.Image: F.posterize_image_tensor, + datapoints.Image: F.posterize_image, datapoints.Video: F.posterize_video, }, - pil_kernel_info=PILKernelInfo(F.posterize_image_pil, kernel_name="posterize_image_pil"), + pil_kernel_info=PILKernelInfo(F._posterize_image_pil, kernel_name="posterize_image_pil"), ), DispatcherInfo( F.solarize, kernels={ - datapoints.Image: F.solarize_image_tensor, + datapoints.Image: F.solarize_image, datapoints.Video: F.solarize_video, }, - pil_kernel_info=PILKernelInfo(F.solarize_image_pil, kernel_name="solarize_image_pil"), + pil_kernel_info=PILKernelInfo(F._solarize_image_pil, kernel_name="solarize_image_pil"), ), DispatcherInfo( F.autocontrast, kernels={ - datapoints.Image: F.autocontrast_image_tensor, + datapoints.Image: F.autocontrast_image, datapoints.Video: F.autocontrast_video, }, - pil_kernel_info=PILKernelInfo(F.autocontrast_image_pil, kernel_name="autocontrast_image_pil"), + pil_kernel_info=PILKernelInfo(F._autocontrast_image_pil, kernel_name="autocontrast_image_pil"), ), DispatcherInfo( F.adjust_sharpness, kernels={ - datapoints.Image: F.adjust_sharpness_image_tensor, + datapoints.Image: F.adjust_sharpness_image, datapoints.Video: F.adjust_sharpness_video, }, - pil_kernel_info=PILKernelInfo(F.adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"), + pil_kernel_info=PILKernelInfo(F._adjust_sharpness_image_pil, kernel_name="adjust_sharpness_image_pil"), ), DispatcherInfo( F.erase, kernels={ - datapoints.Image: F.erase_image_tensor, + datapoints.Image: F.erase_image, datapoints.Video: F.erase_video, }, - pil_kernel_info=PILKernelInfo(F.erase_image_pil), + pil_kernel_info=PILKernelInfo(F._erase_image_pil), test_marks=[ skip_dispatch_datapoint, ], @@ -293,42 +293,42 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.adjust_contrast, kernels={ - datapoints.Image: F.adjust_contrast_image_tensor, + datapoints.Image: F.adjust_contrast_image, datapoints.Video: F.adjust_contrast_video, }, - pil_kernel_info=PILKernelInfo(F.adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"), + pil_kernel_info=PILKernelInfo(F._adjust_contrast_image_pil, kernel_name="adjust_contrast_image_pil"), ), DispatcherInfo( F.adjust_gamma, kernels={ - datapoints.Image: F.adjust_gamma_image_tensor, + datapoints.Image: F.adjust_gamma_image, datapoints.Video: F.adjust_gamma_video, }, - pil_kernel_info=PILKernelInfo(F.adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"), + pil_kernel_info=PILKernelInfo(F._adjust_gamma_image_pil, kernel_name="adjust_gamma_image_pil"), ), DispatcherInfo( F.adjust_hue, kernels={ - datapoints.Image: F.adjust_hue_image_tensor, + datapoints.Image: F.adjust_hue_image, datapoints.Video: F.adjust_hue_video, }, - pil_kernel_info=PILKernelInfo(F.adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"), + pil_kernel_info=PILKernelInfo(F._adjust_hue_image_pil, kernel_name="adjust_hue_image_pil"), ), DispatcherInfo( F.adjust_saturation, kernels={ - datapoints.Image: F.adjust_saturation_image_tensor, + datapoints.Image: F.adjust_saturation_image, datapoints.Video: F.adjust_saturation_video, }, - pil_kernel_info=PILKernelInfo(F.adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"), + pil_kernel_info=PILKernelInfo(F._adjust_saturation_image_pil, kernel_name="adjust_saturation_image_pil"), ), DispatcherInfo( F.five_crop, kernels={ - datapoints.Image: F.five_crop_image_tensor, + datapoints.Image: F.five_crop_image, datapoints.Video: F.five_crop_video, }, - pil_kernel_info=PILKernelInfo(F.five_crop_image_pil), + pil_kernel_info=PILKernelInfo(F._five_crop_image_pil), test_marks=[ xfail_jit_python_scalar_arg("size"), *multi_crop_skips, @@ -337,19 +337,19 @@ def fill_sequence_needs_broadcast(args_kwargs): DispatcherInfo( F.ten_crop, kernels={ - datapoints.Image: F.ten_crop_image_tensor, + datapoints.Image: F.ten_crop_image, datapoints.Video: F.ten_crop_video, }, test_marks=[ xfail_jit_python_scalar_arg("size"), *multi_crop_skips, ], - pil_kernel_info=PILKernelInfo(F.ten_crop_image_pil), + pil_kernel_info=PILKernelInfo(F._ten_crop_image_pil), ), DispatcherInfo( F.normalize, kernels={ - datapoints.Image: F.normalize_image_tensor, + datapoints.Image: F.normalize_image, datapoints.Video: F.normalize_video, }, test_marks=[ diff --git a/test/transforms_v2_kernel_infos.py b/test/transforms_v2_kernel_infos.py index ac5651d3217..acb9a857750 100644 --- a/test/transforms_v2_kernel_infos.py +++ b/test/transforms_v2_kernel_infos.py @@ -122,12 +122,12 @@ def wrapper(input_tensor, *other_args, **kwargs): f"Can only test single tensor images against PIL, but input has shape {input_tensor.shape}" ) - input_pil = F.to_image_pil(input_tensor) + input_pil = F.to_pil_image(input_tensor) output_pil = pil_kernel(input_pil, *other_args, **kwargs) if not isinstance(output_pil, PIL.Image.Image): return output_pil - output_tensor = F.to_image_tensor(output_pil) + output_tensor = F.to_image(output_pil) # 2D mask shenanigans if output_tensor.ndim == 2 and input_tensor.ndim == 3: @@ -331,10 +331,10 @@ def reference_inputs_crop_bounding_boxes(): KERNEL_INFOS.extend( [ KernelInfo( - F.crop_image_tensor, + F.crop_image, kernel_name="crop_image_tensor", sample_inputs_fn=sample_inputs_crop_image_tensor, - reference_fn=pil_reference_wrapper(F.crop_image_pil), + reference_fn=pil_reference_wrapper(F._crop_image_pil), reference_inputs_fn=reference_inputs_crop_image_tensor, float32_vs_uint8=True, ), @@ -347,7 +347,7 @@ def reference_inputs_crop_bounding_boxes(): KernelInfo( F.crop_mask, sample_inputs_fn=sample_inputs_crop_mask, - reference_fn=pil_reference_wrapper(F.crop_image_pil), + reference_fn=pil_reference_wrapper(F._crop_image_pil), reference_inputs_fn=reference_inputs_crop_mask, float32_vs_uint8=True, ), @@ -373,7 +373,7 @@ def reference_resized_crop_image_tensor(*args, **kwargs): F.InterpolationMode.BICUBIC, }: raise pytest.UsageError("Anti-aliasing is always active in PIL") - return F.resized_crop_image_pil(*args, **kwargs) + return F._resized_crop_image_pil(*args, **kwargs) def reference_inputs_resized_crop_image_tensor(): @@ -417,7 +417,7 @@ def sample_inputs_resized_crop_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.resized_crop_image_tensor, + F.resized_crop_image, sample_inputs_fn=sample_inputs_resized_crop_image_tensor, reference_fn=reference_resized_crop_image_tensor, reference_inputs_fn=reference_inputs_resized_crop_image_tensor, @@ -570,9 +570,9 @@ def pad_xfail_jit_fill_condition(args_kwargs): KERNEL_INFOS.extend( [ KernelInfo( - F.pad_image_tensor, + F.pad_image, sample_inputs_fn=sample_inputs_pad_image_tensor, - reference_fn=pil_reference_wrapper(F.pad_image_pil), + reference_fn=pil_reference_wrapper(F._pad_image_pil), reference_inputs_fn=reference_inputs_pad_image_tensor, float32_vs_uint8=float32_vs_uint8_fill_adapter, closeness_kwargs=float32_vs_uint8_pixel_difference(), @@ -595,7 +595,7 @@ def pad_xfail_jit_fill_condition(args_kwargs): KernelInfo( F.pad_mask, sample_inputs_fn=sample_inputs_pad_mask, - reference_fn=pil_reference_wrapper(F.pad_image_pil), + reference_fn=pil_reference_wrapper(F._pad_image_pil), reference_inputs_fn=reference_inputs_pad_mask, float32_vs_uint8=float32_vs_uint8_fill_adapter, ), @@ -690,9 +690,9 @@ def sample_inputs_perspective_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.perspective_image_tensor, + F.perspective_image, sample_inputs_fn=sample_inputs_perspective_image_tensor, - reference_fn=pil_reference_wrapper(F.perspective_image_pil), + reference_fn=pil_reference_wrapper(F._perspective_image_pil), reference_inputs_fn=reference_inputs_perspective_image_tensor, float32_vs_uint8=float32_vs_uint8_fill_adapter, closeness_kwargs={ @@ -715,7 +715,7 @@ def sample_inputs_perspective_video(): KernelInfo( F.perspective_mask, sample_inputs_fn=sample_inputs_perspective_mask, - reference_fn=pil_reference_wrapper(F.perspective_image_pil), + reference_fn=pil_reference_wrapper(F._perspective_image_pil), reference_inputs_fn=reference_inputs_perspective_mask, float32_vs_uint8=True, closeness_kwargs={ @@ -786,7 +786,7 @@ def sample_inputs_elastic_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.elastic_image_tensor, + F.elastic_image, sample_inputs_fn=sample_inputs_elastic_image_tensor, reference_inputs_fn=reference_inputs_elastic_image_tensor, float32_vs_uint8=float32_vs_uint8_fill_adapter, @@ -870,9 +870,9 @@ def sample_inputs_center_crop_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.center_crop_image_tensor, + F.center_crop_image, sample_inputs_fn=sample_inputs_center_crop_image_tensor, - reference_fn=pil_reference_wrapper(F.center_crop_image_pil), + reference_fn=pil_reference_wrapper(F._center_crop_image_pil), reference_inputs_fn=reference_inputs_center_crop_image_tensor, float32_vs_uint8=True, test_marks=[ @@ -889,7 +889,7 @@ def sample_inputs_center_crop_video(): KernelInfo( F.center_crop_mask, sample_inputs_fn=sample_inputs_center_crop_mask, - reference_fn=pil_reference_wrapper(F.center_crop_image_pil), + reference_fn=pil_reference_wrapper(F._center_crop_image_pil), reference_inputs_fn=reference_inputs_center_crop_mask, float32_vs_uint8=True, test_marks=[ @@ -924,7 +924,7 @@ def sample_inputs_gaussian_blur_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.gaussian_blur_image_tensor, + F.gaussian_blur_image, sample_inputs_fn=sample_inputs_gaussian_blur_image_tensor, closeness_kwargs=cuda_vs_cpu_pixel_difference(), test_marks=[ @@ -1010,10 +1010,10 @@ def sample_inputs_equalize_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.equalize_image_tensor, + F.equalize_image, kernel_name="equalize_image_tensor", sample_inputs_fn=sample_inputs_equalize_image_tensor, - reference_fn=pil_reference_wrapper(F.equalize_image_pil), + reference_fn=pil_reference_wrapper(F._equalize_image_pil), float32_vs_uint8=True, reference_inputs_fn=reference_inputs_equalize_image_tensor, ), @@ -1043,10 +1043,10 @@ def sample_inputs_invert_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.invert_image_tensor, + F.invert_image, kernel_name="invert_image_tensor", sample_inputs_fn=sample_inputs_invert_image_tensor, - reference_fn=pil_reference_wrapper(F.invert_image_pil), + reference_fn=pil_reference_wrapper(F._invert_image_pil), reference_inputs_fn=reference_inputs_invert_image_tensor, float32_vs_uint8=True, ), @@ -1082,10 +1082,10 @@ def sample_inputs_posterize_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.posterize_image_tensor, + F.posterize_image, kernel_name="posterize_image_tensor", sample_inputs_fn=sample_inputs_posterize_image_tensor, - reference_fn=pil_reference_wrapper(F.posterize_image_pil), + reference_fn=pil_reference_wrapper(F._posterize_image_pil), reference_inputs_fn=reference_inputs_posterize_image_tensor, float32_vs_uint8=True, closeness_kwargs=float32_vs_uint8_pixel_difference(), @@ -1127,10 +1127,10 @@ def sample_inputs_solarize_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.solarize_image_tensor, + F.solarize_image, kernel_name="solarize_image_tensor", sample_inputs_fn=sample_inputs_solarize_image_tensor, - reference_fn=pil_reference_wrapper(F.solarize_image_pil), + reference_fn=pil_reference_wrapper(F._solarize_image_pil), reference_inputs_fn=reference_inputs_solarize_image_tensor, float32_vs_uint8=uint8_to_float32_threshold_adapter, closeness_kwargs=float32_vs_uint8_pixel_difference(), @@ -1161,10 +1161,10 @@ def sample_inputs_autocontrast_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.autocontrast_image_tensor, + F.autocontrast_image, kernel_name="autocontrast_image_tensor", sample_inputs_fn=sample_inputs_autocontrast_image_tensor, - reference_fn=pil_reference_wrapper(F.autocontrast_image_pil), + reference_fn=pil_reference_wrapper(F._autocontrast_image_pil), reference_inputs_fn=reference_inputs_autocontrast_image_tensor, float32_vs_uint8=True, closeness_kwargs={ @@ -1206,10 +1206,10 @@ def sample_inputs_adjust_sharpness_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.adjust_sharpness_image_tensor, + F.adjust_sharpness_image, kernel_name="adjust_sharpness_image_tensor", sample_inputs_fn=sample_inputs_adjust_sharpness_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_sharpness_image_pil), + reference_fn=pil_reference_wrapper(F._adjust_sharpness_image_pil), reference_inputs_fn=reference_inputs_adjust_sharpness_image_tensor, float32_vs_uint8=True, closeness_kwargs=float32_vs_uint8_pixel_difference(2), @@ -1241,7 +1241,7 @@ def sample_inputs_erase_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.erase_image_tensor, + F.erase_image, kernel_name="erase_image_tensor", sample_inputs_fn=sample_inputs_erase_image_tensor, ), @@ -1276,10 +1276,10 @@ def sample_inputs_adjust_contrast_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.adjust_contrast_image_tensor, + F.adjust_contrast_image, kernel_name="adjust_contrast_image_tensor", sample_inputs_fn=sample_inputs_adjust_contrast_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_contrast_image_pil), + reference_fn=pil_reference_wrapper(F._adjust_contrast_image_pil), reference_inputs_fn=reference_inputs_adjust_contrast_image_tensor, float32_vs_uint8=True, closeness_kwargs={ @@ -1329,10 +1329,10 @@ def sample_inputs_adjust_gamma_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.adjust_gamma_image_tensor, + F.adjust_gamma_image, kernel_name="adjust_gamma_image_tensor", sample_inputs_fn=sample_inputs_adjust_gamma_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_gamma_image_pil), + reference_fn=pil_reference_wrapper(F._adjust_gamma_image_pil), reference_inputs_fn=reference_inputs_adjust_gamma_image_tensor, float32_vs_uint8=True, closeness_kwargs={ @@ -1372,10 +1372,10 @@ def sample_inputs_adjust_hue_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.adjust_hue_image_tensor, + F.adjust_hue_image, kernel_name="adjust_hue_image_tensor", sample_inputs_fn=sample_inputs_adjust_hue_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_hue_image_pil), + reference_fn=pil_reference_wrapper(F._adjust_hue_image_pil), reference_inputs_fn=reference_inputs_adjust_hue_image_tensor, float32_vs_uint8=True, closeness_kwargs={ @@ -1414,10 +1414,10 @@ def sample_inputs_adjust_saturation_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.adjust_saturation_image_tensor, + F.adjust_saturation_image, kernel_name="adjust_saturation_image_tensor", sample_inputs_fn=sample_inputs_adjust_saturation_image_tensor, - reference_fn=pil_reference_wrapper(F.adjust_saturation_image_pil), + reference_fn=pil_reference_wrapper(F._adjust_saturation_image_pil), reference_inputs_fn=reference_inputs_adjust_saturation_image_tensor, float32_vs_uint8=True, closeness_kwargs={ @@ -1517,8 +1517,7 @@ def multi_crop_pil_reference_wrapper(pil_kernel): def wrapper(input_tensor, *other_args, **kwargs): output = pil_reference_wrapper(pil_kernel)(input_tensor, *other_args, **kwargs) return type(output)( - F.to_dtype_image_tensor(F.to_image_tensor(output_pil), dtype=input_tensor.dtype, scale=True) - for output_pil in output + F.to_dtype_image(F.to_image(output_pil), dtype=input_tensor.dtype, scale=True) for output_pil in output ) return wrapper @@ -1532,9 +1531,9 @@ def wrapper(input_tensor, *other_args, **kwargs): KERNEL_INFOS.extend( [ KernelInfo( - F.five_crop_image_tensor, + F.five_crop_image, sample_inputs_fn=sample_inputs_five_crop_image_tensor, - reference_fn=multi_crop_pil_reference_wrapper(F.five_crop_image_pil), + reference_fn=multi_crop_pil_reference_wrapper(F._five_crop_image_pil), reference_inputs_fn=reference_inputs_five_crop_image_tensor, test_marks=_common_five_ten_crop_marks, ), @@ -1544,9 +1543,9 @@ def wrapper(input_tensor, *other_args, **kwargs): test_marks=_common_five_ten_crop_marks, ), KernelInfo( - F.ten_crop_image_tensor, + F.ten_crop_image, sample_inputs_fn=sample_inputs_ten_crop_image_tensor, - reference_fn=multi_crop_pil_reference_wrapper(F.ten_crop_image_pil), + reference_fn=multi_crop_pil_reference_wrapper(F._ten_crop_image_pil), reference_inputs_fn=reference_inputs_ten_crop_image_tensor, test_marks=_common_five_ten_crop_marks, ), @@ -1600,7 +1599,7 @@ def sample_inputs_normalize_video(): KERNEL_INFOS.extend( [ KernelInfo( - F.normalize_image_tensor, + F.normalize_image, kernel_name="normalize_image_tensor", sample_inputs_fn=sample_inputs_normalize_image_tensor, reference_fn=reference_normalize_image_tensor, diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index f2c6e89dd3a..81f726a2dbd 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -112,7 +112,7 @@ def _extract_image_targets( if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): images.append(obj) elif isinstance(obj, PIL.Image.Image): - images.append(F.to_image_tensor(obj)) + images.append(F.to_image(obj)) elif isinstance(obj, datapoints.BoundingBoxes): bboxes.append(obj) elif isinstance(obj, datapoints.Mask): @@ -144,7 +144,7 @@ def _insert_outputs( flat_sample[i] = datapoints.wrap(output_images[c0], like=obj) c0 += 1 elif isinstance(obj, PIL.Image.Image): - flat_sample[i] = F.to_image_pil(output_images[c0]) + flat_sample[i] = F.to_pil_image(output_images[c0]) c0 += 1 elif is_simple_tensor(obj): flat_sample[i] = output_images[c0] diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 4451cb7a1a2..38da78fa4d7 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -52,7 +52,7 @@ ToDtype, ) from ._temporal import UniformTemporalSubsample -from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage +from ._type_conversion import PILToTensor, ToImage, ToPILImage from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 8494b64b994..687a2396e67 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -622,6 +622,6 @@ def forward(self, *inputs: Any) -> Any: if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)): mix = datapoints.wrap(mix, like=orig_image_or_video) elif isinstance(orig_image_or_video, PIL.Image.Image): - mix = F.to_image_pil(mix) + mix = F.to_pil_image(mix) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix) diff --git a/torchvision/transforms/v2/_type_conversion.py b/torchvision/transforms/v2/_type_conversion.py index 60f44c5d3db..aec82f46f14 100644 --- a/torchvision/transforms/v2/_type_conversion.py +++ b/torchvision/transforms/v2/_type_conversion.py @@ -26,7 +26,7 @@ def _transform(self, inpt: PIL.Image.Image, params: Dict[str, Any]) -> torch.Ten return F.pil_to_tensor(inpt) -class ToImageTensor(Transform): +class ToImage(Transform): """[BETA] Convert a tensor, ndarray, or PIL Image to :class:`~torchvision.datapoints.Image` ; this does not scale values. @@ -40,10 +40,10 @@ class ToImageTensor(Transform): def _transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] ) -> datapoints.Image: - return F.to_image_tensor(inpt) + return F.to_image(inpt) -class ToImagePIL(Transform): +class ToPILImage(Transform): """[BETA] Convert a tensor or an ndarray to PIL Image - this does not scale values. .. v2betastatus:: ToImagePIL transform @@ -74,9 +74,4 @@ def __init__(self, mode: Optional[str] = None) -> None: def _transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] ) -> PIL.Image.Image: - return F.to_image_pil(inpt, mode=self.mode) - - -# We changed the name to align them with the new naming scheme. Still, `ToPILImage` is -# prevalent and well understood. Thus, we just alias it without deprecating the old name. -ToPILImage = ToImagePIL + return F.to_pil_image(inpt, mode=self.mode) diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index f3295860155..3510962ff3a 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -5,173 +5,173 @@ from ._meta import ( clamp_bounding_boxes, convert_format_bounding_boxes, - get_dimensions_image_tensor, - get_dimensions_image_pil, + get_dimensions_image, + _get_dimensions_image_pil, get_dimensions_video, get_dimensions, get_num_frames_video, get_num_frames, get_image_num_channels, - get_num_channels_image_tensor, - get_num_channels_image_pil, + get_num_channels_image, + _get_num_channels_image_pil, get_num_channels_video, get_num_channels, get_size_bounding_boxes, - get_size_image_tensor, - get_size_image_pil, + get_size_image, + _get_size_image_pil, get_size_mask, get_size_video, get_size, ) # usort: skip -from ._augment import erase, erase_image_pil, erase_image_tensor, erase_video +from ._augment import _erase_image_pil, erase, erase_image, erase_video from ._color import ( + _adjust_brightness_image_pil, + _adjust_contrast_image_pil, + _adjust_gamma_image_pil, + _adjust_hue_image_pil, + _adjust_saturation_image_pil, + _adjust_sharpness_image_pil, + _autocontrast_image_pil, + _equalize_image_pil, + _invert_image_pil, + _permute_channels_image_pil, + _posterize_image_pil, + _rgb_to_grayscale_image_pil, + _solarize_image_pil, adjust_brightness, - adjust_brightness_image_pil, - adjust_brightness_image_tensor, + adjust_brightness_image, adjust_brightness_video, adjust_contrast, - adjust_contrast_image_pil, - adjust_contrast_image_tensor, + adjust_contrast_image, adjust_contrast_video, adjust_gamma, - adjust_gamma_image_pil, - adjust_gamma_image_tensor, + adjust_gamma_image, adjust_gamma_video, adjust_hue, - adjust_hue_image_pil, - adjust_hue_image_tensor, + adjust_hue_image, adjust_hue_video, adjust_saturation, - adjust_saturation_image_pil, - adjust_saturation_image_tensor, + adjust_saturation_image, adjust_saturation_video, adjust_sharpness, - adjust_sharpness_image_pil, - adjust_sharpness_image_tensor, + adjust_sharpness_image, adjust_sharpness_video, autocontrast, - autocontrast_image_pil, - autocontrast_image_tensor, + autocontrast_image, autocontrast_video, equalize, - equalize_image_pil, - equalize_image_tensor, + equalize_image, equalize_video, invert, - invert_image_pil, - invert_image_tensor, + invert_image, invert_video, permute_channels, - permute_channels_image_pil, - permute_channels_image_tensor, + permute_channels_image, permute_channels_video, posterize, - posterize_image_pil, - posterize_image_tensor, + posterize_image, posterize_video, rgb_to_grayscale, - rgb_to_grayscale_image_pil, - rgb_to_grayscale_image_tensor, + rgb_to_grayscale_image, solarize, - solarize_image_pil, - solarize_image_tensor, + solarize_image, solarize_video, to_grayscale, ) from ._geometry import ( + _affine_image_pil, + _center_crop_image_pil, + _crop_image_pil, + _elastic_image_pil, + _five_crop_image_pil, + _horizontal_flip_image_pil, + _pad_image_pil, + _perspective_image_pil, + _resize_image_pil, + _resized_crop_image_pil, + _rotate_image_pil, + _ten_crop_image_pil, + _vertical_flip_image_pil, affine, affine_bounding_boxes, - affine_image_pil, - affine_image_tensor, + affine_image, affine_mask, affine_video, center_crop, center_crop_bounding_boxes, - center_crop_image_pil, - center_crop_image_tensor, + center_crop_image, center_crop_mask, center_crop_video, crop, crop_bounding_boxes, - crop_image_pil, - crop_image_tensor, + crop_image, crop_mask, crop_video, elastic, elastic_bounding_boxes, - elastic_image_pil, - elastic_image_tensor, + elastic_image, elastic_mask, elastic_transform, elastic_video, five_crop, - five_crop_image_pil, - five_crop_image_tensor, + five_crop_image, five_crop_video, hflip, # TODO: Consider moving all pure alias definitions at the bottom of the file horizontal_flip, horizontal_flip_bounding_boxes, - horizontal_flip_image_pil, - horizontal_flip_image_tensor, + horizontal_flip_image, horizontal_flip_mask, horizontal_flip_video, pad, pad_bounding_boxes, - pad_image_pil, - pad_image_tensor, + pad_image, pad_mask, pad_video, perspective, perspective_bounding_boxes, - perspective_image_pil, - perspective_image_tensor, + perspective_image, perspective_mask, perspective_video, resize, resize_bounding_boxes, - resize_image_pil, - resize_image_tensor, + resize_image, resize_mask, resize_video, resized_crop, resized_crop_bounding_boxes, - resized_crop_image_pil, - resized_crop_image_tensor, + resized_crop_image, resized_crop_mask, resized_crop_video, rotate, rotate_bounding_boxes, - rotate_image_pil, - rotate_image_tensor, + rotate_image, rotate_mask, rotate_video, ten_crop, - ten_crop_image_pil, - ten_crop_image_tensor, + ten_crop_image, ten_crop_video, vertical_flip, vertical_flip_bounding_boxes, - vertical_flip_image_pil, - vertical_flip_image_tensor, + vertical_flip_image, vertical_flip_mask, vertical_flip_video, vflip, ) from ._misc import ( + _gaussian_blur_image_pil, convert_image_dtype, gaussian_blur, - gaussian_blur_image_pil, - gaussian_blur_image_tensor, + gaussian_blur_image, gaussian_blur_video, normalize, - normalize_image_tensor, + normalize_image, normalize_video, to_dtype, - to_dtype_image_tensor, + to_dtype_image, to_dtype_video, ) from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video -from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image +from ._type_conversion import pil_to_tensor, to_image, to_pil_image from ._deprecated import get_image_size, to_tensor # usort: skip diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 4a927be9777..48b8865c4cc 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -18,7 +18,7 @@ def erase( inplace: bool = False, ) -> torch.Tensor: if torch.jit.is_scripting(): - return erase_image_tensor(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return erase_image(inpt, i=i, j=j, h=h, w=w, v=v, inplace=inplace) _log_api_usage_once(erase) @@ -28,7 +28,7 @@ def erase( @_register_kernel_internal(erase, torch.Tensor) @_register_kernel_internal(erase, datapoints.Image) -def erase_image_tensor( +def erase_image( image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False ) -> torch.Tensor: if not inplace: @@ -39,11 +39,11 @@ def erase_image_tensor( @_register_kernel_internal(erase, PIL.Image.Image) -def erase_image_pil( +def _erase_image_pil( image: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False ) -> PIL.Image.Image: t_img = pil_to_tensor(image) - output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + output = erase_image(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) return to_pil_image(output, mode=image.mode) @@ -51,4 +51,4 @@ def erase_image_pil( def erase_video( video: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False ) -> torch.Tensor: - return erase_image_tensor(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return erase_image(video, i=i, j=j, h=h, w=w, v=v, inplace=inplace) diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 82bd236645e..825ffa207b0 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -9,14 +9,14 @@ from torchvision.utils import _log_api_usage_once -from ._misc import _num_value_bits, to_dtype_image_tensor -from ._type_conversion import pil_to_tensor, to_image_pil +from ._misc import _num_value_bits, to_dtype_image +from ._type_conversion import pil_to_tensor, to_pil_image from ._utils import _get_kernel, _register_kernel_internal def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: if torch.jit.is_scripting(): - return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) + return rgb_to_grayscale_image(inpt, num_output_channels=num_output_channels) _log_api_usage_once(rgb_to_grayscale) @@ -29,7 +29,7 @@ def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch. to_grayscale = rgb_to_grayscale -def _rgb_to_grayscale_image_tensor( +def _rgb_to_grayscale_image( image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True ) -> torch.Tensor: if image.shape[-3] == 1: @@ -47,14 +47,14 @@ def _rgb_to_grayscale_image_tensor( @_register_kernel_internal(rgb_to_grayscale, torch.Tensor) @_register_kernel_internal(rgb_to_grayscale, datapoints.Image) -def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: +def rgb_to_grayscale_image(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: if num_output_channels not in (1, 3): raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") - return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) + return _rgb_to_grayscale_image(image, num_output_channels=num_output_channels, preserve_dtype=True) @_register_kernel_internal(rgb_to_grayscale, PIL.Image.Image) -def rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: +def _rgb_to_grayscale_image_pil(image: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: if num_output_channels not in (1, 3): raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") return _FP.to_grayscale(image, num_output_channels=num_output_channels) @@ -71,7 +71,7 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): - return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) + return adjust_brightness_image(inpt, brightness_factor=brightness_factor) _log_api_usage_once(adjust_brightness) @@ -81,7 +81,7 @@ def adjust_brightness(inpt: torch.Tensor, brightness_factor: float) -> torch.Ten @_register_kernel_internal(adjust_brightness, torch.Tensor) @_register_kernel_internal(adjust_brightness, datapoints.Image) -def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: +def adjust_brightness_image(image: torch.Tensor, brightness_factor: float) -> torch.Tensor: if brightness_factor < 0: raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.") @@ -96,18 +96,18 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float @_register_kernel_internal(adjust_brightness, PIL.Image.Image) -def adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image: +def _adjust_brightness_image_pil(image: PIL.Image.Image, brightness_factor: float) -> PIL.Image.Image: return _FP.adjust_brightness(image, brightness_factor=brightness_factor) @_register_kernel_internal(adjust_brightness, datapoints.Video) def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> torch.Tensor: - return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) + return adjust_brightness_image(video, brightness_factor=brightness_factor) def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): - return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) + return adjust_saturation_image(inpt, saturation_factor=saturation_factor) _log_api_usage_once(adjust_saturation) @@ -117,7 +117,7 @@ def adjust_saturation(inpt: torch.Tensor, saturation_factor: float) -> torch.Ten @_register_kernel_internal(adjust_saturation, torch.Tensor) @_register_kernel_internal(adjust_saturation, datapoints.Image) -def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: +def adjust_saturation_image(image: torch.Tensor, saturation_factor: float) -> torch.Tensor: if saturation_factor < 0: raise ValueError(f"saturation_factor ({saturation_factor}) is not non-negative.") @@ -128,24 +128,24 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float if c == 1: # Match PIL behaviour return image - grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) + grayscale_image = _rgb_to_grayscale_image(image, num_output_channels=1, preserve_dtype=False) if not image.is_floating_point(): grayscale_image = grayscale_image.floor_() return _blend(image, grayscale_image, saturation_factor) -adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation) +_adjust_saturation_image_pil = _register_kernel_internal(adjust_saturation, PIL.Image.Image)(_FP.adjust_saturation) @_register_kernel_internal(adjust_saturation, datapoints.Video) def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> torch.Tensor: - return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) + return adjust_saturation_image(video, saturation_factor=saturation_factor) def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): - return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) + return adjust_contrast_image(inpt, contrast_factor=contrast_factor) _log_api_usage_once(adjust_contrast) @@ -155,7 +155,7 @@ def adjust_contrast(inpt: torch.Tensor, contrast_factor: float) -> torch.Tensor: @_register_kernel_internal(adjust_contrast, torch.Tensor) @_register_kernel_internal(adjust_contrast, datapoints.Image) -def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: +def adjust_contrast_image(image: torch.Tensor, contrast_factor: float) -> torch.Tensor: if contrast_factor < 0: raise ValueError(f"contrast_factor ({contrast_factor}) is not non-negative.") @@ -164,7 +164,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") fp = image.is_floating_point() if c == 3: - grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) + grayscale_image = _rgb_to_grayscale_image(image, num_output_channels=1, preserve_dtype=False) if not fp: grayscale_image = grayscale_image.floor_() else: @@ -173,17 +173,17 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> return _blend(image, mean, contrast_factor) -adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast) +_adjust_contrast_image_pil = _register_kernel_internal(adjust_contrast, PIL.Image.Image)(_FP.adjust_contrast) @_register_kernel_internal(adjust_contrast, datapoints.Video) def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch.Tensor: - return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) + return adjust_contrast_image(video, contrast_factor=contrast_factor) def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): - return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) + return adjust_sharpness_image(inpt, sharpness_factor=sharpness_factor) _log_api_usage_once(adjust_sharpness) @@ -193,7 +193,7 @@ def adjust_sharpness(inpt: torch.Tensor, sharpness_factor: float) -> torch.Tenso @_register_kernel_internal(adjust_sharpness, torch.Tensor) @_register_kernel_internal(adjust_sharpness, datapoints.Image) -def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: +def adjust_sharpness_image(image: torch.Tensor, sharpness_factor: float) -> torch.Tensor: num_channels, height, width = image.shape[-3:] if num_channels not in (1, 3): raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") @@ -245,17 +245,17 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float) return output -adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness) +_adjust_sharpness_image_pil = _register_kernel_internal(adjust_sharpness, PIL.Image.Image)(_FP.adjust_sharpness) @_register_kernel_internal(adjust_sharpness, datapoints.Video) def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor: - return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) + return adjust_sharpness_image(video, sharpness_factor=sharpness_factor) def adjust_hue(inpt: torch.Tensor, hue_factor: float) -> torch.Tensor: if torch.jit.is_scripting(): - return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) + return adjust_hue_image(inpt, hue_factor=hue_factor) _log_api_usage_once(adjust_hue) @@ -335,7 +335,7 @@ def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(adjust_hue, torch.Tensor) @_register_kernel_internal(adjust_hue, datapoints.Image) -def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: +def adjust_hue_image(image: torch.Tensor, hue_factor: float) -> torch.Tensor: if not (-0.5 <= hue_factor <= 0.5): raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") @@ -351,7 +351,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten return image orig_dtype = image.dtype - image = to_dtype_image_tensor(image, torch.float32, scale=True) + image = to_dtype_image(image, torch.float32, scale=True) image = _rgb_to_hsv(image) h, s, v = image.unbind(dim=-3) @@ -359,20 +359,20 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten image = torch.stack((h, s, v), dim=-3) image_hue_adj = _hsv_to_rgb(image) - return to_dtype_image_tensor(image_hue_adj, orig_dtype, scale=True) + return to_dtype_image(image_hue_adj, orig_dtype, scale=True) -adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue) +_adjust_hue_image_pil = _register_kernel_internal(adjust_hue, PIL.Image.Image)(_FP.adjust_hue) @_register_kernel_internal(adjust_hue, datapoints.Video) def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: - return adjust_hue_image_tensor(video, hue_factor=hue_factor) + return adjust_hue_image(video, hue_factor=hue_factor) def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: if torch.jit.is_scripting(): - return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) + return adjust_gamma_image(inpt, gamma=gamma, gain=gain) _log_api_usage_once(adjust_gamma) @@ -382,14 +382,14 @@ def adjust_gamma(inpt: torch.Tensor, gamma: float, gain: float = 1) -> torch.Ten @_register_kernel_internal(adjust_gamma, torch.Tensor) @_register_kernel_internal(adjust_gamma, datapoints.Image) -def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: +def adjust_gamma_image(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: if gamma < 0: raise ValueError("Gamma should be a non-negative real number") # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). # Since the gamma is non-negative, the output remains at [0, 1] scale. if not torch.is_floating_point(image): - output = to_dtype_image_tensor(image, torch.float32, scale=True).pow_(gamma) + output = to_dtype_image(image, torch.float32, scale=True).pow_(gamma) else: output = image.pow(gamma) @@ -398,20 +398,20 @@ def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1 # of the output can go beyond [0, 1]. output = output.mul_(gain).clamp_(0.0, 1.0) - return to_dtype_image_tensor(output, image.dtype, scale=True) + return to_dtype_image(output, image.dtype, scale=True) -adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma) +_adjust_gamma_image_pil = _register_kernel_internal(adjust_gamma, PIL.Image.Image)(_FP.adjust_gamma) @_register_kernel_internal(adjust_gamma, datapoints.Video) def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> torch.Tensor: - return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) + return adjust_gamma_image(video, gamma=gamma, gain=gain) def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: if torch.jit.is_scripting(): - return posterize_image_tensor(inpt, bits=bits) + return posterize_image(inpt, bits=bits) _log_api_usage_once(posterize) @@ -421,7 +421,7 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: @_register_kernel_internal(posterize, torch.Tensor) @_register_kernel_internal(posterize, datapoints.Image) -def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: +def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor: if image.is_floating_point(): levels = 1 << bits return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels) @@ -434,17 +434,17 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: return image & mask -posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize) +_posterize_image_pil = _register_kernel_internal(posterize, PIL.Image.Image)(_FP.posterize) @_register_kernel_internal(posterize, datapoints.Video) def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: - return posterize_image_tensor(video, bits=bits) + return posterize_image(video, bits=bits) def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: if torch.jit.is_scripting(): - return solarize_image_tensor(inpt, threshold=threshold) + return solarize_image(inpt, threshold=threshold) _log_api_usage_once(solarize) @@ -454,24 +454,24 @@ def solarize(inpt: torch.Tensor, threshold: float) -> torch.Tensor: @_register_kernel_internal(solarize, torch.Tensor) @_register_kernel_internal(solarize, datapoints.Image) -def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: +def solarize_image(image: torch.Tensor, threshold: float) -> torch.Tensor: if threshold > _max_value(image.dtype): raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") - return torch.where(image >= threshold, invert_image_tensor(image), image) + return torch.where(image >= threshold, invert_image(image), image) -solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize) +_solarize_image_pil = _register_kernel_internal(solarize, PIL.Image.Image)(_FP.solarize) @_register_kernel_internal(solarize, datapoints.Video) def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: - return solarize_image_tensor(video, threshold=threshold) + return solarize_image(video, threshold=threshold) def autocontrast(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): - return autocontrast_image_tensor(inpt) + return autocontrast_image(inpt) _log_api_usage_once(autocontrast) @@ -481,7 +481,7 @@ def autocontrast(inpt: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(autocontrast, torch.Tensor) @_register_kernel_internal(autocontrast, datapoints.Image) -def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: +def autocontrast_image(image: torch.Tensor) -> torch.Tensor: c = image.shape[-3] if c not in [1, 3]: raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") @@ -510,17 +510,17 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: return diff.div_(inv_scale).clamp_(0, bound).to(image.dtype) -autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast) +_autocontrast_image_pil = _register_kernel_internal(autocontrast, PIL.Image.Image)(_FP.autocontrast) @_register_kernel_internal(autocontrast, datapoints.Video) def autocontrast_video(video: torch.Tensor) -> torch.Tensor: - return autocontrast_image_tensor(video) + return autocontrast_image(video) def equalize(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): - return equalize_image_tensor(inpt) + return equalize_image(inpt) _log_api_usage_once(equalize) @@ -530,7 +530,7 @@ def equalize(inpt: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(equalize, torch.Tensor) @_register_kernel_internal(equalize, datapoints.Image) -def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: +def equalize_image(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: return image @@ -545,7 +545,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is # by far the most common, we choose it as base. output_dtype = image.dtype - image = to_dtype_image_tensor(image, torch.uint8, scale=True) + image = to_dtype_image(image, torch.uint8, scale=True) # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image # corresponds to adding 1 to index 127 in the histogram. @@ -596,20 +596,20 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image) output = torch.where(valid_equalization, equalized_image, image) - return to_dtype_image_tensor(output, output_dtype, scale=True) + return to_dtype_image(output, output_dtype, scale=True) -equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize) +_equalize_image_pil = _register_kernel_internal(equalize, PIL.Image.Image)(_FP.equalize) @_register_kernel_internal(equalize, datapoints.Video) def equalize_video(video: torch.Tensor) -> torch.Tensor: - return equalize_image_tensor(video) + return equalize_image(video) def invert(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): - return invert_image_tensor(inpt) + return invert_image(inpt) _log_api_usage_once(invert) @@ -619,7 +619,7 @@ def invert(inpt: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(invert, torch.Tensor) @_register_kernel_internal(invert, datapoints.Image) -def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: +def invert_image(image: torch.Tensor) -> torch.Tensor: if image.is_floating_point(): return 1.0 - image elif image.dtype == torch.uint8: @@ -629,12 +629,12 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1) -invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert) +_invert_image_pil = _register_kernel_internal(invert, PIL.Image.Image)(_FP.invert) @_register_kernel_internal(invert, datapoints.Video) def invert_video(video: torch.Tensor) -> torch.Tensor: - return invert_image_tensor(video) + return invert_image(video) def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor: @@ -660,7 +660,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor ValueError: If ``len(permutation)`` doesn't match the number of channels in the input. """ if torch.jit.is_scripting(): - return permute_channels_image_tensor(inpt, permutation=permutation) + return permute_channels_image(inpt, permutation=permutation) _log_api_usage_once(permute_channels) @@ -670,7 +670,7 @@ def permute_channels(inpt: torch.Tensor, permutation: List[int]) -> torch.Tensor @_register_kernel_internal(permute_channels, torch.Tensor) @_register_kernel_internal(permute_channels, datapoints.Image) -def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) -> torch.Tensor: +def permute_channels_image(image: torch.Tensor, permutation: List[int]) -> torch.Tensor: shape = image.shape num_channels, height, width = shape[-3:] @@ -688,10 +688,10 @@ def permute_channels_image_tensor(image: torch.Tensor, permutation: List[int]) - @_register_kernel_internal(permute_channels, PIL.Image.Image) -def permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image: - return to_image_pil(permute_channels_image_tensor(pil_to_tensor(image), permutation=permutation)) +def _permute_channels_image_pil(image: PIL.Image.Image, permutation: List[int]) -> PIL.Image: + return to_pil_image(permute_channels_image(pil_to_tensor(image), permutation=permutation)) @_register_kernel_internal(permute_channels, datapoints.Video) def permute_channels_video(video: torch.Tensor, permutation: List[int]) -> torch.Tensor: - return permute_channels_image_tensor(video, permutation=permutation) + return permute_channels_image(video, permutation=permutation) diff --git a/torchvision/transforms/v2/functional/_deprecated.py b/torchvision/transforms/v2/functional/_deprecated.py index 1cb7f50e5c7..aac56c51cca 100644 --- a/torchvision/transforms/v2/functional/_deprecated.py +++ b/torchvision/transforms/v2/functional/_deprecated.py @@ -10,7 +10,7 @@ def to_tensor(inpt: Any) -> torch.Tensor: warnings.warn( "The function `to_tensor(...)` is deprecated and will be removed in a future release. " - "Instead, please use `to_image_tensor(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`." + "Instead, please use `to_image(...)` followed by `to_dtype(..., dtype=torch.float32, scale=True)`." ) return _F.to_tensor(inpt) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 898e7e0c1a8..0cd43590bee 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -23,7 +23,7 @@ from torchvision.utils import _log_api_usage_once -from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil +from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_format_bounding_boxes from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal @@ -41,7 +41,7 @@ def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> Interp def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): - return horizontal_flip_image_tensor(inpt) + return horizontal_flip_image(inpt) _log_api_usage_once(horizontal_flip) @@ -51,18 +51,18 @@ def horizontal_flip(inpt: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(horizontal_flip, torch.Tensor) @_register_kernel_internal(horizontal_flip, datapoints.Image) -def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: +def horizontal_flip_image(image: torch.Tensor) -> torch.Tensor: return image.flip(-1) @_register_kernel_internal(horizontal_flip, PIL.Image.Image) -def horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: +def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.hflip(image) @_register_kernel_internal(horizontal_flip, datapoints.Mask) def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: - return horizontal_flip_image_tensor(mask) + return horizontal_flip_image(mask) def horizontal_flip_bounding_boxes( @@ -92,12 +92,12 @@ def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> @_register_kernel_internal(horizontal_flip, datapoints.Video) def horizontal_flip_video(video: torch.Tensor) -> torch.Tensor: - return horizontal_flip_image_tensor(video) + return horizontal_flip_image(video) def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting(): - return vertical_flip_image_tensor(inpt) + return vertical_flip_image(inpt) _log_api_usage_once(vertical_flip) @@ -107,18 +107,18 @@ def vertical_flip(inpt: torch.Tensor) -> torch.Tensor: @_register_kernel_internal(vertical_flip, torch.Tensor) @_register_kernel_internal(vertical_flip, datapoints.Image) -def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor: +def vertical_flip_image(image: torch.Tensor) -> torch.Tensor: return image.flip(-2) @_register_kernel_internal(vertical_flip, PIL.Image.Image) -def vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: +def _vertical_flip_image_pil(image: PIL.Image) -> PIL.Image: return _FP.vflip(image) @_register_kernel_internal(vertical_flip, datapoints.Mask) def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: - return vertical_flip_image_tensor(mask) + return vertical_flip_image(mask) def vertical_flip_bounding_boxes( @@ -148,7 +148,7 @@ def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> da @_register_kernel_internal(vertical_flip, datapoints.Video) def vertical_flip_video(video: torch.Tensor) -> torch.Tensor: - return vertical_flip_image_tensor(video) + return vertical_flip_image(video) # We changed the names to align them with the transforms, i.e. `RandomHorizontalFlip`. Still, `hflip` and `vflip` are @@ -178,7 +178,7 @@ def resize( antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: if torch.jit.is_scripting(): - return resize_image_tensor(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + return resize_image(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) _log_api_usage_once(resize) @@ -188,7 +188,7 @@ def resize( @_register_kernel_internal(resize, torch.Tensor) @_register_kernel_internal(resize, datapoints.Image) -def resize_image_tensor( +def resize_image( image: torch.Tensor, size: List[int], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, @@ -267,7 +267,7 @@ def resize_image_tensor( return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) -def resize_image_pil( +def _resize_image_pil( image: PIL.Image.Image, size: Union[Sequence[int], int], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, @@ -289,7 +289,7 @@ def resize_image_pil( @_register_kernel_internal(resize, PIL.Image.Image) -def _resize_image_pil_dispatch( +def __resize_image_pil_dispatch( image: PIL.Image.Image, size: Union[Sequence[int], int], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, @@ -298,7 +298,7 @@ def _resize_image_pil_dispatch( ) -> PIL.Image.Image: if antialias is False: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") - return resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size) + return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size) def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor: @@ -308,7 +308,7 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N else: needs_squeeze = False - output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) + output = resize_image(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) if needs_squeeze: output = output.squeeze(0) @@ -360,7 +360,7 @@ def resize_video( max_size: Optional[int] = None, antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: - return resize_image_tensor(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + return resize_image(video, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) def affine( @@ -374,7 +374,7 @@ def affine( center: Optional[List[float]] = None, ) -> torch.Tensor: if torch.jit.is_scripting(): - return affine_image_tensor( + return affine_image( inpt, angle=angle, translate=translate, @@ -648,7 +648,7 @@ def _affine_grid( @_register_kernel_internal(affine, torch.Tensor) @_register_kernel_internal(affine, datapoints.Image) -def affine_image_tensor( +def affine_image( image: torch.Tensor, angle: Union[int, float], translate: List[float], @@ -700,7 +700,7 @@ def affine_image_tensor( @_register_kernel_internal(affine, PIL.Image.Image) -def affine_image_pil( +def _affine_image_pil( image: PIL.Image.Image, angle: Union[int, float], translate: List[float], @@ -717,7 +717,7 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - height, width = get_size_image_pil(image) + height, width = _get_size_image_pil(image) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) @@ -875,7 +875,7 @@ def affine_mask( else: needs_squeeze = False - output = affine_image_tensor( + output = affine_image( mask, angle=angle, translate=translate, @@ -926,7 +926,7 @@ def affine_video( fill: _FillTypeJIT = None, center: Optional[List[float]] = None, ) -> torch.Tensor: - return affine_image_tensor( + return affine_image( video, angle=angle, translate=translate, @@ -947,9 +947,7 @@ def rotate( fill: _FillTypeJIT = None, ) -> torch.Tensor: if torch.jit.is_scripting(): - return rotate_image_tensor( - inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center - ) + return rotate_image(inpt, angle=angle, interpolation=interpolation, expand=expand, fill=fill, center=center) _log_api_usage_once(rotate) @@ -959,7 +957,7 @@ def rotate( @_register_kernel_internal(rotate, torch.Tensor) @_register_kernel_internal(rotate, datapoints.Image) -def rotate_image_tensor( +def rotate_image( image: torch.Tensor, angle: float, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, @@ -1004,7 +1002,7 @@ def rotate_image_tensor( @_register_kernel_internal(rotate, PIL.Image.Image) -def rotate_image_pil( +def _rotate_image_pil( image: PIL.Image.Image, angle: float, interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST, @@ -1074,7 +1072,7 @@ def rotate_mask( else: needs_squeeze = False - output = rotate_image_tensor( + output = rotate_image( mask, angle=angle, expand=expand, @@ -1111,7 +1109,7 @@ def rotate_video( center: Optional[List[float]] = None, fill: _FillTypeJIT = None, ) -> torch.Tensor: - return rotate_image_tensor(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) + return rotate_image(video, angle, interpolation=interpolation, expand=expand, fill=fill, center=center) def pad( @@ -1121,7 +1119,7 @@ def pad( padding_mode: str = "constant", ) -> torch.Tensor: if torch.jit.is_scripting(): - return pad_image_tensor(inpt, padding=padding, fill=fill, padding_mode=padding_mode) + return pad_image(inpt, padding=padding, fill=fill, padding_mode=padding_mode) _log_api_usage_once(pad) @@ -1155,7 +1153,7 @@ def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]: @_register_kernel_internal(pad, torch.Tensor) @_register_kernel_internal(pad, datapoints.Image) -def pad_image_tensor( +def pad_image( image: torch.Tensor, padding: List[int], fill: Optional[Union[int, float, List[float]]] = None, @@ -1253,7 +1251,7 @@ def _pad_with_vector_fill( return output -pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad) +_pad_image_pil = _register_kernel_internal(pad, PIL.Image.Image)(_FP.pad) @_register_kernel_internal(pad, datapoints.Mask) @@ -1275,7 +1273,7 @@ def pad_mask( else: needs_squeeze = False - output = pad_image_tensor(mask, padding=padding, fill=fill, padding_mode=padding_mode) + output = pad_image(mask, padding=padding, fill=fill, padding_mode=padding_mode) if needs_squeeze: output = output.squeeze(0) @@ -1331,12 +1329,12 @@ def pad_video( fill: Optional[Union[int, float, List[float]]] = None, padding_mode: str = "constant", ) -> torch.Tensor: - return pad_image_tensor(video, padding, fill=fill, padding_mode=padding_mode) + return pad_image(video, padding, fill=fill, padding_mode=padding_mode) def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: if torch.jit.is_scripting(): - return crop_image_tensor(inpt, top=top, left=left, height=height, width=width) + return crop_image(inpt, top=top, left=left, height=height, width=width) _log_api_usage_once(crop) @@ -1346,7 +1344,7 @@ def crop(inpt: torch.Tensor, top: int, left: int, height: int, width: int) -> to @_register_kernel_internal(crop, torch.Tensor) @_register_kernel_internal(crop, datapoints.Image) -def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: +def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: h, w = image.shape[-2:] right = left + width @@ -1364,8 +1362,8 @@ def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, wid return image[..., top:bottom, left:right] -crop_image_pil = _FP.crop -_register_kernel_internal(crop, PIL.Image.Image)(crop_image_pil) +_crop_image_pil = _FP.crop +_register_kernel_internal(crop, PIL.Image.Image)(_crop_image_pil) def crop_bounding_boxes( @@ -1407,7 +1405,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) else: needs_squeeze = False - output = crop_image_tensor(mask, top, left, height, width) + output = crop_image(mask, top, left, height, width) if needs_squeeze: output = output.squeeze(0) @@ -1417,7 +1415,7 @@ def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) @_register_kernel_internal(crop, datapoints.Video) def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: - return crop_image_tensor(video, top, left, height, width) + return crop_image(video, top, left, height, width) def perspective( @@ -1429,7 +1427,7 @@ def perspective( coefficients: Optional[List[float]] = None, ) -> torch.Tensor: if torch.jit.is_scripting(): - return perspective_image_tensor( + return perspective_image( inpt, startpoints=startpoints, endpoints=endpoints, @@ -1500,7 +1498,7 @@ def _perspective_coefficients( @_register_kernel_internal(perspective, torch.Tensor) @_register_kernel_internal(perspective, datapoints.Image) -def perspective_image_tensor( +def perspective_image( image: torch.Tensor, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], @@ -1547,7 +1545,7 @@ def perspective_image_tensor( @_register_kernel_internal(perspective, PIL.Image.Image) -def perspective_image_pil( +def _perspective_image_pil( image: PIL.Image.Image, startpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]], @@ -1686,7 +1684,7 @@ def perspective_mask( else: needs_squeeze = False - output = perspective_image_tensor( + output = perspective_image( mask, startpoints, endpoints, interpolation=InterpolationMode.NEAREST, fill=fill, coefficients=coefficients ) @@ -1724,7 +1722,7 @@ def perspective_video( fill: _FillTypeJIT = None, coefficients: Optional[List[float]] = None, ) -> torch.Tensor: - return perspective_image_tensor( + return perspective_image( video, startpoints, endpoints, interpolation=interpolation, fill=fill, coefficients=coefficients ) @@ -1736,7 +1734,7 @@ def elastic( fill: _FillTypeJIT = None, ) -> torch.Tensor: if torch.jit.is_scripting(): - return elastic_image_tensor(inpt, displacement=displacement, interpolation=interpolation, fill=fill) + return elastic_image(inpt, displacement=displacement, interpolation=interpolation, fill=fill) _log_api_usage_once(elastic) @@ -1749,7 +1747,7 @@ def elastic( @_register_kernel_internal(elastic, torch.Tensor) @_register_kernel_internal(elastic, datapoints.Image) -def elastic_image_tensor( +def elastic_image( image: torch.Tensor, displacement: torch.Tensor, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, @@ -1809,14 +1807,14 @@ def elastic_image_tensor( @_register_kernel_internal(elastic, PIL.Image.Image) -def elastic_image_pil( +def _elastic_image_pil( image: PIL.Image.Image, displacement: torch.Tensor, interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: _FillTypeJIT = None, ) -> PIL.Image.Image: t_img = pil_to_tensor(image) - output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) + output = elastic_image(t_img, displacement, interpolation=interpolation, fill=fill) return to_pil_image(output, mode=image.mode) @@ -1910,7 +1908,7 @@ def elastic_mask( else: needs_squeeze = False - output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill) + output = elastic_image(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST, fill=fill) if needs_squeeze: output = output.squeeze(0) @@ -1933,12 +1931,12 @@ def elastic_video( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, fill: _FillTypeJIT = None, ) -> torch.Tensor: - return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) + return elastic_image(video, displacement, interpolation=interpolation, fill=fill) def center_crop(inpt: torch.Tensor, output_size: List[int]) -> torch.Tensor: if torch.jit.is_scripting(): - return center_crop_image_tensor(inpt, output_size=output_size) + return center_crop_image(inpt, output_size=output_size) _log_api_usage_once(center_crop) @@ -1975,7 +1973,7 @@ def _center_crop_compute_crop_anchor( @_register_kernel_internal(center_crop, torch.Tensor) @_register_kernel_internal(center_crop, datapoints.Image) -def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: +def center_crop_image(image: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) shape = image.shape if image.numel() == 0: @@ -1995,20 +1993,20 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor @_register_kernel_internal(center_crop, PIL.Image.Image) -def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: +def _center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - image_height, image_width = get_size_image_pil(image) + image_height, image_width = _get_size_image_pil(image) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) - image = pad_image_pil(image, padding_ltrb, fill=0) + image = _pad_image_pil(image, padding_ltrb, fill=0) - image_height, image_width = get_size_image_pil(image) + image_height, image_width = _get_size_image_pil(image) if crop_width == image_width and crop_height == image_height: return image crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) - return crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) + return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width) def center_crop_bounding_boxes( @@ -2042,7 +2040,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor else: needs_squeeze = False - output = center_crop_image_tensor(image=mask, output_size=output_size) + output = center_crop_image(image=mask, output_size=output_size) if needs_squeeze: output = output.squeeze(0) @@ -2052,7 +2050,7 @@ def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor @_register_kernel_internal(center_crop, datapoints.Video) def center_crop_video(video: torch.Tensor, output_size: List[int]) -> torch.Tensor: - return center_crop_image_tensor(video, output_size) + return center_crop_image(video, output_size) def resized_crop( @@ -2066,7 +2064,7 @@ def resized_crop( antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: if torch.jit.is_scripting(): - return resized_crop_image_tensor( + return resized_crop_image( inpt, top=top, left=left, @@ -2094,7 +2092,7 @@ def resized_crop( @_register_kernel_internal(resized_crop, torch.Tensor) @_register_kernel_internal(resized_crop, datapoints.Image) -def resized_crop_image_tensor( +def resized_crop_image( image: torch.Tensor, top: int, left: int, @@ -2104,11 +2102,11 @@ def resized_crop_image_tensor( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: - image = crop_image_tensor(image, top, left, height, width) - return resize_image_tensor(image, size, interpolation=interpolation, antialias=antialias) + image = crop_image(image, top, left, height, width) + return resize_image(image, size, interpolation=interpolation, antialias=antialias) -def resized_crop_image_pil( +def _resized_crop_image_pil( image: PIL.Image.Image, top: int, left: int, @@ -2117,12 +2115,12 @@ def resized_crop_image_pil( size: List[int], interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, ) -> PIL.Image.Image: - image = crop_image_pil(image, top, left, height, width) - return resize_image_pil(image, size, interpolation=interpolation) + image = _crop_image_pil(image, top, left, height, width) + return _resize_image_pil(image, size, interpolation=interpolation) @_register_kernel_internal(resized_crop, PIL.Image.Image) -def resized_crop_image_pil_dispatch( +def _resized_crop_image_pil_dispatch( image: PIL.Image.Image, top: int, left: int, @@ -2134,7 +2132,7 @@ def resized_crop_image_pil_dispatch( ) -> PIL.Image.Image: if antialias is False: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") - return resized_crop_image_pil( + return _resized_crop_image_pil( image, top=top, left=left, @@ -2201,7 +2199,7 @@ def resized_crop_video( interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR, antialias: Optional[Union[str, bool]] = "warn", ) -> torch.Tensor: - return resized_crop_image_tensor( + return resized_crop_image( video, top, left, height, width, antialias=antialias, size=size, interpolation=interpolation ) @@ -2210,7 +2208,7 @@ def five_crop( inpt: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: if torch.jit.is_scripting(): - return five_crop_image_tensor(inpt, size=size) + return five_crop_image(inpt, size=size) _log_api_usage_once(five_crop) @@ -2234,7 +2232,7 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: @_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor) @_register_five_ten_crop_kernel_internal(five_crop, datapoints.Image) -def five_crop_image_tensor( +def five_crop_image( image: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: crop_height, crop_width = _parse_five_crop_size(size) @@ -2243,30 +2241,30 @@ def five_crop_image_tensor( if crop_width > image_width or crop_height > image_height: raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") - tl = crop_image_tensor(image, 0, 0, crop_height, crop_width) - tr = crop_image_tensor(image, 0, image_width - crop_width, crop_height, crop_width) - bl = crop_image_tensor(image, image_height - crop_height, 0, crop_height, crop_width) - br = crop_image_tensor(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) - center = center_crop_image_tensor(image, [crop_height, crop_width]) + tl = crop_image(image, 0, 0, crop_height, crop_width) + tr = crop_image(image, 0, image_width - crop_width, crop_height, crop_width) + bl = crop_image(image, image_height - crop_height, 0, crop_height, crop_width) + br = crop_image(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = center_crop_image(image, [crop_height, crop_width]) return tl, tr, bl, br, center @_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image) -def five_crop_image_pil( +def _five_crop_image_pil( image: PIL.Image.Image, size: List[int] ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: crop_height, crop_width = _parse_five_crop_size(size) - image_height, image_width = get_size_image_pil(image) + image_height, image_width = _get_size_image_pil(image) if crop_width > image_width or crop_height > image_height: raise ValueError(f"Requested crop size {size} is bigger than input size {(image_height, image_width)}") - tl = crop_image_pil(image, 0, 0, crop_height, crop_width) - tr = crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width) - bl = crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width) - br = crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) - center = center_crop_image_pil(image, [crop_height, crop_width]) + tl = _crop_image_pil(image, 0, 0, crop_height, crop_width) + tr = _crop_image_pil(image, 0, image_width - crop_width, crop_height, crop_width) + bl = _crop_image_pil(image, image_height - crop_height, 0, crop_height, crop_width) + br = _crop_image_pil(image, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = _center_crop_image_pil(image, [crop_height, crop_width]) return tl, tr, bl, br, center @@ -2275,7 +2273,7 @@ def five_crop_image_pil( def five_crop_video( video: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return five_crop_image_tensor(video, size) + return five_crop_image(video, size) def ten_crop( @@ -2293,7 +2291,7 @@ def ten_crop( torch.Tensor, ]: if torch.jit.is_scripting(): - return ten_crop_image_tensor(inpt, size=size, vertical_flip=vertical_flip) + return ten_crop_image(inpt, size=size, vertical_flip=vertical_flip) _log_api_usage_once(ten_crop) @@ -2303,7 +2301,7 @@ def ten_crop( @_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor) @_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Image) -def ten_crop_image_tensor( +def ten_crop_image( image: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ torch.Tensor, @@ -2317,20 +2315,20 @@ def ten_crop_image_tensor( torch.Tensor, torch.Tensor, ]: - non_flipped = five_crop_image_tensor(image, size) + non_flipped = five_crop_image(image, size) if vertical_flip: - image = vertical_flip_image_tensor(image) + image = vertical_flip_image(image) else: - image = horizontal_flip_image_tensor(image) + image = horizontal_flip_image(image) - flipped = five_crop_image_tensor(image, size) + flipped = five_crop_image(image, size) return non_flipped + flipped @_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image) -def ten_crop_image_pil( +def _ten_crop_image_pil( image: PIL.Image.Image, size: List[int], vertical_flip: bool = False ) -> Tuple[ PIL.Image.Image, @@ -2344,14 +2342,14 @@ def ten_crop_image_pil( PIL.Image.Image, PIL.Image.Image, ]: - non_flipped = five_crop_image_pil(image, size) + non_flipped = _five_crop_image_pil(image, size) if vertical_flip: - image = vertical_flip_image_pil(image) + image = _vertical_flip_image_pil(image) else: - image = horizontal_flip_image_pil(image) + image = _horizontal_flip_image_pil(image) - flipped = five_crop_image_pil(image, size) + flipped = _five_crop_image_pil(image, size) return non_flipped + flipped @@ -2371,4 +2369,4 @@ def ten_crop_video( torch.Tensor, torch.Tensor, ]: - return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip) + return ten_crop_image(video, size, vertical_flip=vertical_flip) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index 89b19d9e887..f2675728ce3 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -13,7 +13,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]: if torch.jit.is_scripting(): - return get_dimensions_image_tensor(inpt) + return get_dimensions_image(inpt) _log_api_usage_once(get_dimensions) @@ -23,7 +23,7 @@ def get_dimensions(inpt: torch.Tensor) -> List[int]: @_register_kernel_internal(get_dimensions, torch.Tensor) @_register_kernel_internal(get_dimensions, datapoints.Image, datapoint_wrapper=False) -def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: +def get_dimensions_image(image: torch.Tensor) -> List[int]: chw = list(image.shape[-3:]) ndims = len(chw) if ndims == 3: @@ -35,17 +35,17 @@ def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]: raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") -get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions) +_get_dimensions_image_pil = _register_kernel_internal(get_dimensions, PIL.Image.Image)(_FP.get_dimensions) @_register_kernel_internal(get_dimensions, datapoints.Video, datapoint_wrapper=False) def get_dimensions_video(video: torch.Tensor) -> List[int]: - return get_dimensions_image_tensor(video) + return get_dimensions_image(video) def get_num_channels(inpt: torch.Tensor) -> int: if torch.jit.is_scripting(): - return get_num_channels_image_tensor(inpt) + return get_num_channels_image(inpt) _log_api_usage_once(get_num_channels) @@ -55,7 +55,7 @@ def get_num_channels(inpt: torch.Tensor) -> int: @_register_kernel_internal(get_num_channels, torch.Tensor) @_register_kernel_internal(get_num_channels, datapoints.Image, datapoint_wrapper=False) -def get_num_channels_image_tensor(image: torch.Tensor) -> int: +def get_num_channels_image(image: torch.Tensor) -> int: chw = image.shape[-3:] ndims = len(chw) if ndims == 3: @@ -66,12 +66,12 @@ def get_num_channels_image_tensor(image: torch.Tensor) -> int: raise TypeError(f"Input tensor should have at least two dimensions, but got {ndims}") -get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels) +_get_num_channels_image_pil = _register_kernel_internal(get_num_channels, PIL.Image.Image)(_FP.get_image_num_channels) @_register_kernel_internal(get_num_channels, datapoints.Video, datapoint_wrapper=False) def get_num_channels_video(video: torch.Tensor) -> int: - return get_num_channels_image_tensor(video) + return get_num_channels_image(video) # We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without @@ -81,7 +81,7 @@ def get_num_channels_video(video: torch.Tensor) -> int: def get_size(inpt: torch.Tensor) -> List[int]: if torch.jit.is_scripting(): - return get_size_image_tensor(inpt) + return get_size_image(inpt) _log_api_usage_once(get_size) @@ -91,7 +91,7 @@ def get_size(inpt: torch.Tensor) -> List[int]: @_register_kernel_internal(get_size, torch.Tensor) @_register_kernel_internal(get_size, datapoints.Image, datapoint_wrapper=False) -def get_size_image_tensor(image: torch.Tensor) -> List[int]: +def get_size_image(image: torch.Tensor) -> List[int]: hw = list(image.shape[-2:]) ndims = len(hw) if ndims == 2: @@ -101,19 +101,19 @@ def get_size_image_tensor(image: torch.Tensor) -> List[int]: @_register_kernel_internal(get_size, PIL.Image.Image) -def get_size_image_pil(image: PIL.Image.Image) -> List[int]: +def _get_size_image_pil(image: PIL.Image.Image) -> List[int]: width, height = _FP.get_image_size(image) return [height, width] @_register_kernel_internal(get_size, datapoints.Video, datapoint_wrapper=False) def get_size_video(video: torch.Tensor) -> List[int]: - return get_size_image_tensor(video) + return get_size_image(video) @_register_kernel_internal(get_size, datapoints.Mask, datapoint_wrapper=False) def get_size_mask(mask: torch.Tensor) -> List[int]: - return get_size_image_tensor(mask) + return get_size_image(mask) @_register_kernel_internal(get_size, datapoints.BoundingBoxes, datapoint_wrapper=False) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index 658b61cedb0..331817bb028 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -21,7 +21,7 @@ def normalize( inplace: bool = False, ) -> torch.Tensor: if torch.jit.is_scripting(): - return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace) + return normalize_image(inpt, mean=mean, std=std, inplace=inplace) _log_api_usage_once(normalize) @@ -31,9 +31,7 @@ def normalize( @_register_kernel_internal(normalize, torch.Tensor) @_register_kernel_internal(normalize, datapoints.Image) -def normalize_image_tensor( - image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False -) -> torch.Tensor: +def normalize_image(image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: if not image.is_floating_point(): raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.") @@ -68,12 +66,12 @@ def normalize_image_tensor( @_register_kernel_internal(normalize, datapoints.Video) def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: - return normalize_image_tensor(video, mean, std, inplace=inplace) + return normalize_image(video, mean, std, inplace=inplace) def gaussian_blur(inpt: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> torch.Tensor: if torch.jit.is_scripting(): - return gaussian_blur_image_tensor(inpt, kernel_size=kernel_size, sigma=sigma) + return gaussian_blur_image(inpt, kernel_size=kernel_size, sigma=sigma) _log_api_usage_once(gaussian_blur) @@ -99,7 +97,7 @@ def _get_gaussian_kernel2d( @_register_kernel_internal(gaussian_blur, torch.Tensor) @_register_kernel_internal(gaussian_blur, datapoints.Image) -def gaussian_blur_image_tensor( +def gaussian_blur_image( image: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: # TODO: consider deprecating integers from sigma on the future @@ -164,11 +162,11 @@ def gaussian_blur_image_tensor( @_register_kernel_internal(gaussian_blur, PIL.Image.Image) -def gaussian_blur_image_pil( +def _gaussian_blur_image_pil( image: PIL.Image.Image, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> PIL.Image.Image: t_img = pil_to_tensor(image) - output = gaussian_blur_image_tensor(t_img, kernel_size=kernel_size, sigma=sigma) + output = gaussian_blur_image(t_img, kernel_size=kernel_size, sigma=sigma) return to_pil_image(output, mode=image.mode) @@ -176,12 +174,12 @@ def gaussian_blur_image_pil( def gaussian_blur_video( video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: - return gaussian_blur_image_tensor(video, kernel_size, sigma) + return gaussian_blur_image(video, kernel_size, sigma) def to_dtype(inpt: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: if torch.jit.is_scripting(): - return to_dtype_image_tensor(inpt, dtype=dtype, scale=scale) + return to_dtype_image(inpt, dtype=dtype, scale=scale) _log_api_usage_once(to_dtype) @@ -206,7 +204,7 @@ def _num_value_bits(dtype: torch.dtype) -> int: @_register_kernel_internal(to_dtype, torch.Tensor) @_register_kernel_internal(to_dtype, datapoints.Image) -def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: +def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: if image.dtype == dtype: return image @@ -260,12 +258,12 @@ def to_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float, # We encourage users to use to_dtype() instead but we keep this for BC def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor: - return to_dtype_image_tensor(image, dtype=dtype, scale=True) + return to_dtype_image(image, dtype=dtype, scale=True) @_register_kernel_internal(to_dtype, datapoints.Video) def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale: bool = False) -> torch.Tensor: - return to_dtype_image_tensor(video, dtype, scale=scale) + return to_dtype_image(video, dtype, scale=scale) @_register_kernel_internal(to_dtype, datapoints.BoundingBoxes, datapoint_wrapper=False) diff --git a/torchvision/transforms/v2/functional/_type_conversion.py b/torchvision/transforms/v2/functional/_type_conversion.py index 67572cf4a72..1f908353db3 100644 --- a/torchvision/transforms/v2/functional/_type_conversion.py +++ b/torchvision/transforms/v2/functional/_type_conversion.py @@ -8,7 +8,7 @@ @torch.jit.unused -def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image: +def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> datapoints.Image: if isinstance(inpt, np.ndarray): output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() elif isinstance(inpt, PIL.Image.Image): @@ -20,9 +20,5 @@ def to_image_tensor(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> d return datapoints.Image(output) -to_image_pil = _F.to_pil_image +to_pil_image = _F.to_pil_image pil_to_tensor = _F.pil_to_tensor - -# We changed the names to align them with the new naming scheme. Still, `to_pil_image` is -# prevalent and well understood. Thus, we just alias it without deprecating the old name. -to_pil_image = to_image_pil From bda807d5a49cca8382a13a835ece2813e9c320ae Mon Sep 17 00:00:00 2001 From: Omkar Salpekar Date: Wed, 16 Aug 2023 15:01:45 -0400 Subject: [PATCH 2/3] Pre-Script Update for Aarch64 (#7834) --- packaging/pre_build_script.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packaging/pre_build_script.sh b/packaging/pre_build_script.sh index 43f60e51064..7d38f2cb4e9 100644 --- a/packaging/pre_build_script.sh +++ b/packaging/pre_build_script.sh @@ -11,7 +11,7 @@ if [[ "$(uname)" == Darwin ]]; then conda install -yq wget fi -if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" ]]; then +if [[ "$(uname)" == Darwin || "$OSTYPE" == "msys" || "$ARCH" == "aarch64" ]]; then # Install libpng from Anaconda (defaults) conda install libpng -yq conda install -yq ffmpeg=4.2 libjpeg-turbo -c pytorch From 4cba51c5254407ec1e460a8cbb4ea06700e92637 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Aug 2023 21:48:49 +0200 Subject: [PATCH 3/3] fix elastic tests (#7841) --- test/test_transforms_v2_refactored.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 9d359e59559..339725327bd 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -2279,7 +2279,7 @@ def test_kernel_image_tensor(self, param, value, dtype, device): image = make_image_tensor(dtype=dtype, device=device) check_kernel( - F.elastic_image_tensor, + F.elastic_image, image, displacement=self._make_displacement(image), **{param: value}, @@ -2320,9 +2320,9 @@ def test_functional(self, make_input): @pytest.mark.parametrize( ("kernel", "input_type"), [ - (F.elastic_image_tensor, torch.Tensor), - (F.elastic_image_pil, PIL.Image.Image), - (F.elastic_image_tensor, datapoints.Image), + (F.elastic_image, torch.Tensor), + (F._elastic_image_pil, PIL.Image.Image), + (F.elastic_image, datapoints.Image), (F.elastic_bounding_boxes, datapoints.BoundingBoxes), (F.elastic_mask, datapoints.Mask), (F.elastic_video, datapoints.Video),