Skip to content

Commit

Permalink
EMA bug fix 2 (ultralytics#2330)
Browse files Browse the repository at this point in the history
* EMA bug fix 2

* update
  • Loading branch information
glenn-jocher committed Mar 2, 2021
1 parent 497f7e6 commit 442edb4
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def custom(path_or_model='path/to/model.pt', autoshape=True):
"""
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
if isinstance(model, dict):
model = model['model'] # load model
model = model['ema' if model.get('ema') else 'model'] # load model

hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
Expand Down
3 changes: 2 additions & 1 deletion models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def attempt_load(weights, map_location=None):
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
attempt_download(w)
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
ckpt = torch.load(w, map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model

# Compatibility updates
for m in model.modules():
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# EMA
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
ema.updates = ckpt['ema'][1]
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
ema.updates = ckpt['updates']

# Results
if ckpt.get('training_results') is not None:
Expand Down Expand Up @@ -383,9 +383,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': ema.ema if final_epoch else deepcopy(
model.module if is_parallel(model) else model).half(),
'ema': (deepcopy(ema.ema).half(), ema.updates),
'model': deepcopy(model.module if is_parallel(model) else model).half(),
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}

Expand Down
8 changes: 5 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,18 +481,20 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
return output


def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")


def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
Expand Down

0 comments on commit 442edb4

Please sign in to comment.