From 7fa039b03e3ed2ef015825eea89172cba11e7f19 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 1 Mar 2021 16:51:20 -0800 Subject: [PATCH 1/2] EMA bug fix 2 --- hubconf.py | 2 +- models/experimental.py | 3 ++- train.py | 3 +-- utils/general.py | 6 ++++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/hubconf.py b/hubconf.py index 47eee4477725..a8eb51681794 100644 --- a/hubconf.py +++ b/hubconf.py @@ -120,7 +120,7 @@ def custom(path_or_model='path/to/model.pt', autoshape=True): """ model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint if isinstance(model, dict): - model = model['model'] # load model + model = model['ema' if model.get('ema') else 'model'] # load model hub_model = Model(model.yaml).to(next(model.parameters()).device) # create hub_model.load_state_dict(model.float().state_dict()) # load state_dict diff --git a/models/experimental.py b/models/experimental.py index 5fe56858c54a..d79052314f9b 100644 --- a/models/experimental.py +++ b/models/experimental.py @@ -115,7 +115,8 @@ def attempt_load(weights, map_location=None): model = Ensemble() for w in weights if isinstance(weights, list) else [weights]: attempt_download(w) - model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model + ckpt = torch.load(w, map_location=map_location) # load + model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model # Compatibility updates for m in model.modules(): diff --git a/train.py b/train.py index 5c203f12651d..bbf879f3af5f 100644 --- a/train.py +++ b/train.py @@ -383,8 +383,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ckpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': results_file.read_text(), - 'model': ema.ema if final_epoch else deepcopy( - model.module if is_parallel(model) else model).half(), + 'model': deepcopy(model.module if is_parallel(model) else model).half(), 'ema': (deepcopy(ema.ema).half(), ema.updates), 'optimizer': optimizer.state_dict(), 'wandb_id': wandb_run.id if wandb else None} diff --git a/utils/general.py b/utils/general.py index e5bbc50c6177..72b8359e6627 100755 --- a/utils/general.py +++ b/utils/general.py @@ -481,9 +481,11 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non return output -def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer() +def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer() # Strip optimizer from 'f' to finalize training, optionally save as 's' x = torch.load(f, map_location=torch.device('cpu')) + if x.get('ema'): + x['model'] = x['ema'] # replace model with ema for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys x[k] = None x['epoch'] = -1 @@ -492,7 +494,7 @@ def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; p.requires_grad = False torch.save(x, s or f) mb = os.path.getsize(s or f) / 1E6 # filesize - print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb)) + print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB") def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''): From 048a5ef3b21bc5573ce94fb22a219efc306a2bf6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 1 Mar 2021 17:06:22 -0800 Subject: [PATCH 2/2] update --- train.py | 7 ++++--- utils/general.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index bbf879f3af5f..e2c82339f7fe 100644 --- a/train.py +++ b/train.py @@ -151,8 +151,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # EMA if ema and ckpt.get('ema'): - ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict()) - ema.updates = ckpt['ema'][1] + ema.ema.load_state_dict(ckpt['ema'].float().state_dict()) + ema.updates = ckpt['updates'] # Results if ckpt.get('training_results') is not None: @@ -384,7 +384,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): 'best_fitness': best_fitness, 'training_results': results_file.read_text(), 'model': deepcopy(model.module if is_parallel(model) else model).half(), - 'ema': (deepcopy(ema.ema).half(), ema.updates), + 'ema': deepcopy(ema.ema).half(), + 'updates': ema.updates, 'optimizer': optimizer.state_dict(), 'wandb_id': wandb_run.id if wandb else None} diff --git a/utils/general.py b/utils/general.py index 72b8359e6627..df8cf7bab60d 100755 --- a/utils/general.py +++ b/utils/general.py @@ -486,7 +486,7 @@ def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_op x = torch.load(f, map_location=torch.device('cpu')) if x.get('ema'): x['model'] = x['ema'] # replace model with ema - for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys + for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys x[k] = None x['epoch'] = -1 x['model'].half() # to FP16