Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Stoch Wt Avg (SWA) closes #321 #320

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

pchalasani
Copy link
Contributor

@pchalasani pchalasani commented Nov 27, 2022

Stochastic Weight Averaging (SWA) is (quoting/paraphrasing from their page):

a simple procedure that improves generalization in deep learning over Stochastic Gradient Descent (SGD) at no additional cost, and can be used as a drop-in replacement for any other optimizer in PyTorch. SWA has a wide range of applications and features, [...] including [...] improve the stability of training as well as the final average rewards of policy-gradient methods in deep reinforcement learning.

See the PyTorch SWA page for more.

Description

Relatively simple change in exp_manager.py. It allows an additional key "swa" to be included in policy_kwargs, e.g.

hyperparams["policy_kwargs"]["swa"] = {
   "swa_start": 5, 
   "swa_freq: 3,
   "swa_lr": 0.05
}

Motivation and Context

SWA might help improve stability and reduce sensitivity to random seeds in some DRL applications.

Closes #321

  • I have raised an issue to propose this change (required for new features and bug fixes)

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)

Note: we are using a maximum length of 127 characters per line

@pchalasani pchalasani changed the title Support for Stoch Wt Avg (SWA) Support for Stoch Wt Avg (SWA) closes #321 Nov 27, 2022
@pchalasani pchalasani marked this pull request as ready for review November 27, 2022 02:51
@pchalasani
Copy link
Contributor Author

I realized we need to do opt.swap_swa_sgd() at the end of training, and some further thought is needed, to see how this impacts computation of validation metrics by EvalCallback etc

@pchalasani pchalasani marked this pull request as draft November 27, 2022 17:43
@pchalasani
Copy link
Contributor Author

I added opt.swap_swa_sgd() after model.learn.
We also need to do this before and after each evaluate_policy() call in EvalCallback (which is in the original sb3 repo), so that validation metrics are evaluated with the SWA-averaged model weights. We could potentially subclass EvalCallback to accomplish this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Support Stochastic Weight Averaging (SWA) for improved stability
1 participant