Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add functool.wraps() annotation to flax.nn.jit.
At the moment, all the jit names in a jaxpr show up as "jitted". functools.partial does not forward names. PiperOrigin-RevId: 648751795
- Loading branch information