From d808920dcc5b707ea36e9d3e19f0061a276830af Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 21 Aug 2022 03:16:36 +0200 Subject: [PATCH 1/4] zero-mAP fix 3 Signed-off-by: Glenn Jocher --- utils/torch_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index b934248dee43..9fad8bdea2c0 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -412,7 +412,6 @@ def __init__(self, model, decay=0.9999, tau=2000, updates=0): for p in self.ema.parameters(): p.requires_grad_(False) - @smart_inference_mode() def update(self, model): # Update EMA parameters self.updates += 1 From 8a02e918b6aadfee5c28c0d7c2cb96b0fb53497e Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 21 Aug 2022 03:19:48 +0200 Subject: [PATCH 2/4] Update torch_utils.py Signed-off-by: Glenn Jocher --- utils/torch_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 9fad8bdea2c0..782829eaaae6 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -422,7 +422,7 @@ def update(self, model): if v.dtype.is_floating_point: # true for FP16 and FP32 v *= d v += (1 - d) * msd[k].detach() - assert v.dtype == msd[k].dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32' + assert v.dtype == msd[k].detach().dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32' def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes From b2949e8d4a332fb52e2bea0d9591d5a6ab9fe35d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Aug 2022 01:20:06 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- utils/torch_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 782829eaaae6..54c5a7831a79 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -422,7 +422,8 @@ def update(self, model): if v.dtype.is_floating_point: # true for FP16 and FP32 v *= d v += (1 - d) * msd[k].detach() - assert v.dtype == msd[k].detach().dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32' + assert v.dtype == msd[k].detach( + ).dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32' def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes From f4d67e5fd25e259609c0a5c403395abeb2a07ba6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sun, 21 Aug 2022 03:21:51 +0200 Subject: [PATCH 4/4] Update torch_utils.py Signed-off-by: Glenn Jocher --- utils/torch_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils/torch_utils.py b/utils/torch_utils.py index 54c5a7831a79..5fbe8bbf10f6 100644 --- a/utils/torch_utils.py +++ b/utils/torch_utils.py @@ -422,8 +422,7 @@ def update(self, model): if v.dtype.is_floating_point: # true for FP16 and FP32 v *= d v += (1 - d) * msd[k].detach() - assert v.dtype == msd[k].detach( - ).dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must be updated in FP32' + assert v.dtype == msd[k].detach().dtype == torch.float32, f'EMA {v.dtype} and model {msd[k]} must both be FP32' def update_attr(self, model, include=(), exclude=('process_group', 'reducer')): # Update EMA attributes