Skip to content

Commit

Permalink
Fix reset info being lost in vector environments (#3111)
Browse files Browse the repository at this point in the history
* Fix reset info

* Added test for checking vector info
  • Loading branch information
pseudo-rnd-thoughts committed Oct 4, 2022
1 parent 21e6e27 commit 1486d33
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 6 deletions.
7 changes: 4 additions & 3 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,9 +566,10 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
info,
) = env.step(data)
if terminated or truncated:
old_observation = observation
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
pipe.send(((observation, reward, terminated, truncated, info), True))
elif command == "seed":
env.seed(data)
Expand Down Expand Up @@ -636,10 +637,10 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
info,
) = env.step(data)
if terminated or truncated:
old_observation = observation
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation

info["final_info"] = old_info
write_to_shared_memory(
observation_space, index, observation, shared_memory
)
Expand Down
3 changes: 2 additions & 1 deletion gym/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ def step_wait(self):
) = env.step(action)

if self._terminateds[i] or self._truncateds[i]:
old_observation = observation
old_observation, old_info = observation, info
observation, info = env.reset()
info["final_observation"] = old_observation
info["final_info"] = old_info
observations.append(observation)
infos = self._add_info(infos, info, i)
self.observations = concatenate(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"box2d": ["box2d-py==2.3.5", "pygame==2.1.0", "swig==4.*"],
"classic_control": ["pygame==2.1.0"],
"mujoco_py": ["mujoco_py<2.2,>=2.1"],
"mujoco": ["mujoco==2.2.0", "imageio>=2.14.1"],
"mujoco": ["mujoco==2.2", "imageio>=2.14.1"],
"toy_text": ["pygame==2.1.0"],
"other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
}
Expand Down
67 changes: 66 additions & 1 deletion tests/vector/test_vector_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from functools import partial

import numpy as np
import pytest

from gym.spaces import Tuple
from gym.spaces import Discrete, Tuple
from gym.vector.async_vector_env import AsyncVectorEnv
from gym.vector.sync_vector_env import SyncVectorEnv
from gym.vector.vector_env import VectorEnv
from tests.testing_env import GenericTestEnv
from tests.vector.utils import CustomSpace, make_env


Expand Down Expand Up @@ -58,3 +61,65 @@ def test_custom_space_vector_env():

assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)


@pytest.mark.parametrize(
"vectoriser",
(
SyncVectorEnv,
partial(AsyncVectorEnv, shared_memory=True),
partial(AsyncVectorEnv, shared_memory=False),
),
ids=["Sync", "Async with shared memory", "Async without shared memory"],
)
def test_final_obs_info(vectoriser):
"""Tests that the vector environments correctly return the final observation and info."""

def reset_fn(self, seed=None, options=None):
return 0, {"reset": True}

def thunk():
return GenericTestEnv(
action_space=Discrete(4),
observation_space=Discrete(4),
reset_fn=reset_fn,
step_fn=lambda self, action: (
action if action < 3 else 0,
0,
action >= 3,
False,
{"action": action},
),
)

env = vectoriser([thunk])
obs, info = env.reset()
assert obs == np.array([0]) and info == {
"reset": np.array([True]),
"_reset": np.array([True]),
}

obs, _, termination, _, info = env.step([1])
assert (
obs == np.array([1])
and termination == np.array([False])
and info == {"action": np.array([1]), "_action": np.array([True])}
)

obs, _, termination, _, info = env.step([2])
assert (
obs == np.array([2])
and termination == np.array([False])
and info == {"action": np.array([2]), "_action": np.array([True])}
)

obs, _, termination, _, info = env.step([3])
assert (
obs == np.array([0])
and termination == np.array([True])
and info["reset"] == np.array([True])
)
assert "final_observation" in info and "final_info" in info
assert info["final_observation"] == np.array([0]) and info["final_info"] == {
"action": 3
}

0 comments on commit 1486d33

Please sign in to comment.