Skip to content

Commit

Permalink
Update model definitions to support different input spaces in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Oct 7, 2024
1 parent 5f8b615 commit 722dfa2
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 11 deletions.
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/jax/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from skrl.models.jax import CategoricalMixin # noqa
from skrl.models.jax import Model # noqa
from skrl.utils.model_instantiators.jax.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa


def categorical_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -84,6 +85,8 @@ def setup(self):
{networks}
def __call__(self, inputs, role):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))
{forward}
return output, {{}}
"""
Expand Down
12 changes: 6 additions & 6 deletions skrl/utils/model_instantiators/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def visit_Call(self, node: ast.Call):
node.func = ast.Attribute(value=ast.Name("jnp"), attr="concatenate")
node.keywords = [ast.keyword(arg="axis", value=ast.Constant(value=-1))]
# operation: permute
if node.func.id == "permute":
elif node.func.id == "permute":
node.func = ast.Attribute(value=ast.Name("jnp"), attr="permute_dims")
return node

Expand All @@ -61,11 +61,11 @@ def visit_Call(self, node: ast.Call):
NodeTransformer().visit(tree)
source = ast.unparse(tree)
# enum substitutions
source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", 'jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)')
source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", 'jnp.concatenate([inputs["states"], inputs["taken_actions"]], axis=-1)')
source = source.replace("Shape.STATES", "STATES").replace("STATES", 'inputs["states"]')
source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", 'inputs["states"]')
source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", 'inputs["taken_actions"]')
source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", "jnp.concatenate([states, taken_actions], axis=-1)")
source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", "jnp.concatenate([states, taken_actions], axis=-1)")
source = source.replace("Shape.STATES", "STATES").replace("STATES", "states")
source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", "states")
source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", "taken_actions")
return source

def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequence[str]], Sequence[str], int]:
Expand Down
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/jax/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from skrl.models.jax import DeterministicMixin # noqa
from skrl.models.jax import Model # noqa
from skrl.utils.model_instantiators.jax.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa


def deterministic_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -81,6 +82,8 @@ def setup(self):
{networks}
def __call__(self, inputs, role):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))
{forward}
return output, {{}}
"""
Expand Down
3 changes: 3 additions & 0 deletions skrl/utils/model_instantiators/jax/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from skrl.models.jax import GaussianMixin # noqa
from skrl.models.jax import Model # noqa
from skrl.utils.model_instantiators.jax.common import convert_deprecated_parameters, generate_containers
from skrl.utils.spaces.jax import unflatten_tensorized_space # noqa


def gaussian_model(observation_space: Optional[Union[int, Tuple[int], gym.Space, gymnasium.Space]] = None,
Expand Down Expand Up @@ -95,6 +96,8 @@ def setup(self):
self.log_std_parameter = self.param("log_std_parameter", lambda _: {initial_log_std} * jnp.ones({output["size"]}))
def __call__(self, inputs, role):
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))
taken_actions = unflatten_tensorized_space(self.action_space, inputs.get("taken_actions"))
{forward}
return output, self.log_std_parameter, {{}}
"""
Expand Down
10 changes: 5 additions & 5 deletions skrl/utils/model_instantiators/torch/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def visit_Call(self, node: ast.Call):
NodeTransformer().visit(tree)
source = ast.unparse(tree)
# enum substitutions
source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", 'torch.cat((states, taken_actions), dim=1)')
source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", 'torch.cat((states, taken_actions), dim=1)')
source = source.replace("Shape.STATES", "STATES").replace("STATES", 'states')
source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", 'states')
source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", 'taken_actions')
source = source.replace("Shape.STATES_ACTIONS", "STATES_ACTIONS").replace("STATES_ACTIONS", "torch.cat([states, taken_actions], dim=1)")
source = source.replace("Shape.OBSERVATIONS_ACTIONS", "OBSERVATIONS_ACTIONS").replace("OBSERVATIONS_ACTIONS", "torch.cat([states, taken_actions], dim=1)")
source = source.replace("Shape.STATES", "STATES").replace("STATES", "states")
source = source.replace("Shape.OBSERVATIONS", "OBSERVATIONS").replace("OBSERVATIONS", "states")
source = source.replace("Shape.ACTIONS", "ACTIONS").replace("ACTIONS", "taken_actions")
return source

def _parse_output(source: Union[str, Sequence[str]]) -> Tuple[Union[str, Sequence[str]], Sequence[str], int]:
Expand Down

0 comments on commit 722dfa2

Please sign in to comment.