Skip to content

Commit

Permalink
Add segmentation and classification support for ClearML (#10752)
Browse files Browse the repository at this point in the history
* Added ClearML instance segmentation and classification support

* Cleaned up ClearML plot output

* typos

* Log results as plots instead of debug samples

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
3 people committed Jan 3, 2024
1 parent 66edf38 commit 151c953
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 13 deletions.
49 changes: 40 additions & 9 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ def on_pretrain_routine_end(self, labels, names):
paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb:
self.wandb.log({'Labels': [wandb.Image(str(x), caption=x.name) for x in paths]})
# if self.clearml:
# pass # ClearML saves these images automatically using hooks
if self.comet_logger:
self.comet_logger.on_pretrain_routine_end(paths)
if self.clearml:
for path in paths:
self.clearml.log_plot(title=path.stem, plot_path=path)

def on_train_batch_end(self, model, ni, imgs, targets, paths, vals):
log_dict = dict(zip(self.keys[:3], vals))
Expand Down Expand Up @@ -255,9 +256,7 @@ def on_fit_epoch_end(self, vals, epoch, best_fitness, fi):
for k, v in x.items():
self.tb.add_scalar(k, v, epoch)
elif self.clearml: # log to ClearML if TensorBoard not used
for k, v in x.items():
title, series = k.split('/')
self.clearml.task.get_logger().report_scalar(title, series, v, epoch)
self.clearml.log_scalars(x, epoch)

if self.wandb:
if best_fitness == fi:
Expand Down Expand Up @@ -311,9 +310,10 @@ def on_train_end(self, last, best, epoch, results):
self.wandb.finish_run()

if self.clearml and not self.opt.evolve:
self.clearml.task.update_output_model(model_path=str(best if best.exists() else last),
name='Best Model',
auto_delete_file=False)
self.clearml.log_summary(dict(zip(self.keys[3:10], results)))
[self.clearml.log_plot(title=f.stem, plot_path=f) for f in files]
self.clearml.log_model(str(best if best.exists() else last),
"Best Model" if best.exists() else "Last Model", epoch)

if self.comet_logger:
final_results = dict(zip(self.keys[3:10], results))
Expand All @@ -325,6 +325,8 @@ def on_params_update(self, params: dict):
self.wandb.wandb_run.config.update(params, allow_val_change=True)
if self.comet_logger:
self.comet_logger.on_params_update(params)
if self.clearml:
self.clearml.task.connect(params)


class GenericLogger:
Expand All @@ -337,7 +339,7 @@ class GenericLogger:
include: loggers to include
"""

def __init__(self, opt, console_logger, include=('tb', 'wandb')):
def __init__(self, opt, console_logger, include=('tb', 'wandb', 'clearml')):
# init default loggers
self.save_dir = Path(opt.save_dir)
self.include = include
Expand All @@ -356,6 +358,22 @@ def __init__(self, opt, console_logger, include=('tb', 'wandb')):
else:
self.wandb = None

if clearml and 'clearml' in self.include:
try:
# Hyp is not available in classification mode
if 'hyp' not in opt:
hyp = {}
else:
hyp = opt.hyp
self.clearml = ClearmlLogger(opt, hyp)
except Exception:
self.clearml = None
prefix = colorstr('ClearML: ')
LOGGER.warning(f'{prefix}WARNING ⚠️ ClearML is installed but not configured, skipping ClearML logging.'
f' See https://github.com/ultralytics/yolov5/tree/master/utils/loggers/clearml#readme')
else:
self.clearml = None

def log_metrics(self, metrics, epoch):
# Log metrics dictionary to all loggers
if self.csv:
Expand All @@ -372,6 +390,9 @@ def log_metrics(self, metrics, epoch):
if self.wandb:
self.wandb.log(metrics, step=epoch)

if self.clearml:
self.clearml.log_scalars(metrics, epoch)

def log_images(self, files, name='Images', epoch=0):
# Log images to all loggers
files = [Path(f) for f in (files if isinstance(files, (tuple, list)) else [files])] # to Path
Expand All @@ -384,6 +405,12 @@ def log_images(self, files, name='Images', epoch=0):
if self.wandb:
self.wandb.log({name: [wandb.Image(str(f), caption=f.name) for f in files]}, step=epoch)

if self.clearml:
if name == 'Results':
[self.clearml.log_plot(f.stem, f) for f in files]
else:
self.clearml.log_debug_samples(files, title=name)

def log_graph(self, model, imgsz=(640, 640)):
# Log model graph to all loggers
if self.tb:
Expand All @@ -395,11 +422,15 @@ def log_model(self, model_path, epoch=0, metadata={}):
art = wandb.Artifact(name=f'run_{wandb.run.id}_model', type='model', metadata=metadata)
art.add_file(str(model_path))
wandb.log_artifact(art)
if self.clearml:
self.clearml.log_model(model_path=model_path, model_name=model_path.stem)

def update_params(self, params):
# Update the parameters logged
if self.wandb:
wandb.run.config.update(params, allow_val_change=True)
if self.clearml:
self.clearml.task.connect(params)


def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
Expand Down
65 changes: 61 additions & 4 deletions utils/loggers/clearml/clearml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import re
from pathlib import Path

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import yaml
from ultralytics.utils.plotting import Annotator, colors
Expand Down Expand Up @@ -78,18 +80,22 @@ def __init__(self, opt, hyp):
# Maximum number of images to log to clearML per epoch
self.max_imgs_to_log_per_epoch = 16
# Get the interval of epochs when bounding box images should be logged
self.bbox_interval = opt.bbox_interval
# Only for detection task though!
if 'bbox_interval' in opt:
self.bbox_interval = opt.bbox_interval
self.clearml = clearml
self.task = None
self.data_dict = None
if self.clearml:
self.task = Task.init(
project_name=opt.project if opt.project != 'runs/train' else 'YOLOv5',
project_name=opt.project if not str(opt.project).startswith('runs/') else 'YOLOv5',
task_name=opt.name if opt.name != 'exp' else 'Training',
tags=['YOLOv5'],
output_uri=True,
reuse_last_task_id=opt.exist_ok,
auto_connect_frameworks={'pytorch': False}
auto_connect_frameworks={
'pytorch': False,
'matplotlib': False}
# We disconnect pytorch auto-detection, because we added manual model save points in the code
)
# ClearML's hooks will already grab all general parameters
Expand All @@ -112,6 +118,57 @@ def __init__(self, opt, hyp):
# to give it to them
opt.data = self.data_dict

def log_scalars(self, metrics, epoch):
"""
Log scalars/metrics to ClearML
arguments:
metrics (dict) Metrics in dict format: {"metrics/mAP": 0.8, ...}
epoch (int) iteration number for the current set of metrics
"""
for k, v in metrics.items():
title, series = k.split('/')
self.task.get_logger().report_scalar(title, series, v, epoch)

def log_model(self, model_path, model_name, epoch=0):
"""
Log model weights to ClearML
arguments:
model_path (PosixPath or str) Path to the model weights
model_name (str) Name of the model visible in ClearML
epoch (int) Iteration / epoch of the model weights
"""
self.task.update_output_model(model_path=str(model_path),
name=model_name,
iteration=epoch,
auto_delete_file=False)

def log_summary(self, metrics):
"""
Log final metrics to a summary table
arguments:
metrics (dict) Metrics in dict format: {"metrics/mAP": 0.8, ...}
"""
for k, v in metrics.items():
self.task.get_logger().report_single_value(k, v)

def log_plot(self, title, plot_path):
"""
Log image as plot in the plot section of ClearML
arguments:
title (str) Title of the plot
plot_path (PosixPath or str) Path to the saved image file
"""
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
ax.imshow(img)

self.task.get_logger().report_matplotlib_figure(title, "", figure=fig, report_interactive=False)

def log_debug_samples(self, files, title='Debug Samples'):
"""
Log files (images) as debug samples in the ClearML task.
Expand All @@ -125,7 +182,7 @@ def log_debug_samples(self, files, title='Debug Samples'):
it = re.search(r'_batch(\d+)', f.name)
iteration = int(it.groups()[0]) if it else 0
self.task.get_logger().report_image(title=title,
series=f.name.replace(it.group(), ''),
series=f.name.replace(f"_batch{iteration}", ''),
local_path=str(f),
iteration=iteration)

Expand Down

0 comments on commit 151c953

Please sign in to comment.