Skip to content

Commit

Permalink
chore: add hasac config
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Jul 31, 2024
1 parent 327f293 commit 67e0f6c
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 7 deletions.
7 changes: 7 additions & 0 deletions mava/configs/default_ff_hasac.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- _self_
- logger: ff_hasac
- arch: anakin
- system: sac/ff_hasac
- network: continuous_mlp
- env: mabrax
4 changes: 4 additions & 0 deletions mava/configs/logger/ff_hasac.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- base_logger

system_name: ff_masac
2 changes: 1 addition & 1 deletion mava/configs/logger/ff_masac.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defaults:
- base_logger

system_name: ff_masac
system_name: ff_hasac
36 changes: 36 additions & 0 deletions mava/configs/system/sac/ff_hasac.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# --- Defaults FF-MASAC ---
seed: 1

# --- Agent observations ---
add_agent_id: False

# --- RL hyperparameters ---
# step related
total_timesteps: ~ # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: 625000 # Number of updates
explore_steps: 500 # number of steps to take with random actions at the start of training
update_batch_size: 1 # number of vectorised gradient updates per device.

rollout_length: 1 # number of environment steps per vectorised environment.
epochs: 32 # number of learn epochs per training data batch.
policy_update_delay: 4 # the delay before training the policy -
# Every `policy_update_delay` q network learning steps the policy network is trained.
# It is trained `policy_update_delay` times to compensate, this is a TD3 trick.

# sizes
buffer_size: 1000000 # size of the replay buffer. Note: total size is this * num_devices
batch_size: 32

# learning rates
policy_lr: 3e-4 # the learning rate of the policy network optimizer
q_lr: 1e-3 # the learning rate of the Q network network optimizer
alpha_lr: 3e-4 # the learning rate of the alpha optimizer

# SAC specific
tau: 0.005 # smoothing coefficient for target networks
gamma: 0.99 # discount factor

autotune: True # whether to autotune alpha
target_entropy_scale: 5.0 # scale factor for target entropy (when auto-tuning)
init_alpha: 0.1 # initial entropy value when not using autotune
22 changes: 16 additions & 6 deletions mava/systems/sac/ff_hasac.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,9 +404,15 @@ def update_actor_and_alpha(
for agent_id in agent_ids:
actor_key, alpha_key = jax.random.split(key)

agent_actor_params = jax.tree_util.tree_map(lambda x, agent_id=agent_id: x[agent_id], params.actor)
actor_opt_state = jax.tree_util.tree_map(lambda x, agent_id=agent_id: x[agent_id], opt_states.actor)
obs_per_agent = jax.tree_util.tree_map(lambda x, agent_id=agent_id: x[:, agent_id], data.obs)
agent_actor_params = jax.tree_util.tree_map(
lambda x, agent_id=agent_id: x[agent_id], params.actor
)
actor_opt_state = jax.tree_util.tree_map(
lambda x, agent_id=agent_id: x[agent_id], opt_states.actor
)
obs_per_agent = jax.tree_util.tree_map(
lambda x, agent_id=agent_id: x[:, agent_id], data.obs
)
# Update actor.
actor_grad_fn = jax.value_and_grad(actor_loss_fn)
actor_loss, act_grads = actor_grad_fn(
Expand Down Expand Up @@ -470,10 +476,14 @@ def update_actor_and_alpha(
)

new_log_alphas = jax.tree_util.tree_map(
lambda x, y, agent_id=agent_id: x.at[agent_id].set(y), params.log_alpha, new_log_alpha
lambda x, y, agent_id=agent_id: x.at[agent_id].set(y),
params.log_alpha,
new_log_alpha,
)
new_alpha_opt_states = jax.tree_util.tree_map(
lambda x, y, agent_id=agent_id: x.at[agent_id].set(y), opt_states.alpha, new_alpha_opt_state
lambda x, y, agent_id=agent_id: x.at[agent_id].set(y),
opt_states.alpha,
new_alpha_opt_state,
)
params = params._replace(log_alpha=new_log_alphas)
opt_states = opt_states._replace(alpha=new_alpha_opt_states)
Expand Down Expand Up @@ -700,7 +710,7 @@ def run_experiment(cfg: DictConfig) -> float:
return eval_performance


@hydra.main(config_path="../../configs", config_name="default_ff_masac.yaml", version_base="1.2")
@hydra.main(config_path="../../configs", config_name="default_ff_hasac.yaml", version_base="1.2")
def hydra_entry_point(cfg: DictConfig) -> float:
"""Experiment entry point."""
# Allow dynamic attributes.
Expand Down

0 comments on commit 67e0f6c

Please sign in to comment.