Skip to content

Commit

Permalink
[RLlib] JAXPolicy prep. PR #1. (ray-project#13077)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Dec 27, 2020
1 parent 25f9f0d commit 99ae7ba
Show file tree
Hide file tree
Showing 28 changed files with 501 additions and 359 deletions.
5 changes: 3 additions & 2 deletions rllib/agents/a3c/a3c_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ray
from ray.rllib.evaluation.postprocessing import compute_advantages, \
Postprocessing
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()
Expand Down Expand Up @@ -84,8 +84,9 @@ def _value(self, obs):
return self.model.value_function()[0]


A3CTorchPolicy = build_torch_policy(
A3CTorchPolicy = build_policy_class(
name="A3CTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.a3c.a3c.DEFAULT_CONFIG,
loss_fn=actor_critic_loss,
stats_fn=loss_and_entropy_stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/ars/ars_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import ray
from ray.rllib.agents.es.es_torch_policy import after_init, before_init, \
make_model_and_action_dist
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.policy.policy_template import build_policy_class

ARSTorchPolicy = build_torch_policy(
ARSTorchPolicy = build_policy_class(
name="ARSTorchPolicy",
framework="torch",
loss_fn=None,
get_default_config=lambda: ray.rllib.agents.ars.ars.DEFAULT_CONFIG,
before_init=before_init,
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/ddpg/ddpg_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import try_import_torch, get_activation_fn
from ray.rllib.utils.framework import get_activation_fn, try_import_torch

torch, nn = try_import_torch()

Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/ddpg/ddpg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from ray.rllib.agents.dqn.dqn_tf_policy import postprocess_nstep_and_prio, \
PRIO_WEIGHTS
from ray.rllib.models.torch.torch_action_dist import TorchDeterministic
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import huber_loss, l2_loss

Expand Down Expand Up @@ -264,8 +264,9 @@ def setup_late_mixins(policy, obs_space, action_space, config):
TargetNetworkMixin.__init__(policy)


DDPGTorchPolicy = build_torch_policy(
DDPGTorchPolicy = build_policy_class(
name="DDPGTorchPolicy",
framework="torch",
loss_fn=ddpg_actor_critic_loss,
get_default_config=lambda: ray.rllib.agents.ddpg.ddpg.DEFAULT_CONFIG,
stats_fn=build_ddpg_stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/dqn/dqn_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from ray.rllib.models.torch.torch_action_dist import (TorchCategorical,
TorchDistributionWrapper)
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
from ray.rllib.utils.framework import try_import_torch
Expand Down Expand Up @@ -384,8 +384,9 @@ def extra_action_out_fn(policy: Policy, input_dict, state_batches, model,
return {"q_values": policy.q_values}


DQNTorchPolicy = build_torch_policy(
DQNTorchPolicy = build_policy_class(
name="DQNTorchPolicy",
framework="torch",
loss_fn=build_q_losses,
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
make_model_and_action_dist=build_q_model_and_distribution,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/dqn/simple_q_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchDistributionWrapper
from ray.rllib.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import huber_loss
from ray.rllib.utils.typing import TensorType, TrainerConfigDict
Expand Down Expand Up @@ -127,8 +127,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,
TargetNetworkMixin.__init__(policy, obs_space, action_space, config)


SimpleQTorchPolicy = build_torch_policy(
SimpleQTorchPolicy = build_policy_class(
name="SimpleQPolicy",
framework="torch",
loss_fn=build_q_losses,
get_default_config=lambda: ray.rllib.agents.dqn.dqn.DEFAULT_CONFIG,
extra_action_out_fn=extra_action_out_fn,
Expand Down
9 changes: 5 additions & 4 deletions rllib/agents/dreamer/dreamer_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import logging

import ray
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.agents.dreamer.utils import FreezeParameters
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()
if torch:
Expand Down Expand Up @@ -236,8 +236,9 @@ def dreamer_optimizer_fn(policy, config):
return (model_opt, actor_opt, critic_opt)


DreamerTorchPolicy = build_torch_policy(
DreamerTorchPolicy = build_policy_class(
name="DreamerTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.dreamer.dreamer.DEFAULT_CONFIG,
action_sampler_fn=action_sampler_fn,
loss_fn=dreamer_loss,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/es/es_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import ray
from ray.rllib.models import ModelCatalog
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space, \
Expand Down Expand Up @@ -126,8 +126,9 @@ def make_model_and_action_dist(policy, observation_space, action_space,
return model, dist_class


ESTorchPolicy = build_torch_policy(
ESTorchPolicy = build_policy_class(
name="ESTorchPolicy",
framework="torch",
loss_fn=None,
get_default_config=lambda: ray.rllib.agents.es.es.DEFAULT_CONFIG,
before_init=before_init,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/impala/vtrace_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from ray.rllib.agents.a3c.a3c_torch_policy import apply_grad_clipping
import ray.rllib.agents.impala.vtrace_torch as vtrace
from ray.rllib.models.torch.torch_action_dist import TorchCategorical
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule, \
EntropyCoeffSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
sequence_mask
Expand Down Expand Up @@ -260,8 +260,9 @@ def setup_mixins(policy, obs_space, action_space, config):
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])


VTraceTorchPolicy = build_torch_policy(
VTraceTorchPolicy = build_policy_class(
name="VTraceTorchPolicy",
framework="torch",
loss_fn=build_vtrace_loss,
get_default_config=lambda: ray.rllib.agents.impala.impala.DEFAULT_CONFIG,
stats_fn=stats,
Expand Down
6 changes: 3 additions & 3 deletions rllib/agents/maml/maml_tf_policy.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging

import ray
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
ValueNetworkMixin
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
ValueNetworkMixin
from ray.rllib.utils.framework import get_activation_fn

tf1, tf, tfv = try_import_tf()
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/maml/maml_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import ray
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae, \
setup_config
from ray.rllib.agents.ppo.ppo_torch_policy import vf_preds_fetches, \
Expand Down Expand Up @@ -347,8 +347,9 @@ def setup_mixins(policy, obs_space, action_space, config):
KLCoeffMixin.__init__(policy, config)


MAMLTorchPolicy = build_torch_policy(
MAMLTorchPolicy = build_policy_class(
name="MAMLTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.maml.maml.DEFAULT_CONFIG,
loss_fn=maml_loss,
stats_fn=maml_stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/marwil/marwil_torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import ray
from ray.rllib.agents.marwil.marwil_tf_policy import postprocess_advantages
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance

Expand Down Expand Up @@ -75,8 +75,9 @@ def setup_mixins(policy, obs_space, action_space, config):
ValueNetworkMixin.__init__(policy)


MARWILTorchPolicy = build_torch_policy(
MARWILTorchPolicy = build_policy_class(
name="MARWILTorchPolicy",
framework="torch",
loss_fn=marwil_loss,
get_default_config=lambda: ray.rllib.agents.marwil.marwil.DEFAULT_CONFIG,
stats_fn=stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/mbmpo/mbmpo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TrainerConfigDict

Expand Down Expand Up @@ -76,8 +76,9 @@ def make_model_and_action_dist(

# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
MBMPOTorchPolicy = build_torch_policy(
MBMPOTorchPolicy = build_policy_class(
name="MBMPOTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.mbmpo.mbmpo.DEFAULT_CONFIG,
make_model_and_action_dist=make_model_and_action_dist,
loss_fn=maml_loss,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/pg/pg_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType

Expand Down Expand Up @@ -72,8 +72,9 @@ def pg_loss_stats(policy: Policy,
# Build a child class of `TFPolicy`, given the extra options:
# - trajectory post-processing function (to calculate advantages)
# - PG loss function
PGTorchPolicy = build_torch_policy(
PGTorchPolicy = build_policy_class(
name="PGTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG,
loss_fn=pg_torch_loss,
stats_fn=pg_loss_stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/ppo/appo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
from ray.rllib.models.torch.torch_action_dist import \
TorchDistributionWrapper, TorchCategorical
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import explained_variance, global_norm, \
sequence_mask
Expand Down Expand Up @@ -322,8 +322,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,

# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
AsyncPPOTorchPolicy = build_torch_policy(
AsyncPPOTorchPolicy = build_policy_class(
name="AsyncPPOTorchPolicy",
framework="torch",
loss_fn=appo_surrogate_loss,
stats_fn=stats,
postprocess_fn=postprocess_trajectory,
Expand Down
12 changes: 7 additions & 5 deletions rllib/agents/ppo/ppo_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy import EntropyCoeffSchedule, \
LearningRateSchedule
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import convert_to_torch_tensor, \
explained_variance, sequence_mask
Expand Down Expand Up @@ -111,6 +111,9 @@ def reduce_mean_valid(t):
policy._total_loss = total_loss
policy._mean_policy_loss = mean_policy_loss
policy._mean_vf_loss = mean_vf_loss
policy._vf_explained_var = explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function())
policy._mean_entropy = mean_entropy
policy._mean_kl = mean_kl

Expand All @@ -134,9 +137,7 @@ def kl_and_loss_stats(policy: Policy,
"total_loss": policy._total_loss,
"policy_loss": policy._mean_policy_loss,
"vf_loss": policy._mean_vf_loss,
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy.model.value_function()),
"vf_explained_var": policy._vf_explained_var,
"kl": policy._mean_kl,
"entropy": policy._mean_entropy,
"entropy_coeff": policy.entropy_coeff,
Expand Down Expand Up @@ -271,8 +272,9 @@ def setup_mixins(policy: Policy, obs_space: gym.spaces.Space,

# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
PPOTorchPolicy = build_torch_policy(
PPOTorchPolicy = build_policy_class(
name="PPOTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.ppo.ppo.DEFAULT_CONFIG,
loss_fn=ppo_surrogate_loss,
stats_fn=kl_and_loss_stats,
Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/qmix/qmix_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward(self,
return loss, mask, masked_td_error, chosen_action_qvals, targets


# TODO(sven): Make this a TorchPolicy child via `build_torch_policy`.
# TODO(sven): Make this a TorchPolicy child via `build_policy_class`.
class QMixTorchPolicy(Policy):
"""QMix impl. Assumes homogeneous agents for now.
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/sac/sac_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from ray.rllib.models.torch.torch_action_dist import \
TorchDistributionWrapper, TorchDirichlet
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.models.torch.torch_action_dist import (
TorchCategorical, TorchSquashedGaussian, TorchDiagGaussian, TorchBeta)
from ray.rllib.utils.framework import try_import_torch
Expand Down Expand Up @@ -480,8 +480,9 @@ def setup_late_mixins(policy: Policy, obs_space: gym.spaces.Space,

# Build a child class of `TorchPolicy`, given the custom functions defined
# above.
SACTorchPolicy = build_torch_policy(
SACTorchPolicy = build_policy_class(
name="SACTorchPolicy",
framework="torch",
loss_fn=actor_critic_loss,
get_default_config=lambda: ray.rllib.agents.sac.sac.DEFAULT_CONFIG,
stats_fn=stats,
Expand Down
5 changes: 3 additions & 2 deletions rllib/agents/slateq/slateq_torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
TorchDistributionWrapper)
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy_template import build_policy_class
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.torch_policy_template import build_torch_policy
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import (ModelConfigDict, TensorType,
TrainerConfigDict)
Expand Down Expand Up @@ -403,8 +403,9 @@ def postprocess_fn_add_next_actions_for_sarsa(policy: Policy,
return batch


SlateQTorchPolicy = build_torch_policy(
SlateQTorchPolicy = build_policy_class(
name="SlateQTorchPolicy",
framework="torch",
get_default_config=lambda: ray.rllib.agents.slateq.slateq.DEFAULT_CONFIG,

# build model, loss functions, and optimizers
Expand Down
Loading

0 comments on commit 99ae7ba

Please sign in to comment.