From 199b099ac945a426686b634dd967b0dbbf025015 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 6 Feb 2021 10:23:11 -0800 Subject: [PATCH 1/2] Update train.py --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 4ec97ae71e16..4cbd022bd231 100644 --- a/train.py +++ b/train.py @@ -190,7 +190,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # Process 0 if rank in [-1, 0]: ema.updates = start_epoch * nb // accumulate # set EMA updates - testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader + testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1, world_size=opt.world_size, workers=opt.workers, pad=0.5, prefix=colorstr('val: '))[0] @@ -338,7 +338,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP results, maps, times = test.test(opt.data, - batch_size=total_batch_size, + batch_size=batch_size * 2, imgsz=imgsz_test, model=ema.ema, single_cls=opt.single_cls, From 89b31dc3587f3bb9634b3b4073026972140e9211 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 6 Feb 2021 10:26:07 -0800 Subject: [PATCH 2/2] Update loss.py --- utils/loss.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index 889ddf7295da..2490d4bb7cfc 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -105,8 +105,7 @@ def __init__(self, model, autobalance=False): BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module - self.balance = {3: [3.67, 1.0, 0.43], 4: [3.78, 1.0, 0.39, 0.22], 5: [3.88, 1.0, 0.37, 0.17, 0.10]}[det.nl] - # self.balance = [1.0] * det.nl + self.balance = {3: [3.67, 1.0, 0.43], 4: [4.0, 1.0, 0.25, 0.06], 5: [4.0, 1.0, 0.25, 0.06, .02]}[det.nl] self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance for k in 'na', 'nc', 'nl', 'anchors':