Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case #29828

Merged
merged 37 commits into from
Dec 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bbd2da9
Merge pull request #1 from python/master
rhettinger Mar 16, 2021
74bdf1b
Merge branch 'master' of github.com:python/cpython
rhettinger Mar 22, 2021
6c53f1a
Merge branch 'master' of github.com:python/cpython
rhettinger Mar 22, 2021
a487c4f
.
rhettinger Mar 24, 2021
eb56423
.
rhettinger Mar 25, 2021
cc7ba06
.
rhettinger Mar 26, 2021
d024dd0
.
rhettinger Apr 22, 2021
b10f912
merge
rhettinger May 5, 2021
fb6744d
merge
rhettinger May 6, 2021
7f21a1c
Merge branch 'main' of github.com:python/cpython
rhettinger Aug 15, 2021
7da42d4
Merge branch 'main' of github.com:rhettinger/cpython
rhettinger Aug 25, 2021
e31757b
Merge branch 'main' of github.com:python/cpython
rhettinger Aug 31, 2021
f058a6f
Merge branch 'main' of github.com:python/cpython
rhettinger Aug 31, 2021
1fc29bd
Merge branch 'main' of github.com:python/cpython
rhettinger Sep 4, 2021
e5c0184
Merge branch 'main' of github.com:python/cpython
rhettinger Oct 30, 2021
3c86ec1
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 9, 2021
96675e4
Merge branch 'main' of github.com:rhettinger/cpython
rhettinger Nov 9, 2021
de558c6
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 9, 2021
418a07f
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 14, 2021
ea23a8b
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 21, 2021
ba248b7
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 27, 2021
037b5fe
Correctly rounded stdev results for Decimal inputs
rhettinger Nov 29, 2021
f5c091c
Whitespace
rhettinger Nov 29, 2021
70cdade
Rename the functions consistently
rhettinger Nov 29, 2021
82dbec6
Improve comment
rhettinger Nov 29, 2021
1a6c58d
Tweak variable names
rhettinger Nov 29, 2021
b2385b0
Replace Fraction arithmetic with integer arithmetic
rhettinger Nov 29, 2021
594ea27
Add spacing between terms
rhettinger Nov 29, 2021
3911581
Fix type annotation
rhettinger Nov 29, 2021
a09e3c4
Return a Decimal zero when the numerator is zero
rhettinger Nov 29, 2021
152ed3f
Remove unused import
rhettinger Nov 29, 2021
80371c1
Factor lhs of inequality. Rename helper function for consistency.
rhettinger Nov 29, 2021
1c86e7c
Add comment for future work.
rhettinger Nov 29, 2021
0684fac
Fix typo in docstring. Refine wording in comment.
rhettinger Nov 29, 2021
8b5e377
Add more detail to the comment about numerator and denominator sizes
rhettinger Nov 30, 2021
d11d567
Improve variable name
rhettinger Nov 30, 2021
309cb0a
Avoid double rounding in test code
rhettinger Nov 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 66 additions & 13 deletions Lib/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@
from itertools import groupby, repeat
from bisect import bisect_left, bisect_right
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
from operator import itemgetter, mul
from operator import mul
from collections import Counter, namedtuple

_SQRT2 = sqrt(2.0)
Expand Down Expand Up @@ -248,6 +248,28 @@ def _exact_ratio(x):

