Skip to content

Commit

Permalink
Add f8E4M3 and f8E3M4 types support
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Aug 30, 2024
1 parent 54aa1a5 commit 356dc4b
Show file tree
Hide file tree
Showing 18 changed files with 3,108 additions and 74 deletions.
6 changes: 4 additions & 2 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
FloatType ::= 'f8E3M4' | 'f8E4M3' | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ'
| 'f8E5M2' | 'f8E5M2FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Expand All @@ -265,6 +265,8 @@ values of type `tensor<T>`).
inclusive, and unsigned `uiN` types represent integer values from `0` to
`2^N-1` inclusive.
* **Floating-point types** can be one of the following:
* `f8E3M4`, `f8E4M3` and `f8E5M2` 8-bit floating point numbers following
IEEE-754 conventions.
* `f8E4M3FN` and `f8E5M2` types corresponding to respectively the
`E4M3` and `E5M2` encodings of the FP8 format described in
[FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433).
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,9 @@ FailureOr<KnownDotAlgorithm> getKnownDotAlgorithm(
if (lhsComponentCount != 1 || rhsComponentCount != 1) return failure();

auto isAnyF8 = [](Type t) {
return llvm::isa<Float8E4M3FNType, Float8E5M2Type, Float8E4M3FNUZType,
Float8E4M3B11FNUZType, Float8E5M2FNUZType>(t);
return llvm::isa<Float8E3M4Type, Float8E4M3Type, Float8E4M3FNType,
Float8E5M2Type, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
Float8E5M2FNUZType>(t);
};
if (isAnyF8(lhsPrecisionType) && isAnyF8(rhsPrecisionType) &&
accumulationType.isF32() && numPrimitiveOperations == 1) {
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def HLO_SInt : SignlessIntOfWidths<[2, 4, 8, 16, 32, 64]>;
def HLO_UInt : UnsignedIntOfWidths<[2, 4, 8, 16, 32, 64]>;
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;

def HLO_Float : AnyTypeOf<[F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2,
F8E5M2FNUZ, F16, F32, F64, BF16]>;
def HLO_Float : AnyTypeOf<[F8E3M4, F8E4M3, F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ,
F8E5M2, F8E5M2FNUZ, F16, F32, F64, BF16]>;
def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>;

def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(1, 6, 3); }
static Version getCurrentVersion() { return Version(1, 7, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
22 changes: 22 additions & 0 deletions stablehlo/dialect/VhloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ enum TypeCode {
/// }
kFloatF64V1Type = 5,

/// FloatF8E3M4V1Type {
/// }
kFloatF8E3M4V1Type = 36,

/// FloatF8E4M3V1Type {
/// }
kFloatF8E4M3V1Type = 35,

/// FloatF8E4M3FNV1Type {
/// }
kFloatF8E4M3FNV1Type = 6,
Expand Down Expand Up @@ -698,9 +706,11 @@ const llvm::fltSemantics &getFloatSemantics(Type type) {
if (isa<FloatF16V1Type>(type)) return APFloat::IEEEhalf();
if (isa<FloatF32V1Type>(type)) return APFloat::IEEEsingle();
if (isa<FloatF64V1Type>(type)) return APFloat::IEEEdouble();
if (isa<FloatF8E3M4V1Type>(type)) return APFloat::Float8E3M4();
if (isa<FloatF8E4M3FNUZV1Type>(type)) return APFloat::Float8E4M3FNUZ();
if (isa<FloatF8E4M3B11FNUZV1Type>(type)) return APFloat::Float8E4M3B11FNUZ();
if (isa<FloatF8E4M3FNV1Type>(type)) return APFloat::Float8E4M3FN();
if (isa<FloatF8E4M3V1Type>(type)) return APFloat::Float8E4M3();
if (isa<FloatF8E5M2FNUZV1Type>(type)) return APFloat::Float8E5M2FNUZ();
if (isa<FloatF8E5M2V1Type>(type)) return APFloat::Float8E5M2();
if (isa<FloatTF32V1Type>(type)) return APFloat::FloatTF32();
Expand Down Expand Up @@ -968,6 +978,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return FloatF64V1Type::get(getContext());
case vhlo_encoding::kFloatF8E5M2V1Type:
return FloatF8E5M2V1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3V1Type:
return FloatF8E4M3V1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3FNV1Type:
return FloatF8E4M3FNV1Type::get(getContext());
case vhlo_encoding::kFloatF8E5M2FNUZV1Type:
Expand All @@ -976,6 +988,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return FloatF8E4M3FNUZV1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3B11FNUZV1Type:
return FloatF8E4M3B11FNUZV1Type::get(getContext());
case vhlo_encoding::kFloatF8E3M4V1Type:
return FloatF8E3M4V1Type::get(getContext());
case vhlo_encoding::kFloatTF32V1Type:
return FloatTF32V1Type::get(getContext());
case vhlo_encoding::kFunctionV1Type:
Expand Down Expand Up @@ -1060,6 +1074,14 @@ LogicalResult VhloBytecodeInterface::writeType(
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF64V1Type), success();
})
.Case([&](FloatF8E3M4V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E3M4V1Type), success();
})
.Case([&](FloatF8E4M3V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3V1Type), success();
})
.Case([&](FloatF8E4M3FNV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3FNV1Type),
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def VHLO_Dialect : Dialect {
1.4.0: Add `tan` op to StableHLO opset.
1.5.0: Make collective ops (`all_reduce`, `all_gather`, `all_to_all`) variadic.
1.6.0: Add DotAlgorithm specificaiton to `dot_general`.
1.7.0: Introduce `f8E4M3` and `f8E3M4` types.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
12 changes: 12 additions & 0 deletions stablehlo/dialect/VhloTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ void VhloTypeConverter::addBuiltinToVhloConversions() {
[&](Float32Type type) { return FloatF32V1Type::get(type.getContext()); });
addConversion(
[&](Float64Type type) { return FloatF64V1Type::get(type.getContext()); });
addConversion([&](Float8E3M4Type type) {
return FloatF8E3M4V1Type::get(type.getContext());
});
addConversion([&](Float8E4M3Type type) {
return FloatF8E4M3V1Type::get(type.getContext());
});
addConversion([&](Float8E4M3FNType type) {
return FloatF8E4M3FNV1Type::get(type.getContext());
});
Expand Down Expand Up @@ -176,6 +182,12 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
[&](FloatF32V1Type type) { return Float32Type::get(type.getContext()); });
addConversion(
[&](FloatF64V1Type type) { return Float64Type::get(type.getContext()); });
addConversion([&](FloatF8E3M4V1Type type) {
return Float8E3M4Type::get(type.getContext());
});
addConversion([&](FloatF8E4M3V1Type type) {
return Float8E4M3Type::get(type.getContext());
});
addConversion([&](FloatF8E4M3FNV1Type type) {
return Float8E4M3FNType::get(type.getContext());
});
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/dialect/VhloTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def VHLO_FloatF32V1 : VHLO_TypeDef<"FloatF32V1", "f32_v1", "0.9.0", "current">;
// Corresponds to the 'f64' FloatType from the StableHLO spec.
def VHLO_FloatF64V1 : VHLO_TypeDef<"FloatF64V1","f64_v1", "0.9.0", "current">;

// Corresponds to the 'f8E3M4' FloatType from the StableHLO spec.
def VHLO_FloatF8E3M4V1 : VHLO_TypeDef<"FloatF8E3M4V1", "f8E3M4_v1", "1.7.0", "current">;

// Corresponds to the 'f8E4M3' FloatType from the StableHLO spec.
def VHLO_FloatF8E4M3V1 : VHLO_TypeDef<"FloatF8E4M3V1", "f8E4M3_v1", "1.7.0", "current">;

// Corresponds to the 'f8E4M3FN' FloatType from the StableHLO spec.
def VHLO_FloatF8E4M3FNV1 : VHLO_TypeDef<"FloatF8E4M3FNV1", "f8E4M3FN_v1", "0.9.0", "current">;

Expand Down
22 changes: 17 additions & 5 deletions stablehlo/reference/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,21 @@ Element Tensor::get(const Index &index) const {
getSizeInBytes(elementType) * flattenIndex(getShape(), index);

// Handle floating-point types.
if (elementType.isFloat8E3M4()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E3M4(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3B11FNUZ()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3B11FNUZ(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3FN()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FN(),
Expand Down Expand Up @@ -252,7 +262,8 @@ void Tensor::set(const Index &index, const Element &element) {
getSizeInBytes(elementType) * flattenIndex(getShape(), index);

// Handle floating-point types.
if (elementType.isFloat8E4M3B11FNUZ() || elementType.isFloat8E4M3FN() ||
if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() ||
elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() ||
elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() ||
elementType.isFloat8E5M2FNUZ()) {
auto elementData = reinterpret_cast<uint8_t *>(elementPtr);
Expand Down Expand Up @@ -446,17 +457,18 @@ Tensor makeTensor(DenseElementsAttr attr) {
auto elementType = type.getElementType();

// Handle floating-point types.
if (elementType.isFloat8E4M3B11FNUZ() || elementType.isFloat8E4M3FN() ||
if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() ||
elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() ||
elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() ||
elementType.isFloat8E5M2FNUZ()) {
auto floatValues = llvm::map_to_vector(
attr.getValues<APFloat>(), [&](APFloat value) -> uint8_t {
return value.bitcastToAPInt().getZExtValue();
});

// For f8E4M3B11FNUZ, f8E4M3FN, f8E4M3FNUZ, f8E5M2, and f8E5M2FNUZ
// floating-point types, we use uint8_t as their storage type because there
// are no builtin types for those.
// For f8E3M4, f8E4M3, f8E4M3FN, f8E4M3FNUZ, f8E4M3B11FNUZ, f8E5M2, and
// f8E5M2FNUZ floating-point types, we use uint8_t as their storage type
// because there are no builtin types for those.
return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign<uint8_t>(
floatValues));
}
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/reference/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ bool isSupportedIntegerType(Type type) {
}

bool isSupportedFloatType(Type type) {
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
return type.isFloat8E3M4() || type.isFloat8E4M3B11FNUZ() ||
type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
type.isFloat8E5M2FNUZ() || type.isF16() || type.isBF16() ||
type.isF32() || type.isF64();
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/tests/interpret/constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ func.func @constant_op_test_ui64() {

// -----

func.func @constant_op_test_f8_e3m4() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E3M4>
check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.09375, 3.125, 0x7F, 0xFF, 0.015625, -0.015625]> : tensor<10xf8E3M4>
func.return
}

// -----

func.func @constant_op_test_f8_e4m3b11_fnuz() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3B11FNUZ>
check.expect_almost_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.101563, 3.25, 30.0, -30.0, 0.00012207, -0.00012207]> : tensor<10xf8E4M3B11FNUZ>
Expand All @@ -104,6 +112,14 @@ func.func @constant_op_test_f8_e4m3b11_fnuz() {

// -----

func.func @constant_op_test_f8_e4m3() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3>
check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.1015630, 3.25, 0x7F, 0xFF, 0.001953130, -0.001953130]> : tensor<10xf8E4M3>
func.return
}

// -----

func.func @constant_op_test_f8_e4m3_fn() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3FN>
check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.1015630, 3.25, 0x7F, 0xFF, 0.001953130, -0.001953130]> : tensor<10xf8E4M3FN>
Expand Down
23 changes: 23 additions & 0 deletions stablehlo/tests/interpret/dot_general.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,26 @@ func.func @dot_general_op_test_different_operand_and_result_element_types() {
[[5.0, 6.0], [7.0, 8.0]]]> : tensor<2x2x2xf64>
func.return
}

// -----

func.func @add_op_test_f8E3M4() {
%0 = stablehlo.constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf8E3M4>
%result = stablehlo.dot_general %0, %0,
contracting_dims = [0] x [0]
: (tensor<4xf8E3M4>, tensor<4xf8E3M4>) -> tensor<f8E3M4>
check.expect_almost_eq_const %result, dense<14.0> : tensor<f8E3M4>
func.return
}

// -----

func.func @add_op_test_f8E4M3() {
%0 = stablehlo.constant dense<[0.0, 1.0, 2.0, 3.0,
4.0, 5.0, 6.0, 7.0]> : tensor<8xf8E4M3>
%result = stablehlo.dot_general %0, %0,
contracting_dims = [0] x [0]
: (tensor<8xf8E4M3>, tensor<8xf8E4M3>) -> tensor<f8E4M3>
check.expect_almost_eq_const %result, dense<140.0> : tensor<f8E4M3>
func.return
}
Loading

0 comments on commit 356dc4b

Please sign in to comment.