From b757bb5fa45df7d856fc25753f2a59a5a3ce8bb0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 24 Feb 2021 18:26:46 -0800 Subject: [PATCH] Improved model+EMA checkpointing (#2292) * Enhanced model+EMA checkpointing * update * bug fix * bug fix 2 * always save optimizer * ema half * remove model.float() * model half * carry ema/model in fp32 * rm model.float() * both to float always * cleanup * cleanup --- test.py | 1 - train.py | 25 ++++++++++++++++--------- utils/general.py | 4 ++-- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/test.py b/test.py index ecd45f5f4943..9f484c809052 100644 --- a/test.py +++ b/test.py @@ -272,7 +272,6 @@ def test(data, if not training: s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' print(f"Results saved to {save_dir}{s}") - model.float() # for training maps = np.zeros(nc) + map for i, c in enumerate(ap_class): maps[c] = ap[i] diff --git a/train.py b/train.py index 8533667fe57f..7aa57fa99e24 100644 --- a/train.py +++ b/train.py @@ -31,7 +31,7 @@ from utils.google_utils import attempt_download from utils.loss import ComputeLoss from utils.plots import plot_images, plot_labels, plot_results, plot_evolution -from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first +from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel logger = logging.getLogger(__name__) @@ -136,6 +136,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): id=ckpt.get('wandb_id') if 'ckpt' in locals() else None) loggers = {'wandb': wandb} # loggers dict + # EMA + ema = ModelEMA(model) if rank in [-1, 0] else None + # Resume start_epoch, best_fitness = 0, 0.0 if pretrained: @@ -144,6 +147,11 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): optimizer.load_state_dict(ckpt['optimizer']) best_fitness = ckpt['best_fitness'] + # EMA + if ema and ckpt.get('ema'): + ema.ema.load_state_dict(ckpt['ema'][0].float().state_dict()) + ema.updates = ckpt['ema'][1] + # Results if ckpt.get('training_results') is not None: results_file.write_text(ckpt['training_results']) # write results.txt @@ -173,9 +181,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device) logger.info('Using SyncBatchNorm()') - # EMA - ema = ModelEMA(model) if rank in [-1, 0] else None - # DDP mode if cuda and rank != -1: model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) @@ -191,7 +196,6 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Process 0 if rank in [-1, 0]: - ema.updates = start_epoch * nb // accumulate # set EMA updates testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, world_size=opt.world_size, workers=opt.workers, @@ -335,8 +339,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # DDP process 0 or single-GPU if rank in [-1, 0]: # mAP - if ema: - ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) + ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights']) final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP results, maps, times = test.test(opt.data, @@ -378,8 +381,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): ckpt = {'epoch': epoch, 'best_fitness': best_fitness, 'training_results': results_file.read_text(), - 'model': ema.ema, - 'optimizer': None if final_epoch else optimizer.state_dict(), + 'model': (model.module if is_parallel(model) else model).half(), + 'ema': (ema.ema.half(), ema.updates), + 'optimizer': optimizer.state_dict(), 'wandb_id': wandb_run.id if wandb else None} # Save last, best and delete @@ -387,6 +391,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): if best_fitness == fi: torch.save(ckpt, best) del ckpt + + model.float(), ema.ema.float() + # end epoch ---------------------------------------------------------------------------------------------------- # end training diff --git a/utils/general.py b/utils/general.py index 3b5f4629b00a..e5bbc50c6177 100755 --- a/utils/general.py +++ b/utils/general.py @@ -484,8 +484,8 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non def strip_optimizer(f='weights/best.pt', s=''): # from utils.general import *; strip_optimizer() # Strip optimizer from 'f' to finalize training, optionally save as 's' x = torch.load(f, map_location=torch.device('cpu')) - for key in 'optimizer', 'training_results', 'wandb_id': - x[key] = None + for k in 'optimizer', 'training_results', 'wandb_id', 'ema': # keys + x[k] = None x['epoch'] = -1 x['model'].half() # to FP16 for p in x['model'].parameters():