Skip to content

Commit

Permalink
Implement HER (DLR-RM#120)
Browse files Browse the repository at this point in the history
* Added working her version, Online sampling is missing.

* Updated test_her.

* Added first version of online her sampling. Still problems with tensor dimensions.

* Reformat

* Fixed tests

* Added some comments.

* Updated changelog.

* Add missing init file

* Fixed some small bugs.

* Reduced arguments for HER, small changes.

* Added getattr. Fixed bug for online sampling.

* Updated save/load funtions. Small changes.

* Added her to init.

* Updated save method.

* Updated her ratio.

* Move obs_wrapper

* Added DQN test.

* Fix potential bug

* Offline and online her share same sample_goal function.

* Changed lists into arrays.

* Updated her test.

* Fix online sampling

* Fixed action bug. Updated time limit for episodes.

* Updated convert_dict method to take keys as arguments.

* Renamed obs dict wrapper.

* Seed bit flipping env

* Remove get_episode_dict

* Add fast online sampling version

* Added documentation.

* Vectorized reward computation

* Vectorized goal sampling

* Update time limit for episodes in online her sampling.

* Fix max episode length inference

* Bug fix for Fetch envs

* Fix for HER + gSDE

* Reformat (new black version)

* Added info dict to compute new reward. Check her_replay_buffer again.

* Fix info buffer

* Updated done flag.

* Fixes for gSDE

* Offline her version uses now HerReplayBuffer as episode storage.

* Fix num_timesteps computation

* Fix get torch params

* Vectorized version for offline sampling.

* Modified offline her sampling to use sample method of her_replay_buffer

* Updated HER tests.

* Updated documentation

* Cleanup docstrings

* Updated to review comments

* Fix pytype

* Update according to review comments.

* Removed random goal strategy. Updated sample transitions.

* Updated migration. Removed time signal removal.

* Update doc

* Fix potential load issue

* Add VecNormalize support for dict obs

* Updated saving/loading replay buffer for HER.

* Fix test memory usage

* Fixed save/load replay buffer.

* Fixed save/load replay buffer

* Fixed transition index after loading replay buffer in online sampling

* Better error handling

* Add tests for get_time_limit

* More tests for VecNormalize with dict obs

* Update doc

* Improve HER description

* Add test for sde support

* Add comments

* Add comments

* Remove check that was always valid

* Fix for terminal observation

* Updated buffer size in offline version and reset of HER buffer

* Reformat

* Update doc

* Remove np.empty + add doc

* Fix loading

* Updated loading replay buffer

* Separate online and offline sampling + bug fixes

* Update tensorboard log name

* Version bump

* Bug fix for special case

Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
3 people authored and leor-c committed Nov 12, 2020
1 parent ed29d47 commit 940ee5d
Show file tree
Hide file tree
Showing 34 changed files with 1,899 additions and 102 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pytest:
./scripts/run_tests.sh

type:
pytype
pytype -j auto

lint:
# stop the build if there are Python syntax errors or undefined names
Expand Down
13 changes: 1 addition & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,9 @@ These algorithms will make it easier for the research community and industry to
| Type hints | :heavy_check_mark: |


### Roadmap to V1.0

Please look at the issue for more details.
Planned features:

- [ ] HER

### Planned features (v1.1+)

- [ ] DQN extensions (prioritized replay, double q-learning, ...)
- [ ] Support for `Tuple` and `Dict` observation spaces
- [ ] Recurrent Policies
- [ ] TRPO

Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones).

## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3)

Expand Down
78 changes: 76 additions & 2 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ notebooks:
- `Atari Games`_
- `RL Baselines zoo`_
- `PyBullet`_

.. - `Hindsight Experience Replay`_
- `Hindsight Experience Replay`_

