From 2d85b63871e84adbdb586ce4fc8b0fc05655f88d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Tue, 2 Aug 2022 15:13:58 +0200 Subject: [PATCH] New `smart_resume()` (#8838) * New `smart_resume()` * Update torch_utils.py * Update torch_utils.py * Update torch_utils.py * fix --- train.py | 33 ++++++--------------------------- utils/torch_utils.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/train.py b/train.py index 20fef265110c..99a43f8614c4 100644 --- a/train.py +++ b/train.py @@ -54,7 +54,7 @@ from utils.metrics import fitness from utils.plots import plot_evolve, plot_labels from utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel, select_device, smart_DDP, smart_optimizer, - torch_distributed_zero_first) + smart_resume, torch_distributed_zero_first) LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html RANK = int(os.getenv('RANK', -1)) @@ -163,26 +163,9 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio ema = ModelEMA(model) if RANK in {-1, 0} else None # Resume - start_epoch, best_fitness = 0, 0.0 + best_fitness, start_epoch = 0.0, 0 if pretrained: - # Optimizer - if ckpt['optimizer'] is not None: - optimizer.load_state_dict(ckpt['optimizer']) - best_fitness = ckpt['best_fitness'] - - # EMA - if ema and ckpt.get('ema'): - ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) - ema.updates = ckpt['updates'] - - # Epochs - start_epoch = ckpt['epoch'] + 1 - if resume: - assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' - if epochs < start_epoch: - LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") - epochs += ckpt['epoch'] # finetune additional epochs - + best_fitness, start_epoch, epochs = smart_resume(ckpt, optimizer, ema, weights, epochs, resume) del ckpt, csd # DP mode @@ -212,8 +195,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio quad=opt.quad, prefix=colorstr('train: '), shuffle=True) - mlc = int(np.concatenate(dataset.labels, 0)[:, 0].max()) # max label class - nb = len(train_loader) # number of batches + labels = np.concatenate(dataset.labels, 0) + mlc = int(labels[:, 0].max()) # max label class assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}' # Process 0 @@ -232,10 +215,6 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio prefix=colorstr('val: '))[0] if not resume: - labels = np.concatenate(dataset.labels, 0) - # c = torch.tensor(labels[:, 0]) # classes - # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency - # model._initialize_biases(cf.to(device)) if plots: plot_labels(labels, names, save_dir) @@ -263,6 +242,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio # Start training t0 = time.time() + nb = len(train_loader) # number of batches nw = max(round(hyp['warmup_epochs'] * nb), 100) # number of warmup iterations, max(3 epochs, 100 iterations) # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training last_opt_step = -1 @@ -510,7 +490,6 @@ def main(opt, callbacks=Callbacks()): with open(Path(ckpt).parent.parent / 'opt.yaml', errors='ignore') as f: opt = argparse.Namespace(**yaml.safe_load(f)) # replace opt.cfg, opt.weights, opt.resume = '', ckpt, True # reinstate - LOGGER.info(f'Resuming training from {ckpt}') else: opt.data, opt.cfg, opt.hyp, opt.weights, opt.project = \ check_file(opt.data), check_yaml(opt.cfg), check_yaml(opt.hyp), str(opt.weights), str(opt.project) # checks diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 5f2a22c36f1a..391ddead2985 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -306,6 +306,25 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, weight_decay=1e- return optimizer +def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True): + # Resume training from a partially trained checkpoint + best_fitness = 0.0 + start_epoch = ckpt['epoch'] + 1 + if ckpt['optimizer'] is not None: + optimizer.load_state_dict(ckpt['optimizer']) # optimizer + best_fitness = ckpt['best_fitness'] + if ema and ckpt.get('ema'): + ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) # EMA + ema.updates = ckpt['updates'] + if resume: + assert start_epoch > 0, f'{weights} training to {epochs} epochs is finished, nothing to resume.' + LOGGER.info(f'Resuming training from {weights} for {epochs - start_epoch} more epochs to {epochs} total epochs') + if epochs < start_epoch: + LOGGER.info(f"{weights} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {epochs} more epochs.") + epochs += ckpt['epoch'] # finetune additional epochs + return best_fitness, start_epoch, epochs + + class EarlyStopping: # YOLOv5 simple early stopper def __init__(self, patience=30):