From b643a9e6b9f01d9ab89d3fbe4afb368f3efa3ac0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 24 Aug 2022 12:21:31 +0200 Subject: [PATCH] Add CSV logging to GenericLogger Enable CSV logging for Classify training. Signed-off-by: Glenn Jocher --- utils/loggers/__init__.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index 59d4b566836a..880039b1914c 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -242,9 +242,10 @@ class GenericLogger: def __init__(self, opt, console_logger, include=('tb', 'wandb')): # init default loggers - self.save_dir = opt.save_dir + self.save_dir = Path(opt.save_dir) self.include = include self.console_logger = console_logger + self.csv = self.save_dir / 'results.csv' # CSV logger if 'tb' in self.include: prefix = colorstr('TensorBoard: ') self.console_logger.info( @@ -258,14 +259,21 @@ def __init__(self, opt, console_logger, include=('tb', 'wandb')): else: self.wandb = None - def log_metrics(self, metrics_dict, epoch): + def log_metrics(self, metrics, epoch): # Log metrics dictionary to all loggers + if self.csv: + keys, vals = list(metrics.keys()), list(metrics.values()) + n = len(metrics) + 1 # number of cols + s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header + with open(self.csv, 'a') as f: + f.write(s + ('%23.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n') + if self.tb: - for k, v in metrics_dict.items(): + for k, v in metrics.items(): self.tb.add_scalar(k, v, epoch) if self.wandb: - self.wandb.log(metrics_dict, step=epoch) + self.wandb.log(metrics, step=epoch) def log_images(self, files, name='Images', epoch=0): # Log images to all loggers @@ -291,6 +299,11 @@ def log_model(self, model_path, epoch=0, metadata={}): art.add_file(str(model_path)) wandb.log_artifact(art) + def update_params(self, params): + # Update the paramters logged + if self.wandb: + wandb.run.config.update(params, allow_val_change=True) + def log_tensorboard_graph(tb, model, imgsz=(640, 640)): # Log model graph to TensorBoard