diff --git a/utils/torch_utils.py b/utils/torch_utils.py index b934248dee43..5fbe8bbf10f6 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -412,7 +412,6 @@ def __init__(self, model, decay=0.9999, tau=2000, updates=0): for p in self.ema.parameters(): p.requires_grad_(False) - @smart_inference_mode() def update(self, model): # Update EMA parameters self.updates += 1 @@ -423,7 +422,7 @@ def update(self, model): if v.dtype.is_floating_point: # true for FP16 and FP32 v *= d v += (1 - d) * msd[k].detach() - assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32' + assert v.dtype == msd[k].detach().dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must both be FP32' def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes