Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add f8E4M3 and f8E3M4 types support #2482

Merged
merged 1 commit into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ BooleanType ::= 'i1'
IntegerType ::= SignedIntegerType | UnsignedIntegerType
SignedIntegerType ::= 'si2' | 'si4' | 'si8' | 'si16' | 'si32' | 'si64'
UnsignedIntegerType ::= 'ui2' | 'ui4' | 'ui8' | 'ui16' | 'ui32' | 'ui64'
FloatType ::= 'f8E4M3FN' | 'f8E5M2' | 'f8E4M3FNUZ' | 'f8E5M2FNUZ'
| 'f8E4M3B11FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
FloatType ::= 'f8E3M4' | 'f8E4M3' | 'f8E4M3FN' | 'f8E4M3FNUZ' | 'f8E4M3B11FNUZ'
| 'f8E5M2' | 'f8E5M2FNUZ' | 'bf16' | 'f16' | 'f32' | 'f64'
TensorFloat32 ::= 'tf32'
ComplexType ::= 'complex' '<' ComplexElementType '>'
ComplexElementType ::= 'f32' | 'f64'
Expand All @@ -265,6 +265,8 @@ values of type `tensor<T>`).
inclusive, and unsigned `uiN` types represent integer values from `0` to
`2^N-1` inclusive.
* **Floating-point types** can be one of the following:
* `f8E3M4`, `f8E4M3` and `f8E5M2` 8-bit floating point numbers following
IEEE-754 conventions.
* `f8E4M3FN` and `f8E5M2` types corresponding to respectively the
`E4M3` and `E5M2` encodings of the FP8 format described in
[FP8 Formats for Deep Learning](https://arxiv.org/abs/2209.05433).
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,9 @@ FailureOr<KnownDotAlgorithm> getKnownDotAlgorithm(
if (lhsComponentCount != 1 || rhsComponentCount != 1) return failure();

auto isAnyF8 = [](Type t) {
return llvm::isa<Float8E4M3FNType, Float8E5M2Type, Float8E4M3FNUZType,
Float8E4M3B11FNUZType, Float8E5M2FNUZType>(t);
return llvm::isa<Float8E3M4Type, Float8E4M3Type, Float8E4M3FNType,
Float8E5M2Type, Float8E4M3FNUZType, Float8E4M3B11FNUZType,
Float8E5M2FNUZType>(t);
};
if (isAnyF8(lhsPrecisionType) && isAnyF8(rhsPrecisionType) &&
accumulationType.isF32() && numPrimitiveOperations == 1) {
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def HLO_SInt : SignlessIntOfWidths<[2, 4, 8, 16, 32, 64]>;
def HLO_UInt : UnsignedIntOfWidths<[2, 4, 8, 16, 32, 64]>;
def HLO_Int : AnyTypeOf<[HLO_SInt, HLO_UInt]>;

def HLO_Float : AnyTypeOf<[F8E4M3B11FNUZ, F8E4M3FN, F8E4M3FNUZ, F8E5M2,
F8E5M2FNUZ, F16, F32, F64, BF16]>;
def HLO_Float : AnyTypeOf<[F8E3M4, F8E4M3, F8E4M3FN, F8E4M3FNUZ, F8E4M3B11FNUZ,
F8E5M2, F8E5M2FNUZ, F16, F32, F64, BF16]>;
def HLO_Float32Or64 : AnyTypeOf<[F32, F64]>;

def HLO_Complex : Complex<AnyTypeOf<[F32, F64]>>;
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/Version.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Version {
static FailureOr<Version> fromString(llvm::StringRef versionRef);

/// Return a Version representing the current VHLO dialect version.
static Version getCurrentVersion() { return Version(1, 6, 4); }
static Version getCurrentVersion() { return Version(1, 7, 0); }

/// Return a Version representing the minimum supported VHLO dialect version.
static Version getMinimumVersion() { return Version(0, 9, 0); }
Expand Down
24 changes: 23 additions & 1 deletion stablehlo/dialect/VhloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ enum AttributeCode {
/// location is updated.
enum TypeCode {
// TO ADD TYPE: Add an enum value with doc string for new type.
// Next available code: 35
// Next available code: 37

/// BooleanV1Type {
/// }
Expand All @@ -216,6 +216,14 @@ enum TypeCode {
/// }
kFloatF64V1Type = 5,

/// FloatF8E3M4V1Type {
/// }
kFloatF8E3M4V1Type = 36,

/// FloatF8E4M3V1Type {
/// }
kFloatF8E4M3V1Type = 35,
GleasonK marked this conversation as resolved.
Show resolved Hide resolved

/// FloatF8E4M3FNV1Type {
/// }
kFloatF8E4M3FNV1Type = 6,
Expand Down Expand Up @@ -698,9 +706,11 @@ const llvm::fltSemantics &getFloatSemantics(Type type) {
if (isa<FloatF16V1Type>(type)) return APFloat::IEEEhalf();
if (isa<FloatF32V1Type>(type)) return APFloat::IEEEsingle();
if (isa<FloatF64V1Type>(type)) return APFloat::IEEEdouble();
if (isa<FloatF8E3M4V1Type>(type)) return APFloat::Float8E3M4();
if (isa<FloatF8E4M3FNUZV1Type>(type)) return APFloat::Float8E4M3FNUZ();
if (isa<FloatF8E4M3B11FNUZV1Type>(type)) return APFloat::Float8E4M3B11FNUZ();
if (isa<FloatF8E4M3FNV1Type>(type)) return APFloat::Float8E4M3FN();
if (isa<FloatF8E4M3V1Type>(type)) return APFloat::Float8E4M3();
if (isa<FloatF8E5M2FNUZV1Type>(type)) return APFloat::Float8E5M2FNUZ();
if (isa<FloatF8E5M2V1Type>(type)) return APFloat::Float8E5M2();
if (isa<FloatTF32V1Type>(type)) return APFloat::FloatTF32();
Expand Down Expand Up @@ -968,6 +978,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return FloatF64V1Type::get(getContext());
case vhlo_encoding::kFloatF8E5M2V1Type:
return FloatF8E5M2V1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3V1Type:
return FloatF8E4M3V1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3FNV1Type:
return FloatF8E4M3FNV1Type::get(getContext());
case vhlo_encoding::kFloatF8E5M2FNUZV1Type:
Expand All @@ -976,6 +988,8 @@ Type VhloBytecodeInterface::readType(DialectBytecodeReader &reader) const {
return FloatF8E4M3FNUZV1Type::get(getContext());
case vhlo_encoding::kFloatF8E4M3B11FNUZV1Type:
return FloatF8E4M3B11FNUZV1Type::get(getContext());
case vhlo_encoding::kFloatF8E3M4V1Type:
return FloatF8E3M4V1Type::get(getContext());
case vhlo_encoding::kFloatTF32V1Type:
return FloatTF32V1Type::get(getContext());
case vhlo_encoding::kFunctionV1Type:
Expand Down Expand Up @@ -1060,6 +1074,14 @@ LogicalResult VhloBytecodeInterface::writeType(
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF64V1Type), success();
})
.Case([&](FloatF8E3M4V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E3M4V1Type), success();
})
.Case([&](FloatF8E4M3V1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3V1Type), success();
})
.Case([&](FloatF8E4M3FNV1Type) {
LOG_WRITE_CALL;
return writer.writeVarInt(vhlo_encoding::kFloatF8E4M3FNV1Type),
Expand Down
1 change: 1 addition & 0 deletions stablehlo/dialect/VhloDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def VHLO_Dialect : Dialect {
1.4.0: Add `tan` op to StableHLO opset.
1.5.0: Make collective ops (`all_reduce`, `all_gather`, `all_to_all`) variadic.
1.6.0: Add DotAlgorithm specificaiton to `dot_general`.
1.7.0: Introduce `f8E4M3` and `f8E3M4` types.
}];

let useDefaultAttributePrinterParser = 0;
Expand Down
12 changes: 12 additions & 0 deletions stablehlo/dialect/VhloTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ void VhloTypeConverter::addBuiltinToVhloConversions() {
[&](Float32Type type) { return FloatF32V1Type::get(type.getContext()); });
addConversion(
[&](Float64Type type) { return FloatF64V1Type::get(type.getContext()); });
addConversion([&](Float8E3M4Type type) {
return FloatF8E3M4V1Type::get(type.getContext());
});
addConversion([&](Float8E4M3Type type) {
return FloatF8E4M3V1Type::get(type.getContext());
});
addConversion([&](Float8E4M3FNType type) {
return FloatF8E4M3FNV1Type::get(type.getContext());
});
Expand Down Expand Up @@ -176,6 +182,12 @@ void VhloTypeConverter::addVhloToBuiltinConversions() {
[&](FloatF32V1Type type) { return Float32Type::get(type.getContext()); });
addConversion(
[&](FloatF64V1Type type) { return Float64Type::get(type.getContext()); });
addConversion([&](FloatF8E3M4V1Type type) {
return Float8E3M4Type::get(type.getContext());
});
addConversion([&](FloatF8E4M3V1Type type) {
return Float8E4M3Type::get(type.getContext());
});
addConversion([&](FloatF8E4M3FNV1Type type) {
return Float8E4M3FNType::get(type.getContext());
});
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/dialect/VhloTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def VHLO_FloatF32V1 : VHLO_TypeDef<"FloatF32V1", "f32_v1", "0.9.0", "current">;
// Corresponds to the 'f64' FloatType from the StableHLO spec.
def VHLO_FloatF64V1 : VHLO_TypeDef<"FloatF64V1","f64_v1", "0.9.0", "current">;

// Corresponds to the 'f8E3M4' FloatType from the StableHLO spec.
def VHLO_FloatF8E3M4V1 : VHLO_TypeDef<"FloatF8E3M4V1", "f8E3M4_v1", "1.7.0", "current">;

// Corresponds to the 'f8E4M3' FloatType from the StableHLO spec.
def VHLO_FloatF8E4M3V1 : VHLO_TypeDef<"FloatF8E4M3V1", "f8E4M3_v1", "1.7.0", "current">;

// Corresponds to the 'f8E4M3FN' FloatType from the StableHLO spec.
def VHLO_FloatF8E4M3FNV1 : VHLO_TypeDef<"FloatF8E4M3FNV1", "f8E4M3FN_v1", "0.9.0", "current">;

Expand Down
22 changes: 17 additions & 5 deletions stablehlo/reference/Tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,21 @@ Element Tensor::get(const Index &index) const {
getSizeInBytes(elementType) * flattenIndex(getShape(), index);

// Handle floating-point types.
if (elementType.isFloat8E3M4()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E3M4(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3B11FNUZ()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3B11FNUZ(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3(),
APInt(8, *elementData)));
}
if (elementType.isFloat8E4M3FN()) {
auto elementData = reinterpret_cast<const uint8_t *>(elementPtr);
return Element(elementType, APFloat(llvm::APFloatBase::Float8E4M3FN(),
Expand Down Expand Up @@ -252,7 +262,8 @@ void Tensor::set(const Index &index, const Element &element) {
getSizeInBytes(elementType) * flattenIndex(getShape(), index);

// Handle floating-point types.
if (elementType.isFloat8E4M3B11FNUZ() || elementType.isFloat8E4M3FN() ||
if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() ||
elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() ||
elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() ||
elementType.isFloat8E5M2FNUZ()) {
auto elementData = reinterpret_cast<uint8_t *>(elementPtr);
Expand Down Expand Up @@ -446,17 +457,18 @@ Tensor makeTensor(DenseElementsAttr attr) {
auto elementType = type.getElementType();

// Handle floating-point types.
if (elementType.isFloat8E4M3B11FNUZ() || elementType.isFloat8E4M3FN() ||
if (elementType.isFloat8E3M4() || elementType.isFloat8E4M3B11FNUZ() ||
elementType.isFloat8E4M3() || elementType.isFloat8E4M3FN() ||
elementType.isFloat8E4M3FNUZ() || elementType.isFloat8E5M2() ||
elementType.isFloat8E5M2FNUZ()) {
auto floatValues = llvm::map_to_vector(
attr.getValues<APFloat>(), [&](APFloat value) -> uint8_t {
return value.bitcastToAPInt().getZExtValue();
});

// For f8E4M3B11FNUZ, f8E4M3FN, f8E4M3FNUZ, f8E5M2, and f8E5M2FNUZ
// floating-point types, we use uint8_t as their storage type because there
// are no builtin types for those.
// For f8E3M4, f8E4M3, f8E4M3FN, f8E4M3FNUZ, f8E4M3B11FNUZ, f8E5M2, and
// f8E5M2FNUZ floating-point types, we use uint8_t as their storage type
// because there are no builtin types for those.
return Tensor(type, HeapAsmResourceBlob::allocateAndCopyInferAlign<uint8_t>(
floatValues));
}
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/reference/Types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ bool isSupportedIntegerType(Type type) {
}

bool isSupportedFloatType(Type type) {
return type.isFloat8E4M3B11FNUZ() || type.isFloat8E4M3FN() ||
return type.isFloat8E3M4() || type.isFloat8E4M3B11FNUZ() ||
type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E4M3FNUZ() || type.isFloat8E5M2() ||
type.isFloat8E5M2FNUZ() || type.isF16() || type.isBF16() ||
type.isF32() || type.isF64();
Expand Down
16 changes: 16 additions & 0 deletions stablehlo/tests/interpret/constant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ func.func @constant_op_test_ui64() {

// -----

func.func @constant_op_test_f8_e3m4() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E3M4>
check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.09375, 3.125, 0x7F, 0xFF, 0.015625, -0.015625]> : tensor<10xf8E3M4>
func.return
}

// -----

func.func @constant_op_test_f8_e4m3b11_fnuz() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3B11FNUZ>
check.expect_almost_eq_const %0, dense<[0.0, 0.0, 1.0, 0.125, 0.101563, 3.25, 30.0, -30.0, 0.00012207, -0.00012207]> : tensor<10xf8E4M3B11FNUZ>
Expand All @@ -104,6 +112,14 @@ func.func @constant_op_test_f8_e4m3b11_fnuz() {

// -----

func.func @constant_op_test_f8_e4m3() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3>
check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.1015630, 3.25, 0x7F, 0xFF, 0.001953130, -0.001953130]> : tensor<10xf8E4M3>
func.return
}

// -----

func.func @constant_op_test_f8_e4m3_fn() {
%0 = stablehlo.constant dense<[0.0, -0.0, 1.0, 0.125, 0.1, 3.1415, 0x7F, 0xFF, 0x01, 0x81]> : tensor<10xf8E4M3FN>
check.expect_almost_eq_const %0, dense<[0.0, -0.0, 1.0, 0.125, 0.1015630, 3.25, 0x7F, 0xFF, 0.001953130, -0.001953130]> : tensor<10xf8E4M3FN>
Expand Down
23 changes: 23 additions & 0 deletions stablehlo/tests/interpret/dot_general.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,26 @@ func.func @dot_general_op_test_different_operand_and_result_element_types() {
[[5.0, 6.0], [7.0, 8.0]]]> : tensor<2x2x2xf64>
func.return
}

// -----

func.func @add_op_test_f8E3M4() {
%0 = stablehlo.constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf8E3M4>
%result = stablehlo.dot_general %0, %0,
contracting_dims = [0] x [0]
: (tensor<4xf8E3M4>, tensor<4xf8E3M4>) -> tensor<f8E3M4>
check.expect_almost_eq_const %result, dense<14.0> : tensor<f8E3M4>
func.return
}

// -----

func.func @add_op_test_f8E4M3() {
%0 = stablehlo.constant dense<[0.0, 1.0, 2.0, 3.0,
4.0, 5.0, 6.0, 7.0]> : tensor<8xf8E4M3>
%result = stablehlo.dot_general %0, %0,
contracting_dims = [0] x [0]
: (tensor<8xf8E4M3>, tensor<8xf8E4M3>) -> tensor<f8E4M3>
check.expect_almost_eq_const %result, dense<140.0> : tensor<f8E4M3>
func.return
}
Loading
Loading