Skip to content

Commit

Permalink
Adding mixup/cutmix in references script.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Sep 9, 2021
1 parent c1bc525 commit c4ca8c9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
16 changes: 10 additions & 6 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,16 @@ def main(args):
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)

collate_fn = None
if args.mixup_alpha > 0.0 or args.cutmix_alpha > 0.0:
mixupcutmix = torchvision.transforms.RandomMixupCutmix(len(dataset.classes), mixup_alpha=args.mixup_alpha,
cutmix_alpha=args.cutmix_alpha)
collate_fn = lambda batch: mixupcutmix(*torch.utils.data._utils.collate.default_collate(batch)) # noqa: E731
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)

sampler=train_sampler, num_workers=args.workers, pin_memory=True,
collate_fn=collate_fn)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers, pin_memory=True)
Expand Down Expand Up @@ -254,7 +260,6 @@ def main(args):
def get_args_parser(add_help=True):
import argparse
parser = argparse.ArgumentParser(description='PyTorch Classification Training', add_help=add_help)

parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
parser.add_argument('--model', default='resnet18', help='model')
parser.add_argument('--device', default='cuda', help='device')
Expand All @@ -273,6 +278,8 @@ def get_args_parser(add_help=True):
parser.add_argument('--label-smoothing', default=0.0, type=float,
help='label smoothing (default: 0.0)',
dest='label_smoothing')
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)')
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
Expand Down Expand Up @@ -306,7 +313,6 @@ def get_args_parser(add_help=True):
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')

# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
help='Use apex for mixed precision training')
Expand All @@ -315,7 +321,6 @@ def get_args_parser(add_help=True):
'O0 for FP32 training, O1 for mixed precision training.'
'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
)

# distributed training parameters
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
Expand All @@ -326,7 +331,6 @@ def get_args_parser(add_help=True):
parser.add_argument(
'--model-ema-decay', type=float, default=0.99,
help='decay factor for Exponential Moving Average of model parameters(default: 0.99)')

return parser


Expand Down
2 changes: 1 addition & 1 deletion test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,5 +788,5 @@ def test_random_mixupcutmix_with_real_data():

torch.testing.assert_close(
torch.stack(stats).mean(dim=0),
torch.tensor([46.931968688964844, 69.97343444824219, 0.459820032119751])
torch.tensor([46.9443473815918, 64.79092407226562, 0.459820032119751])
)
17 changes: 9 additions & 8 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,8 +1955,9 @@ def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)


# TODO: move this to references before merging and delete the tests
class RandomMixupCutmix(torch.nn.Module):
"""Randomly apply Mixum or Cutmix to the provided batch and targets.
"""Randomly apply Mixup or Cutmix to the provided batch and targets.
The class implements the data augmentations as described in the papers
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
Expand Down Expand Up @@ -2014,8 +2015,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
return batch, target

# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_flipped = batch.roll(1)
target_flipped = target.roll(1)
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1)

if self.mixup_alpha <= 0.0:
use_mixup = False
Expand All @@ -2025,8 +2026,8 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
if use_mixup:
# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0])
batch_flipped.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_flipped)
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)
else:
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0])
Expand All @@ -2044,11 +2045,11 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))

batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2]
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

target_flipped.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_flipped)
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)

return batch, target

Expand Down

0 comments on commit c4ca8c9

Please sign in to comment.