From 4899f92902997facfdeb91e94d25063e3660c124 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 21 Aug 2022 03:22:24 +0200 Subject: [PATCH] zero-mAP fix 3 (#9058) * zero-mAP fix 3 Signed-off-by: Glenn Jocher * Update torch_utils.py Signed-off-by: Glenn Jocher * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update torch_utils.py Signed-off-by: Glenn Jocher Signed-off-by: Glenn Jocher Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- utils/torch_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index b934248dee43..5fbe8bbf10f6 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -412,7 +412,6 @@ def __init__(self, model, decay=0.9999, tau=2000, updates=0): for p in self.ema.parameters(): p.requires_grad_(False) - @smart_inference_mode() def update(self, model): # Update EMA parameters self.updates += 1 @@ -423,7 +422,7 @@ def update(self, model): if v.dtype.is_floating_point: # true for FP16 and FP32 v *= d v += (1 - d) * msd[k].detach() - assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32' + assert v.dtype == msd[k].detach().dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must both be FP32' def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes