Skip to content

Commit

Permalink
Use the unwrapped environment to fix gymnasium warnings in PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 2, 2024
1 parent 82ea390 commit 1c6b53a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 21 deletions.
37 changes: 17 additions & 20 deletions skrl/envs/wrappers/torch/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping, Sequence, Tuple
from typing import Any, Mapping, Sequence, Tuple, Union

import gym

Expand All @@ -13,20 +13,16 @@ def __init__(self, env: Any) -> None:
:type env: Any supported RL environment
"""
self._env = env
try:
self._unwrapped = self._env.unwrapped
except:
self._unwrapped = env

# device (faster than @property)
if hasattr(self._env, "device"):
self.device = torch.device(self._env.device)
if hasattr(self._unwrapped, "device"):
self.device = torch.device(self._unwrapped.device)
else:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# spaces
try:
self._action_space = self._env.single_action_space
self._observation_space = self._env.single_observation_space
except AttributeError:
self._action_space = self._env.action_space
self._observation_space = self._env.observation_space
self._state_space = self._env.state_space if hasattr(self._env, "state_space") else self._observation_space

def __getattr__(self, key: str) -> Any:
"""Get an attribute from the wrapped environment
Expand All @@ -41,7 +37,9 @@ def __getattr__(self, key: str) -> Any:
"""
if hasattr(self._env, key):
return getattr(self._env, key)
raise AttributeError(f"Wrapped environment ({self._env.__class__.__name__}) does not have attribute '{key}'")
if hasattr(self._unwrapped, key):
return getattr(self._unwrapped, key)
raise AttributeError(f"Wrapped environment ({self._unwrapped.__class__.__name__}) does not have attribute '{key}'")

def reset(self) -> Tuple[torch.Tensor, Any]:
"""Reset the environment
Expand Down Expand Up @@ -82,36 +80,35 @@ def num_envs(self) -> int:
If the wrapped environment does not have the ``num_envs`` property, it will be set to 1
"""
return self._env.num_envs if hasattr(self._env, "num_envs") else 1
return self._unwrapped.num_envs if hasattr(self._unwrapped, "num_envs") else 1

@property
def num_agents(self) -> int:
"""Number of agents
If the wrapped environment does not have the ``num_agents`` property, it will be set to 1
"""
return self._env.num_agents if hasattr(self._env, "num_agents") else 1
return self._unwrapped.num_agents if hasattr(self._unwrapped, "num_agents") else 1

@property
def state_space(self) -> gym.Space:
def state_space(self) -> Union[gym.Space, None]:
"""State space
If the wrapped environment does not have the ``state_space`` property,
the value of the ``observation_space`` property will be used
If the wrapped environment does not have the ``state_space`` property, ``None`` will be returned
"""
return self._state_space
return self._unwrapped.state_space if hasattr(self._unwrapped, "state_space") else None

@property
def observation_space(self) -> gym.Space:
"""Observation space
"""
return self._observation_space
return self._unwrapped.observation_space

@property
def action_space(self) -> gym.Space:
"""Action space
"""
return self._action_space
return self._unwrapped.action_space


class MultiAgentEnvWrapper(object):
Expand Down
20 changes: 19 additions & 1 deletion skrl/envs/wrappers/torch/isaaclab_envs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Tuple

import gymnasium

import torch

from skrl.envs.wrappers.torch.base import Wrapper
Expand All @@ -17,7 +19,23 @@ def __init__(self, env: Any) -> None:
self._reset_once = True
self._obs_dict = None

self._observation_space = self._observation_space["policy"]
@property
def observation_space(self) -> gymnasium.Space:
"""Observation space
"""
try:
return self._unwrapped.single_observation_space["policy"]
except:
return self._unwrapped.observation_space["policy"]

@property
def action_space(self) -> gymnasium.Space:
"""Action space
"""
try:
return self._unwrapped.single_action_space
except:
return self._unwrapped.action_space

def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
"""Perform a step in the environment
Expand Down

0 comments on commit 1c6b53a

Please sign in to comment.