Skip to content

Commit

Permalink
[fbsync] Integration of TrivialAugment with the current AutoAugment C…
Browse files Browse the repository at this point in the history
…ode (#4221)

Summary:
* Initial Proposal

* Tensor Save Test + Test Name Fix

* Formatting + removing unused argument

* fix old argument

* fix isnan check error + indexing error with jit

* Fix Flake8 error.

* Fix MyPy error.

* Fix Flake8 error.

* Fix PyTorch JIT error in UnitTests due to type annotation.

* Fixing tests.

* Removing type ignore.

* Adding support of ta_wide in references.

* Move methods in classes.

* Moving new classes to the bottom.

* Specialize to TA to TAwide

* Add missing type

* Fixing lint

* Fix doc

* Fix search space of TrivialAugment.

Reviewed By: fmassa

Differential Revision: D30793337

fbshipit-source-id: 01ffd0268c10beb7d96017ad9490d3d5c9238810

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Vasilis Vryniotis <vvryniotis@fb.com>
  • Loading branch information
2 people authored and facebook-github-bot committed Sep 9, 2021
1 parent 81cba99 commit d6a64d9
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ The new transform can be used standalone or mixed-and-matched with existing tran
.. autoclass:: RandAugment
:members:

`TrivialAugmentWide <https://arxiv.org/abs/2103.10158>`_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models.

.. autoclass:: TrivialAugmentWide
:members:

.. _functional_transforms:

Functional Transforms
Expand Down
8 changes: 8 additions & 0 deletions gallery/plot_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# TrivialAugmentWide
# ~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data.
augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# Randomly-applied transforms
# ---------------------------
Expand Down
2 changes: 2 additions & 0 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment())
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide())
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
Expand Down
11 changes: 11 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,17 @@ def test_randaugment(num_ops, magnitude, fill):
transform.__repr__()


@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
def test_trivialaugmentwide(fill, num_magnitude_bins):
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
for _ in range(100):
img = transform(img)
transform.__repr__()


def test_random_crop():
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
Expand Down
15 changes: 14 additions & 1 deletion test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,20 @@ def test_randaugment(device, num_ops, magnitude, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment])
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
def test_trivialaugmentwide(device, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

transform = T.TrivialAugmentWide(fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation()
s_transform = torch.jit.script(transform)
Expand Down
80 changes: 79 additions & 1 deletion torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import functional as F, InterpolationMode

__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"]
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]


def _apply_op(img: Tensor, op_name: str, magnitude: float,
Expand Down Expand Up @@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float,
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
elif op_name == "Identity":
pass
else:
raise ValueError("The provided operator {} is not recognized.".format(op_name))
return img
Expand Down Expand Up @@ -325,3 +327,79 @@ def __repr__(self) -> str:
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)


class TrivialAugmentWide(torch.nn.Module):
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""

def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None) -> None:
super().__init__()
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill

def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}

def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]

op_meta = self._augmentation_space(self.num_magnitude_bins)
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
if magnitudes.ndim > 0 else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0

return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)

0 comments on commit d6a64d9

Please sign in to comment.