Skip to content

Commit

Permalink
Use spaces utils to process actions in PettingZoo wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 2, 2024
1 parent 52ce622 commit 80a8f02
Showing 1 changed file with 7 additions and 19 deletions.
26 changes: 7 additions & 19 deletions skrl/envs/wrappers/torch/pettingzoo_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import torch

from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper
from skrl.utils.spaces.torch import flatten_tensorized_space, tensorize_space
from skrl.utils.spaces.torch import (
flatten_tensorized_space,
tensorize_space,
unflatten_tensorized_space,
untensorize_space
)


class PettingZooWrapper(MultiAgentEnvWrapper):
Expand All @@ -19,23 +24,6 @@ def __init__(self, env: Any) -> None:
"""
super().__init__(env)

def _tensor_to_action(self, actions: torch.Tensor, space: gymnasium.Space) -> Any:
"""Convert the action to the Gymnasium expected format
:param actions: The actions to perform
:type actions: torch.Tensor
:raise ValueError: If the action space type is not supported
:return: The action in the Gymnasium format
:rtype: Any supported Gymnasium action space
"""
if isinstance(space, gymnasium.spaces.Discrete):
return actions.item()
elif isinstance(space, gymnasium.spaces.Box):
return np.array(actions.cpu().numpy(), dtype=space.dtype).reshape(space.shape)
raise ValueError(f"Action space type {type(space)} not supported. Please report this issue")

def step(self, actions: Mapping[str, torch.Tensor]) -> \
Tuple[Mapping[str, torch.Tensor], Mapping[str, torch.Tensor],
Mapping[str, torch.Tensor], Mapping[str, torch.Tensor], Mapping[str, Any]]:
Expand All @@ -47,7 +35,7 @@ def step(self, actions: Mapping[str, torch.Tensor]) -> \
:return: Observation, reward, terminated, truncated, info
:rtype: tuple of dictionaries torch.Tensor and any other info
"""
actions = {uid: self._tensor_to_action(action, self.action_space(uid)) for uid, action in actions.items()}
actions = {uid: untensorize_space(self.action_spaces[uid], unflatten_tensorized_space(self.action_spaces[uid], action)) for uid, action in actions.items()}
observations, rewards, terminated, truncated, infos = self._env.step(actions)

# convert response to torch
Expand Down

0 comments on commit 80a8f02

Please sign in to comment.