Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float8_e4m3 and float8_e3m4 types support #23585

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apivovarov
Copy link
Contributor

@apivovarov apivovarov commented Sep 12, 2024

Description

Amazon has proposed two new FP8 types, Float8E4M3 and Float8E3M4. These types are implemented in commercially available hardware Amazon EC2 Trn1 Instances, and added to MLIR builtin types, LLVM APFloat, ml_dtypes, StableHLO.

XLA has Float8E4M3 and Float8E3M4 implementation in Review. See PR links in Related PRs section below.

This PR adds f8E4M3 and f8E3M4 types support to JAX.

f8E4M3 type follows IEEE 754 convention.

f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 17 =6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Related PRs:

  • LLVM PR-97179 [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-97118 [MLIR] Add f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
  • LLVM PR-101230 [MLIR] Add f8E3M4 IEEE 754 type (Merged)
  • StableHLO PR-2486 [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
  • StableHLO PR-2482 Add f8E4M3 and f8E3M4 types support (Merged)
  • ml_dtypes PR-161 Add float8_e4m3 (Merged)
  • ml_dtypes PR-171 Add float8_e3m4 (Merged)
  • XLA PR-17075 [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Merged)
  • XLA PR-16585 Add support for float8_e4m3 and float8_e3m4 types (in Review)

How to build/install

This PR requires ml_dtype version 20240821 or later.

The current version on PyPI is 0.4.0, released on April 1, 2024, which is outdated. Therefore, ml_dtypes should be installed from source.

Related issue: jax-ml/ml_dtypes#185 [Question] Can we release a new version of ml_dtypes?

## Install the latest ml_dtypes
cd ml_dtypes
pip3 install .

## Install jaxlib and JAX
cd jax

### install jaxlib
python3 build/build.py
pip3 install dist/*.whl

### install jax
pip3 install .

Smoke test

import jax
import jax.numpy as jnp
from jax import Array, random

key1 = random.PRNGKey(41)
key2 = random.PRNGKey(42)
a = random.uniform(key1, shape=(4,4), dtype="float8_e4m3")
b = random.uniform(key2, shape=(4,4), dtype="float8_e4m3")

def foo(a, b):
  return a@b

# StableHLO
print(jax.jit(foo).lower(a,b).as_text())

# HLO (optimized for cpu)
print(jax.jit(foo).lower(a,b).compile().as_text())

c = foo(a, b)

Array([[1, 0.9375, 1.25, 0.5625],
       [0.75, 0.625, 0.75, 0.5],
       [0.8125, 0.8125, 1.25, 0.40625],
       [0.8125, 0.875, 1.25, 0.4375]], dtype=float8_e4m3)

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 12, 2024

Thanks for the contribution! I don't think we'll be able to bump our ml_dtypes requirement any time soon, so if we want to merge this we'll have to make it robust to older ml_dtypes versions (the reason is that tensorflow pins a specific ml_dtypes version, and some workflows depend on installing both JAX and tensorflow.

The good news is this is easy enough to do with a few version guards: if you look at the initial implementation of float8 types in JAX, you can see the pattern we used previously.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 12, 2024

Here's an example of how this was handled in the past: https://github.com/google/jax/blob/jax-v0.4.12/jax/_src/dtypes.py#L71

Basically, we only define the dtype in JAX if it's defined in ml_dtypes.

Another strategy we could use is the module-level __getattr__ for these types, so that if the ml_dtypes version is too old, we raise an error that specifies what version is required.

@hawkinsp
Copy link
Collaborator

Incidentally, the current TF pin is : Requires-Dist: ml-dtypes <0.5.0,>=0.3.1.

If we release ml_dtypes as 0.4.1 instead of 0.5.0 we probably could bump the minimum version.

I suspect we could ease this process if we committed to semver for ml_dtypes so TF felt like they could be less conservative in their pins. (Adding dtypes is hopefully safe!)

@hawkinsp
Copy link
Collaborator

That said I'd probably do it the way Jake said for now and then we can think about the minimum version bump separately, there may be other factors I haven't considered (e.g., users being stuck on an older TF for whatever reason).

@apivovarov apivovarov changed the title Add float8_e4m3 type support Add float8_e4m3 and float8_e3m4 types support Sep 26, 2024
@apivovarov
Copy link
Contributor Author

That said I'd probably do it the way Jake said for now and then we can think about the minimum version bump separately, there may be other factors I haven't considered (e.g., users being stuck on an older TF for whatever reason).

I updated the PR and tested it with ml_dtypes 0.4.0 and 0.5.0
@jakevdp @hawkinsp

third_party/xla/workspace.bzl Outdated Show resolved Hide resolved
jax/_src/dtypes.py Outdated Show resolved Hide resolved
jax/_src/export/serialization.py Outdated Show resolved Hide resolved
jax/_src/interpreters/mlir.py Outdated Show resolved Hide resolved
jax/_src/lax/lax.py Outdated Show resolved Hide resolved
jax/_src/public_test_util.py Outdated Show resolved Hide resolved
jax/numpy/__init__.pyi Outdated Show resolved Hide resolved
tests/dtypes_test.py Outdated Show resolved Hide resolved
@jakevdp jakevdp self-assigned this Sep 26, 2024
@apivovarov apivovarov force-pushed the float8_e4m3 branch 3 times, most recently from a6218f2 to 0b862ed Compare September 27, 2024 03:41
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good – at some point we should refactor things so that we keep a single source of truth for the list of custom float types, rather than re-defining it a half dozen times across the package. But that can be for another PR.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 27, 2024

Please fix the lint issues – thanks! Also, the test failures look real. It seems that there's some place where the new float8 types must be registered

@apivovarov
Copy link
Contributor Author

apivovarov commented Sep 27, 2024

Please fix the lint issues – thanks! Also, the test failures look real. It seems that there's some place where the new float8 types must be registered

MyPy

fixed mypy issues

btw, Contributing to JAX explains how to run lint/ruff/mypy/jupytext locally

pre-commit run --all-files

All passed now

Regarding failed tests

FAILED tests/export_test.py::JaxExportTest::test_poly_numeric_dtypes_dtype_float8_e3m4

e4m3 and e3m4 were added to stablehlo 1.7.0 (3 weeks ago, Sep 4, 2024)

jax/_src/export/_export.py uses

target_version = hlo.get_version_from_compatibility_requirement(
      hlo.StablehloCompatibilityRequirement.WEEK_4)

it returns target_version 1.5.0

workaround - use NONE instead of WEEK_4 - it returns 1.7.5

FAILED tests/array_test.py::JaxArrayTest::test_shards_have_correct_dtype17

FAILED tests/dtypes_test.py::TestPromotionTables::testFloat8PromotionError

The tests work fine if I use XLA COMMIT_ID from XLA PR openxla/xla#16585 Add support for float8_e4m3 and float8_e3m4 types (in Review)

I guess we need to put this PR on hold and rerun the tests once XLA PR-16585 is merged to XLA main and XLA_COMMIT is updated in JAX. StableHLO WEEK_4 issue should resolve itself in 1-2 weeks too.

Test report on CPU

pytest -n auto tests/

26921 passed, 12530 skipped in 492.32s (0:08:12)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants