diff --git a/Lib/statistics.py b/Lib/statistics.py index cf8eaa0a61e624..9f1efa21b15e3c 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -137,7 +137,7 @@ from itertools import groupby, repeat from bisect import bisect_left, bisect_right from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum -from operator import itemgetter, mul +from operator import mul from collections import Counter, namedtuple _SQRT2 = sqrt(2.0) @@ -248,6 +248,28 @@ def _exact_ratio(x): x is expected to be an int, Fraction, Decimal or float. """ + + # XXX We should revisit whether using fractions to accumulate exact + # ratios is the right way to go. + + # The integer ratios for binary floats can have numerators or + # denominators with over 300 decimal digits. The problem is more + # acute with decimal floats where the the default decimal context + # supports a huge range of exponents from Emin=-999999 to + # Emax=999999. When expanded with as_integer_ratio(), numbers like + # Decimal('3.14E+5000') and Decimal('3.14E-5000') have large + # numerators or denominators that will slow computation. + + # When the integer ratios are accumulated as fractions, the size + # grows to cover the full range from the smallest magnitude to the + # largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300), + # has a 616 digit numerator. Likewise, + # Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000')) + # has 10,003 digit numerator. + + # This doesn't seem to have been problem in practice, but it is a + # potential pitfall. + try: return x.as_integer_ratio() except AttributeError: @@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'): raise StatisticsError(errmsg) yield x -def _isqrt_frac_rto(n: int, m: int) -> float: + +def _integer_sqrt_of_frac_rto(n: int, m: int) -> int: """Square root of n/m, rounded to the nearest integer using round-to-odd.""" # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf a = math.isqrt(n // m) return a | (a*a*m != n) -# For 53 bit precision floats, the _sqrt_frac() shift is 109. -_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3 -def _sqrt_frac(n: int, m: int) -> float: +# For 53 bit precision floats, the bit width used in +# _float_sqrt_of_frac() is 109. +_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3 + + +def _float_sqrt_of_frac(n: int, m: int) -> float: """Square root of n/m as a float, correctly rounded.""" # See principle and proof sketch at: https://bugs.python.org/msg407078 - q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2 + q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2 if q >= 0: - numerator = _isqrt_frac_rto(n, m << 2 * q) << q + numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q denominator = 1 else: - numerator = _isqrt_frac_rto(n << -2 * q, m) + numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m) denominator = 1 << -q return numerator / denominator # Convert to float +def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal: + """Square root of n/m as a Decimal, correctly rounded.""" + # Premise: For decimal, computing (n/m).sqrt() can be off + # by 1 ulp from the correctly rounded result. + # Method: Check the result, moving up or down a step if needed. + if n <= 0: + if not n: + return Decimal('0.0') + n, m = -n, -m + + root = (Decimal(n) / Decimal(m)).sqrt() + nr, dr = root.as_integer_ratio() + + plus = root.next_plus() + np, dp = plus.as_integer_ratio() + # test: n / m > ((root + plus) / 2) ** 2 + if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2: + return plus + + minus = root.next_minus() + nm, dm = minus.as_integer_ratio() + # test: n / m < ((root + minus) / 2) ** 2 + if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2: + return minus + + return root + + # === Measures of central tendency (averages) === def mean(data): @@ -869,7 +923,7 @@ def stdev(data, xbar=None): if hasattr(T, 'sqrt'): var = _convert(mss, T) return var.sqrt() - return _sqrt_frac(mss.numerator, mss.denominator) + return _float_sqrt_of_frac(mss.numerator, mss.denominator) def pstdev(data, mu=None): @@ -888,10 +942,9 @@ def pstdev(data, mu=None): raise StatisticsError('pstdev requires at least one data point') T, ss = _ss(data, mu) mss = ss / n - if hasattr(T, 'sqrt'): - var = _convert(mss, T) - return var.sqrt() - return _sqrt_frac(mss.numerator, mss.denominator) + if issubclass(T, Decimal): + return _decimal_sqrt_of_frac(mss.numerator, mss.denominator) + return _float_sqrt_of_frac(mss.numerator, mss.denominator) # === Statistics for relations between two inputs === diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index 771a03e707ee01..bacb76a9b036bf 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -2164,9 +2164,9 @@ def test_center_not_at_mean(self): class TestSqrtHelpers(unittest.TestCase): - def test_isqrt_frac_rto(self): + def test_integer_sqrt_of_frac_rto(self): for n, m in itertools.product(range(100), range(1, 1000)): - r = statistics._isqrt_frac_rto(n, m) + r = statistics._integer_sqrt_of_frac_rto(n, m) self.assertIsInstance(r, int) if r*r*m == n: # Root is exact @@ -2177,7 +2177,7 @@ def test_isqrt_frac_rto(self): self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2) @requires_IEEE_754 - def test_sqrt_frac(self): + def test_float_sqrt_of_frac(self): def is_root_correctly_rounded(x: Fraction, root: float) -> bool: if not x: @@ -2204,22 +2204,59 @@ def is_root_correctly_rounded(x: Fraction, root: float) -> bool: denonimator: int = randrange(10 ** randrange(50)) + 1 with self.subTest(numerator=numerator, denonimator=denonimator): x: Fraction = Fraction(numerator, denonimator) - root: float = statistics._sqrt_frac(numerator, denonimator) + root: float = statistics._float_sqrt_of_frac(numerator, denonimator) self.assertTrue(is_root_correctly_rounded(x, root)) # Verify that corner cases and error handling match math.sqrt() - self.assertEqual(statistics._sqrt_frac(0, 1), 0.0) + self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0) with self.assertRaises(ValueError): - statistics._sqrt_frac(-1, 1) + statistics._float_sqrt_of_frac(-1, 1) with self.assertRaises(ValueError): - statistics._sqrt_frac(1, -1) + statistics._float_sqrt_of_frac(1, -1) # Error handling for zero denominator matches that for Fraction(1, 0) with self.assertRaises(ZeroDivisionError): - statistics._sqrt_frac(1, 0) + statistics._float_sqrt_of_frac(1, 0) # The result is well defined if both inputs are negative - self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0)) + self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1)) + + def test_decimal_sqrt_of_frac(self): + root: Decimal + numerator: int + denominator: int + + for root, numerator, denominator in [ + (Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj + (Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up + (Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down + ]: + with decimal.localcontext(decimal.DefaultContext): + self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root) + + # Confirm expected root with a quad precision decimal computation + with decimal.localcontext(decimal.DefaultContext) as ctx: + ctx.prec *= 4 + high_prec_ratio = Decimal(numerator) / Decimal(denominator) + ctx.rounding = decimal.ROUND_05UP + high_prec_root = high_prec_ratio.sqrt() + with decimal.localcontext(decimal.DefaultContext): + target_root = +high_prec_root + self.assertEqual(root, target_root) + + # Verify that corner cases and error handling match Decimal.sqrt() + self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0) + with self.assertRaises(decimal.InvalidOperation): + statistics._decimal_sqrt_of_frac(-1, 1) + with self.assertRaises(decimal.InvalidOperation): + statistics._decimal_sqrt_of_frac(1, -1) + + # Error handling for zero denominator matches that for Fraction(1, 0) + with self.assertRaises(ZeroDivisionError): + statistics._decimal_sqrt_of_frac(1, 0) + + # The result is well defined if both inputs are negative + self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1)) class TestStdev(VarianceStdevMixin, NumericTestCase):