diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 7bed48e6c15..b1760f6f965 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -12,13 +12,10 @@ make_bounding_box, make_detection_mask, make_image, - make_images, - make_segmentation_mask, make_video, - make_videos, ) -from prototype_common_utils import make_label, make_one_hot_labels +from prototype_common_utils import make_label from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.prototype import datapoints, transforms @@ -44,49 +41,6 @@ def parametrize(transforms_with_inputs): ) -@parametrize( - [ - ( - transform, - [ - dict(inpt=inpt, one_hot_label=one_hot_label) - for inpt, one_hot_label in itertools.product( - itertools.chain( - make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), - make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), - ), - make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), - ) - ], - ) - for transform in [ - transforms.RandomMixUp(alpha=1.0), - transforms.RandomCutMix(alpha=1.0), - ] - ] -) -def test_mixup_cutmix(transform, input): - transform(input) - - input_copy = dict(input) - input_copy["path"] = "/path/to/somewhere" - input_copy["num"] = 1234 - transform(input_copy) - - # Check if we raise an error if sample contains bbox or mask or label - err_msg = "does not support PIL images, bounding boxes, masks and plain labels" - input_copy = dict(input) - for unsup_data in [ - make_label(), - make_bounding_box(format="XYXY"), - make_detection_mask(), - make_segmentation_mask(), - ]: - input_copy["unsupported"] = unsup_data - with pytest.raises(TypeError, match=err_msg): - transform(input_copy) - - class TestSimpleCopyPaste: def create_fake_image(self, mocker, image_type): if image_type == PIL.Image.Image: diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index e3a18599806..c264db5d33d 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,6 +1,6 @@ from ._presets import StereoMatching # usort: skip -from ._augment import RandomCutMix, RandomMixUp, SimpleCopyPaste +from ._augment import SimpleCopyPaste from ._geometry import FixedSizeCrop from ._misc import PermuteDimensions, TransposeDimensions from ._type_conversion import LabelToOneHot diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 4da6cfcf9cd..95585fe287c 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,4 +1,3 @@ -import math from typing import Any, cast, Dict, List, Optional, Tuple, Union import PIL.Image @@ -9,100 +8,8 @@ from torchvision.prototype import datapoints as proto_datapoints from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform -from torchvision.transforms.v2._transform import _RandomApplyTransform from torchvision.transforms.v2.functional._geometry import _check_interpolation -from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size - - -class _BaseMixUpCutMix(_RandomApplyTransform): - def __init__(self, alpha: float, p: float = 0.5) -> None: - super().__init__(p=p) - self.alpha = alpha - self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) - - def _check_inputs(self, flat_inputs: List[Any]) -> None: - if not ( - has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor) - and has_any(flat_inputs, proto_datapoints.OneHotLabel) - ): - raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") - if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask, proto_datapoints.Label): - raise TypeError( - f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." - ) - - def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel: - if inpt.ndim < 2: - raise ValueError("Need a batch of one hot labels") - output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) - return proto_datapoints.OneHotLabel.wrap_like(inpt, output) - - -class RandomMixUp(_BaseMixUpCutMix): - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type] - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - lam = params["lam"] - if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): - expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 - if inpt.ndim < expected_ndim: - raise ValueError("The transform expects a batched input") - output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam)) - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] - - return output - elif isinstance(inpt, proto_datapoints.OneHotLabel): - return self._mixup_onehotlabel(inpt, lam) - else: - return inpt - - -class RandomCutMix(_BaseMixUpCutMix): - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - lam = float(self._dist.sample(())) # type: ignore[arg-type] - - H, W = query_size(flat_inputs) - - r_x = torch.randint(W, ()) - r_y = torch.randint(H, ()) - - r = 0.5 * math.sqrt(1.0 - lam) - r_w_half = int(r * W) - r_h_half = int(r * H) - - x1 = int(torch.clamp(r_x - r_w_half, min=0)) - y1 = int(torch.clamp(r_y - r_h_half, min=0)) - x2 = int(torch.clamp(r_x + r_w_half, max=W)) - y2 = int(torch.clamp(r_y + r_h_half, max=H)) - box = (x1, y1, x2, y2) - - lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) - - return dict(box=box, lam_adjusted=lam_adjusted) - - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt): - box = params["box"] - expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4 - if inpt.ndim < expected_ndim: - raise ValueError("The transform expects a batched input") - x1, y1, x2, y2 = box - rolled = inpt.roll(1, 0) - output = inpt.clone() - output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] - - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - - return output - elif isinstance(inpt, proto_datapoints.OneHotLabel): - lam_adjusted = params["lam_adjusted"] - return self._mixup_onehotlabel(inpt, lam_adjusted) - else: - return inpt +from torchvision.transforms.v2.utils import is_simple_tensor class SimpleCopyPaste(Transform):