Skip to content

Commit

Permalink
Improved model+EMA checkpointing (ultralytics#2292)
Browse files Browse the repository at this point in the history
* Enhanced model+EMA checkpointing

* update

* bug fix

* bug fix 2

* always save optimizer

* ema half

* remove model.float()

* model half

* carry ema/model in fp32

* rm model.float()

* both to float always

* cleanup

* cleanup
  • Loading branch information
glenn-jocher committed Feb 25, 2021
1 parent 5437c78 commit 4d9a3dd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
1 change: 0 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def test(data,
if not training:
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}")
model.float() # for training
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
maps[c] = ap[i]
Expand Down
25 changes: 16 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from utils.google_utils import attempt_download
from utils.loss import ComputeLoss
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
id=ckpt.get('wandb_id') if 'ckpt' in locals() else None)
loggers = {'wandb': wandb} # loggers dict

# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None

# Resume
start_epoch, best_fitness = 0, 0.0
if pretrained:
Expand All @@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
optimizer.load_state_dict(ckpt['optimizer'])
best_fitness = ckpt['best_fitness']

# EMA
if ema and ckpt.get('ema'):
ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict())
ema.updates = ckpt['ema'][1]

# Results
if ckpt.get('training_results') is not None:
results_file.write_text(ckpt['training_results']) # write results.txt
Expand Down Expand Up @@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
logger.info('Using SyncBatchNorm()')

# EMA
ema = ModelEMA(model) if rank in [-1, 0] else None

# DDP mode
if cuda and rank != -1:
model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank)
Expand All @@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):

# Process 0
if rank in [-1, 0]:
ema.updates = start_epoch * nb // accumulate # set EMA updates
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
Expand Down Expand Up @@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# DDP process 0 or single-GPU
if rank in [-1, 0]:
# mAP
if ema:
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data,
Expand Down Expand Up @@ -378,15 +381,19 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
ckpt = {'epoch': epoch,
'best_fitness': best_fitness,
'training_results': results_file.read_text(),
'model': ema.ema,
'optimizer': None if final_epoch else optimizer.state_dict(),
'model': (model.module if is_parallel(model) else model).half(),
'ema': (ema.ema.half(), ema.updates),
'optimizer': optimizer.state_dict(),
'wandb_id': wandb_run.id if wandb else None}

# Save last, best and delete
torch.save(ckpt, last)
if best_fitness == fi:
torch.save(ckpt, best)
del ckpt

model.float(), ema.ema.float()

# end epoch ----------------------------------------------------------------------------------------------------
# end training

Expand Down
4 changes: 2 additions & 2 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
for key in 'optimizer', 'training_results', 'wandb_id':
x[key] = None
for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
Expand Down

0 comments on commit 4d9a3dd

Please sign in to comment.