summaryrefslogtreecommitdiffstats
path: root/Lib/statistics.py
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2021-12-01 00:20:08 (GMT)
committerGitHub <noreply@github.com>2021-12-01 00:20:08 (GMT)
commita39f46afdead515e7ac3722464b5ee8d7b0b2c9b (patch)
treed6f13232e73c75e9f8514411930add6a531b5ea6 /Lib/statistics.py
parent8a45ca542a65ea27e7acaa44a4c833a27830e796 (diff)
downloadcpython-a39f46afdead515e7ac3722464b5ee8d7b0b2c9b.zip
cpython-a39f46afdead515e7ac3722464b5ee8d7b0b2c9b.tar.gz
cpython-a39f46afdead515e7ac3722464b5ee8d7b0b2c9b.tar.bz2
bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case (GH-29828)
Diffstat (limited to 'Lib/statistics.py')
-rw-r--r--Lib/statistics.py79
1 files changed, 66 insertions, 13 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py
index cf8eaa0..9f1efa2 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -137,7 +137,7 @@ from decimal import Decimal
from itertools import groupby, repeat
from bisect import bisect_left, bisect_right
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
-from operator import itemgetter, mul
+from operator import mul
from collections import Counter, namedtuple
_SQRT2 = sqrt(2.0)
@@ -248,6 +248,28 @@ def _exact_ratio(x):
x is expected to be an int, Fraction, Decimal or float.
"""
+
+ # XXX We should revisit whether using fractions to accumulate exact
+ # ratios is the right way to go.
+
+ # The integer ratios for binary floats can have numerators or
+ # denominators with over 300 decimal digits. The problem is more
+ # acute with decimal floats where the the default decimal context
+ # supports a huge range of exponents from Emin=-999999 to
+ # Emax=999999. When expanded with as_integer_ratio(), numbers like
+ # Decimal('3.14E+5000') and Decimal('3.14E-5000') have large
+ # numerators or denominators that will slow computation.
+
+ # When the integer ratios are accumulated as fractions, the size
+ # grows to cover the full range from the smallest magnitude to the
+ # largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300),
+ # has a 616 digit numerator. Likewise,
+ # Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000'))
+ # has 10,003 digit numerator.
+
+ # This doesn't seem to have been problem in practice, but it is a
+ # potential pitfall.
+
try:
return x.as_integer_ratio()
except AttributeError:
@@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'):
raise StatisticsError(errmsg)
yield x
-def _isqrt_frac_rto(n: int, m: int) -> float:
+
+def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
a = math.isqrt(n // m)
return a | (a*a*m != n)
-# For 53 bit precision floats, the _sqrt_frac() shift is 109.
-_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
-def _sqrt_frac(n: int, m: int) -> float:
+# For 53 bit precision floats, the bit width used in
+# _float_sqrt_of_frac() is 109.
+_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3
+
+
+def _float_sqrt_of_frac(n: int, m: int) -> float:
"""Square root of n/m as a float, correctly rounded."""
# See principle and proof sketch at: https://bugs.python.org/msg407078
- q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
+ q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2
if q >= 0:
- numerator = _isqrt_frac_rto(n, m << 2 * q) << q
+ numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q
denominator = 1
else:
- numerator = _isqrt_frac_rto(n << -2 * q, m)
+ numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)
denominator = 1 << -q
return numerator / denominator # Convert to float
+def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal:
+ """Square root of n/m as a Decimal, correctly rounded."""
+ # Premise: For decimal, computing (n/m).sqrt() can be off
+ # by 1 ulp from the correctly rounded result.
+ # Method: Check the result, moving up or down a step if needed.
+ if n <= 0:
+ if not n:
+ return Decimal('0.0')
+ n, m = -n, -m
+
+ root = (Decimal(n) / Decimal(m)).sqrt()
+ nr, dr = root.as_integer_ratio()
+
+ plus = root.next_plus()
+ np, dp = plus.as_integer_ratio()
+ # test: n / m > ((root + plus) / 2) ** 2
+ if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2:
+ return plus
+
+ minus = root.next_minus()
+ nm, dm = minus.as_integer_ratio()
+ # test: n / m < ((root + minus) / 2) ** 2
+ if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2:
+ return minus
+
+ return root
+
+
# === Measures of central tendency (averages) ===
def mean(data):
@@ -869,7 +923,7 @@ def stdev(data, xbar=None):
if hasattr(T, 'sqrt'):
var = _convert(mss, T)
return var.sqrt()
- return _sqrt_frac(mss.numerator, mss.denominator)
+ return _float_sqrt_of_frac(mss.numerator, mss.denominator)
def pstdev(data, mu=None):
@@ -888,10 +942,9 @@ def pstdev(data, mu=None):
raise StatisticsError('pstdev requires at least one data point')
T, ss = _ss(data, mu)
mss = ss / n
- if hasattr(T, 'sqrt'):
- var = _convert(mss, T)
- return var.sqrt()
- return _sqrt_frac(mss.numerator, mss.denominator)
+ if issubclass(T, Decimal):
+ return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
+ return _float_sqrt_of_frac(mss.numerator, mss.denominator)
# === Statistics for relations between two inputs ===