diff --git a/utils/general.py b/utils/general.py index ca1225a09e48..45cc6103ab39 100755 --- a/utils/general.py +++ b/utils/general.py @@ -496,8 +496,7 @@ def compute_loss(p, targets, model): # predictions, targets, model s = 3 / np # output count scaling lbox *= h['giou'] * s lobj *= h['obj'] * s * (1.4 if np == 4 else 1.) - if model.nc > 1: - lcls *= h['cls'] * s + lcls *= h['cls'] * s bs = tobj.shape[0] # batch size loss = lbox + lobj + lcls @@ -524,7 +523,7 @@ def build_targets(p, targets, model): gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain # Match targets to anchors - t, offsets = targets * gain, 0 + t = targets * gain if nt: # Matches r = t[:, :, 4:6] / anchors[:, None] # wh ratio @@ -540,6 +539,9 @@ def build_targets(p, targets, model): j = torch.stack((torch.ones_like(j), j, k, l, m)) t = t.repeat((5, 1, 1))[j] offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] + else: + t = targets[0] + offsets = 0 # Define b, c = t[:, :2].long().T # image, class