-
-
Notifications
You must be signed in to change notification settings - Fork 30.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case #29828
Changes from 33 commits
bbd2da9
74bdf1b
6c53f1a
a487c4f
eb56423
cc7ba06
d024dd0
b10f912
fb6744d
7f21a1c
7da42d4
e31757b
f058a6f
1fc29bd
e5c0184
3c86ec1
96675e4
de558c6
418a07f
ea23a8b
ba248b7
037b5fe
f5c091c
70cdade
82dbec6
1a6c58d
b2385b0
594ea27
3911581
a09e3c4
152ed3f
80371c1
1c86e7c
0684fac
8b5e377
d11d567
309cb0a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -137,7 +137,7 @@ | |
from itertools import groupby, repeat | ||
from bisect import bisect_left, bisect_right | ||
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum | ||
from operator import itemgetter, mul | ||
from operator import mul | ||
from collections import Counter, namedtuple | ||
|
||
_SQRT2 = sqrt(2.0) | ||
|
@@ -248,6 +248,15 @@ def _exact_ratio(x): | |
|
||
x is expected to be an int, Fraction, Decimal or float. | ||
""" | ||
|
||
# XXX We should revisit whether accumulating exact ratios is the | ||
# right way to go. The default decimal context supports a huge range | ||
# of exponents. When expanded with as_integer_ratio(), numbers like | ||
# Decimal('3.14E+5000') and Decimal('3.14E-5000') have large | ||
# numerators or denominators that will slow computation. This | ||
# doesn't seem to have been problem in practice, but it is a | ||
# potential pitfall. | ||
|
||
try: | ||
return x.as_integer_ratio() | ||
except AttributeError: | ||
|
@@ -305,28 +314,58 @@ def _fail_neg(values, errmsg='negative value'): | |
raise StatisticsError(errmsg) | ||
yield x | ||
|
||
def _isqrt_frac_rto(n: int, m: int) -> float: | ||
|
||
def _integer_sqrt_of_frac_rto(n: int, m: int) -> int: | ||
"""Square root of n/m, rounded to the nearest integer using round-to-odd.""" | ||
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf | ||
a = math.isqrt(n // m) | ||
return a | (a*a*m != n) | ||
|
||
# For 53 bit precision floats, the _sqrt_frac() shift is 109. | ||
|
||
# For 53 bit precision floats, the _float_sqrt_of_frac() shift is 109. | ||
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3 | ||
|
||
def _sqrt_frac(n: int, m: int) -> float: | ||
|
||
def _float_sqrt_of_frac(n: int, m: int) -> float: | ||
"""Square root of n/m as a float, correctly rounded.""" | ||
# See principle and proof sketch at: https://bugs.python.org/msg407078 | ||
q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2 | ||
if q >= 0: | ||
numerator = _isqrt_frac_rto(n, m << 2 * q) << q | ||
numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q | ||
denominator = 1 | ||
else: | ||
numerator = _isqrt_frac_rto(n << -2 * q, m) | ||
numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m) | ||
denominator = 1 << -q | ||
return numerator / denominator # Convert to float | ||
|
||
|
||
def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal: | ||
"""Square root of n/m as a float, correctly rounded.""" | ||
# Premise: For decimal, computing (n/m).sqrt() can be off by 1 ulp. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately, while the corresponding result is true for binary floating-point, it's not true for decimal floating-point: for decimal, the error can be more than 1 ulp. Example using 3-digit precision: >>> from decimal import Decimal, getcontext
>>> n, m = 209, 20
>>> getcontext().prec = 3
>>> (Decimal(n) / Decimal(m)).sqrt()
Decimal('3.22')
>>> getcontext().prec = 10
>>> (Decimal(n) / Decimal(m)).sqrt()
Decimal('3.232645975') The correctly-rounded value here is around 1.265 ulps away, at 3.232645975... However, I believe that one can show that the worst case error is at most 1.3ulp (actually 0.5 + sqrt(10)/4 = 1.290569... ulps), in which case it remains true that the correctly-rounded result will always be one of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Self-correction: I guess "off by <= 1ulp" is accurate, so long as it's clear that we're talking about the maximum difference between the computed value and the correctly-rounded value. I was thinking in terms of the max difference between the computed value and the true mathematical value. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updating the comment. |
||
# Method: Check the result, moving up or down a step if needed. | ||
if n <= 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we be checking There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The case where m is zero is handled before the inequality test. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I was really more worried about the case m negative, not m zero. It seems more natural to have the normalization step It doesn't matter, because if just one of |
||
if not n: | ||
return Decimal('0.0') | ||
n, m = -n, -m | ||
|
||
root = (Decimal(n) / Decimal(m)).sqrt() | ||
nr, dr = root.as_integer_ratio() | ||
|
||
plus = root.next_plus() | ||
np, dp = plus.as_integer_ratio() | ||
# test: n / m > ((root + plus) / 2) ** 2 | ||
if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2: | ||
return plus | ||
|
||
minus = root.next_minus() | ||
nm, dm = minus.as_integer_ratio() | ||
# test: n / m < ((root + minus) / 2) ** 2 | ||
if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2: | ||
return minus | ||
|
||
return root | ||
|
||
|
||
# === Measures of central tendency (averages) === | ||
|
||
def mean(data): | ||
|
@@ -869,7 +908,7 @@ def stdev(data, xbar=None): | |
if hasattr(T, 'sqrt'): | ||
var = _convert(mss, T) | ||
return var.sqrt() | ||
return _sqrt_frac(mss.numerator, mss.denominator) | ||
return _float_sqrt_of_frac(mss.numerator, mss.denominator) | ||
|
||
|
||
def pstdev(data, mu=None): | ||
|
@@ -888,10 +927,9 @@ def pstdev(data, mu=None): | |
raise StatisticsError('pstdev requires at least one data point') | ||
T, ss = _ss(data, mu) | ||
mss = ss / n | ||
if hasattr(T, 'sqrt'): | ||
var = _convert(mss, T) | ||
return var.sqrt() | ||
return _sqrt_frac(mss.numerator, mss.denominator) | ||
if issubclass(T, Decimal): | ||
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) | ||
return _float_sqrt_of_frac(mss.numerator, mss.denominator) | ||
|
||
|
||
# === Statistics for relations between two inputs === | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copypasta: "as a float"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.