Skip to content

Commit

Permalink
[nnx] experimental transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 2, 2024
1 parent 15e0e8d commit 6f87daa
Show file tree
Hide file tree
Showing 13 changed files with 1,447 additions and 113 deletions.
9 changes: 8 additions & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .nnx import compat as compat
from .nnx import traversals as traversals
from .nnx import filterlib as filterlib
from .nnx import transforms as transforms
from .nnx.filterlib import WithTag as WithTag
from .nnx.filterlib import PathContains as PathContains
from .nnx.filterlib import OfType as OfType
Expand Down Expand Up @@ -103,6 +104,8 @@
from .nnx.rnglib import RngCount as RngCount
from .nnx.rnglib import ForkStates as ForkStates
from .nnx.rnglib import fork as fork
from .nnx.rnglib import split_rngs as split_rngs
from .nnx.rnglib import restore_rngs as restore_rngs
from .nnx.spmd import PARTITION_NAME as PARTITION_NAME
from .nnx.spmd import get_partition_spec as get_partition_spec
from .nnx.spmd import get_named_sharding as get_named_sharding
Expand All @@ -122,15 +125,19 @@
from .nnx.transforms.looping import Scan as Scan
from .nnx.transforms.parallelization import Vmap as Vmap
from .nnx.transforms.parallelization import Pmap as Pmap
from .nnx.transforms.transforms import grad as grad
from .nnx.transforms.general import split_inputs as split_inputs
from .nnx.transforms.general import merge_inputs as merge_inputs
from .nnx.transforms.transforms import jit as jit
from .nnx.transforms.transforms import grad as grad
from .nnx.transforms.transforms import remat as remat
from .nnx.transforms.looping import scan as scan
from .nnx.transforms.transforms import value_and_grad as value_and_grad
from .nnx.transforms.parallelization import vmap as vmap
from .nnx.transforms.parallelization import pmap as pmap
from .nnx.transforms.transforms import eval_shape as eval_shape
from .nnx.transforms.transforms import cond as cond
from .nnx.transforms.experimental import vmap as experimental_vmap
from .nnx.transforms.experimental import StateAxes as StateAxes
from .nnx.variables import EMPTY as EMPTY
from .nnx.variables import A as A
from .nnx.variables import BatchStat as BatchStat
Expand Down
202 changes: 202 additions & 0 deletions flax/nnx/nnx/extract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import abc
import typing as tp

import jax
from jax._src.tree_util import broadcast_prefix

from flax import struct
from flax.nnx.nnx.state import State
from flax.typing import PathParts
from flax.nnx.nnx import graph


class Missing:
pass


MISSING = Missing()
A = tp.TypeVar('A')
E = tp.TypeVar('E', bound='Extractable')
Index = int


class Extractable(abc.ABC):
@property
@abc.abstractmethod
def index(self) -> Index: ...


class ExtractableStates(Extractable):
@property
@abc.abstractmethod
def states(self) -> tp.Iterable[State]: ...

@property
@abc.abstractmethod
def graphdef(self) -> graph.GraphDef[tp.Any]: ...


class ExtractionIndex(struct.PyTreeNode, Extractable):
"""Index of a graph node in a Pytree structure."""

_index: Index = struct.field(pytree_node=False)

@property
def index(self) -> Index:
return self._index


@tp.overload
def extract_graph_nodes(pytree: A, /) -> tuple[A, tuple[tp.Any, ...]]: ...


@tp.overload
def extract_graph_nodes(
pytree: A, /, *, prefix: tp.Any
) -> tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]: ...


def extract_graph_nodes(
pytree: A, /, *, prefix: tp.Any = MISSING
) -> (
tuple[A, tuple[tp.Any, ...]]
| tuple[A, tuple[tp.Any, ...], tuple[tp.Any, ...]]
):
"""Extracts all graph nodes from a pytree."""
nodes = graph.RefMap[tp.Any, Index]()
node_prefixes = []
leaves = []

