From c1592f963ab69baa740eee9e6c0d167446cd92c0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 16 Aug 2023 09:36:58 +0100 Subject: [PATCH] Remove wrap_like class method and add datapoints.wrap() function (#7832) Co-authored-by: Philip Meier --- docs/source/datapoints.rst | 1 + gallery/plot_custom_datapoints.py | 6 ++-- gallery/plot_datapoints.py | 21 +++++------- test/test_datapoints.py | 4 +-- test/test_transforms_v2_refactored.py | 6 ++-- torchvision/datapoints/__init__.py | 22 ++++++++++++ torchvision/datapoints/_bounding_box.py | 26 -------------- torchvision/datapoints/_datapoint.py | 4 --- torchvision/prototype/datapoints/_label.py | 7 ---- torchvision/prototype/transforms/_augment.py | 16 ++++----- torchvision/prototype/transforms/_geometry.py | 6 ++-- torchvision/transforms/v2/_augment.py | 4 +-- torchvision/transforms/v2/_auto_augment.py | 2 +- torchvision/transforms/v2/_geometry.py | 2 +- torchvision/transforms/v2/_misc.py | 4 +-- .../transforms/v2/functional/_geometry.py | 34 +++++++++---------- torchvision/transforms/v2/functional/_meta.py | 4 +-- .../transforms/v2/functional/_utils.py | 6 ++-- 18 files changed, 78 insertions(+), 97 deletions(-) diff --git a/docs/source/datapoints.rst b/docs/source/datapoints.rst index 7351c8685a2..0599545f7f3 100644 --- a/docs/source/datapoints.rst +++ b/docs/source/datapoints.rst @@ -19,3 +19,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`. Mask Datapoint set_return_type + wrap diff --git a/gallery/plot_custom_datapoints.py b/gallery/plot_custom_datapoints.py index b0a48d75d6a..a8db878119a 100644 --- a/gallery/plot_custom_datapoints.py +++ b/gallery/plot_custom_datapoints.py @@ -53,11 +53,11 @@ class MyDatapoint(datapoints.Datapoint): def hflip_my_datapoint(my_dp, *args, **kwargs): print("Flipping!") out = my_dp.flip(-1) - return MyDatapoint.wrap_like(my_dp, out) + return datapoints.wrap(out, like=my_dp) # %% -# To understand why ``wrap_like`` is used, see +# To understand why :func:`~torchvision.datapoints.wrap` is used, see # :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now, # we will explain it below in :ref:`param_forwarding`. # @@ -107,7 +107,7 @@ def hflip_my_datapoint(my_dp, *args, **kwargs): def hflip_my_datapoint(my_dp): # noqa print("Flipping!") out = my_dp.flip(-1) - return MyDatapoint.wrap_like(my_dp, out) + return datapoints.wrap(out, like=my_dp) # %% diff --git a/gallery/plot_datapoints.py b/gallery/plot_datapoints.py index 5bbf6c200af..eecefe9551c 100644 --- a/gallery/plot_datapoints.py +++ b/gallery/plot_datapoints.py @@ -107,26 +107,23 @@ print(bboxes) # %% -# Using the ``wrap_like()`` class method -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Using ``datapoints.wrap()`` +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# You can also use the ``wrap_like()`` class method to wrap a tensor object +# You can also use the :func:`~torchvision.datapoints.wrap` function to wrap a tensor object # into a datapoint. This is useful when you already have an object of the # desired type, which typically happens when writing transforms: you just want -# to wrap the output like the input. This API is inspired by utils like -# :func:`torch.zeros_like`: +# to wrap the output like the input. new_bboxes = torch.tensor([0, 20, 30, 40]) -new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) +new_bboxes = datapoints.wrap(new_bboxes, like=bboxes) assert isinstance(new_bboxes, datapoints.BoundingBoxes) assert new_bboxes.canvas_size == bboxes.canvas_size # %% # The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass -# it as a parameter to override it. Check the -# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for -# more details. +# it as a parameter to override it. # # Do I have to wrap the output of the datasets myself? # ---------------------------------------------------- @@ -230,11 +227,11 @@ def get_transform(train): # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # You can re-wrap a pure tensor into a datapoint by just calling the datapoint -# constructor, or by using the ``.wrap_like()`` class method (see more details -# above in :ref:`datapoint_creation`): +# constructor, or by using the :func:`~torchvision.datapoints.wrap` function +# (see more details above in :ref:`datapoint_creation`): new_bboxes = bboxes + 3 -new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes) +new_bboxes = datapoints.wrap(new_bboxes, like=bboxes) assert isinstance(new_bboxes, datapoints.BoundingBoxes) # %% diff --git a/test/test_datapoints.py b/test/test_datapoints.py index 1042587e396..4da2eb39383 100644 --- a/test/test_datapoints.py +++ b/test/test_datapoints.py @@ -213,13 +213,13 @@ def test_inplace_op_no_wrapping(make_input, return_type): @pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video]) -def test_wrap_like(make_input): +def test_wrap(make_input): dp = make_input() # any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here output = dp * 2 - dp_new = type(dp).wrap_like(dp, output) + dp_new = datapoints.wrap(output, like=dp) assert type(dp_new) is type(dp) assert dp_new.data_ptr() == output.data_ptr() diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index fa1ed05b84b..414104a554b 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -570,7 +570,7 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non canvas_size=(new_height, new_width), affine_matrix=affine_matrix, ) - return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes, canvas_size=(new_height, new_width)) + return datapoints.wrap(expected_bboxes, like=bounding_boxes, canvas_size=(new_height, new_width)) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("size", OUTPUT_SIZES) @@ -815,7 +815,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes): affine_matrix=affine_matrix, ) - return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes) + return datapoints.wrap(expected_bboxes, like=bounding_boxes) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize( @@ -1278,7 +1278,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes): affine_matrix=affine_matrix, ) - return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes) + return datapoints.wrap(expected_bboxes, like=bounding_boxes) @pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat)) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) diff --git a/torchvision/datapoints/__init__.py b/torchvision/datapoints/__init__.py index 7e1295c1197..f99e25b622c 100644 --- a/torchvision/datapoints/__init__.py +++ b/torchvision/datapoints/__init__.py @@ -12,3 +12,25 @@ import warnings warnings.warn(_BETA_TRANSFORMS_WARNING) + + +def wrap(wrappee, *, like, **kwargs): + """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.datapoint.Datapoint` subclass as ``like``. + + If ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`, the ``format`` and ``canvas_size`` of + ``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``. + + Args: + wrappee (Tensor): The tensor to convert. + like (Datapoint): The + kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`. + Ignored otherwise. + """ + if isinstance(like, BoundingBoxes): + return BoundingBoxes._wrap( + wrappee, + format=kwargs.get("format", like.format), + canvas_size=kwargs.get("canvas_size", like.canvas_size), + ) + else: + return wrappee.as_subclass(type(like)) diff --git a/torchvision/datapoints/_bounding_box.py b/torchvision/datapoints/_bounding_box.py index d6f0747df17..ebed0628250 100644 --- a/torchvision/datapoints/_bounding_box.py +++ b/torchvision/datapoints/_bounding_box.py @@ -75,32 +75,6 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor, format=format, canvas_size=canvas_size) - @classmethod - def wrap_like( - cls, - other: BoundingBoxes, - tensor: torch.Tensor, - *, - format: Optional[Union[BoundingBoxFormat, str]] = None, - canvas_size: Optional[Tuple[int, int]] = None, - ) -> BoundingBoxes: - """Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference. - - Args: - other (BoundingBoxes): Reference bounding box. - tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes` - format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the - reference. - canvas_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If - omitted, it is taken from the reference. - - """ - return cls._wrap( - tensor, - format=format if format is not None else other.format, - canvas_size=canvas_size if canvas_size is not None else other.canvas_size, - ) - @classmethod def _wrap_output( cls, diff --git a/torchvision/datapoints/_datapoint.py b/torchvision/datapoints/_datapoint.py index 613a1fb8b25..59b017b4417 100644 --- a/torchvision/datapoints/_datapoint.py +++ b/torchvision/datapoints/_datapoint.py @@ -31,10 +31,6 @@ def _to_tensor( requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad) - @classmethod - def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D: - return tensor.as_subclass(cls) - @classmethod def _wrap_output( cls, diff --git a/torchvision/prototype/datapoints/_label.py b/torchvision/prototype/datapoints/_label.py index 7ed2f7522b0..10ac1bf8295 100644 --- a/torchvision/prototype/datapoints/_label.py +++ b/torchvision/prototype/datapoints/_label.py @@ -32,13 +32,6 @@ def __new__( tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor, categories=categories) - @classmethod - def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L: - return cls._wrap( - tensor, - categories=categories if categories is not None else other.categories, - ) - @classmethod def from_category( cls: Type[L], diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 53f3f801303..f2c6e89dd3a 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -36,11 +36,9 @@ def _copy_paste( antialias: Optional[bool], ) -> Tuple[torch.Tensor, Dict[str, Any]]: - paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection]) - paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection]) - paste_labels = paste_target["labels"].wrap_like( - paste_target["labels"], paste_target["labels"][random_selection] - ) + paste_masks = datapoints.wrap(paste_target["masks"][random_selection], like=paste_target["masks"]) + paste_boxes = datapoints.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"]) + paste_labels = datapoints.wrap(paste_target["labels"][random_selection], like=paste_target["labels"]) masks = target["masks"] @@ -143,7 +141,7 @@ def _insert_outputs( c0, c1, c2, c3 = 0, 0, 0, 0 for i, obj in enumerate(flat_sample): if isinstance(obj, datapoints.Image): - flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0]) + flat_sample[i] = datapoints.wrap(output_images[c0], like=obj) c0 += 1 elif isinstance(obj, PIL.Image.Image): flat_sample[i] = F.to_image_pil(output_images[c0]) @@ -152,13 +150,13 @@ def _insert_outputs( flat_sample[i] = output_images[c0] c0 += 1 elif isinstance(obj, datapoints.BoundingBoxes): - flat_sample[i] = datapoints.BoundingBoxes.wrap_like(obj, output_targets[c1]["boxes"]) + flat_sample[i] = datapoints.wrap(output_targets[c1]["boxes"], like=obj) c1 += 1 elif isinstance(obj, datapoints.Mask): - flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"]) + flat_sample[i] = datapoints.wrap(output_targets[c2]["masks"], like=obj) c2 += 1 elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)): - flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] + flat_sample[i] = datapoints.wrap(output_targets[c3]["labels"], like=obj) c3 += 1 def forward(self, *inputs: Any) -> Any: diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index fe2e8df47eb..8d8e7eb42f0 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -112,11 +112,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["is_valid"] is not None: if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)): - inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] + inpt = datapoints.wrap(inpt[params["is_valid"]], like=inpt) elif isinstance(inpt, datapoints.BoundingBoxes): - inpt = datapoints.BoundingBoxes.wrap_like( - inpt, + inpt = datapoints.wrap( F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size), + like=inpt, ) if params["needs_pad"]: diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 844e0321e0c..f64ae564b54 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -249,7 +249,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] + output = datapoints.wrap(output, like=inpt) return output else: @@ -319,7 +319,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] + output = datapoints.wrap(output, like=inpt) return output else: diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 26eb3abbcf9..8494b64b994 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -620,7 +620,7 @@ def forward(self, *inputs: Any) -> Any: mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype) if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)): - mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type] + mix = datapoints.wrap(mix, like=orig_image_or_video) elif isinstance(orig_image_or_video, PIL.Image.Image): mix = F.to_image_pil(mix) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index b28fad6eabc..f441a0b747b 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -338,7 +338,7 @@ class FiveCrop(Transform): ... images_or_videos, labels = sample ... batch_size = len(images_or_videos) ... image_or_video = images_or_videos[0] - ... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos)) + ... images_or_videos = datapoints.wrap(torch.stack(images_or_videos), like=image_or_video) ... labels = torch.full((batch_size,), label, device=images_or_videos.device) ... return images_or_videos, labels ... diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 1550b523820..ef9ac5fd0c7 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -131,7 +131,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = output.reshape(shape) if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] + output = datapoints.wrap(output, like=inpt) return output @@ -423,4 +423,4 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if is_label: return output - return type(inpt).wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index f8f3b1da0b3..0872d71dd8e 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -87,7 +87,7 @@ def _horizontal_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> output = horizontal_flip_bounding_boxes( inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size ) - return datapoints.BoundingBoxes.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) @_register_kernel_internal(horizontal_flip, datapoints.Video) @@ -143,7 +143,7 @@ def _vertical_flip_bounding_boxes_dispatch(inpt: datapoints.BoundingBoxes) -> da output = vertical_flip_bounding_boxes( inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size ) - return datapoints.BoundingBoxes.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) @_register_kernel_internal(vertical_flip, datapoints.Video) @@ -321,7 +321,7 @@ def _resize_mask_dispatch( inpt: datapoints.Mask, size: List[int], max_size: Optional[int] = None, **kwargs: Any ) -> datapoints.Mask: output = resize_mask(inpt.as_subclass(torch.Tensor), size, max_size=max_size) - return datapoints.Mask.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) def resize_bounding_boxes( @@ -349,7 +349,7 @@ def _resize_bounding_boxes_dispatch( output, canvas_size = resize_bounding_boxes( inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size ) - return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) @_register_kernel_internal(resize, datapoints.Video) @@ -857,7 +857,7 @@ def _affine_bounding_boxes_dispatch( shear=shear, center=center, ) - return datapoints.BoundingBoxes.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) def affine_mask( @@ -912,7 +912,7 @@ def _affine_mask_dispatch( fill=fill, center=center, ) - return datapoints.Mask.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) @_register_kernel_internal(affine, datapoints.Video) @@ -1058,7 +1058,7 @@ def _rotate_bounding_boxes_dispatch( expand=expand, center=center, ) - return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) def rotate_mask( @@ -1099,7 +1099,7 @@ def _rotate_mask_dispatch( **kwargs, ) -> datapoints.Mask: output = rotate_mask(inpt.as_subclass(torch.Tensor), angle=angle, expand=expand, fill=fill, center=center) - return datapoints.Mask.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) @_register_kernel_internal(rotate, datapoints.Video) @@ -1321,7 +1321,7 @@ def _pad_bounding_boxes_dispatch( padding=padding, padding_mode=padding_mode, ) - return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) @_register_kernel_internal(pad, datapoints.Video) @@ -1396,7 +1396,7 @@ def _crop_bounding_boxes_dispatch( output, canvas_size = crop_bounding_boxes( inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width ) - return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) @_register_kernel_internal(crop, datapoints.Mask) @@ -1670,7 +1670,7 @@ def _perspective_bounding_boxes_dispatch( endpoints=endpoints, coefficients=coefficients, ) - return datapoints.BoundingBoxes.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) def perspective_mask( @@ -1712,7 +1712,7 @@ def _perspective_mask_dispatch( fill=fill, coefficients=coefficients, ) - return datapoints.Mask.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) @_register_kernel_internal(perspective, datapoints.Video) @@ -1887,7 +1887,7 @@ def _elastic_bounding_boxes_dispatch( output = elastic_bounding_boxes( inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, displacement=displacement ) - return datapoints.BoundingBoxes.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) def elastic_mask( @@ -1914,7 +1914,7 @@ def _elastic_mask_dispatch( inpt: datapoints.Mask, displacement: torch.Tensor, fill: _FillTypeJIT = None, **kwargs ) -> datapoints.Mask: output = elastic_mask(inpt.as_subclass(torch.Tensor), displacement=displacement, fill=fill) - return datapoints.Mask.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) @_register_kernel_internal(elastic, datapoints.Video) @@ -2022,7 +2022,7 @@ def _center_crop_bounding_boxes_dispatch( output, canvas_size = center_crop_bounding_boxes( inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size, output_size=output_size ) - return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) @_register_kernel_internal(center_crop, datapoints.Mask) @@ -2156,7 +2156,7 @@ def _resized_crop_bounding_boxes_dispatch( output, canvas_size = resized_crop_bounding_boxes( inpt.as_subclass(torch.Tensor), format=inpt.format, top=top, left=left, height=height, width=width, size=size ) - return datapoints.BoundingBoxes.wrap_like(inpt, output, canvas_size=canvas_size) + return datapoints.wrap(output, like=inpt, canvas_size=canvas_size) def resized_crop_mask( @@ -2178,7 +2178,7 @@ def _resized_crop_mask_dispatch( output = resized_crop_mask( inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size ) - return datapoints.Mask.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) @_register_kernel_internal(resized_crop, datapoints.Video) diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index aed1133020f..89b19d9e887 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -223,7 +223,7 @@ def convert_format_bounding_boxes( output = _convert_format_bounding_boxes( inpt.as_subclass(torch.Tensor), old_format=inpt.format, new_format=new_format, inplace=inplace ) - return datapoints.BoundingBoxes.wrap_like(inpt, output, format=new_format) + return datapoints.wrap(output, like=inpt, format=new_format) else: raise TypeError( f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." @@ -265,7 +265,7 @@ def clamp_bounding_boxes( if format is not None or canvas_size is not None: raise ValueError("For bounding box datapoint inputs, `format` and `canvas_size` must not be passed.") output = _clamp_bounding_boxes(inpt.as_subclass(torch.Tensor), format=inpt.format, canvas_size=inpt.canvas_size) - return datapoints.BoundingBoxes.wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) else: raise TypeError( f"Input can either be a plain tensor or a bounding box datapoint, but got {type(inpt)} instead." diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 1f5c6f5eea0..0ea8e5658ed 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -25,11 +25,11 @@ def wrapper(inpt, *args, **kwargs): # regardless of whether we override __torch_function__ in our base class # or not. # Also, even if we didn't call `as_subclass` here, we would still need - # this wrapper to call wrap_like(), because the Datapoint type would be + # this wrapper to call wrap(), because the Datapoint type would be # lost after the first operation due to our own __torch_function__ # logic. output = kernel(inpt.as_subclass(torch.Tensor), *args, **kwargs) - return type(inpt).wrap_like(inpt, output) + return datapoints.wrap(output, like=inpt) return wrapper @@ -137,7 +137,7 @@ def wrap(kernel): def wrapper(inpt, *args, **kwargs): output = kernel(inpt, *args, **kwargs) container_type = type(output) - return container_type(type(inpt).wrap_like(inpt, o) for o in output) + return container_type(datapoints.wrap(o, like=inpt) for o in output) return wrapper