Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added update_parameters to EMA to fix calculation #4406

Merged
merged 1 commit into from
Sep 14, 2021

Conversation

prabhat00155
Copy link
Contributor

@prabhat00155 prabhat00155 commented Sep 14, 2021

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks Prabhat.

As @kazhang this is an issue on PyTorch core. Ideally the base class should include all state not just the params. We use a similar approach in other averaging schemes (see this) so this is aligned to what we've seen in the past. Though this workaround will do for our use-case, I think it's still worth raising it on PyTorch core and either correcting it or introducing better control over which params we should consider.

@prabhat00155
Copy link
Contributor Author

LGTM, thanks Prabhat.

As @kazhang this is an issue on PyTorch core. Ideally the base class should include all state not just the params. We use a similar approach in other averaging schemes (see this) so this is aligned to what we've seen in the past. Though this workaround will do for our use-case, I think it's still worth raising it on PyTorch core and either correcting it or introducing better control over which params we should consider.

Makes sense. Will follow-up on this.

@prabhat00155 prabhat00155 merged commit c1f1e22 into pytorch:main Sep 14, 2021
@prabhat00155 prabhat00155 deleted the prabhat00155/fix_ema branch September 14, 2021 10:20
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this pull request Sep 28, 2021
…65495)

Summary:
While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

Pull Request resolved: #65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
facebook-github-bot pushed a commit that referenced this pull request Sep 30, 2021
Reviewed By: datumbox

Differential Revision: D31268055

fbshipit-source-id: 2bedf7cd5db0a345dffa42a9ff94ce7d425e1008
prabhat00155 added a commit to prabhat00155/pytorch that referenced this pull request Oct 5, 2021
…ytorch#65495)

Summary:
While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

Pull Request resolved: pytorch#65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
(cherry picked from commit 2ea724b)
malfet pushed a commit to pytorch/pytorch that referenced this pull request Oct 6, 2021
…65495) (#65755)

* Added option to update parameters using state_dict in AveragedModel (#65495)

Summary:
While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

Pull Request resolved: #65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
(cherry picked from commit 2ea724b)

* Added validation of mode parameter in AveragedModel (#65921)

Summary:
Discussion: #65495 (comment)

Pull Request resolved: #65921

Reviewed By: albanD

Differential Revision: D31310105

Pulled By: prabhat00155

fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
(cherry picked from commit c7748fc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Investigate Exponential Moving Average result in classification script
3 participants