Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[APFloat] Add support for f8E3M4 IEEE 754 type #99698

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/lib/AST/MicrosoftMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ void MicrosoftCXXNameMangler::mangleFloat(llvm::APFloat Number) {
case APFloat::S_Float8E5M2FNUZ:
case APFloat::S_Float8E4M3FNUZ:
case APFloat::S_Float8E4M3B11FNUZ:
case APFloat::S_Float8E3M4:
case APFloat::S_FloatTF32:
case APFloat::S_Float6E3M2FN:
case APFloat::S_Float6E2M3FN:
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/ADT/APFloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@ struct APFloatBase {
// This format's exponent bias is 11, instead of the 7 (2 ** (4 - 1) - 1)
// that IEEE precedent would imply.
S_Float8E4M3B11FNUZ,
// 8-bit floating point number following IEEE-754 conventions with bit
// layout S1E3M4.
S_Float8E3M4,
// Floating point number that occupies 32 bits or less of storage, providing
// improved range compared to half (16-bit) formats, at (potentially)
// greater throughput than single precision (32-bit) formats.
Expand Down Expand Up @@ -224,6 +227,7 @@ struct APFloatBase {
static const fltSemantics &Float8E4M3FN() LLVM_READNONE;
static const fltSemantics &Float8E4M3FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
static const fltSemantics &Float8E3M4() LLVM_READNONE;
static const fltSemantics &FloatTF32() LLVM_READNONE;
static const fltSemantics &Float6E3M2FN() LLVM_READNONE;
static const fltSemantics &Float6E2M3FN() LLVM_READNONE;
Expand Down Expand Up @@ -646,6 +650,7 @@ class IEEEFloat final : public APFloatBase {
APInt convertFloat8E4M3FNAPFloatToAPInt() const;
APInt convertFloat8E4M3FNUZAPFloatToAPInt() const;
APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
APInt convertFloat8E3M4APFloatToAPInt() const;
APInt convertFloatTF32APFloatToAPInt() const;
APInt convertFloat6E3M2FNAPFloatToAPInt() const;
APInt convertFloat6E2M3FNAPFloatToAPInt() const;
Expand All @@ -665,6 +670,7 @@ class IEEEFloat final : public APFloatBase {
void initFromFloat8E4M3FNAPInt(const APInt &api);
void initFromFloat8E4M3FNUZAPInt(const APInt &api);
void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
void initFromFloat8E3M4APInt(const APInt &api);
void initFromFloatTF32APInt(const APInt &api);
void initFromFloat6E3M2FNAPInt(const APInt &api);
void initFromFloat6E2M3FNAPInt(const APInt &api);
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/Support/APFloat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ static constexpr fltSemantics semFloat8E4M3FNUZ = {
7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8};
static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19};
static constexpr fltSemantics semFloat6E3M2FN = {
4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
Expand Down Expand Up @@ -217,6 +218,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
return Float8E4M3FNUZ();
case S_Float8E4M3B11FNUZ:
return Float8E4M3B11FNUZ();
case S_Float8E3M4:
return Float8E3M4();
case S_FloatTF32:
return FloatTF32();
case S_Float6E3M2FN:
Expand Down Expand Up @@ -257,6 +260,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
return S_Float8E4M3FNUZ;
else if (&Sem == &llvm::APFloat::Float8E4M3B11FNUZ())
return S_Float8E4M3B11FNUZ;
else if (&Sem == &llvm::APFloat::Float8E3M4())
return S_Float8E3M4;
else if (&Sem == &llvm::APFloat::FloatTF32())
return S_FloatTF32;
else if (&Sem == &llvm::APFloat::Float6E3M2FN())
Expand Down Expand Up @@ -287,6 +292,7 @@ const fltSemantics &APFloatBase::Float8E4M3FNUZ() { return semFloat8E4M3FNUZ; }
const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
return semFloat8E4M3B11FNUZ;
}
const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; }
const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
Expand Down Expand Up @@ -3643,6 +3649,11 @@ APInt IEEEFloat::convertFloat8E4M3B11FNUZAPFloatToAPInt() const {
return convertIEEEFloatToAPInt<semFloat8E4M3B11FNUZ>();
}

APInt IEEEFloat::convertFloat8E3M4APFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloat8E3M4>();
}

APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const {
assert(partCount() == 1);
return convertIEEEFloatToAPInt<semFloatTF32>();
Expand Down Expand Up @@ -3704,6 +3715,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
if (semantics == (const llvm::fltSemantics *)&semFloat8E4M3B11FNUZ)
return convertFloat8E4M3B11FNUZAPFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloat8E3M4)
return convertFloat8E3M4APFloatToAPInt();

if (semantics == (const llvm::fltSemantics *)&semFloatTF32)
return convertFloatTF32APFloatToAPInt();

Expand Down Expand Up @@ -3932,6 +3946,10 @@ void IEEEFloat::initFromFloat8E4M3B11FNUZAPInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E4M3B11FNUZ>(api);
}

void IEEEFloat::initFromFloat8E3M4APInt(const APInt &api) {
initFromIEEEAPInt<semFloat8E3M4>(api);
}

void IEEEFloat::initFromFloatTF32APInt(const APInt &api) {
initFromIEEEAPInt<semFloatTF32>(api);
}
Expand Down Expand Up @@ -3977,6 +3995,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
return initFromFloat8E4M3FNUZAPInt(api);
if (Sem == &semFloat8E4M3B11FNUZ)
return initFromFloat8E4M3B11FNUZAPInt(api);
if (Sem == &semFloat8E3M4)
return initFromFloat8E3M4APInt(api);
if (Sem == &semFloatTF32)
return initFromFloatTF32APInt(api);
if (Sem == &semFloat6E3M2FN)
Expand Down
81 changes: 81 additions & 0 deletions llvm/unittests/ADT/APFloatTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2141,6 +2141,8 @@ TEST(APFloatTest, getZero) {
{&APFloat::Float8E4M3FNUZ(), true, false, {0, 0}, 1},
{&APFloat::Float8E4M3B11FNUZ(), false, false, {0, 0}, 1},
{&APFloat::Float8E4M3B11FNUZ(), true, false, {0, 0}, 1},
{&APFloat::Float8E3M4(), false, true, {0, 0}, 1},
{&APFloat::Float8E3M4(), true, true, {0x80ULL, 0}, 1},
{&APFloat::FloatTF32(), false, true, {0, 0}, 1},
{&APFloat::FloatTF32(), true, true, {0x40000ULL, 0}, 1},
{&APFloat::Float6E3M2FN(), false, true, {0, 0}, 1},
Expand Down Expand Up @@ -6636,6 +6638,45 @@ TEST(APFloatTest, Float8E4M3FNUZToDouble) {
EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
}

TEST(APFloatTest, Float8E3M4ToDouble) {
APFloat PosZero = APFloat::getZero(APFloat::Float8E3M4(), false);
APFloat PosZeroToDouble(PosZero.convertToDouble());
EXPECT_TRUE(PosZeroToDouble.isPosZero());
APFloat NegZero = APFloat::getZero(APFloat::Float8E3M4(), true);
APFloat NegZeroToDouble(NegZero.convertToDouble());
EXPECT_TRUE(NegZeroToDouble.isNegZero());

APFloat One(APFloat::Float8E3M4(), "1.0");
EXPECT_EQ(1.0, One.convertToDouble());
APFloat Two(APFloat::Float8E3M4(), "2.0");
EXPECT_EQ(2.0, Two.convertToDouble());
APFloat PosLargest = APFloat::getLargest(APFloat::Float8E3M4(), false);
EXPECT_EQ(15.5F, PosLargest.convertToDouble());
APFloat NegLargest = APFloat::getLargest(APFloat::Float8E3M4(), true);
EXPECT_EQ(-15.5F, NegLargest.convertToDouble());
APFloat PosSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E3M4(), false);
EXPECT_EQ(0x1.p-2, PosSmallest.convertToDouble());
APFloat NegSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E3M4(), true);
EXPECT_EQ(-0x1.p-2, NegSmallest.convertToDouble());

APFloat PosSmallestDenorm =
APFloat::getSmallest(APFloat::Float8E3M4(), false);
EXPECT_TRUE(PosSmallestDenorm.isDenormal());
EXPECT_EQ(0x1.p-6, PosSmallestDenorm.convertToDouble());
APFloat NegSmallestDenorm = APFloat::getSmallest(APFloat::Float8E3M4(), true);
EXPECT_TRUE(NegSmallestDenorm.isDenormal());
EXPECT_EQ(-0x1.p-6, NegSmallestDenorm.convertToDouble());

APFloat PosInf = APFloat::getInf(APFloat::Float8E3M4());
EXPECT_EQ(std::numeric_limits<double>::infinity(), PosInf.convertToDouble());
APFloat NegInf = APFloat::getInf(APFloat::Float8E3M4(), true);
EXPECT_EQ(-std::numeric_limits<double>::infinity(), NegInf.convertToDouble());
APFloat QNaN = APFloat::getQNaN(APFloat::Float8E3M4());
EXPECT_TRUE(std::isnan(QNaN.convertToDouble()));
}

