diff --git a/utils/loss.py b/utils/loss.py index 194c8e503e0e..5aa9f017d2af 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -7,7 +7,7 @@ import torch.nn as nn from utils.metrics import bbox_iou -from utils.torch_utils import is_parallel +from utils.torch_utils import de_parallel def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 @@ -107,7 +107,7 @@ def __init__(self, model, autobalance=False): if g > 0: BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) - det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module + det = de_parallel(model).model[-1] # Detect() module self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 7e464190f9ba..2a45f434c6a5 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -295,7 +295,7 @@ class ModelEMA: def __init__(self, model, decay=0.9999, updates=0): # Create EMA - self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA + self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA # if next(model.parameters()).device.type != 'cpu': # self.ema.half() # FP16 EMA self.updates = updates # number of EMA updates @@ -309,7 +309,7 @@ def update(self, model): self.updates += 1 d = self.decay(self.updates) - msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict + msd = de_parallel(model).state_dict() # model state_dict for k, v in self.ema.state_dict().items(): if v.dtype.is_floating_point: v *= d