Skip to content

Commit

Permalink
Cranelift: Constant propagate floats (#8954)
Browse files Browse the repository at this point in the history
* const propagate `fadd`, `fsub`, `fmul`, `fdiv`

* add `sqrt`, `ceil`, `floor`, `trunc`, `nearest`

* todo

* bail if result is NaN

* add `fmin`, `fmax`

* `non_nan` helper methods

* explain why no const folding of NaNs

* use `f32`/`f64` `round_ties_even` methods

Those methods are stable since Rust version 1.77.0
  • Loading branch information
primoly authored Jul 19, 2024
1 parent 3d7a1c8 commit 542af68
Show file tree
Hide file tree
Showing 5 changed files with 502 additions and 33 deletions.
58 changes: 28 additions & 30 deletions cranelift/codegen/src/ir/immediates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,11 @@ impl Ieee32 {
self.as_f32().is_nan()
}

/// Returns `None` if `self` is a NaN and `Some(self)` otherwise.
pub fn non_nan(self) -> Option<Self> {
Some(self).filter(|f| !f.is_nan())
}

/// Converts Self to a rust f32
pub fn as_f32(self) -> f32 {
f32::from_bits(self.0)
Expand All @@ -927,17 +932,21 @@ impl Ieee32 {

/// Computes the absolute value of self.
pub fn abs(self) -> Self {
Self::with_float(self.as_f32().abs())
Self(self.0 & !(1u32 << 31))
}

/// Returns a number composed of the magnitude of self and the sign of sign.
pub fn copysign(self, sign: Self) -> Self {
Self::with_float(self.as_f32().copysign(sign.as_f32()))
if self.is_negative() == sign.is_negative() {
self
} else {
self.neg()
}
}

/// Returns true if self has a negative sign, including -0.0, NaNs with negative sign bit and negative infinity.
pub fn is_negative(&self) -> bool {
self.as_f32().is_sign_negative()
self.0 & (1 << 31) != 0
}

/// Returns true if self is positive or negative zero
Expand All @@ -963,17 +972,7 @@ impl Ieee32 {
/// Returns the nearest integer to `self`. Rounds half-way cases to the number
/// with an even least significant digit.
pub fn round_ties_even(self) -> Self {
// TODO: Replace with the native implementation once
// https://github.com/rust-lang/rust/issues/96710 is stabilized
let toint_32: f32 = 1.0 / f32::EPSILON;

let f = self.as_f32();
let e = self.0 >> 23 & 0xff;
if e >= 0x7f_u32 + 23 {
self
} else {
Self::with_float((f.abs() + toint_32 - toint_32).copysign(f))
}
Self::with_float(self.as_f32().round_ties_even())
}
}

Expand Down Expand Up @@ -1017,7 +1016,7 @@ impl Neg for Ieee32 {
type Output = Ieee32;

fn neg(self) -> Self::Output {
Self::with_float(self.as_f32().neg())
Self(self.0 ^ (1 << 31))
}
}

Expand Down Expand Up @@ -1133,6 +1132,11 @@ impl Ieee64 {
self.as_f64().is_nan()
}

/// Returns `None` if `self` is a NaN and `Some(self)` otherwise.
pub fn non_nan(self) -> Option<Self> {
Some(self).filter(|f| !f.is_nan())
}

/// Converts Self to a rust f64
pub fn as_f64(self) -> f64 {
f64::from_bits(self.0)
Expand All @@ -1145,17 +1149,21 @@ impl Ieee64 {

/// Computes the absolute value of self.
pub fn abs(self) -> Self {
Self::with_float(self.as_f64().abs())
Self(self.0 & !(1u64 << 63))
}

/// Returns a number composed of the magnitude of self and the sign of sign.
pub fn copysign(self, sign: Self) -> Self {
Self::with_float(self.as_f64().copysign(sign.as_f64()))
if self.is_negative() == sign.is_negative() {
self
} else {
self.neg()
}
}

/// Returns true if self has a negative sign, including -0.0, NaNs with negative sign bit and negative infinity.
pub fn is_negative(&self) -> bool {
self.as_f64().is_sign_negative()
self.0 & (1 << 63) != 0
}

/// Returns true if self is positive or negative zero
Expand All @@ -1181,17 +1189,7 @@ impl Ieee64 {
/// Returns the nearest integer to `self`. Rounds half-way cases to the number
/// with an even least significant digit.
pub fn round_ties_even(self) -> Self {
// TODO: Replace with the native implementation once
// https://github.com/rust-lang/rust/issues/96710 is stabilized
let toint_64: f64 = 1.0 / f64::EPSILON;

let f = self.as_f64();
let e = self.0 >> 52 & 0x7ff_u64;
if e >= 0x3ff_u64 + 52 {
self
} else {
Self::with_float((f.abs() + toint_64 - toint_64).copysign(f))
}
Self::with_float(self.as_f64().round_ties_even())
}
}

Expand Down Expand Up @@ -1241,7 +1239,7 @@ impl Neg for Ieee64 {
type Output = Ieee64;

fn neg(self) -> Self::Output {
Self::with_float(self.as_f64().neg())
Self(self.0 ^ (1 << 63))
}
}

Expand Down
112 changes: 112 additions & 0 deletions cranelift/codegen/src/isle_prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -950,6 +950,62 @@ macro_rules! isle_common_prelude_methods {
a.copysign(b)
}

fn f32_add(&mut self, lhs: Ieee32, rhs: Ieee32) -> Option<Ieee32> {
(lhs + rhs).non_nan()
}

fn f32_sub(&mut self, lhs: Ieee32, rhs: Ieee32) -> Option<Ieee32> {
(lhs - rhs).non_nan()
}

fn f32_mul(&mut self, lhs: Ieee32, rhs: Ieee32) -> Option<Ieee32> {
(lhs * rhs).non_nan()
}

fn f32_div(&mut self, lhs: Ieee32, rhs: Ieee32) -> Option<Ieee32> {
(lhs / rhs).non_nan()
}

fn f32_sqrt(&mut self, n: Ieee32) -> Option<Ieee32> {
n.sqrt().non_nan()
}

fn f32_ceil(&mut self, n: Ieee32) -> Option<Ieee32> {
n.ceil().non_nan()
}

fn f32_floor(&mut self, n: Ieee32) -> Option<Ieee32> {
n.floor().non_nan()
}

fn f32_trunc(&mut self, n: Ieee32) -> Option<Ieee32> {
n.trunc().non_nan()
}

fn f32_nearest(&mut self, n: Ieee32) -> Option<Ieee32> {
n.round_ties_even().non_nan()
}

fn f32_min(&mut self, a: Ieee32, b: Ieee32) -> Option<Ieee32> {
if a.is_nan() || b.is_nan() {
None
} else if a <= b {
Some(a)
} else {
Some(b)
}
}

fn f32_max(&mut self, a: Ieee32, b: Ieee32) -> Option<Ieee32> {
if a.is_nan() || b.is_nan() {
None
} else if a >= b {
Some(a)
} else {
Some(b)
}
}

fn f32_neg(&mut self, n: Ieee32) -> Ieee32 {
n.neg()
}
Expand All @@ -962,6 +1018,62 @@ macro_rules! isle_common_prelude_methods {
a.copysign(b)
}

fn f64_add(&mut self, lhs: Ieee64, rhs: Ieee64) -> Option<Ieee64> {
(lhs + rhs).non_nan()
}

fn f64_sub(&mut self, lhs: Ieee64, rhs: Ieee64) -> Option<Ieee64> {
(lhs - rhs).non_nan()
}

fn f64_mul(&mut self, lhs: Ieee64, rhs: Ieee64) -> Option<Ieee64> {
(lhs * rhs).non_nan()
}

fn f64_div(&mut self, lhs: Ieee64, rhs: Ieee64) -> Option<Ieee64> {
(lhs / rhs).non_nan()
}

fn f64_sqrt(&mut self, n: Ieee64) -> Option<Ieee64> {
n.sqrt().non_nan()
}

fn f64_ceil(&mut self, n: Ieee64) -> Option<Ieee64> {
n.ceil().non_nan()
}

fn f64_floor(&mut self, n: Ieee64) -> Option<Ieee64> {
n.floor().non_nan()
}

fn f64_trunc(&mut self, n: Ieee64) -> Option<Ieee64> {
n.trunc().non_nan()
}

fn f64_nearest(&mut self, n: Ieee64) -> Option<Ieee64> {
n.round_ties_even().non_nan()
}

fn f64_min(&mut self, a: Ieee64, b: Ieee64) -> Option<Ieee64> {
if a.is_nan() || b.is_nan() {
None
} else if a <= b {
Some(a)
} else {
Some(b)
}
}

fn f64_max(&mut self, a: Ieee64, b: Ieee64) -> Option<Ieee64> {
if a.is_nan() || b.is_nan() {
None
} else if a >= b {
Some(a)
} else {
Some(b)
}
}

fn f64_neg(&mut self, n: Ieee64) -> Ieee64 {
n.neg()
}
Expand Down
89 changes: 86 additions & 3 deletions cranelift/codegen/src/opts/cprop.isle
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@
(if-let $true (u64_lt (i64_as_u64 (i64_neg k)) (i64_as_u64 k)))
(iadd ty x (iconst ty (imm64_masked ty (i64_as_u64 (i64_neg k))))))

;; TODO: fadd, fsub, fmul, fdiv, fneg, fabs

;; A splat of a constant can become a direct `vconst` with the appropriate bit
;; pattern.
(rule (simplify (splat dst (iconst $I8 n)))
Expand Down Expand Up @@ -280,7 +278,92 @@
(decl pure u64_bswap64 (u64) u64)
(extern constructor u64_bswap64 u64_bswap64)

;; Constant fold bitwise float operations (fneg/fabs/fcopysign)
;; Constant fold float operations
;; Note: With the exception of fabs, fneg and copysign,
;; constant folding is only performed when the result of
;; an instruction isn't NaN. We want the NaN bit patterns
;; produced by an instruction to be consistent, and
;; compile-time evaluation in a cross-compilation scenario
;; risks producing different NaN bit patterns than the target
;; would have at run-time.
;; TODO: fcmp, fma, demote, promote, to-int ops
(rule (simplify (fadd $F32 (f32const $F32 lhs) (f32const $F32 rhs)))
(if-let r (f32_add lhs rhs))
(subsume (f32const $F32 r)))
(rule (simplify (fadd $F64 (f64const $F64 lhs) (f64const $F64 rhs)))
(if-let r (f64_add lhs rhs))
(subsume (f64const $F64 r)))

(rule (simplify (fsub $F32 (f32const $F32 lhs) (f32const $F32 rhs)))
(if-let r (f32_sub lhs rhs))
(subsume (f32const $F32 r)))
(rule (simplify (fsub $F64 (f64const $F64 lhs) (f64const $F64 rhs)))
(if-let r (f64_sub lhs rhs))
(subsume (f64const $F64 r)))

(rule (simplify (fmul $F32 (f32const $F32 lhs) (f32const $F32 rhs)))
(if-let r (f32_mul lhs rhs))
(subsume (f32const $F32 r)))
(rule (simplify (fmul $F64 (f64const $F64 lhs) (f64const $F64 rhs)))
(if-let r (f64_mul lhs rhs))
(subsume (f64const $F64 r)))

(rule (simplify (fdiv $F32 (f32const $F32 lhs) (f32const $F32 rhs)))
(if-let r (f32_div lhs rhs))
(subsume (f32const $F32 r)))
(rule (simplify (fdiv $F64 (f64const $F64 lhs) (f64const $F64 rhs)))
(if-let r (f64_div lhs rhs))
(subsume (f64const $F64 r)))

(rule (simplify (sqrt $F32 (f32const $F32 n)))
(if-let r (f32_sqrt n))
(subsume (f32const $F32 r)))
(rule (simplify (sqrt $F64 (f64const $F64 n)))
(if-let r (f64_sqrt n))
(subsume (f64const $F64 r)))

(rule (simplify (ceil $F32 (f32const $F32 n)))
(if-let r (f32_ceil n))
(subsume (f32const $F32 r)))
(rule (simplify (ceil $F64 (f64const $F64 n)))
(if-let r (f64_ceil n))
(subsume (f64const $F64 r)))

(rule (simplify (floor $F32 (f32const $F32 n)))
(if-let r (f32_floor n))
(subsume (f32const $F32 r)))
(rule (simplify (floor $F64 (f64const $F64 n)))
(if-let r (f64_floor n))
(subsume (f64const $F64 r)))

(rule (simplify (trunc $F32 (f32const $F32 n)))
(if-let r (f32_trunc n))
(subsume (f32const $F32 r)))
(rule (simplify (trunc $F64 (f64const $F64 n)))
(if-let r (f64_trunc n))
(subsume (f64const $F64 r)))

(rule (simplify (nearest $F32 (f32const $F32 n)))
(if-let r (f32_nearest n))
(subsume (f32const $F32 r)))
(rule (simplify (nearest $F64 (f64const $F64 n)))
(if-let r (f64_nearest n))
(subsume (f64const $F64 r)))

(rule (simplify (fmin $F32 (f32const $F32 n) (f32const $F32 m)))
(if-let r (f32_min n m))
(subsume (f32const $F32 r)))
(rule (simplify (fmin $F64 (f64const $F64 n) (f64const $F64 m)))
(if-let r (f64_min n m))
(subsume (f64const $F64 r)))

(rule (simplify (fmax $F32 (f32const $F32 n) (f32const $F32 m)))
(if-let r (f32_max n m))
(subsume (f32const $F32 r)))
(rule (simplify (fmax $F64 (f64const $F64 n) (f64const $F64 m)))
(if-let r (f64_max n m))
(subsume (f64const $F64 r)))

(rule (simplify (fneg $F16 (f16const $F16 n)))
(subsume (f16const $F16 (f16_neg n))))
(rule (simplify (fneg $F32 (f32const $F32 n)))
Expand Down
Loading

0 comments on commit 542af68

Please sign in to comment.