prefix_leaves = broadcast_prefix(
prefix,
pytree,
is_leaf=lambda x: x is None,
)
key_leaves, treedef = jax.tree_util.tree_flatten_with_path(pytree)

assert len(key_leaves) == len(prefix_leaves)

for (keypath, leaf), prefix_leaf in zip(key_leaves, prefix_leaves):
if graph.is_graph_node(leaf):
if leaf not in nodes:
index = nodes[leaf] = len(nodes)
node_prefixes.append(prefix_leaf)
else:
index = nodes[leaf]
# check consistent aliasing
if prefix_leaf != node_prefixes[index]:
path_str = jax.tree_util.keystr(keypath)
raise ValueError(
f'Inconsistent aliasing detected. Node {type(leaf)} at path {path_str} '
f'has different prefixes: {prefix_leaf} and {node_prefixes[index]}.'
)
leaves.append(ExtractionIndex(index))
else:
leaves.append(leaf)

pytree_out = jax.tree.unflatten(treedef, leaves)

if prefix is MISSING:
return pytree_out, tuple(nodes)
else:
return pytree_out, tuple(nodes), tuple(node_prefixes)


def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A:
"""Inserts graph nodes into a pytree."""

def _maybe_insert(x):
if isinstance(x, Extractable):
return nodes[x.index]
return x

return jax.tree_util.tree_map(
_maybe_insert, pytree, is_leaf=lambda x: isinstance(x, Extractable)
)


def extract_indexes(
pytree, /, types: tuple[type[E], ...] | type[E] = Extractable
) -> tuple[E, ...]:
"""Extracts all indexes from a pytree."""
indexes: list[E] = []
for x in jax.tree.leaves(
pytree, is_leaf=lambda x: isinstance(x, Extractable)
):
if isinstance(x, Extractable):
if not isinstance(x, types):
raise ValueError(f'Expected Extractable of type {types}, got {type(x)}')
indexes.append(x)
return tuple(indexes)


def replace_indexes(
pytree: A,
replace_fn: tp.Callable[[Extractable], tp.Any],
/,
clear: bool = False,
) -> A:
def _replace_map_fn(x):
if isinstance(x, Extractable):
return replace_fn(x)
elif clear:
return None
return x

return jax.tree_util.tree_map(
_replace_map_fn, pytree, is_leaf=lambda x: isinstance(x, Extractable)
)


def merge_extractable_states(
extractable_states: tp.Sequence[ExtractableStates], /
):
if len(extractable_states) == 0:
raise ValueError('Expected at least one ExtractableStates object')

graphdef = extractable_states[0].graphdef
flat_state = []

for extractable_state in extractable_states:
flat_state.extend(
((extractable_state.index, *path), value)
for state in extractable_state.states
for path, value in state.flat_state().items()
)

state = State.from_flat_path(flat_state)
return graphdef, state


def check_consistent_aliasing(
nodes: tuple[tp.Any, ...], prefixes: tuple[tp.Any, ...]
):
node_prefixes = graph.RefMap[tp.Any, list[tuple[PathParts, tp.Any]]]()

# collect all paths and prefixes for each node
for node, prefix in zip(nodes, prefixes):
for path, value in graph.iter_graph(node):
if graph.is_graph_node(value):
if value in node_prefixes:
paths_prefixes = node_prefixes[value]
paths_prefixes.append((path, prefix))
else:
node_prefixes[value] = [(path, prefix)]

# check for inconsistent aliasing
node_msgs = []
for node, paths_prefixes in node_prefixes.items():
unique_prefixes = {prefix for _, prefix in paths_prefixes}
if len(unique_prefixes) > 1:
path_prefix_repr = '\n'.join(
f' {"/".join(map(str,path)) if path else "<root>"}: {prefix}'
for path, prefix in paths_prefixes
)
nodes_msg = f'Node: {type(node)}\n{path_prefix_repr}'
node_msgs.append(nodes_msg)

