Skip to content

Commit

Permalink
fix: actually select actions
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Aug 5, 2024
1 parent dfaf7c4 commit 2d531fa
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions mava/systems/sac/ff_hasac.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)
from mava.types import ObservationGlobalState
from mava.utils import make_env as environments
from mava.utils.centralised_training import get_joint_action
from mava.utils.centralised_training import get_joint_action, get_updated_joint_actions
from mava.utils.checkpointing import Checkpointer
from mava.utils.jax_utils import (
tree_at_set,
Expand All @@ -73,8 +73,8 @@ def get_action(actor_params, actor_net, keys, env, obs, batch_size):

pi = actor_net.apply(actor_params_per_agent, obs_per_agent)
action = pi.sample(seed=keys[agent])
actions.at[:, agent].set(action)
log_std.at[:, agent].set(pi.log_prob(action))
actions = actions.at[:, agent].set(action)
log_std = log_std.at[:, agent].set(pi.log_prob(action))

return actions, log_std

Expand Down Expand Up @@ -322,7 +322,7 @@ def actor_loss_fn(
# This is done by placing new_action[i] in joint_actions[i].
# [32, 4, 2] -> insert -> [32, 8]
joint_actions = actions.at[:, agent_id, :].set(new_actions).reshape(B, -1)
# joint_actions = get_updated_joint_actions(actions, new_actions)
# joint_actions = get_updated_joint_actions(actions, new_actions)[:, agent_id]

qval_1 = q_net.apply(q_params.q1, obs, joint_actions)
qval_2 = q_net.apply(q_params.q2, obs, joint_actions)
Expand Down Expand Up @@ -387,7 +387,10 @@ def update_actor_and_alpha(
for _ in range(cfg.system.policy_update_delay):
key, act_key, agent_order_key = jax.random.split(key, 3)
act_keys = jax.random.split(act_key, env.num_agents)
agent_ids = jax.random.permutation(agent_order_key, env.num_agents)
if cfg.system.shuffle_agents:
agent_ids = jax.random.permutation(agent_order_key, env.num_agents)
else:
agent_ids = jnp.arange(env.num_agents)

# todo: we can almost certainly get this from the buffer, we just need the log probs for alpha :/
actions, log_probs = get_action(
Expand Down

0 comments on commit 2d531fa

Please sign in to comment.