Skip to content

Commit

Permalink
[fbsync] Added update_parameters to EMA to fix calculation (#4406)
Browse files Browse the repository at this point in the history
Reviewed By: datumbox

Differential Revision: D31268055

fbshipit-source-id: 2bedf7cd5db0a345dffa42a9ff94ce7d425e1008
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Sep 30, 2021
1 parent 43ef2f4 commit dd6d318
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,17 @@ def __init__(self, model, decay, device='cpu'):
decay * avg_model_param + (1 - decay) * model_param)
super().__init__(model, device, ema_avg)

def update_parameters(self, model):
for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
self.n_averaged += 1


def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
Expand Down

0 comments on commit dd6d318

Please sign in to comment.