Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding LoggerManager to torchvision train script #1299

Merged
merged 4 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 66 additions & 26 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Adapted from https://github.com/pytorch/vision

import datetime
import logging
import math
import os
import sys
Expand All @@ -39,10 +40,19 @@
download_framework_model_by_recipe_type,
torch_distributed_zero_first,
)
from sparseml.pytorch.utils.logger import (
LoggerManager,
PythonLogger,
TensorBoardLogger,
WANDBLogger,
)
from sparseml.pytorch.utils.model import load_model
from sparsezoo import Model


_LOGGER = logging.getLogger(__name__)


def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
Expand All @@ -55,7 +65,7 @@ def train_one_epoch(
scaler=None,
) -> utils.MetricLogger:
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger = utils.MetricLogger(_LOGGER, delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))

Expand Down Expand Up @@ -122,7 +132,7 @@ def evaluate(
log_suffix="",
) -> utils.MetricLogger:
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger = utils.MetricLogger(_LOGGER, delimiter=" ")
header = f"Test: {log_suffix}"

num_processed_samples = 0
Expand Down Expand Up @@ -162,10 +172,10 @@ def evaluate(

metric_logger.synchronize_between_processes()

print(
header,
f"Acc@1 {metric_logger.acc1.global_avg:.3f}",
f"Acc@5 {metric_logger.acc5.global_avg:.3f}",
_LOGGER.info(
header
+ f"Acc@1 {metric_logger.acc1.global_avg:.3f}"
+ f"Acc@5 {metric_logger.acc5.global_avg:.3f}"
)
return metric_logger

Expand All @@ -183,20 +193,20 @@ def _get_cache_path(filepath):

def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
_LOGGER.info("Loading data")
val_resize_size, val_crop_size, train_crop_size = (
args.val_resize_size,
args.val_crop_size,
args.train_crop_size,
)
interpolation = InterpolationMode(args.interpolation)

print("Loading training data")
_LOGGER.info("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print(f"Loading dataset_train from {cache_path}")
_LOGGER.info(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path)
else:
auto_augment_policy = getattr(args, "auto_augment", None)
Expand All @@ -215,16 +225,16 @@ def load_data(traindir, valdir, args):
),
)
if args.cache_dataset:
print(f"Saving dataset_train to {cache_path}")
_LOGGER.info(f"Saving dataset_train to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
_LOGGER.info(f"Took {time.time() - st}")

print("Loading validation data")
_LOGGER.info("Loading validation data")
cache_path = _get_cache_path(valdir)
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print(f"Loading dataset_test from {cache_path}")
_LOGGER.info(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path)
else:
preprocessing = presets.ClassificationPresetEval(
Expand All @@ -238,11 +248,11 @@ def load_data(traindir, valdir, args):
preprocessing,
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
_LOGGER.info(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)

print("Creating data loaders")
_LOGGER.info("Creating data loaders")
if args.distributed:
if hasattr(args, "ra_sampler") and args.ra_sampler:
train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps)
Expand All @@ -268,7 +278,10 @@ def main(args):
utils.mkdir(args.output_dir)

utils.init_distributed_mode(args)
print(args)
if not utils.is_main_process():
_LOGGER.disabled = True

_LOGGER.info(args)

device = torch.device(args.device)

Expand Down Expand Up @@ -318,7 +331,7 @@ def collate_fn(batch):
pin_memory=True,
)

print("Creating model")
_LOGGER.info("Creating model")
if args.arch_key in ModelRegistry.available_keys():
with torch_distributed_zero_first(args.rank if args.distributed else None):
model = ModelRegistry.create(
Expand Down Expand Up @@ -439,7 +452,6 @@ def collate_fn(batch):
# NOTE: override manager with the checkpoint's manager
manager = ScheduledModifierManager.from_yaml(checkpoint["recipe"])
checkpoint_manager = None
manager.initialize(model, epoch=checkpoint["epoch"])
else:
raise ValueError("Flag --resume is set but checkpoint does not have recipe")

Expand Down Expand Up @@ -480,11 +492,30 @@ def collate_fn(batch):
evaluate(model, criterion, data_loader_test, device)
return

optimizer = (
manager.modify(model, optimizer, len(data_loader))
if manager is not None
else optimizer
)
if utils.is_main_process():
loggers = [
PythonLogger(logger=_LOGGER),
TensorBoardLogger(log_path=args.output_dir),
]
try:
loggers.append(WANDBLogger())
except ImportError:
warnings.warn("Unable to import wandb for logging")
logger = LoggerManager(loggers)
else:
logger = LoggerManager(log_python=False)

def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int):
for metric_name, smoothed_value in metrics.meters.items():
logger.log_scalar(
f"{tag}/{metric_name}", smoothed_value.global_avg, step=epoch
)

if manager is not None:
manager.initialize(model, epoch=args.start_epoch, loggers=logger)
optimizer = manager.modify(
model, optimizer, len(data_loader), epoch=args.start_epoch
)

lr_scheduler = _get_lr_scheduler(
args, optimizer, checkpoint=checkpoint, manager=manager
Expand All @@ -497,7 +528,8 @@ def collate_fn(batch):

best_top1_acc = -math.inf

print("Start training")
_LOGGER.info("Start training")

start_time = time.time()
max_epochs = manager.max_epochs if manager is not None else args.epochs
for epoch in range(args.start_epoch, max_epochs):
Expand All @@ -506,6 +538,7 @@ def collate_fn(batch):
if manager is not None and manager.qat_active(epoch=epoch):
scaler = None
model_ema = None

train_metrics = train_one_epoch(
model,
criterion,
Expand All @@ -517,18 +550,25 @@ def collate_fn(batch):
model_ema=model_ema,
scaler=scaler,
)
log_metrics("Train", train_metrics, epoch)

if lr_scheduler:
lr_scheduler.step()

eval_metrics = evaluate(model, criterion, data_loader_test, device)
log_metrics("Test", eval_metrics, epoch)

top1_acc = eval_metrics.acc1.global_avg
if model_ema:
evaluate(
ema_eval_metrics = evaluate(
model_ema,
criterion,
data_loader_test,
device,
log_suffix="EMA",
)
log_metrics("Test/EMA", ema_eval_metrics, epoch)

is_new_best = epoch >= args.save_best_after and top1_acc > best_top1_acc
if is_new_best:
best_top1_acc = top1_acc
Expand Down Expand Up @@ -575,7 +615,7 @@ def collate_fn(batch):

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")
_LOGGER.info(f"Training time {total_time_str}")


def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None):
Expand Down
10 changes: 6 additions & 4 deletions src/sparseml/pytorch/torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime
import errno
import hashlib
import logging
import os
import time
from collections import OrderedDict, defaultdict, deque
Expand Down Expand Up @@ -74,7 +75,8 @@ def __str__(self):


class MetricLogger:
def __init__(self, delimiter="\t"):
def __init__(self, logger: logging.Logger, delimiter="\t"):
self.logger = logger
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter

Expand Down Expand Up @@ -148,7 +150,7 @@ def log_every(self, iterable, print_freq, header=None):
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
self.logger.info(
log_msg.format(
i,
len(iterable),
Expand All @@ -160,7 +162,7 @@ def log_every(self, iterable, print_freq, header=None):
)
)
else:
print(
self.logger.info(
log_msg.format(
i,
len(iterable),
Expand All @@ -174,7 +176,7 @@ def log_every(self, iterable, print_freq, header=None):
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
self.logger.info(f"{header} Total time: {total_time_str}")


class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
Expand Down