Skip to content

Commit

Permalink
Adding mixup/cutmix in references script.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Sep 9, 2021
1 parent c1bc525 commit 67acd89
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
12 changes: 10 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,16 @@ def main(args):
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)

collate_fn = None
if args.mixup_alpha > 0.0 or args.cutmix_alpha > 0.0:
mixupcutmix = torchvision.transforms.RandomMixupCutmix(len(dataset.classes), mixup_alpha=args.mixup_alpha,
cutmix_alpha=args.cutmix_alpha)
collate_fn = lambda batch: mixupcutmix(*torch.utils.data._utils.collate.default_collate(batch)) # noqa: E731
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)

sampler=train_sampler, num_workers=args.workers, pin_memory=True,
collate_fn=collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
Expand Down Expand Up @@ -273,6 +279,8 @@ def get_args_parser(add_help=True):
parser.add_argument('--label-smoothing', default=0.0, type=float,
help='label smoothing (default: 0.0)',
dest='label_smoothing')
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,5 +788,5 @@ def test_random_mixupcutmix_with_real_data():

torch.testing.assert_close(
torch.stack(stats).mean(dim=0),
torch.tensor([46.931968688964844, 69.97343444824219, 0.459820032119751])
torch.tensor([46.9443473815918, 64.79092407226562, 0.459820032119751])
)
17 changes: 9 additions & 8 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,8 +1955,9 @@ def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)


# TODO: move this to references before merging and delete the tests
class RandomMixupCutmix(torch.nn.Module):
"""Randomly apply Mixum or Cutmix to the provided batch and targets.
"""Randomly apply Mixup or Cutmix to the provided batch and targets.
The class implements the data augmentations as described in the papers
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
Expand Down Expand Up @@ -2014,8 +2015,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_flipped = batch.roll(1)
target_flipped = target.roll(1)
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1)

if self.mixup_alpha <= 0.0:
use_mixup = False
Expand All @@ -2025,8 +2026,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
if use_mixup:
# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0])
batch_flipped.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_flipped)
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)
else:
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0])
Expand All @@ -2044,11 +2045,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))

batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2]
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

target_flipped.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_flipped)
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

Expand Down

0 comments on commit 67acd89

Please sign in to comment.