-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add float8_e4m3 and float8_e3m4 types support #23585
base: main
Are you sure you want to change the base?
Conversation
Thanks for the contribution! I don't think we'll be able to bump our The good news is this is easy enough to do with a few version guards: if you look at the initial implementation of |
Here's an example of how this was handled in the past: https://github.com/google/jax/blob/jax-v0.4.12/jax/_src/dtypes.py#L71 Basically, we only define the dtype in JAX if it's defined in Another strategy we could use is the module-level |
Incidentally, the current TF pin is : If we release I suspect we could ease this process if we committed to semver for |
That said I'd probably do it the way Jake said for now and then we can think about the minimum version bump separately, there may be other factors I haven't considered (e.g., users being stuck on an older TF for whatever reason). |
0336705
to
5553be7
Compare
5553be7
to
090ff3e
Compare
I updated the PR and tested it with ml_dtypes 0.4.0 and 0.5.0 |
a6218f2
to
0b862ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good – at some point we should refactor things so that we keep a single source of truth for the list of custom float types, rather than re-defining it a half dozen times across the package. But that can be for another PR.
Please fix the lint issues – thanks! Also, the test failures look real. It seems that there's some place where the new float8 types must be registered |
0b862ed
to
be89ca7
Compare
MyPyfixed mypy issues btw, Contributing to JAX explains how to run lint/ruff/mypy/jupytext locally pre-commit run --all-files All passed now Regarding failed testsFAILED tests/export_test.py::JaxExportTest::test_poly_numeric_dtypes_dtype_float8_e3m4e4m3 and e3m4 were added to stablehlo 1.7.0 (3 weeks ago, Sep 4, 2024) jax/_src/export/_export.py uses
it returns target_version 1.5.0 workaround - use NONE instead of WEEK_4 - it returns 1.7.5 FAILED tests/array_test.py::JaxArrayTest::test_shards_have_correct_dtype17FAILED tests/dtypes_test.py::TestPromotionTables::testFloat8PromotionErrorThe tests work fine if I use XLA COMMIT_ID from XLA PR openxla/xla#16585 Add support for float8_e4m3 and float8_e3m4 types (in Review) I guess we need to put this PR on hold and rerun the tests once XLA PR-16585 is merged to XLA main and XLA_COMMIT is updated in JAX. StableHLO WEEK_4 issue should resolve itself in 1-2 weeks too. Test report on CPU pytest -n auto tests/
26921 passed, 12530 skipped in 492.32s (0:08:12) |
Description
Amazon has proposed two new FP8 types,
Float8E4M3
andFloat8E3M4
. These types are implemented in commercially available hardware Amazon EC2 Trn1 Instances, and added to MLIR builtin types, LLVM APFloat, ml_dtypes, StableHLO.XLA has Float8E4M3 and Float8E3M4 implementation in Review. See PR links in Related PRs section below.
This PR adds f8E4M3 and f8E3M4 types support to JAX.
f8E4M3
type follows IEEE 754 convention.f8E3M4
type follows IEEE 754 conventionRelated PRs:
How to build/install
This PR requires ml_dtype version 20240821 or later.
The current version on PyPI is 0.4.0, released on April 1, 2024, which is outdated. Therefore, ml_dtypes should be installed from source.
Related issue: jax-ml/ml_dtypes#185 [Question] Can we release a new version of ml_dtypes?
Smoke test