diff --git a/references/classification/train.py b/references/classification/train.py index 89eae31c2cd..2d4a7c9a6fc 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -165,10 +165,16 @@ def main(args): train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) + + collate_fn = None + if args.mixup_alpha > 0.0 or args.cutmix_alpha > 0.0: + mixupcutmix = torchvision.transforms.RandomMixupCutmix(len(dataset.classes), mixup_alpha=args.mixup_alpha, + cutmix_alpha=args.cutmix_alpha) + collate_fn = lambda batch: mixupcutmix(*torch.utils.data._utils.collate.default_collate(batch)) # noqa: E731 data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, - sampler=train_sampler, num_workers=args.workers, pin_memory=True) - + sampler=train_sampler, num_workers=args.workers, pin_memory=True, + collate_fn=collate_fn) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=args.batch_size, sampler=test_sampler, num_workers=args.workers, pin_memory=True) @@ -273,6 +279,8 @@ def get_args_parser(add_help=True): parser.add_argument('--label-smoothing', default=0.0, type=float, help='label smoothing (default: 0.0)', dest='label_smoothing') + parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)') + parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 744b94cfddb..058ee1a0cfc 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -788,5 +788,5 @@ def test_random_mixupcutmix_with_real_data(): torch.testing.assert_close( torch.stack(stats).mean(dim=0), - torch.tensor([46.931968688964844, 69.97343444824219, 0.459820032119751]) + torch.tensor([46.9443473815918, 64.79092407226562, 0.459820032119751]) ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 87994a40f3f..6c7e47c221f 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1955,8 +1955,9 @@ def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.p) +# TODO: move this to references before merging and delete the tests class RandomMixupCutmix(torch.nn.Module): - """Randomly apply Mixum or Cutmix to the provided batch and targets. + """Randomly apply Mixup or Cutmix to the provided batch and targets. The class implements the data augmentations as described in the papers `"mixup: Beyond Empirical Risk Minimization" `_ and `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" @@ -2014,8 +2015,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: return batch, target # It's faster to roll the batch by one instead of shuffling it to create image pairs - batch_flipped = batch.roll(1) - target_flipped = target.roll(1) + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1) if self.mixup_alpha <= 0.0: use_mixup = False @@ -2025,8 +2026,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: if use_mixup: # Implemented as on mixup paper, page 3. lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0]) - batch_flipped.mul_(1.0 - lambda_param) - batch.mul_(lambda_param).add_(batch_flipped) + batch_rolled.mul_(1.0 - lambda_param) + batch.mul_(lambda_param).add_(batch_rolled) else: # Implemented as on cutmix paper, page 12 (with minor corrections on typos). lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0]) @@ -2044,11 +2045,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: x2 = int(torch.clamp(r_x + r_w_half, max=W)) y2 = int(torch.clamp(r_y + r_h_half, max=H)) - batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2] + batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) - target_flipped.mul_(1.0 - lambda_param) - target.mul_(lambda_param).add_(target_flipped) + target_rolled.mul_(1.0 - lambda_param) + target.mul_(lambda_param).add_(target_rolled) return batch, target