Skip to content

Commit

Permalink
[fbsync] Added Exponential Moving Average support to classification r…
Browse files Browse the repository at this point in the history
…eference script (#4381)

Summary:
* Added Exponential Moving Average support to classification reference script

* Addressed review comments

* Updated model argument

Reviewed By: kazhang

Differential Revision: D30898332

fbshipit-source-id: 1c9aaa2b9b1e8773fce155063bfa4de32c4c1c1e
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Sep 13, 2021
1 parent 0eb492f commit 16e774a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
28 changes: 24 additions & 4 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
amp = None


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch,
print_freq, apex=False, model_ema=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
Expand Down Expand Up @@ -45,11 +46,14 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))

if model_ema:
model_ema.update_parameters(model)

def evaluate(model, criterion, data_loader, device, print_freq=100):

def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=''):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
header = f'Test: {log_suffix}'
with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True)
Expand Down Expand Up @@ -199,12 +203,18 @@ def main(args):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module

model_ema = None
if args.model_ema:
model_ema = utils.ExponentialMovingAverage(model_without_ddp, device=device, decay=args.model_ema_decay)

if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if model_ema:
model_ema.load_state_dict(checkpoint['model_ema'])

if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
Expand All @@ -215,16 +225,20 @@ def main(args):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
lr_scheduler.step()
evaluate(model, criterion, data_loader_test, device=device)
if model_ema:
evaluate(model_ema, criterion, data_loader_test, device=device, log_suffix='EMA')
if args.output_dir:
checkpoint = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args}
if model_ema:
checkpoint['model_ema'] = model_ema.state_dict()
utils.save_on_master(
checkpoint,
os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
Expand Down Expand Up @@ -306,6 +320,12 @@ def get_args_parser(add_help=True):
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
parser.add_argument(
'--model-ema', action='store_true',
help='enable tracking Exponential Moving Average of model parameters')
parser.add_argument(
'--model-ema-decay', type=float, default=0.99,
help='decay factor for Exponential Moving Average of model parameters(default: 0.99)')

return parser

Expand Down
12 changes: 12 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,18 @@ def log_every(self, iterable, print_freq, header=None):
print('{} Total time: {}'.format(header, total_time_str))


class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an exponential decay.
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device='cpu'):
ema_avg = (lambda avg_model_param, model_param, num_averaged:
decay * avg_model_param + (1 - decay) * model_param)
super().__init__(model, device, ema_avg)


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
Expand Down

0 comments on commit 16e774a

Please sign in to comment.