From e8d79543ab941d9e5819c18dcf60dffb8245b87e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 3 Mar 2021 20:19:13 -0800 Subject: [PATCH 1/2] Resume with custom anchors fix --- train.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 1b8b315ce927..84ce0e9f740e 100644 --- a/train.py +++ b/train.py @@ -75,10 +75,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): with torch_distributed_zero_first(rank): attempt_download(weights) # download if not found locally ckpt = torch.load(weights, map_location=device) # load checkpoint - if hyp.get('anchors'): - ckpt['model'].yaml['anchors'] = round(hyp['anchors']) # force autoanchor - model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create - exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys + model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create + exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys state_dict = ckpt['model'].float().state_dict() # to FP32 state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect model.load_state_dict(state_dict, strict=False) # load From 5f9d6a618bd643863f170a39c36353623a3c4885 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 3 Mar 2021 21:02:26 -0800 Subject: [PATCH 2/2] Update train.py --- train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train.py b/train.py index 84ce0e9f740e..ecac59857ccc 100644 --- a/train.py +++ b/train.py @@ -214,6 +214,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Anchors if not opt.noautoanchor: check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) + model.half().float() # pre-reduce anchor precision # Model parameters hyp['box'] *= 3. / nl # scale to layers