From 7ede3a4923be4540461ccf22be0864353d4ff0ae Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 20 Feb 2021 11:09:06 -0800 Subject: [PATCH] Update loss.py --- utils/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index 481d25e207f2..2302d18de87d 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -105,8 +105,8 @@ def __init__(self, model, autobalance=False): BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module - self.balance = {3: [4.0, 1.0, 0.4], 4: [4.0, 1.0, 0.25, 0.06], 5: [4.0, 1.0, 0.25, 0.06, .02]}[det.nl] - self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index + self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, .02]) # P3-P7 + self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance for k in 'na', 'nc', 'nl', 'anchors': setattr(self, k, getattr(det, k))