Skip to content

Commit

Permalink
agent factor for both haiku and flax versions
Browse files Browse the repository at this point in the history
  • Loading branch information
perrin-isir committed Jul 29, 2023
1 parent 13642b1 commit 2244e7d
Show file tree
Hide file tree
Showing 17 changed files with 79 additions and 25 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ The *xpag-tutorials* repository contains a list of tutorials (colab notebooks) f

-----
## Short documentation
<details><summary><B><I>xpag</I>: a platform for goal-conditioned RL</B></summary>
<details><summary><B><I>xpag</I>: a platform for RL, goal-conditioned RL, and more.</B></summary>

*xpag* allows standard reinforcement learning, but it has been designed with
goal-conditioned reinforcement learning (GCRL) in mind (check out the [train_gmazes.ipynb](https://colab.research.google.com/github/perrin-isir/xpag-tutorials/blob/main/train_gmazes.ipynb)
Expand Down Expand Up @@ -237,7 +237,8 @@ The figure below summarizes the RL loop and the interactions between the compone
- Stéphane Caron (Inria)
- Fabian Schramm (Inria)

* The [SAC agent](https://github.com/perrin-isir/xpag/blob/main/xpag/agents/sac) is based on the implementation of SAC in [JAXRL](https://github.com/ikostrikov/jaxrl), and some elements of the [TQC agent](https://github.com/perrin-isir/xpag/blob/main/xpag/agents/tqc) come from the implementation of TQC in [RLJAX](https://github.com/ku2482/rljax).
* There is an interface to agents from the [RLJAX](https://github.com/ku2482/rljax) library (see [rljax_interface.py](https://github.com/perrin-isir/xpag/blob/main/xpag/agents/rljax_agents/rljax_interface.py)). This provides [haiku](https://github.com/deepmind/dm-haiku) versions of [DDPG](https://arxiv.org/abs/1509.02971), [TD3](https://arxiv.org/abs/1802.09477), [TQC](https://arxiv.org/abs/2005.04269), [SAC](https://arxiv.org/abs/1812.05905) and SAC with [DisCor](https://arxiv.org/abs/2003.07305).
* The [flax](https://github.com/google/flax) version of the [SAC agent](https://github.com/perrin-isir/xpag/blob/main/xpag/agents/flax_agents/sac) is based on the implementation of SAC in [JAXRL](https://github.com/ikostrikov/jaxrl), and some elements of the flax version of the [TQC agent](https://github.com/perrin-isir/xpag/blob/main/xpag/agents/flax_agents/tqc) come from the implementation of TQC in [RLJAX](https://github.com/ku2482/rljax).

-----
## Citing the project
Expand Down
10 changes: 2 additions & 8 deletions xpag/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
from xpag.agents.agent import (
Agent,
)
from xpag.agents.sac.sac import SAC
from xpag.agents.td3.td3 import TD3
from xpag.agents.tqc.tqc import TQC
from xpag.agents.sdqn.sdqn import SDQN, SDQNSetter
from xpag.agents.rljax_agents.rljax_interface import RLJAXSAC
from xpag.agents.agent import Agent
from xpag.agents.all_agents import SAC, TD3, TQC, SDQN, SDQNSetter
File renamed without changes.
4 changes: 2 additions & 2 deletions xpag/agents/sac/sac.py → xpag/agents/flax_agents/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from xpag.agents.agent import Agent
from xpag.agents.sac.sac_from_jaxrl import Batch, SACLearner
from xpag.agents.flax_agents.sac.sac_from_jaxrl import Batch, SACLearner
from xpag.tools.utils import squeeze
import functools
from typing import Callable, Any, Tuple
Expand All @@ -26,7 +26,7 @@ def _qvalue(
)


class SAC(Agent):
class FlaxSAC(Agent):
"""
Interface to the SAC agent from JAXRL (https://github.com/ikostrikov/jaxrl)
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class TrainingState:
steps: jnp.ndarray


class SDQN(Agent):
class FlaxSDQN(Agent):
def __init__(
self,
observation_dim,
Expand Down Expand Up @@ -564,8 +564,8 @@ def train_on_batch(self, batch):
return metrics


class SDQNSetter(Setter):
def __init__(self, sdqn_agent: SDQN):
class FlaxSDQNSetter(Setter):
def __init__(self, sdqn_agent: FlaxSDQN):
super().__init__("SDQNSetter")
self.agent = sdqn_agent

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class TrainingState:
steps: jnp.ndarray


class TD3(Agent):
class FlaxTD3(Agent):
def __init__(
self,
observation_dim,
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions xpag/agents/tqc/tqc.py → xpag/agents/flax_agents/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

import os
from xpag.agents.agent import Agent
from xpag.agents.sac.sac_from_jaxrl import (
from xpag.agents.flax_agents.sac.sac_from_jaxrl import (
PRNGKey,
InfoDict,
Params,
Expand Down Expand Up @@ -335,7 +335,7 @@ def update(self, batch: Batch) -> InfoDict:
return info


class TQC(Agent):
class FlaxTQC(Agent):
def __init__(self, observation_dim, action_dim, params=None):
"""
Interface to TQC agent
Expand Down
2 changes: 1 addition & 1 deletion xpag/agents/rljax_agents/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from xpag.agents.rljax_agents.algorithm.ddpg import DDPG
from xpag.agents.rljax_agents.algorithm.sac import SAC
from xpag.agents.rljax_agents.algorithm.sac_discor import SAC_DisCor
from xpag.agents.rljax_agents.algorithm.sac_discor import SACDisCor
from xpag.agents.rljax_agents.algorithm.td3 import TD3
from xpag.agents.rljax_agents.algorithm.tqc import TQC
27 changes: 26 additions & 1 deletion xpag/agents/rljax_agents/algorithm/ddpg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from typing import Tuple

import os
import joblib
import haiku as hk
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -220,3 +221,27 @@ def _loss_actor(
action = self.actor.apply(params_actor, state)
mean_q = self.critic.apply(params_critic, state, action)[0].mean()
return -mean_q, None

def save_params(self, save_dir):
os.makedirs(save_dir, exist_ok=True)
for filename in [
"params_critic",
"opt_state_critic",
"params_critic_target",
"params_actor",
"opt_state_actor",
]:
with open(os.path.join(save_dir, filename + ".joblib"), "wb") as f_:
joblib.dump(self.__dict__[filename], f_)

def load_params(self, save_dir):
for filename in [
"params_critic",
"opt_state_critic",
"params_critic_target",
"params_actor",
"opt_state_actor",
]:
self.__dict__[filename] = jax.tree_util.tree_map(
jnp.array, joblib.load(os.path.join(save_dir, filename + ".joblib"))
)
40 changes: 38 additions & 2 deletions xpag/agents/rljax_agents/algorithm/sac_discor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import jax.numpy as jnp

import os
import joblib
import jax
from xpag.agents.rljax_agents.algorithm.misc import DisCorMixIn
from xpag.agents.rljax_agents.algorithm.sac import SAC
from xpag.agents.rljax_agents.util import optimize


class SAC_DisCor(DisCorMixIn, SAC):
class SACDisCor(DisCorMixIn, SAC):
name = "SAC+DisCor"

def __init__(
Expand Down Expand Up @@ -167,3 +169,37 @@ def update(self, writer=None):
writer.add_scalar("loss/error", loss_error, self.learning_step)
writer.add_scalar("stat/alpha", jnp.exp(self.log_alpha), self.learning_step)
writer.add_scalar("stat/entropy", -mean_log_pi, self.learning_step)

def save_params(self, save_dir):
os.makedirs(save_dir, exist_ok=True)
for filename in [
"params_critic",
"opt_state_critic",
"params_critic_target",
"params_actor",
"opt_state_actor",
"log_alpha",
"opt_state_alpha",
"params_error",
"opt_state_error",
"params_error_target",
]:
with open(os.path.join(save_dir, filename + ".joblib"), "wb") as f_:
joblib.dump(self.__dict__[filename], f_)

def load_params(self, save_dir):
for filename in [
"params_critic",
"opt_state_critic",
"params_critic_target",
"params_actor",
"opt_state_actor",
"log_alpha",
"opt_state_alpha",
"params_error",
"opt_state_error",
"params_error_target",
]:
self.__dict__[filename] = jax.tree_util.tree_map(
jnp.array, joblib.load(os.path.join(save_dir, filename + ".joblib"))
)
1 change: 0 additions & 1 deletion xpag/agents/rljax_agents/algorithm/td3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from functools import partial
from typing import Tuple

import haiku as hk
import jax
import jax.numpy as jnp
Expand Down
1 change: 0 additions & 1 deletion xpag/agents/rljax_agents/algorithm/tqc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from functools import partial
from typing import List

import haiku as hk
import jax
import jax.numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion xpag/agents/rljax_agents/rljax_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def sample(self, batch_size):
return 1.0, self.next_batch


class RLJAXSAC(Agent):
class RljaxSAC(Agent):
"""
Interface to the SAC agent from RLJAX (https://github.com/toshikwa/rljax)
Expand Down

0 comments on commit 2244e7d

Please sign in to comment.