Skip to content

Commit

Permalink
chore: remove evaluator for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a committed Aug 1, 2024
1 parent 0bef27e commit f3b518f
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions mava/systems/sac/ff_masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def run_experiment(cfg: DictConfig) -> float:

actor, _ = networks
key, eval_key = jax.random.split(key)
evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor.apply, cfg)
# evaluator, absolute_metric_evaluator = make_eval_fns(eval_env, actor.apply, cfg)

if cfg.logger.checkpointing.save_model:
checkpointer = Checkpointer(
Expand Down Expand Up @@ -569,19 +569,19 @@ def run_experiment(cfg: DictConfig) -> float:
logger.log(loss_metrics, t, eval_idx, LogEvent.TRAIN)

# Evaluate:
key, eval_key = jax.random.split(key)
eval_keys = jax.random.split(eval_key, cfg.arch.n_devices)
eval_output = evaluator(unreplicate_batch_dim(learner_state.params.actor), eval_keys)
jax.block_until_ready(eval_output)
# key, eval_key = jax.random.split(key)
# eval_keys = jax.random.split(eval_key, cfg.arch.n_devices)
# eval_output = evaluator(unreplicate_batch_dim(learner_state.params.actor), eval_keys)
# jax.block_until_ready(eval_output)

# Log:
episode_return = jnp.mean(eval_output.episode_metrics["episode_return"])
logger.log(eval_output.episode_metrics, t, eval_idx, LogEvent.EVAL)
# episode_return = jnp.mean(eval_output.episode_metrics["episode_return"])
# logger.log(eval_output.episode_metrics, t, eval_idx, LogEvent.EVAL)

# Save best actor params.
if cfg.arch.absolute_metric and max_episode_return <= episode_return:
best_params = copy.deepcopy(unreplicate_batch_dim(learner_state.params.actor))
max_episode_return = episode_return
# if cfg.arch.absolute_metric and max_episode_return <= episode_return:
# best_params = copy.deepcopy(unreplicate_batch_dim(learner_state.params.actor))
# max_episode_return = episode_return

# Checkpoint:
if cfg.logger.checkpointing.save_model:
Expand All @@ -594,20 +594,21 @@ def run_experiment(cfg: DictConfig) -> float:
)

# Record the performance for the final evaluation run.
eval_performance = float(jnp.mean(eval_output.episode_metrics[cfg.env.eval_metric]))
# eval_performance = float(jnp.mean(eval_output.episode_metrics[cfg.env.eval_metric]))

# Measure absolute metric.
if cfg.arch.absolute_metric:
eval_keys = jax.random.split(key, cfg.arch.n_devices)

eval_output = absolute_metric_evaluator(best_params, eval_keys)
jax.block_until_ready(eval_output)

logger.log(eval_output.episode_metrics, t, eval_idx, LogEvent.ABSOLUTE)
# if cfg.arch.absolute_metric:
# eval_keys = jax.random.split(key, cfg.arch.n_devices)
#
# eval_output = absolute_metric_evaluator(best_params, eval_keys)
# jax.block_until_ready(eval_output)
#
# logger.log(eval_output.episode_metrics, t, eval_idx, LogEvent.ABSOLUTE)

logger.stop()

return eval_performance
# return eval_performance
return 0.0


@hydra.main(config_path="../../configs", config_name="default_ff_masac.yaml", version_base="1.2")
Expand Down

0 comments on commit f3b518f

Please sign in to comment.