-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add template for Isaac Lab examples in JAX
- Loading branch information
Showing
3 changed files
with
220 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
163 changes: 163 additions & 0 deletions
163
docs/source/examples/isaaclab/generator/templates/ppo_skrl_py_jax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
# import the skrl components to build the RL system | ||
from skrl import config | ||
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG | ||
from skrl.envs.loaders.jax import load_isaaclab_env | ||
from skrl.envs.wrappers.jax import wrap_env | ||
from skrl.memories.jax import RandomMemory | ||
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model | ||
from skrl.resources.preprocessors.jax import RunningStandardScaler | ||
from skrl.resources.schedulers.jax import KLAdaptiveRL | ||
from skrl.trainers.jax import SequentialTrainer | ||
from skrl.utils import set_seed | ||
|
||
|
||
config.jax.backend = "jax" # or "numpy" | ||
|
||
|
||
# seed for reproducibility | ||
set_seed() # e.g. `set_seed(40)` for fixed seed | ||
|
||
|
||
# define models (stochastic and deterministic models) using mixins | ||
class Policy(GaussianMixin, Model): | ||
def __init__(self, observation_space, action_space, device=None, clip_actions=False, | ||
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs): | ||
Model.__init__(self, observation_space, action_space, device, **kwargs) | ||
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) | ||
|
||
@nn.compact # marks the given module method allowing inlined submodules | ||
def __call__(self, inputs, role): | ||
{% for index in range(models.policy.hiddens | length) %} | ||
{% if loop.first %} | ||
x = nn.{{ models.policy.hidden_activation__jax[index] }}(nn.Dense({{ models.policy.hiddens[index] }})(inputs["states"])) | ||
{% else %} | ||
x = nn.{{ models.policy.hidden_activation__jax[index] }}(nn.Dense({{ models.policy.hiddens[index] }})(x)) | ||
{% endif %} | ||
{% endfor %} | ||
x = nn.Dense(self.num_actions)(x) | ||
log_std = self.param("log_std", lambda _: jnp.ones(self.num_actions)) | ||
return x, log_std, {} | ||
|
||
class Value(DeterministicMixin, Model): | ||
def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs): | ||
Model.__init__(self, observation_space, action_space, device, **kwargs) | ||
DeterministicMixin.__init__(self, clip_actions) | ||
|
||
@nn.compact # marks the given module method allowing inlined submodules | ||
def __call__(self, inputs, role): | ||
{% for index in range(models.value.hiddens | length) %} | ||
{% if loop.first %} | ||
x = nn.{{ models.value.hidden_activation__jax[index] }}(nn.Dense({{ models.value.hiddens[index] }})(inputs["states"])) | ||
{% else %} | ||
x = nn.{{ models.value.hidden_activation__jax[index] }}(nn.Dense({{ models.value.hiddens[index] }})(x)) | ||
{% endif %} | ||
{% endfor %} | ||
x = nn.Dense(1)(x) | ||
return x, {} | ||
|
||
|
||
# load and wrap the Isaac Lab environment | ||
env = load_isaaclab_env(task_name="{{ metadata.task }}") | ||
env = wrap_env(env) | ||
|
||
device = env.device | ||
|
||
|
||
# instantiate a memory as rollout buffer (any memory can be used for this) | ||
memory = RandomMemory(memory_size={{ agent.rollouts }}, num_envs=env.num_envs, device=device) | ||
|
||
|
||
# instantiate the agent's models (function approximators). | ||
# PPO requires 2 models, visit its documentation for more details | ||
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html#models | ||
models = {} | ||
models["policy"] = Policy(env.observation_space, env.action_space, device) | ||
models["value"] = Value(env.observation_space, env.action_space, device) | ||
|
||
# instantiate models' state dict | ||
for role, model in models.items(): | ||
model.init_state_dict(role) | ||
|
||
|
||
# configure and instantiate the agent (visit its documentation to see all the options) | ||
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html#configuration-and-hyperparameters | ||
cfg = PPO_DEFAULT_CONFIG.copy() | ||
cfg["rollouts"] = {{ agent.rollouts }} # memory_size | ||
cfg["learning_epochs"] = {{ agent.learning_epochs }} | ||
cfg["mini_batches"] = {{ agent.mini_batches }} | ||
cfg["discount_factor"] = {{ agent.discount_factor }} | ||
cfg["lambda"] = {{ agent.lambda }} | ||
cfg["learning_rate"] = {{ "%.1e" | format(agent.learning_rate) }} | ||
{% if agent.learning_rate_scheduler == "KLAdaptiveLR" %} | ||
cfg["learning_rate_scheduler"] = KLAdaptiveRL | ||
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": {{ agent.learning_rate_scheduler_kwargs.kl_threshold }}} | ||
{% else %} | ||
cfg["learning_rate_scheduler"] = None | ||
cfg["learning_rate_scheduler_kwargs"] = {} | ||
{% endif %} | ||
cfg["random_timesteps"] = 0 | ||
cfg["learning_starts"] = 0 | ||
cfg["grad_norm_clip"] = {{ agent.grad_norm_clip }} | ||
cfg["ratio_clip"] = {{ agent.ratio_clip }} | ||
cfg["value_clip"] = {{ agent.value_clip }} | ||
cfg["clip_predicted_values"] = {{ agent.clip_predicted_values }} | ||
cfg["entropy_loss_scale"] = {{ agent.entropy_loss_scale }} | ||
cfg["value_loss_scale"] = {{ agent.value_loss_scale }} | ||
cfg["kl_threshold"] = {{ agent.kl_threshold }} | ||
{% if agent.rewards_shaper_scale == 1.0 %} | ||
cfg["rewards_shaper"] = None | ||
{% else %} | ||
cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * {{ agent.rewards_shaper_scale }} | ||
{% endif %} | ||
cfg["time_limit_bootstrap"] = True | ||
{% if agent.state_preprocessor == "RunningStandardScaler" %} | ||
cfg["state_preprocessor"] = RunningStandardScaler | ||
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} | ||
{% else %} | ||
cfg["state_preprocessor"] = None | ||
cfg["state_preprocessor_kwargs"] = {} | ||
{% endif %} | ||
{% if agent.value_preprocessor == "RunningStandardScaler" %} | ||
cfg["value_preprocessor"] = RunningStandardScaler | ||
cfg["value_preprocessor_kwargs"] = {"size": 1, "device": device} | ||
{% else %} | ||
cfg["value_preprocessor"] = None | ||
cfg["value_preprocessor_kwargs"] = {} | ||
{% endif %} | ||
# logging to TensorBoard and write checkpoints (in timesteps) | ||
cfg["experiment"]["write_interval"] = "auto" | ||
cfg["experiment"]["checkpoint_interval"] = "auto" | ||
cfg["experiment"]["directory"] = "runs/jax/{{ metadata.task }}" | ||
|
||
agent = PPO(models=models, | ||
memory=memory, | ||
cfg=cfg, | ||
observation_space=env.observation_space, | ||
action_space=env.action_space, | ||
device=device) | ||
|
||
|
||
# configure and instantiate the RL trainer | ||
cfg_trainer = {"timesteps": {{ trainer.timesteps }}, "headless": True} | ||
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent) | ||
|
||
# start training | ||
trainer.train() | ||
|
||
|
||
# # --------------------------------------------------------- | ||
# # comment the code above: `trainer.train()`, and... | ||
# # uncomment the following lines to evaluate a trained agent | ||
# # --------------------------------------------------------- | ||
# from skrl.utils.huggingface import download_model_from_huggingface | ||
|
||
# # download the trained agent's checkpoint from Hugging Face Hub and load it | ||
# path = download_model_from_huggingface("skrl/IsaacLab-{{ metadata.task }}-PPO", filename="agent.pickle") | ||
# agent.load(path) | ||
|
||
# # start evaluation | ||
# trainer.eval() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters