Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix FixedDecimal rounding bug; add more tests #3644

Merged
merged 7 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 80 additions & 74 deletions utils/fixed_decimal/src/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -934,63 +934,6 @@ impl FixedDecimal {
self.check_invariants();
}

/// Increments the digits by 1. if the digits are empty, it will add
/// an element with value 1. If there are some trailing zeros,
/// it will be reomved from `self.digits`.
fn increment_abs_by_one(&mut self) -> Result<(), Error> {
for (zero_count, digit) in self.digits.iter_mut().rev().enumerate() {
*digit += 1;
if *digit < 10 {
self.digits.truncate(self.digits.len() - zero_count);
#[cfg(debug_assertions)]
self.check_invariants();
return Ok(());
}
}

self.digits.clear();

if self.magnitude == i16::MAX {
self.magnitude = 0;

#[cfg(debug_assertions)]
self.check_invariants();
return Err(Error::Limit);
}

// Still a carry, carry one to the next magnitude.
self.digits.push(1);
self.magnitude += 1;

if self.upper_magnitude < self.magnitude {
self.upper_magnitude = self.magnitude;
}

#[cfg(debug_assertions)]
self.check_invariants();
Ok(())
}

/// Removes the trailing zeros in `self.digits`
fn remove_trailing_zeros_from_digits_list(&mut self) {
let no_of_trailing_zeros = self
.digits
.iter()
.rev()
.take_while(|&digit| *digit == 0)
.count();

self.digits
.truncate(self.digits.len() - no_of_trailing_zeros);

if self.digits.is_empty() {
self.magnitude = 0;
}

#[cfg(debug_assertions)]
self.check_invariants();
}

/// Truncate the number on the right to a particular position, deleting
/// digits if necessary.
///
Expand Down Expand Up @@ -1046,6 +989,7 @@ impl FixedDecimal {
/// assert_eq!("1", dec.to_string());
/// ```
pub fn trunc(&mut self, position: i16) {
// 1. Set upper and lower magnitude
self.lower_magnitude = cmp::min(position, 0);
if position == i16::MIN {
// Nothing more to do
Expand All @@ -1056,11 +1000,32 @@ impl FixedDecimal {
let magnitude = position - 1;
self.upper_magnitude = cmp::max(self.upper_magnitude, magnitude);

if magnitude <= self.magnitude {
self.digits
.truncate(crate::ops::i16_abs_sub(self.magnitude, magnitude) as usize);
self.remove_trailing_zeros_from_digits_list();
// 2. If the rounding position is *lower than* the rightmost nonzero digit, exit early
if self.is_zero() || magnitude < self.nonzero_magnitude_end() {
#[cfg(debug_assertions)]
self.check_invariants();
return;
}

// 3. If the rounding position is *in the middle* of the nonzero digits
if magnitude < self.magnitude {
// 3a. Calculate the number of digits to retain and remove the rest
let digits_to_retain = crate::ops::i16_abs_sub(self.magnitude, magnitude);
self.digits.truncate(digits_to_retain as usize);
// 3b. Remove trailing zeros from self.digits to retain invariants
// Note: this does not affect visible trailing zeros,
// which is tracked by self.lower_magnitude
let position_past_last_nonzero_digit = self
.digits
.iter()
.rposition(|x| *x != 0)
.map(|x| x + 1)
.unwrap_or(0);
self.digits.truncate(position_past_last_nonzero_digit);
// 3c. By the invariant, there should still be at least 1 nonzero digit
debug_assert!(!self.digits.is_empty());
} else {
// 4. If the rounding position is *above* the leftmost nonzero digit, set to zero
self.digits.clear();
self.magnitude = 0;
}
Expand Down Expand Up @@ -1168,32 +1133,60 @@ impl FixedDecimal {
/// assert_eq!("2", dec.to_string());
/// ```
pub fn expand(&mut self, position: i16) {
let before_truncate_is_zero = self.is_zero();
let before_truncate_bottom_magnitude = self.nonzero_magnitude_end();
let before_truncate_magnitude = self.magnitude;
self.trunc(position);
// 1. Set upper and lower magnitude
self.lower_magnitude = cmp::min(position, 0);
if position == i16::MIN {
// Nothing more to do
#[cfg(debug_assertions)]
self.check_invariants();
return;
}
let magnitude = position - 1;
self.upper_magnitude = cmp::max(self.upper_magnitude, magnitude);

if before_truncate_is_zero || position <= before_truncate_bottom_magnitude {
// 2. If the rounding position is *lower than* the rightmost nonzero digit, exit early
if self.is_zero() || magnitude < self.nonzero_magnitude_end() {
#[cfg(debug_assertions)]
self.check_invariants();
return;
}

if position <= before_truncate_magnitude {
let result = self.increment_abs_by_one();
if result.is_err() {
// Do nothing for now.
// 3. If the rounding position is *in the middle* of the nonzero digits
if magnitude < self.magnitude {
// 3a. Calculate the number of digits to retain and remove the rest
let digits_to_retain = crate::ops::i16_abs_sub(self.magnitude, magnitude);
self.digits.truncate(digits_to_retain as usize);
// 3b. Increment the rightmost remaining digit since we are rounding up; this might
// require bubbling the addition to higher magnitudes, like 199 + 1 = 200
for (zero_count, digit) in self.digits.iter_mut().rev().enumerate() {
*digit += 1;
if *digit < 10 {
self.digits.truncate(self.digits.len() - zero_count);
#[cfg(debug_assertions)]
self.check_invariants();
return;
}
}
// 3c. If we get here, the mantissa is fully saturated, and we continue into case 4
}

// 4. If the rounding position is *above* the leftmost nonzero digit, OR if we saturated,
// set the value to a single digit 1 at the appropriate position
if self.magnitude == i16::MAX {
// TODO(#2297): Decide on behavior here
self.magnitude = 0;
self.digits.clear();
#[cfg(debug_assertions)]
self.check_invariants();
return;
}

debug_assert!(self.digits.is_empty());
self.digits.clear();
self.digits.push(1);
self.magnitude = position;
self.upper_magnitude = cmp::max(self.upper_magnitude, position);
// If we got here from case 3, we should use self.magnitude + 1
// If we got here directly from case 4, we should use magnitude + 1
// We can use cmp::max to pick the right one
self.magnitude = cmp::max(self.magnitude, magnitude) + 1;
self.upper_magnitude = cmp::max(self.upper_magnitude, self.magnitude);

#[cfg(debug_assertions)]
self.check_invariants();
Expand Down Expand Up @@ -3338,6 +3331,19 @@ fn test_rounding() {
let mut dec = FixedDecimal::from_str("-0.009").unwrap();
dec.half_expand(-1);
assert_eq!("-0.0", dec.to_string());

// // Test specific cases
let mut dec = FixedDecimal::from_str("1.108").unwrap();
dec.half_even(-2);
assert_eq!("1.11", dec.to_string());

let mut dec = FixedDecimal::from_str("1.108").unwrap();
dec.expand(-2);
assert_eq!("1.11", dec.to_string());

let mut dec = FixedDecimal::from_str("1.108").unwrap();
dec.trunc(-2);
assert_eq!("1.10", dec.to_string());
}

#[test]
Expand Down
Loading
Loading