summaryrefslogtreecommitdiffstats
path: root/Lib/statistics.py
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2021-11-27 05:54:50 (GMT)
committerGitHub <noreply@github.com>2021-11-27 05:54:50 (GMT)
commitaf9ee57b96cb872df6574e36027cc753417605f9 (patch)
treef0cce757d3ce53ff64b31706875b3b8290f6f0ab /Lib/statistics.py
parentdb55f3fabafc046e4fca907210ced4ce16bf58d6 (diff)
downloadcpython-af9ee57b96cb872df6574e36027cc753417605f9.zip
cpython-af9ee57b96cb872df6574e36027cc753417605f9.tar.gz
cpython-af9ee57b96cb872df6574e36027cc753417605f9.tar.bz2
bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (GH-29736)
* Inlined code from variance functions * Added helper functions for the float square root of a fraction * Call helper functions * Add blurb * Fix over-specified test * Add a test for the _sqrt_frac() helper function * Increase the tested range * Add type hints to the internal function. * Fix test for correct rounding * Simplify ⌊√(n/m)⌋ calculation Co-authored-by: Mark Dickinson <dickinsm@gmail.com> * Add comment and beef-up tests * Test for zero denominator * Add algorithmic references * Add test for the _isqrt_frac_rto() helper function. * Compute the 109 instead of hard-wiring it * Stronger test for _isqrt_frac_rto() * Bigger range * Bigger range * Replace float() call with int/int division to be parallel with the other code path. * Factor out division. Update proof link. Remove internal type declaration Co-authored-by: Mark Dickinson <dickinsm@gmail.com>
Diffstat (limited to 'Lib/statistics.py')
-rw-r--r--Lib/statistics.py56
1 files changed, 42 insertions, 14 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py
index 5c3f77d..cf8eaa0 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -130,6 +130,7 @@ __all__ = [
import math
import numbers
import random
+import sys
from fractions import Fraction
from decimal import Decimal
@@ -304,6 +305,27 @@ def _fail_neg(values, errmsg='negative value'):
raise StatisticsError(errmsg)
yield x
+def _isqrt_frac_rto(n: int, m: int) -> float:
+ """Square root of n/m, rounded to the nearest integer using round-to-odd."""
+ # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
+ a = math.isqrt(n // m)
+ return a | (a*a*m != n)
+
+# For 53 bit precision floats, the _sqrt_frac() shift is 109.
+_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
+
+def _sqrt_frac(n: int, m: int) -> float:
+ """Square root of n/m as a float, correctly rounded."""
+ # See principle and proof sketch at: https://bugs.python.org/msg407078
+ q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
+ if q >= 0:
+ numerator = _isqrt_frac_rto(n, m << 2 * q) << q
+ denominator = 1
+ else:
+ numerator = _isqrt_frac_rto(n << -2 * q, m)
+ denominator = 1 << -q
+ return numerator / denominator # Convert to float
+
# === Measures of central tendency (averages) ===
@@ -837,14 +859,17 @@ def stdev(data, xbar=None):
1.0810874155219827
"""
- # Fixme: Despite the exact sum of squared deviations, some inaccuracy
- # remain because there are two rounding steps. The first occurs in
- # the _convert() step for variance(), the second occurs in math.sqrt().
- var = variance(data, xbar)
- try:
+ if iter(data) is data:
+ data = list(data)
+ n = len(data)
+ if n < 2:
+ raise StatisticsError('stdev requires at least two data points')
+ T, ss = _ss(data, xbar)
+ mss = ss / (n - 1)
+ if hasattr(T, 'sqrt'):
+ var = _convert(mss, T)
return var.sqrt()
- except AttributeError:
- return math.sqrt(var)
+ return _sqrt_frac(mss.numerator, mss.denominator)
def pstdev(data, mu=None):
@@ -856,14 +881,17 @@ def pstdev(data, mu=None):
0.986893273527251
"""
- # Fixme: Despite the exact sum of squared deviations, some inaccuracy
- # remain because there are two rounding steps. The first occurs in
- # the _convert() step for pvariance(), the second occurs in math.sqrt().
- var = pvariance(data, mu)
- try:
+ if iter(data) is data:
+ data = list(data)
+ n = len(data)
+ if n < 1:
+ raise StatisticsError('pstdev requires at least one data point')
+ T, ss = _ss(data, mu)
+ mss = ss / n
+ if hasattr(T, 'sqrt'):
+ var = _convert(mss, T)
return var.sqrt()
- except AttributeError:
- return math.sqrt(var)
+ return _sqrt_frac(mss.numerator, mss.denominator)
# === Statistics for relations between two inputs ===