Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Exponential Moving Average support to classification reference script #4381

Merged
merged 5 commits into from
Sep 9, 2021

Conversation

prabhat00155
Copy link
Contributor

@prabhat00155 prabhat00155 commented Sep 7, 2021

Resolves #4346, resolves #4281

cc @datumbox

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @prabhat00155, overall it looks good. I left a couple of comments for your consideration. Let me know what you think.

Also do you plan to investigate how we can avoid loading two models on the GPU?

references/classification/train.py Show resolved Hide resolved
references/classification/train.py Outdated Show resolved Hide resolved
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @prabhat00155 , I made some minor comments / questions

Also, it'd be good to validate the new code somehow. We don't really have tests for the references but perhaps it would be relevant to report the result of a model relying on the --model_avg option to compare to the baseline?

references/classification/utils.py Outdated Show resolved Hide resolved
@prabhat00155
Copy link
Contributor Author

Thanks @prabhat00155 , I made some minor comments / questions

Also, it'd be good to validate the new code somehow. We don't really have tests for the references but perhaps it would be relevant to report the result of a model relying on the --model_avg option to compare to the baseline?

Thanks @NicolasHug! I ran tests locally(on CPU) on a toy dataset and on aws cluster for 5 epochs. I will run it to completion to verify the results before merging this PR.

@prabhat00155
Copy link
Contributor Author

Thanks @prabhat00155, overall it looks good. I left a couple of comments for your consideration. Let me know what you think.

Also do you plan to investigate how we can avoid loading two models on the GPU?

Thanks @datumbox! Yes, I was thinking of doing that in a follow-up PR.

@datumbox
Copy link
Contributor

datumbox commented Sep 8, 2021

@prabhat00155 sounds good to me.

The only thing worth addressing here is to pass the non parallel model in ema (see my earlier comment). I believe what you have now will fail to handle properly the checkpoints on a multi gpu setup. Worth confirming by doing two epochs on AWS and resuming from checkpoint. If it works you can leave as is.

Everything else are nits that can be done later on separate PRs.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@prabhat00155 As we discussed offline, this is a great contribution so let's unblock the merge and push further investigations on follow up PRs.

@prabhat00155
Copy link
Contributor Author

@prabhat00155 As we discussed offline, this is a great contribution so let's unblock the merge and push further investigations on follow up PRs.

Makes sense, thanks!

@prabhat00155 prabhat00155 merged commit 12fd3a6 into pytorch:main Sep 9, 2021
@prabhat00155 prabhat00155 deleted the prabhat00155/ema_support branch September 9, 2021 16:05
facebook-github-bot pushed a commit that referenced this pull request Sep 13, 2021
…eference script (#4381)

Summary:
* Added Exponential Moving Average support to classification reference script

* Addressed review comments

* Updated model argument

Reviewed By: kazhang

Differential Revision: D30898332

fbshipit-source-id: 1c9aaa2b9b1e8773fce155063bfa4de32c4c1c1e
facebook-github-bot pushed a commit to pytorch/pytorch that referenced this pull request Sep 28, 2021
…65495)

Summary:
While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

Pull Request resolved: #65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
prabhat00155 added a commit to prabhat00155/pytorch that referenced this pull request Oct 5, 2021
…ytorch#65495)

Summary:
While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

Pull Request resolved: pytorch#65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
(cherry picked from commit 2ea724b)
malfet pushed a commit to pytorch/pytorch that referenced this pull request Oct 6, 2021
…65495) (#65755)

* Added option to update parameters using state_dict in AveragedModel (#65495)

Summary:
While implementing [EMA](pytorch/vision#4381 extends AveragedModel) in torchvision, update_parameters() from AveragedModel could not be used as it did not handle state_dict(), so a custom update_parameters() needed to be defined in [EMA class](pytorch/vision#4406). This PR aims to handle this scenario removing the need for this custom update_parameters() implementation.

Discussion: pytorch/vision#4406 (review)

Pull Request resolved: #65495

Reviewed By: datumbox

Differential Revision: D31176742

Pulled By: prabhat00155

fbshipit-source-id: 326d14876018f21cf602bab5eaba344678dbabe2
(cherry picked from commit 2ea724b)

* Added validation of mode parameter in AveragedModel (#65921)

Summary:
Discussion: #65495 (comment)

Pull Request resolved: #65921

Reviewed By: albanD

Differential Revision: D31310105

Pulled By: prabhat00155

fbshipit-source-id: 417691832a7c793744830c11e0ce53e3972d21a3
(cherry picked from commit c7748fc)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants