Skip to content

Commit

Permalink
Add CSV logging to GenericLogger (ultralytics#9128)
Browse files Browse the repository at this point in the history
Enable CSV logging for Classify training.

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
glenn-jocher committed Aug 24, 2022
1 parent f8816f5 commit f0e5a60
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,10 @@ class GenericLogger:

def __init__(self, opt, console_logger, include=('tb', 'wandb')):
# init default loggers
self.save_dir = opt.save_dir
self.save_dir = Path(opt.save_dir)
self.include = include
self.console_logger = console_logger
self.csv = self.save_dir / 'results.csv' # CSV logger
if 'tb' in self.include:
prefix = colorstr('TensorBoard: ')
self.console_logger.info(
Expand All @@ -258,14 +259,21 @@ def __init__(self, opt, console_logger, include=('tb', 'wandb')):
else:
self.wandb = None

def log_metrics(self, metrics_dict, epoch):
def log_metrics(self, metrics, epoch):
# Log metrics dictionary to all loggers
if self.csv:
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
with open(self.csv, 'a') as f:
f.write(s + ('%23.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')

if self.tb:
for k, v in metrics_dict.items():
for k, v in metrics.items():
self.tb.add_scalar(k, v, epoch)

if self.wandb:
self.wandb.log(metrics_dict, step=epoch)
self.wandb.log(metrics, step=epoch)

def log_images(self, files, name='Images', epoch=0):
# Log images to all loggers
Expand All @@ -291,6 +299,11 @@ def log_model(self, model_path, epoch=0, metadata={}):
art.add_file(str(model_path))
wandb.log_artifact(art)

def update_params(self, params):
# Update the paramters logged
if self.wandb:
wandb.run.config.update(params, allow_val_change=True)


def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
# Log model graph to TensorBoard
Expand Down

0 comments on commit f0e5a60

Please sign in to comment.