forked from DLR-RM/stable-baselines3
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
- Loading branch information
Showing
34 changed files
with
1,899 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,7 @@ Main Features | |
modules/a2c | ||
modules/ddpg | ||
modules/dqn | ||
modules/her | ||
modules/ppo | ||
modules/sac | ||
modules/td3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
.. _her: | ||
|
||
.. automodule:: stable_baselines3.her | ||
|
||
|
||
HER | ||
==== | ||
|
||
`Hindsight Experience Replay (HER) <https://arxiv.org/abs/1707.01495>`_ | ||
|
||
HER is an algorithm that works with off-policy methods (DQN, SAC, TD3 and DDPG for example). | ||
HER uses the fact that even if a desired goal was not achieved, other goal may have been achieved during a rollout. | ||
It creates "virtual" transitions by relabeling transitions (changing the desired goal) from past episodes. | ||
|
||
|
||
|
||
.. warning:: | ||
|
||
HER requires the environment to inherits from `gym.GoalEnv <https://github.com/openai/gym/blob/3394e245727c1ae6851b504a50ba77c73cd4c65b/gym/core.py#L160>`_ | ||
|
||
|
||
.. warning:: | ||
|
||
For performance reasons, the maximum number of steps per episodes must be specified. | ||
In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment | ||
or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None). | ||
Otherwise, you can directly pass ``max_episode_length`` to the model constructor | ||
|
||
|
||
.. warning:: | ||
|
||
``HER`` supports ``VecNormalize`` wrapper but only when ``online_sampling=True`` | ||
|
||
|
||
Notes | ||
----- | ||
|
||
- Original paper: https://arxiv.org/abs/1707.01495 | ||
- OpenAI paper: `Plappert et al. (2018)`_ | ||
- OpenAI blog post: https://openai.com/blog/ingredients-for-robotics-research/ | ||
|
||
|
||
.. _Plappert et al. (2018): https://arxiv.org/abs/1802.09464 | ||
|
||
Can I use? | ||
---------- | ||
|
||
Please refer to the used model (DQN, SAC, TD3 or DDPG) for that section. | ||
|
||
Example | ||
------- | ||
|
||
.. code-block:: python | ||
from stable_baselines3 import HER, DDPG, DQN, SAC, TD3 | ||
from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy | ||
from stable_baselines3.common.bit_flipping_env import BitFlippingEnv | ||
from stable_baselines3.common.vec_env import DummyVecEnv | ||
from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper | ||
model_class = DQN # works also with SAC, DDPG and TD3 | ||
N_BITS = 15 | ||
env = BitFlippingEnv(n_bits=N_BITS, continuous=model_class in [DDPG, SAC, TD3], max_steps=N_BITS) | ||
# Available strategies (cf paper): future, final, episode | ||
goal_selection_strategy = 'future' # equivalent to GoalSelectionStrategy.FUTURE | ||
# If True the HER transitions will get sampled online | ||
online_sampling = True | ||
# Time limit for the episodes | ||
max_episode_length = N_BITS | ||
# Initialize the model | ||
model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, online_sampling=online_sampling, | ||
verbose=1, max_episode_length=max_episode_length) | ||
# Train the model | ||
model.learn(1000) | ||
model.save("./her_bit_env") | ||
model = HER.load('./her_bit_env', env=env) | ||
obs = env.reset() | ||
for _ in range(100): | ||
action, _ = model.model.predict(obs, deterministic=True) | ||
obs, reward, done, _ = env.step(action) | ||
if done: | ||
obs = env.reset() | ||
Parameters | ||
---------- | ||
|
||
.. autoclass:: HER | ||
:members: | ||
|
||
Goal Selection Strategies | ||
------------------------- | ||
|
||
.. autoclass:: GoalSelectionStrategy | ||
:members: | ||
:inherited-members: | ||
:undoc-members: | ||
|
||
|
||
Obs Dict Wrapper | ||
---------------- | ||
|
||
.. autoclass:: ObsDictWrapper | ||
:members: | ||
:inherited-members: | ||
:undoc-members: | ||
|
||
|
||
HER Replay Buffer | ||
----------------- | ||
|
||
.. autoclass:: HerReplayBuffer | ||
:members: | ||
:inherited-members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.