Skip to content

Commit

Permalink
Add functool.wraps() annotation to flax.nn.jit.
Browse files Browse the repository at this point in the history
At the moment, all the jit names in a jaxpr show up as "jitted". functools.partial does not forward names.

PiperOrigin-RevId: 648760671
  • Loading branch information
hawkinsp authored and Flax Authors committed Jul 2, 2024
1 parent 5b8265c commit 37123d5
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,10 @@ def core_fn(
return res

core_fns = [
functools.partial(core_fn, prewrapped_fn, class_fn)
for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns)
functools.wraps(class_fn)(
functools.partial(core_fn, prewrapped_fn, class_fn)
)
for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns)
]

# here we apply the given lifting transform to the scope-ingesting fn
Expand Down

0 comments on commit 37123d5

Please sign in to comment.