Skip to content

Commit

Permalink
Fix env checker single-step-env edge case (#1524)
Browse files Browse the repository at this point in the history
* Fix env checker single-step-env edge case

Before this change, env checker failed to `reset()` the tested
environment before calling `step()` when checking for `Inf` / `NaN`.
This could cause environments which happened to have only one `step()`
available before the episode was terminated to fail.

This is now fixed.

* Code review fixes #1

As suggested by Antonin Raffin <antonin.raffin@ensta.org>.
  • Loading branch information
lutogniew authored May 25, 2023
1 parent 1bfb55d commit e763163
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 3 deletions.
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.0.0a10 (WIP)
Release 2.0.0a11 (WIP)
--------------------------

**Gymnasium support**
Expand Down Expand Up @@ -39,6 +39,8 @@ Bug Fixes:
- Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel)
- Set NumPy version to ``>=1.20`` due to use of ``numpy.typing`` (@troiganto)
- Fixed loading DQN changes ``target_update_interval`` (@tobirohrer)
- Fixed env checker to properly reset the env before calling ``step()`` when checking
for ``Inf`` and ``NaN`` (@lutogniew)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -1346,3 +1348,4 @@ And all the contributors:
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew
1 change: 1 addition & 0 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
vec_env.reset()
for _ in range(10):
action = np.array([env.action_space.sample()])
_, _, _, _ = vec_env.step(action)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0a10
2.0.0a11
46 changes: 45 additions & 1 deletion tests/test_env_checker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Any, Dict, Optional, Tuple

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -112,3 +112,47 @@ def step(self, action):
test_env = TestEnv()
with pytest.raises(AssertionError, match=error_message):
check_env(env=test_env)


class LimitedStepsTestEnv(gym.Env):
action_space = spaces.Discrete(n=2)
observation_space = spaces.Discrete(n=2)

def __init__(self, steps_before_termination: int = 1):
super().__init__()

assert steps_before_termination >= 1
self._steps_before_termination = steps_before_termination

self._steps_called = 0
self._terminated = False

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[int, Dict]:
super().reset(seed=seed)

self._steps_called = 0
self._terminated = False

return 0, {}

def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, Any]]:
self._steps_called += 1

assert not self._terminated

observation = 0
reward = 0.0
self._terminated = self._steps_called >= self._steps_before_termination
truncated = False

return observation, reward, self._terminated, truncated, {}

def render(self) -> None:
pass


def test_check_env_single_step_env():
test_env = LimitedStepsTestEnv(steps_before_termination=1)

# This should not throw
check_env(env=test_env, warn=True)

0 comments on commit e763163

Please sign in to comment.