Skip to content

Commit

Permalink
EMA bug fix 2 (#2330)
Browse files Browse the repository at this point in the history
* EMA bug fix 2

* update
  • Loading branch information
glenn-jocher authored Mar 2, 2021
1 parent fd96810 commit fab5085
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def custom(path_or_model='path/to/model.pt', autoshape=True):
"""
model = torch.load(path_or_model) if isinstance(path_or_model, str) else path_or_model # load checkpoint
if isinstance(model, dict):
model = model['model'] # load model
model = model['ema' if model.get('ema') else 'model'] # load model

hub_model = Model(model.yaml).to(next(model.parameters()).device) # create
hub_model.load_state_dict(model.float().state_dict()) # load state_dict
Expand Down
3 changes: 2 additions & 1 deletion models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def attempt_load(weights, map_location=None):
model = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
attempt_download(w)
model.append(torch.load(w, map_location=map_location)['model'].float().fuse().eval()) # load FP32 model
ckpt = torch.load(w, map_location=map_location) # load
model.append(ckpt['ema' if ckpt.get('ema') else 'model'].float().fuse().eval()) # FP32 model

# Compatibility updates
for m in model.modules():
Expand Down
10 changes: 5 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# EMA
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
ema.updates = ckpt['ema'][1]
ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
ema.updates = ckpt['updates']

# Results
if ckpt.get('training_results') is not None:
Expand Down Expand Up @@ -383,9 +383,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': ema.ema if final_epoch else deepcopy(
model.module if is_parallel(model) else model).half(),
'ema': (deepcopy(ema.ema).half(), ema.updates),
'model': deepcopy(model.module if is_parallel(model) else model).half(),
'ema': deepcopy(ema.ema).half(),
'updates': ema.updates,
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}

Expand Down
8 changes: 5 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,18 +481,20 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
return output


def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
print('Optimizer stripped from %s,%s %.1fMB' % (f, (' saved as %s,' % s) if s else '', mb))
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")


def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):
Expand Down

4 comments on commit fab5085

@train255
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glenn-jocher I have an error with the last commit

python train.py --img 640 --batch 16 --epochs 100 --data /content/config.yaml --cfg ./models/yolov5x.yaml --weights '/content/yolov5/runs/train/exp3/weights/best.pt'

Log

Transferred 792/794 items from /content/yolov5/runs/train/exp3/weights/best.pt
Scaled weight_decay = 0.0005
Optimizer groups: 134 .bias, 134 conv.weight, 131 other
Traceback (most recent call last):
  File "train.py", line 532, in <module>
    train(hyp, opt, device, tb_writer, wandb)
  File "train.py", line 154, in train
    ema.ema.load_state_dict(ckpt['ema'].float().state_dict())
AttributeError: 'tuple' object has no attribute 'float'

@glenn-jocher
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@train255 this might be caused by updating the code in the middle of a training run and then resuming. In this case you may need to restart your training, or git revert back 1 commit and then --resume, and then you can update after training completes.

@ggyybb
Copy link

@ggyybb ggyybb commented on fab5085 Mar 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@train255 this might be caused by updating the code in the middle of a training run and then resuming. In this case you may need to restart your training, or git revert back 1 commit and then --resume, and then you can update after training completes.

oh my god ,good work, I trained my model four days,today, i found that some problem in best.pt,finally, I found best.pt not save ema model,
Now, I will pull latest code to train my model again,555555555

@glenn-jocher
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ggyybb sorry about that! We try to make all our PRs robust to starting/stopping/resuming trainings, but in this case we were not able to do that.

Everything is working correctly now, but in general it's probably best to wait until trainings are complete before doing a git pull.

Please sign in to comment.