TEST(APFloatTest, FloatTF32ToDouble) {
APFloat One(APFloat::FloatTF32(), "1.0");
EXPECT_EQ(1.0, One.convertToDouble());
Expand Down Expand Up @@ -6944,6 +6985,46 @@ TEST(APFloatTest, Float8E4M3FNToFloat) {
EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
}

TEST(APFloatTest, Float8E3M4ToFloat) {
APFloat PosZero = APFloat::getZero(APFloat::Float8E3M4(), false);
APFloat PosZeroToFloat(PosZero.convertToFloat());
EXPECT_TRUE(PosZeroToFloat.isPosZero());
APFloat NegZero = APFloat::getZero(APFloat::Float8E3M4(), true);
APFloat NegZeroToFloat(NegZero.convertToFloat());
EXPECT_TRUE(NegZeroToFloat.isNegZero());

APFloat One(APFloat::Float8E3M4(), "1.0");
EXPECT_EQ(1.0F, One.convertToFloat());
APFloat Two(APFloat::Float8E3M4(), "2.0");
EXPECT_EQ(2.0F, Two.convertToFloat());

APFloat PosLargest = APFloat::getLargest(APFloat::Float8E3M4(), false);
EXPECT_EQ(15.5F, PosLargest.convertToFloat());
APFloat NegLargest = APFloat::getLargest(APFloat::Float8E3M4(), true);
EXPECT_EQ(-15.5F, NegLargest.convertToFloat());
APFloat PosSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E3M4(), false);
EXPECT_EQ(0x1.p-2, PosSmallest.convertToFloat());
APFloat NegSmallest =
APFloat::getSmallestNormalized(APFloat::Float8E3M4(), true);
EXPECT_EQ(-0x1.p-2, NegSmallest.convertToFloat());

APFloat PosSmallestDenorm =
APFloat::getSmallest(APFloat::Float8E3M4(), false);
EXPECT_TRUE(PosSmallestDenorm.isDenormal());
EXPECT_EQ(0x1.p-6, PosSmallestDenorm.convertToFloat());
APFloat NegSmallestDenorm = APFloat::getSmallest(APFloat::Float8E3M4(), true);
EXPECT_TRUE(NegSmallestDenorm.isDenormal());
EXPECT_EQ(-0x1.p-6, NegSmallestDenorm.convertToFloat());

APFloat PosInf = APFloat::getInf(APFloat::Float8E3M4());
EXPECT_EQ(std::numeric_limits<float>::infinity(), PosInf.convertToFloat());
APFloat NegInf = APFloat::getInf(APFloat::Float8E3M4(), true);
EXPECT_EQ(-std::numeric_limits<float>::infinity(), NegInf.convertToFloat());
APFloat QNaN = APFloat::getQNaN(APFloat::Float8E3M4());
EXPECT_TRUE(std::isnan(QNaN.convertToFloat()));
}

TEST(APFloatTest, FloatTF32ToFloat) {
APFloat PosZero = APFloat::getZero(APFloat::FloatTF32());
APFloat PosZeroToFloat(PosZero.convertToFloat());
Expand Down
Loading