From 464e409d21d41c4e67147d8ad47340e74fc8e1fd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 17 Aug 2023 13:11:54 +0100 Subject: [PATCH 1/3] simple tensor -> pure tensor --- test/test_prototype_datasets_builtin.py | 2 +- test/test_transforms_v2.py | 4 ++-- torchvision/transforms/v2/_transform.py | 10 +++++----- torchvision/transforms/v2/functional/_meta.py | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_prototype_datasets_builtin.py b/test/test_prototype_datasets_builtin.py index 4d19b67967f..d6e4b38892c 100644 --- a/test/test_prototype_datasets_builtin.py +++ b/test/test_prototype_datasets_builtin.py @@ -151,7 +151,7 @@ def test_no_unaccompanied_simple_tensors(self, dataset_mock, config): ): raise AssertionError( f"The values of key(s) " - f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained simple tensors, " + f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained pure tensors, " f"but didn't find any (encoded) image or video." ) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index ade3bdf0b51..1615d54e8aa 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -360,8 +360,8 @@ def test_random_resized_crop(self, transform, input): def test_simple_tensor_heuristic(flat_inputs): def split_on_simple_tensor(to_split): # This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts: - # 1. The first simple tensor. If none is present, this will be `None` - # 2. A list of the remaining simple tensors + # 1. The first pure tensor. If none is present, this will be `None` + # 2. A list of the remaining pure tensors # 3. A list of all other items simple_tensors = [] others = [] diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index d4ee8af556d..c1baad51978 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -55,20 +55,20 @@ def forward(self, *inputs: Any) -> Any: return tree_unflatten(flat_outputs, spec) def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: - # Below is a heuristic on how to deal with simple tensor inputs: + # Below is a heuristic on how to deal with pure tensor inputs: # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. - # 2. If there is no explicit image or video in the sample, only the first encountered simple tensor is + # 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs` # of `tree_flatten`, which recurses depth-first through the input. # # This heuristic stems from two requirements: - # 1. We need to keep BC for single input simple tensors and treat them as images. - # 2. We don't want to treat all simple tensors as images, because some datasets like `CelebA` or `Widerface` + # 1. We need to keep BC for single input pure tensors and treat them as images. + # 2. We don't want to treat all pure tensors as images, because some datasets like `CelebA` or `Widerface` # return supplemental numerical data as tensors that cannot be transformed as images. # # The heuristic should work well for most people in practice. The only case where it doesn't is if someone - # tries to transform multiple simple tensors at the same time, expecting them all to be treated as images. + # tries to transform multiple pure tensors at the same time, expecting them all to be treated as images. # However, this case wasn't supported by transforms v1 either, so there is no BC concern. needs_transform_list = [] diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index f2675728ce3..b5b1ad651bf 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -203,7 +203,7 @@ def convert_format_bounding_boxes( new_format: Optional[BoundingBoxFormat] = None, inplace: bool = False, ) -> torch.Tensor: - # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for simple tensor + # This being a kernel / functional hybrid, we need an option to pass `old_format` explicitly for pure tensor # inputs as well as extract it from `datapoints.BoundingBoxes` inputs. However, putting a default value on # `old_format` means we also need to put one on `new_format` to have syntactically correct Python. Here we mimic the # default error that would be thrown if `new_format` had no default value. @@ -215,7 +215,7 @@ def convert_format_bounding_boxes( if torch.jit.is_scripting() or is_simple_tensor(inpt): if old_format is None: - raise ValueError("For simple tensor inputs, `old_format` has to be passed.") + raise ValueError("For pure tensor inputs, `old_format` has to be passed.") return _convert_format_bounding_boxes(inpt, old_format=old_format, new_format=new_format, inplace=inplace) elif isinstance(inpt, datapoints.BoundingBoxes): if old_format is not None: @@ -259,7 +259,7 @@ def clamp_bounding_boxes( if torch.jit.is_scripting() or is_simple_tensor(inpt): if format is None or canvas_size is None: - raise ValueError("For simple tensor inputs, `format` and `canvas_size` has to be passed.") + raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.") return _clamp_bounding_boxes(inpt, format=format, canvas_size=canvas_size) elif isinstance(inpt, datapoints.BoundingBoxes): if format is not None or canvas_size is not None: From f93c56503553765412cccc864cca99c3485bda2c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 17 Aug 2023 13:12:13 +0100 Subject: [PATCH 2/3] simple_tensor -> pure_tensor --- test/test_prototype_datasets_builtin.py | 10 +++--- test/test_prototype_transforms.py | 6 ++-- test/test_transforms_v2.py | 34 +++++++++---------- test/test_transforms_v2_consistency.py | 2 +- test/test_transforms_v2_functional.py | 26 +++++++------- test/test_transforms_v2_utils.py | 6 ++-- test/transforms_v2_dispatcher_infos.py | 2 +- torchvision/prototype/transforms/_augment.py | 6 ++-- torchvision/prototype/transforms/_geometry.py | 4 +-- torchvision/prototype/transforms/_misc.py | 6 ++-- torchvision/transforms/v2/_augment.py | 6 ++-- torchvision/transforms/v2/_auto_augment.py | 4 +-- torchvision/transforms/v2/_geometry.py | 4 +-- torchvision/transforms/v2/_misc.py | 8 ++--- torchvision/transforms/v2/_transform.py | 10 +++--- torchvision/transforms/v2/_type_conversion.py | 6 ++-- .../transforms/v2/functional/__init__.py | 2 +- torchvision/transforms/v2/functional/_meta.py | 6 ++-- .../transforms/v2/functional/_utils.py | 2 +- torchvision/transforms/v2/utils.py | 6 ++-- 20 files changed, 78 insertions(+), 78 deletions(-) diff --git a/test/test_prototype_datasets_builtin.py b/test/test_prototype_datasets_builtin.py index d6e4b38892c..e29dfb17fe1 100644 --- a/test/test_prototype_datasets_builtin.py +++ b/test/test_prototype_datasets_builtin.py @@ -25,7 +25,7 @@ from torchvision.prototype.datapoints import Label from torchvision.prototype.datasets.utils import EncodedImage from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE -from torchvision.transforms.v2.utils import is_simple_tensor +from torchvision.transforms.v2.utils import is_pure_tensor def assert_samples_equal(*args, msg=None, **kwargs): @@ -140,18 +140,18 @@ def make_msg_and_close(head): raise AssertionError(make_msg_and_close("The following streams were not closed after a full iteration:")) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_no_unaccompanied_simple_tensors(self, dataset_mock, config): + def test_no_unaccompanied_pure_tensors(self, dataset_mock, config): dataset, _ = dataset_mock.load(config) sample = next_consume(iter(dataset)) - simple_tensors = {key for key, value in sample.items() if is_simple_tensor(value)} + pure_tensors = {key for key, value in sample.items() if is_pure_tensor(value)} - if simple_tensors and not any( + if pure_tensors and not any( isinstance(item, (datapoints.Image, datapoints.Video, EncodedImage)) for item in sample.values() ): raise AssertionError( f"The values of key(s) " - f"{sequence_to_str(sorted(simple_tensors), separate_last='and ')} contained pure tensors, " + f"{sequence_to_str(sorted(pure_tensors), separate_last='and ')} contained pure tensors, " f"but didn't find any (encoded) image or video." ) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 32a68e14017..bf45970df97 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -18,7 +18,7 @@ from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.prototype import datapoints, transforms from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_pil_image -from torchvision.transforms.v2.utils import check_type, is_simple_tensor +from torchvision.transforms.v2.utils import check_type, is_pure_tensor BATCH_EXTRA_DIMS = [extra_dims for extra_dims in DEFAULT_EXTRA_DIMS if extra_dims] @@ -296,7 +296,7 @@ def test_call(self, dims, inverse_dims): value_type = type(value) transformed_value = transformed_sample[key] - if check_type(value, (Image, is_simple_tensor, Video)): + if check_type(value, (Image, is_pure_tensor, Video)): if transform.dims.get(value_type) is not None: assert transformed_value.permute(inverse_dims[value_type]).equal(value) assert type(transformed_value) == torch.Tensor @@ -341,7 +341,7 @@ def test_call(self, dims): transformed_value = transformed_sample[key] transposed_dims = transform.dims.get(value_type) - if check_type(value, (Image, is_simple_tensor, Video)): + if check_type(value, (Image, is_pure_tensor, Video)): if transposed_dims is not None: assert transformed_value.transpose(*transposed_dims).equal(value) assert type(transformed_value) == torch.Tensor diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 1615d54e8aa..d7a6f21bbe7 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -29,7 +29,7 @@ from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import to_pil_image from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw +from torchvision.transforms.v2.utils import check_type, is_pure_tensor, query_chw def make_vanilla_tensor_images(*args, **kwargs): @@ -71,7 +71,7 @@ def auto_augment_adapter(transform, input, device): if isinstance(value, (datapoints.BoundingBoxes, datapoints.Mask)): # AA transforms don't support bounding boxes or masks continue - elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)): + elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor, PIL.Image.Image)): if image_or_video_found: # AA transforms only support a single image or video continue @@ -101,7 +101,7 @@ def normalize_adapter(transform, input, device): if isinstance(value, PIL.Image.Image): # normalize doesn't support PIL images continue - elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor)): + elif check_type(value, (datapoints.Image, datapoints.Video, is_pure_tensor)): # normalize doesn't support integer images value = F.to_dtype(value, torch.float32, scale=True) adapted_input[key] = value @@ -357,19 +357,19 @@ def test_random_resized_crop(self, transform, input): 3, ), ) -def test_simple_tensor_heuristic(flat_inputs): - def split_on_simple_tensor(to_split): +def test_pure_tensor_heuristic(flat_inputs): + def split_on_pure_tensor(to_split): # This takes a sequence that is structurally aligned with `flat_inputs` and splits its items into three parts: # 1. The first pure tensor. If none is present, this will be `None` # 2. A list of the remaining pure tensors # 3. A list of all other items - simple_tensors = [] + pure_tensors = [] others = [] # Splitting always happens on the original `flat_inputs` to avoid any erroneous type changes by the transform to # affect the splitting. for item, inpt in zip(to_split, flat_inputs): - (simple_tensors if is_simple_tensor(inpt) else others).append(item) - return simple_tensors[0] if simple_tensors else None, simple_tensors[1:], others + (pure_tensors if is_pure_tensor(inpt) else others).append(item) + return pure_tensors[0] if pure_tensors else None, pure_tensors[1:], others class CopyCloneTransform(transforms.Transform): def _transform(self, inpt, params): @@ -385,20 +385,20 @@ def was_applied(output, inpt): assert_equal(output, inpt) return True - first_simple_tensor_input, other_simple_tensor_inputs, other_inputs = split_on_simple_tensor(flat_inputs) + first_pure_tensor_input, other_pure_tensor_inputs, other_inputs = split_on_pure_tensor(flat_inputs) transform = CopyCloneTransform() transformed_sample = transform(flat_inputs) - first_simple_tensor_output, other_simple_tensor_outputs, other_outputs = split_on_simple_tensor(transformed_sample) + first_pure_tensor_output, other_pure_tensor_outputs, other_outputs = split_on_pure_tensor(transformed_sample) - if first_simple_tensor_input is not None: + if first_pure_tensor_input is not None: if other_inputs: - assert not transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) + assert not transform.was_applied(first_pure_tensor_output, first_pure_tensor_input) else: - assert transform.was_applied(first_simple_tensor_output, first_simple_tensor_input) + assert transform.was_applied(first_pure_tensor_output, first_pure_tensor_input) - for output, inpt in zip(other_simple_tensor_outputs, other_simple_tensor_inputs): + for output, inpt in zip(other_pure_tensor_outputs, other_pure_tensor_inputs): assert not transform.was_applied(output, inpt) for input, output in zip(other_inputs, other_outputs): @@ -1004,7 +1004,7 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): image = to_pil_image(image[0]) elif image_type is torch.Tensor: image = image.as_subclass(torch.Tensor) - assert is_simple_tensor(image) + assert is_pure_tensor(image) label = 1 if label_type is int else torch.tensor([1]) @@ -1125,7 +1125,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): image = to_pil_image(image[0]) elif image_type is torch.Tensor: image = image.as_subclass(torch.Tensor) - assert is_simple_tensor(image) + assert is_pure_tensor(image) label = torch.randint(0, 10, size=(num_boxes,)) @@ -1146,7 +1146,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): out = t(sample) if isinstance(to_tensor, transforms.ToTensor) and image_type is not datapoints.Image: - assert is_simple_tensor(out["image"]) + assert is_pure_tensor(out["image"]) else: assert isinstance(out["image"], datapoints.Image) assert isinstance(out["label"], type(sample["label"])) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 5855fbe447f..3196a5fd82c 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -602,7 +602,7 @@ def check_call_consistency( raise AssertionError( f"Transforming a tensor image with shape {image_repr} failed in the prototype transform with " f"the error above. This means there is a consistency bug either in `_get_params` or in the " - f"`is_simple_tensor` path in `_transform`." + f"`is_pure_tensor` path in `_transform`." ) from exc assert_close( diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index 14a1f82b2cf..29ef54d925a 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -24,7 +24,7 @@ from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2.functional._geometry import _center_crop_compute_padding from torchvision.transforms.v2.functional._meta import clamp_bounding_boxes, convert_format_bounding_boxes -from torchvision.transforms.v2.utils import is_simple_tensor +from torchvision.transforms.v2.utils import is_pure_tensor from transforms_v2_dispatcher_infos import DISPATCHER_INFOS from transforms_v2_kernel_infos import KERNEL_INFOS @@ -168,7 +168,7 @@ def _unbatch(self, batch, *, data_dims): def test_batched_vs_single(self, test_id, info, args_kwargs, device): (batched_input, *other_args), kwargs = args_kwargs.load(device) - datapoint_type = datapoints.Image if is_simple_tensor(batched_input) else type(batched_input) + datapoint_type = datapoints.Image if is_pure_tensor(batched_input) else type(batched_input) # This dictionary contains the number of rightmost dimensions that contain the actual data. # Everything to the left is considered a batch dimension. data_dims = { @@ -333,9 +333,9 @@ def test_scripted_smoke(self, info, args_kwargs, device): dispatcher = script(info.dispatcher) (image_datapoint, *other_args), kwargs = args_kwargs.load(device) - image_simple_tensor = torch.Tensor(image_datapoint) + image_pure_tensor = torch.Tensor(image_datapoint) - dispatcher(image_simple_tensor, *other_args, **kwargs) + dispatcher(image_pure_tensor, *other_args, **kwargs) # TODO: We need this until the dispatchers below also have `DispatcherInfo`'s. If they do, `test_scripted_smoke` # replaces this test for them. @@ -358,11 +358,11 @@ def test_scriptable(self, dispatcher): script(dispatcher) @image_sample_inputs - def test_simple_tensor_output_type(self, info, args_kwargs): + def test_pure_tensor_output_type(self, info, args_kwargs): (image_datapoint, *other_args), kwargs = args_kwargs.load() - image_simple_tensor = image_datapoint.as_subclass(torch.Tensor) + image_pure_tensor = image_datapoint.as_subclass(torch.Tensor) - output = info.dispatcher(image_simple_tensor, *other_args, **kwargs) + output = info.dispatcher(image_pure_tensor, *other_args, **kwargs) # We cannot use `isinstance` here since all datapoints are instances of `torch.Tensor` as well assert type(output) is torch.Tensor @@ -505,11 +505,11 @@ class TestClampBoundingBoxes: dict(canvas_size=(1, 1)), ], ) - def test_simple_tensor_insufficient_metadata(self, metadata): - simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) + def test_pure_tensor_insufficient_metadata(self, metadata): + pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) with pytest.raises(ValueError, match=re.escape("`format` and `canvas_size` has to be passed")): - F.clamp_bounding_boxes(simple_tensor, **metadata) + F.clamp_bounding_boxes(pure_tensor, **metadata) @pytest.mark.parametrize( "metadata", @@ -538,11 +538,11 @@ def test_missing_new_format(self, inpt, old_format): with pytest.raises(TypeError, match=re.escape("missing 1 required argument: 'new_format'")): F.convert_format_bounding_boxes(inpt, old_format) - def test_simple_tensor_insufficient_metadata(self): - simple_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) + def test_pure_tensor_insufficient_metadata(self): + pure_tensor = next(make_bounding_boxes()).as_subclass(torch.Tensor) with pytest.raises(ValueError, match=re.escape("`old_format` has to be passed")): - F.convert_format_bounding_boxes(simple_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) + F.convert_format_bounding_boxes(pure_tensor, new_format=datapoints.BoundingBoxFormat.CXCYWH) def test_datapoint_explicit_metadata(self): datapoint = next(make_bounding_boxes()) diff --git a/test/test_transforms_v2_utils.py b/test/test_transforms_v2_utils.py index 0cf7a77ac0d..0cfe0db7077 100644 --- a/test/test_transforms_v2_utils.py +++ b/test/test_transforms_v2_utils.py @@ -37,15 +37,15 @@ ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, datapoints.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True), - ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), True), + ((IMAGE,), (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True), ( (torch.Tensor(IMAGE),), - (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), + (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True, ), ( (to_pil_image(IMAGE),), - (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_simple_tensor), + (datapoints.Image, PIL.Image.Image, torchvision.transforms.v2.utils.is_pure_tensor), True, ), ], diff --git a/test/transforms_v2_dispatcher_infos.py b/test/transforms_v2_dispatcher_infos.py index 8f212c850cb..903518627de 100644 --- a/test/transforms_v2_dispatcher_infos.py +++ b/test/transforms_v2_dispatcher_infos.py @@ -107,7 +107,7 @@ def xfail_jit_python_scalar_arg(name, *, reason=None): ("TestDispatchers", test_name), pytest.mark.skip(reason="Multi-crop dispatchers return a sequence of items rather than a single one."), ) - for test_name in ["test_simple_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"] + for test_name in ["test_pure_tensor_output_type", "test_pil_output_type", "test_datapoint_output_type"] ] multi_crop_skips.append(skip_dispatch_datapoint) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 81f726a2dbd..eaa181b6717 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -9,7 +9,7 @@ from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2.functional._geometry import _check_interpolation -from torchvision.transforms.v2.utils import is_simple_tensor +from torchvision.transforms.v2.utils import is_pure_tensor class SimpleCopyPaste(Transform): @@ -109,7 +109,7 @@ def _extract_image_targets( # with List[image], List[BoundingBoxes], List[Mask], List[Label] images, bboxes, masks, labels = [], [], [], [] for obj in flat_sample: - if isinstance(obj, datapoints.Image) or is_simple_tensor(obj): + if isinstance(obj, datapoints.Image) or is_pure_tensor(obj): images.append(obj) elif isinstance(obj, PIL.Image.Image): images.append(F.to_image(obj)) @@ -146,7 +146,7 @@ def _insert_outputs( elif isinstance(obj, PIL.Image.Image): flat_sample[i] = F.to_pil_image(output_images[c0]) c0 += 1 - elif is_simple_tensor(obj): + elif is_pure_tensor(obj): flat_sample[i] = output_images[c0] c0 += 1 elif isinstance(obj, datapoints.BoundingBoxes): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 8d8e7eb42f0..1350b6d1bd1 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -7,7 +7,7 @@ from torchvision.prototype.datapoints import Label, OneHotLabel from torchvision.transforms.v2 import functional as F, Transform from torchvision.transforms.v2._utils import _FillType, _get_fill, _setup_fill_arg, _setup_size -from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_simple_tensor, query_size +from torchvision.transforms.v2.utils import get_bounding_boxes, has_any, is_pure_tensor, query_size class FixedSizeCrop(Transform): @@ -32,7 +32,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: flat_inputs, PIL.Image.Image, datapoints.Image, - is_simple_tensor, + is_pure_tensor, datapoints.Video, ): raise TypeError( diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index f1b859aac03..0dd495ab05b 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -8,7 +8,7 @@ from torchvision import datapoints from torchvision.transforms.v2 import Transform -from torchvision.transforms.v2.utils import is_simple_tensor +from torchvision.transforms.v2.utils import is_pure_tensor T = TypeVar("T") @@ -25,7 +25,7 @@ def _get_defaultdict(default: T) -> Dict[Any, T]: class PermuteDimensions(Transform): - _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) + _transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video) def __init__(self, dims: Union[Sequence[int], Dict[Type, Optional[Sequence[int]]]]) -> None: super().__init__() @@ -47,7 +47,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> torch.Tensor: class TransposeDimensions(Transform): - _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) + _transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video) def __init__(self, dims: Union[Tuple[int, int], Dict[Type, Optional[Tuple[int, int]]]]) -> None: super().__init__() diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index f64ae564b54..51ca4c14591 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -12,7 +12,7 @@ from ._transform import _RandomApplyTransform, Transform from ._utils import _parse_labels_getter -from .utils import has_any, is_simple_tensor, query_chw, query_size +from .utils import has_any, is_pure_tensor, query_chw, query_size class RandomErasing(_RandomApplyTransform): @@ -243,7 +243,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if inpt is params["labels"]: return self._mixup_label(inpt, lam=lam) - elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): + elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt): self._check_image_or_video(inpt, batch_size=params["batch_size"]) output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) @@ -310,7 +310,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if inpt is params["labels"]: return self._mixup_label(inpt, lam=params["lam_adjusted"]) - elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): + elif isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_pure_tensor(inpt): self._check_image_or_video(inpt, batch_size=params["batch_size"]) x1, y1, x2, y2 = params["box"] diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 687a2396e67..097e90fc4ab 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -13,7 +13,7 @@ from torchvision.transforms.v2.functional._utils import _FillType, _FillTypeJIT from ._utils import _get_fill, _setup_fill_arg -from .utils import check_type, is_simple_tensor +from .utils import check_type, is_pure_tensor ImageOrVideo = Union[torch.Tensor, PIL.Image.Image, datapoints.Image, datapoints.Video] @@ -50,7 +50,7 @@ def _flatten_and_extract_image_or_video( ( datapoints.Image, PIL.Image.Image, - is_simple_tensor, + is_pure_tensor, datapoints.Video, ), ): diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index f441a0b747b..0be62ae8a12 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -24,7 +24,7 @@ _setup_float_or_seq, _setup_size, ) -from .utils import get_bounding_boxes, has_all, has_any, is_simple_tensor, query_size +from .utils import get_bounding_boxes, has_all, has_any, is_pure_tensor, query_size class RandomHorizontalFlip(_RandomApplyTransform): @@ -1149,7 +1149,7 @@ def __init__( def _check_inputs(self, flat_inputs: List[Any]) -> None: if not ( has_all(flat_inputs, datapoints.BoundingBoxes) - and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_simple_tensor) + and has_any(flat_inputs, PIL.Image.Image, datapoints.Image, is_pure_tensor) ): raise TypeError( f"{type(self).__name__}() requires input sample to contain tensor or PIL images " diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index ef9ac5fd0c7..405fbc6c43a 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -10,7 +10,7 @@ from torchvision.transforms.v2 import functional as F, Transform from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size -from .utils import get_bounding_boxes, has_any, is_simple_tensor +from .utils import get_bounding_boxes, has_any, is_pure_tensor # TODO: do we want/need to expose this? @@ -75,7 +75,7 @@ class LinearTransformation(Transform): _v1_transform_cls = _transforms.LinearTransformation - _transformed_types = (is_simple_tensor, datapoints.Image, datapoints.Video) + _transformed_types = (is_pure_tensor, datapoints.Image, datapoints.Video) def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tensor): super().__init__() @@ -264,7 +264,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(self.dtype, torch.dtype): # For consistency / BC with ConvertImageDtype, we only care about images or videos when dtype # is a simple torch.dtype - if not is_simple_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)): + if not is_pure_tensor(inpt) and not isinstance(inpt, (datapoints.Image, datapoints.Video)): return inpt dtype: Optional[torch.dtype] = self.dtype @@ -281,7 +281,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: 'e.g. dtype={datapoints.Mask: torch.int64, "others": None} to pass-through the rest of the inputs.' ) - supports_scaling = is_simple_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)) + supports_scaling = is_pure_tensor(inpt) or isinstance(inpt, (datapoints.Image, datapoints.Video)) if dtype is None: if self.scale and supports_scaling: warnings.warn( diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index c1baad51978..171630c9fd4 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -8,7 +8,7 @@ from torch import nn from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints -from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor +from torchvision.transforms.v2.utils import check_type, has_any, is_pure_tensor from torchvision.utils import _log_api_usage_once from .functional._utils import _get_kernel @@ -72,15 +72,15 @@ def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: # However, this case wasn't supported by transforms v1 either, so there is no BC concern. needs_transform_list = [] - transform_simple_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) + transform_pure_tensor = not has_any(flat_inputs, datapoints.Image, datapoints.Video, PIL.Image.Image) for inpt in flat_inputs: needs_transform = True if not check_type(inpt, self._transformed_types): needs_transform = False - elif is_simple_tensor(inpt): - if transform_simple_tensor: - transform_simple_tensor = False + elif is_pure_tensor(inpt): + if transform_pure_tensor: + transform_pure_tensor = False else: needs_transform = False needs_transform_list.append(needs_transform) diff --git a/torchvision/transforms/v2/_type_conversion.py b/torchvision/transforms/v2/_type_conversion.py index aec82f46f14..386eeec52d7 100644 --- a/torchvision/transforms/v2/_type_conversion.py +++ b/torchvision/transforms/v2/_type_conversion.py @@ -7,7 +7,7 @@ from torchvision import datapoints from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2.utils import is_simple_tensor +from torchvision.transforms.v2.utils import is_pure_tensor class PILToTensor(Transform): @@ -35,7 +35,7 @@ class ToImage(Transform): This transform does not support torchscript. """ - _transformed_types = (is_simple_tensor, PIL.Image.Image, np.ndarray) + _transformed_types = (is_pure_tensor, PIL.Image.Image, np.ndarray) def _transform( self, inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray], params: Dict[str, Any] @@ -65,7 +65,7 @@ class ToPILImage(Transform): .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes """ - _transformed_types = (is_simple_tensor, datapoints.Image, np.ndarray) + _transformed_types = (is_pure_tensor, datapoints.Image, np.ndarray) def __init__(self, mode: Optional[str] = None) -> None: super().__init__() diff --git a/torchvision/transforms/v2/functional/__init__.py b/torchvision/transforms/v2/functional/__init__.py index 3510962ff3a..5d3a18a9151 100644 --- a/torchvision/transforms/v2/functional/__init__.py +++ b/torchvision/transforms/v2/functional/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip -from ._utils import is_simple_tensor, register_kernel # usort: skip +from ._utils import is_pure_tensor, register_kernel # usort: skip from ._meta import ( clamp_bounding_boxes, diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index b5b1ad651bf..fc4dfb60d60 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -8,7 +8,7 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor +from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor def get_dimensions(inpt: torch.Tensor) -> List[int]: @@ -213,7 +213,7 @@ def convert_format_bounding_boxes( if not torch.jit.is_scripting(): _log_api_usage_once(convert_format_bounding_boxes) - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting() or is_pure_tensor(inpt): if old_format is None: raise ValueError("For pure tensor inputs, `old_format` has to be passed.") return _convert_format_bounding_boxes(inpt, old_format=old_format, new_format=new_format, inplace=inplace) @@ -256,7 +256,7 @@ def clamp_bounding_boxes( if not torch.jit.is_scripting(): _log_api_usage_once(clamp_bounding_boxes) - if torch.jit.is_scripting() or is_simple_tensor(inpt): + if torch.jit.is_scripting() or is_pure_tensor(inpt): if format is None or canvas_size is None: raise ValueError("For pure tensor inputs, `format` and `canvas_size` has to be passed.") diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 7fc48929917..28319e64c8f 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -8,7 +8,7 @@ _FillTypeJIT = Optional[List[float]] -def is_simple_tensor(inpt: Any) -> bool: +def is_pure_tensor(inpt: Any) -> bool: return isinstance(inpt, torch.Tensor) and not isinstance(inpt, datapoints.Datapoint) diff --git a/torchvision/transforms/v2/utils.py b/torchvision/transforms/v2/utils.py index 1d9219fb4f5..1e4ff2d05aa 100644 --- a/torchvision/transforms/v2/utils.py +++ b/torchvision/transforms/v2/utils.py @@ -6,7 +6,7 @@ from torchvision import datapoints from torchvision._utils import sequence_to_str -from torchvision.transforms.v2.functional import get_dimensions, get_size, is_simple_tensor +from torchvision.transforms.v2.functional import get_dimensions, get_size, is_pure_tensor def get_bounding_boxes(flat_inputs: List[Any]) -> datapoints.BoundingBoxes: @@ -21,7 +21,7 @@ def query_chw(flat_inputs: List[Any]) -> Tuple[int, int, int]: chws = { tuple(get_dimensions(inpt)) for inpt in flat_inputs - if check_type(inpt, (is_simple_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)) + if check_type(inpt, (is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video)) } if not chws: raise TypeError("No image or video was found in the sample") @@ -38,7 +38,7 @@ def query_size(flat_inputs: List[Any]) -> Tuple[int, int]: if check_type( inpt, ( - is_simple_tensor, + is_pure_tensor, datapoints.Image, PIL.Image.Image, datapoints.Video, From 94539e091e04ea372e2afa6abfbfebedf9ce8b14 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 17 Aug 2023 13:13:12 +0100 Subject: [PATCH 3/3] last one? --- torchvision/transforms/v2/_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index 171630c9fd4..e9af4b426fa 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -56,7 +56,7 @@ def forward(self, *inputs: Any) -> Any: def _needs_transform_list(self, flat_inputs: List[Any]) -> List[bool]: # Below is a heuristic on how to deal with pure tensor inputs: - # 1. Simple tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image + # 1. Pure tensors, i.e. tensors that are not a datapoint, are passed through if there is an explicit image # (`datapoints.Image` or `PIL.Image.Image`) or video (`datapoints.Video`) in the sample. # 2. If there is no explicit image or video in the sample, only the first encountered pure tensor is # transformed as image, while the rest is passed through. The order is defined by the returned `flat_inputs`