Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
f79d106 by George Necula <gcnecula@gmail.com>:

[export] Fix

A user reported an error when trying to export a function
that has a "lower" attribute (to impersonate a jitted function)
but does not have a "__name__" attribute.
The solution is to use the default name "<unnamed function>".

While I was at it I have added a `util.fun_name` to get
the name of a Callable, and I use it in several places.

COPYBARA_INTEGRATE_REVIEW=#21572 from gnecula:exp_fix_name f79d106
PiperOrigin-RevId: 639236990
  • Loading branch information
gnecula authored and jax authors committed Jun 1, 2024
1 parent 432159a commit be1e40d
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 9 deletions.
4 changes: 2 additions & 2 deletions jax/_src/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
as_hashable_function, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name,
merge_lists, partition_list)
merge_lists, partition_list, fun_name)

source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
Expand Down Expand Up @@ -577,7 +577,7 @@ def infer_params(*args):
in_axes_flat, args_flat)

params = dict(
name=getattr(fun, '__name__', '<unnamed function>'),
name=fun_name(fun),
in_axes=tuple(in_axes_flat),
out_axes_thunk=out_axes_thunk,
donated_invars=donated_invars,
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, weakref_lru_cache,
merge_lists, flatten, unflatten, subs_list)
merge_lists, flatten, unflatten, subs_list, fun_name)

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
Expand Down Expand Up @@ -339,7 +339,7 @@ def cache_miss(*args, **kwargs):

fun = jit_info.fun
cpp_pjit_f = xc._xla.pjit(
getattr(fun, "__name__", "<unnamed function>"),
fun_name(fun),
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry,
pxla.shard_arg,
Expand Down Expand Up @@ -652,7 +652,7 @@ def _infer_params(jit_info, args, kwargs):
out_layouts=out_layouts_flat,
resource_env=resource_env,
donated_invars=donated_invars,
name=getattr(flat_fun, '__name__', '<unknown>'),
name=fun_name(flat_fun),
keep_unused=keep_unused,
inline=inline,
)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def __init__(
args_info, # PyTree of ArgInfo
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False,
fun_name: str = "unknown",
fun_name: str = "<unnamed function>",
jaxpr: core.ClosedJaxpr | None = None):

self._lowering = lowering
Expand All @@ -634,7 +634,7 @@ def from_flat_info(cls,
donate_argnums: tuple[int, ...],
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False,
fun_name: str = "unknown",
fun_name: str = "<unnamed function>",
jaxpr: core.ClosedJaxpr | None = None):
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
Expand Down
5 changes: 4 additions & 1 deletion jax/_src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ def __eq__(self, other):
def wrap_name(name, transform_name):
return transform_name + '(' + name + ')'

def fun_name(fun: Callable):
return getattr(fun, "__name__", "<unnamed function>")

def canonicalize_axis(axis, num_dims) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)
Expand Down Expand Up @@ -399,7 +402,7 @@ def wraps(
"""
def wrapper(fun: T) -> T:
try:
name = getattr(wrapped, "__name__", "<unnamed function>")
name = fun_name(wrapped)
doc = getattr(wrapped, "__doc__", "") or ""
fun.__dict__.update(getattr(wrapped, "__dict__", {}))
fun.__annotations__ = getattr(wrapped, "__annotations__", {})
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
symbolic_scope = (d.scope, k_path)
continue
symbolic_scope[0]._check_same_scope(
d, when=f"when exporting {getattr(wrapped_fun_jax, '__name__')}",
d, when=f"when exporting {util.fun_name(wrapped_fun_jax)}",
self_descr=f"current (from {_shape_poly.args_kwargs_path_to_str(symbolic_scope[1])}) ",
other_descr=_shape_poly.args_kwargs_path_to_str(k_path))

Expand Down
19 changes: 19 additions & 0 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,25 @@ def f(x, y):
r"and found for 'w' \(args\[1\]\) scope .*", re.DOTALL)):
get_exported(f)(x_poly_spec, y_poly_spec)

def test_poly_export_callable_with_no_name(self):
# This was reported by a user
class MyCallable:
def __call__(self, x):
return jnp.sin(x)

# This makes it look like a jitted-function
def lower(self, x,
_experimental_lowering_parameters=None):
return jax.jit(self.__call__).lower(
x,
_experimental_lowering_parameters=_experimental_lowering_parameters)

a, = export.symbolic_shape("a,")
# No error
_ = get_exported(MyCallable())(
jax.ShapeDtypeStruct((a, a), dtype=np.float32)
)

@jtu.parameterized_filterable(
kwargs=[
dict(v=v)
Expand Down

0 comments on commit be1e40d

Please sign in to comment.