From 446b2ca571dedaf3c43b2cdc0f299084fd13b93b Mon Sep 17 00:00:00 2001 From: SamuelGabriel Date: Mon, 6 Sep 2021 11:19:10 +0200 Subject: [PATCH] Integration of TrivialAugment with the current AutoAugment Code (#4221) * Initial Proposal * Tensor Save Test + Test Name Fix * Formatting + removing unused argument * fix old argument * fix isnan check error + indexing error with jit * Fix Flake8 error. * Fix MyPy error. * Fix Flake8 error. * Fix PyTorch JIT error in UnitTests due to type annotation. * Fixing tests. * Removing type ignore. * Adding support of ta_wide in references. * Move methods in classes. * Moving new classes to the bottom. * Specialize to TA to TAwide * Add missing type * Fixing lint * Fix doc * Fix search space of TrivialAugment. Co-authored-by: Vasilis Vryniotis Co-authored-by: Vasilis Vryniotis --- docs/source/transforms.rst | 5 ++ gallery/plot_transforms.py | 8 +++ references/classification/presets.py | 2 + test/test_transforms.py | 11 ++++ test/test_transforms_tensor.py | 15 ++++- torchvision/transforms/autoaugment.py | 80 ++++++++++++++++++++++++++- 6 files changed, 119 insertions(+), 2 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 49a5a32301b..f184d5da30e 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -234,6 +234,11 @@ The new transform can be used standalone or mixed-and-matched with existing tran .. autoclass:: RandAugment :members: +`TrivialAugmentWide `_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models. + +.. autoclass:: TrivialAugmentWide + :members: + .. _functional_transforms: Functional Transforms diff --git a/gallery/plot_transforms.py b/gallery/plot_transforms.py index 0a0c1afb479..fe5864ebad5 100644 --- a/gallery/plot_transforms.py +++ b/gallery/plot_transforms.py @@ -253,6 +253,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): imgs = [augmenter(orig_img) for _ in range(4)] plot(imgs) +#################################### +# TrivialAugmentWide +# ~~~~~~~~~~~~~~~~~~ +# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data. +augmenter = T.TrivialAugmentWide() +imgs = [augmenter(orig_img) for _ in range(4)] +plot(imgs) + #################################### # Randomly-applied transforms # --------------------------- diff --git a/references/classification/presets.py b/references/classification/presets.py index c289c3b1c8b..981dbd6ed9e 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -11,6 +11,8 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2 if auto_augment_policy is not None: if auto_augment_policy == "ra": trans.append(autoaugment.RandAugment()) + elif auto_augment_policy == "ta_wide": + trans.append(autoaugment.TrivialAugmentWide()) else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy)) diff --git a/test/test_transforms.py b/test/test_transforms.py index ca11bf664c1..675e79ac3ba 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1502,6 +1502,17 @@ def test_randaugment(num_ops, magnitude, fill): transform.__repr__() +@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)]) +@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30]) +def test_trivialaugmentwide(fill, num_magnitude_bins): + random.seed(42) + img = Image.open(GRACE_HOPPER) + transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins) + for _ in range(100): + img = transform(img) + transform.__repr__() + + def test_random_crop(): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index c0669987213..aaf7880f124 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -547,7 +547,20 @@ def test_randaugment(device, num_ops, magnitude, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment]) +@pytest.mark.parametrize('device', cpu_and_gpu()) +@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]) +def test_trivialaugmentwide(device, fill): + tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) + batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) + + transform = T.TrivialAugmentWide(fill=fill) + s_transform = torch.jit.script(transform) + for _ in range(25): + _test_transform_vs_scripted(transform, s_transform, tensor) + _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + + +@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide]) def test_autoaugment_save(augmentation, tmpdir): transform = augmentation() s_transform = torch.jit.script(transform) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index 3d9c8b6796f..44c7990482b 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -7,7 +7,7 @@ from . import functional as F, InterpolationMode -__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"] +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] def _apply_op(img: Tensor, op_name: str, magnitude: float, @@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float, img = F.equalize(img) elif op_name == "Invert": img = F.invert(img) + elif op_name == "Identity": + pass else: raise ValueError("The provided operator {} is not recognized.".format(op_name)) return img @@ -325,3 +327,79 @@ def __repr__(self) -> str: s += ', fill={fill}' s += ')' return s.format(**self.__dict__) + + +class TrivialAugmentWide(torch.nn.Module): + r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in + `"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" `. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + num_magnitude_bins (int): The number of different magnitude values. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None) -> None: + super().__init__() + self.num_magnitude_bins = num_magnitude_bins + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), + "ShearX": (torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.99, num_bins), True), + "Color": (torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False), + "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + + def forward(self, img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + if isinstance(img, Tensor): + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + + op_meta = self._augmentation_space(self.num_magnitude_bins) + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \ + if magnitudes.ndim > 0 else 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + + return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_magnitude_bins={num_magnitude_bins}' + s += ', interpolation={interpolation}' + s += ', fill={fill}' + s += ')' + return s.format(**self.__dict__)