.. _Getting Started: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_getting_started.ipynb
.. _Training, Saving, Loading: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb
Expand Down Expand Up @@ -343,6 +342,81 @@ will compute a running average and standard deviation of input features (it can
env.norm_reward = False
Hindsight Experience Replay (HER)
---------------------------------

For this example, we are using `Highway-Env <https://github.com/eleurent/highway-env>`_ by `@eleurent <https://github.com/eleurent>`_.


.. image:: ../_static/img/colab-badge.svg
:target: https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/stable_baselines_her.ipynb


.. figure:: https://github.com/raw/eleurent/highway-env/gh-media/docs/media/parking-env.gif

The highway-parking-v0 environment.

The parking env is a goal-conditioned continuous control task, in which the vehicle must park in a given space with the appropriate heading.

.. note::

The hyperparameters in the following example were optimized for that environment.


.. code-block:: python
import gym
import highway_env
import numpy as np
from stable_baselines3 import HER, SAC, DDPG, TD3
from stable_baselines3.common.noise import NormalActionNoise
env = gym.make("parking-v0")
# Create 4 artificial transitions per real transition
n_sampled_goal = 4
# SAC hyperparams:
model = HER(
"MlpPolicy",
env,
SAC,
n_sampled_goal=n_sampled_goal,
goal_selection_strategy="future",
# IMPORTANT: because the env is not wrapped with a TimeLimit wrapper
# we have to manually specify the max number of steps per episode
max_episode_length=100,
verbose=1,
buffer_size=int(1e6),
learning_rate=1e-3,
gamma=0.95,
batch_size=256,
online_sampling=True,
policy_kwargs=dict(net_arch=[256, 256, 256]),
)
model.learn(int(2e5))
model.save("her_sac_highway")
# Load saved model
model = HER.load("her_sac_highway", env=env)
obs = env.reset()
# Evaluate the agent
episode_reward = 0
for _ in range(100):
action, _ = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
episode_reward += reward
if done or info.get("is_success", False):
print("Reward:", episode_reward, "Success?", info.get("is_success", False))
episode_reward = 0.0
obs = env.reset()
Record a Video
--------------

Expand Down
8 changes: 8 additions & 0 deletions docs/guide/migration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ Despite this change, no change in performance should be expected.
To match SB2 behavior, you need to explicitly pass ``deterministic=True``


HER
^^^

The ``HER`` implementation now also supports online sampling of the new goals. This is done in a vectorized version.
The goal selection strategy ``RANDOM`` is no longer supported.
``HER`` now supports ``VecNormalize`` wrapper but only when ``online_sampling=True``.
For performance reasons, the maximum number of steps per episodes must be specified (see :ref:`HER <her>` documentation).


New logger API
--------------
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Main Features
modules/a2c
modules/ddpg
modules/dqn
modules/her
modules/ppo
modules/sac
modules/td3
Expand Down
10 changes: 7 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog
==========


Pre-Release 0.10.0a0 (WIP)
Pre-Release 0.10.0a1 (WIP)
------------------------------

Breaking Changes:
Expand All @@ -14,11 +14,14 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Allow custom actor/critic network architectures using ``net_arch=dict(qf=[400, 300], pi=[64, 64])`` for off-policy algorithms (SAC, TD3, DDPG)
- Added Hindsight Experience Replay ``HER``. (@megan-klaiber)
- ``VecNormalize`` now supports ``gym.spaces.Dict`` observation spaces
- Support logging videos to Tensorboard (@SwamyDev)

Bug Fixes:
^^^^^^^^^^
- Fix GAE computation for on-policy algorithms (off-by one for the last value) (thanks @Wovchena)
- Fixed potential issue when loading a different environment
- Fix ignoring the exclude parameter when recording logs using json, csv or log as logging format (@SwamyDev)
- Make ``make_vec_env`` support the ``env_kwargs`` argument when using an env ID str (@ManifoldFR)
- Fix model creation initializing CUDA even when `device="cpu"` is provided
Expand All @@ -37,6 +40,7 @@ Others:
Documentation:
^^^^^^^^^^^^^^
- Added first draft of migration guide
- Enabled doc for ``CnnPolicies``


Pre-Release 0.9.0 (2020-10-03)
Expand Down Expand Up @@ -68,6 +72,7 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``
- Fix logging of ``clip_fraction`` in PPO (@diditforlulz273)
- Fixed a bug where cuda support was wrongly checked when passing the GPU index, e.g., ``device="cuda:0"`` (@liorcohen5)
Expand Down Expand Up @@ -160,7 +165,6 @@ Documentation:
- Fixed typo in custom policy doc (@RaphaelWag)



Pre-Release 0.7.0 (2020-06-10)
------------------------------

Expand Down Expand Up @@ -461,4 +465,4 @@ And all the contributors:
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber
21 changes: 20 additions & 1 deletion docs/modules/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ It uses multiple workers to avoid the use of a replay buffer.


.. warning::

If you find training unstable or want to match performance of stable-baselines A2C, consider using
``RMSpropTFLike`` optimizer from ``stable_baselines3.common.sb2_compat.rmsprop_tf_like``.
You can change optimizer with ``A2C(policy_kwargs=dict(optimizer_class=RMSpropTFLike))``.
Expand Down Expand Up @@ -79,3 +79,22 @@ Parameters
.. autoclass:: A2C
:members:
:inherited-members:


A2C Policies
-------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
:members:
:noindex:

.. autoclass:: CnnPolicy
:members:

.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
:members:
:noindex:
8 changes: 5 additions & 3 deletions docs/modules/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ DDPG Policies
:members:
:inherited-members:

.. autoclass:: stable_baselines3.td3.policies.TD3Policy
:members:
:noindex:

.. .. autoclass:: CnnPolicy
.. :members:
.. :inherited-members:
.. autoclass:: CnnPolicy
:members:
4 changes: 4 additions & 0 deletions docs/modules/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,5 +90,9 @@ DQN Policies
:members:
:inherited-members:

.. autoclass:: stable_baselines3.dqn.policies.DQNPolicy
:members:
:noindex:

.. autoclass:: CnnPolicy
:members:
121 changes: 121 additions & 0 deletions docs/modules/her.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
.. _her:

.. automodule:: stable_baselines3.her


HER
====

`Hindsight Experience Replay (HER) <https://arxiv.org/abs/1707.01495>`_

HER is an algorithm that works with off-policy methods (DQN, SAC, TD3 and DDPG for example).
HER uses the fact that even if a desired goal was not achieved, other goal may have been achieved during a rollout.
It creates "virtual" transitions by relabeling transitions (changing the desired goal) from past episodes.



.. warning::

HER requires the environment to inherits from `gym.GoalEnv <https://github.com/openai/gym/blob/3394e245727c1ae6851b504a50ba77c73cd4c65b/gym/core.py#L160>`_


.. warning::

For performance reasons, the maximum number of steps per episodes must be specified.
In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment
or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None).
Otherwise, you can directly pass ``max_episode_length`` to the model constructor


