diff --git a/train.py b/train.py index 665b4f5b609e..0bfcaffc16db 100644 --- a/train.py +++ b/train.py @@ -52,7 +52,7 @@ from utils.loggers.wandb.wandb_utils import check_wandb_resume from utils.loss import ComputeLoss from utils.metrics import fitness -from utils.plots import plot_evolve, plot_labels +from utils.plots import plot_evolve from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer, smart_resume, torch_distributed_zero_first) @@ -215,15 +215,11 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio prefix=colorstr('val: '))[0] if not resume: - if plots: - plot_labels(labels, names, save_dir) - - # Anchors if not opt.noautoanchor: - check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) + check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor model.half().float() # pre-reduce anchor precision - callbacks.run('on_pretrain_routine_end') + callbacks.run('on_pretrain_routine_end', labels, names, plots) # DDP mode if cuda and RANK != -1: diff --git a/utils/callbacks.py b/utils/callbacks.py index 2b32df0bf1c1..166d8938322d 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -3,6 +3,8 @@ Callback utils """ +import threading + class Callbacks: """" @@ -55,17 +57,20 @@ def get_registered_actions(self, hook=None): """ return self._callbacks[hook] if hook else self._callbacks - def run(self, hook, *args, **kwargs): + def run(self, hook, *args, thread=False, **kwargs): """ - Loop through the registered actions and fire all callbacks + Loop through the registered actions and fire all callbacks on main thread Args: hook: The name of the hook to check, defaults to all args: Arguments to receive from YOLOv5 + thread: (boolean) Run callbacks in daemon thread kwargs: Keyword Arguments to receive from YOLOv5 """ assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" - for logger in self._callbacks[hook]: - logger['callback'](*args, **kwargs) + if thread: + threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start() + else: + logger['callback'](*args, **kwargs) diff --git a/utils/general.py b/utils/general.py index d9f436a36359..3bc6fbc22d57 100755 --- a/utils/general.py +++ b/utils/general.py @@ -622,7 +622,7 @@ def download_one(url, dir): dir.mkdir(parents=True, exist_ok=True) # make directory if threads > 1: pool = ThreadPool(threads) - pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded + pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded pool.close() pool.join() else: diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py index b95a463717f8..c5cdd92772f2 100644 --- a/utils/loggers/__init__.py +++ b/utils/loggers/__init__.py @@ -11,10 +11,10 @@ import torch from torch.utils.tensorboard import SummaryWriter -from utils.general import colorstr, cv2 +from utils.general import colorstr, cv2, threaded from utils.loggers.clearml.clearml_utils import ClearmlLogger from utils.loggers.wandb.wandb_utils import WandbLogger -from utils.plots import plot_images, plot_results +from utils.plots import plot_images, plot_labels, plot_results from utils.torch_utils import de_parallel LOGGERS = ('csv', 'tb', 'wandb', 'clearml') # *.csv, TensorBoard, Weights & Biases, ClearML @@ -110,13 +110,15 @@ def on_train_start(self): # Callback runs on train start pass - def on_pretrain_routine_end(self): + def on_pretrain_routine_end(self, labels, names, plots): # Callback runs on pre-train routine end + if plots: + plot_labels(labels, names, self.save_dir) paths = self.save_dir.glob('*labels*.jpg') # training labels if self.wandb: self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]}) - if self.clearml: - pass # ClearML saves these images automatically using hooks + # if self.clearml: + # pass # ClearML saves these images automatically using hooks def on_train_batch_end(self, ni, model, imgs, targets, paths, plots): # Callback runs on train batch end diff --git a/utils/plots.py b/utils/plots.py index 2c7a80b4c872..7e1de43aba1b 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -340,7 +340,6 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_ @try_except # known issue https://github.com/ultralytics/yolov5/issues/5395 -@Timeout(30) # known issue https://github.com/ultralytics/yolov5/issues/5611 def plot_labels(labels, names=(), save_dir=Path('')): # plot dataset labels LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")