Skip to content

Commit

Permalink
Merge branch 'main' into pil-dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Aug 16, 2023
2 parents c5a1a07 + c1592f9 commit e3db48a
Show file tree
Hide file tree
Showing 18 changed files with 78 additions and 97 deletions.
1 change: 1 addition & 0 deletions docs/source/datapoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ see e.g. :ref:`sphx_glr_auto_examples_plot_transforms_v2_e2e.py`.
Mask
Datapoint
set_return_type
wrap
6 changes: 3 additions & 3 deletions gallery/plot_custom_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ class MyDatapoint(datapoints.Datapoint):
def hflip_my_datapoint(my_dp, *args, **kwargs):
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)
return datapoints.wrap(out, like=my_dp)


# %%
# To understand why ``wrap_like`` is used, see
# To understand why :func:`~torchvision.datapoints.wrap` is used, see
# :ref:`datapoint_unwrapping_behaviour`. Ignore the ``*args, **kwargs`` for now,
# we will explain it below in :ref:`param_forwarding`.
#
Expand Down Expand Up @@ -107,7 +107,7 @@ def hflip_my_datapoint(my_dp, *args, **kwargs):
def hflip_my_datapoint(my_dp): # noqa
print("Flipping!")
out = my_dp.flip(-1)
return MyDatapoint.wrap_like(my_dp, out)
return datapoints.wrap(out, like=my_dp)


# %%
Expand Down
21 changes: 9 additions & 12 deletions gallery/plot_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,23 @@
print(bboxes)

# %%
# Using the ``wrap_like()`` class method
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Using ``datapoints.wrap()``
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can also use the ``wrap_like()`` class method to wrap a tensor object
# You can also use the :func:`~torchvision.datapoints.wrap` function to wrap a tensor object
# into a datapoint. This is useful when you already have an object of the
# desired type, which typically happens when writing transforms: you just want
# to wrap the output like the input. This API is inspired by utils like
# :func:`torch.zeros_like`:
# to wrap the output like the input.

new_bboxes = torch.tensor([0, 20, 30, 40])
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)
assert new_bboxes.canvas_size == bboxes.canvas_size


# %%
# The metadata of ``new_bboxes`` is the same as ``bboxes``, but you could pass
# it as a parameter to override it. Check the
# :meth:`~torchvision.datapoints.BoundingBoxes.wrap_like` documentation for
# more details.
# it as a parameter to override it.
#
# Do I have to wrap the output of the datasets myself?
# ----------------------------------------------------
Expand Down Expand Up @@ -230,11 +227,11 @@ def get_transform(train):
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# You can re-wrap a pure tensor into a datapoint by just calling the datapoint
# constructor, or by using the ``.wrap_like()`` class method (see more details
# above in :ref:`datapoint_creation`):
# constructor, or by using the :func:`~torchvision.datapoints.wrap` function
# (see more details above in :ref:`datapoint_creation`):

new_bboxes = bboxes + 3
new_bboxes = datapoints.BoundingBoxes.wrap_like(bboxes, new_bboxes)
new_bboxes = datapoints.wrap(new_bboxes, like=bboxes)
assert isinstance(new_bboxes, datapoints.BoundingBoxes)

# %%
Expand Down
4 changes: 2 additions & 2 deletions test/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,13 @@ def test_inplace_op_no_wrapping(make_input, return_type):


@pytest.mark.parametrize("make_input", [make_image, make_bounding_box, make_segmentation_mask, make_video])
def test_wrap_like(make_input):
def test_wrap(make_input):
dp = make_input()

# any operation besides the ones listed in _FORCE_TORCHFUNCTION_SUBCLASS will do here
output = dp * 2

dp_new = type(dp).wrap_like(dp, output)
dp_new = datapoints.wrap(output, like=dp)

assert type(dp_new) is type(dp)
assert dp_new.data_ptr() == output.data_ptr()
Expand Down
6 changes: 3 additions & 3 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non
canvas_size=(new_height, new_width),
affine_matrix=affine_matrix,
)
return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes, canvas_size=(new_height, new_width))
return datapoints.wrap(expected_bboxes, like=bounding_boxes, canvas_size=(new_height, new_width))

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("size", OUTPUT_SIZES)
Expand Down Expand Up @@ -816,7 +816,7 @@ def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes):
affine_matrix=affine_matrix,
)

