Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove cutmix and mixup from prototype #7787

Merged
merged 2 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 1 addition & 47 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
make_bounding_box,
make_detection_mask,
make_image,
make_images,
make_segmentation_mask,
make_video,
make_videos,
)

from prototype_common_utils import make_label, make_one_hot_labels
from prototype_common_utils import make_label

from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video
from torchvision.prototype import datapoints, transforms
Expand All @@ -44,49 +41,6 @@ def parametrize(transforms_with_inputs):
)


@parametrize(
[
(
transform,
[
dict(inpt=inpt, one_hot_label=one_hot_label)
for inpt, one_hot_label in itertools.product(
itertools.chain(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
),
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
)
],
)
for transform in [
transforms.RandomMixUp(alpha=1.0),
transforms.RandomCutMix(alpha=1.0),
]
]
)
def test_mixup_cutmix(transform, input):
transform(input)

input_copy = dict(input)
input_copy["path"] = "/path/to/somewhere"
input_copy["num"] = 1234
transform(input_copy)

# Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support PIL images, bounding boxes, masks and plain labels"
input_copy = dict(input)
for unsup_data in [
make_label(),
make_bounding_box(format="XYXY"),
make_detection_mask(),
make_segmentation_mask(),
]:
input_copy["unsupported"] = unsup_data
with pytest.raises(TypeError, match=err_msg):
transform(input_copy)


class TestSimpleCopyPaste:
def create_fake_image(self, mocker, image_type):
if image_type == PIL.Image.Image:
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._presets import StereoMatching # usort: skip

from ._augment import RandomCutMix, RandomMixUp, SimpleCopyPaste
from ._augment import SimpleCopyPaste
from ._geometry import FixedSizeCrop
from ._misc import PermuteDimensions, TransposeDimensions
from ._type_conversion import LabelToOneHot
95 changes: 1 addition & 94 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from typing import Any, cast, Dict, List, Optional, Tuple, Union

import PIL.Image
Expand All @@ -9,100 +8,8 @@
from torchvision.prototype import datapoints as proto_datapoints
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform

from torchvision.transforms.v2._transform import _RandomApplyTransform
from torchvision.transforms.v2.functional._geometry import _check_interpolation
from torchvision.transforms.v2.utils import has_any, is_simple_tensor, query_size


class _BaseMixUpCutMix(_RandomApplyTransform):
def __init__(self, alpha: float, p: float = 0.5) -> None:
super().__init__(p=p)
self.alpha = alpha
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))

def _check_inputs(self, flat_inputs: List[Any]) -> None:
if not (
has_any(flat_inputs, datapoints.Image, datapoints.Video, is_simple_tensor)
and has_any(flat_inputs, proto_datapoints.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(flat_inputs, PIL.Image.Image, datapoints.BoundingBoxes, datapoints.Mask, proto_datapoints.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
)

def _mixup_onehotlabel(self, inpt: proto_datapoints.OneHotLabel, lam: float) -> proto_datapoints.OneHotLabel:
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
return proto_datapoints.OneHotLabel.wrap_like(inpt, output)


class RandomMixUp(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(()))) # type: ignore[arg-type]

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]

return output
elif isinstance(inpt, proto_datapoints.OneHotLabel):
return self._mixup_onehotlabel(inpt, lam)
else:
return inpt


class RandomCutMix(_BaseMixUpCutMix):
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
lam = float(self._dist.sample(())) # type: ignore[arg-type]

H, W = query_size(flat_inputs)

r_x = torch.randint(W, ())
r_y = torch.randint(H, ())

r = 0.5 * math.sqrt(1.0 - lam)
r_w_half = int(r * W)
r_h_half = int(r * H)

x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))
box = (x1, y1, x2, y2)

lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

return dict(box=box, lam_adjusted=lam_adjusted)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, (datapoints.Image, datapoints.Video)) or is_simple_tensor(inpt):
box = params["box"]
expected_ndim = 5 if isinstance(inpt, datapoints.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
x1, y1, x2, y2 = box
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]

if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]

return output
elif isinstance(inpt, proto_datapoints.OneHotLabel):
lam_adjusted = params["lam_adjusted"]
return self._mixup_onehotlabel(inpt, lam_adjusted)
else:
return inpt
from torchvision.transforms.v2.utils import is_simple_tensor


class SimpleCopyPaste(Transform):
Expand Down
Loading