diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 97b8735db..f42f75872 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Pre-Release 0.7.0a0 (WIP) +Pre-Release 0.7.0a1 (WIP) ------------------------------ Breaking Changes: @@ -18,6 +18,7 @@ Bug Fixes: ^^^^^^^^^^ - Fixed ``render()`` method for ``VecEnvs`` - Fixed ``seed()``` method for ``SubprocVecEnv`` +- Fixed loading on GPU for testing when using gSDE and ``deterministic=False`` Deprecations: ^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index e64243cce..c35de3d1d 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -375,6 +375,10 @@ def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs): for name in tensors: recursive_setattr(model, name, tensors[name]) + # Sample gSDE exploration matrix, so it uses the right device + # see issue #44 + if model.use_sde: + model.policy.reset_noise() return model @staticmethod diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 5f7526435..78482f44a 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -128,8 +128,8 @@ def get_std(self) -> th.Tensor: :return: (th.Tensor) """ - assert isinstance(self.action_dist, StateDependentNoiseDistribution), \ - 'get_std() is only available when using gSDE' + msg = 'get_std() is only available when using gSDE' + assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg return self.action_dist.get_std(self.log_std) def reset_noise(self, batch_size: int = 1) -> None: @@ -138,8 +138,8 @@ def reset_noise(self, batch_size: int = 1) -> None: :param batch_size: (int) """ - assert isinstance(self.action_dist, StateDependentNoiseDistribution), \ - 'reset_noise() is only available when using gSDE' + msg = 'reset_noise() is only available when using gSDE' + assert isinstance(self.action_dist, StateDependentNoiseDistribution), msg self.action_dist.sample_weights(self.log_std, batch_size=batch_size) def get_action_dist_params(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, Dict[str, th.Tensor]]: @@ -354,6 +354,14 @@ def _get_data(self) -> Dict[str, Any]: )) return data + def reset_noise(self, batch_size: int = 1) -> None: + """ + Sample new weights for the exploration matrix, when using gSDE. + + :param batch_size: (int) + """ + self.actor.reset_noise(batch_size=batch_size) + def make_actor(self) -> Actor: return Actor(**self.actor_kwargs).to(self.device) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 93acf06b0..cde2c3fbb 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -0.7.0a0 +0.7.0a1