diff --git a/train.py b/train.py index 8234025212b8..14d7ac8fc027 100644 --- a/train.py +++ b/train.py @@ -472,9 +472,10 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if opt.resume: # resume an interrupted run ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist' + apriori = opt.global_rank, opt.local_rank with open(Path(ckpt).parent.parent / 'opt.yaml') as f: opt = argparse.Namespace(**yaml.load(f, Loader=yaml.FullLoader)) # replace - opt.cfg, opt.weights, opt.resume = '', ckpt, True + opt.cfg, opt.weights, opt.resume, opt.global_rank, opt.local_rank = '', ckpt, True, *apriori # reinstate logger.info('Resuming training from %s' % ckpt) else: # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')