diff --git a/compiler/frontend/src/mir_optimize/constant_folding.rs b/compiler/frontend/src/mir_optimize/constant_folding.rs index e91cad15c..8ea6161fc 100644 --- a/compiler/frontend/src/mir_optimize/constant_folding.rs +++ b/compiler/frontend/src/mir_optimize/constant_folding.rs @@ -41,7 +41,7 @@ use crate::{ use itertools::Itertools; use num_bigint::BigInt; use num_integer::Integer; -use num_traits::{ToPrimitive, Zero}; +use num_traits::{One, ToPrimitive, Zero}; use std::{ borrow::Cow, cmp::Ordering, @@ -171,9 +171,14 @@ fn run_builtin( } BuiltinFunction::IntAdd => { let [a, b] = arguments else { unreachable!() }; - let a: &BigInt = visible.get(*a).try_into().ok()?; - let b: &BigInt = visible.get(*b).try_into().ok()?; - (a + b).into() + match (visible.get(*a), visible.get(*b)) { + // 0 + b = b + (Expression::Int(a), _) if a.is_zero() => Expression::Reference(*b), + // a + 0 = a + (_, Expression::Int(b)) if b.is_zero() => Expression::Reference(*a), + (Expression::Int(a), Expression::Int(b)) => (a + b).into(), + _ => return None, + } } BuiltinFunction::IntBitLength => { let [a] = arguments else { unreachable!() }; @@ -186,9 +191,14 @@ fn run_builtin( return Some(Expression::Reference(*a)); } - let a: &BigInt = visible.get(*a).try_into().ok()?; - let b: &BigInt = visible.get(*b).try_into().ok()?; - (a & b).into() + match (visible.get(*a), visible.get(*b)) { + // 0 & b = a & 0 = 0 + (Expression::Int(zero), _) | (_, Expression::Int(zero)) if zero.is_zero() => { + 0.into() + } + (Expression::Int(a), Expression::Int(b)) => (a & b).into(), + _ => return None, + } } BuiltinFunction::IntBitwiseOr => { let [a, b] = arguments else { unreachable!() }; @@ -196,9 +206,14 @@ fn run_builtin( return Some(Expression::Reference(*a)); } - let a: &BigInt = visible.get(*a).try_into().ok()?; - let b: &BigInt = visible.get(*b).try_into().ok()?; - (a | b).into() + match (visible.get(*a), visible.get(*b)) { + // 0 | b = b + (Expression::Int(a), _) if a.is_zero() => Expression::Reference(*b), + // a | 0 = a + (_, Expression::Int(b)) if b.is_zero() => Expression::Reference(*a), + (Expression::Int(a), Expression::Int(b)) => (a | b).into(), + _ => return None, + } } BuiltinFunction::IntBitwiseXor => { let [a, b] = arguments else { unreachable!() }; @@ -206,9 +221,14 @@ fn run_builtin( return Some(0.into()); } - let a: &BigInt = visible.get(*a).try_into().ok()?; - let b: &BigInt = visible.get(*b).try_into().ok()?; - (a ^ b).into() + match (visible.get(*a), visible.get(*b)) { + // 0 ^ b = b + (Expression::Int(a), _) if a.is_zero() => Expression::Reference(*b), + // a ^ 0 = a + (_, Expression::Int(b)) if b.is_zero() => Expression::Reference(*a), + (Expression::Int(a), Expression::Int(b)) => (a ^ b).into(), + _ => return None, + } } BuiltinFunction::IntCompareTo => { let [a, b] = arguments else { unreachable!() }; @@ -228,9 +248,16 @@ fn run_builtin( return Some(1.into()); } - let dividend: &BigInt = visible.get(*dividend).try_into().ok()?; - let divisor: &BigInt = visible.get(*divisor).try_into().ok()?; - (dividend / divisor).into() + match (visible.get(*dividend), visible.get(*divisor)) { + // dividend / 1 = dividend + (_, Expression::Int(divisor)) if divisor.is_one() => { + Expression::Reference(*dividend) + } + (Expression::Int(dividend), Expression::Int(divisor)) => { + (dividend / divisor).into() + } + _ => return None, + } } BuiltinFunction::IntModulo => { let [dividend, divisor] = arguments else { @@ -248,9 +275,20 @@ fn run_builtin( let [factor_a, factor_b] = arguments else { unreachable!() }; - let factor_a: &BigInt = visible.get(*factor_a).try_into().ok()?; - let factor_b: &BigInt = visible.get(*factor_b).try_into().ok()?; - (factor_a * factor_b).into() + match (visible.get(*factor_a), visible.get(*factor_b)) { + // 1 * factor_b = factor_b + (Expression::Int(factor_a), _) if factor_a.is_one() => { + Expression::Reference(*factor_b) + } + // factor_a * 1 = factor_a + (_, Expression::Int(factor_b)) if factor_b.is_one() => { + Expression::Reference(*factor_a) + } + (Expression::Int(factor_a), Expression::Int(factor_b)) => { + (factor_a * factor_b).into() + } + _ => return None, + } } BuiltinFunction::IntParse => { let [text] = arguments else { unreachable!() }; @@ -280,29 +318,35 @@ fn run_builtin( let [value, amount] = arguments else { unreachable!() }; - let amount: &BigInt = visible.get(*amount).try_into().ok()?; - // TODO: Support larger shift amounts. - let amount: u128 = amount.try_into().unwrap(); - if amount == 0 { - return Some(value.into()); + match (visible.get(*value), visible.get(*amount)) { + // value << 0 = value + (_, Expression::Int(amount)) if amount.is_zero() => Expression::Reference(*value), + // 0 << amount = 0 + (Expression::Int(value), _) if value.is_zero() => 0.into(), + (Expression::Int(value), Expression::Int(amount)) => { + // TODO: Support larger shift amounts. + let amount: u128 = amount.try_into().unwrap(); + (value << amount).into() + } + _ => return None, } - - let value: &BigInt = visible.get(*value).try_into().ok()?; - (value << amount).into() } BuiltinFunction::IntShiftRight => { let [value, amount] = arguments else { unreachable!() }; - let amount: &BigInt = visible.get(*amount).try_into().ok()?; - // TODO: Support larger shift amounts. - let amount: u128 = amount.try_into().unwrap(); - if amount == 0 { - return Some(value.into()); + match (visible.get(*value), visible.get(*amount)) { + // value >> 0 = value + (_, Expression::Int(amount)) if amount.is_zero() => Expression::Reference(*value), + // 0 >> amount = 0 + (Expression::Int(value), _) if value.is_zero() => 0.into(), + (Expression::Int(value), Expression::Int(amount)) => { + // TODO: Support larger shift amounts. + let amount: u128 = amount.try_into().unwrap(); + (value >> amount).into() + } + _ => return None, } - - let value: &BigInt = visible.get(*value).try_into().ok()?; - (value >> amount).into() } BuiltinFunction::IntSubtract => { let [minuend, subtrahend] = arguments else { @@ -312,9 +356,16 @@ fn run_builtin( return Some(Expression::Int(0.into())); } - let minuend: &BigInt = visible.get(*minuend).try_into().ok()?; - let subtrahend: &BigInt = visible.get(*subtrahend).try_into().ok()?; - (minuend - subtrahend).into() + match (visible.get(*minuend), visible.get(*subtrahend)) { + // minuend - 0 = minuend + (_, Expression::Int(subtrahend)) if subtrahend.is_zero() => { + Expression::Reference(*minuend) + } + (Expression::Int(minuend), Expression::Int(subtrahend)) => { + (minuend - subtrahend).into() + } + _ => return None, + } } BuiltinFunction::ListFilled => { let [length, item] = arguments else {