Skip to content

Commit

Permalink
Remove cutmix and mixup from prototype (#7787)
Browse files Browse the repository at this point in the history
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
  • Loading branch information
NicolasHug and pmeier authored Aug 2, 2023
1 parent cab9fba commit f3c89cc
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 142 deletions.
48 changes: 1 addition & 47 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
make_bounding_box,
make_detection_mask,
make_image,
make_images,
make_segmentation_mask,
make_video,
make_videos,
)

from prototype_common_utils import make_label, make_one_hot_labels
from prototype_common_utils import make_label

from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
Expand All @@ -44,49 +41,6 @@ def parametrize(transforms_with_inputs):
)


@parametrize(
[
(
transform,
[
dict(inpt=inpt, one_hot_label=one_hot_label)
for inpt, one_hot_label in itertools.product(
itertools.chain(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
),
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
)
],
)
for transform in [
transforms.RandomMixUp(alpha=1.0),
transforms.RandomCutMix(alpha=1.0),
]
]
)
def test_mixup_cutmix(transform, input):
transform(input)

input_copy = dict(input)
input_copy["path"] = "/path/to/somewhere"
input_copy["num"] = 1234
transform(input_copy)

# Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support PIL images, bounding boxes, masks and plain labels"
input_copy = dict(input)
for unsup_data in [
make_label(),
make_bounding_box(format="XYXY"),
make_detection_mask(),
make_segmentation_mask(),
]:
input_copy["unsupported"] = unsup_data
with pytest.raises(TypeError, match=err_msg):
transform(input_copy)


class TestSimpleCopyPaste:
def create_fake_image(self, mocker, image_type):
if image_type == PIL.Image.Image:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._presets import StereoMatching # usort: skip

from ._augment import RandomCutMix, RandomMixUp, SimpleCopyPaste
from ._augment import SimpleCopyPaste
from ._geometry import FixedSizeCrop
from ._misc import PermuteDimensions, TransposeDimensions
from ._type_conversion import LabelToOneHot
95 changes: 1 addition & 94 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from typing import Any, cast, Dict, List, Optional, Tuple, Union

import PIL.Image
Expand All @@ -9,100 +8,8 @@
from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform

from torchvision.transforms.v2._transform import _RandomApplyTransform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size


class _BaseMixUpCutMix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))

def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor)
and has_any(flat_inputs, proto_datapoints.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask, proto_datapoints.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)

def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel:
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
return proto_datapoints.OneHotLabel.wrap_like(inpt, output)


class RandomMixUp(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]

return output
elif isinstance(inpt, proto_datapoints.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam)
else:
return inpt


class RandomCutMix(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type]

H, W = query_size(flat_inputs)

r_x = torch.randint(W, ())
r_y = torch.randint(H, ())

r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)

x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)

lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

return dict(box=box, lam_adjusted=lam_adjusted)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
box = params["box"]
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
x1, y1, x2, y2 = box
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]

return output
elif isinstance(inpt, proto_datapoints.OneHotLabel):
lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
return inpt
from torchvision.transforms.v2.utils import is_simple_tensor


class SimpleCopyPaste(Transform):
Expand Down

0 comments on commit f3c89cc

Please sign in to comment.