Skip to content

Commit

Permalink
Adding LoggerManager to torchvision train script (#1299)
Browse files Browse the repository at this point in the history
* Adding LoggerManager to torchvision train script

* Quality

Co-authored-by: Alexandre Marques <alexandre@neuralmagic.com>
  • Loading branch information
corey-nm and anmarques committed Jan 9, 2023
1 parent 8c9db35 commit 1e08775
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 30 deletions.
92 changes: 66 additions & 26 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Adapted from https://github.com/pytorch/vision

import datetime
import logging
import math
import os
import sys
Expand All @@ -39,10 +40,19 @@
download_framework_model_by_recipe_type,
torch_distributed_zero_first,
)
from sparseml.pytorch.utils.logger import (
LoggerManager,
PythonLogger,
TensorBoardLogger,
WANDBLogger,
)
from sparseml.pytorch.utils.model import load_model
from sparsezoo import Model


_LOGGER = logging.getLogger(__name__)


def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
Expand All @@ -55,7 +65,7 @@ def train_one_epoch(
scaler=None,
) -> utils.MetricLogger:
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger = utils.MetricLogger(_LOGGER, delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

Expand Down Expand Up @@ -122,7 +132,7 @@ def evaluate(
log_suffix="",
) -> utils.MetricLogger:
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger = utils.MetricLogger(_LOGGER, delimiter=" ")
header = f"Test: {log_suffix}"

num_processed_samples = 0
Expand Down Expand Up @@ -162,10 +172,10 @@ def evaluate(

metric_logger.synchronize_between_processes()

print(
header,
f"Acc@1 {metric_logger.acc1.global_avg:.3f}",
f"Acc@5 {metric_logger.acc5.global_avg:.3f}",
_LOGGER.info(
header
+ f"Acc@1 {metric_logger.acc1.global_avg:.3f}"
+ f"Acc@5 {metric_logger.acc5.global_avg:.3f}"
)
return metric_logger

Expand All @@ -183,20 +193,20 @@ def _get_cache_path(filepath):

def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
_LOGGER.info("Loading data")
val_resize_size, val_crop_size, train_crop_size = (
args.val_resize_size,
args.val_crop_size,
args.train_crop_size,
)
interpolation = InterpolationMode(args.interpolation)

print("Loading training data")
_LOGGER.info("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print(f"Loading dataset_train from {cache_path}")
_LOGGER.info(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path)
else:
auto_augment_policy = getattr(args, "auto_augment", None)
Expand All @@ -215,16 +225,16 @@ def load_data(traindir, valdir, args):
),
)
if args.cache_dataset:
print(f"Saving dataset_train to {cache_path}")
_LOGGER.info(f"Saving dataset_train to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
_LOGGER.info(f"Took {time.time() - st}")

print("Loading validation data")
_LOGGER.info("Loading validation data")
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print(f"Loading dataset_test from {cache_path}")
_LOGGER.info(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path)
else:
preprocessing = presets.ClassificationPresetEval(
Expand All @@ -238,11 +248,11 @@ def load_data(traindir, valdir, args):
preprocessing,
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
_LOGGER.info(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)

print("Creating data loaders")
_LOGGER.info("Creating data loaders")
if args.distributed:
if hasattr(args, "ra_sampler") and args.ra_sampler:
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
Expand All @@ -268,7 +278,10 @@ def main(args):
utils.mkdir(args.output_dir)

utils.init_distributed_mode(args)
print(args)
if not utils.is_main_process():
_LOGGER.disabled = True

_LOGGER.info(args)

device = torch.device(args.device)

Expand Down Expand Up @@ -318,7 +331,7 @@ def collate_fn(batch):
pin_memory=True,
)

print("Creating model")
_LOGGER.info("Creating model")
if args.arch_key in ModelRegistry.available_keys():
with torch_distributed_zero_first(args.rank if args.distributed else None):
model = ModelRegistry.create(
Expand Down Expand Up @@ -439,7 +452,6 @@ def collate_fn(batch):
# NOTE: override manager with the checkpoint's manager
manager = ScheduledModifierManager.from_yaml(checkpoint["recipe"])
checkpoint_manager = None
manager.initialize(model, epoch=checkpoint["epoch"])
else:
raise ValueError("Flag --resume is set but checkpoint does not have recipe")

Expand Down Expand Up @@ -480,11 +492,30 @@ def collate_fn(batch):
evaluate(model, criterion, data_loader_test, device)
return

optimizer = (
manager.modify(model, optimizer, len(data_loader))
if manager is not None
else optimizer
)
if utils.is_main_process():
loggers = [
PythonLogger(logger=_LOGGER),
TensorBoardLogger(log_path=args.output_dir),
]
try:
loggers.append(WANDBLogger())
except ImportError:
warnings.warn("Unable to import wandb for logging")
logger = LoggerManager(loggers)
else:
logger = LoggerManager(log_python=False)

def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
for metric_name, smoothed_value in metrics.meters.items():
logger.log_scalar(
f"{tag}/{metric_name}", smoothed_value.global_avg, step=epoch
)

if manager is not None:
manager.initialize(model, epoch=args.start_epoch, loggers=logger)
optimizer = manager.modify(
model, optimizer, len(data_loader), epoch=args.start_epoch
)

lr_scheduler = _get_lr_scheduler(
args, optimizer, checkpoint=checkpoint, manager=manager
Expand All @@ -497,7 +528,8 @@ def collate_fn(batch):

best_top1_acc = -math.inf

print("Start training")
_LOGGER.info("Start training")

start_time = time.time()
max_epochs = manager.max_epochs if manager is not None else args.epochs
for epoch in range(args.start_epoch, max_epochs):
Expand All @@ -506,6 +538,7 @@ def collate_fn(batch):
if manager is not None and manager.qat_active(epoch=epoch):
scaler = None
model_ema = None

train_metrics = train_one_epoch(
model,
criterion,
Expand All @@ -517,18 +550,25 @@ def collate_fn(batch):
model_ema=model_ema,
scaler=scaler,
)
log_metrics("Train", train_metrics, epoch)

if lr_scheduler:
lr_scheduler.step()

eval_metrics = evaluate(model, criterion, data_loader_test, device)
log_metrics("Test", eval_metrics, epoch)

top1_acc = eval_metrics.acc1.global_avg
if model_ema:
evaluate(
ema_eval_metrics = evaluate(
model_ema,
criterion,
data_loader_test,
device,
log_suffix="EMA",
)
log_metrics("Test/EMA", ema_eval_metrics, epoch)

is_new_best = epoch >= args.save_best_after and top1_acc > best_top1_acc
if is_new_best:
best_top1_acc = top1_acc
Expand Down Expand Up @@ -575,7 +615,7 @@ def collate_fn(batch):

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")
_LOGGER.info(f"Training time {total_time_str}")


def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
Expand Down
10 changes: 6 additions & 4 deletions src/sparseml/pytorch/torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime
import errno
import hashlib
import logging
import os
import time
from collections import OrderedDict, defaultdict, deque
Expand Down Expand Up @@ -74,7 +75,8 @@ def __str__(self):


class MetricLogger:
def __init__(self, delimiter="\t"):
def __init__(self, logger: logging.Logger, delimiter="\t"):
self.logger = logger
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter

Expand Down Expand Up @@ -148,7 +150,7 @@ def log_every(self, iterable, print_freq, header=None):
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
self.logger.info(
log_msg.format(
i,
len(iterable),
Expand All @@ -160,7 +162,7 @@ def log_every(self, iterable, print_freq, header=None):
)
)
else:
print(
self.logger.info(
log_msg.format(
i,
len(iterable),
Expand All @@ -174,7 +176,7 @@ def log_every(self, iterable, print_freq, header=None):
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
self.logger.info(f"{header} Total time: {total_time_str}")


class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
Expand Down

0 comments on commit 1e08775

Please sign in to comment.