return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes)
return datapoints.wrap(expected_bboxes, like=bounding_boxes)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1279,7 +1279,7 @@ def _reference_vertical_flip_bounding_boxes(self, bounding_boxes):
affine_matrix=affine_matrix,
)

return datapoints.BoundingBoxes.wrap_like(bounding_boxes, expected_bboxes)
return datapoints.wrap(expected_bboxes, like=bounding_boxes)

@pytest.mark.parametrize("format", list(datapoints.BoundingBoxFormat))
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
Expand Down
22 changes: 22 additions & 0 deletions torchvision/datapoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,25 @@
import warnings

warnings.warn(_BETA_TRANSFORMS_WARNING)


def wrap(wrappee, *, like, **kwargs):
"""Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.datapoint.Datapoint` subclass as ``like``.
If ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`, the ``format`` and ``canvas_size`` of
``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.
Args:
wrappee (Tensor): The tensor to convert.
like (Datapoint): The
kwargs: Can contain "format" and "canvas_size" if ``like`` is a :class:`~torchvision.datapoint.BoundingBoxes`.
Ignored otherwise.
"""
if isinstance(like, BoundingBoxes):
return BoundingBoxes._wrap(
wrappee,
format=kwargs.get("format", like.format),
canvas_size=kwargs.get("canvas_size", like.canvas_size),
)
else:
return wrappee.as_subclass(type(like))
26 changes: 0 additions & 26 deletions torchvision/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,32 +75,6 @@ def __new__(
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, format=format, canvas_size=canvas_size)

@classmethod
def wrap_like(
cls,
other: BoundingBoxes,
tensor: torch.Tensor,
*,
format: Optional[Union[BoundingBoxFormat, str]] = None,
canvas_size: Optional[Tuple[int, int]] = None,
) -> BoundingBoxes:
"""Wrap a :class:`torch.Tensor` as :class:`BoundingBoxes` from a reference.
Args:
other (BoundingBoxes): Reference bounding box.
tensor (Tensor): Tensor to be wrapped as :class:`BoundingBoxes`
format (BoundingBoxFormat, str, optional): Format of the bounding box. If omitted, it is taken from the
reference.
canvas_size (two-tuple of ints, optional): Height and width of the corresponding image or video. If
omitted, it is taken from the reference.
"""
return cls._wrap(
tensor,
format=format if format is not None else other.format,
canvas_size=canvas_size if canvas_size is not None else other.canvas_size,
)

