diff --git a/utils/fixed_decimal/src/decimal.rs b/utils/fixed_decimal/src/decimal.rs index 0fadc7f5b01..67b38a34634 100644 --- a/utils/fixed_decimal/src/decimal.rs +++ b/utils/fixed_decimal/src/decimal.rs @@ -934,63 +934,6 @@ impl FixedDecimal { self.check_invariants(); } - /// Increments the digits by 1. if the digits are empty, it will add - /// an element with value 1. If there are some trailing zeros, - /// it will be reomved from `self.digits`. - fn increment_abs_by_one(&mut self) -> Result<(), Error> { - for (zero_count, digit) in self.digits.iter_mut().rev().enumerate() { - *digit += 1; - if *digit < 10 { - self.digits.truncate(self.digits.len() - zero_count); - #[cfg(debug_assertions)] - self.check_invariants(); - return Ok(()); - } - } - - self.digits.clear(); - - if self.magnitude == i16::MAX { - self.magnitude = 0; - - #[cfg(debug_assertions)] - self.check_invariants(); - return Err(Error::Limit); - } - - // Still a carry, carry one to the next magnitude. - self.digits.push(1); - self.magnitude += 1; - - if self.upper_magnitude < self.magnitude { - self.upper_magnitude = self.magnitude; - } - - #[cfg(debug_assertions)] - self.check_invariants(); - Ok(()) - } - - /// Removes the trailing zeros in `self.digits` - fn remove_trailing_zeros_from_digits_list(&mut self) { - let no_of_trailing_zeros = self - .digits - .iter() - .rev() - .take_while(|&digit| *digit == 0) - .count(); - - self.digits - .truncate(self.digits.len() - no_of_trailing_zeros); - - if self.digits.is_empty() { - self.magnitude = 0; - } - - #[cfg(debug_assertions)] - self.check_invariants(); - } - /// Truncate the number on the right to a particular position, deleting /// digits if necessary. /// @@ -1046,6 +989,7 @@ impl FixedDecimal { /// assert_eq!("1", dec.to_string()); /// ``` pub fn trunc(&mut self, position: i16) { + // 1. Set upper and lower magnitude self.lower_magnitude = cmp::min(position, 0); if position == i16::MIN { // Nothing more to do @@ -1056,11 +1000,32 @@ impl FixedDecimal { let magnitude = position - 1; self.upper_magnitude = cmp::max(self.upper_magnitude, magnitude); - if magnitude <= self.magnitude { - self.digits - .truncate(crate::ops::i16_abs_sub(self.magnitude, magnitude) as usize); - self.remove_trailing_zeros_from_digits_list(); + // 2. If the rounding position is *lower than* the rightmost nonzero digit, exit early + if self.is_zero() || magnitude < self.nonzero_magnitude_end() { + #[cfg(debug_assertions)] + self.check_invariants(); + return; + } + + // 3. If the rounding position is *in the middle* of the nonzero digits + if magnitude < self.magnitude { + // 3a. Calculate the number of digits to retain and remove the rest + let digits_to_retain = crate::ops::i16_abs_sub(self.magnitude, magnitude); + self.digits.truncate(digits_to_retain as usize); + // 3b. Remove trailing zeros from self.digits to retain invariants + // Note: this does not affect visible trailing zeros, + // which is tracked by self.lower_magnitude + let position_past_last_nonzero_digit = self + .digits + .iter() + .rposition(|x| *x != 0) + .map(|x| x + 1) + .unwrap_or(0); + self.digits.truncate(position_past_last_nonzero_digit); + // 3c. By the invariant, there should still be at least 1 nonzero digit + debug_assert!(!self.digits.is_empty()); } else { + // 4. If the rounding position is *above* the leftmost nonzero digit, set to zero self.digits.clear(); self.magnitude = 0; } @@ -1168,32 +1133,60 @@ impl FixedDecimal { /// assert_eq!("2", dec.to_string()); /// ``` pub fn expand(&mut self, position: i16) { - let before_truncate_is_zero = self.is_zero(); - let before_truncate_bottom_magnitude = self.nonzero_magnitude_end(); - let before_truncate_magnitude = self.magnitude; - self.trunc(position); + // 1. Set upper and lower magnitude + self.lower_magnitude = cmp::min(position, 0); + if position == i16::MIN { + // Nothing more to do + #[cfg(debug_assertions)] + self.check_invariants(); + return; + } + let magnitude = position - 1; + self.upper_magnitude = cmp::max(self.upper_magnitude, magnitude); - if before_truncate_is_zero || position <= before_truncate_bottom_magnitude { + // 2. If the rounding position is *lower than* the rightmost nonzero digit, exit early + if self.is_zero() || magnitude < self.nonzero_magnitude_end() { #[cfg(debug_assertions)] self.check_invariants(); return; } - if position <= before_truncate_magnitude { - let result = self.increment_abs_by_one(); - if result.is_err() { - // Do nothing for now. + // 3. If the rounding position is *in the middle* of the nonzero digits + if magnitude < self.magnitude { + // 3a. Calculate the number of digits to retain and remove the rest + let digits_to_retain = crate::ops::i16_abs_sub(self.magnitude, magnitude); + self.digits.truncate(digits_to_retain as usize); + // 3b. Increment the rightmost remaining digit since we are rounding up; this might + // require bubbling the addition to higher magnitudes, like 199 + 1 = 200 + for (zero_count, digit) in self.digits.iter_mut().rev().enumerate() { + *digit += 1; + if *digit < 10 { + self.digits.truncate(self.digits.len() - zero_count); + #[cfg(debug_assertions)] + self.check_invariants(); + return; + } } + // 3c. If we get here, the mantissa is fully saturated, and we continue into case 4 + } + // 4. If the rounding position is *above* the leftmost nonzero digit, OR if we saturated, + // set the value to a single digit 1 at the appropriate position + if self.magnitude == i16::MAX { + // TODO(#2297): Decide on behavior here + self.magnitude = 0; + self.digits.clear(); #[cfg(debug_assertions)] self.check_invariants(); return; } - - debug_assert!(self.digits.is_empty()); + self.digits.clear(); self.digits.push(1); - self.magnitude = position; - self.upper_magnitude = cmp::max(self.upper_magnitude, position); + // If we got here from case 3, we should use self.magnitude + 1 + // If we got here directly from case 4, we should use magnitude + 1 + // We can use cmp::max to pick the right one + self.magnitude = cmp::max(self.magnitude, magnitude) + 1; + self.upper_magnitude = cmp::max(self.upper_magnitude, self.magnitude); #[cfg(debug_assertions)] self.check_invariants(); @@ -3338,6 +3331,19 @@ fn test_rounding() { let mut dec = FixedDecimal::from_str("-0.009").unwrap(); dec.half_expand(-1); assert_eq!("-0.0", dec.to_string()); + + // // Test specific cases + let mut dec = FixedDecimal::from_str("1.108").unwrap(); + dec.half_even(-2); + assert_eq!("1.11", dec.to_string()); + + let mut dec = FixedDecimal::from_str("1.108").unwrap(); + dec.expand(-2); + assert_eq!("1.11", dec.to_string()); + + let mut dec = FixedDecimal::from_str("1.108").unwrap(); + dec.trunc(-2); + assert_eq!("1.10", dec.to_string()); } #[test] diff --git a/utils/fixed_decimal/tests/rounding.rs b/utils/fixed_decimal/tests/rounding.rs new file mode 100644 index 00000000000..112aca81edb --- /dev/null +++ b/utils/fixed_decimal/tests/rounding.rs @@ -0,0 +1,346 @@ +// This file is part of ICU4X. For terms of use, please see the file +// called LICENSE at the top level of the ICU4X source tree +// (online at: https://github.com/unicode-org/icu4x/blob/main/LICENSE ). + +use core::ops::RangeInclusive; +use fixed_decimal::FixedDecimal; +use fixed_decimal::Sign; +use writeable::Writeable; + +#[test] +pub fn test_ecma402_table() { + // Source: + #[allow(clippy::type_complexity)] // best way to make it render like a table + let cases: [( + &'static str, + fn(&mut FixedDecimal, i16), + i32, + i32, + i32, + i32, + i32, + ); 9] = [ + ("ceil", FixedDecimal::ceil, -1, 1, 1, 1, 2), + ("floor", FixedDecimal::floor, -2, 0, 0, 0, 1), + ("expand", FixedDecimal::expand, -2, 1, 1, 1, 2), + ("trunc", FixedDecimal::trunc, -1, 0, 0, 0, 1), + ("half_ceil", FixedDecimal::half_ceil, -1, 0, 1, 1, 2), + ("half_floor", FixedDecimal::half_floor, -2, 0, 0, 1, 1), + ("half_expand", FixedDecimal::half_expand, -2, 0, 1, 1, 2), + ("half_trunc", FixedDecimal::half_trunc, -1, 0, 0, 1, 1), + ("half_even", FixedDecimal::half_even, -2, 0, 0, 1, 2), + ]; + for (rounding_mode, f, e1, e2, e3, e4, e5) in cases.into_iter() { + let mut fd1: FixedDecimal = "-1.5".parse().unwrap(); + let mut fd2: FixedDecimal = "0.4".parse().unwrap(); + let mut fd3: FixedDecimal = "0.5".parse().unwrap(); + let mut fd4: FixedDecimal = "0.6".parse().unwrap(); + let mut fd5: FixedDecimal = "1.5".parse().unwrap(); + f(&mut fd1, 0); + f(&mut fd2, 0); + f(&mut fd3, 0); + f(&mut fd4, 0); + f(&mut fd5, 0); + assert_eq!( + fd1.write_to_string(), + e1.write_to_string(), + "-1.5 failed for {rounding_mode}" + ); + assert_eq!( + fd2.write_to_string(), + e2.write_to_string(), + "0.4 failed for {rounding_mode}" + ); + assert_eq!( + fd3.write_to_string(), + e3.write_to_string(), + "0.5 failed for {rounding_mode}" + ); + assert_eq!( + fd4.write_to_string(), + e4.write_to_string(), + "0.6 failed for {rounding_mode}" + ); + assert_eq!( + fd5.write_to_string(), + e5.write_to_string(), + "1.5 failed for {rounding_mode}" + ); + } +} + +#[test] +pub fn test_within_ranges() { + struct TestCase { + rounding_mode: &'static str, + f: fn(&mut FixedDecimal, i16), + range_n2000: RangeInclusive, + range_n1000: RangeInclusive, + range_0: RangeInclusive, + range_1000: RangeInclusive, + range_2000: RangeInclusive, + } + let cases: [TestCase; 9] = [ + TestCase { + rounding_mode: "ceil", + f: FixedDecimal::ceil, + range_n2000: -2999..=-2000, + range_n1000: -1999..=-1000, + range_0: -999..=0, + range_1000: 1..=1000, + range_2000: 1001..=2000, + }, + TestCase { + rounding_mode: "floor", + f: FixedDecimal::floor, + range_n2000: -2000..=-1001, + range_n1000: -1000..=-1, + range_0: 0..=999, + range_1000: 1000..=1999, + range_2000: 2000..=2999, + }, + TestCase { + rounding_mode: "expand", + f: FixedDecimal::expand, + range_n2000: -2000..=-1001, + range_n1000: -1000..=-1, + range_0: 0..=0, + range_1000: 1..=1000, + range_2000: 1001..=2000, + }, + TestCase { + rounding_mode: "trunc", + f: FixedDecimal::trunc, + range_n2000: -2999..=-2000, + range_n1000: -1999..=-1000, + range_0: -999..=999, + range_1000: 1000..=1999, + range_2000: 2000..=2999, + }, + TestCase { + rounding_mode: "half_ceil", + f: FixedDecimal::half_ceil, + range_n2000: -2500..=-1501, + range_n1000: -1500..=-501, + range_0: -500..=449, + range_1000: 500..=1449, + range_2000: 1500..=2449, + }, + TestCase { + rounding_mode: "half_floor", + f: FixedDecimal::half_floor, + range_n2000: -2449..=-1500, + range_n1000: -1449..=-500, + range_0: -449..=500, + range_1000: 501..=1500, + range_2000: 1501..=2500, + }, + TestCase { + rounding_mode: "half_expand", + f: FixedDecimal::half_expand, + range_n2000: -2449..=-1500, + range_n1000: -1449..=-500, + range_0: -449..=449, + range_1000: 500..=1449, + range_2000: 1500..=2449, + }, + TestCase { + rounding_mode: "half_trunc", + f: FixedDecimal::half_trunc, + range_n2000: -2500..=-1501, + range_n1000: -1500..=-501, + range_0: -500..=500, + range_1000: 501..=1500, + range_2000: 1501..=2500, + }, + TestCase { + rounding_mode: "half_even", + f: FixedDecimal::half_even, + range_n2000: -2500..=-1500, + range_n1000: -1449..=-501, + range_0: -500..=500, + range_1000: 501..=1449, + range_2000: 1500..=2500, + }, + ]; + for TestCase { + rounding_mode, + f, + range_n2000, + range_n1000, + range_0, + range_1000, + range_2000, + } in cases + { + for n in range_n2000 { + let mut fd = FixedDecimal::from(n); + f(&mut fd, 3); + assert_eq!(fd.write_to_string(), "-2000", "{rounding_mode}: {n}"); + let mut fd = FixedDecimal::from(n - 1000000).multiplied_pow10(-5); + f(&mut fd, -2); + assert_eq!( + fd.write_to_string(), + "-10.02", + "{rounding_mode}: {n} ÷ 10^5 ± 10" + ); + } + for n in range_n1000 { + let mut fd = FixedDecimal::from(n); + f(&mut fd, 3); + assert_eq!(fd.write_to_string(), "-1000", "{rounding_mode}: {n}"); + let mut fd = FixedDecimal::from(n - 1000000).multiplied_pow10(-5); + f(&mut fd, -2); + assert_eq!( + fd.write_to_string(), + "-10.01", + "{rounding_mode}: {n} ÷ 10^5 ± 10" + ); + } + for n in range_0 { + let mut fd = FixedDecimal::from(n); + f(&mut fd, 3); + fd.set_sign(Sign::None); // get rid of -0 + assert_eq!(fd.write_to_string(), "000", "{rounding_mode}: {n}"); + let (mut fd, expected) = if n < 0 { + ( + FixedDecimal::from(n - 1000000).multiplied_pow10(-5), + "-10.00", + ) + } else { + ( + FixedDecimal::from(n + 1000000).multiplied_pow10(-5), + "10.00", + ) + }; + f(&mut fd, -2); + assert_eq!( + fd.write_to_string(), + expected, + "{rounding_mode}: {n} ÷ 10^5 ± 10" + ); + } + for n in range_1000 { + let mut fd = FixedDecimal::from(n); + f(&mut fd, 3); + assert_eq!(fd.write_to_string(), "1000", "{rounding_mode}: {n}"); + let mut fd = FixedDecimal::from(n + 1000000).multiplied_pow10(-5); + f(&mut fd, -2); + assert_eq!( + fd.write_to_string(), + "10.01", + "{rounding_mode}: {n} ÷ 10^5 ± 10" + ); + } + for n in range_2000 { + let mut fd = FixedDecimal::from(n); + f(&mut fd, 3); + assert_eq!(fd.write_to_string(), "2000", "{rounding_mode}: {n}"); + let mut fd = FixedDecimal::from(n + 1000000).multiplied_pow10(-5); + f(&mut fd, -2); + assert_eq!( + fd.write_to_string(), + "10.02", + "{rounding_mode}: {n} ÷ 10^5 ± 10" + ); + } + } +} + +#[test] +pub fn extra_rounding_mode_cases() { + struct TestCase { + input: &'static str, + position: i16, + // ceil, floor, expand, trunc, half_ceil, half_floor, half_expand, half_trunc, half_even + all_expected: [&'static str; 9], + } + let cases: [TestCase; 8] = [ + TestCase { + input: "505.050", + position: -3, + all_expected: [ + "505.050", "505.050", "505.050", "505.050", "505.050", "505.050", "505.050", + "505.050", "505.050", + ], + }, + TestCase { + input: "505.050", + position: -2, + all_expected: [ + "505.05", "505.05", "505.05", "505.05", "505.05", "505.05", "505.05", "505.05", + "505.05", + ], + }, + TestCase { + input: "505.050", + position: -1, + all_expected: [ + "505.1", "505.0", "505.1", "505.0", "505.1", "505.0", "505.1", "505.0", "505.0", + ], + }, + TestCase { + input: "505.050", + position: 0, + all_expected: [ + "506", "505", "506", "505", "505", "505", "505", "505", "505", + ], + }, + TestCase { + input: "505.050", + position: 1, + all_expected: [ + "510", "500", "510", "500", "510", "510", "510", "510", "510", + ], + }, + TestCase { + input: "505.050", + position: 2, + all_expected: [ + "600", "500", "600", "500", "500", "500", "500", "500", "500", + ], + }, + TestCase { + input: "505.050", + position: 3, + all_expected: [ + "1000", "000", "1000", "000", "1000", "1000", "1000", "1000", "1000", + ], + }, + TestCase { + input: "505.050", + position: 4, + all_expected: [ + "10000", "0000", "10000", "0000", "0000", "0000", "0000", "0000", "0000", + ], + }, + ]; + #[allow(clippy::type_complexity)] // most compact representation in code + let rounding_modes: [(&'static str, fn(&mut FixedDecimal, i16)); 9] = [ + ("ceil", FixedDecimal::ceil), + ("floor", FixedDecimal::floor), + ("expand", FixedDecimal::expand), + ("trunc", FixedDecimal::trunc), + ("half_ceil", FixedDecimal::half_ceil), + ("half_floor", FixedDecimal::half_floor), + ("half_expand", FixedDecimal::half_expand), + ("half_trunc", FixedDecimal::half_trunc), + ("half_even", FixedDecimal::half_even), + ]; + for TestCase { + input, + position, + all_expected, + } in cases + { + for ((rounding_mode, f), expected) in rounding_modes.iter().zip(all_expected.iter()) { + let mut fd: FixedDecimal = input.parse().unwrap(); + f(&mut fd, position); + assert_eq!( + &*fd.write_to_string(), + *expected, + "{input}: {rounding_mode} @ {position}" + ) + } + } +}