Skip to content

Commit

Permalink
[APFloat] Add support for f8E4M3 IEEE 754 type (#97179)
Browse files Browse the repository at this point in the history
This PR adds `f8E4M3` type to APFloat.

`f8E4M3` type  follows IEEE 754 convention

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

Related PRs:
- [PR-97118](#97118) Add f8E4M3
IEEE 754 type to mlir
  • Loading branch information
apivovarov committed Jul 18, 2024
1 parent 1e6672a commit f363317
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 3 deletions.
6 changes: 3 additions & 3 deletions clang/include/clang/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,10 @@ class alignas(void *) Stmt {
unsigned : NumExprBits;

static_assert(
llvm::APFloat::S_MaxSemantics < 16,
"Too many Semantics enum values to fit in bitfield of size 4");
llvm::APFloat::S_MaxSemantics < 32,
"Too many Semantics enum values to fit in bitfield of size 5");
LLVM_PREFERRED_TYPE(llvm::APFloat::Semantics)
unsigned Semantics : 4; // Provides semantics for APFloat construction
unsigned Semantics : 5; // Provides semantics for APFloat construction
LLVM_PREFERRED_TYPE(bool)
unsigned IsExact : 1;
};
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/MicrosoftMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
case APFloat::S_IEEEquad: Out << 'Y'; break;
case APFloat::S_PPCDoubleDouble: Out << 'Z'; break;
case APFloat::S_Float8E5M2:
case APFloat::S_Float8E4M3:
case APFloat::S_Float8E4M3FN:
case APFloat::S_Float8E5M2FNUZ:
case APFloat::S_Float8E4M3FNUZ:
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/ADT/APFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ struct APFloatBase {
// This format's exponent bias is 16, instead of the 15 (2 ** (5 - 1) - 1)
// that IEEE precedent would imply.
S_Float8E5M2FNUZ,
// 8-bit floating point number following IEEE-754 conventions with bit
// layout S1E4M3.
S_Float8E4M3,
// 8-bit floating point number mostly following IEEE-754 conventions with
// bit layout S1E4M3 as described in https://arxiv.org/abs/2209.05433.
// Unlike IEEE-754 types, there are no infinity values, and NaN is
Expand Down Expand Up @@ -217,6 +220,7 @@ struct APFloatBase {
static const fltSemantics &PPCDoubleDouble() LLVM_READNONE;
static const fltSemantics &Float8E5M2() LLVM_READNONE;
static const fltSemantics &Float8E5M2FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3() LLVM_READNONE;
static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
Expand Down Expand Up @@ -638,6 +642,7 @@ class IEEEFloat final : public APFloatBase {
APInt convertPPCDoubleDoubleAPFloatToAPInt() const;
APInt convertFloat8E5M2APFloatToAPInt() const;
APInt convertFloat8E5M2FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3APFloatToAPInt() const;
APInt convertFloat8E4M3FNAPFloatToAPInt() const;
APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
Expand All @@ -656,6 +661,7 @@ class IEEEFloat final : public APFloatBase {
void initFromPPCDoubleDoubleAPInt(const APInt &api);
void initFromFloat8E5M2APInt(const APInt &api);
void initFromFloat8E5M2FNUZAPInt(const APInt &api);
void initFromFloat8E4M3APInt(const APInt &api);
void initFromFloat8E4M3FNAPInt(const APInt &api);
void initFromFloat8E4M3FNUZAPInt(const APInt &api);
void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/Support/APFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ static constexpr fltSemantics semIEEEquad = {16383, -16382, 113, 128};
static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
static constexpr fltSemantics semFloat8E5M2FNUZ = {
15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3 = {7, -6, 4, 8};
static constexpr fltSemantics semFloat8E4M3FN = {
8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static constexpr fltSemantics semFloat8E4M3FNUZ = {
Expand Down Expand Up @@ -208,6 +209,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float8E5M2();
case S_Float8E5M2FNUZ:
return Float8E5M2FNUZ();
case S_Float8E4M3:
return Float8E4M3();
case S_Float8E4M3FN:
return Float8E4M3FN();
case S_Float8E4M3FNUZ:
Expand Down Expand Up @@ -246,6 +249,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float8E5M2;
else if (&Sem == &llvm::APFloat::Float8E5M2FNUZ())
return S_Float8E5M2FNUZ;
else if (&Sem == &llvm::APFloat::Float8E4M3())
return S_Float8E4M3;
else if (&Sem == &llvm::APFloat::Float8E4M3FN())
return S_Float8E4M3FN;
else if (&Sem == &llvm::APFloat::Float8E4M3FNUZ())
Expand Down Expand Up @@ -276,6 +281,7 @@ const fltSemantics &APFloatBase::PPCDoubleDouble() {
}
const fltSemantics &APFloatBase::Float8E5M2() { return semFloat8E5M2; }
const fltSemantics &APFloatBase::Float8E5M2FNUZ() { return semFloat8E5M2FNUZ; }
const fltSemantics &APFloatBase::Float8E4M3() { return semFloat8E4M3; }
const fltSemantics &APFloatBase::Float8E4M3FN() { return semFloat8E4M3FN; }
const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; }
const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
Expand Down Expand Up @@ -3617,6 +3623,11 @@ APInt IEEEFloat::convertFloat8E5M2FNUZAPFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloat8E5M2FNUZ>();
}

APInt IEEEFloat::convertFloat8E4M3APFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat8E4M3>();
}

APInt IEEEFloat::convertFloat8E4M3FNAPFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat8E4M3FN>();
Expand Down Expand Up @@ -3681,6 +3692,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat8E5M2FNUZ)
return convertFloat8E5M2FNUZAPFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3)
return convertFloat8E4M3APFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3FN)
return convertFloat8E4M3FNAPFloatToAPInt();

Expand Down Expand Up @@ -3902,6 +3916,10 @@ void IEEEFloat::initFromFloat8E5M2FNUZAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E5M2FNUZ>(api);
}

void IEEEFloat::initFromFloat8E4M3APInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E4M3>(api);
}

