Skip to content

Commit

Permalink
[RLlib] JAXPolicy prep PR #2 (move get_activation_fn (backward-compat…
Browse files Browse the repository at this point in the history
…ibly), minor fixes and preparations). (ray-project#13091)
  • Loading branch information
sven1977 authored Dec 31, 2020
1 parent 6a54897 commit 8726521
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 21 deletions.
3 changes: 2 additions & 1 deletion rllib/agents/ddpg/ddpg_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()

Expand Down
2 changes: 1 addition & 1 deletion rllib/agents/maml/maml_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
vf_preds_fetches, compute_and_clip_gradients, setup_config, \
ValueNetworkMixin
from ray.rllib.evaluation.postprocessing import Postprocessing
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.policy.tf_policy_template import build_tf_policy
from ray.rllib.utils import try_import_tf
from ray.rllib.utils.framework import get_activation_fn

tf1, tf, tfv = try_import_tf()

Expand Down
3 changes: 2 additions & 1 deletion rllib/agents/sac/sac_torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.spaces.simplex import Simplex
from ray.rllib.utils.typing import ModelConfigDict, TensorType

Expand Down
3 changes: 2 additions & 1 deletion rllib/models/tf/fcnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import Dict, TensorType, List, ModelConfigDict

tf1, tf, tfv = try_import_tf()
Expand Down
6 changes: 3 additions & 3 deletions rllib/models/tf/layers/noisy_layer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np

from ray.rllib.utils.framework import get_activation_fn, get_variable, \
try_import_tf
from ray.rllib.utils.framework import TensorType, TensorShape
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import get_variable, try_import_tf, \
TensorType, TensorShape

tf1, tf, tfv = try_import_tf()

Expand Down
4 changes: 2 additions & 2 deletions rllib/models/tf/visionnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer
from ray.rllib.models.utils import get_filter_config
from ray.rllib.utils.framework import get_activation_fn, try_import_tf
from ray.rllib.models.utils import get_activation_fn, get_filter_config
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.typing import ModelConfigDict, TensorType

tf1, tf, tfv = try_import_tf()
Expand Down
3 changes: 2 additions & 1 deletion rllib/models/torch/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
from typing import Union, Tuple, Any, List

from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/torch/modules/convtranspose2d_stack.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Tuple

from ray.rllib.models.torch.misc import Reshape
from ray.rllib.models.utils import get_initializer
from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.models.utils import get_activation_fn, get_initializer
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()
if torch:
Expand Down
4 changes: 2 additions & 2 deletions rllib/models/torch/modules/noisy_layer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from ray.rllib.utils.framework import get_activation_fn, try_import_torch
from ray.rllib.utils.framework import TensorType
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.utils.framework import try_import_torch, TensorType

torch, nn = try_import_torch()

Expand Down
81 changes: 79 additions & 2 deletions rllib/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,62 @@
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from typing import Optional

from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
try_import_torch


def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
"""Returns a framework specific activation function, given a name string.
Args:
name (Optional[str]): One of "relu" (default), "tanh", "swish", or
"linear" or None.
framework (str): One of "jax", "tf|tfe|tf2" or "torch".
Returns:
A framework-specific activtion function. e.g. tf.nn.tanh or
torch.nn.ReLU. None if name in ["linear", None].
Raises:
ValueError: If name is an unknown activation function.
"""
# Already a callable, return as-is.
if callable(name):
return name

# Infer the correct activation function from the string specifier.
if framework == "torch":
if name in ["linear", None]:
return None
if name == "swish":
from ray.rllib.utils.torch_ops import Swish
return Swish
_, nn = try_import_torch()
if name == "relu":
return nn.ReLU
elif name == "tanh":
return nn.Tanh
elif framework == "jax":
if name in ["linear", None]:
return None
jax, _ = try_import_jax()
if name == "swish":
return jax.nn.swish
if name == "relu":
return jax.nn.relu
elif name == "tanh":
return jax.nn.hard_tanh
else:
assert framework in ["tf", "tfe", "tf2"],\
"Unsupported framework `{}`!".format(framework)
if name in ["linear", None]:
return None
tf1, tf, tfv = try_import_tf()
fn = getattr(tf.nn, name, None)
if fn is not None:
return fn

raise ValueError("Unknown activation ({}) for framework={}!".format(
name, framework))


def get_filter_config(shape):
Expand Down Expand Up @@ -40,7 +98,7 @@ def get_initializer(name, framework="tf"):
Args:
name (str): One of "xavier_uniform" (default), "xavier_normal".
framework (str): One of "tf" or "torch".
framework (str): One of "jax", "tf|tfe|tf2" or "torch".
Returns:
A framework-specific initializer function, e.g.
Expand All @@ -50,14 +108,33 @@ def get_initializer(name, framework="tf"):
Raises:
ValueError: If name is an unknown initializer.
"""
# Already a callable, return as-is.
if callable(name):
return name

if framework == "jax":
_, flax = try_import_jax()
assert flax is not None,\
"`flax` not installed. Try `pip install jax flax`."
import flax.linen as nn
if name in [None, "default", "xavier_uniform"]:
return nn.initializers.xavier_uniform()
elif name == "xavier_normal":
return nn.initializers.xavier_normal()
if framework == "torch":
_, nn = try_import_torch()
assert nn is not None,\
"`torch` not installed. Try `pip install torch`."
if name in [None, "default", "xavier_uniform"]:
return nn.init.xavier_uniform_
elif name == "xavier_normal":
return nn.init.xavier_normal_
else:
assert framework in ["tf", "tfe", "tf2"],\
"Unsupported framework `{}`!".format(framework)
tf1, tf, tfv = try_import_tf()
assert tf is not None,\
"`tensorflow` not installed. Try `pip install tensorflow`."
if name in [None, "default", "xavier_uniform"]:
return tf.keras.initializers.GlorotUniform
elif name == "xavier_normal":
Expand Down
4 changes: 2 additions & 2 deletions rllib/utils/deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def deprecation_warning(old, new=None, error=None):
Args:
old (str): A description of the "thing" that is to be deprecated.
new (Optional[str]): A description of the new "thing" that replaces it.
error (Optional[bool,Exception]): Whether or which exception to throw.
If True, throw ValueError.
error (Optional[Union[bool,Exception]]): Whether or which exception to
throw. If True, throw ValueError.
"""
msg = "`{}` has been deprecated.{}".format(
old, (" Use `{}` instead.".format(new) if new else ""))
Expand Down
3 changes: 2 additions & 1 deletion rllib/utils/exploration/curiosity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.torch_action_dist import TorchCategorical, \
TorchMultiCategorical
from ray.rllib.models.utils import get_activation_fn
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils import NullContextManager
from ray.rllib.utils.annotations import override
from ray.rllib.utils.exploration.exploration import Exploration
from ray.rllib.utils.framework import get_activation_fn, try_import_tf, \
from ray.rllib.utils.framework import try_import_tf, \
try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.tf_ops import get_placeholder, one_hot as tf_one_hot
Expand Down
7 changes: 6 additions & 1 deletion rllib/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from typing import Any, Optional

from ray.rllib.utils.deprecation import deprecation_warning
from ray.rllib.utils.typing import TensorStructType, TensorShape, TensorType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -252,7 +253,7 @@ def get_variable(value,
return value


# TODO: (sven) move to models/utils.py
# Deprecated: Use rllib.models.utils::get_activation_fn instead.
def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
"""Returns a framework specific activation function, given a name string.
Expand All @@ -268,6 +269,10 @@ def get_activation_fn(name: Optional[str] = None, framework: str = "tf"):
Raises:
ValueError: If name is an unknown activation function.
"""
deprecation_warning(
"rllib/utils/framework.py::get_activation_fn",
"rllib/models/utils.py::get_activation_fn",
error=False)
if framework == "torch":
if name in ["linear", None]:
return None
Expand Down
2 changes: 1 addition & 1 deletion rllib/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ray.rllib.utils.framework import try_import_jax, try_import_tf, \
try_import_torch

jax, flax = try_import_jax()
jax, _ = try_import_jax()
tf1, tf, tfv = try_import_tf()
if tf1:
eager_mode = None
Expand Down

0 comments on commit 8726521

Please sign in to comment.