From 542af68d10a02460dc0c5f8e5e75ae3256a22cbc Mon Sep 17 00:00:00 2001 From: primoly <168267431+primoly@users.noreply.github.com> Date: Fri, 19 Jul 2024 20:44:19 +0200 Subject: [PATCH] Cranelift: Constant propagate floats (#8954) * const propagate `fadd`, `fsub`, `fmul`, `fdiv` * add `sqrt`, `ceil`, `floor`, `trunc`, `nearest` * todo * bail if result is NaN * add `fmin`, `fmax` * `non_nan` helper methods * explain why no const folding of NaNs * use `f32`/`f64` `round_ties_even` methods Those methods are stable since Rust version 1.77.0 --- cranelift/codegen/src/ir/immediates.rs | 58 +++-- cranelift/codegen/src/isle_prelude.rs | 112 +++++++++ cranelift/codegen/src/opts/cprop.isle | 89 ++++++- cranelift/codegen/src/prelude.isle | 44 ++++ .../filetests/filetests/egraph/cprop.clif | 232 ++++++++++++++++++ 5 files changed, 502 insertions(+), 33 deletions(-) diff --git a/cranelift/codegen/src/ir/immediates.rs b/cranelift/codegen/src/ir/immediates.rs index ffc212e41f7b..24982c644c95 100644 --- a/cranelift/codegen/src/ir/immediates.rs +++ b/cranelift/codegen/src/ir/immediates.rs @@ -915,6 +915,11 @@ impl Ieee32 { self.as_f32().is_nan() } + /// Returns `None` if `self` is a NaN and `Some(self)` otherwise. + pub fn non_nan(self) -> Option { + Some(self).filter(|f| !f.is_nan()) + } + /// Converts Self to a rust f32 pub fn as_f32(self) -> f32 { f32::from_bits(self.0) @@ -927,17 +932,21 @@ impl Ieee32 { /// Computes the absolute value of self. pub fn abs(self) -> Self { - Self::with_float(self.as_f32().abs()) + Self(self.0 & !(1u32 << 31)) } /// Returns a number composed of the magnitude of self and the sign of sign. pub fn copysign(self, sign: Self) -> Self { - Self::with_float(self.as_f32().copysign(sign.as_f32())) + if self.is_negative() == sign.is_negative() { + self + } else { + self.neg() + } } /// Returns true if self has a negative sign, including -0.0, NaNs with negative sign bit and negative infinity. pub fn is_negative(&self) -> bool { - self.as_f32().is_sign_negative() + self.0 & (1 << 31) != 0 } /// Returns true if self is positive or negative zero @@ -963,17 +972,7 @@ impl Ieee32 { /// Returns the nearest integer to `self`. Rounds half-way cases to the number /// with an even least significant digit. pub fn round_ties_even(self) -> Self { - // TODO: Replace with the native implementation once - // https://github.com/rust-lang/rust/issues/96710 is stabilized - let toint_32: f32 = 1.0 / f32::EPSILON; - - let f = self.as_f32(); - let e = self.0 >> 23 & 0xff; - if e >= 0x7f_u32 + 23 { - self - } else { - Self::with_float((f.abs() + toint_32 - toint_32).copysign(f)) - } + Self::with_float(self.as_f32().round_ties_even()) } } @@ -1017,7 +1016,7 @@ impl Neg for Ieee32 { type Output = Ieee32; fn neg(self) -> Self::Output { - Self::with_float(self.as_f32().neg()) + Self(self.0 ^ (1 << 31)) } } @@ -1133,6 +1132,11 @@ impl Ieee64 { self.as_f64().is_nan() } + /// Returns `None` if `self` is a NaN and `Some(self)` otherwise. + pub fn non_nan(self) -> Option { + Some(self).filter(|f| !f.is_nan()) + } + /// Converts Self to a rust f64 pub fn as_f64(self) -> f64 { f64::from_bits(self.0) @@ -1145,17 +1149,21 @@ impl Ieee64 { /// Computes the absolute value of self. pub fn abs(self) -> Self { - Self::with_float(self.as_f64().abs()) + Self(self.0 & !(1u64 << 63)) } /// Returns a number composed of the magnitude of self and the sign of sign. pub fn copysign(self, sign: Self) -> Self { - Self::with_float(self.as_f64().copysign(sign.as_f64())) + if self.is_negative() == sign.is_negative() { + self + } else { + self.neg() + } } /// Returns true if self has a negative sign, including -0.0, NaNs with negative sign bit and negative infinity. pub fn is_negative(&self) -> bool { - self.as_f64().is_sign_negative() + self.0 & (1 << 63) != 0 } /// Returns true if self is positive or negative zero @@ -1181,17 +1189,7 @@ impl Ieee64 { /// Returns the nearest integer to `self`. Rounds half-way cases to the number /// with an even least significant digit. pub fn round_ties_even(self) -> Self { - // TODO: Replace with the native implementation once - // https://github.com/rust-lang/rust/issues/96710 is stabilized - let toint_64: f64 = 1.0 / f64::EPSILON; - - let f = self.as_f64(); - let e = self.0 >> 52 & 0x7ff_u64; - if e >= 0x3ff_u64 + 52 { - self - } else { - Self::with_float((f.abs() + toint_64 - toint_64).copysign(f)) - } + Self::with_float(self.as_f64().round_ties_even()) } } @@ -1241,7 +1239,7 @@ impl Neg for Ieee64 { type Output = Ieee64; fn neg(self) -> Self::Output { - Self::with_float(self.as_f64().neg()) + Self(self.0 ^ (1 << 63)) } } diff --git a/cranelift/codegen/src/isle_prelude.rs b/cranelift/codegen/src/isle_prelude.rs index f36c4bdb0a12..530cfb139b37 100644 --- a/cranelift/codegen/src/isle_prelude.rs +++ b/cranelift/codegen/src/isle_prelude.rs @@ -950,6 +950,62 @@ macro_rules! isle_common_prelude_methods { a.copysign(b) } + fn f32_add(&mut self, lhs: Ieee32, rhs: Ieee32) -> Option { + (lhs + rhs).non_nan() + } + + fn f32_sub(&mut self, lhs: Ieee32, rhs: Ieee32) -> Option { + (lhs - rhs).non_nan() + } + + fn f32_mul(&mut self, lhs: Ieee32, rhs: Ieee32) -> Option { + (lhs * rhs).non_nan() + } + + fn f32_div(&mut self, lhs: Ieee32, rhs: Ieee32) -> Option { + (lhs / rhs).non_nan() + } + + fn f32_sqrt(&mut self, n: Ieee32) -> Option { + n.sqrt().non_nan() + } + + fn f32_ceil(&mut self, n: Ieee32) -> Option { + n.ceil().non_nan() + } + + fn f32_floor(&mut self, n: Ieee32) -> Option { + n.floor().non_nan() + } + + fn f32_trunc(&mut self, n: Ieee32) -> Option { + n.trunc().non_nan() + } + + fn f32_nearest(&mut self, n: Ieee32) -> Option { + n.round_ties_even().non_nan() + } + + fn f32_min(&mut self, a: Ieee32, b: Ieee32) -> Option { + if a.is_nan() || b.is_nan() { + None + } else if a <= b { + Some(a) + } else { + Some(b) + } + } + + fn f32_max(&mut self, a: Ieee32, b: Ieee32) -> Option { + if a.is_nan() || b.is_nan() { + None + } else if a >= b { + Some(a) + } else { + Some(b) + } + } + fn f32_neg(&mut self, n: Ieee32) -> Ieee32 { n.neg() } @@ -962,6 +1018,62 @@ macro_rules! isle_common_prelude_methods { a.copysign(b) } + fn f64_add(&mut self, lhs: Ieee64, rhs: Ieee64) -> Option { + (lhs + rhs).non_nan() + } + + fn f64_sub(&mut self, lhs: Ieee64, rhs: Ieee64) -> Option { + (lhs - rhs).non_nan() + } + + fn f64_mul(&mut self, lhs: Ieee64, rhs: Ieee64) -> Option { + (lhs * rhs).non_nan() + } + + fn f64_div(&mut self, lhs: Ieee64, rhs: Ieee64) -> Option { + (lhs / rhs).non_nan() + } + + fn f64_sqrt(&mut self, n: Ieee64) -> Option { + n.sqrt().non_nan() + } + + fn f64_ceil(&mut self, n: Ieee64) -> Option { + n.ceil().non_nan() + } + + fn f64_floor(&mut self, n: Ieee64) -> Option { + n.floor().non_nan() + } + + fn f64_trunc(&mut self, n: Ieee64) -> Option { + n.trunc().non_nan() + } + + fn f64_nearest(&mut self, n: Ieee64) -> Option { + n.round_ties_even().non_nan() + } + + fn f64_min(&mut self, a: Ieee64, b: Ieee64) -> Option { + if a.is_nan() || b.is_nan() { + None + } else if a <= b { + Some(a) + } else { + Some(b) + } + } + + fn f64_max(&mut self, a: Ieee64, b: Ieee64) -> Option { + if a.is_nan() || b.is_nan() { + None + } else if a >= b { + Some(a) + } else { + Some(b) + } + } + fn f64_neg(&mut self, n: Ieee64) -> Ieee64 { n.neg() } diff --git a/cranelift/codegen/src/opts/cprop.isle b/cranelift/codegen/src/opts/cprop.isle index 9b31b0e7ebd5..a9685be75987 100644 --- a/cranelift/codegen/src/opts/cprop.isle +++ b/cranelift/codegen/src/opts/cprop.isle @@ -182,8 +182,6 @@ (if-let $true (u64_lt (i64_as_u64 (i64_neg k)) (i64_as_u64 k))) (iadd ty x (iconst ty (imm64_masked ty (i64_as_u64 (i64_neg k)))))) -;; TODO: fadd, fsub, fmul, fdiv, fneg, fabs - ;; A splat of a constant can become a direct `vconst` with the appropriate bit ;; pattern. (rule (simplify (splat dst (iconst $I8 n))) @@ -280,7 +278,92 @@ (decl pure u64_bswap64 (u64) u64) (extern constructor u64_bswap64 u64_bswap64) -;; Constant fold bitwise float operations (fneg/fabs/fcopysign) +;; Constant fold float operations +;; Note: With the exception of fabs, fneg and copysign, +;; constant folding is only performed when the result of +;; an instruction isn't NaN. We want the NaN bit patterns +;; produced by an instruction to be consistent, and +;; compile-time evaluation in a cross-compilation scenario +;; risks producing different NaN bit patterns than the target +;; would have at run-time. +;; TODO: fcmp, fma, demote, promote, to-int ops +(rule (simplify (fadd $F32 (f32const $F32 lhs) (f32const $F32 rhs))) + (if-let r (f32_add lhs rhs)) + (subsume (f32const $F32 r))) +(rule (simplify (fadd $F64 (f64const $F64 lhs) (f64const $F64 rhs))) + (if-let r (f64_add lhs rhs)) + (subsume (f64const $F64 r))) + +(rule (simplify (fsub $F32 (f32const $F32 lhs) (f32const $F32 rhs))) + (if-let r (f32_sub lhs rhs)) + (subsume (f32const $F32 r))) +(rule (simplify (fsub $F64 (f64const $F64 lhs) (f64const $F64 rhs))) + (if-let r (f64_sub lhs rhs)) + (subsume (f64const $F64 r))) + +(rule (simplify (fmul $F32 (f32const $F32 lhs) (f32const $F32 rhs))) + (if-let r (f32_mul lhs rhs)) + (subsume (f32const $F32 r))) +(rule (simplify (fmul $F64 (f64const $F64 lhs) (f64const $F64 rhs))) + (if-let r (f64_mul lhs rhs)) + (subsume (f64const $F64 r))) + +(rule (simplify (fdiv $F32 (f32const $F32 lhs) (f32const $F32 rhs))) + (if-let r (f32_div lhs rhs)) + (subsume (f32const $F32 r))) +(rule (simplify (fdiv $F64 (f64const $F64 lhs) (f64const $F64 rhs))) + (if-let r (f64_div lhs rhs)) + (subsume (f64const $F64 r))) + +(rule (simplify (sqrt $F32 (f32const $F32 n))) + (if-let r (f32_sqrt n)) + (subsume (f32const $F32 r))) +(rule (simplify (sqrt $F64 (f64const $F64 n))) + (if-let r (f64_sqrt n)) + (subsume (f64const $F64 r))) + +(rule (simplify (ceil $F32 (f32const $F32 n))) + (if-let r (f32_ceil n)) + (subsume (f32const $F32 r))) +(rule (simplify (ceil $F64 (f64const $F64 n))) + (if-let r (f64_ceil n)) + (subsume (f64const $F64 r))) + +(rule (simplify (floor $F32 (f32const $F32 n))) + (if-let r (f32_floor n)) + (subsume (f32const $F32 r))) +(rule (simplify (floor $F64 (f64const $F64 n))) + (if-let r (f64_floor n)) + (subsume (f64const $F64 r))) + +(rule (simplify (trunc $F32 (f32const $F32 n))) + (if-let r (f32_trunc n)) + (subsume (f32const $F32 r))) +(rule (simplify (trunc $F64 (f64const $F64 n))) + (if-let r (f64_trunc n)) + (subsume (f64const $F64 r))) + +(rule (simplify (nearest $F32 (f32const $F32 n))) + (if-let r (f32_nearest n)) + (subsume (f32const $F32 r))) +(rule (simplify (nearest $F64 (f64const $F64 n))) + (if-let r (f64_nearest n)) + (subsume (f64const $F64 r))) + +(rule (simplify (fmin $F32 (f32const $F32 n) (f32const $F32 m))) + (if-let r (f32_min n m)) + (subsume (f32const $F32 r))) +(rule (simplify (fmin $F64 (f64const $F64 n) (f64const $F64 m))) + (if-let r (f64_min n m)) + (subsume (f64const $F64 r))) + +(rule (simplify (fmax $F32 (f32const $F32 n) (f32const $F32 m))) + (if-let r (f32_max n m)) + (subsume (f32const $F32 r))) +(rule (simplify (fmax $F64 (f64const $F64 n) (f64const $F64 m))) + (if-let r (f64_max n m)) + (subsume (f64const $F64 r))) + (rule (simplify (fneg $F16 (f16const $F16 n))) (subsume (f16const $F16 (f16_neg n)))) (rule (simplify (fneg $F32 (f32const $F32 n))) diff --git a/cranelift/codegen/src/prelude.isle b/cranelift/codegen/src/prelude.isle index 95f31bfdbcb0..233ab899e374 100644 --- a/cranelift/codegen/src/prelude.isle +++ b/cranelift/codegen/src/prelude.isle @@ -245,12 +245,56 @@ (extern constructor f16_abs f16_abs) (decl pure f16_copysign (Ieee16 Ieee16) Ieee16) (extern constructor f16_copysign f16_copysign) +(decl pure partial f32_add (Ieee32 Ieee32) Ieee32) +(extern constructor f32_add f32_add) +(decl pure partial f32_sub (Ieee32 Ieee32) Ieee32) +(extern constructor f32_sub f32_sub) +(decl pure partial f32_mul (Ieee32 Ieee32) Ieee32) +(extern constructor f32_mul f32_mul) +(decl pure partial f32_div (Ieee32 Ieee32) Ieee32) +(extern constructor f32_div f32_div) +(decl pure partial f32_sqrt (Ieee32) Ieee32) +(extern constructor f32_sqrt f32_sqrt) +(decl pure partial f32_ceil (Ieee32) Ieee32) +(extern constructor f32_ceil f32_ceil) +(decl pure partial f32_floor (Ieee32) Ieee32) +(extern constructor f32_floor f32_floor) +(decl pure partial f32_trunc (Ieee32) Ieee32) +(extern constructor f32_trunc f32_trunc) +(decl pure partial f32_nearest (Ieee32) Ieee32) +(extern constructor f32_nearest f32_nearest) +(decl pure partial f32_min (Ieee32 Ieee32) Ieee32) +(extern constructor f32_min f32_min) +(decl pure partial f32_max (Ieee32 Ieee32) Ieee32) +(extern constructor f32_max f32_max) (decl pure f32_neg (Ieee32) Ieee32) (extern constructor f32_neg f32_neg) (decl pure f32_abs (Ieee32) Ieee32) (extern constructor f32_abs f32_abs) (decl pure f32_copysign (Ieee32 Ieee32) Ieee32) (extern constructor f32_copysign f32_copysign) +(decl pure partial f64_add (Ieee64 Ieee64) Ieee64) +(extern constructor f64_add f64_add) +(decl pure partial f64_sub (Ieee64 Ieee64) Ieee64) +(extern constructor f64_sub f64_sub) +(decl pure partial f64_mul (Ieee64 Ieee64) Ieee64) +(extern constructor f64_mul f64_mul) +(decl pure partial f64_div (Ieee64 Ieee64) Ieee64) +(extern constructor f64_div f64_div) +(decl pure partial f64_sqrt (Ieee64) Ieee64) +(extern constructor f64_sqrt f64_sqrt) +(decl pure partial f64_ceil (Ieee64) Ieee64) +(extern constructor f64_ceil f64_ceil) +(decl pure partial f64_floor (Ieee64) Ieee64) +(extern constructor f64_floor f64_floor) +(decl pure partial f64_trunc (Ieee64) Ieee64) +(extern constructor f64_trunc f64_trunc) +(decl pure partial f64_nearest (Ieee64) Ieee64) +(extern constructor f64_nearest f64_nearest) +(decl pure partial f64_min (Ieee64 Ieee64) Ieee64) +(extern constructor f64_min f64_min) +(decl pure partial f64_max (Ieee64 Ieee64) Ieee64) +(extern constructor f64_max f64_max) (decl pure f64_neg (Ieee64) Ieee64) (extern constructor f64_neg f64_neg) (decl pure f64_abs (Ieee64) Ieee64) diff --git a/cranelift/filetests/filetests/egraph/cprop.clif b/cranelift/filetests/filetests/egraph/cprop.clif index 70d026f41da7..f1ca491f4977 100644 --- a/cranelift/filetests/filetests/egraph/cprop.clif +++ b/cranelift/filetests/filetests/egraph/cprop.clif @@ -344,6 +344,122 @@ block0: ; check: v4 = f16const -NaN ; check: return v4 ; v4 = -NaN +function %f32_fadd() -> f32 { +block0: + v1 = f32const 0x1.fp2 + v2 = f32const 0x1.9p3 + v3 = fadd v1, v2 + return v3 +} + +; check: v4 = f32const 0x1.440000p4 +; check: return v4 ; v4 = 0x1.440000p4 + +function %f32_fsub() -> f32 { +block0: + v1 = f32const 0x1.fp2 + v2 = f32const 0x1.9p3 + v3 = fsub v1, v2 + return v3 +} + +; check: v4 = f32const -0x1.300000p2 +; check: return v4 ; v4 = -0x1.300000p2 + +function %f32_fmul() -> f32 { +block0: + v1 = f32const 0x1.fp2 + v2 = f32const 0x1.9p3 + v3 = fmul v1, v2 + return v3 +} + +; check: v4 = f32const 0x1.838000p6 +; check: return v4 ; v4 = 0x1.838000p6 + +function %f32_fdiv() -> f32 { +block0: + v1 = f32const 0x1.9p5 + v2 = f32const 0x1.9p3 + v3 = fdiv v1, v2 + return v3 +} + +; check: v4 = f32const 0x1.000000p2 +; check: return v4 ; v4 = 0x1.000000p2 + +function %f32_sqrt() -> f32 { +block0: + v1 = f32const 0x1.9p4 + v2 = sqrt v1 + return v2 +} + +; check: v3 = f32const 0x1.400000p2 +; check: return v3 ; v3 = 0x1.400000p2 + +function %f32_ceil() -> f32 { +block0: + v1 = f32const -0x1.9p3 + v2 = ceil v1 + return v2 +} + +; check: v3 = f32const -0x1.800000p3 +; check: return v3 ; v3 = -0x1.800000p3 + +function %f32_floor() -> f32 { +block0: + v1 = f32const -0x1.9p3 + v2 = floor v1 + return v2 +} + +; check: v3 = f32const -0x1.a00000p3 +; check: return v3 ; v3 = -0x1.a00000p3 + +function %f32_trunc() -> f32 { +block0: + v1 = f32const 0x1.9p3 + v2 = trunc v1 + return v2 +} + +; check: v3 = f32const 0x1.800000p3 +; check: return v3 ; v3 = 0x1.800000p3 + +function %f32_nearest() -> f32 { +block0: + v1 = f32const 0x1.9p3 + v2 = nearest v1 + return v2 +} + +; check: v3 = f32const 0x1.800000p3 +; check: return v3 ; v3 = 0x1.800000p3 + +function %f32_fmin() -> f32 { +block0: + v1 = f32const 0x1.5p6 + v2 = f32const 0x1.5p7 + v3 = fmin v2, v1 + return v3 +} + +; check: v4 = f32const 0x1.500000p6 +; check: return v4 ; v4 = 0x1.500000p6 + +function %f32_fmax() -> f32 { +block0: + v1 = f32const 0x1.5p6 + v2 = f32const 0x1.5p7 + v3 = fmax v2, v1 + return v3 +} + +; check: v4 = f32const 0x1.500000p7 +; check: return v4 ; v4 = 0x1.500000p7 + function %f32_fneg() -> f32 { block0: v1 = f32const 0.0 @@ -375,6 +491,122 @@ block0: ; check: v4 = f32const -NaN ; check: return v4 ; v4 = -NaN +function %f64_fadd() -> f64 { +block0: + v1 = f64const 0x1.fp2 + v2 = f64const 0x1.9p3 + v3 = fadd v1, v2 + return v3 +} + +; check: v4 = f64const 0x1.4400000000000p4 +; check: return v4 ; v4 = 0x1.4400000000000p4 + +function %f64_fsub() -> f64 { +block0: + v1 = f64const 0x1.fp2 + v2 = f64const 0x1.9p3 + v3 = fsub v1, v2 + return v3 +} + +; check: v4 = f64const -0x1.3000000000000p2 +; check: return v4 ; v4 = -0x1.3000000000000p2 + +function %f64_fmul() -> f64 { +block0: + v1 = f64const 0x1.fp2 + v2 = f64const 0x1.9p3 + v3 = fmul v1, v2 + return v3 +} + +; check: v4 = f64const 0x1.8380000000000p6 +; check: return v4 ; v4 = 0x1.8380000000000p6 + +function %f64_fdiv() -> f64 { +block0: + v1 = f64const 0x1.9p5 + v2 = f64const 0x1.9p3 + v3 = fdiv v1, v2 + return v3 +} + +; check: v4 = f64const 0x1.0000000000000p2 +; check: return v4 ; v4 = 0x1.0000000000000p2 + +function %f64_sqrt() -> f64 { +block0: + v1 = f64const 0x1.9p4 + v2 = sqrt v1 + return v2 +} + +; check: v3 = f64const 0x1.4000000000000p2 +; check: return v3 ; v3 = 0x1.4000000000000p2 + +function %f64_ceil() -> f64 { +block0: + v1 = f64const 0x1.9p3 + v2 = ceil v1 + return v2 +} + +; check: v3 = f64const 0x1.a000000000000p3 +; check: return v3 ; v3 = 0x1.a000000000000p3 + +function %f64_floor() -> f64 { +block0: + v1 = f64const 0x1.9p3 + v2 = floor v1 + return v2 +} + +; check: v3 = f64const 0x1.8000000000000p3 +; check: return v3 ; v3 = 0x1.8000000000000p3 + +function %f64_trunc() -> f64 { +block0: + v1 = f64const 0x1.9p3 + v2 = trunc v1 + return v2 +} + +; check: v3 = f64const 0x1.8000000000000p3 +; check: return v3 ; v3 = 0x1.8000000000000p3 + +function %f64_nearest() -> f64 { +block0: + v1 = f64const 0x1.9p3 + v2 = nearest v1 + return v2 +} + +; check: v3 = f64const 0x1.8000000000000p3 +; check: return v3 ; v3 = 0x1.8000000000000p3 + +function %f64_fmin() -> f64 { +block0: + v1 = f64const -0x1.5p6 + v2 = f64const -0x1.5p7 + v3 = fmin v2, v1 + return v3 +} + +; check: v4 = f64const -0x1.5000000000000p7 +; check: return v4 ; v4 = -0x1.5000000000000p7 + +function %f64_fmax() -> f64 { +block0: + v1 = f64const -0x1.5p6 + v2 = f64const -0x1.5p7 + v3 = fmax v2, v1 + return v3 +} + +; check: v4 = f64const -0x1.5000000000000p6 +; check: return v4 ; v4 = -0x1.5000000000000p6 + function %f64_fneg() -> f64 { block0: v1 = f64const 0.0