From c8e3b2a5925e7b7ed21662e86a7e9553170a5633 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 15 Sep 2021 18:32:22 +0100 Subject: [PATCH] Adding Mixup and Cutmix (#4379) * Add RandomMixupCutmix. * Add test with real data. * Use dataloader and collate in the test. * Making RandomMixupCutmix JIT scriptable. * Move out label_smoothing and try roll instead of flip * Adding mixup/cutmix in references script. * Handle one-hot encoded target in accuracy. * Add support of devices on tests. * Separate Mixup from Cutmix. * Add check for floats. * Adding device on expect value. * Remove hardcoded weights. * One-hot only when necessary. * Fix linter. * Moving mixup and cutmix to references. * Final code clean up. --- references/classification/train.py | 19 ++- references/classification/transforms.py | 175 ++++++++++++++++++++++++ references/classification/utils.py | 2 + test/test_transforms.py | 3 +- torchvision/transforms/transforms.py | 17 ++- 5 files changed, 210 insertions(+), 6 deletions(-) create mode 100644 references/classification/transforms.py diff --git a/references/classification/train.py b/references/classification/train.py index a3e4c9ad8e9..3ec9039a018 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -4,11 +4,13 @@ import torch import torch.utils.data +from torch.utils.data.dataloader import default_collate from torch import nn import torchvision from torchvision.transforms.functional import InterpolationMode import presets +import transforms import utils try: @@ -164,10 +166,21 @@ def main(args): train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) + + collate_fn = None + num_classes = len(dataset.classes) + mixup_transforms = [] + if args.mixup_alpha > 0.0: + mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) + if args.cutmix_alpha > 0.0: + mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) + if mixup_transforms: + mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) + collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, pin_memory=True) - + sampler=train_sampler, num_workers=args.workers, pin_memory=True, + collate_fn=collate_fn) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) @@ -272,6 +285,8 @@ def get_args_parser(add_help=True): parser.add_argument('--label-smoothing', default=0.0, type=float, help='label smoothing (default: 0.0)', dest='label_smoothing') + parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)') + parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') diff --git a/references/classification/transforms.py b/references/classification/transforms.py new file mode 100644 index 00000000000..c4d83ce410c --- /dev/null +++ b/references/classification/transforms.py @@ -0,0 +1,175 @@ +import math +import torch + +from typing import Tuple +from torch import Tensor +from torchvision.transforms import functional as F + + +class RandomMixup(torch.nn.Module): + """Randomly apply Mixup to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"mixup: Beyond Empirical Risk Minimization" `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for mixup. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, + p: float = 0.5, alpha: float = 1.0, + inplace: bool = False) -> None: + super().__init__() + assert num_classes > 0, "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) + elif target.ndim != 1: + raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) + elif not batch.is_floating_point(): + raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) + elif target.dtype != torch.int64: + raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on mixup paper, page 3. + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_classes={num_classes}' + s += ', p={p}' + s += ', alpha={alpha}' + s += ', inplace={inplace}' + s += ')' + return s.format(**self.__dict__) + + +class RandomCutmix(torch.nn.Module): + """Randomly apply Cutmix to the provided batch and targets. + The class implements the data augmentations as described in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" + `_. + + Args: + num_classes (int): number of classes used for one-hot encoding. + p (float): probability of the batch being transformed. Default value is 0.5. + alpha (float): hyperparameter of the Beta distribution used for cutmix. + Default value is 1.0. + inplace (bool): boolean to make this transform inplace. Default set to False. + """ + + def __init__(self, num_classes: int, + p: float = 0.5, alpha: float = 1.0, + inplace: bool = False) -> None: + super().__init__() + assert num_classes > 0, "Please provide a valid positive value for the num_classes." + assert alpha > 0, "Alpha param can't be zero." + + self.num_classes = num_classes + self.p = p + self.alpha = alpha + self.inplace = inplace + + def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + batch (Tensor): Float tensor of size (B, C, H, W) + target (Tensor): Integer tensor of size (B, ) + + Returns: + Tensor: Randomly transformed batch. + """ + if batch.ndim != 4: + raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) + elif target.ndim != 1: + raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) + elif not batch.is_floating_point(): + raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) + elif target.dtype != torch.int64: + raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) + + if not self.inplace: + batch = batch.clone() + target = target.clone() + + if target.ndim == 1: + target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) + + if torch.rand(1).item() >= self.p: + return batch, target + + # It's faster to roll the batch by one instead of shuffling it to create image pairs + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # Implemented as on cutmix paper, page 12 (with minor corrections on typos). + lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) + W, H = F.get_image_size(batch) + + r_x = torch.randint(W, (1,)) + r_y = torch.randint(H, (1,)) + + r = 0.5 * math.sqrt(1.0 - lambda_param) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] + lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) + + return batch, target + + def __repr__(self) -> str: + s = self.__class__.__name__ + '(' + s += 'num_classes={num_classes}' + s += ', p={p}' + s += ', alpha={alpha}' + s += ', inplace={inplace}' + s += ')' + return s.format(**self.__dict__) diff --git a/references/classification/utils.py b/references/classification/utils.py index bf7662ad023..fad607636e5 100644 --- a/references/classification/utils.py +++ b/references/classification/utils.py @@ -189,6 +189,8 @@ def accuracy(output, target, topk=(1,)): with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) + if target.ndim == 2: + target = target.max(dim=1)[1] _, pred = output.topk(maxk, 1, True, True) pred = pred.t() diff --git a/test/test_transforms.py b/test/test_transforms.py index 675e79ac3ba..541b0adfb6c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1311,7 +1311,8 @@ def test_random_choice(): transforms.Resize(15), transforms.Resize(20), transforms.CenterCrop(10) - ] + ], + [1 / 3, 1 / 3, 1 / 3] ) img = transforms.ToPILImage()(torch.rand(3, 25, 25)) num_samples = 250 diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 4b3c08dbce7..8da0d016f4d 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -515,9 +515,20 @@ def __call__(self, img): class RandomChoice(RandomTransforms): """Apply single transformation randomly picked from a list. This transform does not support torchscript. """ - def __call__(self, img): - t = random.choice(self.transforms) - return t(img) + def __init__(self, transforms, p=None): + super().__init__(transforms) + if p is not None and not isinstance(p, Sequence): + raise TypeError("Argument transforms should be a sequence") + self.p = p + + def __call__(self, *args): + t = random.choices(self.transforms, weights=self.p)[0] + return t(*args) + + def __repr__(self): + format_string = super().__repr__() + format_string += '(p={0})'.format(self.p) + return format_string class RandomCrop(torch.nn.Module):