Skip to content

Commit

Permalink
Merge branch 'main' into feature/torchvision-distillation-support
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jan 11, 2023
2 parents 735956e + da3ffb0 commit 6be8dd8
Showing 1 changed file with 47 additions and 15 deletions.
62 changes: 47 additions & 15 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import warnings
from functools import update_wrapper
from types import SimpleNamespace
from typing import Optional
from typing import Callable, Optional

import torch
import torch.utils.data
Expand Down Expand Up @@ -62,23 +62,32 @@ def train_one_epoch(
device: torch.device,
epoch: int,
args,
log_metrics_fn: Callable[[str, utils.MetricLogger, int, int], None],
manager=None,
model_ema=None,
scaler=None,
) -> utils.MetricLogger:
accum_steps = args.gradient_accum_steps

model.train()
metric_logger = utils.MetricLogger(_LOGGER, delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
metric_logger.add_meter(
"imgs_per_sec", utils.SmoothedValue(window_size=10, fmt="{value}")
)
metric_logger.add_meter("loss", utils.SmoothedValue(window_size=accum_steps))
metric_logger.add_meter("acc1", utils.SmoothedValue(window_size=accum_steps))
metric_logger.add_meter("acc5", utils.SmoothedValue(window_size=accum_steps))

steps_accumulated = 0
num_optim_steps = 0

# initial zero grad for gradient accumulation
optimizer.zero_grad()

header = f"Epoch: [{epoch}]"
for i, (image, target) in enumerate(
metric_logger.log_every(data_loader, args.print_freq, header)
for (image, target) in metric_logger.log_every(
data_loader, args.logging_steps * accum_steps, header
):
start_time = time.time()
image, target = image.to(device), target.to(device)
Expand All @@ -89,7 +98,7 @@ def train_one_epoch(
output = output[0]
loss = criterion(output, target)

if steps_accumulated % args.gradient_accum_steps == 0:
if steps_accumulated % accum_steps == 0:
if manager is not None:
loss = manager.loss_update(
loss=loss,
Expand Down Expand Up @@ -119,9 +128,10 @@ def train_one_epoch(

# zero grad here to start accumulating next set of gradients
optimizer.zero_grad()
num_optim_steps += 1
steps_accumulated += 1

if model_ema and i % args.model_ema_steps == 0:
if model_ema and num_optim_steps % args.model_ema_steps == 0:
model_ema.update_parameters(model)
if epoch < args.lr_warmup_epochs:
# Reset ema buffer to keep copying weights during warmup period
Expand All @@ -132,7 +142,12 @@ def train_one_epoch(
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
metric_logger.meters["img/s"].update(batch_size / (time.time() - start_time))
metric_logger.meters["imgs_per_sec"].update(
batch_size / (time.time() - start_time)
)

if num_optim_steps % args.logging_steps == 0:
log_metrics_fn("Train", metric_logger, epoch, num_optim_steps)
return metric_logger


Expand Down Expand Up @@ -504,10 +519,17 @@ def collate_fn(batch):
criterion,
data_loader_test,
device,
print_freq=args.logging_steps,
log_suffix="EMA",
)
else:
evaluate(model, criterion, data_loader_test, device)
evaluate(
model,
criterion,
data_loader_test,
device,
print_freq=args.logging_steps,
)
return

if utils.is_main_process():
Expand All @@ -523,10 +545,13 @@ def collate_fn(batch):
else:
logger = LoggerManager(log_python=False)

def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
steps_per_epoch = len(data_loader) / args.gradient_accum_steps

def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: int):
step = int(epoch * steps_per_epoch + epoch_step)
for metric_name, smoothed_value in metrics.meters.items():
logger.log_scalar(
f"{tag}/{metric_name}", smoothed_value.global_avg, step=epoch
f"{tag}/{metric_name}", smoothed_value.global_avg, step=step
)

if manager is not None:
Expand All @@ -537,7 +562,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
distillation_teacher=args.distill_teacher,
)
optimizer = manager.modify(
model, optimizer, len(data_loader), epoch=args.start_epoch
model, optimizer, steps_per_epoch=steps_per_epoch, epoch=args.start_epoch
)

lr_scheduler = _get_lr_scheduler(
Expand Down Expand Up @@ -570,17 +595,18 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
device,
epoch,
args,
log_metrics,
manager=manager,
model_ema=model_ema,
scaler=scaler,
)
log_metrics("Train", train_metrics, epoch)
log_metrics("Train", train_metrics, epoch, steps_per_epoch)

if lr_scheduler:
lr_scheduler.step()

eval_metrics = evaluate(model, criterion, data_loader_test, device)
log_metrics("Test", eval_metrics, epoch)
log_metrics("Test", eval_metrics, epoch, steps_per_epoch)

top1_acc = eval_metrics.acc1.global_avg
if model_ema:
Expand All @@ -591,7 +617,7 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
device,
log_suffix="EMA",
)
log_metrics("Test/EMA", ema_eval_metrics, epoch)
log_metrics("Test/EMA", ema_eval_metrics, epoch, steps_per_epoch)

is_new_best = epoch >= args.save_best_after and top1_acc > best_top1_acc
if is_new_best:
Expand Down Expand Up @@ -916,7 +942,13 @@ def new_func(*args, **kwargs):
type=float,
help="minimum lr of lr schedule",
)
@click.option("--print-freq", default=10, type=int, help="print frequency")
@click.option("--print-freq", default=None, type=int, help="DEPRECATED. Does nothing.")
@click.option(
"--logging-steps",
default=10,
type=int,
help="Frequency in number of batch updates for logging/printing",
)
@click.option("--output-dir", default=".", type=str, help="path to save outputs")
@click.option("--resume", default=None, type=str, help="path of checkpoint")
@click.option(
Expand Down

0 comments on commit 6be8dd8

Please sign in to comment.