Skip to content

Commit

Permalink
Daemon plot_labels() for faster start (#9057)
Browse files Browse the repository at this point in the history
* Daemon `plot_labels()` for faster start

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update train.py

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
glenn-jocher and pre-commit-ci[bot] committed Aug 21, 2022
1 parent 27fb6fd commit e0700cc
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 18 deletions.
10 changes: 3 additions & 7 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from utils.loggers.wandb.wandb_utils import check_wandb_resume
from utils.loss import ComputeLoss
from utils.metrics import fitness
from utils.plots import plot_evolve, plot_labels
from utils.plots import plot_evolve
from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer,
smart_resume, torch_distributed_zero_first)

Expand Down Expand Up @@ -215,15 +215,11 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
prefix=colorstr('val: '))[0]

if not resume:
if plots:
plot_labels(labels, names, save_dir)

# Anchors
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
model.half().float() # pre-reduce anchor precision

callbacks.run('on_pretrain_routine_end')
callbacks.run('on_pretrain_routine_end', labels, names, plots)

# DDP mode
if cuda and RANK != -1:
Expand Down
13 changes: 9 additions & 4 deletions utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Callback utils
"""

import threading


class Callbacks:
""""
Expand Down Expand Up @@ -55,17 +57,20 @@ def get_registered_actions(self, hook=None):
"""
return self._callbacks[hook] if hook else self._callbacks

def run(self, hook, *args, **kwargs):
def run(self, hook, *args, thread=False, **kwargs):
"""
Loop through the registered actions and fire all callbacks
Loop through the registered actions and fire all callbacks on main thread
Args:
hook: The name of the hook to check, defaults to all
args: Arguments to receive from YOLOv5
thread: (boolean) Run callbacks in daemon thread
kwargs: Keyword Arguments to receive from YOLOv5
"""

assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"

for logger in self._callbacks[hook]:
logger['callback'](*args, **kwargs)
if thread:
threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
else:
logger['callback'](*args, **kwargs)
2 changes: 1 addition & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ def download_one(url, dir):
dir.mkdir(parents=True, exist_ok=True) # make directory
if threads > 1:
pool = ThreadPool(threads)
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
pool.close()
pool.join()
else:
Expand Down
12 changes: 7 additions & 5 deletions utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
import torch
from torch.utils.tensorboard import SummaryWriter

from utils.general import colorstr, cv2
from utils.general import colorstr, cv2, threaded
from utils.loggers.clearml.clearml_utils import ClearmlLogger
from utils.loggers.wandb.wandb_utils import WandbLogger
from utils.plots import plot_images, plot_results
from utils.plots import plot_images, plot_labels, plot_results
from utils.torch_utils import de_parallel

LOGGERS = ('csv', 'tb', 'wandb', 'clearml') # *.csv, TensorBoard, Weights & Biases, ClearML
Expand Down Expand Up @@ -110,13 +110,15 @@ def on_train_start(self):
# Callback runs on train start
pass

def on_pretrain_routine_end(self):
def on_pretrain_routine_end(self, labels, names, plots):
# Callback runs on pre-train routine end
if plots:
plot_labels(labels, names, self.save_dir)
paths = self.save_dir.glob('*labels*.jpg') # training labels
if self.wandb:
self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
if self.clearml:
pass # ClearML saves these images automatically using hooks
# if self.clearml:
# pass # ClearML saves these images automatically using hooks

def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
# Callback runs on train batch end
Expand Down
1 change: 0 additions & 1 deletion utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_


@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
@Timeout(30) # known issue https://github.com/ultralytics/yolov5/issues/5611
def plot_labels(labels, names=(), save_dir=Path('')):
# plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
Expand Down

0 comments on commit e0700cc

Please sign in to comment.