diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 3743581794f..755a7b0350c 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -82,6 +82,28 @@ def auto_augment_adapter(transform, input, device): return adapted_input +def auto_augment_detection_adapter(transform, input, device): + adapted_input = {} + image_or_video_found = False + bounding_box_found = False + for key, value in input.items(): + if isinstance(value, datapoints.Mask): + # AA detection transforms don't support masks + continue + elif isinstance(value, datapoints.BoundingBox): + if bounding_box_found: + # AA detection transforms only support a single bounding box tensor + continue + bounding_box_found = True + elif check_type(value, (datapoints.Image, datapoints.Video, is_simple_tensor, PIL.Image.Image)): + if image_or_video_found: + # AA transforms only support a single image or video + continue + image_or_video_found = True + adapted_input[key] = value + return adapted_input + + def linear_transformation_adapter(transform, input, device): flat_inputs = list(input.values()) c, h, w = query_chw( @@ -119,6 +141,10 @@ class TestSmoke: (transforms.AutoAugment(), auto_augment_adapter), (transforms.RandAugment(), auto_augment_adapter), (transforms.TrivialAugmentWide(), auto_augment_adapter), + (transforms._AutoAugmentDetection("v0"), auto_augment_detection_adapter), + (transforms._AutoAugmentDetection("v1"), auto_augment_detection_adapter), + (transforms._AutoAugmentDetection("v2"), auto_augment_detection_adapter), + (transforms._AutoAugmentDetection("v3"), auto_augment_detection_adapter), (transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0.3, hue=0.15), None), (transforms.Grayscale(), None), (transforms.RandomAdjustSharpness(sharpness_factor=0.5, p=1.0), None), @@ -310,6 +336,37 @@ def test_common(self, transform, adapter, container_type, image_or_video, device def test_auto_augment(self, transform, input): transform(input) + @pytest.mark.parametrize( + "transform", + [ + transforms._AutoAugmentDetection("v0"), + transforms._AutoAugmentDetection("v1"), + transforms._AutoAugmentDetection("v2"), + transforms._AutoAugmentDetection("v3"), + ], + ) + @pytest.mark.parametrize("seed", range(10)) + def test_auto_augment_detection(self, transform, seed): + torch.manual_seed(seed) + image = datapoints.Image(torch.randint(0, 256, size=(3, 480, 640), dtype=torch.uint8)) + boxes = torch.tensor( + [ + [388.3100, 38.8300, 638.5600, 480.0000], + [82.6800, 314.7300, 195.4400, 445.7400], + [199.1000, 214.1700, 316.4100, 399.2800], + [159.8400, 183.6800, 216.8700, 273.1200], + [15.8000, 265.4600, 93.2900, 395.1800], + [88.2300, 266.0800, 222.9800, 371.2500], + [176.9000, 283.8600, 208.4300, 292.3000], + [537.6300, 230.8300, 580.7100, 291.3300], + [421.1200, 230.4700, 580.6700, 350.9500], + [427.4200, 185.4300, 494.0200, 266.3500], + ] + ) + bboxes = datapoints.BoundingBox(boxes, format="XYXY", spatial_size=image.shape[-2:]) + input = (image, bboxes) + transform(input) + @parametrize( [ ( diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 6573446a33a..24728bcdeaa 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -5,7 +5,7 @@ from ._transform import Transform # usort: skip from ._augment import RandomErasing -from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide +from ._auto_augment import _AutoAugmentDetection, AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, Grayscale, diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 34c0ced43d2..4a4a620471a 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -1,5 +1,5 @@ import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union import PIL.Image import torch @@ -15,6 +15,58 @@ from .utils import check_type, is_simple_tensor +def _solarize_add( + image: Union[datapoints._ImageType, datapoints._VideoType], addition: float = 0.0, threshold: float = 128 / 255 +) -> Union[datapoints._ImageType, datapoints._VideoType]: + if check_type(image, (datapoints.Image, is_simple_tensor, datapoints.Video)): + bound = _FT._max_value(image.dtype) + else: + bound = 255 + if bound != 1: + addition = round(addition * bound) + threshold = round(threshold * bound) + + if isinstance(image, PIL.Image.Image): + + def pixel_fn(pixel: int) -> int: + return max(0, min(pixel + int(addition), bound)) if pixel < int(threshold) else pixel + + return PIL.Image.eval(image, pixel_fn) + + added_image = image.add(addition).clip(0, bound) + result = torch.where(image < threshold, added_image, image) + if isinstance(image, datapoints._datapoint.Datapoint): + result = image.wrap_like(image, result) + return result + + +def _cutout( + image: Union[datapoints._ImageType, datapoints._VideoType], + pad_size: Union[int, Tuple[int, int], None] = None, + pad_fraction: float = 0.0, + replace: int = 0, +) -> Union[datapoints._ImageType, datapoints._VideoType]: + img_c, img_h, img_w = F.get_dimensions(image) + + # Sample the center location in the image where the zero mask will be applied. + cutout_center_height = int(torch.randint(img_h, ())) + cutout_center_width = int(torch.randint(img_w, ())) + + if isinstance(pad_size, int): + pad_size = (pad_size, pad_size) + elif pad_size is None: + assert pad_fraction > 0.0 + pad_size = (int(pad_fraction * img_h), int(pad_fraction * img_w)) + + lower_pad = max(0, cutout_center_height - pad_size[0]) + upper_pad = max(0, img_h - cutout_center_height - pad_size[0]) + left_pad = max(0, cutout_center_width - pad_size[1]) + right_pad = max(0, img_w - cutout_center_width - pad_size[1]) + + cutout_shape = [img_h - (lower_pad + upper_pad), img_w - (left_pad + right_pad)] + return F.erase(image, lower_pad, left_pad, cutout_shape[0], cutout_shape[1], torch.tensor(replace)) + + class _AutoAugmentBase(Transform): def __init__( self, @@ -74,18 +126,18 @@ def _unflatten_and_insert_image_or_video( flat_inputs[idx] = image_or_video return tree_unflatten(flat_inputs, spec) - def _apply_image_or_video_transform( + def _apply_transform( self, - image: Union[datapoints._ImageType, datapoints._VideoType], + inpt: datapoints._InputType, transform_id: str, magnitude: float, interpolation: Union[InterpolationMode, int], fill: Dict[Type, datapoints._FillTypeJIT], - ) -> Union[datapoints._ImageType, datapoints._VideoType]: - fill_ = fill[type(image)] + ) -> Union[datapoints._ImageType, datapoints._VideoType, datapoints.BoundingBox]: + fill_ = fill[type(inpt)] if transform_id == "Identity": - return image + return inpt elif transform_id == "ShearX": # magnitude should be arctan(magnitude) # official autoaug: (1, level, 0, 0, 1, 0) @@ -94,7 +146,7 @@ def _apply_image_or_video_transform( # torchvision: (1, tan(level), 0, 0, 1, 0) # https://github.com/pytorch/vision/blob/0c2373d0bba3499e95776e7936e207d8a1676e65/torchvision/transforms/functional.py#L976 return F.affine( - image, + inpt, angle=0.0, translate=[0, 0], scale=1.0, @@ -107,7 +159,7 @@ def _apply_image_or_video_transform( # magnitude should be arctan(magnitude) # See above return F.affine( - image, + inpt, angle=0.0, translate=[0, 0], scale=1.0, @@ -118,7 +170,7 @@ def _apply_image_or_video_transform( ) elif transform_id == "TranslateX": return F.affine( - image, + inpt, angle=0.0, translate=[int(magnitude), 0], scale=1.0, @@ -128,7 +180,7 @@ def _apply_image_or_video_transform( ) elif transform_id == "TranslateY": return F.affine( - image, + inpt, angle=0.0, translate=[0, int(magnitude)], scale=1.0, @@ -137,33 +189,38 @@ def _apply_image_or_video_transform( fill=fill_, ) elif transform_id == "Rotate": - return F.rotate(image, angle=magnitude, interpolation=interpolation, fill=fill_) + return F.rotate(inpt, angle=magnitude, interpolation=interpolation, fill=fill_) elif transform_id == "Brightness": - return F.adjust_brightness(image, brightness_factor=1.0 + magnitude) + return F.adjust_brightness(inpt, brightness_factor=1.0 + magnitude) elif transform_id == "Color": - return F.adjust_saturation(image, saturation_factor=1.0 + magnitude) + return F.adjust_saturation(inpt, saturation_factor=1.0 + magnitude) elif transform_id == "Contrast": - return F.adjust_contrast(image, contrast_factor=1.0 + magnitude) + return F.adjust_contrast(inpt, contrast_factor=1.0 + magnitude) elif transform_id == "Sharpness": - return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude) + return F.adjust_sharpness(inpt, sharpness_factor=1.0 + magnitude) elif transform_id == "Posterize": - return F.posterize(image, bits=int(magnitude)) + return F.posterize(inpt, bits=int(magnitude)) elif transform_id == "Solarize": - bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0 - return F.solarize(image, threshold=bound * magnitude) + if check_type(inpt, (datapoints.Image, is_simple_tensor, datapoints.Video)): + bound = _FT._max_value(inpt.dtype) + else: + bound = 255 + return F.solarize(inpt, threshold=bound * magnitude) elif transform_id == "AutoContrast": - return F.autocontrast(image) + return F.autocontrast(inpt) elif transform_id == "Equalize": - return F.equalize(image) + return F.equalize(inpt) elif transform_id == "Invert": - return F.invert(image) + return F.invert(inpt) + elif transform_id == "Flip": + return F.horizontal_flip(inpt) else: raise ValueError(f"No transform available for {transform_id}") class AutoAugment(_AutoAugmentBase): r"""[BETA] AutoAugment data augmentation method based on - `"AutoAugment: Learning Augmentation Strategies from Data" `_. + `"AutoAugment: Learning Augmentation Strategies from Data" `_. .. v2betastatus:: AutoAugment transform @@ -330,7 +387,7 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - image_or_video = self._apply_image_or_video_transform( + image_or_video = self._apply_transform( image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) @@ -414,7 +471,7 @@ def forward(self, *inputs: Any) -> Any: magnitude *= -1 else: magnitude = 0.0 - image_or_video = self._apply_image_or_video_transform( + image_or_video = self._apply_transform( image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) @@ -486,7 +543,7 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - image_or_video = self._apply_image_or_video_transform( + image_or_video = self._apply_transform( image_or_video, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, image_or_video) @@ -609,7 +666,7 @@ def forward(self, *inputs: Any) -> Any: else: magnitude = 0.0 - aug = self._apply_image_or_video_transform( + aug = self._apply_transform( aug, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill ) mix.add_(combined_weights[:, i].reshape(batch_dims) * aug) @@ -621,3 +678,309 @@ def forward(self, *inputs: Any) -> Any: mix = F.to_image_pil(mix) return self._unflatten_and_insert_image_or_video(flat_inputs_with_spec, mix) + + +class _AutoAugmentDetectionBase(_AutoAugmentBase): + def _apply_transform( + self, + inpt: datapoints._InputType, + transform_id: str, + magnitude: float, + interpolation: Union[InterpolationMode, int], + fill: Dict[Type, datapoints._FillTypeJIT], + ) -> Union[datapoints._ImageType, datapoints._VideoType, datapoints.BoundingBox]: + if transform_id in ["SolarizeAdd", "Cutout", "BBox_Cutout"]: + if isinstance(inpt, datapoints.BoundingBox): + return inpt + elif transform_id == "SolarizeAdd": + return _solarize_add(inpt, magnitude) + elif transform_id == "Cutout": + return _cutout(inpt, pad_size=int(magnitude)) + elif transform_id == "BBox_Cutout": + return _cutout(inpt, pad_fraction=magnitude) + return super()._apply_transform(inpt, transform_id, magnitude, interpolation, fill) + + def _flatten_and_extract_image_or_video_and_bboxes( + self, + inputs: Any, + unsupported_types: Tuple[Type, ...] = (datapoints.Mask,), + ) -> Tuple[ + Tuple[List[Any], TreeSpec, int, int], + Union[datapoints._ImageType, datapoints._VideoType], + datapoints.BoundingBox, + ]: + flat_inputs, spec = tree_flatten(inputs if len(inputs) > 1 else inputs[0]) + + image_or_videos = [] + bboxes_list = [] + for idx, inpt in enumerate(flat_inputs): + if check_type( + inpt, + ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ), + ): + image_or_videos.append((idx, inpt)) + elif isinstance(inpt, datapoints.BoundingBox): + bboxes_list.append((idx, inpt)) + elif isinstance(inpt, unsupported_types): + raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") + + if not image_or_videos: + raise TypeError("Found no image or video in the sample.") + if len(image_or_videos) > 1: + raise TypeError( + f"Auto augment transformations are only properly defined for a single image or video, " + f"but found {len(image_or_videos)}." + ) + if not bboxes_list: + raise TypeError("Found no bounding box in the sample.") + if len(bboxes_list) > 1: + raise TypeError( + f"Auto augment transformations are only properly defined for a single bboxes tensor, " + f"but found {len(bboxes_list)}." + ) + + idx1, image_or_video = image_or_videos[0] + idx2, bboxes = bboxes_list[0] + return (flat_inputs, spec, idx1, idx2), image_or_video, bboxes + + def _unflatten_and_insert_image_or_video_and_bboxes( + self, + flat_inputs_with_spec: Tuple[List[Any], TreeSpec, int, int], + image_or_video: Union[datapoints._ImageType, datapoints._VideoType], + bboxes: datapoints.BoundingBox, + ) -> Any: + flat_inputs, spec, idx1, idx2 = flat_inputs_with_spec + flat_inputs[idx1] = image_or_video + flat_inputs[idx2] = bboxes + return tree_unflatten(flat_inputs, spec) + + @staticmethod + def _transform_image_or_video_in_bboxes( + fn: Callable[..., torch.Tensor], fn_kwrgs: dict, image: torch.Tensor, bboxes: datapoints.BoundingBox + ) -> torch.Tensor: + new_img = image.clone() + xyxy_bboxes = F.convert_format_bounding_box(bboxes, new_format=datapoints.BoundingBoxFormat.XYXY) + for bbox in xyxy_bboxes.to(torch.long): + bbox_img = new_img[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] + out_bbox_img = fn(bbox_img, **fn_kwrgs) + new_img[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] = out_bbox_img + return new_img + + def _apply_image_or_video_and_bboxes_transform( + self, + image: Union[datapoints._ImageType, datapoints._VideoType], + bboxes: datapoints.BoundingBox, + transform_id: str, + magnitude: float, + interpolation: InterpolationMode, + fill: Dict[Type, datapoints._FillTypeJIT], + ) -> Tuple[Any, datapoints.BoundingBox]: + if transform_id == "BBox_Cutout": + random_index = int(torch.randint(len(bboxes), ())) + chosen_bbox = bboxes.wrap_like(bboxes, bboxes[random_index].unsqueeze(0)) + fn_kwargs = dict(transform_id=transform_id, magnitude=magnitude, interpolation=interpolation, fill=fill) + image = self._transform_image_or_video_in_bboxes(self._apply_transform, fn_kwargs, image, chosen_bbox) + elif transform_id.endswith("_Only_BBoxes"): + transform_id = transform_id.replace("_Only_BBoxes", "") + fn_kwargs = dict(transform_id=transform_id, magnitude=magnitude, interpolation=interpolation, fill=fill) + image = self._transform_image_or_video_in_bboxes(self._apply_transform, fn_kwargs, image, bboxes) + else: + image = self._apply_transform(image, transform_id, magnitude, interpolation, fill) + bboxes = cast( + datapoints.BoundingBox, self._apply_transform(bboxes, transform_id, magnitude, interpolation, fill) + ) + + return image, bboxes + + +class _AutoAugmentDetection(_AutoAugmentDetectionBase): + r"""[BETA] AutoAugment data augmentation method for object detection task based on + `"Learning Data Augmentation Strategies for Object Detection" `_. + + .. v2betastatus:: AutoAugment transform + + This transformation works on images, videos and bounding boxes only. + + If the input is :class:`torch.Tensor`, it should be of type ``torch.uint8``, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + policy (str, optional): The name of the AutoAugment policy to use. The available + options are `v0`, `v1`, `v2`, `v3` and `test`. `v0` is the policy used for + all of the results in the paper and was found to achieve the best results + on the COCO dataset. `v1`, `v2` and `v3` are additional good policies + found on the COCO dataset that have slight variation in what operations + were used during the search procedure along with how many operations are + applied in parallel to a single image (2 vs 3). + interpolation (InterpolationMode, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + _AUGMENTATION_SPACE = { + "AutoContrast": (lambda num_bins, height, width: None, False), + "Equalize": (lambda num_bins, height, width: None, False), + "Posterize": ( + lambda num_bins, height, width: torch.linspace(8, 4, num_bins + 1).round().int(), + False, + ), + "Solarize": ( + lambda num_bins, height, width: torch.linspace(256 / 255, 0, num_bins + 1), + False, + ), + "SolarizeAdd": ( + lambda num_bins, height, width: torch.linspace(0, 110 / 255, num_bins + 1), + False, + ), + "Color": (lambda num_bins, height, width: torch.linspace(-0.9, 0.9, num_bins + 1), False), + "Contrast": (lambda num_bins, height, width: torch.linspace(-0.9, 0.9, num_bins + 1), False), + "Brightness": (lambda num_bins, height, width: torch.linspace(-0.9, 0.9, num_bins + 1), False), + "Sharpness": (lambda num_bins, height, width: torch.linspace(-0.9, 0.9, num_bins + 1), False), + "Cutout": (lambda num_bins, height, width: torch.linspace(0, 100, num_bins + 1).round().int(), False), + "BBox_Cutout": (lambda num_bins, height, width: torch.linspace(0.0, 0.75, num_bins + 1), False), + "TranslateX": ( + lambda num_bins, height, width: torch.linspace(0.0, 250.0 / 331.0 * width, num_bins + 1), + True, + ), + "TranslateY": ( + lambda num_bins, height, width: torch.linspace(0.0, 250.0 / 331.0 * height, num_bins + 1), + True, + ), + "ShearX": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins + 1), True), + "ShearY": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins + 1), True), + "Rotate": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins + 1), True), + "Rotate_Only_BBoxes": (lambda num_bins, height, width: torch.linspace(0.0, 30.0, num_bins + 1), True), + "ShearX_Only_BBoxes": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins + 1), True), + "ShearY_Only_BBoxes": (lambda num_bins, height, width: torch.linspace(0.0, 0.3, num_bins + 1), True), + "TranslateX_Only_BBoxes": ( + lambda num_bins, height, width: torch.linspace(0.0, 120.0 / 331.0 * width, num_bins + 1), + True, + ), + "TranslateY_Only_BBoxes": ( + lambda num_bins, height, width: torch.linspace(0.0, 120.0 / 331.0 * height, num_bins + 1), + True, + ), + "Flip_Only_BBoxes": (lambda num_bins, height, width: None, False), + "Solarize_Only_BBoxes": ( + lambda num_bins, height, width: torch.linspace(256 / 255, 0, num_bins + 1).round().int(), + False, + ), + "Equalize_Only_BBoxes": (lambda num_bins, height, width: None, False), + "Cutout_Only_BBoxes": (lambda num_bins, height, width: torch.linspace(0, 50, num_bins + 1), False), + } + + def __init__( + self, + policy: str = "v0", + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Union[datapoints._FillType, Dict[Type, datapoints._FillType]] = None, + ) -> None: + super().__init__(interpolation=interpolation, fill=fill) + self.policy = policy + self._policies = self._get_policies(policy) + + def _get_policies(self, policy: str) -> List[List[Tuple[str, float, Optional[int]]]]: + if policy == "v0": + return [ + [("TranslateX", 0.6, 4), ("Equalize", 0.8, 10)], + [("TranslateY_Only_BBoxes", 0.2, 2), ("Cutout", 0.8, 8)], + [("Sharpness", 0.0, 8), ("ShearX", 0.4, 0)], + [("ShearY", 1.0, 2), ("TranslateY_Only_BBoxes", 0.6, 6)], + [("Rotate", 0.6, 10), ("Color", 1.0, 6)], + ] + elif policy == "v1": + return [ + [("TranslateX", 0.6, 4), ("Equalize", 0.8, 10)], + [("TranslateY_Only_BBoxes", 0.2, 2), ("Cutout", 0.8, 8)], + [("Sharpness", 0.0, 8), ("ShearX", 0.4, 0)], + [("ShearY", 1.0, 2), ("TranslateY_Only_BBoxes", 0.6, 6)], + [("Rotate", 0.6, 10), ("Color", 1.0, 6)], + [("Color", 0.0, 0), ("ShearX_Only_BBoxes", 0.8, 4)], + [("ShearY_Only_BBoxes", 0.8, 2), ("Flip_Only_BBoxes", 0.0, 10)], + [("Equalize", 0.6, 10), ("TranslateX", 0.2, 2)], + [("Color", 1.0, 10), ("TranslateY_Only_BBoxes", 0.4, 6)], + [("Rotate", 0.8, 10), ("Contrast", 0.0, 10)], + [("Cutout", 0.2, 2), ("Brightness", 0.8, 10)], + [("Color", 1.0, 6), ("Equalize", 1.0, 2)], + [("Cutout_Only_BBoxes", 0.4, 6), ("TranslateY_Only_BBoxes", 0.8, 2)], + [("Color", 0.2, 8), ("Rotate", 0.8, 10)], + [("Sharpness", 0.4, 4), ("TranslateY_Only_BBoxes", 0.0, 4)], + [("Sharpness", 1.0, 4), ("SolarizeAdd", 0.4, 4)], + [("Rotate", 1.0, 8), ("Sharpness", 0.2, 8)], + [("ShearY", 0.6, 10), ("Equalize_Only_BBoxes", 0.6, 8)], + [("ShearX", 0.2, 6), ("TranslateY_Only_BBoxes", 0.2, 10)], + [("SolarizeAdd", 0.6, 8), ("Brightness", 0.8, 10)], + ] + elif policy == "v2": + return [ + [("Color", 0.0, 6), ("Cutout", 0.6, 8), ("Sharpness", 0.4, 8)], + [("Rotate", 0.4, 8), ("Sharpness", 0.4, 2), ("Rotate", 0.8, 10)], + [("TranslateY", 1.0, 8), ("AutoContrast", 0.8, 2)], + [("AutoContrast", 0.4, 6), ("ShearX", 0.8, 8), ("Brightness", 0.0, 10)], + [("SolarizeAdd", 0.2, 6), ("Contrast", 0.0, 10), ("AutoContrast", 0.6, 0)], + [("Cutout", 0.2, 0), ("Solarize", 0.8, 8), ("Color", 1.0, 4)], + [("TranslateY", 0.0, 4), ("Equalize", 0.6, 8), ("Solarize", 0.0, 10)], + [("TranslateY", 0.2, 2), ("ShearY", 0.8, 8), ("Rotate", 0.8, 8)], + [("Cutout", 0.8, 8), ("Brightness", 0.8, 8), ("Cutout", 0.2, 2)], + [("Color", 0.8, 4), ("TranslateY", 1.0, 6), ("Rotate", 0.6, 6)], + [("Rotate", 0.6, 10), ("BBox_Cutout", 1.0, 4), ("Cutout", 0.2, 8)], + [("Rotate", 0.0, 0), ("Equalize", 0.6, 6), ("ShearY", 0.6, 8)], + [("Brightness", 0.8, 8), ("AutoContrast", 0.4, 2), ("Brightness", 0.2, 2)], + [("TranslateY", 0.4, 8), ("Solarize", 0.4, 6), ("SolarizeAdd", 0.2, 10)], + [("Contrast", 1.0, 10), ("SolarizeAdd", 0.2, 8), ("Equalize", 0.2, 4)], + ] + elif policy == "v3": + return [ + [("Posterize", 0.8, 2), ("TranslateX", 1.0, 8)], + [("BBox_Cutout", 0.2, 10), ("Sharpness", 1.0, 8)], + [("Rotate", 0.6, 8), ("Rotate", 0.8, 10)], + [("Equalize", 0.8, 10), ("AutoContrast", 0.2, 10)], + [("SolarizeAdd", 0.2, 2), ("TranslateY", 0.2, 8)], + [("Sharpness", 0.0, 2), ("Color", 0.4, 8)], + [("Equalize", 1.0, 8), ("TranslateY", 1.0, 8)], + [("Posterize", 0.6, 2), ("Rotate", 0.0, 10)], + [("AutoContrast", 0.6, 0), ("Rotate", 1.0, 6)], + [("Equalize", 0.0, 4), ("Cutout", 0.8, 10)], + [("Brightness", 1.0, 2), ("TranslateY", 1.0, 6)], + [("Contrast", 0.0, 2), ("ShearY", 0.8, 0)], + [("AutoContrast", 0.8, 10), ("Contrast", 0.2, 10)], + [("Rotate", 1.0, 10), ("Cutout", 1.0, 10)], + [("SolarizeAdd", 0.8, 6), ("Equalize", 0.8, 8)], + ] + else: + raise ValueError(f"The provided policy {policy} is not recognized.") + + def forward(self, *inputs: Any) -> Any: + flat_inputs_with_spec, image_or_video, bboxes = self._flatten_and_extract_image_or_video_and_bboxes( + inputs, unsupported_types=(datapoints.Mask,) + ) + height, width = get_spatial_size(image_or_video) + policy = self._policies[int(torch.randint(len(self._policies), ()))] + + for transform_id, probability, magnitude_idx in policy: + if not torch.rand(()) <= probability: + continue + + magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] + + magnitudes = magnitudes_fn(10, height, width) + if magnitudes is not None: + magnitude = float(magnitudes[magnitude_idx]) + if signed and torch.rand(()) <= 0.5: + magnitude *= -1 + else: + magnitude = 0.0 + + image_or_video, bboxes = self._apply_image_or_video_and_bboxes_transform( + image_or_video, bboxes, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill + ) + + return self._unflatten_and_insert_image_or_video_and_bboxes(flat_inputs_with_spec, image_or_video, bboxes)