diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 88108906bfd3..b934248dee43 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -422,7 +422,7 @@ def update(self, model): for k, v in self.ema.state_dict().items(): if v.dtype.is_floating_point: # true for FP16 and FP32 v *= d - v += (1 - d) * msd[k] + v += (1 - d) * msd[k].detach() assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32' def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):