Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer Frequencies logic, and new configure_optimizers #1269

Merged
merged 16 commits into from
Mar 31, 2020
Merged

Optimizer Frequencies logic, and new configure_optimizers #1269

merged 16 commits into from
Mar 31, 2020

Conversation

asafmanor
Copy link
Contributor

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

What does this PR do?

Fixes #594

Description

This PR implements "optimizer_frequencies" logic, where every optimizer can be used several steps in a row as the number of its associated frequency before the next optimizer is called.
This is a highly needed feature for the optimization of networks, e.g. Wasserstein GANs.
The API follows @williamFalcon suggested API, and allows LightningModule.configure_optimizers()
to return a dictionary or multiple dictionaries, containing the new frequency key.
Backward compatibility is obviously maintained.

Documentation:

  • The LightningModule.configure_optimizers() method.
  • The ValueError returned from Trainer.init_optimizers() method.

Tests:

  • Tests were added in tests.models.test_gpu to assert return types of init_optimizers.

TODO:

  • A test is required to assert that optimizers are called in the right order and frequency.
    I've started implementing such a test in test_optimizers but ran into questions regarding the method.
    On a single_gpu training in a private project of mine, this has proven to work flawlessly.
  • I have yet to test the case where the batch is split into multiple splits.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Yep 🥇

and returns optimizer_frequencies.
optimizer_frequencies was added as a member of Trainer.
Description added to configure_optimizers in LightningModule
@pep8speaks
Copy link

pep8speaks commented Mar 28, 2020

Hello @asafmanor! Thanks for updating this PR.

Line 769:111: E501 line too long (112 > 110 characters)
Line 786:111: E501 line too long (113 > 110 characters)
Line 788:111: E501 line too long (115 > 110 characters)

Comment last updated at 2020-03-31 13:40:10 UTC

@codecov
Copy link

codecov bot commented Mar 28, 2020

Codecov Report

Merging #1269 into master will decrease coverage by 0%.
The diff coverage is 80%.

@@          Coverage Diff           @@
##           master   #1269   +/-   ##
======================================
- Coverage      92%     92%   -0%     
======================================
  Files          62      62           
  Lines        3210    3235   +25     
======================================
+ Hits         2949    2968   +19     
- Misses        261     267    +6     

@Borda Borda requested review from a team March 28, 2020 22:50
@Borda Borda added the feature Is an improvement or enhancement label Mar 28, 2020
@Borda Borda added this to the 0.7.2 milestone Mar 28, 2020
@Borda
Copy link
Member

Borda commented Mar 28, 2020

@PyTorchLightning/core-contributors ^^ pls check...

pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/distrib_parts.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/distrib_parts.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/distrib_parts.py Outdated Show resolved Hide resolved
tests/base/utils.py Outdated Show resolved Hide resolved
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls, could you check my questions?

pytorch_lightning/trainer/trainer.py Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
tests/models/test_gpu.py Outdated Show resolved Hide resolved
@asafmanor asafmanor requested a review from Borda March 29, 2020 09:15
@Borda
Copy link
Member

Borda commented Mar 29, 2020

just now it turned out there is another PR refactoring optimizers, see #1279

@Borda Borda requested review from ethanwharris and a team March 29, 2020 12:36
@Borda
Copy link
Member

Borda commented Mar 29, 2020

@ethanwharris pls be aware also of this work and decide which shall be merged first to reduce conflicts... if you agree we can change the destination branch, so this can be merged to yours #1279 or that one to here...

Copy link
Member

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great :)

@Borda I'm happy for this to be merged first and then I can rebase my changes from #1279 on top

Co-Authored-By: Asaf Manor <32155911+asafmanor@users.noreply.github.com>
@Borda Borda changed the title Optimizer Frequencies logic, and new configure_optimizers() API. Optimizer Frequencies logic, and new configure_optimizers Mar 29, 2020
@mergify mergify bot requested a review from a team March 30, 2020 22:33
Copy link
Contributor

@jeremyjordan jeremyjordan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work on this. i like the added option of returning dicts, makes the code more explicit when reading 👍

tests/models/test_gpu.py Outdated Show resolved Hide resolved
@mergify mergify bot requested a review from a team March 31, 2020 07:16
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚀

pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/training_loop.py Outdated Show resolved Hide resolved
@Borda Borda added the ready PRs ready to be merged label Mar 31, 2020
@Borda Borda changed the title Optimizer Frequencies logic, and new configure_optimizers Optimizer Frequencies logic, and new configure_optimizers [wip] Mar 31, 2020
@Borda
Copy link
Member

Borda commented Mar 31, 2020

@asafmanor GREAT work, may you pls just add a note to changelog so it can go...
when you are done remove "wipe" from PR name, Thx

@mergify
Copy link
Contributor

mergify bot commented Mar 31, 2020

This pull request is now in conflict... :(

CHANGELOG.md Outdated Show resolved Hide resolved
@Borda Borda changed the title Optimizer Frequencies logic, and new configure_optimizers [wip] Optimizer Frequencies logic, and new configure_optimizers Mar 31, 2020
@mergify
Copy link
Contributor

mergify bot commented Mar 31, 2020

This pull request is now in conflict... :(

CHANGELOG.md Outdated Show resolved Hide resolved
@mergify mergify bot merged commit aca8c7e into Lightning-AI:master Mar 31, 2020
@mergify
Copy link
Contributor

mergify bot commented Mar 31, 2020

Great job! =)

n_critic = 5
return (
{'optimizer': dis_opt, 'frequency': n_critic},
{'optimizer': gen_opt, 'frequency': 1}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this also have an example for scheduler?

{'optimizer': dis_opt, 'frequency': n_critic, 'lr_scheduler': Scheduler()}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda @asafmanor?
Also, amazing job :)

@williamFalcon
Copy link
Contributor

@asafmanor what about saving/loading checkpoints? did we handle storing this information for resuming training?

@asafmanor
Copy link
Contributor Author

I use hparams to save the frequencies.
I can easily add the optimizer frequencies to the checkpoint.
Speaking of that, I've noticed that load_from_checkpoint() does not resume optimizer states -
Is there a different method for that? should I implement one for resuming the entire information stored in the checkpoint?

@Borda
Copy link
Member

Borda commented Mar 31, 2020

We can do followup PR if needed but now it was blocking others so we needed to get it done...

alexeykarnachev pushed a commit to alexeykarnachev/pytorch-lightning that referenced this pull request Apr 3, 2020
…AI#1269)

* init_optimizers accepts Dict, Sequence[Dict]
and returns optimizer_frequencies.
optimizer_frequencies was added as a member of Trainer.

* Optimizer frequencies logic implemented in training_loop.
Description added to configure_optimizers in LightningModule

* optimizer frequencies tests added to test_gpu

* Fixed formatting for merging PR Lightning-AI#1269

* Apply suggestions from code review

* Apply suggestions from code review

Co-Authored-By: Asaf Manor <32155911+asafmanor@users.noreply.github.com>

* Update trainer.py

* Moving get_optimizers_iterable() outside.

* Update note

* Apply suggestions from code review

* formatting

* formatting

* Update CHANGELOG.md

* formatting

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
@Borda Borda modified the milestones: v0.7., v0.7.x Apr 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GAN example: Only one backward() call?
7 participants