From aa35ada72208cb82ab958a6e0fc65dbbe8bd0995 Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Tue, 7 Sep 2021 18:32:15 +0200 Subject: [PATCH] Optimised Callback Class to Reduce Code and Fix Errors (#4688) * added callbacks * added back callback to main * added save_dir to callback output * reduced code count * updated callbacks * added default callback class to main, added missing parameters to on_model_save * Glenn updates Co-authored-by: Glenn Jocher --- train.py | 20 ++++---- utils/callbacks.py | 123 ++++----------------------------------------- val.py | 4 +- 3 files changed, 22 insertions(+), 125 deletions(-) diff --git a/train.py b/train.py index 72aee2cb8883..f9aa3d4b5f69 100644 --- a/train.py +++ b/train.py @@ -56,7 +56,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, - callbacks=Callbacks() + callbacks ): save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ @@ -231,7 +231,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) model.half().float() # pre-reduce anchor precision - callbacks.on_pretrain_routine_end() + callbacks.run('on_pretrain_routine_end') # DDP mode if cuda and RANK != -1: @@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB) pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) - callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots, opt.sync_bn) + callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn) # end batch ------------------------------------------------------------------------------------------------ # Scheduler @@ -342,7 +342,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if RANK in [-1, 0]: # mAP - callbacks.on_train_epoch_end(epoch=epoch) + callbacks.run('on_train_epoch_end', epoch=epoch) ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) final_epoch = (epoch + 1 == epochs) or stopper.possible_stop if not noval or final_epoch: # Calculate mAP @@ -364,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if fi > best_fitness: best_fitness = fi log_vals = list(mloss) + list(results) + lr - callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi) + callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi) # Save model if (not nosave) or (final_epoch and not evolve): # if save @@ -381,7 +381,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if best_fitness == fi: torch.save(ckpt, best) del ckpt - callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) + callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) # Stop Single-GPU if RANK == -1 and stopper(epoch=epoch, fitness=fi): @@ -418,7 +418,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary for f in last, best: if f.exists(): strip_optimizer(f) # strip optimizers - callbacks.on_train_end(last, best, plots, epoch) + callbacks.run('on_train_end', last, best, plots, epoch) LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") torch.cuda.empty_cache() @@ -467,7 +467,7 @@ def parse_opt(known=False): return opt -def main(opt): +def main(opt, callbacks=Callbacks()): # Checks set_logging(RANK) if RANK in [-1, 0]: @@ -505,7 +505,7 @@ def main(opt): # Train if not opt.evolve: - train(opt.hyp, opt, device) + train(opt.hyp, opt, device, callbacks) if WORLD_SIZE > 1 and RANK == 0: _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] @@ -585,7 +585,7 @@ def main(opt): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation - results = train(hyp.copy(), opt, device) + results = train(hyp.copy(), opt, device, callbacks) # Write mutation results print_mutation(results, hyp.copy(), save_dir, opt.bucket) diff --git a/utils/callbacks.py b/utils/callbacks.py index 19c334430b5d..327b8639b60c 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -9,6 +9,7 @@ class Callbacks: Handles all registered callbacks for YOLOv5 Hooks """ + # Define the available callbacks _callbacks = { 'on_pretrain_routine_start': [], 'on_pretrain_routine_end': [], @@ -34,16 +35,13 @@ class Callbacks: 'teardown': [], } - def __init__(self): - return - def register_action(self, hook, name='', callback=None): """ Register a new action to a callback hook Args: hook The callback hook name to register the action to - name The name of the action + name The name of the action for later reference callback The callback to fire """ assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" @@ -62,118 +60,17 @@ def get_registered_actions(self, hook=None): else: return self._callbacks - def run_callbacks(self, hook, *args, **kwargs): + def run(self, hook, *args, **kwargs): """ Loop through the registered actions and fire all callbacks - """ - for logger in self._callbacks[hook]: - # print(f"Running callbacks.{logger['callback'].__name__}()") - logger['callback'](*args, **kwargs) - - def on_pretrain_routine_start(self, *args, **kwargs): - """ - Fires all registered callbacks at the start of each pretraining routine - """ - self.run_callbacks('on_pretrain_routine_start', *args, **kwargs) - - def on_pretrain_routine_end(self, *args, **kwargs): - """ - Fires all registered callbacks at the end of each pretraining routine - """ - self.run_callbacks('on_pretrain_routine_end', *args, **kwargs) - - def on_train_start(self, *args, **kwargs): - """ - Fires all registered callbacks at the start of each training - """ - self.run_callbacks('on_train_start', *args, **kwargs) - - def on_train_epoch_start(self, *args, **kwargs): - """ - Fires all registered callbacks at the start of each training epoch - """ - self.run_callbacks('on_train_epoch_start', *args, **kwargs) - - def on_train_batch_start(self, *args, **kwargs): - """ - Fires all registered callbacks at the start of each training batch - """ - self.run_callbacks('on_train_batch_start', *args, **kwargs) - def optimizer_step(self, *args, **kwargs): - """ - Fires all registered callbacks on each optimizer step - """ - self.run_callbacks('optimizer_step', *args, **kwargs) - - def on_before_zero_grad(self, *args, **kwargs): - """ - Fires all registered callbacks before zero grad - """ - self.run_callbacks('on_before_zero_grad', *args, **kwargs) - - def on_train_batch_end(self, *args, **kwargs): - """ - Fires all registered callbacks at the end of each training batch - """ - self.run_callbacks('on_train_batch_end', *args, **kwargs) - - def on_train_epoch_end(self, *args, **kwargs): - """ - Fires all registered callbacks at the end of each training epoch - """ - self.run_callbacks('on_train_epoch_end', *args, **kwargs) - - def on_val_start(self, *args, **kwargs): - """ - Fires all registered callbacks at the start of the validation - """ - self.run_callbacks('on_val_start', *args, **kwargs) - - def on_val_batch_start(self, *args, **kwargs): - """ - Fires all registered callbacks at the start of each validation batch - """ - self.run_callbacks('on_val_batch_start', *args, **kwargs) - - def on_val_image_end(self, *args, **kwargs): - """ - Fires all registered callbacks at the end of each val image - """ - self.run_callbacks('on_val_image_end', *args, **kwargs) - - def on_val_batch_end(self, *args, **kwargs): - """ - Fires all registered callbacks at the end of each validation batch - """ - self.run_callbacks('on_val_batch_end', *args, **kwargs) - - def on_val_end(self, *args, **kwargs): - """ - Fires all registered callbacks at the end of the validation - """ - self.run_callbacks('on_val_end', *args, **kwargs) - - def on_fit_epoch_end(self, *args, **kwargs): - """ - Fires all registered callbacks at the end of each fit (train+val) epoch - """ - self.run_callbacks('on_fit_epoch_end', *args, **kwargs) - - def on_model_save(self, *args, **kwargs): - """ - Fires all registered callbacks after each model save + Args: + hook The name of the hook to check, defaults to all + args Arguments to receive from YOLOv5 + kwargs Keyword Arguments to receive from YOLOv5 """ - self.run_callbacks('on_model_save', *args, **kwargs) - def on_train_end(self, *args, **kwargs): - """ - Fires all registered callbacks at the end of training - """ - self.run_callbacks('on_train_end', *args, **kwargs) + assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" - def teardown(self, *args, **kwargs): - """ - Fires all registered callbacks before teardown - """ - self.run_callbacks('teardown', *args, **kwargs) + for logger in self._callbacks[hook]: + logger['callback'](*args, **kwargs) diff --git a/val.py b/val.py index 1aa37d12dfac..947cd78f7e1f 100644 --- a/val.py +++ b/val.py @@ -216,7 +216,7 @@ def run(data, save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt')) if save_json: save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary - callbacks.on_val_image_end(pred, predn, path, names, img[si]) + callbacks.run('on_val_image_end', pred, predn, path, names, img[si]) # Plot images if plots and batch_i < 3: @@ -253,7 +253,7 @@ def run(data, # Plots if plots: confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) - callbacks.on_val_end() + callbacks.run('on_val_end') # Save JSON if save_json and len(jdict):