From 814ea3ff9f4e0b02879121289a81507c72cf8f9f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 2 Aug 2022 14:41:45 +0200 Subject: [PATCH 1/5] New `smart_resume()` --- train.py | 33 ++++++--------------------------- utils/torch_utils.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/train.py b/train.py index 20fef265110c..99a43f8614c4 100644 --- a/train.py +++ b/train.py @@ -54,7 +54,7 @@ from utils.metrics import fitness from utils.plots import plot_evolve, plot_labels from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer, - torch_distributed_zero_first) + smart_resume, torch_distributed_zero_first) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) @@ -163,26 +163,9 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio ema = ModelEMA(model) if RANK in {-1, 0} else None # Resume - start_epoch, best_fitness = 0, 0.0 + best_fitness, start_epoch = 0.0, 0 if pretrained: - # Optimizer - if ckpt['optimizer'] is not None: - optimizer.load_state_dict(ckpt['optimizer']) - best_fitness = ckpt['best_fitness'] - - # EMA - if ema and ckpt.get('ema'): - ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) - ema.updates = ckpt['updates'] - - # Epochs - start_epoch = ckpt['epoch'] + 1 - if resume: - assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' - if epochs < start_epoch: - LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") - epochs += ckpt['epoch'] # finetune additional epochs - + best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume) del ckpt, csd # DP mode @@ -212,8 +195,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio quad=opt.quad, prefix=colorstr('train: '), shuffle=True) - mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class - nb = len(train_loader) # number of batches + labels = np.concatenate(dataset.labels, 0) + mlc = int(labels[:, 0].max()) # max label class assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' # Process 0 @@ -232,10 +215,6 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio prefix=colorstr('val: '))[0] if not resume: - labels = np.concatenate(dataset.labels, 0) - # c = torch.tensor(labels[:, 0]) # classes - # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency - # model._initialize_biases(cf.to(device)) if plots: plot_labels(labels, names, save_dir) @@ -263,6 +242,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # Start training t0 = time.time() + nb = len(train_loader) # number of batches nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations) # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training last_opt_step = -1 @@ -510,7 +490,6 @@ def main(opt, callbacks=Callbacks()): with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f: opt = argparse.Namespace(**yaml.safe_load(f)) # replace opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate - LOGGER.info(f'Resuming training from {ckpt}') else: opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \ check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 5f2a22c36f1a..df2a38441042 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -306,6 +306,31 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e- return optimizer +def smart_resume(ckpt, optimizer, ema, weights, epochs, resume): + # Resume training from a partially trained checkpoint + + # Optimizer + best_fitness = 0.0 + if ckpt['optimizer'] is not None: + optimizer.load_state_dict(ckpt['optimizer']) + best_fitness = ckpt['best_fitness'] + + # EMA + if ema and ckpt.get('ema'): + ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) + ema.updates = ckpt['updates'] + + # Epochs + start_epoch = ckpt['epoch'] + 1 + if resume: + assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' + if epochs < start_epoch: + LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") + epochs += ckpt['epoch'] # finetune additional epochs + + return best_fitness, start_epoch, epochs + + class EarlyStopping: # YOLOv5 simple early stopper def __init__(self, patience=30): From 794a15f56eb1c4343f1a1a93b66c1c26a71009bb Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 2 Aug 2022 14:43:59 +0200 Subject: [PATCH 2/5] Update torch_utils.py --- utils/torch_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index df2a38441042..5f1f2b677c16 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -324,6 +324,7 @@ def smart_resume(ckpt, optimizer, ema, weights, epochs, resume): start_epoch = ckpt['epoch'] + 1 if resume: assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' + LOGGER.info(f'Resuming training from {weights} to {epochs} epochs') if epochs < start_epoch: LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") epochs += ckpt['epoch'] # finetune additional epochs From 53e73a3f93011b2b47729c5250525a0185c8e301 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 2 Aug 2022 14:48:12 +0200 Subject: [PATCH 3/5] Update torch_utils.py --- utils/torch_utils.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 5f1f2b677c16..069a1b109779 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -306,29 +306,22 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e- return optimizer -def smart_resume(ckpt, optimizer, ema, weights, epochs, resume): +def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True): # Resume training from a partially trained checkpoint - - # Optimizer - best_fitness = 0.0 + best_fitness, start_epoch = 0.0 + start_epoch = ckpt['epoch'] + 1 if ckpt['optimizer'] is not None: - optimizer.load_state_dict(ckpt['optimizer']) + optimizer.load_state_dict(ckpt['optimizer']) # optimizer best_fitness = ckpt['best_fitness'] - - # EMA if ema and ckpt.get('ema'): - ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) + ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA ema.updates = ckpt['updates'] - - # Epochs - start_epoch = ckpt['epoch'] + 1 if resume: assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' LOGGER.info(f'Resuming training from {weights} to {epochs} epochs') if epochs < start_epoch: LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") epochs += ckpt['epoch'] # finetune additional epochs - return best_fitness, start_epoch, epochs From 14bcb1fb9f233d7e79c215eb48e8f6e4cb91d890 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 2 Aug 2022 14:50:09 +0200 Subject: [PATCH 4/5] Update torch_utils.py --- utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 069a1b109779..ffdd3d68b7f7 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -318,7 +318,7 @@ def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, re ema.updates = ckpt['updates'] if resume: assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' - LOGGER.info(f'Resuming training from {weights} to {epochs} epochs') + LOGGER.info(f'Resuming training from {weights} for {epochs - start_epoch} more epochs to {epochs} total epochs') if epochs < start_epoch: LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") epochs += ckpt['epoch'] # finetune additional epochs From 21e0562661b587dd70dcb282f44ec88a1d8ede97 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 2 Aug 2022 14:53:17 +0200 Subject: [PATCH 5/5] fix --- utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index ffdd3d68b7f7..391ddead2985 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -308,7 +308,7 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e- def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True): # Resume training from a partially trained checkpoint - best_fitness, start_epoch = 0.0 + best_fitness = 0.0 start_epoch = ckpt['epoch'] + 1 if ckpt['optimizer'] is not None: optimizer.load_state_dict(ckpt['optimizer']) # optimizer