@classmethod
def _wrap_output(
cls,
Expand Down
4 changes: 0 additions & 4 deletions torchvision/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def _to_tensor(
requires_grad = data.requires_grad if isinstance(data, torch.Tensor) else False
return torch.as_tensor(data, dtype=dtype, device=device).requires_grad_(requires_grad)

@classmethod
def wrap_like(cls: Type[D], other: D, tensor: torch.Tensor) -> D:
return tensor.as_subclass(cls)

@classmethod
def _wrap_output(
cls,
Expand Down
7 changes: 0 additions & 7 deletions torchvision/prototype/datapoints/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ def __new__(
tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad)
return cls._wrap(tensor, categories=categories)

@classmethod
def wrap_like(cls: Type[L], other: L, tensor: torch.Tensor, *, categories: Optional[Sequence[str]] = None) -> L:
return cls._wrap(
tensor,
categories=categories if categories is not None else other.categories,
)

@classmethod
def from_category(
cls: Type[L],
Expand Down
16 changes: 7 additions & 9 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,9 @@ def _copy_paste(
antialias: Optional[bool],
) -> Tuple[torch.Tensor, Dict[str, Any]]:

paste_masks = paste_target["masks"].wrap_like(paste_target["masks"], paste_target["masks"][random_selection])
paste_boxes = paste_target["boxes"].wrap_like(paste_target["boxes"], paste_target["boxes"][random_selection])
paste_labels = paste_target["labels"].wrap_like(
paste_target["labels"], paste_target["labels"][random_selection]
)
paste_masks = datapoints.wrap(paste_target["masks"][random_selection], like=paste_target["masks"])
paste_boxes = datapoints.wrap(paste_target["boxes"][random_selection], like=paste_target["boxes"])
paste_labels = datapoints.wrap(paste_target["labels"][random_selection], like=paste_target["labels"])

masks = target["masks"]

Expand Down Expand Up @@ -143,7 +141,7 @@ def _insert_outputs(
c0, c1, c2, c3 = 0, 0, 0, 0
for i, obj in enumerate(flat_sample):
if isinstance(obj, datapoints.Image):
flat_sample[i] = datapoints.Image.wrap_like(obj, output_images[c0])
flat_sample[i] = datapoints.wrap(output_images[c0], like=obj)
c0 += 1
elif isinstance(obj, PIL.Image.Image):
flat_sample[i] = F.to_image_pil(output_images[c0])
Expand All @@ -152,13 +150,13 @@ def _insert_outputs(
flat_sample[i] = output_images[c0]
c0 += 1
elif isinstance(obj, datapoints.BoundingBoxes):
flat_sample[i] = datapoints.BoundingBoxes.wrap_like(obj, output_targets[c1]["boxes"])
flat_sample[i] = datapoints.wrap(output_targets[c1]["boxes"], like=obj)
c1 += 1
elif isinstance(obj, datapoints.Mask):
flat_sample[i] = datapoints.Mask.wrap_like(obj, output_targets[c2]["masks"])
flat_sample[i] = datapoints.wrap(output_targets[c2]["masks"], like=obj)
c2 += 1
elif isinstance(obj, (proto_datapoints.Label, proto_datapoints.OneHotLabel)):
flat_sample[i] = obj.wrap_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
flat_sample[i] = datapoints.wrap(output_targets[c3]["labels"], like=obj)
c3 += 1

def forward(self, *inputs: Any) -> Any:
Expand Down
6 changes: 3 additions & 3 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:

if params["is_valid"] is not None:
if isinstance(inpt, (Label, OneHotLabel, datapoints.Mask)):
inpt = inpt.wrap_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
inpt = datapoints.wrap(inpt[params["is_valid"]], like=inpt)
elif isinstance(inpt, datapoints.BoundingBoxes):
inpt = datapoints.BoundingBoxes.wrap_like(
inpt,
inpt = datapoints.wrap(
F.clamp_bounding_boxes(inpt[params["is_valid"]], format=inpt.format, canvas_size=inpt.canvas_size),
like=inpt,
)

if params["needs_pad"]:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
output = datapoints.wrap(output, like=inpt)

return output
else:
Expand Down Expand Up @@ -319,7 +319,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
output = datapoints.wrap(output, like=inpt)

return output
else:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def forward(self, *inputs: Any) -> Any:
mix = mix.reshape(orig_dims).to(dtype=image_or_video.dtype)

if isinstance(orig_image_or_video, (datapoints.Image, datapoints.Video)):
mix = orig_image_or_video.wrap_like(orig_image_or_video, mix) # type: ignore[arg-type]
mix = datapoints.wrap(mix, like=orig_image_or_video)
elif isinstance(orig_image_or_video, PIL.Image.Image):
mix = F.to_image_pil(mix)

Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class FiveCrop(Transform):
... images_or_videos, labels = sample
... batch_size = len(images_or_videos)
... image_or_video = images_or_videos[0]
... images_or_videos = image_or_video.wrap_like(image_or_video, torch.stack(images_or_videos))
... images_or_videos = datapoints.wrap(torch.stack(images_or_videos), like=image_or_video)
... labels = torch.full((batch_size,), label, device=images_or_videos.device)
... return images_or_videos, labels
...
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = output.reshape(shape)

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
output = datapoints.wrap(output, like=inpt)
return output


Expand Down Expand Up @@ -423,4 +423,4 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if is_label:
return output

return type(inpt).wrap_like(inpt, output)
return datapoints.wrap(output, like=inpt)
Loading

0 comments on commit e3db48a

Please sign in to comment.