-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement HER #120
Implement HER #120
Conversation
docs/misc/changelog.rst
Outdated
@@ -25,6 +25,7 @@ New Features: | |||
- Refactored opening paths for saving and loading to use strings, pathlib or io.BufferedIOBase (@PartiallyTyped) | |||
- Added ``DDPG`` algorithm as a special case of ``TD3``. | |||
- Introduced ``BaseModel`` abstract parent for ``BasePolicy``, which critics inherit from. | |||
- Added Hindsight Experience Replay ``HER``. (@megan-klaiber) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you will need also to update the documentation: add HER to the module and to the examples (you can mostly copy-paste what was done in SB2 documentation ;))
stable_baselines3/her/her.py
Outdated
use_sde: bool = False, | ||
sde_sample_freq: int = -1, | ||
use_sde_at_warmup: bool = False, | ||
sde_support: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sde support should not be here
stable_baselines3/her/her.py
Outdated
self.goal_strategy, GoalSelectionStrategy | ||
), "Invalid goal selection strategy," "please use one of {}".format(list(GoalSelectionStrategy)) | ||
|
||
self.env = ObsWrapper(env) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should wrap it only afterward and check if the wrapper is needed or not
stable_baselines3/her/her.py
Outdated
|
||
assert isinstance( | ||
self.goal_strategy, GoalSelectionStrategy | ||
), "Invalid goal selection strategy," "please use one of {}".format(list(GoalSelectionStrategy)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
), "Invalid goal selection strategy," "please use one of {}".format(list(GoalSelectionStrategy)) | |
), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy))}" |
we require python 3.6+, so you can use f-strings
stable_baselines3/her/her.py
Outdated
# get arguments for the model initialization | ||
model_signature = signature(model.__init__) | ||
arguments = locals() | ||
model_init_dict = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need that because HER inherits from the off-policy class? I would make it inherit from the BaseAlgorithm
then.
It seems that you are initializing two models (and two replay buffers, including one that you don't use)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe keep OffpolicyAlgorithm
as base class but initialize empty buffer, so you can re-use learn()
from the base class
|
||
# buffer with episodes | ||
self.buffer = [] | ||
# TODO just for typing reason , need another solution |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
# TODO just for typing reason , need another solution | ||
self.observations = np.zeros((self.buffer_size, self.n_envs,) + self.obs_shape, dtype=observation_space.dtype) | ||
self.goal_strategy = goal_strategy | ||
self.her_ratio = 1 - (1.0 / (1 + her_ratio)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing comment, looks weird compared to what is described in the docstring
] | ||
|
||
# concatenate observation with (desired) goal | ||
obs = [np.concatenate([o["observation"], o["desired_goal"]], axis=1) for o in observations] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please avoid one character variable, you can use obs_
instead
her_episode_lenghts = episode_lengths[her_idxs] | ||
|
||
# get new goals with goal selection strategy | ||
if self.goal_strategy == GoalSelectionStrategy.FINAL: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this logic cannot be shared with the "offline" version?
stable_baselines3/her/obs_wrapper.py
Outdated
def close(self): | ||
return self.venv.close() | ||
|
||
def get_attr(self, attr_name, indices=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to re-implement those as they are already in the base wrapper class, no?
stable_baselines3/her/her.py
Outdated
self.model._last_original_obs, new_obs_, reward_ = observation, new_obs, reward | ||
|
||
# add current transition to episode storage | ||
self.episode_storage.append((self.model._last_original_obs, buffer_action, reward_, new_obs_, done)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be clearer to use NamedTuple (cf what is done for the replay buffer)
stable_baselines3/her/her.py
Outdated
self.model.actor.reset_noise() | ||
|
||
# Select action randomly or according to policy | ||
action, buffer_action = self.model._sample_action(learning_starts, action_noise) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after thinking more about it, I think we need to define __get__attr_
to automaticaly retrieve the attribute from self.model
if present. This would allow to write directly self._sample_action()
.
new_rewards = np.array(rewards) | ||
new_rewards[her_idxs] = [ | ||
self.env.env_method("compute_reward", ag, her_new_goals, None) | ||
for ag, new_goal in zip(achieved_goals, her_new_goals) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please avoid name without meaning: achieved_goal
instead of ag
;)
self.buffer[idx] = episode | ||
self.n_transitions_stored -= self.buffer[idx] - episode_length | ||
|
||
if self.n_transitions_stored == self.size(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be simplified
def get_current_size(self): | ||
return self.n_transitions_stored | ||
|
||
def get_transitions_stored(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe make self.n_transitions_stored
private and create a getter using @property
def get_transitions_stored(self): | ||
return self.n_transitions_stored | ||
|
||
def clear_buffer(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you need to re-initialize the number of transitions stored too, no?
def get_torch_variables(self) -> Tuple[List[str], List[str]]: | ||
return self.model.get_torch_variables() | ||
|
||
def save( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there no quicker way of doing that? (without duplicating too much code)
I would store HER specific arguments in the model (self.model.dict) , see what is done in SB2.
# sample virtual transitions and store them in replay buffer | ||
self._sample_her_transitions() | ||
# clear storage for current episode | ||
self._episode_storage.reset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not properly defined in the HER replay buffer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM =)
Tested on simulated envs and on a real robot, time to merge now.
|
depends on which Fetch please look at the rl zoo: https://github.com/DLR-RM/rl-baselines3-zoo. |
* Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
It looks like the code added in this PR seems to break normal envs with dict observation space since it assumes that whenever the observation space is a dictionary the user wants HER:
specifically, this code:
doesn't verify at all that HER is what I want, and assumes the dict has a specific purpose, breaking training any envs with dict obs space |
Please read #216 |
@araffin thanks for the info and links. might be good to throw a more readable error in that case. For now this works fine for me: class FlattenVecWrapper(VecEnvWrapper):
def __init__(self, env):
super().__init__(env)
self.observation_space = gym.spaces.flatten_space(self.venv.observation_space)
def reset(self, **kwargs):
observation = self.venv.reset(**kwargs)
return self.observation(observation)
def step_wait(self):
observation, reward, done, info = self.venv.step_wait()
return self.observation(observation), reward, done, info
def observation(self, observation):
return [gym.spaces.flatten(self.venv.observation_space, o) for o in observation] |
* Added working her version, Online sampling is missing. * Updated test_her. * Added first version of online her sampling. Still problems with tensor dimensions. * Reformat * Fixed tests * Added some comments. * Updated changelog. * Add missing init file * Fixed some small bugs. * Reduced arguments for HER, small changes. * Added getattr. Fixed bug for online sampling. * Updated save/load funtions. Small changes. * Added her to init. * Updated save method. * Updated her ratio. * Move obs_wrapper * Added DQN test. * Fix potential bug * Offline and online her share same sample_goal function. * Changed lists into arrays. * Updated her test. * Fix online sampling * Fixed action bug. Updated time limit for episodes. * Updated convert_dict method to take keys as arguments. * Renamed obs dict wrapper. * Seed bit flipping env * Remove get_episode_dict * Add fast online sampling version * Added documentation. * Vectorized reward computation * Vectorized goal sampling * Update time limit for episodes in online her sampling. * Fix max episode length inference * Bug fix for Fetch envs * Fix for HER + gSDE * Reformat (new black version) * Added info dict to compute new reward. Check her_replay_buffer again. * Fix info buffer * Updated done flag. * Fixes for gSDE * Offline her version uses now HerReplayBuffer as episode storage. * Fix num_timesteps computation * Fix get torch params * Vectorized version for offline sampling. * Modified offline her sampling to use sample method of her_replay_buffer * Updated HER tests. * Updated documentation * Cleanup docstrings * Updated to review comments * Fix pytype * Update according to review comments. * Removed random goal strategy. Updated sample transitions. * Updated migration. Removed time signal removal. * Update doc * Fix potential load issue * Add VecNormalize support for dict obs * Updated saving/loading replay buffer for HER. * Fix test memory usage * Fixed save/load replay buffer. * Fixed save/load replay buffer * Fixed transition index after loading replay buffer in online sampling * Better error handling * Add tests for get_time_limit * More tests for VecNormalize with dict obs * Update doc * Improve HER description * Add test for sde support * Add comments * Add comments * Remove check that was always valid * Fix for terminal observation * Updated buffer size in offline version and reset of HER buffer * Reformat * Update doc * Remove np.empty + add doc * Fix loading * Updated loading replay buffer * Separate online and offline sampling + bug fixes * Update tensorboard log name * Version bump * Bug fix for special case Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de> Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Description
HER inherits from
OffPolicyAlgorithm
and takes the model as an argument. It also implements its owncollect_rollout
function.HER can operate in two modes for now.
online_sampling
beingTrue
orFalse
. If True, HER samples are added while sampling, otherwise they are added at the end of an episode. If online sampling is used, a customHerReplayBuffer
will be used which stores the transitions episode-wise.Motivation and Context
closes #8
Types of changes
Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)Missing:
Note: we are using a maximum length of 127 characters per line
Results
Results on https://github.com/eleurent/highway-env
her_parking.pdf
her.pdf