From abb4643c4906bd0bf0c7d8f413fad7bdfa199254 Mon Sep 17 00:00:00 2001 From: "Matthias F. Brandstetter" Date: Thu, 20 Jan 2022 13:41:28 +0100 Subject: [PATCH 1/3] New flag 'stop_training' in util.callbacks.Callbacks class to prematurely stop training from callback handler --- train.py | 8 ++++++++ utils/callbacks.py | 3 +++ 2 files changed, 11 insertions(+) diff --git a/train.py b/train.py index b20b7dbb2dda..e1c4ffcdb352 100644 --- a/train.py +++ b/train.py @@ -352,6 +352,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % ( f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1])) callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn) + if callbacks.stop_training: + return # end batch ------------------------------------------------------------------------------------------------ # Scheduler @@ -381,6 +383,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary best_fitness = fi log_vals = list(mloss) + list(results) + lr callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi) + if callbacks.stop_training: + return # Save model if (not nosave) or (final_epoch and not evolve): # if save @@ -401,6 +405,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary torch.save(ckpt, w / f'epoch{epoch}.pt') del ckpt callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) + if callbacks.stop_training: + return # Stop Single-GPU if RANK == -1 and stopper(epoch=epoch, fitness=fi): @@ -440,6 +446,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary compute_loss=compute_loss) # val best model with plots if is_coco: callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi) + if callbacks.stop_training: + return callbacks.run('on_train_end', last, best, plots, epoch, results) LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") diff --git a/utils/callbacks.py b/utils/callbacks.py index 13d82ebc2e41..c99f60e6f49c 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -36,6 +36,9 @@ def __init__(self): 'teardown': [], } + # Set this to True in your callback handler to prematurely stop training + self.stop_training = False + def register_action(self, hook, name='', callback=None): """ Register a new action to a callback hook From 047b7e8e41ab6e190c486a016496dd098df42040 Mon Sep 17 00:00:00 2001 From: Matthias Date: Sat, 22 Jan 2022 18:35:11 +0100 Subject: [PATCH 2/3] Removed most of the new checks, leaving only the one after calling 'on_train_batch_end' --- train.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/train.py b/train.py index e1c4ffcdb352..510377e1178e 100644 --- a/train.py +++ b/train.py @@ -383,8 +383,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary best_fitness = fi log_vals = list(mloss) + list(results) + lr callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi) - if callbacks.stop_training: - return # Save model if (not nosave) or (final_epoch and not evolve): # if save @@ -405,8 +403,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary torch.save(ckpt, w / f'epoch{epoch}.pt') del ckpt callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi) - if callbacks.stop_training: - return # Stop Single-GPU if RANK == -1 and stopper(epoch=epoch, fitness=fi): @@ -446,8 +442,6 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary compute_loss=compute_loss) # val best model with plots if is_coco: callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi) - if callbacks.stop_training: - return callbacks.run('on_train_end', last, best, plots, epoch, results) LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}") From 9fd2d2554c934a8e6153c33df843e20a8b4faf30 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 22 Jan 2022 16:26:43 -1000 Subject: [PATCH 3/3] Cleanup --- utils/callbacks.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/utils/callbacks.py b/utils/callbacks.py index c99f60e6f49c..c51c268f20d6 100644 --- a/utils/callbacks.py +++ b/utils/callbacks.py @@ -35,9 +35,7 @@ def __init__(self): 'on_params_update': [], 'teardown': [], } - - # Set this to True in your callback handler to prematurely stop training - self.stop_training = False + self.stop_training = False # set True to interrupt training def register_action(self, hook, name='', callback=None): """