Skip to content

Commit

Permalink
Merge branch 'main' into f8E4M3_f8E3M4
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Sep 3, 2024
2 parents 7f4c57d + d68ab07 commit 526cebc
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 0 deletions.
178 changes: 178 additions & 0 deletions rfcs/20240808-f8E4M3_f8E3M4.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# RFC: Float8E4M3 and Float8E3M4

Status: In Review<br/>
Initial version: 8/8/2024<br/>
Last updated: 8/9/2024<br/>
Discussion thread: [PR-2486](https://github.com/openxla/stablehlo/pull/2486)
[RFC] Add f8E4M3 and f8E3M4 types support

## Summary

Amazon has proposed two new FP8 types, Float8E4M3 and Float8E3M4. These
types are implemented in commercially available hardware[^1], and added to MLIR
builtin types[^2]˒[^3] and LLVM APFloat[^4]˒[^5].

Both Float8E4M3 and Float8E3M4 follows IEEE 754 convention similar to existing
type Float8E5M2.

### Float8E4M3

8-bit floating point type with 1 sign bit, 4 bits exponent and 3 bits mantissa
following IEEE-754 conventions with bit layout S1E4M3.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Minimum stored exponent value: 1 (binary 0001)
- Maximum stored exponent value: 14 (binary 1110)
- Minimum unbiased exponent value: 1 − 7 = −6
- Maximum unbiased exponent value: 14 - 7 = 7
- Precision specifies the total number of bits used for the significand
(mantisa), including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Min exp (unbiased): -6
- Max exp (unbiased): 7
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Min normal number: S.0001.000 = +/-2^(1 - 7) x (1 + 0) = +/-2^(-6)
- Max normal number: S.1110.111 = +/-2^(14 - 7) x (1 + 7/8) = +/-240
- Min subnormal number: S.0000.001 = +/-2^(-6) x 1/8 = +/-2^(-9)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 7/8 = +/-2^(-9) x 7
```
#### Comparison of Float8E4M3FN and Float8E4M3
| |Float8E4M3FN |Float8E4M3 |
|-------------------|------------------------------------------------------------------------|-------------------------------------------------------------------------|
|Bias |7 |7 |
|Min Normal Value |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> |
|Max Normal Value |`0bS1111110` = -1<sup>S</sup> $\times$ 1.75 $\times$ 2<sup>8</sup> = 448|`0bS1110111` = -1<sup>S</sup> $\times$ 1.875 $\times$ 2<sup>7</sup> = 240|
|Min Subnormal Value|`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> |
|Max Subnormal Value|`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> |`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> |
|NaN |`0bS1111111` |`0bS1111MMM`, where `MMM` is non-zero. |
|Infinity |N/A |`0bS1111000` |
|-0.0 |`0b10000000` |`0b10000000` |
### Float8E3M4
8-bit floating point type with 1 sign bit, 3 bits exponent and 4 bits mantissa
following IEEE-754 conventions with bit layout S1E3M4.
```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Minimum stored exponent value: 1 (binary 001)
- Maximum stored exponent value: 6 (binary 110)
- Minimum unbiased exponent value: 1 − 3 = −2
- Maximum unbiased exponent value: 6 - 3 = 3
- Precision specifies the total number of bits used for the significand
(mantissa), including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs
Additional details:
- Min exp (unbiased): -2
- Max exp (unbiased): 3
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Min normal number: S.001.0000 = +/-2^(1 - 3) x (1 + 0) = +/-0.25
- Max normal number: S.110.1111 = +/-2^(6 - 3) x (1 + 15/16) = +/-15.5
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-6)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-6) x 15
```

### Comparison of Float8E5M2, Float8E4M3 and Float8E3M4

| |Float8E5M2 |Float8E4M3 |Float8E3M4 |
|-------------------|----------------------------------------------------------------------------|-------------------------------------------------------------------------|---------------------------------------------------------------------------|
|Bias |15 |7 |3 |
|Min Normal Value |`0bS0000100` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-14</sup> |`0bS0001000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-6</sup> |`0bS0010000` = -1<sup>S</sup> $\times$ 1.0 $\times$ 2<sup>-2</sup> |
|Max Normal Value |`0bS1111011` = -1<sup>S</sup> $\times$ 1.75 $\times$ 2<sup>15</sup> = 57344 |`0bS1110111` = -1<sup>S</sup> $\times$ 1.875 $\times$ 2<sup>7</sup> = 240|`0bS1101111` = -1<sup>S</sup> $\times$ 1.9375 $\times$ 2<sup>3</sup> = 15.5|
|Min Subnormal Value|`0bS0000001` = -1<sup>S</sup> $\times$ 0.25 $\times$ 2<sup>-14</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.125 $\times$ 2<sup>-6</sup> |`0bS0000001` = -1<sup>S</sup> $\times$ 0.0625 $\times$ 2<sup>-2</sup> |
|Max Subnormal Value|`0bS0000011` = -1<sup>S</sup> $\times$ 0.75 $\times$ 2<sup>-14</sup> |`0bS0000111` = -1<sup>S</sup> $\times$ 0.875 $\times$ 2<sup>-6</sup> |`0bS0001111` = -1<sup>S</sup> $\times$ 0.9375 $\times$ 2<sup>-2</sup> |
|NaN |`0bS11111MM`, where `MM` is non-zero. |`0bS1111MMM`, where `MMM` is non-zero. |`0bS111MMMM`, where `MMMM` is non-zero. |
|Infinity |`0bS1111100` |`0bS1111000` |`0bS1110000` |
|-0.0 |`0b10000000` |`0b10000000` |`0b10000000` |

## Changes in StableHLO

I propose adding Float8E4M3 and Float8E3M4 types to StableHLO similar to the
previously introduces FP8 types (below) with some differences:

- [FP8 RFC](https://github.com/openxla/xla/discussions/22)
- [[RFC] Add Float8E4M3FNUZ and Float8E5M2FNUZ to StableHLO](https://github.com/openxla/stablehlo/pull/1342)

### StableHLO Interpreter

To provide a reference implementation, I intend to add support for
Float8E4M3 and Float8E3M4 in the StableHLO interpreter. This will be
useful for testing other backends and validating new implementations. This will
be achieved in two ways:

1. Map directly to the appropriate APFloat operation.
2. Cast up to the appropriate type, use that implementation, cast back down.

### Float8E4M3 and Float8E3M4 Arithmetic

I intend for Float8E4M3 and Float8E3M4 to be types that support the
appropriate arithmetic operations, like any other floating point type. For
platforms that don't have hardware support for these types, they may either
throw an error and reject the program or cast up to an appropriate higher
precision type that is supported, compute the answer, and cast back down.

This is a simple approach that aligns with user expectations of a floating
point data type, and is the approach taken by BFloat16. This also gives
backends freedom to exploit any hardware support.

Here's an example of a real JAX program (logging the MLIR) computing a simple
dot product in Float8E4M3. Note the answer is slightly "wrong", as expected
due to the lower precision (round-to-nearest).

```python
>>> import jax
>>> import jax.numpy as jnp
>>> x = jnp.arange(8, dtype=jnp.float8_e4m3)
module @jit_iota {
func.func public @main() -> tensor<8xf8E4M3> {
%0 = stablehlo.iota dim = 0 : tensor<8xf8E4M3>
return %0 : tensor<8xf8E4M3>
}
}
>>> x
Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=float8_e4m3)
>>> x @ x
module @jit_matmul {
func.func public @main(%arg0: tensor<8xf8E4M3> {mhlo.sharding = ""}, %arg1: tensor<8xf8E4M3> {mhlo.sharding = ""}) -> tensor<f8E4M3> {
%0 = "stablehlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = #stablehlo.dot<lhs_contracting_dimensions = [0], rhs_contracting_dimensions = [0]>, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<8xf8E4M3>, tensor<8xf8E4M3>) -> tensor<f8E4M3>
return %0 : tensor<f8E4M3>
}
}
Array(144, dtype=float8_e4m3)
```

### Testing

Built on the StableHLO interpreter, I intend to introduce tests for all
possible operations with Float8E4M3 and Float8E3M4 inputs. This will at
a minimum mean adding additional cases to the `interpret_X.mlir` family of
tests.

### References and Links

- [RFC: FP8 in StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md)
- [RFC: Float8E4M3FNUZ and Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md)

[^1]: [Amazon EC2 Trn1 Instances](https://aws.amazon.com/ec2/instance-types/trn1/)
[^2]: LLVM [PR-97118](https://github.com/llvm/llvm-project/pull/97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
[^3]: LLVM [PR-101230](https://github.com/llvm/llvm-project/pull/101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
[^4]: LLVM [PR-97179](https://github.com/llvm/llvm-project/pull/97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
[^5]: LLVM [PR-99698](https://github.com/llvm/llvm-project/pull/99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
54 changes: 54 additions & 0 deletions stablehlo/integrations/c/StablehloAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,60 @@ int64_t stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr) {
.getIndexVectorDim();
}

//===----------------------------------------------------------------------===//
// DotAlgorithm
//===----------------------------------------------------------------------===//

MlirAttribute stablehloDotAlgorithmGet(
MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
MlirType accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount, int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation) {
return wrap(mlir::stablehlo::DotAlgorithmAttr::get(
unwrap(ctx), unwrap(lhsPrecisionType), unwrap(rhsPrecisionType),
unwrap(accumulationType), lhsComponentCount, rhsComponentCount,
numPrimitiveOperations, allowImpreciseAccumulation));
}

bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr) {
return llvm::isa<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr));
}

MlirType stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr) {
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getLhsPrecisionType());
}

MlirType stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr) {
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getRhsPrecisionType());
}

MlirType stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr) {
return wrap(llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getAccumulationType());
}

int64_t stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getLhsComponentCount();
}

int64_t stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getRhsComponentCount();
}

int64_t stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getNumPrimitiveOperations();
}

bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(MlirAttribute attr) {
return llvm::cast<mlir::stablehlo::DotAlgorithmAttr>(unwrap(attr))
.getAllowImpreciseAccumulation();
}

//===----------------------------------------------------------------------===//
// DotDimensionNumbers
//===----------------------------------------------------------------------===//
Expand Down
33 changes: 33 additions & 0 deletions stablehlo/integrations/c/StablehloAttributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,39 @@ MLIR_CAPI_EXPORTED int64_t stablehloGatherDimensionNumbersGetStartIndexMapElem(
MLIR_CAPI_EXPORTED int64_t
stablehloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr);

//===----------------------------------------------------------------------===//
// DotAlgorithm
//===----------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED MlirAttribute stablehloDotAlgorithmGet(
MlirContext ctx, MlirType lhsPrecisionType, MlirType rhsPrecisionType,
MlirType accumulationType, int64_t lhsComponentCount,
int64_t rhsComponentCount, int64_t numPrimitiveOperations,
bool allowImpreciseAccumulation);

MLIR_CAPI_EXPORTED bool stablehloAttributeIsADotAlgorithm(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirType
stablehloDotAlgorithmGetLhsPrecisionType(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirType
stablehloDotAlgorithmGetRhsPrecisionType(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirType
stablehloDotAlgorithmGetAccumulationType(MlirAttribute attr);

MLIR_CAPI_EXPORTED int64_t
stablehloDotAlgorithmGetLhsComponentCount(MlirAttribute attr);

MLIR_CAPI_EXPORTED int64_t
stablehloDotAlgorithmGetRhsComponentCount(MlirAttribute attr);

MLIR_CAPI_EXPORTED int64_t
stablehloDotAlgorithmGetNumPrimitiveOperations(MlirAttribute attr);

MLIR_CAPI_EXPORTED bool stablehloDotAlgorithmGetAllowImpreciseAccumulation(
MlirAttribute attr);

//===----------------------------------------------------------------------===//
// DotDimensionNumbers
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 56 additions & 0 deletions stablehlo/integrations/python/StablehloModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,62 @@ PYBIND11_MODULE(_stablehlo, m) {
return stablehloGatherDimensionNumbersGetIndexVectorDim(self);
});

mlir::python::adaptors::mlir_attribute_subclass(
m, "DotAlgorithm", stablehloAttributeIsADotAlgorithm)
.def_classmethod(
"get",
[](py::object cls, MlirType lhsPrecisionType,
MlirType rhsPrecisionType, MlirType accumulationType,
int64_t lhsComponentCount, int64_t rhsComponentCount,
int64_t numPrimitiveOperations, bool allowImpreciseAccumulation,
MlirContext ctx) {
return cls(stablehloDotAlgorithmGet(
ctx, lhsPrecisionType, rhsPrecisionType, accumulationType,
lhsComponentCount, rhsComponentCount, numPrimitiveOperations,
allowImpreciseAccumulation));
},
py::arg("cls"), py::arg("lhs_precision_type"),
py::arg("rhs_precision_type"), py::arg("accumulation_type"),
py::arg("lhs_component_count"), py::arg("rhs_component_count"),
py::arg("num_primitive_operations"),
py::arg("allow_imprecise_accumulation"), py::arg("ctx") = py::none(),
"Creates a DotAlgorithm attribute with the given dimension "
"configuration.")
.def_property_readonly(
"lhs_precision_type",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetLhsPrecisionType(self);
})
.def_property_readonly(
"rhs_precision_type",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetRhsPrecisionType(self);
})
.def_property_readonly(
"accumulation_type",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetAccumulationType(self);
})
.def_property_readonly(
"lhs_component_count",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetLhsComponentCount(self);
})
.def_property_readonly(
"rhs_component_count",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetRhsComponentCount(self);
})
.def_property_readonly(
"num_primitive_operations",
[](MlirAttribute self) {
return stablehloDotAlgorithmGetNumPrimitiveOperations(self);
})
.def_property_readonly(
"allow_imprecise_accumulation", [](MlirAttribute self) {
return stablehloDotAlgorithmGetAllowImpreciseAccumulation(self);
});

mlir::python::adaptors::mlir_attribute_subclass(
m, "DotDimensionNumbers", stablehloAttributeIsADotDimensionNumbers)
.def_classmethod(
Expand Down
26 changes: 26 additions & 0 deletions stablehlo/integrations/python/tests/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,32 @@ def test_conv_dimension_numbers():
assert attr.output_spatial_dimensions == [2, 3]


@run
def test_dot_algorithm():
# BF16_BF16_F32_X3
attr = stablehlo.DotAlgorithm.get(
lhs_precision_type=ir.BF16Type.get(),
rhs_precision_type=ir.BF16Type.get(),
accumulation_type=ir.F32Type.get(),
lhs_component_count=1,
rhs_component_count=1,
num_primitive_operations=3,
allow_imprecise_accumulation=False)
assert attr is not None
assert str(attr) == ("#stablehlo.dot_algorithm<lhs_precision_type = bf16, "
"rhs_precision_type = bf16, accumulation_type = f32, "
"lhs_component_count = 1, rhs_component_count = 1, "
"num_primitive_operations = 3, "
"allow_imprecise_accumulation = false>")
assert isinstance(attr.lhs_precision_type, ir.BF16Type)
assert isinstance(attr.rhs_precision_type, ir.BF16Type)
assert isinstance(attr.accumulation_type, ir.F32Type)
assert attr.lhs_component_count == 1
assert attr.rhs_component_count == 1
assert attr.num_primitive_operations == 3
assert attr.allow_imprecise_accumulation == False


@run
def test_dot_dimension_numbers():
attr = stablehlo.DotDimensionNumbers.get(
Expand Down

0 comments on commit 526cebc

Please sign in to comment.