Skip to content

Commit

Permalink
Fix FixedDecimal rounding bug; add more tests (#3644)
Browse files Browse the repository at this point in the history
  • Loading branch information
sffc authored Jul 7, 2023
1 parent c3f3fb3 commit d1cc1a6
Show file tree
Hide file tree
Showing 2 changed files with 426 additions and 74 deletions.
154 changes: 80 additions & 74 deletions utils/fixed_decimal/src/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -934,63 +934,6 @@ impl FixedDecimal {
self.check_invariants();
}

/// Increments the digits by 1. if the digits are empty, it will add
/// an element with value 1. If there are some trailing zeros,
/// it will be reomved from `self.digits`.
fn increment_abs_by_one(&mut self) -> Result<(), Error> {
for (zero_count, digit) in self.digits.iter_mut().rev().enumerate() {
*digit += 1;
if *digit < 10 {
self.digits.truncate(self.digits.len() - zero_count);
#[cfg(debug_assertions)]
self.check_invariants();
return Ok(());
}
}

self.digits.clear();

if self.magnitude == i16::MAX {
self.magnitude = 0;

#[cfg(debug_assertions)]
self.check_invariants();
return Err(Error::Limit);
}

// Still a carry, carry one to the next magnitude.
self.digits.push(1);
self.magnitude += 1;

if self.upper_magnitude < self.magnitude {
self.upper_magnitude = self.magnitude;
}

#[cfg(debug_assertions)]
self.check_invariants();
Ok(())
}

/// Removes the trailing zeros in `self.digits`
fn remove_trailing_zeros_from_digits_list(&mut self) {
let no_of_trailing_zeros = self
.digits
.iter()
.rev()
.take_while(|&digit| *digit == 0)
.count();

self.digits
.truncate(self.digits.len() - no_of_trailing_zeros);

if self.digits.is_empty() {
self.magnitude = 0;
}

#[cfg(debug_assertions)]
self.check_invariants();
}

/// Truncate the number on the right to a particular position, deleting
/// digits if necessary.
///
Expand Down Expand Up @@ -1046,6 +989,7 @@ impl FixedDecimal {
/// assert_eq!("1", dec.to_string());
/// ```
pub fn trunc(&mut self, position: i16) {
// 1. Set upper and lower magnitude
self.lower_magnitude = cmp::min(position, 0);
if position == i16::MIN {
// Nothing more to do
Expand All @@ -1056,11 +1000,32 @@ impl FixedDecimal {
let magnitude = position - 1;
self.upper_magnitude = cmp::max(self.upper_magnitude, magnitude);

if magnitude <= self.magnitude {
self.digits
.truncate(crate::ops::i16_abs_sub(self.magnitude, magnitude) as usize);
self.remove_trailing_zeros_from_digits_list();
// 2. If the rounding position is *lower than* the rightmost nonzero digit, exit early
if self.is_zero() || magnitude < self.nonzero_magnitude_end() {
#[cfg(debug_assertions)]
self.check_invariants();
return;
}

// 3. If the rounding position is *in the middle* of the nonzero digits
if magnitude < self.magnitude {
// 3a. Calculate the number of digits to retain and remove the rest
let digits_to_retain = crate::ops::i16_abs_sub(self.magnitude, magnitude);
self.digits.truncate(digits_to_retain as usize);
// 3b. Remove trailing zeros from self.digits to retain invariants
// Note: this does not affect visible trailing zeros,
// which is tracked by self.lower_magnitude
let position_past_last_nonzero_digit = self
.digits
.iter()
.rposition(|x| *x != 0)
.map(|x| x + 1)
.unwrap_or(0);
self.digits.truncate(position_past_last_nonzero_digit);
// 3c. By the invariant, there should still be at least 1 nonzero digit
debug_assert!(!self.digits.is_empty());
} else {
// 4. If the rounding position is *above* the leftmost nonzero digit, set to zero
self.digits.clear();
self.magnitude = 0;
}
Expand Down Expand Up @@ -1168,32 +1133,60 @@ impl FixedDecimal {
/// assert_eq!("2", dec.to_string());
/// ```
pub fn expand(&mut self, position: i16) {
let before_truncate_is_zero = self.is_zero();
let before_truncate_bottom_magnitude = self.nonzero_magnitude_end();
let before_truncate_magnitude = self.magnitude;
self.trunc(position);
// 1. Set upper and lower magnitude
self.lower_magnitude = cmp::min(position, 0);
if position == i16::MIN {
// Nothing more to do
#[cfg(debug_assertions)]
self.check_invariants();
return;
}
let magnitude = position - 1;
self.upper_magnitude = cmp::max(self.upper_magnitude, magnitude);

if before_truncate_is_zero || position <= before_truncate_bottom_magnitude {
// 2. If the rounding position is *lower than* the rightmost nonzero digit, exit early
if self.is_zero() || magnitude < self.nonzero_magnitude_end() {
#[cfg(debug_assertions)]
self.check_invariants();
return;
}

if position <= before_truncate_magnitude {
let result = self.increment_abs_by_one();
if result.is_err() {
// Do nothing for now.
// 3. If the rounding position is *in the middle* of the nonzero digits
if magnitude < self.magnitude {
// 3a. Calculate the number of digits to retain and remove the rest
let digits_to_retain = crate::ops::i16_abs_sub(self.magnitude, magnitude);
self.digits.truncate(digits_to_retain as usize);
// 3b. Increment the rightmost remaining digit since we are rounding up; this might
// require bubbling the addition to higher magnitudes, like 199 + 1 = 200
for (zero_count, digit) in self.digits.iter_mut().rev().enumerate() {
*digit += 1;
if *digit < 10 {
self.digits.truncate(self.digits.len() - zero_count);
#[cfg(debug_assertions)]
self.check_invariants();
return;
}
}
// 3c. If we get here, the mantissa is fully saturated, and we continue into case 4
}

// 4. If the rounding position is *above* the leftmost nonzero digit, OR if we saturated,
// set the value to a single digit 1 at the appropriate position
if self.magnitude == i16::MAX {
// TODO(#2297): Decide on behavior here
self.magnitude = 0;
self.digits.clear();
#[cfg(debug_assertions)]
self.check_invariants();
return;
}

debug_assert!(self.digits.is_empty());
self.digits.clear();
self.digits.push(1);
self.magnitude = position;
self.upper_magnitude = cmp::max(self.upper_magnitude, position);
// If we got here from case 3, we should use self.magnitude + 1
// If we got here directly from case 4, we should use magnitude + 1
// We can use cmp::max to pick the right one
self.magnitude = cmp::max(self.magnitude, magnitude) + 1;
self.upper_magnitude = cmp::max(self.upper_magnitude, self.magnitude);

#[cfg(debug_assertions)]
self.check_invariants();
Expand Down Expand Up @@ -3338,6 +3331,19 @@ fn test_rounding() {
let mut dec = FixedDecimal::from_str("-0.009").unwrap();
dec.half_expand(-1);
assert_eq!("-0.0", dec.to_string());

// // Test specific cases
let mut dec = FixedDecimal::from_str("1.108").unwrap();
dec.half_even(-2);
assert_eq!("1.11", dec.to_string());

let mut dec = FixedDecimal::from_str("1.108").unwrap();
dec.expand(-2);
assert_eq!("1.11", dec.to_string());

let mut dec = FixedDecimal::from_str("1.108").unwrap();
dec.trunc(-2);
assert_eq!("1.10", dec.to_string());
}

#[test]
Expand Down
Loading

0 comments on commit d1cc1a6

Please sign in to comment.