Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TRPO #40

Merged
merged 33 commits into from
Dec 29, 2021
Merged

Add TRPO #40

merged 33 commits into from
Dec 29, 2021

Conversation

cyprienc
Copy link
Contributor

@cyprienc cyprienc commented Sep 8, 2021

Description

This PR adds TRPO: https://arxiv.org/abs/1502.05477
It's still a work in progress (see TODO list below)

Context

closes #38

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist:

  • I've read the CONTRIBUTION guide (required)
  • The functionality/performance matches that of the source (required for new training algorithms or training-related features).
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have included an example of using the feature (required for new features).
  • I have included baseline results (required for new training algorithms or training-related features).
  • I have updated the documentation accordingly.
  • I have updated the changelog accordingly (required).
  • I have reformatted the code using make format (required)
  • I have checked the codestyle using make check-codestyle and make lint (required)
  • I have ensured make pytest and make type both pass. (required)

MuJoCo 1M Benchmark

Mujoco v2.1.0
v3 envs

Environments TRPO
logs/
HalfCheetah 1803 +/- 46
Ant 3554 +/- 591
Hopper 3372 +/- 215
Walker2d 4502 +/- 234
Swimmer 359 +/- 2

Results_Ant
Results_HalfCheetah
Results_Hopper
Results_Swimmer
Results_Walker2d

WIP - Trust Region Policy Algorithm
Currently the Hessian vector product is not working (see inline comments for more detail)
Adding no_grad block for the line search
Additional assert in the conjugate solver to help debugging
- Adding ActorCriticPolicy.get_distribution
- Using the Distribution object to compute the KL divergence
- Checking for objective improvement in the line search
- Moving magic numbers to instance variables
Improving numerical stability of the conjugate gradient algorithm
Critic updates
Changes around the alpha of the line search
Adding TRPO to __init__ files
@araffin
Copy link
Member

araffin commented Sep 9, 2021

Thanks for the PR.
before I help you to have larger scale experiments, could you first match the TRPO results from SB2 on simple env (classic control envs: CartPole, Pendulum, LunarLander) ?
You can find tuned params in the SB2 zoo: https://github.com/araffin/rl-baselines-zoo/blob/master/hyperparams/trpo.yml

Once this is done, you should focus on documentation and test and I will run experiments on pybullet envs + atari games ;)

sb3_contrib/common/utils.py Outdated Show resolved Hide resolved
sb3_contrib/common/utils.py Outdated Show resolved Hide resolved
sb3_contrib/trpo/trpo.py Outdated Show resolved Hide resolved
sb3_contrib/trpo/trpo.py Outdated Show resolved Hide resolved
sb3_contrib/trpo/trpo.py Outdated Show resolved Hide resolved
sb3_contrib/trpo/trpo.py Outdated Show resolved Hide resolved
- renaming cg_solver to conjugate_gradient_solver and renaming parameter Avp_fun to  matrix_vector_dot_func + docstring
- extra comments + better variable names in trpo.py
- defining a method for the hessian vector product instead of an inline function
- fix registering correct policies for TRPO and using correct policy base in constructor
- refactoring sb3_contrib.common.policies to reuse as much code as possible from sb3
- get_distribution will be added directly to the SB3 version of ActorCriticPolicy, this commit reflects this
@araffin
Copy link
Member

araffin commented Sep 13, 2021

Could you remove protection on your master branch so I can push changes?
While waiting for that, you can find them here: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/tree/feat/get-dist

(next time, please use another branch ;))

@cyprienc
Copy link
Contributor Author

Here are the results using the SB2 Hyper-parameters (I'll update the PR on zoo with the parameters used):

Environments TRPO
logs
CartPole-v1 468 +/- 27
Pendulum-v0 -330 +/- 110
LunarLander-v2 53 +/- 57

@araffin
Copy link
Member

araffin commented Sep 13, 2021

Here are the results using the SB2 Hyper-parameters (I'll update the PR on zoo with the parameters used):

Environments TRPO
logs
CartPole-v1 468 +/- 27
Pendulum-v0 -330 +/- 110
LunarLander-v2 53 +/- 57

thanks =), could you also add here learning curve/results comparison with SB2? (plot scripts are included in the zoo, i can help with them if needed)

@araffin
Copy link
Member

araffin commented Sep 29, 2021

Hi,
so I took a closer look and started experimenting with bullet envs. After some fixes, results look good =D (I updated the hyperparams in DLR-RM/rl-baselines3-zoo#163)
The entropy coeff is still missing though (important for Atari games I think).

Could you start adding the documentation page + more tests ?

Btw, I'm thinking about renaming some variables (related to backtracking line search) so we are more consistent with other implementations, but this is just details...

Copy link
Member

@araffin araffin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks =)

@araffin araffin merged commit 59be198 into Stable-Baselines-Team:master Dec 29, 2021
@araffin
Copy link
Member

araffin commented Oct 14, 2022

@cyprienc I think it's time to move TRPO to SB3 =)!
Could you do a PR? (that adds TRPO to SB3 and remove it from sb3 contrib, while maintaining backward compat doing from stable_baselines3 import TRPO in the init)

@cyprienc
Copy link
Contributor Author

@araffin sure, will do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] Implement TRPO
3 participants