if node_msgs:
raise ValueError(
'Inconsistent aliasing detected. The following nodes have different prefixes:\n'
+ '\n'.join(node_msgs)
)
68 changes: 14 additions & 54 deletions flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import annotations

from collections import defaultdict
import dataclasses
import enum
import functools
Expand Down Expand Up @@ -61,8 +60,8 @@ def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:

@dataclasses.dataclass
class GraphContext(threading.local):
update_context_stacks: defaultdict[str, list[UpdateContext]] = (
dataclasses.field(default_factory=lambda: defaultdict(list))
update_context_stacks: dict[str, list[UpdateContext]] = dataclasses.field(
default_factory=dict
)


Expand Down Expand Up @@ -1021,21 +1020,27 @@ class UpdateContextManager:

def __enter__(self):
ctx = UpdateContext(self.tag, None, None)
GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx)
if self.tag not in GRAPH_CONTEXT.update_context_stacks:
GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx]
else:
GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx)
return ctx

def __exit__(self, *args):
stack = GRAPH_CONTEXT.update_context_stacks[self.tag]
if not stack:
if self.tag not in GRAPH_CONTEXT.update_context_stacks:
raise RuntimeError(
f'No update context found for tag {self.tag!r}, this is a bug.'
)
stack = GRAPH_CONTEXT.update_context_stacks[self.tag]

ctx = GRAPH_CONTEXT.update_context_stacks[self.tag].pop()
ctx = stack.pop()
# clear references
ctx.refmap = None
ctx.idxmap = None

if not stack:
del GRAPH_CONTEXT.update_context_stacks[self.tag]

def __call__(self, f: F) -> F:
@functools.wraps(f)
def update_context_manager_wrapper(*args, **kwargs):
Expand Down Expand Up @@ -1142,10 +1147,9 @@ def update_context(tag: str):

def current_update_context(tag: str) -> UpdateContext:
"""Returns the current active :class:`UpdateContext` for the given tag."""
stack = GRAPH_CONTEXT.update_context_stacks[tag]
if not stack:
if tag not in GRAPH_CONTEXT.update_context_stacks:
raise ValueError(f'No update context found for tag {tag!r}.')
return stack[-1]
return GRAPH_CONTEXT.update_context_stacks[tag][-1]


# --------------------------------------------------------
Expand Down Expand Up @@ -1595,50 +1599,6 @@ class Static(tp.Generic[A]):

jax.tree_util.register_static(Static)

# ---------------------------------------------------------
# insert/extract_graph_nodes API
# ---------------------------------------------------------


@dataclasses.dataclass(frozen=True)
class GraphNodeIndex:
"""Index of a graph node in a Pytree structure."""

index: Index


jax.tree_util.register_static(GraphNodeIndex)


def extract_graph_nodes(pytree: A, /) -> tuple[A, tuple[tp.Any, ...]]:
"""Extracts all graph nodes from a pytree."""
nodes = RefMap[tp.Any, Index]()

def _maybe_extract(x):
if is_graph_node(x):
if x not in nodes:
index = nodes[x] = len(nodes)
else:
index = nodes[x]
return GraphNodeIndex(index)
return x

return jax.tree_util.tree_map(_maybe_extract, pytree), tuple(nodes)


def insert_graph_nodes(pytree: A, nodes: tuple[tp.Any, ...], /) -> A:
"""Inserts graph nodes into a pytree."""

def _maybe_insert(x):
if isinstance(x, GraphNodeIndex):
return nodes[x.index]
return x

return jax.tree_util.tree_map(
_maybe_insert, pytree, is_leaf=lambda x: isinstance(x, GraphNodeIndex)
)


# ---------------------------------------------------------
# Pytree
# ---------------------------------------------------------
Expand Down
Loading

0 comments on commit 6f87daa

Please sign in to comment.