Skip to content

Commit

Permalink
[fbsync] Adding Mixup and Cutmix (#4379)
Browse files Browse the repository at this point in the history
Summary:
* Add RandomMixupCutmix.

* Add test with real data.

* Use dataloader and collate in the test.

* Making RandomMixupCutmix JIT scriptable.

* Move out label_smoothing and try roll instead of flip

* Adding mixup/cutmix in references script.

* Handle one-hot encoded target in accuracy.

* Add support of devices on tests.

* Separate Mixup from Cutmix.

* Add check for floats.

* Adding device on expect value.

* Remove hardcoded weights.

* One-hot only when necessary.

* Fix linter.

* Moving mixup and cutmix to references.

* Final code clean up.

Reviewed By: datumbox

Differential Revision: D31268036

fbshipit-source-id: 6a73c079d667443da898e3b175b88978b24d52ad
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Sep 30, 2021
1 parent 66724be commit 54f89b9
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 6 deletions.
19 changes: 17 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import torch
import torch.utils.data
from torch.utils.data.dataloader import default_collate
from torch import nn
import torchvision
from torchvision.transforms.functional import InterpolationMode

import presets
import transforms
import utils

try:
Expand Down Expand Up @@ -164,10 +166,21 @@ def main(args):
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)

collate_fn = None
num_classes = len(dataset.classes)
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha))
if args.cutmix_alpha > 0.0:
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha))
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)

sampler=train_sampler, num_workers=args.workers, pin_memory=True,
collate_fn=collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
Expand Down Expand Up @@ -272,6 +285,8 @@ def get_args_parser(add_help=True):
parser.add_argument('--label-smoothing', default=0.0, type=float,
help='label smoothing (default: 0.0)',
dest='label_smoothing')
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
Expand Down
175 changes: 175 additions & 0 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import math
import torch

from typing import Tuple
from torch import Tensor
from torchvision.transforms import functional as F


class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int,
p: float = 0.5, alpha: float = 1.0,
inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point():
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))

if not self.inplace:
batch = batch.clone()
target = target.clone()

if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)

# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)

target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_classes={num_classes}'
s += ', p={p}'
s += ', alpha={alpha}'
s += ', inplace={inplace}'
s += ')'
return s.format(**self.__dict__)


class RandomCutmix(torch.nn.Module):
"""Randomly apply Cutmix to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for cutmix.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int,
p: float = 0.5, alpha: float = 1.0,
inplace: bool = False) -> None:
super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert alpha > 0, "Alpha param can't be zero."

self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif not batch.is_floating_point():
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype))
elif target.dtype != torch.int64:
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype))

if not self.inplace:
batch = batch.clone()
target = target.clone()

if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)

if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)

# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
W, H = F.get_image_size(batch)

r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))

r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)

x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))

batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_classes={num_classes}'
s += ', p={p}'
s += ', alpha={alpha}'
s += ', inplace={inplace}'
s += ')'
return s.format(**self.__dict__)
2 changes: 2 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def accuracy(output, target, topk=(1,)):
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
if target.ndim == 2:
target = target.max(dim=1)[1]

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
Expand Down
3 changes: 2 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,7 +1311,8 @@ def test_random_choice():
transforms.Resize(15),
transforms.Resize(20),
transforms.CenterCrop(10)
]
],
[1 / 3, 1 / 3, 1 / 3]
)
img = transforms.ToPILImage()(torch.rand(3, 25, 25))
num_samples = 250
Expand Down
17 changes: 14 additions & 3 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,20 @@ def __call__(self, img):
class RandomChoice(RandomTransforms):
"""Apply single transformation randomly picked from a list. This transform does not support torchscript.
"""
def __call__(self, img):
t = random.choice(self.transforms)
return t(img)
def __init__(self, transforms, p=None):
super().__init__(transforms)
if p is not None and not isinstance(p, Sequence):
raise TypeError("Argument transforms should be a sequence")
self.p = p

def __call__(self, *args):
t = random.choices(self.transforms, weights=self.p)[0]
return t(*args)

def __repr__(self):
format_string = super().__repr__()
format_string += '(p={0})'.format(self.p)
return format_string


class RandomCrop(torch.nn.Module):
Expand Down

0 comments on commit 54f89b9

Please sign in to comment.