Skip to content

Commit

Permalink
feat: add support for continuous action space environments to happo
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Aug 1, 2024
1 parent 6c56349 commit b3fce1a
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 13 deletions.
7 changes: 7 additions & 0 deletions mava/configs/default_ff_happo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- logger: ff_happo
- arch: anakin
- system: ppo/ff_happo
- network: mlp
- env: rware
- _self_
4 changes: 4 additions & 0 deletions mava/configs/logger/ff_happo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
defaults:
- base_logger

system_name: ff_mappo
23 changes: 23 additions & 0 deletions mava/configs/system/ppo/ff_happo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# --- Defaults FF-HAPPO ---
total_timesteps: ~ # Set the total environment steps.
# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value.
num_updates: 1220 # Number of updates
seed: 42

# --- Agent observations ---
add_agent_id: False

# --- RL hyperparameters ---
actor_lr: 2.5e-4 # Learning rate for actor network
critic_lr: 2.5e-4 # Learning rate for critic network
update_batch_size: 2 # Number of vectorised gradient updates per device.
rollout_length: 128 # Number of environment steps per vectorised environment.
ppo_epochs: 4 # Number of ppo epochs per training data batch.
num_minibatches: 2 # Number of minibatches per ppo epoch.
gamma: 0.99 # Discounting factor.
gae_lambda: 0.95 # Lambda value for GAE computation.
clip_eps: 0.2 # Clipping value for PPO updates and value function.
ent_coef: 0.01 # Entropy regularisation term for loss function.
vf_coef: 0.5 # Critic weight in
max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update.
decay_learning_rates: False # Whether learning rates should be linearly decayed during training.
13 changes: 10 additions & 3 deletions mava/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,19 @@ def _env_step(eval_state: EvalState) -> EvalState:
key, policy_key = jax.random.split(key)
# Add a batch dimension to the observation.
if is_happo:
action = jnp.zeros((config.system.num_agents), dtype=jnp.int32)
# check if the environment is continuous to set the action array correctly.
if "Continuous" in config.network.action_head._target_:
action = jnp.zeros(
(config.system.num_agents, config.system.action_dim), dtype=jnp.float32
)
else:
action = jnp.zeros((config.system.num_agents), dtype=jnp.int32)

for agent in range(config.system.num_agents):
single_agent_obs = jax.tree_util.tree_map(
lambda x: x[jnp.newaxis, agent], last_timestep.observation
lambda x, agent=agent: x[jnp.newaxis, agent], last_timestep.observation
)
agent_params = jax.tree_util.tree_map(lambda x: x[agent], params)
agent_params = jax.tree_util.tree_map(lambda x, agent=agent: x[agent], params)
pi = apply_fn(agent_params, single_agent_obs)
if config.arch.evaluation_greedy:
action_per_agent = pi.mode()
Expand Down
44 changes: 34 additions & 10 deletions mava/systems/ppo/ff_happo.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,26 @@ def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTra
# SELECT ACTION
key, policy_key = jax.random.split(key)
value = critic_apply_fn(params.critic_params, last_timestep.observation)
actions = jnp.zeros((config.arch.num_envs, config.system.num_agents), dtype=jnp.int32)

# check if the environment is continuous or discrete to set the action array correctly.
if "Continuous" in config.network.action_head._target_:
actions = jnp.zeros(
(config.arch.num_envs, config.system.num_agents, config.system.action_dim),
dtype=jnp.float32,
)
else:
actions = jnp.zeros(
(config.arch.num_envs, config.system.num_agents), dtype=jnp.int32
)

log_probs = jnp.zeros((config.arch.num_envs, config.system.num_agents))
for agent in range(config.system.num_agents):
single_agent_obs = jax.tree_util.tree_map(
lambda x: x[:, agent], last_timestep.observation
lambda x, agent=agent: x[:, agent], last_timestep.observation
)
agent_params = jax.tree_util.tree_map(
lambda x, agent=agent: x[agent], params.actor_params
)
agent_params = jax.tree_util.tree_map(lambda x: x[agent], params.actor_params)
actor_policy = actor_apply_fn(agent_params, single_agent_obs)
action = actor_policy.sample(seed=policy_key)
log_prob = actor_policy.log_prob(action)
Expand Down Expand Up @@ -221,9 +234,15 @@ def _critic_loss_fn(
shuffled_agents = jax.random.permutation(shuffle_key, config.system.num_agents)
for agent in shuffled_agents:
key, entropy_key = jax.random.split(key)
agent_params = jax.tree_util.tree_map(lambda x: x[agent], agents_params)
agent_traj = jax.tree_util.tree_map(lambda x: x[:, agent], traj_batch)
agent_opt_state = jax.tree_util.tree_map(lambda x: x[agent], agent_opt_states)
agent_params = jax.tree_util.tree_map(
lambda x, agent=agent: x[agent], agents_params
)
agent_traj = jax.tree_util.tree_map(
lambda x, agent=agent: x[:, agent], traj_batch
)
agent_opt_state = jax.tree_util.tree_map(
lambda x, agent=agent: x[agent], agent_opt_states
)

actor_loss_info_per_agent, actor_grads_per_agent = actor_grad_fn(
agent_params,
Expand All @@ -248,10 +267,14 @@ def _critic_loss_fn(
actor_new_params = optax.apply_updates(agent_params, actor_updates)

agents_params = jax.tree_util.tree_map(
lambda x, y: x.at[agent].set(y), agents_params, actor_new_params
lambda x, y, agent=agent: x.at[agent].set(y),
agents_params,
actor_new_params,
)
agent_opt_states = jax.tree_util.tree_map(
lambda x, y: x.at[agent].set(y), agent_opt_states, actor_new_opt_state
lambda x, y, agent=agent: x.at[agent].set(y),
agent_opt_states,
actor_new_opt_state,
)

# Updating the gae
Expand Down Expand Up @@ -376,6 +399,7 @@ def learner_setup(

# Get number of agents.
config.system.num_agents = env.num_agents
config.system.action_dim = env.action_dim

# PRNG keys.
key, actor_net_key, critic_net_key = keys
Expand Down Expand Up @@ -610,15 +634,15 @@ def run_experiment(_config: DictConfig) -> float:
return eval_performance


@hydra.main(config_path="../../configs", config_name="default_ff_mappo.yaml", version_base="1.2")
@hydra.main(config_path="../../configs", config_name="default_ff_happo.yaml", version_base="1.2")
def hydra_entry_point(cfg: DictConfig) -> float:
"""Experiment entry point."""
# Allow dynamic attributes.
OmegaConf.set_struct(cfg, False)

# Run experiment.
eval_performance = run_experiment(cfg)
print(f"{Fore.CYAN}{Style.BRIGHT}MAPPO experiment completed{Style.RESET_ALL}")
print(f"{Fore.CYAN}{Style.BRIGHT}HAPPO experiment completed{Style.RESET_ALL}")
return eval_performance


Expand Down

0 comments on commit b3fce1a

Please sign in to comment.