diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 16a7737e5..d2a3d6060 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a10 (WIP) +Release 2.0.0a11 (WIP) -------------------------- **Gymnasium support** @@ -39,6 +39,8 @@ Bug Fixes: - Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel) - Set NumPy version to ``>=1.20`` due to use of ``numpy.typing`` (@troiganto) - Fixed loading DQN changes ``target_update_interval`` (@tobirohrer) +- Fixed env checker to properly reset the env before calling ``step()`` when checking + for ``Inf`` and ``NaN`` (@lutogniew) Deprecations: ^^^^^^^^^^^^^ @@ -1346,3 +1348,4 @@ And all the contributors: @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto +@lutogniew diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 058710df9..b6ce490df 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -110,6 +110,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act def _check_nan(env: gym.Env) -> None: """Check for Inf and NaN using the VecWrapper.""" vec_env = VecCheckNan(DummyVecEnv([lambda: env])) + vec_env.reset() for _ in range(10): action = np.array([env.action_space.sample()]) _, _, _, _ = vec_env.step(action) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 7385c4c8b..d70b1bb71 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.0.0a10 +2.0.0a11 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index e855e2137..c0a5e0610 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Any, Dict, Optional, Tuple import gymnasium as gym import numpy as np @@ -112,3 +112,47 @@ def step(self, action): test_env = TestEnv() with pytest.raises(AssertionError, match=error_message): check_env(env=test_env) + + +class LimitedStepsTestEnv(gym.Env): + action_space = spaces.Discrete(n=2) + observation_space = spaces.Discrete(n=2) + + def __init__(self, steps_before_termination: int = 1): + super().__init__() + + assert steps_before_termination >= 1 + self._steps_before_termination = steps_before_termination + + self._steps_called = 0 + self._terminated = False + + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[int, Dict]: + super().reset(seed=seed) + + self._steps_called = 0 + self._terminated = False + + return 0, {} + + def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, Any]]: + self._steps_called += 1 + + assert not self._terminated + + observation = 0 + reward = 0.0 + self._terminated = self._steps_called >= self._steps_before_termination + truncated = False + + return observation, reward, self._terminated, truncated, {} + + def render(self) -> None: + pass + + +def test_check_env_single_step_env(): + test_env = LimitedStepsTestEnv(steps_before_termination=1) + + # This should not throw + check_env(env=test_env, warn=True)