x is expected to be an int, Fraction, Decimal or float.
"""

# XXX We should revisit whether using fractions to accumulate exact
# ratios is the right way to go.

# The integer ratios for binary floats can have numerators or
# denominators with over 300 decimal digits. The problem is more
# acute with decimal floats where the the default decimal context
# supports a huge range of exponents from Emin=-999999 to
# Emax=999999. When expanded with as_integer_ratio(), numbers like
# Decimal('3.14E+5000') and Decimal('3.14E-5000') have large
# numerators or denominators that will slow computation.

# When the integer ratios are accumulated as fractions, the size
# grows to cover the full range from the smallest magnitude to the
# largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300),
# has a 616 digit numerator. Likewise,
# Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000'))
# has 10,003 digit numerator.

# This doesn't seem to have been problem in practice, but it is a
# potential pitfall.

try:
return x.as_integer_ratio()
except AttributeError:
Expand Down Expand Up @@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'):
raise StatisticsError(errmsg)
yield x

def _isqrt_frac_rto(n: int, m: int) -> float:

def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
a = math.isqrt(n // m)
return a | (a*a*m != n)

# For 53 bit precision floats, the _sqrt_frac() shift is 109.
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3

def _sqrt_frac(n: int, m: int) -> float:
# For 53 bit precision floats, the bit width used in
# _float_sqrt_of_frac() is 109.
_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3


def _float_sqrt_of_frac(n: int, m: int) -> float:
"""Square root of n/m as a float, correctly rounded."""
# See principle and proof sketch at: https://bugs.python.org/msg407078
q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2
if q >= 0:
numerator = _isqrt_frac_rto(n, m << 2 * q) << q
numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q
denominator = 1
else:
numerator = _isqrt_frac_rto(n << -2 * q, m)
numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)
denominator = 1 << -q
return numerator / denominator # Convert to float


def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal:
"""Square root of n/m as a Decimal, correctly rounded."""
# Premise: For decimal, computing (n/m).sqrt() can be off
# by 1 ulp from the correctly rounded result.
# Method: Check the result, moving up or down a step if needed.
if n <= 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be checking m here rather than n? It looks as though we multiply through by m below, so we need m to be positive for the inequalities to work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case where m is zero is handled before the inequality test. The Decimal(n) / Decimal(m) step raises DivisionByZero which is a subclass of ZeroDivisionError. There is a test for this case.

Copy link
Member

@mdickinson mdickinson Nov 30, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was really more worried about the case m negative, not m zero. It seems more natural to have the normalization step n, m = -n, -m ensure that the denominator is positive than that the numerator is positive: right now, we're effectively converting -2 / 5 to 2 / (-5), and what we care about below for the maths to work is that m is positive, not that n is positive: you're implicitly converting the inequality n / m > ((root + plus) / 2) ** 2 to the inequality n > ((root + plus) / 2)**2 * m, and that conversion is only valid if m is positive.

It doesn't matter, because if just one of n and m is negative then the (Decimal(n) / Decimal(m)).sqrt() step will fail (except in the case of really extreme m where the division rounds from something tiny and negative to negative zero); it just seems like a surprising choice and for me at least it made the code harder to read and reason about.

if not n:
return Decimal('0.0')
n, m = -n, -m

root = (Decimal(n) / Decimal(m)).sqrt()
nr, dr = root.as_integer_ratio()

plus = root.next_plus()
np, dp = plus.as_integer_ratio()
# test: n / m > ((root + plus) / 2) ** 2
if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2:
return plus

minus = root.next_minus()
nm, dm = minus.as_integer_ratio()
# test: n / m < ((root + minus) / 2) ** 2
if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2:
return minus

return root


# === Measures of central tendency (averages) ===

def mean(data):
Expand Down Expand Up @@ -869,7 +923,7 @@ def stdev(data, xbar=None):
if hasattr(T, 'sqrt'):
var = _convert(mss, T)
return var.sqrt()
return _sqrt_frac(mss.numerator, mss.denominator)
return _float_sqrt_of_frac(mss.numerator, mss.denominator)


def pstdev(data, mu=None):
Expand All @@ -888,10 +942,9 @@ def pstdev(data, mu=None):
raise StatisticsError('pstdev requires at least one data point')
T, ss = _ss(data, mu)
mss = ss / n
if hasattr(T, 'sqrt'):
var = _convert(mss, T)
return var.sqrt()
return _sqrt_frac(mss.numerator, mss.denominator)
if issubclass(T, Decimal):
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
return _float_sqrt_of_frac(mss.numerator, mss.denominator)


# === Statistics for relations between two inputs ===
Expand Down
55 changes: 46 additions & 9 deletions Lib/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2164,9 +2164,9 @@ def test_center_not_at_mean(self):

class TestSqrtHelpers(unittest.TestCase):

def test_isqrt_frac_rto(self):
def test_integer_sqrt_of_frac_rto(self):
for n, m in itertools.product(range(100), range(1, 1000)):
r = statistics._isqrt_frac_rto(n, m)
r = statistics._integer_sqrt_of_frac_rto(n, m)
self.assertIsInstance(r, int)
if r*r*m == n:
# Root is exact
Expand All @@ -2177,7 +2177,7 @@ def test_isqrt_frac_rto(self):
self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)

@requires_IEEE_754
def test_sqrt_frac(self):
def test_float_sqrt_of_frac(self):

def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
if not x:
Expand All @@ -2204,22 +2204,59 @@ def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
denonimator: int = randrange(10 ** randrange(50)) + 1
with self.subTest(numerator=numerator, denonimator=denonimator):
x: Fraction = Fraction(numerator, denonimator)
root: float = statistics._sqrt_frac(numerator, denonimator)
root: float = statistics._float_sqrt_of_frac(numerator, denonimator)
self.assertTrue(is_root_correctly_rounded(x, root))

# Verify that corner cases and error handling match math.sqrt()
self.assertEqual(statistics._sqrt_frac(0, 1), 0.0)
self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0)
with self.assertRaises(ValueError):
statistics._sqrt_frac(-1, 1)
statistics._float_sqrt_of_frac(-1, 1)
with self.assertRaises(ValueError):
statistics._sqrt_frac(1, -1)
statistics._float_sqrt_of_frac(1, -1)

# Error handling for zero denominator matches that for Fraction(1, 0)
with self.assertRaises(ZeroDivisionError):
statistics._sqrt_frac(1, 0)
statistics._float_sqrt_of_frac(1, 0)

# The result is well defined if both inputs are negative
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1))

def test_decimal_sqrt_of_frac(self):
root: Decimal
numerator: int
denominator: int

for root, numerator, denominator in [
(Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj
(Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up
(Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down
]:
with decimal.localcontext(decimal.DefaultContext):
self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root)

# Confirm expected root with a quad precision decimal computation
with decimal.localcontext(decimal.DefaultContext) as ctx:
ctx.prec *= 4
high_prec_ratio = Decimal(numerator) / Decimal(denominator)
ctx.rounding = decimal.ROUND_05UP
high_prec_root = high_prec_ratio.sqrt()
with decimal.localcontext(decimal.DefaultContext):
target_root = +high_prec_root
self.assertEqual(root, target_root)

# Verify that corner cases and error handling match Decimal.sqrt()
self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0)
with self.assertRaises(decimal.InvalidOperation):
statistics._decimal_sqrt_of_frac(-1, 1)
with self.assertRaises(decimal.InvalidOperation):
statistics._decimal_sqrt_of_frac(1, -1)

# Error handling for zero denominator matches that for Fraction(1, 0)
with self.assertRaises(ZeroDivisionError):
statistics._decimal_sqrt_of_frac(1, 0)

# The result is well defined if both inputs are negative
self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1))


class TestStdev(VarianceStdevMixin, NumericTestCase):
Expand Down