Skip to content

Commit

Permalink
Plumb spmd_axis_name through transforms.vmap through to JAX vmap
Browse files Browse the repository at this point in the history
This ensure transforms.vmap matches lift.vmap following #2390

PiperOrigin-RevId: 467287573
  • Loading branch information
James Lee-Thorp authored and Flax Authors committed Aug 19, 2022
1 parent 2b73efd commit 4449d87
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,11 @@ def vmap(target: Target,
variable_axes: Mapping[lift.CollectionFilter,
lift.InOutAxis] = FrozenDict(),
split_rngs: Mapping[lift.PRNGSequenceFilter, bool] = FrozenDict(),
in_axes=0, out_axes=0,
in_axes=0,
out_axes=0,
axis_size: Optional[int] = None,
axis_name: Optional[str] = None,
spmd_axis_name: Optional[str] = None,
methods=None) -> Target:
"""A lifted version of ``jax.vmap``.
Expand Down Expand Up @@ -486,17 +488,26 @@ def vmap(target: Target,
with parallel reduction primitives (e.g. `jax.lax.pmean`,
`jax.lax.ppermute`, etc.)
methods: If `target` is a `Module`, the methods of `Module` to vmap over.
spmd_axis_name: Axis name added to any pjit sharding constraints appearing
in `fn`. See also
https://github.com/google/flax/blob/main/flax/linen/partitioning.py.
Returns:
A batched/vectorized version of ``target``, with the same arguments but with
extra axes at positions indicated by ``in_axes``, and the same return value,
but with extra axes at positions indicated by ``out_axes``.
"""
return lift_transform(
lift.vmap, target, variable_axes, split_rngs,
lift.vmap,
target,
variable_axes,
split_rngs,
methods=methods,
in_axes=in_axes, out_axes=out_axes,
axis_size=axis_size, axis_name=axis_name)
in_axes=in_axes,
out_axes=out_axes,
axis_size=axis_size,
axis_name=axis_name,
spmd_axis_name=spmd_axis_name)


def jit(target: Target,
Expand Down

0 comments on commit 4449d87

Please sign in to comment.