Skip to content

Releases: google/jax

Jaxlib release v0.4.30

18 Jun 15:07
Compare
Choose a tag to compare
jaxlib-v0.4.30

jaxlib version 0.4.30

Jax release v0.4.30

18 Jun 15:07
Compare
Choose a tag to compare
jax-v0.4.30

jax version 0.4.30

Jaxlib release v0.4.29

10 Jun 18:31
Compare
Choose a tag to compare
  • Bug fixes

    • Fixed a bug where XLA sharded some concatenation operations incorrectly,
      which manifested as an incorrect output for cumulative reductions (#21403).
    • Fixed a bug where XLA:CPU miscompiled certain matmul fusions
      (openxla/xla#13301).
    • Fixes a compiler crash on GPU (#21396).
  • Deprecations

    • jax.tree.map(f, None, non-None) now emits a DeprecationWarning, and will
      raise an error in a future version of jax. None is only a tree-prefix of
      itself. To preserve the current behavior, you can ask jax.tree.map to
      treat None as a leaf value by writing:
      jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).

JAX v0.4.29

10 Jun 18:31
Compare
Choose a tag to compare
  • Changes

    • We anticipate that this will be the last release of JAX and jaxlib
      supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
      plugin jaxlib (e.g. pip install jax[cuda12]).
    • JAX now requires ml_dtypes version 0.4.0 or newer.
    • Removed backwards-compatibility support for old usage of the
      jax.experimental.export API. It is not possible anymore to use
      from jax.experimental.export import export, and instead you should use
      from jax.experimental import export.
      The removed functionality has been deprecated since 0.4.24.
  • Deprecations

    • jax.sharding.XLACompatibleSharding is deprecated. Please use
      jax.sharding.Sharding.
    • jax.experimental.Exported.in_shardings has been renamed as
      jax.experimental.Exported.in_shardings_hlo. Same for out_shardings.
      The old names will be removed after 3 months.
    • Removed a number of previously-deprecated APIs:
      • from {mod}jax.core: non_negative_dim, DimSize, Shape
      • from {mod}jax.lax: tie_in
      • from {mod}jax.nn: normalize
      • from {mod}jax.interpreters.xla: backend_specific_translations,
        translations, register_translation, xla_destructure,
        TranslationRule, TranslationContext, XlaOp.
    • The tol argument of {func}jax.numpy.linalg.matrix_rank is being
      deprecated and will soon be removed. Use rtol instead.
    • The rcond argument of {func}jax.numpy.linalg.pinv is being
      deprecated and will soon be removed. Use rtol instead.
    • The deprecated jax.config submodule has been removed. To configure JAX
      use import jax and then reference the config object via jax.config.
    • {mod}jax.random APIs no longer accept batched keys, where previously
      some did unintentionally. Going forward, we recommend explicit use of
      {func}jax.vmap in such cases.
  • New Functionality

    • Added {func}jax.experimental.Exported.in_shardings_jax to construct
      shardings that can be used with the JAX APIs from the HloShardings
      that are stored in the Exported objects.

jaxlib v0.4.28

09 May 23:27
Compare
Choose a tag to compare
  • Bug fixes

    • Fixes a memory corruption bug in the type name of Array and JIT Python
      objects in Python 3.10 or earlier.
    • Fixed a warning '+ptx84' is not a recognized feature for this target
      under CUDA 12.4.
    • Fixed a slow compilation problem on CPU.
  • Changes

    • The Windows build is now built with Clang instead of MSVC.

JAX v0.4.28

09 May 23:28
Compare
Choose a tag to compare
  • Bug fixes

    • Reverted a change to make_jaxpr that was breaking Equinox (#21116).
  • Deprecations & removals

    • The kind argument to jax.numpy.sort and jax.numpy.argsort
      is now removed. Use stable=True or stable=False instead.
    • Removed get_compute_capability from the jax.experimental.pallas.gpu
      module. Use the compute_capability attribute of a GPU device, returned
      by jax.devices or jax.local_devices, instead.
  • Changes

    • The minimum jaxlib version of this release is 0.4.27.

Jaxlib release v0.4.27

07 May 17:44
Compare
Choose a tag to compare
jaxlib-v0.4.27

jaxlib version 0.4.27

Jax release v0.4.27

07 May 17:44
Compare
Choose a tag to compare
jax-v0.4.27

jax version 0.4.27

Jaxlib release v0.4.26

03 Apr 22:09
Compare
Choose a tag to compare
jaxlib-v0.4.26

jaxlib version 0.4.26

Jax release v0.4.26

03 Apr 22:09
Compare
Choose a tag to compare
jax-v0.4.26

jax version 0.4.26