diff options
author | Raymond Hettinger <rhettinger@users.noreply.github.com> | 2021-11-27 05:54:50 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-27 05:54:50 (GMT) |
commit | af9ee57b96cb872df6574e36027cc753417605f9 (patch) | |
tree | f0cce757d3ce53ff64b31706875b3b8290f6f0ab /Lib/statistics.py | |
parent | db55f3fabafc046e4fca907210ced4ce16bf58d6 (diff) | |
download | cpython-af9ee57b96cb872df6574e36027cc753417605f9.zip cpython-af9ee57b96cb872df6574e36027cc753417605f9.tar.gz cpython-af9ee57b96cb872df6574e36027cc753417605f9.tar.bz2 |
bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (GH-29736)
* Inlined code from variance functions
* Added helper functions for the float square root of a fraction
* Call helper functions
* Add blurb
* Fix over-specified test
* Add a test for the _sqrt_frac() helper function
* Increase the tested range
* Add type hints to the internal function.
* Fix test for correct rounding
* Simplify ⌊√(n/m)⌋ calculation
Co-authored-by: Mark Dickinson <dickinsm@gmail.com>
* Add comment and beef-up tests
* Test for zero denominator
* Add algorithmic references
* Add test for the _isqrt_frac_rto() helper function.
* Compute the 109 instead of hard-wiring it
* Stronger test for _isqrt_frac_rto()
* Bigger range
* Bigger range
* Replace float() call with int/int division to be parallel with the other code path.
* Factor out division. Update proof link. Remove internal type declaration
Co-authored-by: Mark Dickinson <dickinsm@gmail.com>
Diffstat (limited to 'Lib/statistics.py')
-rw-r--r-- | Lib/statistics.py | 56 |
1 files changed, 42 insertions, 14 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py index 5c3f77d..cf8eaa0 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -130,6 +130,7 @@ __all__ = [ import math import numbers import random +import sys from fractions import Fraction from decimal import Decimal @@ -304,6 +305,27 @@ def _fail_neg(values, errmsg='negative value'): raise StatisticsError(errmsg) yield x +def _isqrt_frac_rto(n: int, m: int) -> float: + """Square root of n/m, rounded to the nearest integer using round-to-odd.""" + # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf + a = math.isqrt(n // m) + return a | (a*a*m != n) + +# For 53 bit precision floats, the _sqrt_frac() shift is 109. +_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3 + +def _sqrt_frac(n: int, m: int) -> float: + """Square root of n/m as a float, correctly rounded.""" + # See principle and proof sketch at: https://bugs.python.org/msg407078 + q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2 + if q >= 0: + numerator = _isqrt_frac_rto(n, m << 2 * q) << q + denominator = 1 + else: + numerator = _isqrt_frac_rto(n << -2 * q, m) + denominator = 1 << -q + return numerator / denominator # Convert to float + # === Measures of central tendency (averages) === @@ -837,14 +859,17 @@ def stdev(data, xbar=None): 1.0810874155219827 """ - # Fixme: Despite the exact sum of squared deviations, some inaccuracy - # remain because there are two rounding steps. The first occurs in - # the _convert() step for variance(), the second occurs in math.sqrt(). - var = variance(data, xbar) - try: + if iter(data) is data: + data = list(data) + n = len(data) + if n < 2: + raise StatisticsError('stdev requires at least two data points') + T, ss = _ss(data, xbar) + mss = ss / (n - 1) + if hasattr(T, 'sqrt'): + var = _convert(mss, T) return var.sqrt() - except AttributeError: - return math.sqrt(var) + return _sqrt_frac(mss.numerator, mss.denominator) def pstdev(data, mu=None): @@ -856,14 +881,17 @@ def pstdev(data, mu=None): 0.986893273527251 """ - # Fixme: Despite the exact sum of squared deviations, some inaccuracy - # remain because there are two rounding steps. The first occurs in - # the _convert() step for pvariance(), the second occurs in math.sqrt(). - var = pvariance(data, mu) - try: + if iter(data) is data: + data = list(data) + n = len(data) + if n < 1: + raise StatisticsError('pstdev requires at least one data point') + T, ss = _ss(data, mu) + mss = ss / n + if hasattr(T, 'sqrt'): + var = _convert(mss, T) return var.sqrt() - except AttributeError: - return math.sqrt(var) + return _sqrt_frac(mss.numerator, mss.denominator) # === Statistics for relations between two inputs === |