Skip to content

Commit

Permalink
Merge pull request #21404 from superbobry:maint
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636862158
  • Loading branch information
jax authors committed May 24, 2024
2 parents e0bf783 + 0a694a1 commit e86c436
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.29

* Breaking changes
* JAX now requires ml_dtypes version 0.4.0 or newer.

* Deprecations
* Removed a number of previously-deprecated APIs:
* from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape`
Expand Down
17 changes: 6 additions & 11 deletions jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
except:
pass
else:
if _ml_dtypes_version < (0, 2, 0):
raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; "
if _ml_dtypes_version < (0, 4, 0):
raise ValueError("JAX requires ml_dtypes version 0.4.0 or newer; "
f"installed version is {ml_dtypes.__version__}.")

export = set_module('jax.dtypes')
Expand Down Expand Up @@ -500,26 +500,21 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis
This DAG maps each type to its immediately higher type on the lattice.
"""
b1, = _bool_types
_uint4, u1, u2, u4, u8, _int4, i1, i2, i4, i8 = _int_types
uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types
*f1_types, bf, f2, f4, f8 = _float_types
c4, c8 = _complex_types
i_, f_, c_ = _weak_types
if jax_numpy_dtype_promotion == 'standard':
out: dict[JAXType, list[JAXType]]
out = {
b1: [i_],
u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_],
i_: [uint4, int4, u1, i1],
int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_],
f_: [*f1_types, bf, f2, c_],
**{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8],
c_: [c4], c4: [c8], c8: [],
}
if _int4_dtype is not None:
out[i_].append(_int4_dtype)
out[_int4_dtype] = []
if _uint4_dtype is not None:
out[i_].append(_uint4_dtype)
out[_uint4_dtype] = []
return out
elif jax_numpy_dtype_promotion == 'strict':
return {
Expand Down
9 changes: 0 additions & 9 deletions jax/_src/public_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''):
def maybe_upcast(x):
if x.dtype in custom_float_dtypes:
return x.astype(np.float32)
# TODO(reedwm): Upcasting int4 to int8 will no longer be neccessary once
# ml_dtypes has a stable release with commit
# https://github.com/jax-ml/ml_dtypes/commit/348fd3704306cae97f617c38045cee6bc416bf10.
# Remove these checks once JAX depends on a version on ml_dtypes with that
# commit.
if x.dtype == _dtypes.int4:
return x.astype(np.int8)
if x.dtype == _dtypes.uint4:
return x.astype(np.uint8)
return x

a = maybe_upcast(a)
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def has_ext_modules(self):
'scipy>=1.9',
"scipy>=1.11.1; python_version>='3.12'",
'numpy>=1.22',
'ml_dtypes>=0.2.0',
'ml_dtypes>=0.4.0',
],
extras_require={
'cuda12_pip': [
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def load_version_module(pkg_path):
package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]},
python_requires='>=3.9',
install_requires=[
'ml_dtypes>=0.2.0',
'ml_dtypes>=0.4.0',
'numpy>=1.22',
"numpy>=1.23.2; python_version>='3.11'",
"numpy>=1.26.0; python_version>='3.12'",
Expand Down

0 comments on commit e86c436

Please sign in to comment.