Skip to content

Commit

Permalink
Update gymnasium wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 3, 2024
1 parent ae66fe7 commit 7e4d1cd
Showing 1 changed file with 10 additions and 23 deletions.
33 changes: 10 additions & 23 deletions skrl/envs/wrappers/torch/gymnasium_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,14 @@ def __init__(self, env: Any) -> None:

self._vectorized = False
try:
if isinstance(env, gymnasium.vector.SyncVectorEnv) or isinstance(env, gymnasium.vector.AsyncVectorEnv):
if isinstance(env, gymnasium.vector.VectorEnv) or isinstance(env, gymnasium.experimental.vector.VectorEnv):
self._vectorized = True
self._reset_once = True
self._obs_tensor = None
self._info_dict = None
self._observation = None
self._info = None
except Exception as e:
logger.warning(f"Failed to check for a vectorized environment: {e}")

@property
def state_space(self) -> gymnasium.Space:
"""State space
An alias for the ``observation_space`` property
"""
if self._vectorized:
return self._env.single_observation_space
return self._env.observation_space

@property
def observation_space(self) -> gymnasium.Space:
"""Observation space
Expand Down Expand Up @@ -133,8 +123,8 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch

# save observation and info for vectorized envs
if self._vectorized:
self._obs_tensor = observation
self._info_dict = info
self._observation = observation
self._info = info

return observation, reward, terminated, truncated, info

Expand All @@ -147,19 +137,16 @@ def reset(self) -> Tuple[torch.Tensor, Any]:
# handle vectorized envs
if self._vectorized:
if not self._reset_once:
return self._obs_tensor, self._info_dict
return self._observation, self._info
self._reset_once = False

# reset the env/envs
observation, info = self._env.reset()
return self._observation_to_tensor(observation), info

def render(self, *args, **kwargs) -> None:
def render(self, *args, **kwargs) -> Any:
"""Render the environment
"""
self._env.render(*args, **kwargs)

def close(self) -> None:
"""Close the environment
"""
self._env.close()
if self._vectorized:
return self._env.call("render", *args, **kwargs)
return self._env.render(*args, **kwargs)

0 comments on commit 7e4d1cd

Please sign in to comment.