Skip to content

Commit

Permalink
Add template for Isaac Lab examples in JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Aug 3, 2024
1 parent 3ca1ac8 commit 35d4991
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 43 deletions.
86 changes: 50 additions & 36 deletions docs/source/examples/isaaclab/generator/isaaclab_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,44 +110,58 @@ def generate_yaml(self) -> None:
file.write(content)

def generate_python_script(self) -> None:
def convert_hidden_activation(activations):
def convert_hidden_activation(activations, framework):
mapping = {
"": "Identity",
"relu": "ReLU",
"tanh": "Tanh",
"sigmoid": "Sigmoid",
"leaky_relu": "LeakyReLU",
"elu": "ELU",
"softplus": "Softplus",
"softsign": "Softsign",
"selu": "SELU",
"softmax": "Softmax",
"torch": {
"": "Identity",
"relu": "ReLU",
"tanh": "Tanh",
"sigmoid": "Sigmoid",
"leaky_relu": "LeakyReLU",
"elu": "ELU",
"softplus": "Softplus",
"softsign": "Softsign",
"selu": "SELU",
"softmax": "Softmax",
},
"jax": {
"relu": "relu",
"tanh": "tanh",
"sigmoid": "sigmoid",
"leaky_relu": "leaky_relu",
"elu": "elu",
"softplus": "softplus",
"softsign": "soft_sign",
"selu": "selu",
"softmax": "softmax",
},
}
return [mapping[activation] for activation in activations]

content = ""
if self.library == "skrl":
# generate file name
os.makedirs("skrl_examples", exist_ok=True)
task_name = "_".join([item.lower() for item in self.cfg["metadata"]["task"].split("-")[1:-1]])
path = os.path.join("skrl_examples", f"torch_{task_name}_ppo.py")
with open("templates/ppo_skrl_py_torch") as file:
content = file.read()
if not content:
raise ValueError
# update config
self.cfg["models"]["policy"]["hidden_activation"] = convert_hidden_activation(
self.cfg["models"]["policy"]["hidden_activation"]
)
self.cfg["models"]["value"]["hidden_activation"] = convert_hidden_activation(
self.cfg["models"]["value"]["hidden_activation"]
)
# render template
template = Template(content, keep_trailing_newline=True, trim_blocks=True, lstrip_blocks=True)
content = template.render(self.cfg)
# save file
with open(path, "w") as file:
file.write(content)
return [mapping[framework][activation] for activation in activations]

task_name = "_".join([item.lower() for item in self.cfg["metadata"]["task"].split("-")[1:-1]])
for framework in ["torch", "jax"]:
content = ""
if self.library == "skrl":
# generate file name
os.makedirs("skrl_examples", exist_ok=True)
path = os.path.join("skrl_examples", f"{framework}_{task_name}_ppo.py")
with open(f"templates/ppo_skrl_py_{framework}") as file:
content = file.read()
if not content:
raise ValueError
# update config
self.cfg["models"]["policy"][f"hidden_activation__{framework}"] = convert_hidden_activation(
self.cfg["models"]["policy"]["hidden_activation"], framework
)
self.cfg["models"]["value"][f"hidden_activation__{framework}"] = convert_hidden_activation(
self.cfg["models"]["value"]["hidden_activation"], framework
)
# render template
template = Template(content, keep_trailing_newline=True, trim_blocks=True, lstrip_blocks=True)
content = template.render(self.cfg)
# save file
with open(path, "w") as file:
file.write(content)


if __name__ == "__main__":
Expand Down
163 changes: 163 additions & 0 deletions docs/source/examples/isaaclab/generator/templates/ppo_skrl_py_jax
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import flax.linen as nn
import jax
import jax.numpy as jnp

# import the skrl components to build the RL system
from skrl import config
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.envs.loaders.jax import load_isaaclab_env
from skrl.envs.wrappers.jax import wrap_env
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model
from skrl.resources.preprocessors.jax import RunningStandardScaler
from skrl.resources.schedulers.jax import KLAdaptiveRL
from skrl.trainers.jax import SequentialTrainer
from skrl.utils import set_seed


