diff --git a/Cargo.lock b/Cargo.lock index b5e177d..548ccb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,7 +106,7 @@ dependencies = [ [[package]] name = "darklua" -version = "0.3.5" +version = "0.3.6" dependencies = [ "insta 0.12.0 (registry+https://github.com/rust-lang/crates.io-index)", "json5 0.2.5 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/Cargo.toml b/Cargo.toml index 64c7a5d..261f26b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "darklua" -version = "0.3.5" +version = "0.3.6" authors = ["jeparlefrancais "] edition = "2018" readme = "README.md" diff --git a/RULES.md b/RULES.md index c02d390..25b0d99 100644 --- a/RULES.md +++ b/RULES.md @@ -2,6 +2,7 @@ You can find the available rules and their properties here. The default rule stack is: + - [Compute expressions](#compute-expressions) - [Convert local functions to assignments](#convert-local-functions-to-assignments) - [Group local assignments](#group-local-assignments) - [Remove empty do statements](#remove-empty-do-statements) @@ -17,6 +18,30 @@ There are also other rules available for more processing: --- +## Compute expression +```compute_expression``` + +This rule computes expressions and replaces them with their result. An expression will not be replaced if it has any side-effects. This rule is influenced by the evaluation system of darklua. As its capacity increases, the rule will be able to compute more complex expressions. For example, if you use this rule on the following code: + +```lua +return 1 + 1 +``` + +Will produce the following code: + +```lua +return 2 +``` + +### Examples +```json5 +{ + rule: 'compute_expression', +} +``` + +--- + ## Convert local functions to assignments ```convert_local_function_to_assign``` diff --git a/src/nodes/expressions/binary.rs b/src/nodes/expressions/binary.rs index 8033abc..ab06cca 100644 --- a/src/nodes/expressions/binary.rs +++ b/src/nodes/expressions/binary.rs @@ -27,7 +27,7 @@ impl BinaryOperator { } #[inline] - pub fn preceeds_unary_expression(&self) -> bool { + pub fn precedes_unary_expression(&self) -> bool { match self { Self::Caret => true, _ => false, @@ -50,6 +50,36 @@ impl BinaryOperator { } } + pub fn left_needs_parentheses(&self, left: &Expression) -> bool { + match left { + Expression::Binary(left) => { + if self.is_left_associative() { + self.precedes(left.operator()) + } else { + !left.operator().precedes(*self) + } + } + Expression::Unary(_) => { + self.precedes_unary_expression() + } + _ => false, + } + } + + pub fn right_needs_parentheses(&self, right: &Expression) -> bool { + match right { + Expression::Binary(right) => { + if self.is_right_associative() { + self.precedes(right.operator()) + } else { + !right.operator().precedes(*self) + } + } + Expression::Unary(_) => false, + _ => false, + } + } + fn get_precedence(&self) -> u8 { match self { Self::Or => 0, @@ -113,22 +143,27 @@ impl BinaryExpression { } } + #[inline] pub fn mutate_left(&mut self) -> &mut Expression { &mut self.left } + #[inline] pub fn mutate_right(&mut self) -> &mut Expression { &mut self.right } + #[inline] pub fn left(&self) -> &Expression { &self.left } + #[inline] pub fn right(&self) -> &Expression { &self.right } + #[inline] pub fn operator(&self) -> BinaryOperator { self.operator } @@ -136,9 +171,23 @@ impl BinaryExpression { impl ToLua for BinaryExpression { fn to_lua(&self, generator: &mut LuaGenerator) { - self.left.to_lua(generator); + if self.operator.left_needs_parentheses(&self.left) { + generator.push_char('('); + self.left.to_lua(generator); + generator.push_char(')'); + } else { + self.left.to_lua(generator); + } + self.operator.to_lua(generator); - self.right.to_lua(generator); + + if self.operator.right_needs_parentheses(&self.right) { + generator.push_char('('); + self.right.to_lua(generator); + generator.push_char(')'); + } else { + self.right.to_lua(generator); + } } } @@ -168,7 +217,7 @@ mod test { assert!(Caret.precedes(Percent)); assert!(Caret.precedes(Concat)); assert!(!Caret.precedes(Caret)); - assert!(Caret.preceeds_unary_expression()); + assert!(Caret.precedes_unary_expression()); } #[test] @@ -188,7 +237,7 @@ mod test { assert!(!Asterisk.precedes(Percent)); assert!(Asterisk.precedes(Concat)); assert!(!Asterisk.precedes(Caret)); - assert!(!Asterisk.preceeds_unary_expression()); + assert!(!Asterisk.precedes_unary_expression()); } #[test] @@ -208,7 +257,7 @@ mod test { assert!(!Slash.precedes(Percent)); assert!(Slash.precedes(Concat)); assert!(!Slash.precedes(Caret)); - assert!(!Slash.preceeds_unary_expression()); + assert!(!Slash.precedes_unary_expression()); } #[test] @@ -228,7 +277,7 @@ mod test { assert!(!Percent.precedes(Percent)); assert!(Percent.precedes(Concat)); assert!(!Percent.precedes(Caret)); - assert!(!Percent.preceeds_unary_expression()); + assert!(!Percent.precedes_unary_expression()); } #[test] @@ -248,7 +297,7 @@ mod test { assert!(!Plus.precedes(Percent)); assert!(Plus.precedes(Concat)); assert!(!Plus.precedes(Caret)); - assert!(!Plus.preceeds_unary_expression()); + assert!(!Plus.precedes_unary_expression()); } #[test] @@ -268,7 +317,7 @@ mod test { assert!(!Minus.precedes(Percent)); assert!(Minus.precedes(Concat)); assert!(!Minus.precedes(Caret)); - assert!(!Minus.preceeds_unary_expression()); + assert!(!Minus.precedes_unary_expression()); } #[test] @@ -288,7 +337,7 @@ mod test { assert!(!Concat.precedes(Percent)); assert!(!Concat.precedes(Concat)); assert!(!Concat.precedes(Caret)); - assert!(!Concat.preceeds_unary_expression()); + assert!(!Concat.precedes_unary_expression()); } #[test] @@ -308,7 +357,146 @@ mod test { assert!(!And.precedes(Percent)); assert!(!And.precedes(Concat)); assert!(!And.precedes(Caret)); - assert!(!And.preceeds_unary_expression()); + assert!(!And.precedes_unary_expression()); + } + } + + mod to_lua { + use super::*; + + use crate::nodes::{DecimalNumber, UnaryExpression, UnaryOperator}; + + #[test] + fn left_associative_wraps_left_operand_if_has_lower_precedence() { + let expression = BinaryExpression::new( + BinaryOperator::Asterisk, + DecimalNumber::new(2.0).into(), + BinaryExpression::new( + BinaryOperator::Plus, + DecimalNumber::new(1.0).into(), + DecimalNumber::new(3.0).into(), + ).into(), + ); + + assert_eq!("2*(1+3)", expression.to_lua_string()); + } + + #[test] + fn left_associative_wraps_right_operand_if_has_lower_precedence() { + let expression = BinaryExpression::new( + BinaryOperator::And, + Expression::False, + BinaryExpression::new( + BinaryOperator::Or, + Expression::False, + Expression::True, + ).into(), + ); + + assert_eq!("false and(false or true)", expression.to_lua_string()); + } + + #[test] + fn left_associative_wraps_right_operand_if_has_same_precedence() { + let expression = BinaryExpression::new( + BinaryOperator::Equal, + Expression::True, + BinaryExpression::new( + BinaryOperator::LowerThan, + DecimalNumber::new(1.0).into(), + DecimalNumber::new(2.0).into(), + ).into(), + ); + + assert_eq!("true==(1<2)", expression.to_lua_string()); + } + + #[test] + fn right_associative_wrap_unary_left_operand_if_has_lower_precedence() { + let expression = BinaryExpression::new( + BinaryOperator::Caret, + UnaryExpression::new( + UnaryOperator::Minus, + DecimalNumber::new(2.0).into(), + ).into(), + DecimalNumber::new(2.0).into(), + ); + + assert_eq!("(-2)^2", expression.to_lua_string()); + } + + #[test] + fn right_associative_wraps_left_operand_if_has_lower_precedence() { + let expression = BinaryExpression::new( + BinaryOperator::Caret, + BinaryExpression::new( + BinaryOperator::Plus, + DecimalNumber::new(1.0).into(), + DecimalNumber::new(2.0).into(), + ).into(), + DecimalNumber::new(3.0).into(), + ); + + assert_eq!("(1+2)^3", expression.to_lua_string()); + } + + #[test] + fn right_associative_wraps_left_operand_if_has_same_precedence() { + let expression = BinaryExpression::new( + BinaryOperator::Caret, + BinaryExpression::new( + BinaryOperator::Caret, + DecimalNumber::new(2.0).into(), + DecimalNumber::new(2.0).into(), + ).into(), + DecimalNumber::new(3.0).into(), + ); + + assert_eq!("(2^2)^3", expression.to_lua_string()); + } + + #[test] + fn right_associative_does_not_wrap_right_operand_if_unary() { + let expression = BinaryExpression::new( + BinaryOperator::Caret, + DecimalNumber::new(2.0).into(), + UnaryExpression::new( + UnaryOperator::Minus, + DecimalNumber::new(2.0).into(), + ).into(), + ); + + assert_eq!("2^-2", expression.to_lua_string()); + } + + #[test] + fn right_associative_does_not_wrap_right_operand_if_has_same_precedence() { + let expression = BinaryExpression::new( + BinaryOperator::Caret, + DecimalNumber::new(2.0).into(), + BinaryExpression::new( + BinaryOperator::Caret, + DecimalNumber::new(2.0).into(), + DecimalNumber::new(3.0).into(), + ).into(), + ); + + assert_eq!("2^2^3", expression.to_lua_string()); + } + + #[test] + fn right_associative_does_not_wrap_right_operand_if_has_higher_precedence() { + let expression = BinaryExpression::new( + BinaryOperator::Concat, + DecimalNumber::new(3.0).into(), + BinaryExpression::new( + BinaryOperator::Plus, + DecimalNumber::new(9.0).into(), + DecimalNumber::new(3.0).into(), + ).into(), + ); + + assert_eq!("3 ..9+3", expression.to_lua_string()); } } diff --git a/src/nodes/expressions/mod.rs b/src/nodes/expressions/mod.rs index 440fdcd..16fda1c 100644 --- a/src/nodes/expressions/mod.rs +++ b/src/nodes/expressions/mod.rs @@ -21,6 +21,8 @@ pub use unary::*; use crate::lua_generator::{LuaGenerator, ToLua}; use crate::nodes::FunctionCall; +use std::num::FpCategory; + #[derive(Clone, Debug, PartialEq, Eq)] pub enum Expression { Binary(Box), @@ -48,10 +50,55 @@ impl From for Expression { impl From for Expression { fn from(value: f64) -> Expression { - if value < 0.0 { - UnaryExpression::new(UnaryOperator::Minus, Expression::from(value.abs())).into() - } else { - DecimalNumber::new(value).into() + match value.classify() { + FpCategory::Nan => { + BinaryExpression::new( + BinaryOperator::Slash, + DecimalNumber::new(0.0).into(), + DecimalNumber::new(0.0).into(), + ).into() + } + FpCategory::Infinite => { + BinaryExpression::new( + BinaryOperator::Slash, + Expression::from(if value.is_sign_positive() { 1.0 } else { -1.0 }), + DecimalNumber::new(0.0).into(), + ).into() + } + FpCategory::Zero => { + DecimalNumber::new(0.0).into() + } + FpCategory::Subnormal | FpCategory::Normal => { + if value < 0.0 { + UnaryExpression::new( + UnaryOperator::Minus, + Expression::from(value.abs()), + ).into() + } else { + if value < 0.1 { + let exponent = value.log10().floor(); + let new_value = value / 10_f64.powf(exponent); + + DecimalNumber::new(new_value) + .with_exponent(exponent as i64, true) + .into() + } else if value > 999.0 && (value / 100.0).fract() == 0.0 { + let mut exponent = value.log10().floor(); + let mut power = 10_f64.powf(exponent); + + while exponent > 2.0 && (value / power).fract() != 0.0 { + exponent -= 1.0; + power /= 10.0; + } + + DecimalNumber::new(value / power) + .with_exponent(exponent as i64, true) + .into() + } else { + DecimalNumber::new(value).into() + } + } + } } } } @@ -176,6 +223,53 @@ impl ToLua for Expression { mod test { use super::*; + mod numbers { + use super::*; + + macro_rules! snapshots { + ($($name:ident($input:expr)),+) => { + $( + mod $name { + use super::*; + use insta::assert_snapshot; + use insta::assert_debug_snapshot; + + #[test] + fn expression() { + assert_debug_snapshot!( + "expression", + Expression::from($input) + ); + } + + #[test] + fn lua() { + assert_snapshot!( + "lua_float", + Expression::from($input).to_lua_string() + ); + } + } + )+ + }; + } + + snapshots!( + snaphshot_1(1.0), + snaphshot_0_5(0.5), + snaphshot_123(123.0), + snaphshot_0_005(0.005), + snaphshot_nan(0.0/0.0), + snaphshot_positive_infinity(1.0/0.0), + snaphshot_negative_infinity(-1.0/0.0), + snaphshot_very_small(1.2345e-50), + snapshot_thousand(1000.0), + snaphshot_very_large(1.2345e50), + snapshot_float_below_thousand(100.25), + snapshot_float_above_thousand(2000.05) + ); + } + mod to_lua { use super::*; diff --git a/src/nodes/expressions/number.rs b/src/nodes/expressions/number.rs index a871b89..073f366 100644 --- a/src/nodes/expressions/number.rs +++ b/src/nodes/expressions/number.rs @@ -1,5 +1,8 @@ use crate::lua_generator::{LuaGenerator, ToLua}; +use std::fmt::{Display, Formatter, Result as FmtResult}; +use std::str::FromStr; + #[derive(Clone, Debug, PartialEq)] pub struct DecimalNumber { float: f64, @@ -36,14 +39,31 @@ impl DecimalNumber { impl ToLua for DecimalNumber { fn to_lua(&self, generator: &mut LuaGenerator) { - let mut number = format!("{}", self.float); + if self.float.is_nan() { + generator.push_char('('); + generator.push_char('0'); + generator.push_char('/'); + generator.push_char('0'); + generator.push_char(')'); + } else if self.float.is_infinite() { + generator.push_char('('); + if self.float.is_sign_negative() { + generator.push_char('-'); + } + generator.push_char('1'); + generator.push_char('/'); + generator.push_char('0'); + generator.push_char(')'); + } else { + let mut number = format!("{:.}", self.float); - if let Some((exponent, is_uppercase)) = &self.exponent { - number.push(if *is_uppercase { 'E' } else { 'e' }); - number.push_str(&format!("{}", exponent)); - }; + if let Some((exponent, is_uppercase)) = &self.exponent { + number.push(if *is_uppercase { 'E' } else { 'e' }); + number.push_str(&format!("{}", exponent)); + }; - generator.push_str(&number); + generator.push_str(&number); + } } } @@ -136,6 +156,97 @@ impl From for NumberExpression { } } +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NumberParsingError { + InvalidHexadecimalNumber, + InvalidHexadecimalExponent, + InvalidDecimalNumber, + InvalidDecimalExponent, +} + +impl Display for NumberParsingError { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + use NumberParsingError::*; + + match self { + InvalidHexadecimalNumber => write!(f, "could not parse hexadecimal number"), + InvalidHexadecimalExponent => write!(f, "could not parse hexadecimal exponent"), + InvalidDecimalNumber => write!(f, "could not parse decimal number"), + InvalidDecimalExponent => write!(f, "could not parse decimal exponent"), + } + } +} + +impl FromStr for NumberExpression { + type Err = NumberParsingError; + + fn from_str(value: &str) -> Result { + let number = if value.starts_with("0x") || value.starts_with("0X") { + let is_x_uppercase = value.chars().nth(1) + .map(char::is_uppercase) + .unwrap_or(false); + + if let Some(index) = value.find("p") { + let exponent = value.get(index + 1..) + .and_then(|string| string.parse().ok()) + .ok_or(Self::Err::InvalidHexadecimalExponent)?; + let number = u64::from_str_radix(value.get(2..index).unwrap(), 16) + .map_err(|_| Self::Err::InvalidHexadecimalNumber)?; + + HexNumber::new(number, is_x_uppercase) + .with_exponent(exponent, false) + + } else if let Some(index) = value.find("P") { + let exponent = value.get(index + 1..) + .and_then(|string| string.parse().ok()) + .ok_or(Self::Err::InvalidHexadecimalExponent)?; + let number = u64::from_str_radix(value.get(2..index).unwrap(), 16) + .map_err(|_| Self::Err::InvalidHexadecimalNumber)?; + + HexNumber::new(number, is_x_uppercase) + .with_exponent(exponent, true) + } else { + let number = u64::from_str_radix(value.get(2..) + .unwrap(), 16) + .map_err(|_| Self::Err::InvalidHexadecimalNumber)?; + + HexNumber::new(number, is_x_uppercase) + }.into() + + } else { + if let Some(index) = value.find("e") { + let exponent = value.get(index + 1..) + .and_then(|string| string.parse().ok()) + .ok_or(Self::Err::InvalidDecimalExponent)?; + let number = value.get(0..index) + .and_then(|string| string.parse().ok()) + .ok_or(Self::Err::InvalidDecimalNumber)?; + + DecimalNumber::new(number) + .with_exponent(exponent, false) + + } else if let Some(index) = value.find("E") { + let exponent: i64 = value.get(index + 1..) + .and_then(|string| string.parse().ok()) + .ok_or(Self::Err::InvalidDecimalExponent)?; + let number = value.get(0..index) + .and_then(|string| string.parse().ok()) + .ok_or(Self::Err::InvalidDecimalNumber)?; + + DecimalNumber::new(number) + .with_exponent(exponent, true) + } else { + let number = value.parse::() + .map_err(|_| Self::Err::InvalidDecimalNumber)?; + + DecimalNumber::new(number) + }.into() + }; + + Ok(number) + } +} + impl ToLua for NumberExpression { fn to_lua(&self, generator: &mut LuaGenerator) { match self { @@ -157,7 +268,7 @@ mod test { $( #[test] fn $name() { - let number = NumberExpression::from($input.to_owned()); + let number = NumberExpression::from($input); assert_eq!(number.to_lua_string(), $value); } )* @@ -174,6 +285,71 @@ mod test { ); } + mod parse_number { + use super::*; + + macro_rules! test_numbers { + ($($name:ident($input:literal) => $expect:expr),+) => { + $( + #[test] + fn $name() { + let result: NumberExpression = $input.parse() + .expect("should be a valid number"); + + let expect: NumberExpression = $expect.into(); + + assert_eq!(result, expect); + } + )+ + }; + } + + macro_rules! test_parse_errors { + ($($name:ident($input:literal) => $expect:expr),+) => { + $( + #[test] + fn $name() { + let result = $input.parse::() + .expect_err("should be an invalid number"); + + assert_eq!(result, $expect); + } + )+ + }; + } + + test_numbers!( + parse_zero("0") => DecimalNumber::new(0_f64), + parse_integer("123") => DecimalNumber::new(123_f64), + parse_multiple_decimal("123.24") => DecimalNumber::new(123.24_f64), + parse_float_with_trailing_dot("123.") => DecimalNumber::new(123_f64), + parse_starting_with_dot(".123") => DecimalNumber::new(0.123_f64), + parse_digit_with_exponent("1e10") => DecimalNumber::new(1_f64).with_exponent(10, false), + parse_number_with_exponent("123e456") => DecimalNumber::new(123_f64).with_exponent(456, false), + parse_number_with_exponent_and_plus_symbol("123e+456") => DecimalNumber::new(123_f64).with_exponent(456, false), + parse_number_with_negative_exponent("123e-456") => DecimalNumber::new(123_f64).with_exponent(-456, false), + parse_number_with_upper_exponent("123E4") => DecimalNumber::new(123_f64).with_exponent(4, true), + parse_number_with_upper_negative_exponent("123E-456") => DecimalNumber::new(123_f64).with_exponent(-456, true), + parse_float_with_exponent("10.12e8") => DecimalNumber::new(10.12_f64).with_exponent(8, false), + parse_trailing_dot_with_exponent("10.e8") => DecimalNumber::new(10_f64).with_exponent(8, false), + parse_hex_number("0x12") => HexNumber::new(18, false), + parse_uppercase_hex_number("0X12") => HexNumber::new(18, true), + parse_hex_number_with_lowercase("0x12a") => HexNumber::new(298, false), + parse_hex_number_with_uppercase("0x12A") => HexNumber::new(298, false), + parse_hex_number_with_mixed_case("0x1bF2A") => HexNumber::new(114_474, false), + parse_hex_with_exponent("0x12p4") => HexNumber::new(18, false).with_exponent(4, false), + parse_hex_with_exponent_uppercase("0xABP3") => HexNumber::new(171, false).with_exponent(3, true) + ); + + test_parse_errors!( + parse_empty_string("") => NumberParsingError::InvalidDecimalNumber, + missing_exponent_value("1e") => NumberParsingError::InvalidDecimalExponent, + missing_negative_exponent_value("1e-") => NumberParsingError::InvalidDecimalExponent, + missing_hex_exponent_value("0x1p") => NumberParsingError::InvalidHexadecimalExponent, + negative_hex_exponent_value("0x1p-3") => NumberParsingError::InvalidHexadecimalExponent + ); + } + mod compute_value { use super::*; @@ -182,7 +358,7 @@ mod test { $( #[test] fn $name() { - let number = NumberExpression::from($input.to_owned()); + let number = NumberExpression::from($input); assert_eq!(number.compute_value(), $value as f64); } )* diff --git a/src/nodes/expressions/snapshots/snaphshot_0_005__expression.snap b/src/nodes/expressions/snapshots/snaphshot_0_005__expression.snap new file mode 100644 index 0000000..a31d0f7 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_0_005__expression.snap @@ -0,0 +1,17 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(0.005)" +--- +Number( + Decimal( + DecimalNumber { + float: 5.0, + exponent: Some( + ( + -3, + true, + ), + ), + }, + ), +) diff --git a/src/nodes/expressions/snapshots/snaphshot_0_005__lua_float.snap b/src/nodes/expressions/snapshots/snaphshot_0_005__lua_float.snap new file mode 100644 index 0000000..6940713 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_0_005__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(0.005).to_lua_string()" +--- +5E-3 diff --git a/src/nodes/expressions/snapshots/snaphshot_0_5__expression.snap b/src/nodes/expressions/snapshots/snaphshot_0_5__expression.snap new file mode 100644 index 0000000..40cad83 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_0_5__expression.snap @@ -0,0 +1,12 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(0.5)" +--- +Number( + Decimal( + DecimalNumber { + float: 0.5, + exponent: None, + }, + ), +) diff --git a/src/nodes/expressions/snapshots/snaphshot_0_5__lua_float.snap b/src/nodes/expressions/snapshots/snaphshot_0_5__lua_float.snap new file mode 100644 index 0000000..2505b4f --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_0_5__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(0.5).to_lua_string()" +--- +0.5 diff --git a/src/nodes/expressions/snapshots/snaphshot_123__expression.snap b/src/nodes/expressions/snapshots/snaphshot_123__expression.snap new file mode 100644 index 0000000..2f187bf --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_123__expression.snap @@ -0,0 +1,12 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(123.0)" +--- +Number( + Decimal( + DecimalNumber { + float: 123.0, + exponent: None, + }, + ), +) diff --git a/src/nodes/expressions/snapshots/snaphshot_123__lua_float.snap b/src/nodes/expressions/snapshots/snaphshot_123__lua_float.snap new file mode 100644 index 0000000..9787848 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_123__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(123.0).to_lua_string()" +--- +123 diff --git a/src/nodes/expressions/snapshots/snaphshot_1__expression.snap b/src/nodes/expressions/snapshots/snaphshot_1__expression.snap new file mode 100644 index 0000000..1b653dc --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_1__expression.snap @@ -0,0 +1,12 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1.0)" +--- +Number( + Decimal( + DecimalNumber { + float: 1.0, + exponent: None, + }, + ), +) diff --git a/src/nodes/expressions/snapshots/snaphshot_1__lua_float.snap b/src/nodes/expressions/snapshots/snaphshot_1__lua_float.snap new file mode 100644 index 0000000..a38925e --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_1__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1.0).to_lua_string()" +--- +1 diff --git a/src/nodes/expressions/snapshots/snaphshot_nan__expression.snap b/src/nodes/expressions/snapshots/snaphshot_nan__expression.snap new file mode 100644 index 0000000..945db10 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_nan__expression.snap @@ -0,0 +1,25 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(0.0 / 0.0)" +--- +Binary( + BinaryExpression { + operator: Slash, + left: Number( + Decimal( + DecimalNumber { + float: 0.0, + exponent: None, + }, + ), + ), + right: Number( + Decimal( + DecimalNumber { + float: 0.0, + exponent: None, + }, + ), + ), + }, +) diff --git a/src/nodes/expressions/snapshots/snaphshot_nan__lua_float.snap b/src/nodes/expressions/snapshots/snaphshot_nan__lua_float.snap new file mode 100644 index 0000000..62556e2 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_nan__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(0.0 / 0.0).to_lua_string()" +--- +0/0 diff --git a/src/nodes/expressions/snapshots/snaphshot_negative_infinity__expression.snap b/src/nodes/expressions/snapshots/snaphshot_negative_infinity__expression.snap new file mode 100644 index 0000000..ce1cbe3 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_negative_infinity__expression.snap @@ -0,0 +1,30 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(-1.0 / 0.0)" +--- +Binary( + BinaryExpression { + operator: Slash, + left: Unary( + UnaryExpression { + operator: Minus, + expression: Number( + Decimal( + DecimalNumber { + float: 1.0, + exponent: None, + }, + ), + ), + }, + ), + right: Number( + Decimal( + DecimalNumber { + float: 0.0, + exponent: None, + }, + ), + ), + }, +) diff --git a/src/nodes/expressions/snapshots/snaphshot_negative_infinity__lua_float.snap b/src/nodes/expressions/snapshots/snaphshot_negative_infinity__lua_float.snap new file mode 100644 index 0000000..7f334e6 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_negative_infinity__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(-1.0 / 0.0).to_lua_string()" +--- +-1/0 diff --git a/src/nodes/expressions/snapshots/snaphshot_positive_infinity__expression.snap b/src/nodes/expressions/snapshots/snaphshot_positive_infinity__expression.snap new file mode 100644 index 0000000..6389e19 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_positive_infinity__expression.snap @@ -0,0 +1,25 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1.0 / 0.0)" +--- +Binary( + BinaryExpression { + operator: Slash, + left: Number( + Decimal( + DecimalNumber { + float: 1.0, + exponent: None, + }, + ), + ), + right: Number( + Decimal( + DecimalNumber { + float: 0.0, + exponent: None, + }, + ), + ), + }, +) diff --git a/src/nodes/expressions/snapshots/snaphshot_positive_infinity__lua_float.snap b/src/nodes/expressions/snapshots/snaphshot_positive_infinity__lua_float.snap new file mode 100644 index 0000000..b825948 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_positive_infinity__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1.0 / 0.0).to_lua_string()" +--- +1/0 diff --git a/src/nodes/expressions/snapshots/snaphshot_very_large__expression.snap b/src/nodes/expressions/snapshots/snaphshot_very_large__expression.snap new file mode 100644 index 0000000..5c72c91 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_very_large__expression.snap @@ -0,0 +1,17 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1.2345e50)" +--- +Number( + Decimal( + DecimalNumber { + float: 12345.0, + exponent: Some( + ( + 46, + true, + ), + ), + }, + ), +) diff --git a/src/nodes/expressions/snapshots/snaphshot_very_large__lua_float.snap b/src/nodes/expressions/snapshots/snaphshot_very_large__lua_float.snap new file mode 100644 index 0000000..f0af568 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_very_large__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1.2345e50).to_lua_string()" +--- +12345E46 diff --git a/src/nodes/expressions/snapshots/snaphshot_very_small__expression.snap b/src/nodes/expressions/snapshots/snaphshot_very_small__expression.snap new file mode 100644 index 0000000..7ed172b --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_very_small__expression.snap @@ -0,0 +1,17 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1.2345e-50)" +--- +Number( + Decimal( + DecimalNumber { + float: 1.2345, + exponent: Some( + ( + -50, + true, + ), + ), + }, + ), +) diff --git a/src/nodes/expressions/snapshots/snaphshot_very_small__lua_float.snap b/src/nodes/expressions/snapshots/snaphshot_very_small__lua_float.snap new file mode 100644 index 0000000..e4b67c8 --- /dev/null +++ b/src/nodes/expressions/snapshots/snaphshot_very_small__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1.2345e-50).to_lua_string()" +--- +1.2345E-50 diff --git a/src/nodes/expressions/snapshots/snapshot_float_above_thousand__expression.snap b/src/nodes/expressions/snapshots/snapshot_float_above_thousand__expression.snap new file mode 100644 index 0000000..2bcf0af --- /dev/null +++ b/src/nodes/expressions/snapshots/snapshot_float_above_thousand__expression.snap @@ -0,0 +1,12 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(2000.05)" +--- +Number( + Decimal( + DecimalNumber { + float: 2000.05, + exponent: None, + }, + ), +) diff --git a/src/nodes/expressions/snapshots/snapshot_float_above_thousand__lua_float.snap b/src/nodes/expressions/snapshots/snapshot_float_above_thousand__lua_float.snap new file mode 100644 index 0000000..5e81b6a --- /dev/null +++ b/src/nodes/expressions/snapshots/snapshot_float_above_thousand__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(2000.05).to_lua_string()" +--- +2000.05 diff --git a/src/nodes/expressions/snapshots/snapshot_float_below_thousand__expression.snap b/src/nodes/expressions/snapshots/snapshot_float_below_thousand__expression.snap new file mode 100644 index 0000000..e633d33 --- /dev/null +++ b/src/nodes/expressions/snapshots/snapshot_float_below_thousand__expression.snap @@ -0,0 +1,12 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(100.25)" +--- +Number( + Decimal( + DecimalNumber { + float: 100.25, + exponent: None, + }, + ), +) diff --git a/src/nodes/expressions/snapshots/snapshot_float_below_thousand__lua_float.snap b/src/nodes/expressions/snapshots/snapshot_float_below_thousand__lua_float.snap new file mode 100644 index 0000000..39b8496 --- /dev/null +++ b/src/nodes/expressions/snapshots/snapshot_float_below_thousand__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(100.25).to_lua_string()" +--- +100.25 diff --git a/src/nodes/expressions/snapshots/snapshot_thousand__expression.snap b/src/nodes/expressions/snapshots/snapshot_thousand__expression.snap new file mode 100644 index 0000000..dd9fb3c --- /dev/null +++ b/src/nodes/expressions/snapshots/snapshot_thousand__expression.snap @@ -0,0 +1,17 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1000.0)" +--- +Number( + Decimal( + DecimalNumber { + float: 1.0, + exponent: Some( + ( + 3, + true, + ), + ), + }, + ), +) diff --git a/src/nodes/expressions/snapshots/snapshot_thousand__lua_float.snap b/src/nodes/expressions/snapshots/snapshot_thousand__lua_float.snap new file mode 100644 index 0000000..487bf42 --- /dev/null +++ b/src/nodes/expressions/snapshots/snapshot_thousand__lua_float.snap @@ -0,0 +1,5 @@ +--- +source: src/nodes/expressions/mod.rs +expression: "Expression::from(1000.0).to_lua_string()" +--- +1E3 diff --git a/src/nodes/expressions/unary.rs b/src/nodes/expressions/unary.rs index 75b5ab1..fde1bc7 100644 --- a/src/nodes/expressions/unary.rs +++ b/src/nodes/expressions/unary.rs @@ -56,13 +56,22 @@ impl UnaryExpression { impl ToLua for UnaryExpression { fn to_lua(&self, generator: &mut LuaGenerator) { self.operator.to_lua(generator); - self.expression.to_lua(generator); + + match &self.expression { + Expression::Binary(binary) if !binary.operator().precedes_unary_expression() => { + generator.push_char('('); + self.expression.to_lua(generator); + generator.push_char(')'); + }, + _ => self.expression.to_lua(generator), + } } } #[cfg(test)] mod test { use super::*; + use crate::nodes::{BinaryExpression, BinaryOperator, DecimalNumber}; #[test] fn generate_unary_expression() { @@ -86,4 +95,32 @@ mod test { assert_eq!(output, "- -a"); } + + #[test] + fn wraps_in_parens_if_an_inner_binary_has_lower_precedence() { + let output = UnaryExpression::new( + UnaryOperator::Not, + BinaryExpression::new( + BinaryOperator::Or, + Expression::False, + Expression::True, + ).into(), + ).to_lua_string(); + + assert_eq!(output, "not(false or true)"); + } + + #[test] + fn does_not_wrap_in_parens_if_an_inner_binary_has_higher_precedence() { + let output = UnaryExpression::new( + UnaryOperator::Minus, + BinaryExpression::new( + BinaryOperator::Caret, + DecimalNumber::new(2.0).into(), + DecimalNumber::new(2.0).into(), + ).into(), + ).to_lua_string(); + + assert_eq!(output, "-2^2"); + } } diff --git a/src/parser.rs b/src/parser.rs index a38c2de..d272366 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -232,68 +232,18 @@ impl From<(Prefix, Expression)> for IndexExpression { } } +impl From<&str> for NumberExpression { + fn from(value: &str) -> Self { + match value.parse() { + Ok(value) => value, + Err(error) => panic!("{}", error), + } + } +} + impl From for NumberExpression { fn from(value: String) -> Self { - if value.starts_with("0x") || value.starts_with("0X") { - let is_x_uppercase = value.chars().nth(1) - .map(char::is_uppercase) - .unwrap_or(false); - - if let Some(index) = value.find("p") { - let exponent = value.get(index + 1..).unwrap() - .parse::() - .expect("could not parse hexadecimal exponent"); - let number = u64::from_str_radix(value.get(2..index).unwrap(), 16) - .expect("could not parse hexadecimal number"); - - HexNumber::new(number, is_x_uppercase) - .with_exponent(exponent, false) - - } else if let Some(index) = value.find("P") { - let exponent = value.get(index + 1..).unwrap() - .parse::() - .expect("could not parse hexadecimal exponent"); - let number = u64::from_str_radix(value.get(2..index).unwrap(), 16) - .expect("could not parse hexadecimal number"); - - HexNumber::new(number, is_x_uppercase) - .with_exponent(exponent, true) - } else { - let number = u64::from_str_radix(value.get(2..).unwrap(), 16) - .expect(&format!("could not parse hexadecimal number: {}", value)); - - HexNumber::new(number, is_x_uppercase) - }.into() - - } else { - if let Some(index) = value.find("e") { - let exponent = value.get(index + 1..).unwrap() - .parse::() - .expect("could not parse decimal exponent"); - let number = value.get(0..index).unwrap() - .parse::() - .expect("could not parse decimal number"); - - DecimalNumber::new(number) - .with_exponent(exponent, false) - - } else if let Some(index) = value.find("E") { - let exponent = value.get(index + 1..).unwrap() - .parse::() - .expect("could not parse decimal exponent"); - let number = value.get(0..index).unwrap() - .parse::() - .expect("could not parse decimal number"); - - DecimalNumber::new(number) - .with_exponent(exponent, true) - } else { - let number = value.parse::() - .expect("could not parse number"); - - DecimalNumber::new(number) - }.into() - } + NumberExpression::from(value.as_str()) } } @@ -351,50 +301,3 @@ impl builders::UnaryOperator for UnaryOperator { fn length() -> Self { Self::Length } fn not() -> Self { Self::Not } } - -#[cfg(test)] -mod test { - use super::*; - - mod number_expression { - use super::*; - - macro_rules! test_numbers { - ($($name:ident($input:literal) => $expect:expr),+) => { - $( - #[test] - fn $name() { - let result = NumberExpression::from($input.to_owned()); - - let expect: NumberExpression = $expect.into(); - - assert_eq!(result, expect); - } - )+ - }; - } - - test_numbers!( - parse_zero("0") => DecimalNumber::new(0_f64), - parse_integer("123") => DecimalNumber::new(123_f64), - parse_multiple_decimal("123.24") => DecimalNumber::new(123.24_f64), - parse_float_with_trailing_dot("123.") => DecimalNumber::new(123_f64), - parse_starting_with_dot(".123") => DecimalNumber::new(0.123_f64), - parse_digit_with_exponent("1e10") => DecimalNumber::new(1_f64).with_exponent(10, false), - parse_number_with_exponent("123e456") => DecimalNumber::new(123_f64).with_exponent(456, false), - parse_number_with_exponent_and_plus_symbol("123e+456") => DecimalNumber::new(123_f64).with_exponent(456, false), - parse_number_with_negative_exponent("123e-456") => DecimalNumber::new(123_f64).with_exponent(-456, false), - parse_number_with_upper_exponent("123E4") => DecimalNumber::new(123_f64).with_exponent(4, true), - parse_number_with_upper_negative_exponent("123E-456") => DecimalNumber::new(123_f64).with_exponent(-456, true), - parse_float_with_exponent("10.12e8") => DecimalNumber::new(10.12_f64).with_exponent(8, false), - parse_trailing_dot_with_exponent("10.e8") => DecimalNumber::new(10_f64).with_exponent(8, false), - parse_hex_number("0x12") => HexNumber::new(18, false), - parse_uppercase_hex_number("0X12") => HexNumber::new(18, true), - parse_hex_number_with_lowercase("0x12a") => HexNumber::new(298, false), - parse_hex_number_with_uppercase("0x12A") => HexNumber::new(298, false), - parse_hex_number_with_mixed_case("0x1bF2A") => HexNumber::new(114_474, false), - parse_hex_with_exponent("0x12p4") => HexNumber::new(18, false).with_exponent(4, false), - parse_hex_with_exponent_uppercase("0xABP3") => HexNumber::new(171, false).with_exponent(3, true) - ); - } -} diff --git a/src/process/evaluator/lua_value.rs b/src/process/evaluator/lua_value.rs index a61255f..2b02cc3 100644 --- a/src/process/evaluator/lua_value.rs +++ b/src/process/evaluator/lua_value.rs @@ -1,3 +1,5 @@ +use crate::nodes::{Expression, NumberExpression, StringExpression}; + /// Represents an evaluated Expression result. #[derive(Debug, Clone, PartialEq)] pub enum LuaValue { @@ -60,6 +62,40 @@ impl LuaValue { _ => Self::Unknown, } } + + /// Attempt to convert the Lua value into an expression node. + pub fn to_expression(self) -> Option { + match self { + Self::False => Some(Expression::False), + Self::True => Some(Expression::True), + Self::Nil => Some(Expression::Nil), + Self::String(value) => Some(StringExpression::from_value(value).into()), + Self::Number(value) => Some(Expression::from(value)), + _ => None + } + } + + /// Attempt to convert the Lua value into a number value. This will convert strings when + /// possible and return the same value otherwise. + pub fn number_coercion(self) -> Self { + match &self { + Self::String(string) => { + let string = string.trim(); + + let number = if string.starts_with('-') { + string.get(1..) + .and_then(|string| string.parse::().ok()) + .map(|number| number.compute_value() * -1.0) + } else { + string.parse::().ok() + .map(|number| number.compute_value()) + }; + + number.map(LuaValue::Number) + } + _ => None, + }.unwrap_or(self) + } } impl Default for LuaValue { @@ -123,4 +159,59 @@ mod test { fn string_value_is_truthy() { assert!(LuaValue::String("".to_owned()).is_truthy().unwrap()); } + + mod number_coercion { + use super::*; + + macro_rules! number_coercion { + ($($name:ident ($string:literal) => $result:expr),*) => { + $( + #[test] + fn $name() { + assert_eq!( + LuaValue::String($string.into()).number_coercion(), + LuaValue::Number($result) + ); + } + )* + }; + } + + macro_rules! no_number_coercion { + ($($name:ident ($string:literal)),*) => { + $( + #[test] + fn $name() { + assert_eq!( + LuaValue::String($string.into()).number_coercion(), + LuaValue::String($string.into()) + ); + } + )* + }; + } + + number_coercion!( + zero("0") => 0.0, + integer("12") => 12.0, + integer_with_leading_zeros("00012") => 12.0, + integer_with_ending_space("12 ") => 12.0, + integer_with_leading_space(" 123") => 123.0, + integer_with_leading_tab("\t123") => 123.0, + negative_integer("-3") => -3.0, + hex_zero("0x0") => 0.0, + hex_integer("0xA") => 10.0, + negative_hex_integer("-0xA") => -10.0, + float("0.5") => 0.5, + negative_float("-0.5") => -0.5, + float_starting_with_dot(".5") => 0.5 + ); + + no_number_coercion!( + letter_suffix("123a"), + hex_prefix("0x"), + space_between_minus("- 1"), + two_seperated_digits(" 1 2") + ); + } } diff --git a/src/process/evaluator/mod.rs b/src/process/evaluator/mod.rs index e0f9400..cfc4604 100644 --- a/src/process/evaluator/mod.rs +++ b/src/process/evaluator/mod.rs @@ -48,7 +48,7 @@ impl Evaluator { | Expression::String(_) | Expression::True | Expression::VariableArguments => false, - | Expression::Binary(binary) => { + Expression::Binary(binary) => { if self.pure_metamethods { self.has_side_effects(binary.left()) || self.has_side_effects(binary.left()) } else { @@ -62,7 +62,7 @@ impl Evaluator { } } - | Expression::Unary(unary) => { + Expression::Unary(unary) => { if self.pure_metamethods { self.has_side_effects(unary.get_expression()) } else { @@ -159,6 +159,14 @@ impl Evaluator { _ => LuaValue::Unknown, } } + BinaryOperator::Plus => self.evaluate_math(expression, |a, b| a + b), + BinaryOperator::Minus => self.evaluate_math(expression, |a, b| a - b), + BinaryOperator::Asterisk => self.evaluate_math(expression, |a, b| a * b), + BinaryOperator::Slash => self.evaluate_math(expression, |a, b| a / b), + BinaryOperator::Caret => self.evaluate_math(expression, |a, b| a.powf(b)), + BinaryOperator::Percent => { + self.evaluate_math(expression, |a, b| a - b * (a / b).floor()) + } _ => LuaValue::Unknown, } } @@ -175,6 +183,24 @@ impl Evaluator { } } + fn evaluate_math(&self, expression: &BinaryExpression, operation: F) -> LuaValue + where F: Fn(f64, f64) -> f64 + { + let left = self.evaluate(expression.left()).number_coercion(); + + if let LuaValue::Number(left) = left { + let right = self.evaluate(expression.right()).number_coercion(); + + if let LuaValue::Number(right) = right { + LuaValue::Number(operation(left, right)) + } else { + LuaValue::Unknown + } + } else { + LuaValue::Unknown + } + } + fn evaluate_unary(&self, expression: &UnaryExpression) -> LuaValue { match expression.operator() { UnaryOperator::Not => { @@ -183,6 +209,12 @@ impl Evaluator { .map(|value| LuaValue::from(!value)) .unwrap_or(LuaValue::Unknown) } + UnaryOperator::Minus => { + match self.evaluate(expression.get_expression()).number_coercion() { + LuaValue::Number(value) => LuaValue::from(-value), + _ => LuaValue::Unknown + } + } _ => LuaValue::Unknown, } } @@ -216,12 +248,29 @@ mod test { use super::*; macro_rules! evaluate_binary_expressions { - ($($name:ident ($operator:expr, $left:expr, $right:expr) => $value:expr),*) => { + ($($name:ident ($operator:expr, $left:expr, $right:expr) => $expect:expr),*) => { $( #[test] fn $name() { let binary = BinaryExpression::new($operator, $left.into(), $right.into()); - assert_eq!($value, Evaluator::default().evaluate(&binary.into())); + + let result = Evaluator::default().evaluate(&binary.into()); + + match (&$expect, &result) { + (LuaValue::Number(expect_float), LuaValue::Number(result))=> { + if expect_float.is_nan() { + assert!(result.is_nan(), "{} should be NaN", result); + } else { + assert!( + result == expect_float || (expect_float - result).abs() < 0.1e-10, + "{} does not approximate {}", result, expect_float + ); + } + } + _ => { + assert_eq!($expect, result); + } + } } )* }; @@ -307,7 +356,67 @@ mod test { BinaryOperator::Or, Expression::Nil, Expression::Nil - ) => LuaValue::Nil + ) => LuaValue::Nil, + one_plus_two( + BinaryOperator::Plus, + Expression::from(1.0), + Expression::from(2.0) + ) => LuaValue::Number(3.0), + one_minus_two( + BinaryOperator::Minus, + Expression::from(1.0), + Expression::from(2.0) + ) => LuaValue::Number(-1.0), + three_times_four( + BinaryOperator::Asterisk, + Expression::from(3.0), + Expression::from(4.0) + ) => LuaValue::Number(12.0), + twelve_divided_by_four( + BinaryOperator::Slash, + Expression::from(12.0), + Expression::from(4.0) + ) => LuaValue::Number(3.0), + one_divided_by_zero( + BinaryOperator::Slash, + Expression::from(1.0), + Expression::from(0.0) + ) => LuaValue::Number(std::f64::INFINITY), + zero_divided_by_zero( + BinaryOperator::Slash, + Expression::from(0.0), + Expression::from(0.0) + ) => LuaValue::Number(std::f64::NAN), + five_mod_two( + BinaryOperator::Percent, + Expression::from(5.0), + Expression::from(2.0) + ) => LuaValue::Number(1.0), + minus_five_mod_two( + BinaryOperator::Percent, + Expression::from(-5.0), + Expression::from(2.0) + ) => LuaValue::Number(1.0), + minus_five_mod_minus_two( + BinaryOperator::Percent, + Expression::from(-5.0), + Expression::from(-2.0) + ) => LuaValue::Number(-1.0), + five_point_two_mod_two( + BinaryOperator::Percent, + Expression::from(5.5), + Expression::from(2.0) + ) => LuaValue::Number(1.5), + five_pow_two( + BinaryOperator::Caret, + Expression::from(5.0), + Expression::from(2.0) + ) => LuaValue::Number(25.0), + string_number_plus_string_number( + BinaryOperator::Plus, + StringExpression::from_value("2"), + StringExpression::from_value("3") + ) => LuaValue::Number(5.0) ); macro_rules! evaluate_equality { @@ -390,6 +499,7 @@ mod test { mod unary_expressions { use super::*; + use UnaryOperator::*; macro_rules! evaluate_unary_expressions { ($($name:ident ($operator:expr, $input:expr) => $value:expr),*) => { @@ -404,16 +514,19 @@ mod test { } evaluate_unary_expressions!( - not_true(UnaryOperator::Not, Expression::True) => LuaValue::False, - not_false(UnaryOperator::Not, Expression::False) => LuaValue::True, - not_nil(UnaryOperator::Not, Expression::Nil) => LuaValue::True, - not_table(UnaryOperator::Not, TableExpression::default()) => LuaValue::False, - not_string(UnaryOperator::Not, StringExpression::from_value("foo")) => LuaValue::False, + not_true(Not, Expression::True) => LuaValue::False, + not_false(Not, Expression::False) => LuaValue::True, + not_nil(Not, Expression::Nil) => LuaValue::True, + not_table(Not, TableExpression::default()) => LuaValue::False, + not_string(Not, StringExpression::from_value("foo")) => LuaValue::False, not_number( - UnaryOperator::Not, + Not, Expression::Number(DecimalNumber::new(10.0).into()) ) => LuaValue::False, - not_identifier(UnaryOperator::Not, Expression::Identifier("foo".to_owned())) => LuaValue::Unknown + not_identifier(Not, Expression::Identifier("foo".to_owned())) => LuaValue::Unknown, + minus_one(Minus, DecimalNumber::new(1.0)) => LuaValue::from(-1.0), + minus_negative_number(Minus, DecimalNumber::new(-5.0)) => LuaValue::from(5.0), + minus_string_converted_to_number(Minus, StringExpression::from_value("1")) => LuaValue::from(-1.0) ); } diff --git a/src/rules/compute_expression.rs b/src/rules/compute_expression.rs new file mode 100644 index 0000000..8c2aecb --- /dev/null +++ b/src/rules/compute_expression.rs @@ -0,0 +1,81 @@ +use crate::nodes::{Block, Expression}; +use crate::process::{DefaultVisitor, Evaluator, NodeProcessor, NodeVisitor}; +use crate::rules::{Rule, RuleConfigurationError, RuleProperties}; + +use std::mem; + +#[derive(Debug, Clone, Default)] +struct Computer { + evaluator: Evaluator, +} + +impl Computer { + fn replace_with(&self, expression: &mut Expression) -> Option { + match expression { + Expression::Unary(_) | Expression::Binary(_) => { + if !self.evaluator.has_side_effects(&expression) { + self.evaluator.evaluate(&expression) + .to_expression() + } else { + None + } + } + _ => None, + } + } +} + +impl NodeProcessor for Computer { + fn process_expression(&mut self, expression: &mut Expression) { + if let Some(replace_with) = self.replace_with(expression) { + mem::replace(expression, replace_with); + } + } +} + +pub const COMPUTE_EXPRESSIONS_RULE_NAME: &'static str = "compute_expression"; + +/// A rule that compute expressions that do not have any side-effects. +#[derive(Debug, Default, PartialEq, Eq)] +pub struct ComputeExpression {} + +impl Rule for ComputeExpression { + fn process(&self, block: &mut Block) { + let mut processor = Computer::default(); + DefaultVisitor::visit_block(block, &mut processor); + } + + fn configure(&mut self, properties: RuleProperties) -> Result<(), RuleConfigurationError> { + for (key, _value) in properties { + return Err(RuleConfigurationError::UnexpectedProperty(key)) + } + + Ok(()) + } + + fn get_name(&self) -> &'static str { + COMPUTE_EXPRESSIONS_RULE_NAME + } + + fn serialize_to_properties(&self) -> RuleProperties { + RuleProperties::new() + } +} + +#[cfg(test)] +mod test { + use super::*; + + use insta::assert_json_snapshot; + + fn new_rule() -> ComputeExpression { + ComputeExpression::default() + } + + #[test] + fn serialize_default_rule() { + let rule: Box = Box::new(new_rule()); + + assert_json_snapshot!("default_compute_expression", rule); + } +} diff --git a/src/rules/mod.rs b/src/rules/mod.rs index e520880..facce02 100644 --- a/src/rules/mod.rs +++ b/src/rules/mod.rs @@ -1,7 +1,8 @@ //! A module that contains the different rules that mutates a Lua block. -mod empty_do; mod call_parens; +mod compute_expression; +mod empty_do; mod group_local; mod inject_value; mod method_def; @@ -10,8 +11,9 @@ mod rename_variables; mod unused_if_branch; mod unused_while; -pub use empty_do::*; pub use call_parens::*; +pub use compute_expression::*; +pub use empty_do::*; pub use group_local::*; pub use inject_value::*; pub use method_def::*; @@ -100,6 +102,7 @@ pub trait Rule { /// processed block will work as much as the original one. pub fn get_default_rules() -> Vec> { vec![ + Box::new(ComputeExpression::default()), Box::new(RemoveUnusedIfBranch::default()), Box::new(RemoveUnusedWhile::default()), Box::new(RemoveEmptyDo::default()), @@ -116,6 +119,7 @@ impl FromStr for Box { fn from_str(string: &str) -> Result { let rule: Box = match string { + COMPUTE_EXPRESSIONS_RULE_NAME => Box::new(ComputeExpression::default()), CONVERT_LOCAL_FUNCTION_TO_ASSIGN_RULE_NAME => Box::new(ConvertLocalFunctionToAssign::default()), GROUP_LOCAL_ASSIGNMENT => Box::new(GroupLocalAssignment::default()), INJECT_GLOBAL_VALUE_RULE_NAME => Box::new(InjectGlobalValue::default()), diff --git a/src/rules/snapshots/test__default_compute_expression.snap b/src/rules/snapshots/test__default_compute_expression.snap new file mode 100644 index 0000000..d645856 --- /dev/null +++ b/src/rules/snapshots/test__default_compute_expression.snap @@ -0,0 +1,5 @@ +--- +source: src/rules/compute_expression.rs +expression: rule +--- +"compute_expression" diff --git a/src/rules/snapshots/test__default_rules.snap b/src/rules/snapshots/test__default_rules.snap index 08c2697..7e539cf 100644 --- a/src/rules/snapshots/test__default_rules.snap +++ b/src/rules/snapshots/test__default_rules.snap @@ -3,6 +3,7 @@ source: src/rules/mod.rs expression: rules --- [ + "compute_expression", "remove_unused_if_branch", "remove_unused_while", "remove_empty_do", diff --git a/tests/fuzz.rs b/tests/fuzz.rs index ccf51f2..c9f66cc 100644 --- a/tests/fuzz.rs +++ b/tests/fuzz.rs @@ -396,13 +396,6 @@ fn get_binary_operator(expression: &Expression) -> Option { _ => None, } } -#[inline] -fn is_unary_expression(expression: &Expression) -> bool { - match expression { - Expression::Unary(_) => true, - _ => false, - } -} impl Fuzz for BinaryExpression { fn fuzz(context: &mut FuzzContext) -> Self { @@ -410,23 +403,11 @@ impl Fuzz for BinaryExpression { let mut left = Expression::fuzz(context); let mut right = Expression::fuzz(context); - if let Some(left_operator) = get_binary_operator(&left) { - if (left_operator == operator && operator.is_right_associative()) - || !left_operator.precedes(operator) - { - left = Expression::Parenthese(left.into()); - } - } else if is_unary_expression(&left) && operator.preceeds_unary_expression() { + if operator.left_needs_parentheses(&left) { left = Expression::Parenthese(left.into()); } - if let Some(right_operator) = get_binary_operator(&right) { - if (right_operator == operator && operator.is_left_associative()) - || !right_operator.precedes(operator) - { - right = Expression::Parenthese(right.into()); - } - } else if is_unary_expression(&right) && operator.preceeds_unary_expression() { + if operator.right_needs_parentheses(&right) { right = Expression::Parenthese(right.into()); } @@ -604,7 +585,7 @@ impl Fuzz for UnaryExpression { let mut expression = Expression::fuzz(context); if let Some(inner_operator) = get_binary_operator(&expression) { - if !inner_operator.preceeds_unary_expression() { + if !inner_operator.precedes_unary_expression() { expression = Expression::Parenthese(expression.into()); } } diff --git a/tests/fuzz_generator.rs b/tests/fuzz_generator.rs index 63a3694..291e11e 100644 --- a/tests/fuzz_generator.rs +++ b/tests/fuzz_generator.rs @@ -1,4 +1,16 @@ -use darklua_core::{nodes::Block, LuaGenerator, ToLua}; +use darklua_core::{ + nodes::{ + BinaryExpression, + BinaryOperator, + Block, + Expression, + LastStatement, + UnaryExpression, + UnaryOperator, + }, + LuaGenerator, + ToLua, +}; use std::time::{Duration, Instant}; mod fuzz; @@ -6,7 +18,73 @@ mod utils; use fuzz::*; -macro_rules! fuzz_test { +macro_rules! fuzz_test_expression { + ($node:expr, $column_span:expr) => { + let node: Expression = $node.into(); + let mut generator = LuaGenerator::new($column_span); + node.to_lua(&mut generator); + let lua_code = format!("return {}", generator.into_string()); + + let mut generated_block = match utils::try_parse_input(&lua_code) { + Ok(block) => block, + Err(error) => panic!( + concat!( + "could not parse content: {:?}\n", + "============================================================\n", + ">>> Lua code input:\n{}\n", + "============================================================\n", + "\n", + "============================================================\n", + ">>> Node that produced the generated code:\n{:?}\n", + "============================================================\n", + ), + error, + lua_code, + node, + ), + }; + + let last_statement = generated_block.mutate_last_statement() + .take() + .expect("should have a last statement"); + + let generated_node = match last_statement { + LastStatement::Return(expressions) => { + if expressions.len() != 1 { + panic!("should have exactly one expression") + } + expressions.into_iter().next().unwrap() + } + _ => panic!("return statement expected"), + }; + + assert_eq!( + node, + generated_node, + concat!( + "\n", + "============================================================\n", + ">>> Generated from node fuzz:\n{:#?}\n", + ">>> Lua code generated:\n{}\n", + "============================================================\n", + "\n", + "============================================================\n", + ">>> Parsed node:\n{:#?}\n", + ">>> Node code generated:\n{}\n", + "============================================================\n", + ), + node, + lua_code, + generated_node, + generated_node.to_lua_string(), + ); + }; + ($node:expr) => { + fuzz_test_expression!($node, 80); + }; +} + +macro_rules! fuzz_test_block { ($context:expr, $column_span:expr) => { let block = Block::fuzz(&mut $context); @@ -39,12 +117,12 @@ macro_rules! fuzz_test { concat!( "\n", "============================================================\n", - ">>> Block generated from node fuzz:\n{:?}\n", + ">>> Generated from block fuzz:\n{:?}\n", ">>> Lua code generated:\n{}\n", "============================================================\n", "\n", "============================================================\n", - ">>> Block generated from parsed generated code:\n{:?}\n", + ">>> Parsed generated block:\n{:?}\n", ">>> Lua code generated:\n{}\n", "============================================================\n", ), @@ -55,7 +133,7 @@ macro_rules! fuzz_test { ); }; ($context:expr) => { - fuzz_test!($context, 80); + fuzz_test_block!($context, 80); }; } @@ -80,31 +158,116 @@ fn get_fuzz_duration() -> Duration { Duration::from_millis(millis) } +#[test] +fn fuzz_three_terms_binary_expressions() { + run_for_minimum_time(|| { + let mut empty_context = FuzzContext::new(0, 0); + let first = Expression::True; + let second = Expression::False; + let third = Expression::Nil; + + let (left, right) = if rand::random() { + ( + BinaryExpression::new( + BinaryOperator::fuzz(&mut empty_context), + first, + second, + ).into(), + third, + ) + } else { + ( + first, + BinaryExpression::new( + BinaryOperator::fuzz(&mut empty_context), + second, + third, + ).into(), + ) + }; + + let operator = BinaryOperator::fuzz(&mut empty_context); + let binary = BinaryExpression::new( + operator, + if operator.left_needs_parentheses(&left) { + Expression::Parenthese(left.into()) + } else { + left + }, + if operator.right_needs_parentheses(&right) { + Expression::Parenthese(right.into()) + } else { + right + } + ); + + fuzz_test_expression!(binary); + }); +} + +#[test] +fn fuzz_binary_expressions_with_one_unary_expression() { + run_for_minimum_time(|| { + let mut empty_context = FuzzContext::new(0, 0); + let first = Expression::True; + let second = Expression::False; + + let (left, right) = if rand::random() { + ( + UnaryExpression::new(UnaryOperator::fuzz(&mut empty_context), first).into(), + second, + ) + } else { + ( + first, + UnaryExpression::new(UnaryOperator::fuzz(&mut empty_context), second).into(), + ) + }; + + let operator = BinaryOperator::fuzz(&mut empty_context); + let binary = BinaryExpression::new( + operator, + if operator.left_needs_parentheses(&left) { + Expression::Parenthese(left.into()) + } else { + left + }, + if operator.right_needs_parentheses(&right) { + Expression::Parenthese(right.into()) + } else { + right + } + ); + + fuzz_test_expression!(binary); + }); +} + #[test] fn fuzz_single_statement() { run_for_minimum_time(|| { - fuzz_test!(FuzzContext::new(1, 5)); + fuzz_test_block!(FuzzContext::new(1, 5)); }); } #[test] fn fuzz_small_block() { run_for_minimum_time(|| { - fuzz_test!(FuzzContext::new(20, 40)); + fuzz_test_block!(FuzzContext::new(20, 40)); }); } #[test] fn fuzz_medium_block() { run_for_minimum_time(|| { - fuzz_test!(FuzzContext::new(100, 200)); + fuzz_test_block!(FuzzContext::new(100, 200)); }); } #[test] fn fuzz_large_block() { run_for_minimum_time(|| { - fuzz_test!(FuzzContext::new(200, 500)); + fuzz_test_block!(FuzzContext::new(200, 500)); }); } @@ -112,7 +275,7 @@ fn fuzz_large_block() { fn fuzz_column_span() { run_for_minimum_time(|| { for i in 0..80 { - fuzz_test!(FuzzContext::new(20, 40), i); + fuzz_test_block!(FuzzContext::new(20, 40), i); } }); } diff --git a/tests/rule_tests/compute_expression.rs b/tests/rule_tests/compute_expression.rs new file mode 100644 index 0000000..bbf1b13 --- /dev/null +++ b/tests/rule_tests/compute_expression.rs @@ -0,0 +1,23 @@ +use darklua_core::rules::{ComputeExpression, Rule}; + +test_rule!( + ComputeExpression::default(), + binary_true_and_false("return true and false") => "return false", + number_addition("return 1 + 2") => "return 3", + multiple_addition("return 1 + 2 + 5") => "return 8", + division("return 1/3") => "return 0.3333333333333333", + division_test("return 3 * 0.3333333333333333") => "return 1", + multiply_small_number("return 2 * 1e-50") => "return 2E-50" +); + +#[test] +fn deserialize_from_object_notation() { + json5::from_str::>(r#"{ + rule: 'compute_expression', + }"#).unwrap(); +} + +#[test] +fn deserialize_from_string() { + json5::from_str::>("'compute_expression'").unwrap(); +} diff --git a/tests/rule_tests/group_local_assignment.rs b/tests/rule_tests/group_local_assignment.rs index 892a92d..a327ce3 100644 --- a/tests/rule_tests/group_local_assignment.rs +++ b/tests/rule_tests/group_local_assignment.rs @@ -1,4 +1,4 @@ -use darklua_core::rules::GroupLocalAssignment; +use darklua_core::rules::{GroupLocalAssignment, Rule}; test_rule!( GroupLocalAssignment::default(), @@ -14,3 +14,15 @@ test_rule_wihout_effects!( two_local_using_the_other("local foo = 1 local bar = foo"), multiple_return_values("local a, b = call() local c = 0") ); + +#[test] +fn deserialize_from_object_notation() { + json5::from_str::>(r#"{ + rule: 'group_local_assignment', + }"#).unwrap(); +} + +#[test] +fn deserialize_from_string() { + json5::from_str::>("'group_local_assignment'").unwrap(); +} diff --git a/tests/rule_tests/mod.rs b/tests/rule_tests/mod.rs index 91ffae4..dff194d 100644 --- a/tests/rule_tests/mod.rs +++ b/tests/rule_tests/mod.rs @@ -40,6 +40,7 @@ macro_rules! test_rule_wihout_effects { }; } +mod compute_expression; mod group_local_assignment; mod inject_value; mod no_local_function;