.. warning::

``HER`` supports ``VecNormalize`` wrapper but only when ``online_sampling=True``


Notes
-----

- Original paper: https://arxiv.org/abs/1707.01495
- OpenAI paper: `Plappert et al. (2018)`_
- OpenAI blog post: https://openai.com/blog/ingredients-for-robotics-research/


.. _Plappert et al. (2018): https://arxiv.org/abs/1802.09464

Can I use?
----------

Please refer to the used model (DQN, SAC, TD3 or DDPG) for that section.

Example
-------

.. code-block:: python
from stable_baselines3 import HER, DDPG, DQN, SAC, TD3
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper
model_class = DQN # works also with SAC, DDPG and TD3
N_BITS = 15
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS)
# Available strategies (cf paper): future, final, episode
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE
# If True the HER transitions will get sampled online
online_sampling = True
# Time limit for the episodes
max_episode_length = N_BITS
# Initialize the model
model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, online_sampling=online_sampling,
verbose=1, max_episode_length=max_episode_length)
# Train the model
model.learn(1000)
model.save("./her_bit_env")
model = HER.load('./her_bit_env', env=env)
obs = env.reset()
for _ in range(100):
action, _ = model.model.predict(obs, deterministic=True)
obs, reward, done, _ = env.step(action)
if done:
obs = env.reset()
Parameters
----------

.. autoclass:: HER
:members:

Goal Selection Strategies
-------------------------

.. autoclass:: GoalSelectionStrategy
:members:
:inherited-members:
:undoc-members:


Obs Dict Wrapper
----------------

.. autoclass:: ObsDictWrapper
:members:
:inherited-members:
:undoc-members:


HER Replay Buffer
-----------------

.. autoclass:: HerReplayBuffer
:members:
:inherited-members:
19 changes: 19 additions & 0 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,22 @@ Parameters
.. autoclass:: PPO
:members:
:inherited-members:


PPO Policies
-------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy
:members:
:noindex:

.. autoclass:: CnnPolicy
:members:

.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy
:members:
:noindex:
Loading

0 comments on commit 940ee5d

Please sign in to comment.