From bea6d6b27c366cd07dd5202356f372e02c1f3f9b Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 20 May 2024 04:56:26 -0700 Subject: [PATCH] Replace deprecated `jax.tree_*` functions with `jax.tree.*` The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 635420186 Change-Id: Ie71a2deb905622b947a9b075ce55bcb1bff46462 --- acme/agents/jax/ars/builder.py | 5 +++-- acme/agents/jax/bc/learning.py | 2 +- acme/agents/jax/cql/learning.py | 8 +++++--- acme/agents/jax/mbop/dataset.py | 4 ++-- acme/agents/jax/mbop/ensemble.py | 18 ++++++++++-------- acme/agents/jax/mbop/ensemble_test.py | 14 ++++++++------ acme/agents/jax/mbop/losses.py | 20 ++++++++++++-------- acme/agents/jax/mbop/mppi.py | 5 +++-- acme/agents/jax/mpo/learning.py | 4 ++-- acme/agents/jax/mpo/networks.py | 2 +- acme/agents/jax/mpo/utils.py | 13 +++++++------ acme/agents/jax/r2d2/learning.py | 4 ++-- acme/agents/jax/sac/learning.py | 5 +++-- acme/datasets/tfds.py | 2 +- 14 files changed, 60 insertions(+), 46 deletions(-) diff --git a/acme/agents/jax/ars/builder.py b/acme/agents/jax/ars/builder.py index 01fc8c3fcb..27b715bdb1 100644 --- a/acme/agents/jax/ars/builder.py +++ b/acme/agents/jax/ars/builder.py @@ -50,8 +50,9 @@ def apply( normalized_obs = normalization_apply_fn(obs, normalization_params) action = policy_network.apply(policy_params, normalized_obs) return action, { - 'params_key': - jax.tree_map(lambda x: jnp.expand_dims(x, axis=0), params_key) + 'params_key': jax.tree.map( + lambda x: jnp.expand_dims(x, axis=0), params_key + ) } return apply diff --git a/acme/agents/jax/bc/learning.py b/acme/agents/jax/bc/learning.py index 11406bdaed..46eb2607a9 100644 --- a/acme/agents/jax/bc/learning.py +++ b/acme/agents/jax/bc/learning.py @@ -194,7 +194,7 @@ def get_variables(self, names: List[str]) -> List[networks_lib.Params]: def save(self) -> TrainingState: # Serialize only the first replica of parameters and optimizer state. - return jax.tree_map(utils.get_from_first_device, self._state) + return jax.tree.map(utils.get_from_first_device, self._state) def restore(self, state: TrainingState): self._state = utils.replicate_in_all_devices(state) diff --git a/acme/agents/jax/cql/learning.py b/acme/agents/jax/cql/learning.py index 82b1406861..226375b630 100644 --- a/acme/agents/jax/cql/learning.py +++ b/acme/agents/jax/cql/learning.py @@ -337,9 +337,11 @@ def update_step( critic_grads, state.critic_optimizer_state) critic_params = optax.apply_updates(state.critic_params, critic_update) - new_target_critic_params = jax.tree_map( - lambda x, y: x * (1 - tau) + y * tau, state.target_critic_params, - critic_params) + new_target_critic_params = jax.tree.map( + lambda x, y: x * (1 - tau) + y * tau, + state.target_critic_params, + critic_params, + ) metrics = { 'critic_loss': critic_loss, diff --git a/acme/agents/jax/mbop/dataset.py b/acme/agents/jax/mbop/dataset.py index 22bfdf2065..bfa56a7956 100644 --- a/acme/agents/jax/mbop/dataset.py +++ b/acme/agents/jax/mbop/dataset.py @@ -210,11 +210,11 @@ def get_normalization_stats( """ # Set up normalization: example = next(iterator) - unbatched_single_example = jax.tree_map(lambda x: x[0, PREVIOUS, :], example) + unbatched_single_example = jax.tree.map(lambda x: x[0, PREVIOUS, :], example) mean_std = running_statistics.init_state(unbatched_single_example) for batch in itertools.islice(iterator, num_normalization_batches - 1): - example = jax.tree_map(lambda x: x[:, PREVIOUS, :], batch) + example = jax.tree.map(lambda x: x[:, PREVIOUS, :], batch) mean_std = running_statistics.update(mean_std, example) return mean_std diff --git a/acme/agents/jax/mbop/ensemble.py b/acme/agents/jax/mbop/ensemble.py index c7ccc412c8..d468442352 100644 --- a/acme/agents/jax/mbop/ensemble.py +++ b/acme/agents/jax/mbop/ensemble.py @@ -100,15 +100,17 @@ def apply_round_robin(base_apply: Callable[[networks.Params, Any], Any], num_networks = jax.tree_util.tree_leaves(params)[0].shape[0] # Reshape args and kwargs for the round-robin: - args = jax.tree_map( - functools.partial(_split_batch_dimension, num_networks), args) - kwargs = jax.tree_map( - functools.partial(_split_batch_dimension, num_networks), kwargs) + args = jax.tree.map( + functools.partial(_split_batch_dimension, num_networks), args + ) + kwargs = jax.tree.map( + functools.partial(_split_batch_dimension, num_networks), kwargs + ) # `out.shape` is `(num_networks, initial_batch_size/num_networks, ...) out = jax.vmap(base_apply)(params, *args, **kwargs) # Reshape to [initial_batch_size, ]. Using the 'F' order # forces the original values to the last dimension. - return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:], order='F'), out) + return jax.tree.map(lambda x: x.reshape((-1,) + x.shape[2:], order='F'), out) def apply_all(base_apply: Callable[[networks.Params, Any], Any], @@ -133,8 +135,8 @@ def apply_all(base_apply: Callable[[networks.Params, Any], Any], # `num_networks` is the size of the batch dimension in `params`. num_networks = jax.tree_util.tree_leaves(params)[0].shape[0] - args = jax.tree_map(functools.partial(_repeat_n, num_networks), args) - kwargs = jax.tree_map(functools.partial(_repeat_n, num_networks), kwargs) + args = jax.tree.map(functools.partial(_repeat_n, num_networks), args) + kwargs = jax.tree.map(functools.partial(_repeat_n, num_networks), kwargs) # `out` is of shape `(num_networks, batch_size, )`. return jax.vmap(base_apply)(params, *args, **kwargs) @@ -155,7 +157,7 @@ def apply_mean(base_apply: Callable[[networks.Params, Any], Any], Output shape will be [batch_size, ] """ out = apply_all(base_apply, params, *args, **kwargs) - return jax.tree_map(functools.partial(jnp.mean, axis=0), out) + return jax.tree.map(functools.partial(jnp.mean, axis=0), out) def make_ensemble(base_network: networks.FeedForwardNetwork, diff --git a/acme/agents/jax/mbop/ensemble_test.py b/acme/agents/jax/mbop/ensemble_test.py index 9890a78121..685563681d 100644 --- a/acme/agents/jax/mbop/ensemble_test.py +++ b/acme/agents/jax/mbop/ensemble_test.py @@ -52,10 +52,10 @@ def struct_params_adding_ffn(sx: Any) -> networks.FeedForwardNetwork: """Like params_adding_ffn, but with pytree inputs, preserves structure.""" def init_fn(key, sx=sx): - return jax.tree_map(lambda x: jax.random.uniform(key, x.shape), sx) + return jax.tree.map(lambda x: jax.random.uniform(key, x.shape), sx) def apply_fn(params, x): - return jax.tree_map(lambda p, v: p + v, params, x) + return jax.tree.map(lambda p, v: p + v, params, x) return networks.FeedForwardNetwork(init=init_fn, apply=apply_fn) @@ -291,9 +291,10 @@ def test_round_robin_random(self): for i in range(9): np.testing.assert_allclose( out[i], - ffn.apply(jax.tree_map(lambda p, i=i: p[i % 3], params), bx[i]), - atol=1E-5, - rtol=1E-5) + ffn.apply(jax.tree.map(lambda p, i=i: p[i % 3], params), bx[i]), + atol=1e-5, + rtol=1e-5, + ) def test_mean_random(self): x = jnp.ones(10) @@ -318,7 +319,8 @@ def test_mean_random(self): # Check results explicitly: all_members = jnp.concatenate([ jnp.expand_dims( - ffn.apply(jax.tree_map(lambda p, i=i: p[i], params), bx), axis=0) + ffn.apply(jax.tree.map(lambda p, i=i: p[i], params), bx), axis=0 + ) for i in range(3) ]) batch_means = jnp.mean(all_members, axis=0) diff --git a/acme/agents/jax/mbop/losses.py b/acme/agents/jax/mbop/losses.py index 4ec911f431..fb74834b56 100644 --- a/acme/agents/jax/mbop/losses.py +++ b/acme/agents/jax/mbop/losses.py @@ -53,11 +53,13 @@ def world_model_loss(apply_fn: Callable[[networks.Observation, networks.Action], Returns: A scalar loss value as jnp.ndarray. """ - observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...], - steps.observation) + observation_t = jax.tree.map( + lambda obs: obs[:, dataset.CURRENT, ...], steps.observation + ) action_t = steps.action[:, dataset.CURRENT, ...] - observation_tp1 = jax.tree_map(lambda obs: obs[:, dataset.NEXT, ...], - steps.observation) + observation_tp1 = jax.tree.map( + lambda obs: obs[:, dataset.NEXT, ...], steps.observation + ) reward_t = steps.reward[:, dataset.CURRENT, ...] (predicted_observation_tp1, predicted_reward_t) = apply_fn(observation_t, action_t) @@ -86,8 +88,9 @@ def policy_prior_loss( Returns: A scalar loss value as jnp.ndarray. """ - observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...], - steps.observation) + observation_t = jax.tree.map( + lambda obs: obs[:, dataset.CURRENT, ...], steps.observation + ) action_tm1 = steps.action[:, dataset.PREVIOUS, ...] action_t = steps.action[:, dataset.CURRENT, ...] @@ -109,8 +112,9 @@ def return_loss(apply_fn: Callable[[networks.Observation, networks.Action], Returns: A scalar loss value as jnp.ndarray. """ - observation_t = jax.tree_map(lambda obs: obs[:, dataset.CURRENT, ...], - steps.observation) + observation_t = jax.tree.map( + lambda obs: obs[:, dataset.CURRENT, ...], steps.observation + ) action_t = steps.action[:, dataset.CURRENT, ...] n_step_return_t = steps.extras[dataset.N_STEP_RETURN][:, dataset.CURRENT, ...] diff --git a/acme/agents/jax/mbop/mppi.py b/acme/agents/jax/mbop/mppi.py index 1731a9940a..42eaf683eb 100644 --- a/acme/agents/jax/mbop/mppi.py +++ b/acme/agents/jax/mbop/mppi.py @@ -183,8 +183,9 @@ def mppi_planner( policy_prior_state = policy_prior.init(random_key) # Broadcast so that we have n_trajectories copies of each: - observation_t = jax.tree_map( - functools.partial(_repeat_n, config.n_trajectories), observation) + observation_t = jax.tree.map( + functools.partial(_repeat_n, config.n_trajectories), observation + ) action_tm1 = jnp.broadcast_to(action_trajectory_tm1[0], (config.n_trajectories,) + action_trajectory_tm1[0].shape) diff --git a/acme/agents/jax/mpo/learning.py b/acme/agents/jax/mpo/learning.py index 8e9f3e7821..b73f089092 100644 --- a/acme/agents/jax/mpo/learning.py +++ b/acme/agents/jax/mpo/learning.py @@ -681,7 +681,7 @@ def _sgd_step( dual_params.log_penalty_temperature) elif isinstance(dual_params, discrete_losses.CategoricalMPOParams): dual_metrics['params/dual/log_alpha_avg'] = dual_params.log_alpha - metrics.update(jax.tree_map(jnp.mean, dual_metrics)) + metrics.update(jax.tree.map(jnp.mean, dual_metrics)) return new_state, metrics @@ -733,7 +733,7 @@ def get_variables(self, names: List[str]) -> network_lib.Params: return [variables[name] for name in names] def save(self) -> TrainingState: - return jax.tree_map(mpo_utils.get_from_first_device, self._state) + return jax.tree.map(mpo_utils.get_from_first_device, self._state) def restore(self, state: TrainingState): self._state = utils.replicate_in_all_devices(state, self._local_devices) diff --git a/acme/agents/jax/mpo/networks.py b/acme/agents/jax/mpo/networks.py index 3097cdeae5..4c8a249c73 100644 --- a/acme/agents/jax/mpo/networks.py +++ b/acme/agents/jax/mpo/networks.py @@ -259,7 +259,7 @@ def critic_fn(observation: types.NestedArray, def add_batch(nest, batch_size: Optional[int]): """Adds a batch dimension at axis 0 to the leaves of a nested structure.""" broadcast = lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape) - return jax.tree_map(broadcast, nest) + return jax.tree.map(broadcast, nest) def w_init_identity(shape: Sequence[int], dtype) -> jnp.ndarray: diff --git a/acme/agents/jax/mpo/utils.py b/acme/agents/jax/mpo/utils.py index ba4e60420d..d11dc8b840 100644 --- a/acme/agents/jax/mpo/utils.py +++ b/acme/agents/jax/mpo/utils.py @@ -42,7 +42,7 @@ def _slice_and_maybe_to_numpy(x): x = x[0] return _fetch_devicearray(x) if as_numpy else x - return jax.tree_map(_slice_and_maybe_to_numpy, nest) + return jax.tree.map(_slice_and_maybe_to_numpy, nest) def rolling_window(x: jnp.ndarray, @@ -80,11 +80,11 @@ def tree_map_distribution( if isinstance(x, distrax.Distribution): safe_f = lambda y: f(y) if isinstance(y, jnp.ndarray) else y nil, tree_data = x.tree_flatten() - new_tree_data = jax.tree_map(safe_f, tree_data) + new_tree_data = jax.tree.map(safe_f, tree_data) new_x = x.tree_unflatten(new_tree_data, nil) return new_x elif isinstance(x, tfd.Distribution): - return jax.tree_map(f, x) + return jax.tree.map(f, x) else: return f(x) @@ -95,8 +95,9 @@ def make_sequences_from_transitions( """Convert a batch of transitions into a batch of 1-step sequences.""" stack = lambda x, y: jnp.stack((x, y), axis=num_batch_dims) duplicate = lambda x: stack(x, x) - observation = jax.tree_map(stack, transitions.observation, - transitions.next_observation) + observation = jax.tree.map( + stack, transitions.observation, transitions.next_observation + ) reward = duplicate(transitions.reward) return adders.Step( # pytype: disable=wrong-arg-types # jnp-type @@ -105,5 +106,5 @@ def make_sequences_from_transitions( reward=reward, discount=duplicate(transitions.discount), start_of_episode=jnp.zeros_like(reward, dtype=jnp.bool_), - extras=jax.tree_map(duplicate, transitions.extras), + extras=jax.tree.map(duplicate, transitions.extras), ) diff --git a/acme/agents/jax/r2d2/learning.py b/acme/agents/jax/r2d2/learning.py index 1b6df2186e..f3528f9030 100644 --- a/acme/agents/jax/r2d2/learning.py +++ b/acme/agents/jax/r2d2/learning.py @@ -101,14 +101,14 @@ def loss( # Maybe burn the core state in. if burn_in_length: - burn_obs = jax.tree_map(lambda x: x[:burn_in_length], data.observation) + burn_obs = jax.tree.map(lambda x: x[:burn_in_length], data.observation) key_grad, key1, key2 = jax.random.split(key_grad, 3) _, online_state = networks.unroll(params, key1, burn_obs, online_state) _, target_state = networks.unroll(target_params, key2, burn_obs, target_state) # Only get data to learn on from after the end of the burn in period. - data = jax.tree_map(lambda seq: seq[burn_in_length:], data) + data = jax.tree.map(lambda seq: seq[burn_in_length:], data) # Unroll on sequences to get online and target Q-Values. key1, key2 = jax.random.split(key_grad) diff --git a/acme/agents/jax/sac/learning.py b/acme/agents/jax/sac/learning.py index c10b11e1ce..2360884e32 100644 --- a/acme/agents/jax/sac/learning.py +++ b/acme/agents/jax/sac/learning.py @@ -176,8 +176,9 @@ def update_step( critic_grads, state.q_optimizer_state) q_params = optax.apply_updates(state.q_params, critic_update) - new_target_q_params = jax.tree_map(lambda x, y: x * (1 - tau) + y * tau, - state.target_q_params, q_params) + new_target_q_params = jax.tree.map( + lambda x, y: x * (1 - tau) + y * tau, state.target_q_params, q_params + ) metrics = { 'critic_loss': critic_loss, diff --git a/acme/datasets/tfds.py b/acme/datasets/tfds.py index 5c0a236901..8d68ee7648 100644 --- a/acme/datasets/tfds.py +++ b/acme/datasets/tfds.py @@ -137,7 +137,7 @@ def __init__(self, # we capture the whole dataset. size = _dataset_size_upperbound(dataset) data = next(dataset.batch(size).as_numpy_iterator()) - self._dataset_size = jax.tree_flatten( + self._dataset_size = jax.tree.flatten( jax.tree_util.tree_map(lambda x: x.shape[0], data) )[0][0] device = jax_utils._pmap_device_order()