void IEEEFloat::initFromFloat8E4M3FNAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E4M3FN>(api);
}
Expand Down Expand Up @@ -3951,6 +3969,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E5M2APInt(api);
if (Sem == &semFloat8E5M2FNUZ)
return initFromFloat8E5M2FNUZAPInt(api);
if (Sem == &semFloat8E4M3)
return initFromFloat8E4M3APInt(api);
if (Sem == &semFloat8E4M3FN)
return initFromFloat8E4M3FNAPInt(api);
if (Sem == &semFloat8E4M3FNUZ)
Expand Down
66 changes: 66 additions & 0 deletions llvm/unittests/ADT/APFloatTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,8 @@ TEST(APFloatTest, getZero) {
{&APFloat::Float8E5M2(), true, true, {0x80ULL, 0}, 1},
{&APFloat::Float8E5M2FNUZ(), false, false, {0, 0}, 1},
{&APFloat::Float8E5M2FNUZ(), true, false, {0, 0}, 1},
{&APFloat::Float8E4M3(), false, true, {0, 0}, 1},
{&APFloat::Float8E4M3(), true, true, {0x80ULL, 0}, 1},
{&APFloat::Float8E4M3FN(), false, true, {0, 0}, 1},
{&APFloat::Float8E4M3FN(), true, true, {0x80ULL, 0}, 1},
{&APFloat::Float8E4M3FNUZ(), false, false, {0, 0}, 1},
Expand Down Expand Up @@ -6532,6 +6534,34 @@ TEST(APFloatTest, Float8E5M2ToDouble) {
EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
}

TEST(APFloatTest, Float8E4M3ToDouble) {
APFloat One(APFloat::Float8E4M3(), "1.0");
EXPECT_EQ(1.0, One.convertToDouble());
APFloat Two(APFloat::Float8E4M3(), "2.0");
EXPECT_EQ(2.0, Two.convertToDouble());
APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3(), false);
EXPECT_EQ(240.0F, PosLargest.convertToDouble());
APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3(), true);
EXPECT_EQ(-240.0F, NegLargest.convertToDouble());
APFloat PosSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E4M3(), false);
EXPECT_EQ(0x1.p-6, PosSmallest.convertToDouble());
APFloat NegSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E4M3(), true);
EXPECT_EQ(-0x1.p-6, NegSmallest.convertToDouble());

APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3(), false);
EXPECT_TRUE(SmallestDenorm.isDenormal());
EXPECT_EQ(0x1.p-9, SmallestDenorm.convertToDouble());

APFloat PosInf = APFloat::getInf(APFloat::Float8E4M3());
EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble());
APFloat NegInf = APFloat::getInf(APFloat::Float8E4M3(), true);
EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble());
APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3());
EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
}

TEST(APFloatTest, Float8E4M3FNToDouble) {
APFloat One(APFloat::Float8E4M3FN(), "1.0");
EXPECT_EQ(1.0, One.convertToDouble());
Expand Down Expand Up @@ -6846,6 +6876,42 @@ TEST(APFloatTest, Float8E5M2ToFloat) {
EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
}

TEST(APFloatTest, Float8E4M3ToFloat) {
APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3());
APFloat PosZeroToFloat(PosZero.convertToFloat());
EXPECT_TRUE(PosZeroToFloat.isPosZero());
APFloat NegZero = APFloat::getZero(APFloat::Float8E4M3(), true);
APFloat NegZeroToFloat(NegZero.convertToFloat());
EXPECT_TRUE(NegZeroToFloat.isNegZero());

APFloat One(APFloat::Float8E4M3(), "1.0");
EXPECT_EQ(1.0F, One.convertToFloat());
APFloat Two(APFloat::Float8E4M3(), "2.0");
EXPECT_EQ(2.0F, Two.convertToFloat());

APFloat PosLargest = APFloat::getLargest(APFloat::Float8E4M3(), false);
EXPECT_EQ(240.0F, PosLargest.convertToFloat());
APFloat NegLargest = APFloat::getLargest(APFloat::Float8E4M3(), true);
EXPECT_EQ(-240.0F, NegLargest.convertToFloat());
APFloat PosSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E4M3(), false);
EXPECT_EQ(0x1.p-6, PosSmallest.convertToFloat());
APFloat NegSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E4M3(), true);
EXPECT_EQ(-0x1.p-6, NegSmallest.convertToFloat());

APFloat SmallestDenorm = APFloat::getSmallest(APFloat::Float8E4M3(), false);
EXPECT_TRUE(SmallestDenorm.isDenormal());
EXPECT_EQ(0x1.p-9, SmallestDenorm.convertToFloat());

APFloat PosInf = APFloat::getInf(APFloat::Float8E4M3());
EXPECT_EQ(std::numeric_limits<float>::infinity(), PosInf.convertToFloat());
APFloat NegInf = APFloat::getInf(APFloat::Float8E4M3(), true);
EXPECT_EQ(-std::numeric_limits<float>::infinity(), NegInf.convertToFloat());
APFloat QNaN = APFloat::getQNaN(APFloat::Float8E4M3());
EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
}

TEST(APFloatTest, Float8E4M3FNToFloat) {
APFloat PosZero = APFloat::getZero(APFloat::Float8E4M3FN());
APFloat PosZeroToFloat(PosZero.convertToFloat());
Expand Down

0 comments on commit f363317

Please sign in to comment.