From 8a0a1c4654272e9821e568577cb87c8de84d3d5a Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Thu, 29 Jul 2021 15:25:54 +0200 Subject: [PATCH 1/6] added callbacks --- train.py | 19 ++++- utils/callbacks.py | 183 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+), 3 deletions(-) create mode 100644 utils/callbacks.py diff --git a/train.py b/train.py index 3f5b5ed1195b..a0acd0a0fc7f 100644 --- a/train.py +++ b/train.py @@ -43,6 +43,8 @@ from utils.metrics import fitness from utils.loggers import Loggers +from utils import callbacks + LOGGER = logging.getLogger(__name__) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) @@ -52,6 +54,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary opt, device, + callbacks ): save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, = \ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \ @@ -330,6 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots) + callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots) # end batch ------------------------------------------------------------------------------------------------ @@ -340,6 +344,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if RANK in [-1, 0]: # mAP loggers.on_train_epoch_end(epoch) + callbacks.on_train_epoch_end(epoch) + ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights']) final_epoch = epoch + 1 == epochs if not noval or final_epoch: # Calculate mAP @@ -361,6 +367,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if fi > best_fitness: best_fitness = fi loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi) + callbacks.on_val_end(mloss, results, lr, epoch, best_fitness, fi) # Save model if (not nosave) or (final_epoch and not evolve): # if save @@ -378,6 +385,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary torch.save(ckpt, best) del ckpt loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi) + callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- @@ -401,6 +409,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if f.exists(): strip_optimizer(f) # strip optimizers loggers.on_train_end(last, best, plots) + callbacks.on_train_end(last, best, plots) torch.cuda.empty_cache() return results @@ -446,7 +455,11 @@ def parse_opt(known=False): return opt -def main(opt): +def main(opt, callback_handler = None): + + # Define new hook handler if one is not passed in + if not callback_handler: callback_handler = callbacks.Callbacks() + set_logging(RANK) if RANK in [-1, 0]: print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items())) @@ -482,7 +495,7 @@ def main(opt): # Train if not opt.evolve: - train(opt.hyp, opt, device) + train(opt.hyp, opt, device, callback_handler) if WORLD_SIZE > 1 and RANK == 0: _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] @@ -562,7 +575,7 @@ def main(opt): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation - results = train(hyp.copy(), opt, device) + results = train(hyp.copy(), opt, device, callback_handler) # Write mutation results print_mutation(hyp.copy(), results, yaml_file, opt.bucket) diff --git a/utils/callbacks.py b/utils/callbacks.py new file mode 100644 index 000000000000..d02f114021b3 --- /dev/null +++ b/utils/callbacks.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python + +class Callbacks: + """" + Handles all registered callbacks for YOLOv5 Hooks + """ + + _callbacks = { + 'on_pretrain_routine_start' :[], + 'on_pretrain_routine_end' :[], + + 'on_train_start' :[], + 'on_train_end': [], + 'on_train_epoch_start': [], + 'on_train_epoch_end': [], + 'on_train_batch_start': [], + 'on_train_batch_end': [], + + 'on_val_start' :[], + 'on_val_end': [], + 'on_val_epoch_start': [], + 'on_val_epoch_end': [], + 'on_val_batch_start': [], + 'on_val_batch_end': [], + + + 'on_model_save': [], + 'optimizer_step': [], + 'on_before_zero_grad': [], + 'teardown': [], + } + + def __init__(self): + return + + def regsiterAction(self, hook, name, callback): + """ + Register a new action to a callback hook + + Args: + action The callback hook name to register the action to + name The name of the action + callback The callback to fire + + Returns: + (Bool) The success state + """ + if hook in self._callbacks: + self._callbacks[hook].append({'name': name, 'callback': callback}) + return True + else: + return False + + def getRegisteredActions(self, hook=None): + """" + Returns all the registered actions by callback hook + + Args: + hook The name of the hook to check, defaults to all + """ + if hook: + return self._callbacks[hook] + else: + return self._callbacks + + def fireCallbacks(self, register, *args): + """ + Loop throughs the registered actions and fires all callbacks + """ + for logger in register: + logger['callback'](*args) + + + def on_pretrain_routine_start(self, *args): + """ + Fires all registered callbacks at the start of each pretraining routine + """ + self.fireCallbacks(self._callbacks['on_pretrain_routine_start'], *args) + + def on_pretrain_routine_end(self, *args): + """ + Fires all registered callbacks at the end of each pretraining routine + """ + self.fireCallbacks(self._callbacks['on_pretrain_routine_end'], *args) + + def on_train_start(self, *args): + """ + Fires all registered callbacks at the start of each training + """ + self.fireCallbacks(self._callbacks['on_train_start'], *args) + + def on_train_end(self, *args): + """ + Fires all registered callbacks at the end of training + """ + self.fireCallbacks(self._callbacks['on_train_end'], *args) + + def on_train_epoch_start(self, *args): + """ + Fires all registered callbacks at the start of each training epoch + """ + self.fireCallbacks(self._callbacks['on_train_epoch_start'], *args) + + def on_train_epoch_end(self, *args): + """ + Fires all registered callbacks at the end of each training epoch + """ + self.fireCallbacks(self._callbacks['on_train_epoch_end'], *args) + + + def on_train_batch_start(self, *args): + """ + Fires all registered callbacks at the start of each training batch + """ + self.fireCallbacks(self._callbacks['on_train_batch_start'], *args) + + def on_train_batch_end(self, *args): + """ + Fires all registered callbacks at the end of each training batch + """ + self.fireCallbacks(self._callbacks['on_train_batch_end'], *args) + + def on_val_start(self, *args): + """ + Fires all registered callbacks at the start of the validation + """ + self.fireCallbacks(self._callbacks['on_val_start'], *args) + + def on_val_end(self, *args): + """ + Fires all registered callbacks at the end of the validation + """ + self.fireCallbacks(self._callbacks['on_val_end'], *args) + + def on_val_epoch_start(self, *args): + """ + Fires all registered callbacks at the start of each validation epoch + """ + self.fireCallbacks(self._callbacks['on_val_epoch_start'], *args) + + def on_val_epoch_end(self, *args): + """ + Fires all registered callbacks at the end of each validation epoch + """ + self.fireCallbacks(self._callbacks['on_val_epoch_end'], *args) + + def on_val_batch_start(self, *args): + """ + Fires all registered callbacks at the start of each validation batch + """ + self.fireCallbacks(self._callbacks['on_val_batch_start'], *args) + + def on_val_batch_end(self, *args): + """ + Fires all registered callbacks at the end of each validation batch + """ + self.fireCallbacks(self._callbacks['on_val_batch_end'], *args) + + def on_model_save(self, *args): + """ + Fires all registered callbacks after each model save + """ + self.fireCallbacks(self._callbacks['on_model_save'], *args) + + def optimizer_step(self, *args): + """ + Fires all registered callbacks on each optimizer step + """ + self.fireCallbacks(self._callbacks['optimizer_step'], *args) + + def on_before_zero_grad(self, *args): + """ + Fires all registered callbacks before zero grad + """ + self.fireCallbacks(self._callbacks['on_before_zero_grad'], *args) + + def teardown(self, *args): + """ + Fires all registered callbacks before teardown + """ + self.fireCallbacks(self._callbacks['teardown'], *args) + + From 83fb93fb5d4c2169734e43b03e3e1eb89396d7f5 Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Tue, 3 Aug 2021 10:44:55 +0200 Subject: [PATCH 2/6] added back callback to main --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index dbaf16e9a158..1c84e51a8a56 100644 --- a/train.py +++ b/train.py @@ -457,7 +457,7 @@ def parse_opt(known=False): return opt -def main(opt): +def main(opt, callback_handler): # Checks set_logging(RANK) if RANK in [-1, 0]: From b4434b774f13482b6715ce3bfbda076c352cda6f Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Tue, 3 Aug 2021 11:37:51 +0200 Subject: [PATCH 3/6] added save_dir to callback output --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 1c84e51a8a56..d25b37f059b0 100644 --- a/train.py +++ b/train.py @@ -386,7 +386,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if best_fitness == fi: torch.save(ckpt, best) del ckpt - callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi) + callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi, save_dir) # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- From ef0760a988895939534df54f5528337872ba71ed Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Thu, 9 Sep 2021 17:29:54 +0200 Subject: [PATCH 4/6] merged in upstream --- train.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/train.py b/train.py index 626b7c32f821..7c05849dee00 100644 --- a/train.py +++ b/train.py @@ -383,9 +383,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary if best_fitness == fi: torch.save(ckpt, best) del ckpt -<<<<<<< HEAD - callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi, save_dir) -======= callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) # Stop Single-GPU @@ -401,7 +398,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary # with torch_distributed_zero_first(RANK): # if stop: # break # must break all DDP ranks ->>>>>>> 2d9411dbb85ae63b8ca9913726844767898eb021 # end epoch ---------------------------------------------------------------------------------------------------- # end training ----------------------------------------------------------------------------------------------------- @@ -473,11 +469,7 @@ def parse_opt(known=False): return opt -<<<<<<< HEAD -def main(opt, callback_handler): -======= def main(opt, callbacks=Callbacks()): ->>>>>>> 2d9411dbb85ae63b8ca9913726844767898eb021 # Checks set_logging(RANK) if RANK in [-1, 0]: @@ -516,11 +508,7 @@ def main(opt, callbacks=Callbacks()): # Train if not opt.evolve: -<<<<<<< HEAD - train(opt.hyp, opt, device, callback_handler) -======= train(opt.hyp, opt, device, callbacks) ->>>>>>> 2d9411dbb85ae63b8ca9913726844767898eb021 if WORLD_SIZE > 1 and RANK == 0: _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')] @@ -600,11 +588,7 @@ def main(opt, callbacks=Callbacks()): hyp[k] = round(hyp[k], 5) # significant digits # Train mutation -<<<<<<< HEAD - results = train(hyp.copy(), opt, device, callback_handler) -======= results = train(hyp.copy(), opt, device, callbacks) ->>>>>>> 2d9411dbb85ae63b8ca9913726844767898eb021 # Write mutation results print_mutation(results, hyp.copy(), save_dir, opt.bucket) From 46bb613eb63e11ac2dabdabdd9243f3ce6f4d067 Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Sat, 11 Sep 2021 09:17:29 +0200 Subject: [PATCH 5/6] removed ghost code --- train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/train.py b/train.py index 2ce6e5e4c7ef..e5410eeeba9f 100644 --- a/train.py +++ b/train.py @@ -47,8 +47,6 @@ from utils.loggers import Loggers from utils.callbacks import Callbacks -from utils import callbacks - LOGGER = logging.getLogger(__name__) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) From 72353f069895f4da9c6401a46797c601fee4d651 Mon Sep 17 00:00:00 2001 From: Kalen Michael Date: Wed, 29 Sep 2021 14:52:49 +0200 Subject: [PATCH 6/6] fixed parsing error for google temp links --- utils/general.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/general.py b/utils/general.py index 28301f8573bb..f2afb480cc63 100755 --- a/utils/general.py +++ b/utils/general.py @@ -313,7 +313,7 @@ def check_file(file, suffix=''): return file elif file.startswith(('http:/', 'https:/')): # download url = str(Path(file)).replace(':/', '://') # Pathlib turns :// -> :/ - file = Path(urllib.parse.unquote(file)).name.split('?')[0] # '%2F' to '/', split https://url.com/file.txt?auth + file = Path(urllib.parse.unquote(file).split('?')[0]).name # '%2F' to '/', split https://url.com/file.txt?auth print(f'Downloading {url} to {file}...') torch.hub.download_url_to_file(url, file) assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {url}' # check