Skip to content

Commit

Permalink
[linen] allow checkpoint to cache
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633570309
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Jul 4, 2024
1 parent 0fb1777 commit aa353d8
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 66 deletions.
107 changes: 71 additions & 36 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
"""Jax transform lifting."""

import collections
import functools
from typing import (
Any,
TypeVar,
)
from collections.abc import Callable, Iterable, Mapping, Sequence
import contextlib
import dataclasses
import functools
import threading
from typing import Any, Generic, TypeVar
import warnings

from flax import traceback_util
from flax.typing import (
In,
Out,
InOutAxis,
InOutScanAxis,
In,
InOutAxis,
InOutScanAxis,
Out,
)
import jax
from jax import random
Expand All @@ -51,6 +51,26 @@

traceback_util.register_exclusion(__file__)

A = TypeVar('A')


@dataclasses.dataclass
class TransformContext(Generic[A], threading.local):
"""Context for a transform."""

stack: list[A] = dataclasses.field(default_factory=list)

@contextlib.contextmanager
def push(self, a: A):
self.stack.append(a)
try:
yield
finally:
self.stack.pop()

def get(self) -> A:
return self.stack[-1]


def tree_map_rngs(fn, tree):
"""Needed for mapping JAX random.* functions over PRNGKey leaves."""
Expand Down Expand Up @@ -1416,12 +1436,12 @@ def checkpoint(
This function is aliased to ``lift.remat`` just like ``jax.remat``.
Args:
fn: scope function for which intermediate computations should be
re-computed when computing gradients.
fn: scope function for which intermediate computations should be re-computed
when computing gradients.
variables: The variable collections that are lifted. By default all
collections are lifted.
rngs: The PRNG sequences that are lifted. By default all PRNG sequences
are lifted.
rngs: The PRNG sequences that are lifted. By default all PRNG sequences are
lifted.
concrete: Optional, boolean indicating whether ``fun`` may involve
value-dependent Python control flow (default ``False``). Support for such
control flow is optional, and disabled by default, because in some
Expand All @@ -1440,29 +1460,48 @@ def checkpoint(
arguments as static can avoid ConcretizationTypeErrors when tracing, but
at the cost of more retracing overheads.
policy: Experimental checkpoint policy, see ``jax.checkpoint``.
Returns:
A wrapped version of ``fn``. When computing gradients intermediate
computations will be re-computed when computing gradients.
"""
# add 2 to each static_argnums because we add two initial arguments to rematted
static_argnums_ = (0,) + tuple(i + 3 for i in static_argnums)

def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs):
# add 2 to each static_argnums because we add two initial arguments to rematted
static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums)
scope_fns: list[Callable] = []
repack_fns: list[Callable] = []

@functools.partial(
@functools.partial(
jax.remat,
concrete=concrete,
static_argnums=static_argnums_,
prevent_cse=prevent_cse,
policy=policy,
)
@functools.wraps(fn)
def rematted(variable_groups, rng_groups, *args, **kwargs):
scope = scope_fn(variable_groups, rng_groups)
y = fn(scope, *args, **kwargs)
return y, repack_fn(scope)
)
@functools.wraps(fn)
def rematted(hash_key, variable_groups, rng_groups, *args, **kwargs):
scope_fn = scope_fns[-1]
repack_fn = repack_fns[-1]
scope = scope_fn(variable_groups, rng_groups)
y = fn(scope, hash_key, *args, **kwargs)
return y, repack_fn(scope)

return rematted(variable_groups, rng_groups, *args, **kwargs)
def inner(
scope_fn,
repack_fn,
variable_groups,
rng_groups,
hash_key,
*args,
**kwargs,
):
repack_fns.append(repack_fn)
scope_fns.append(scope_fn)
try:
return rematted(hash_key, variable_groups, rng_groups, *args, **kwargs)
finally:
scope_fns.pop()
repack_fns.pop()

return pack(
inner,
Expand Down Expand Up @@ -1554,8 +1593,9 @@ def jit(
# Close over scope_fn & repack_fn to avoid recompilation
# this is impure but we use the fingerprint arg to differentiate between cases
# where scope_fn or repack_fn actually produce non-identical results.
scope_fn = None # type: Callable | None
repack_fn = None # type: Callable | None
# scope_fns: list[Callable] = []
# repack_fns: list[Callable] = []
jit_context = TransformContext[tuple[Callable, Callable]]()

@functools.partial(
jax.jit,
Expand All @@ -1567,33 +1607,28 @@ def jit(
)
@functools.wraps(fn)
def jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs):
nonlocal scope_fn, repack_fn
scope_fn, repack_fn = jit_context.get()
hash_key = fingerprint[1]
# fingerprint is only used to differentiate the cache signature
del fingerprint
# del fingerprint
scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable
y = fn(scope, hash_key, *args, **kwargs)
return y, repack_fn(scope) # pylint: disable=not-callable

def inner(
scope_fun,
repack_fun,
scope_fn,
repack_fn,
variable_groups,
rng_groups,
module_hash_key,
*args,
**kwargs,
):
nonlocal scope_fn, repack_fn
try:
scope_fn = scope_fun
repack_fn = repack_fun
with jit_context.push((scope_fn, repack_fn)):
scopes = jax.tree_util.tree_leaves(scope_fn(variable_groups, rng_groups))
mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes)
fingerprint = (mutable, module_hash_key)
return jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs)
finally:
scope_fn, repack_fn = None, None

return pack(
inner, (variables,), (variables,), (rngs,), name='jit', enable_kwargs=True
Expand Down
66 changes: 36 additions & 30 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def _get_fingerprint(name: str, value: Any) -> tuple[str, Any]:

if isinstance(obj, str):
return obj
elif hasattr(obj, '__fn_or_cls__'): # support PaxConfig objects
return _fingerprint_recursive(obj.__fn_or_cls__, path, seen_modules)
elif isinstance(obj, Module):
fingerprint: Any
if obj._id in seen_modules:
Expand Down Expand Up @@ -562,7 +564,7 @@ def _check_field_is_hashable(path: tuple[str, ...], x: Any):
raise ValueError(f"Value at '{path_name}' is not hashable: {e}") from e


def decorator_lift_transform_jit(class_fn, **trafo_kwargs):
def decorator_lift_transform_cached(transform, class_fn, **trafo_kwargs):
"""Decorator for lifted transform.
Similar to `decorator_lift_transform` but specialized for `jit`, it reuses the
Expand All @@ -572,7 +574,6 @@ def decorator_lift_transform_jit(class_fn, **trafo_kwargs):
# Due to the ordering of method decorators, we must wrap the class_fn
# with the module state management wrapper first to maintain Module state
# correctly.
transform = lift.jit
multi_scope = True

if isinstance(class_fn, tuple):
Expand Down Expand Up @@ -640,11 +641,12 @@ def core_fn(
return wrapped_fn


def module_class_lift_transform_jit(module_class, methods=None, **trafo_kwargs):
def module_class_lift_transform_cached(
transform, module_class, methods=None, **trafo_kwargs
):
"""Module class lift transform."""
# TODO(marcvanzee): Improve docstrings (#1977).
# TODO(levskaya): find nicer argument convention for multi-method case?
transform = lift.jit
trafo_args = ()

# Prepare per-method transform args, kwargs.
Expand Down Expand Up @@ -765,6 +767,24 @@ def lift_transform(
raise errors.TransformTargetError(target)


def lift_transform_cached(
transform, target, *trafo_args, methods=None, **trafo_kwargs
):
"""Applies to class or as a decorator on class fns."""
# TODO(marcvanzee): Improve docstrings (#1977).
if _is_module_class(target):
return module_class_lift_transform_cached(
transform, target, *trafo_args, methods=methods, **trafo_kwargs
)
# we presume this is being used as a function decorator in class definition
elif callable(target) and not isinstance(target, Module):
return decorator_lift_transform_cached(
transform, target, *trafo_args, **trafo_kwargs
)
else:
raise errors.TransformTargetError(target)


def lift_direct_transform(
transform: Callable[..., Any],
targets: tuple[Callable[..., Any], ...],
Expand Down Expand Up @@ -941,8 +961,8 @@ def jit(
A wrapped version of target, set up for just-in-time compilation.
"""
# TODO(marcvanzee): Improve docstrings (#1977).
if _is_module_class(target):
return module_class_lift_transform_jit(
return lift_transform_cached(
lift.jit,
target,
variables=variables,
rngs=rngs,
Expand All @@ -952,21 +972,7 @@ def jit(
device=device,
backend=backend,
methods=methods,
)
# we presume this is being used as a function decorator in class definition
elif callable(target) and not isinstance(target, Module):
return decorator_lift_transform_jit(
target,
variables=variables,
rngs=rngs,
static_argnums=static_argnums,
static_argnames=static_argnames,
donate_argnums=donate_argnums,
device=device,
backend=backend,
)
else:
raise errors.TransformTargetError(target)
)


def checkpoint(
Expand Down Expand Up @@ -1044,15 +1050,15 @@ def checkpoint(
# lifted function
static_argnums = jax.tree_util.tree_map(lambda x: x - 1, static_argnums)
return lift_transform(
lift.checkpoint,
target,
variables=variables,
rngs=rngs,
concrete=concrete,
static_argnums=static_argnums,
prevent_cse=prevent_cse,
policy=policy,
methods=methods,
lift.checkpoint,
target,
variables=variables,
rngs=rngs,
concrete=concrete,
static_argnums=static_argnums,
prevent_cse=prevent_cse,
policy=policy,
methods=methods,
)


Expand Down
39 changes: 39 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,45 @@ def nested_repeat(mdl):
self.assertEqual(setup_cntr, 128)
self.assertEqual(call_cntr, 64)

def test_checkpoint_caching(self):
n = 0

class Layer(nn.Module):
dout: int

@nn.compact
def __call__(self, x):
W = self.param('W', nn.initializers.ones, (x.shape[-1], self.dout))
return jnp.dot(x, W)

class Block(nn.Module):
dim: int
n_layers: int

@partial(nn.checkpoint, policy=jax.checkpoint_policies.checkpoint_dots)
@nn.compact
def __call__(self, x):
nonlocal n
n += 1
for _ in range(self.n_layers - 1):
x = Layer(self.dim)(x)
x = jnp.sin(x)
return Layer(self.dim)(x)

module = Block(4, 1)
x = jnp.ones((1, 4))
params = module.init(jax.random.PRNGKey(0), x)['params']
self.assertEqual(n, 1)

def predict(params, x):
return module.apply({'params': params}, x)

module.apply({'params': params}, x)
self.assertEqual(n, 2)

module.apply({'params': params}, x)
self.assertEqual(n, 2)

def test_multimethod_setup_calls(self):
cntr = 0

Expand Down

0 comments on commit aa353d8

Please sign in to comment.