From af9ee57b96cb872df6574e36027cc753417605f9 Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Fri, 26 Nov 2021 22:54:50 -0700 Subject: bpo-45876: Improve accuracy for stdev() and pstdev() in statistics (GH-29736) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Inlined code from variance functions * Added helper functions for the float square root of a fraction * Call helper functions * Add blurb * Fix over-specified test * Add a test for the _sqrt_frac() helper function * Increase the tested range * Add type hints to the internal function. * Fix test for correct rounding * Simplify ⌊√(n/m)⌋ calculation Co-authored-by: Mark Dickinson * Add comment and beef-up tests * Test for zero denominator * Add algorithmic references * Add test for the _isqrt_frac_rto() helper function. * Compute the 109 instead of hard-wiring it * Stronger test for _isqrt_frac_rto() * Bigger range * Bigger range * Replace float() call with int/int division to be parallel with the other code path. * Factor out division. Update proof link. Remove internal type declaration Co-authored-by: Mark Dickinson --- Lib/statistics.py | 56 ++++++++++++++----- Lib/test/test_statistics.py | 65 +++++++++++++++++++++- .../2021-11-23-15-36-56.bpo-45876.NO8Yaj.rst | 2 + 3 files changed, 107 insertions(+), 16 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2021-11-23-15-36-56.bpo-45876.NO8Yaj.rst diff --git a/Lib/statistics.py b/Lib/statistics.py index 5c3f77d..cf8eaa0 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -130,6 +130,7 @@ __all__ = [ import math import numbers import random +import sys from fractions import Fraction from decimal import Decimal @@ -304,6 +305,27 @@ def _fail_neg(values, errmsg='negative value'): raise StatisticsError(errmsg) yield x +def _isqrt_frac_rto(n: int, m: int) -> float: + """Square root of n/m, rounded to the nearest integer using round-to-odd.""" + # Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf + a = math.isqrt(n // m) + return a | (a*a*m != n) + +# For 53 bit precision floats, the _sqrt_frac() shift is 109. +_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3 + +def _sqrt_frac(n: int, m: int) -> float: + """Square root of n/m as a float, correctly rounded.""" + # See principle and proof sketch at: https://bugs.python.org/msg407078 + q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2 + if q >= 0: + numerator = _isqrt_frac_rto(n, m << 2 * q) << q + denominator = 1 + else: + numerator = _isqrt_frac_rto(n << -2 * q, m) + denominator = 1 << -q + return numerator / denominator # Convert to float + # === Measures of central tendency (averages) === @@ -837,14 +859,17 @@ def stdev(data, xbar=None): 1.0810874155219827 """ - # Fixme: Despite the exact sum of squared deviations, some inaccuracy - # remain because there are two rounding steps. The first occurs in - # the _convert() step for variance(), the second occurs in math.sqrt(). - var = variance(data, xbar) - try: + if iter(data) is data: + data = list(data) + n = len(data) + if n < 2: + raise StatisticsError('stdev requires at least two data points') + T, ss = _ss(data, xbar) + mss = ss / (n - 1) + if hasattr(T, 'sqrt'): + var = _convert(mss, T) return var.sqrt() - except AttributeError: - return math.sqrt(var) + return _sqrt_frac(mss.numerator, mss.denominator) def pstdev(data, mu=None): @@ -856,14 +881,17 @@ def pstdev(data, mu=None): 0.986893273527251 """ - # Fixme: Despite the exact sum of squared deviations, some inaccuracy - # remain because there are two rounding steps. The first occurs in - # the _convert() step for pvariance(), the second occurs in math.sqrt(). - var = pvariance(data, mu) - try: + if iter(data) is data: + data = list(data) + n = len(data) + if n < 1: + raise StatisticsError('pstdev requires at least one data point') + T, ss = _ss(data, mu) + mss = ss / n + if hasattr(T, 'sqrt'): + var = _convert(mss, T) return var.sqrt() - except AttributeError: - return math.sqrt(var) + return _sqrt_frac(mss.numerator, mss.denominator) # === Statistics for relations between two inputs === diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index c0e427d..771a03e 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -9,13 +9,14 @@ import collections.abc import copy import decimal import doctest +import itertools import math import pickle import random import sys import unittest from test import support -from test.support import import_helper +from test.support import import_helper, requires_IEEE_754 from decimal import Decimal from fractions import Fraction @@ -2161,6 +2162,66 @@ class TestPStdev(VarianceStdevMixin, NumericTestCase): self.assertEqual(self.func(data), 2.5) self.assertEqual(self.func(data, mu=0.5), 6.5) +class TestSqrtHelpers(unittest.TestCase): + + def test_isqrt_frac_rto(self): + for n, m in itertools.product(range(100), range(1, 1000)): + r = statistics._isqrt_frac_rto(n, m) + self.assertIsInstance(r, int) + if r*r*m == n: + # Root is exact + continue + # Inexact, so the root should be odd + self.assertEqual(r&1, 1) + # Verify correct rounding + self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2) + + @requires_IEEE_754 + def test_sqrt_frac(self): + + def is_root_correctly_rounded(x: Fraction, root: float) -> bool: + if not x: + return root == 0.0 + + # Extract adjacent representable floats + r_up: float = math.nextafter(root, math.inf) + r_down: float = math.nextafter(root, -math.inf) + assert r_down < root < r_up + + # Convert to fractions for exact arithmetic + frac_root: Fraction = Fraction(root) + half_way_up: Fraction = (frac_root + Fraction(r_up)) / 2 + half_way_down: Fraction = (frac_root + Fraction(r_down)) / 2 + + # Check a closed interval. + # Does not test for a midpoint rounding rule. + return half_way_down ** 2 <= x <= half_way_up ** 2 + + randrange = random.randrange + + for i in range(60_000): + numerator: int = randrange(10 ** randrange(50)) + denonimator: int = randrange(10 ** randrange(50)) + 1 + with self.subTest(numerator=numerator, denonimator=denonimator): + x: Fraction = Fraction(numerator, denonimator) + root: float = statistics._sqrt_frac(numerator, denonimator) + self.assertTrue(is_root_correctly_rounded(x, root)) + + # Verify that corner cases and error handling match math.sqrt() + self.assertEqual(statistics._sqrt_frac(0, 1), 0.0) + with self.assertRaises(ValueError): + statistics._sqrt_frac(-1, 1) + with self.assertRaises(ValueError): + statistics._sqrt_frac(1, -1) + + # Error handling for zero denominator matches that for Fraction(1, 0) + with self.assertRaises(ZeroDivisionError): + statistics._sqrt_frac(1, 0) + + # The result is well defined if both inputs are negative + self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0)) + + class TestStdev(VarianceStdevMixin, NumericTestCase): # Tests for sample standard deviation. def setUp(self): @@ -2175,7 +2236,7 @@ class TestStdev(VarianceStdevMixin, NumericTestCase): # Test that stdev is, in fact, the square root of variance. data = [random.uniform(-2, 9) for _ in range(1000)] expected = math.sqrt(statistics.variance(data)) - self.assertEqual(self.func(data), expected) + self.assertAlmostEqual(self.func(data), expected) def test_center_not_at_mean(self): data = (1.0, 2.0) diff --git a/Misc/NEWS.d/next/Library/2021-11-23-15-36-56.bpo-45876.NO8Yaj.rst b/Misc/NEWS.d/next/Library/2021-11-23-15-36-56.bpo-45876.NO8Yaj.rst new file mode 100644 index 0000000..889ed6c --- /dev/null +++ b/Misc/NEWS.d/next/Library/2021-11-23-15-36-56.bpo-45876.NO8Yaj.rst @@ -0,0 +1,2 @@ +Improve the accuracy of stdev() and pstdev() in the statistics module. When +the inputs are floats or fractions, the output is a correctly rounded float -- cgit v0.12