Skip to content

Commit

Permalink
Use regular expression to detect environment types
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 4, 2024
1 parent 9e5b5e4 commit b549576
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions skrl/envs/wrappers/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Any, Union

import gym
import gymnasium
import re

from skrl import logger
from skrl.envs.wrappers.torch.base import MultiAgentEnvWrapper, Wrapper
Expand Down Expand Up @@ -70,10 +69,13 @@ def wrap_env(env: Any, wrapper: str = "auto", verbose: bool = True) -> Union[Wra
:rtype: Wrapper or MultiAgentEnvWrapper
"""
def _get_wrapper_name(env, verbose):
def _in(value, container):
def _in(values, container):
if type(values) == str:
values = [values]
for item in container:
if value in item:
return True
for value in values:
if value in item or re.match(value, item):
return True
return False

base_classes = [str(base).replace("<class '", "").replace("'>", "") for base in env.__class__.__bases__]
Expand All @@ -85,7 +87,7 @@ def _in(value, container):
if verbose:
logger.info(f"Environment wrapper: 'auto' (class: {', '.join(base_classes)})")

if _in("omni.isaac.lab.envs.manager_based_env.ManagerBasedEnv", base_classes) or _in("omni.isaac.lab.envs.direct_rl_env.DirectRLEnv", base_classes):
if _in("omni.isaac.lab.envs..*", base_classes):
return "isaaclab"
elif _in("omni.isaac.gym.vec_env.vec_env_base.VecEnvBase", base_classes) or _in("omni.isaac.gym.vec_env.vec_env_mt.VecEnvMT", base_classes):
return "omniverse-isaacgym"
Expand All @@ -97,7 +99,7 @@ def _in(value, container):
return "dm"
elif _in("pettingzoo.utils.env", base_classes) or _in("pettingzoo.utils.wrappers", base_classes):
return "pettingzoo"
elif _in("gymnasium.core.Env", base_classes) or _in("gymnasium.core.Wrapper", base_classes):
elif _in("gymnasium..*", base_classes):
return "gymnasium"
elif _in("gym.core.Env", base_classes) or _in("gym.core.Wrapper", base_classes):
return "gym"
Expand Down

0 comments on commit b549576

Please sign in to comment.