Skip to content

Commit

Permalink
Update train.py test batch_size (ultralytics#2148)
Browse files Browse the repository at this point in the history
* Update train.py

* Update loss.py
  • Loading branch information
glenn-jocher committed Feb 6, 2021
1 parent e69f773 commit 8b768a9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# Process 0
if rank in [-1, 0]:
ema.updates = start_epoch * nb // accumulate # set EMA updates
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, # testloader
testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, opt, # testloader
hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
world_size=opt.world_size, workers=opt.workers,
pad=0.5, prefix=colorstr('val: '))[0]
Expand Down Expand Up @@ -338,7 +338,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
results, maps, times = test.test(opt.data,
batch_size=total_batch_size,
batch_size=batch_size * 2,
imgsz=imgsz_test,
model=ema.ema,
single_cls=opt.single_cls,
Expand Down
3 changes: 1 addition & 2 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def __init__(self, model, autobalance=False):
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)

det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
self.balance = {3: [3.67, 1.0, 0.43], 4: [3.78, 1.0, 0.39, 0.22], 5: [3.88, 1.0, 0.37, 0.17, 0.10]}[det.nl]
# self.balance = [1.0] * det.nl
self.balance = {3: [3.67, 1.0, 0.43], 4: [4.0, 1.0, 0.25, 0.06], 5: [4.0, 1.0, 0.25, 0.06, .02]}[det.nl]
self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
for k in 'na', 'nc', 'nl', 'anchors':
Expand Down

0 comments on commit 8b768a9

Please sign in to comment.