From 0a694a1b42c574fbc70d04bf38a7270a80bab60e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 23 May 2024 20:42:21 +0100 Subject: [PATCH] Bumped the minimum ml_dtypes version to 0.4.0 --- CHANGELOG.md | 3 +++ jax/_src/dtypes.py | 17 ++++++----------- jax/_src/public_test_util.py | 9 --------- jaxlib/setup.py | 2 +- setup.py | 2 +- 5 files changed, 11 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 46351308cb2c..b8fff4f72803 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.29 +* Breaking changes + * JAX now requires ml_dtypes version 0.4.0 or newer. + * Deprecations * Removed a number of previously-deprecated APIs: * from {mod}`jax.core`: `non_negative_dim`, `DimSize`, `Shape` diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index b4c9cfdf5d26..f85e4833e13c 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -42,8 +42,8 @@ except: pass else: - if _ml_dtypes_version < (0, 2, 0): - raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; " + if _ml_dtypes_version < (0, 4, 0): + raise ValueError("JAX requires ml_dtypes version 0.4.0 or newer; " f"installed version is {ml_dtypes.__version__}.") export = set_module('jax.dtypes') @@ -500,7 +500,7 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis This DAG maps each type to its immediately higher type on the lattice. """ b1, = _bool_types - _uint4, u1, u2, u4, u8, _int4, i1, i2, i4, i8 = _int_types + uint4, u1, u2, u4, u8, int4, i1, i2, i4, i8 = _int_types *f1_types, bf, f2, f4, f8 = _float_types c4, c8 = _complex_types i_, f_, c_ = _weak_types @@ -508,18 +508,13 @@ def _type_promotion_lattice(jax_numpy_dtype_promotion: str) -> dict[JAXType, lis out: dict[JAXType, list[JAXType]] out = { b1: [i_], - u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], - i_: [u1, i1], i1: [i2], i2: [i4], i4: [i8], i8: [f_], + uint4: [], u1: [i2, u2], u2: [i4, u4], u4: [i8, u8], u8: [f_], + i_: [uint4, int4, u1, i1], + int4: [], i1: [i2], i2: [i4], i4: [i8], i8: [f_], f_: [*f1_types, bf, f2, c_], **{t: [] for t in f1_types}, bf: [f4], f2: [f4], f4: [f8, c4], f8: [c8], c_: [c4], c4: [c8], c8: [], } - if _int4_dtype is not None: - out[i_].append(_int4_dtype) - out[_int4_dtype] = [] - if _uint4_dtype is not None: - out[i_].append(_uint4_dtype) - out[_uint4_dtype] = [] return out elif jax_numpy_dtype_promotion == 'strict': return { diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index bc18226f07a7..c7cb4ee20e1b 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -104,15 +104,6 @@ def _assert_numpy_allclose(a, b, atol=None, rtol=None, err_msg=''): def maybe_upcast(x): if x.dtype in custom_float_dtypes: return x.astype(np.float32) - # TODO(reedwm): Upcasting int4 to int8 will no longer be neccessary once - # ml_dtypes has a stable release with commit - # https://github.com/jax-ml/ml_dtypes/commit/348fd3704306cae97f617c38045cee6bc416bf10. - # Remove these checks once JAX depends on a version on ml_dtypes with that - # commit. - if x.dtype == _dtypes.int4: - return x.astype(np.int8) - if x.dtype == _dtypes.uint4: - return x.astype(np.uint8) return x a = maybe_upcast(a) diff --git a/jaxlib/setup.py b/jaxlib/setup.py index c6131b7ea29f..25b5cd41e955 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -64,7 +64,7 @@ def has_ext_modules(self): 'scipy>=1.9', "scipy>=1.11.1; python_version>='3.12'", 'numpy>=1.22', - 'ml_dtypes>=0.2.0', + 'ml_dtypes>=0.4.0', ], extras_require={ 'cuda12_pip': [ diff --git a/setup.py b/setup.py index 661763345b2b..481e24a77335 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ def generate_proto(source): package_data={'jax': ['py.typed', "*.pyi", "**/*.pyi"]}, python_requires='>=3.9', install_requires=[ - 'ml_dtypes>=0.2.0', + 'ml_dtypes>=0.4.0', 'numpy>=1.22', "numpy>=1.23.2; python_version>='3.11'", "numpy>=1.26.0; python_version>='3.12'",