diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 446c39e5f30..7741717d72d 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -15,6 +15,7 @@ # Adapted from https://github.com/pytorch/vision import datetime +import logging import math import os import sys @@ -39,10 +40,19 @@ download_framework_model_by_recipe_type, torch_distributed_zero_first, ) +from sparseml.pytorch.utils.logger import ( + LoggerManager, + PythonLogger, + TensorBoardLogger, + WANDBLogger, +) from sparseml.pytorch.utils.model import load_model from sparsezoo import Model +_LOGGER = logging.getLogger(__name__) + + def train_one_epoch( model: torch.nn.Module, criterion: torch.nn.Module, @@ -55,7 +65,7 @@ def train_one_epoch( scaler=None, ) -> utils.MetricLogger: model.train() - metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger = utils.MetricLogger(_LOGGER, delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}")) @@ -122,7 +132,7 @@ def evaluate( log_suffix="", ) -> utils.MetricLogger: model.eval() - metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger = utils.MetricLogger(_LOGGER, delimiter=" ") header = f"Test: {log_suffix}" num_processed_samples = 0 @@ -162,10 +172,10 @@ def evaluate( metric_logger.synchronize_between_processes() - print( - header, - f"Acc@1 {metric_logger.acc1.global_avg:.3f}", - f"Acc@5 {metric_logger.acc5.global_avg:.3f}", + _LOGGER.info( + header + + f"Acc@1 {metric_logger.acc1.global_avg:.3f}" + + f"Acc@5 {metric_logger.acc5.global_avg:.3f}" ) return metric_logger @@ -183,7 +193,7 @@ def _get_cache_path(filepath): def load_data(traindir, valdir, args): # Data loading code - print("Loading data") + _LOGGER.info("Loading data") val_resize_size, val_crop_size, train_crop_size = ( args.val_resize_size, args.val_crop_size, @@ -191,12 +201,12 @@ def load_data(traindir, valdir, args): ) interpolation = InterpolationMode(args.interpolation) - print("Loading training data") + _LOGGER.info("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! - print(f"Loading dataset_train from {cache_path}") + _LOGGER.info(f"Loading dataset_train from {cache_path}") dataset, _ = torch.load(cache_path) else: auto_augment_policy = getattr(args, "auto_augment", None) @@ -215,16 +225,16 @@ def load_data(traindir, valdir, args): ), ) if args.cache_dataset: - print(f"Saving dataset_train to {cache_path}") + _LOGGER.info(f"Saving dataset_train to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) - print("Took", time.time() - st) + _LOGGER.info(f"Took {time.time() - st}") - print("Loading validation data") + _LOGGER.info("Loading validation data") cache_path = _get_cache_path(valdir) if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! - print(f"Loading dataset_test from {cache_path}") + _LOGGER.info(f"Loading dataset_test from {cache_path}") dataset_test, _ = torch.load(cache_path) else: preprocessing = presets.ClassificationPresetEval( @@ -238,11 +248,11 @@ def load_data(traindir, valdir, args): preprocessing, ) if args.cache_dataset: - print(f"Saving dataset_test to {cache_path}") + _LOGGER.info(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) - print("Creating data loaders") + _LOGGER.info("Creating data loaders") if args.distributed: if hasattr(args, "ra_sampler") and args.ra_sampler: train_sampler = RASampler(dataset, shuffle=True, repetitions=args.ra_reps) @@ -268,7 +278,10 @@ def main(args): utils.mkdir(args.output_dir) utils.init_distributed_mode(args) - print(args) + if not utils.is_main_process(): + _LOGGER.disabled = True + + _LOGGER.info(args) device = torch.device(args.device) @@ -318,7 +331,7 @@ def collate_fn(batch): pin_memory=True, ) - print("Creating model") + _LOGGER.info("Creating model") if args.arch_key in ModelRegistry.available_keys(): with torch_distributed_zero_first(args.rank if args.distributed else None): model = ModelRegistry.create( @@ -439,7 +452,6 @@ def collate_fn(batch): # NOTE: override manager with the checkpoint's manager manager = ScheduledModifierManager.from_yaml(checkpoint["recipe"]) checkpoint_manager = None - manager.initialize(model, epoch=checkpoint["epoch"]) else: raise ValueError("Flag --resume is set but checkpoint does not have recipe") @@ -480,11 +492,30 @@ def collate_fn(batch): evaluate(model, criterion, data_loader_test, device) return - optimizer = ( - manager.modify(model, optimizer, len(data_loader)) - if manager is not None - else optimizer - ) + if utils.is_main_process(): + loggers = [ + PythonLogger(logger=_LOGGER), + TensorBoardLogger(log_path=args.output_dir), + ] + try: + loggers.append(WANDBLogger()) + except ImportError: + warnings.warn("Unable to import wandb for logging") + logger = LoggerManager(loggers) + else: + logger = LoggerManager(log_python=False) + + def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int): + for metric_name, smoothed_value in metrics.meters.items(): + logger.log_scalar( + f"{tag}/{metric_name}", smoothed_value.global_avg, step=epoch + ) + + if manager is not None: + manager.initialize(model, epoch=args.start_epoch, loggers=logger) + optimizer = manager.modify( + model, optimizer, len(data_loader), epoch=args.start_epoch + ) lr_scheduler = _get_lr_scheduler( args, optimizer, checkpoint=checkpoint, manager=manager @@ -497,7 +528,8 @@ def collate_fn(batch): best_top1_acc = -math.inf - print("Start training") + _LOGGER.info("Start training") + start_time = time.time() max_epochs = manager.max_epochs if manager is not None else args.epochs for epoch in range(args.start_epoch, max_epochs): @@ -506,6 +538,7 @@ def collate_fn(batch): if manager is not None and manager.qat_active(epoch=epoch): scaler = None model_ema = None + train_metrics = train_one_epoch( model, criterion, @@ -517,18 +550,25 @@ def collate_fn(batch): model_ema=model_ema, scaler=scaler, ) + log_metrics("Train", train_metrics, epoch) + if lr_scheduler: lr_scheduler.step() + eval_metrics = evaluate(model, criterion, data_loader_test, device) + log_metrics("Test", eval_metrics, epoch) + top1_acc = eval_metrics.acc1.global_avg if model_ema: - evaluate( + ema_eval_metrics = evaluate( model_ema, criterion, data_loader_test, device, log_suffix="EMA", ) + log_metrics("Test/EMA", ema_eval_metrics, epoch) + is_new_best = epoch >= args.save_best_after and top1_acc > best_top1_acc if is_new_best: best_top1_acc = top1_acc @@ -575,7 +615,7 @@ def collate_fn(batch): total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print(f"Training time {total_time_str}") + _LOGGER.info(f"Training time {total_time_str}") def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None): diff --git a/src/sparseml/pytorch/torchvision/utils.py b/src/sparseml/pytorch/torchvision/utils.py index 8f743067908..e991c6e132d 100644 --- a/src/sparseml/pytorch/torchvision/utils.py +++ b/src/sparseml/pytorch/torchvision/utils.py @@ -5,6 +5,7 @@ import datetime import errno import hashlib +import logging import os import time from collections import OrderedDict, defaultdict, deque @@ -74,7 +75,8 @@ def __str__(self): class MetricLogger: - def __init__(self, delimiter="\t"): + def __init__(self, logger: logging.Logger, delimiter="\t"): + self.logger = logger self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter @@ -148,7 +150,7 @@ def log_every(self, iterable, print_freq, header=None): eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): - print( + self.logger.info( log_msg.format( i, len(iterable), @@ -160,7 +162,7 @@ def log_every(self, iterable, print_freq, header=None): ) ) else: - print( + self.logger.info( log_msg.format( i, len(iterable), @@ -174,7 +176,7 @@ def log_every(self, iterable, print_freq, header=None): end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print(f"{header} Total time: {total_time_str}") + self.logger.info(f"{header} Total time: {total_time_str}") class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):