config.jax.backend = "jax" # or "numpy"


# seed for reproducibility
set_seed() # e.g. `set_seed(40)` for fixed seed


# define models (stochastic and deterministic models) using mixins
class Policy(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
{% for index in range(models.policy.hiddens | length) %}
{% if loop.first %}
x = nn.{{ models.policy.hidden_activation__jax[index] }}(nn.Dense({{ models.policy.hiddens[index] }})(inputs["states"]))
{% else %}
x = nn.{{ models.policy.hidden_activation__jax[index] }}(nn.Dense({{ models.policy.hiddens[index] }})(x))
{% endif %}
{% endfor %}
x = nn.Dense(self.num_actions)(x)
log_std = self.param("log_std", lambda _: jnp.ones(self.num_actions))
return x, log_std, {}

class Value(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
DeterministicMixin.__init__(self, clip_actions)

@nn.compact # marks the given module method allowing inlined submodules
def __call__(self, inputs, role):
{% for index in range(models.value.hiddens | length) %}
{% if loop.first %}
x = nn.{{ models.value.hidden_activation__jax[index] }}(nn.Dense({{ models.value.hiddens[index] }})(inputs["states"]))
{% else %}
x = nn.{{ models.value.hidden_activation__jax[index] }}(nn.Dense({{ models.value.hiddens[index] }})(x))
{% endif %}
{% endfor %}
x = nn.Dense(1)(x)
return x, {}


# load and wrap the Isaac Lab environment
env = load_isaaclab_env(task_name="{{ metadata.task }}")
env = wrap_env(env)

device = env.device


# instantiate a memory as rollout buffer (any memory can be used for this)
memory = RandomMemory(memory_size={{ agent.rollouts }}, num_envs=env.num_envs, device=device)


# instantiate the agent's models (function approximators).
# PPO requires 2 models, visit its documentation for more details
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html#models
models = {}
models["policy"] = Policy(env.observation_space, env.action_space, device)
models["value"] = Value(env.observation_space, env.action_space, device)

# instantiate models' state dict
for role, model in models.items():
model.init_state_dict(role)


# configure and instantiate the agent (visit its documentation to see all the options)
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html#configuration-and-hyperparameters
cfg = PPO_DEFAULT_CONFIG.copy()
cfg["rollouts"] = {{ agent.rollouts }} # memory_size
cfg["learning_epochs"] = {{ agent.learning_epochs }}
cfg["mini_batches"] = {{ agent.mini_batches }}
cfg["discount_factor"] = {{ agent.discount_factor }}
cfg["lambda"] = {{ agent.lambda }}
cfg["learning_rate"] = {{ "%.1e" | format(agent.learning_rate) }}
{% if agent.learning_rate_scheduler == "KLAdaptiveLR" %}
cfg["learning_rate_scheduler"] = KLAdaptiveRL
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": {{ agent.learning_rate_scheduler_kwargs.kl_threshold }}}
{% else %}
cfg["learning_rate_scheduler"] = None
cfg["learning_rate_scheduler_kwargs"] = {}
{% endif %}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 0
cfg["grad_norm_clip"] = {{ agent.grad_norm_clip }}
cfg["ratio_clip"] = {{ agent.ratio_clip }}
cfg["value_clip"] = {{ agent.value_clip }}
cfg["clip_predicted_values"] = {{ agent.clip_predicted_values }}
cfg["entropy_loss_scale"] = {{ agent.entropy_loss_scale }}
cfg["value_loss_scale"] = {{ agent.value_loss_scale }}
cfg["kl_threshold"] = {{ agent.kl_threshold }}
{% if agent.rewards_shaper_scale == 1.0 %}
cfg["rewards_shaper"] = None
{% else %}
cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * {{ agent.rewards_shaper_scale }}
{% endif %}
cfg["time_limit_bootstrap"] = True
{% if agent.state_preprocessor == "RunningStandardScaler" %}
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
{% else %}
cfg["state_preprocessor"] = None
cfg["state_preprocessor_kwargs"] = {}
{% endif %}
{% if agent.value_preprocessor == "RunningStandardScaler" %}
cfg["value_preprocessor"] = RunningStandardScaler
cfg["value_preprocessor_kwargs"] = {"size": 1, "device": device}
{% else %}
cfg["value_preprocessor"] = None
cfg["value_preprocessor_kwargs"] = {}
{% endif %}
# logging to TensorBoard and write checkpoints (in timesteps)
cfg["experiment"]["write_interval"] = "auto"
cfg["experiment"]["checkpoint_interval"] = "auto"
cfg["experiment"]["directory"] = "runs/jax/{{ metadata.task }}"

agent = PPO(models=models,
memory=memory,
cfg=cfg,
observation_space=env.observation_space,
action_space=env.action_space,
device=device)


# configure and instantiate the RL trainer
cfg_trainer = {"timesteps": {{ trainer.timesteps }}, "headless": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent)

# start training
trainer.train()


# # ---------------------------------------------------------
# # comment the code above: `trainer.train()`, and...
# # uncomment the following lines to evaluate a trained agent
# # ---------------------------------------------------------
# from skrl.utils.huggingface import download_model_from_huggingface

# # download the trained agent's checkpoint from Hugging Face Hub and load it
# path = download_model_from_huggingface("skrl/IsaacLab-{{ metadata.task }}-PPO", filename="agent.pickle")
# agent.load(path)

# # start evaluation
# trainer.eval()
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ class Policy(GaussianMixin, Model):
{% for index in range(models.policy.hiddens | length) %}
{% if loop.first %}
self.net = nn.Sequential(nn.Linear(self.num_observations, {{ models.policy.hiddens | first }}),
nn.{{ models.policy.hidden_activation | first }}(),
nn.{{ models.policy.hidden_activation__torch | first }}(),
{% elif loop.last %}
nn.Linear({{ models.policy.hiddens[loop.previtem] }}, {{ models.policy.hiddens | last }}),
nn.{{ models.policy.hidden_activation | last }}(),
nn.{{ models.policy.hidden_activation__torch | last }}(),
nn.Linear({{ models.policy.hiddens | last }}, self.num_actions))
{% else %}
nn.Linear({{ models.policy.hiddens[loop.previtem] }}, {{ models.policy.hiddens[index] }}),
nn.{{ models.policy.hidden_activation[index] }}(),
nn.{{ models.policy.hidden_activation__torch[index] }}(),
{% endif %}
{% endfor %}
self.log_std_parameter = nn.Parameter(torch.ones(self.num_actions))
Expand All @@ -51,14 +51,14 @@ class Value(DeterministicMixin, Model):
{% for index in range(models.value.hiddens | length) %}
{% if loop.first %}
self.net = nn.Sequential(nn.Linear(self.num_observations, {{ models.value.hiddens | first }}),
nn.{{ models.value.hidden_activation | first }}(),
nn.{{ models.value.hidden_activation_torch | first }}(),
{% elif loop.last %}
nn.Linear({{ models.value.hiddens[loop.previtem] }}, {{ models.value.hiddens | last }}),
nn.{{ models.value.hidden_activation | last }}(),
nn.{{ models.value.hidden_activation_torch | last }}(),
nn.Linear({{ models.value.hiddens | last }}, 1))
{% else %}
nn.Linear({{ models.value.hiddens[loop.previtem] }}, {{ models.value.hiddens[index] }}),
nn.{{ models.value.hidden_activation[index] }}(),
nn.{{ models.value.hidden_activation_torch[index] }}(),
{% endif %}
{% endfor %}

Expand Down Expand Up @@ -160,7 +160,7 @@ cfg["kl_threshold"] = {{ agent.kl_threshold }}
{% if agent.rewards_shaper_scale == 1.0 %}
cfg["rewards_shaper"] = None
{% else %}
cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * {{ agent.rewards_shaper_scale }}
cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * {{ agent.rewards_shaper_scale }}
{% endif %}
cfg["time_limit_bootstrap"] = True
{% if agent.state_preprocessor == "RunningStandardScaler" %}
Expand Down

0 comments on commit 35d4991

Please sign in to comment.