Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integration of TrivialAugment with the current AutoAugment Code #4221

Merged
merged 32 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c56148c
Initial Proposal
SamuelGabriel Jul 29, 2021
f4552ed
Tensor Save Test + Test Name Fix
SamuelGabriel Jul 29, 2021
25968e6
Formatting + removing unused argument
SamuelGabriel Jul 29, 2021
2feff4f
fix old argument
SamuelGabriel Jul 29, 2021
58c7ba8
fix isnan check error + indexing error with jit
SamuelGabriel Jul 29, 2021
33d5d59
Merge branch 'master' into trivialaugment_implementation
SamuelGabriel Jul 29, 2021
a015611
Merge branch 'master' into trivialaugment_implementation
SamuelGabriel Aug 2, 2021
7a7a739
Merge branch 'master' into trivialaugment_implementation
SamuelGabriel Aug 17, 2021
406848a
Fix Flake8 error.
SamuelGabriel Aug 17, 2021
f743481
Fix MyPy error.
SamuelGabriel Aug 17, 2021
19d8696
Fix Flake8 error.
SamuelGabriel Aug 17, 2021
1ed1568
Fix PyTorch JIT error in UnitTests due to type annotation.
SamuelGabriel Aug 17, 2021
536446e
Merge branch 'master' into trivialaugment_implementation
SamuelGabriel Aug 17, 2021
942fb66
Merge branch 'master' into trivialaugment_implementation
datumbox Aug 17, 2021
16784a1
Merge branch 'main' into trivialaugment_implementation
datumbox Aug 26, 2021
c8fb6c7
Merge branch 'main' into trivialaugment_implementation
datumbox Aug 26, 2021
2fc8633
Fixing tests.
datumbox Aug 26, 2021
729c0db
Removing type ignore.
datumbox Aug 26, 2021
d02100a
Merge branch 'main' into SamuelGabriel_trivialaugment_implementation
datumbox Aug 26, 2021
83552c6
Adding support of ta_wide in references.
datumbox Aug 26, 2021
cd6a75e
Merge branch 'main' into trivialaugment_implementation
datumbox Aug 27, 2021
1fe25fb
Move methods in classes.
datumbox Aug 27, 2021
226998c
Moving new classes to the bottom.
datumbox Aug 27, 2021
425c52d
Specialize to TA to TAwide
datumbox Aug 27, 2021
fa8a6d5
Merge branch 'main' into trivialaugment_implementation
datumbox Aug 31, 2021
7483dbc
Merge branch 'main' into SamuelGabriel_trivialaugment_implementation
datumbox Sep 2, 2021
bd2dc17
Add missing type
datumbox Sep 2, 2021
0087be0
Merge branch 'main' into trivialaugment_implementation
datumbox Sep 2, 2021
5770a03
Fixing lint
datumbox Sep 2, 2021
46f886c
Fix doc
datumbox Sep 2, 2021
2933667
Merge branch 'main' into trivialaugment_implementation
SamuelGabriel Sep 6, 2021
30bbae9
Fix search space of TrivialAugment.
SamuelGabriel Sep 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,11 @@ The new transform can be used standalone or mixed-and-matched with existing tran
.. autoclass:: RandAugment
:members:

`TrivialAugmentWide <https://arxiv.org/abs/2103.10158>`_ is a dataset-independent data-augmentation technique which improves the accuracy of Image Classification models.

.. autoclass:: TrivialAugmentWide
:members:

.. _functional_transforms:

Functional Transforms
Expand Down
8 changes: 8 additions & 0 deletions gallery/plot_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# TrivialAugmentWide
# ~~~~~~~~~~~~~~~~~~
# The :class:`~torchvision.transforms.TrivialAugmentWide` transform automatically augments the data.
augmenter = T.TrivialAugmentWide()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# Randomly-applied transforms
# ---------------------------
Expand Down
2 changes: 2 additions & 0 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def __init__(self, crop_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.2
if auto_augment_policy is not None:
if auto_augment_policy == "ra":
trans.append(autoaugment.RandAugment())
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide())
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy))
Expand Down
11 changes: 11 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,17 @@ def test_randaugment(num_ops, magnitude, fill):
transform.__repr__()


@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
def test_trivialaugmentwide(fill, num_magnitude_bins):
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
for _ in range(100):
img = transform(img)
transform.__repr__()


def test_random_crop():
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
Expand Down
15 changes: 14 additions & 1 deletion test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,20 @@ def test_randaugment(device, num_ops, magnitude, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment])
@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
def test_trivialaugmentwide(device, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

transform = T.TrivialAugmentWide(fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize('augmentation', [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation()
s_transform = torch.jit.script(transform)
Expand Down
80 changes: 79 additions & 1 deletion torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import functional as F, InterpolationMode

__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment"]
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]


def _apply_op(img: Tensor, op_name: str, magnitude: float,
Expand Down Expand Up @@ -44,6 +44,8 @@ def _apply_op(img: Tensor, op_name: str, magnitude: float,
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
elif op_name == "Identity":
pass
else:
raise ValueError("The provided operator {} is not recognized.".format(op_name))
return img
Expand Down Expand Up @@ -325,3 +327,79 @@ def __repr__(self) -> str:
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)


class TrivialAugmentWide(torch.nn.Module):
r"""Dataset-independent data-augmentation with TrivialAugment Wide, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".
Args:
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""

def __init__(self, num_magnitude_bins: int = 30, interpolation: InterpolationMode = InterpolationMode.NEAREST,
Copy link
Contributor

@datumbox datumbox Sep 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be 31 as per your comment at #4348 (comment)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be. I pushed a commit to change it to the underlying branch, but I am not sure what is the best way to get such a minor fix into main now, see SamuelGabriel@9df8f13

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worries, I'll fix it on #4370.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, Thanks!

fill: Optional[List[float]] = None) -> None:
super().__init__()
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill

def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]:
return {
# op_name: (magnitudes, signed)
"Identity": (torch.tensor(0.0), False),
"ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
"TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
"Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
"Color": (torch.linspace(0.0, 0.99, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}

def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]

op_meta = self._augmentation_space(self.num_magnitude_bins)
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
if magnitudes.ndim > 0 else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0

return _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_magnitude_bins={num_magnitude_bins}'
s += ', interpolation={interpolation}'
s += ', fill={fill}'
s += ')'
return s.format(**self.__dict__)