Skip to content

Commit

Permalink
Update Isaac Gym preview wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 5, 2024
1 parent d2edbcf commit 9c2b8a8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
2 changes: 1 addition & 1 deletion skrl/envs/wrappers/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _in(values, container):
return "isaaclab"
elif _in("omni.isaac.gym.vec_env.vec_env_base.VecEnvBase", base_classes) or _in("omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT", base_classes):
return "omniverse-isaacgym"
elif _in("tasks..*.VecTask", base_classes):
elif _in(["isaacgymenvs..*", "tasks..*.VecTask"], base_classes):
return "isaacgym-preview4" # preview 4 is the same as 3
elif _in("rlgpu.tasks..*.VecTask", base_classes):
return "isaacgym-preview2"
Expand Down
41 changes: 28 additions & 13 deletions skrl/envs/wrappers/torch/isaacgym_envs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Tuple
from typing import Any, Tuple, Union

import gym

import torch

Expand All @@ -15,7 +17,8 @@ def __init__(self, env: Any) -> None:
super().__init__(env)

self._reset_once = True
self._obs_buf = None
self._observations = None
self._info = {}

def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
"""Perform a step in the environment
Expand All @@ -26,9 +29,9 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of torch.Tensor and any other info
"""
self._obs_buf, reward, terminated, info = self._env.step(actions)
truncated = info["time_outs"] if "time_outs" in info else torch.zeros_like(terminated)
return self._obs_buf, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), info
self._observations, reward, terminated, self._info = self._env.step(actions)
truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated)
return self._observations, reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info

def reset(self) -> Tuple[torch.Tensor, Any]:
"""Reset the environment
Expand All @@ -37,9 +40,9 @@ def reset(self) -> Tuple[torch.Tensor, Any]:
:rtype: torch.Tensor and any other info
"""
if self._reset_once:
self._obs_buf = self._env.reset()
self._observations = self._env.reset()
self._reset_once = False
return self._obs_buf, {}
return self._observations, self._info

def render(self, *args, **kwargs) -> None:
"""Render the environment
Expand All @@ -62,7 +65,19 @@ def __init__(self, env: Any) -> None:
super().__init__(env)

self._reset_once = True
self._obs_dict = None
self._observations = None
self._info = {}

@property
def state_space(self) -> Union[gym.Space, None]:
"""State space
"""
try:
if self.num_states:
return self._unwrapped.state_space
except:
pass
return None

def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Any]:
"""Perform a step in the environment
Expand All @@ -73,9 +88,9 @@ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of torch.Tensor and any other info
"""
self._obs_dict, reward, terminated, info = self._env.step(actions)
truncated = info["time_outs"] if "time_outs" in info else torch.zeros_like(terminated)
return self._obs_dict["obs"], reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), info
self._observations, reward, terminated, self._info = self._env.step(actions)
truncated = self._info["time_outs"] if "time_outs" in self._info else torch.zeros_like(terminated)
return self._observations["obs"], reward.view(-1, 1), terminated.view(-1, 1), truncated.view(-1, 1), self._info

def reset(self) -> Tuple[torch.Tensor, Any]:
"""Reset the environment
Expand All @@ -84,9 +99,9 @@ def reset(self) -> Tuple[torch.Tensor, Any]:
:rtype: torch.Tensor and any other info
"""
if self._reset_once:
self._obs_dict = self._env.reset()
self._observations = self._env.reset()
self._reset_once = False
return self._obs_dict["obs"], {}
return self._observations["obs"], self._info

def render(self, *args, **kwargs) -> None:
"""Render the environment
Expand Down

0 comments on commit 9c2b8a8

Please sign in to comment.