summaryrefslogtreecommitdiffstats
path: root/Lib/statistics.py
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2022-01-05 15:39:10 (GMT)
committerGitHub <noreply@github.com>2022-01-05 15:39:10 (GMT)
commit43aac29cbbb8a963a22c334b5b795d1e43417d6b (patch)
treef8cdea663608c13ba7ae12eac72f4b5c986406ac /Lib/statistics.py
parent46e4c257e7c26c813620232135781e6c53fe8d4d (diff)
downloadcpython-43aac29cbbb8a963a22c334b5b795d1e43417d6b.zip
cpython-43aac29cbbb8a963a22c334b5b795d1e43417d6b.tar.gz
cpython-43aac29cbbb8a963a22c334b5b795d1e43417d6b.tar.bz2
bpo-46257: Convert statistics._ss() to a single pass algorithm (GH-30403)
Diffstat (limited to 'Lib/statistics.py')
-rw-r--r--Lib/statistics.py100
1 files changed, 43 insertions, 57 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py
index c104571..eef2453 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -138,7 +138,7 @@ from itertools import groupby, repeat
from bisect import bisect_left, bisect_right
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
from operator import mul
-from collections import Counter, namedtuple
+from collections import Counter, namedtuple, defaultdict
_SQRT2 = sqrt(2.0)
@@ -202,6 +202,43 @@ def _sum(data):
return (T, total, count)
+def _ss(data, c=None):
+ """Return sum of square deviations of sequence data.
+
+ If ``c`` is None, the mean is calculated in one pass, and the deviations
+ from the mean are calculated in a second pass. Otherwise, deviations are
+ calculated from ``c`` as given. Use the second case with care, as it can
+ lead to garbage results.
+ """
+ if c is not None:
+ T, total, count = _sum((d := x - c) * d for x in data)
+ return (T, total, count)
+ count = 0
+ sx_partials = defaultdict(int)
+ sxx_partials = defaultdict(int)
+ T = int
+ for typ, values in groupby(data, type):
+ T = _coerce(T, typ) # or raise TypeError
+ for n, d in map(_exact_ratio, values):
+ count += 1
+ sx_partials[d] += n
+ sxx_partials[d] += n * n
+ if not count:
+ total = Fraction(0)
+ elif None in sx_partials:
+ # The sum will be a NAN or INF. We can ignore all the finite
+ # partials, and just look at this special one.
+ total = sx_partials[None]
+ assert not _isfinite(total)
+ else:
+ sx = sum(Fraction(n, d) for d, n in sx_partials.items())
+ sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items())
+ # This formula has poor numeric properties for floats,
+ # but with fractions it is exact.
+ total = (count * sxx - sx * sx) / count
+ return (T, total, count)
+
+
def _isfinite(x):
try:
return x.is_finite() # Likely a Decimal.
@@ -399,13 +436,9 @@ def mean(data):
If ``data`` is empty, StatisticsError will be raised.
"""
- if iter(data) is data:
- data = list(data)
- n = len(data)
+ T, total, n = _sum(data)
if n < 1:
raise StatisticsError('mean requires at least one data point')
- T, total, count = _sum(data)
- assert count == n
return _convert(total / n, T)
@@ -776,41 +809,6 @@ def quantiles(data, *, n=4, method='exclusive'):
# See http://mathworld.wolfram.com/Variance.html
# http://mathworld.wolfram.com/SampleVariance.html
-# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
-#
-# Under no circumstances use the so-called "computational formula for
-# variance", as that is only suitable for hand calculations with a small
-# amount of low-precision data. It has terrible numeric properties.
-#
-# See a comparison of three computational methods here:
-# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/
-
-def _ss(data, c=None):
- """Return sum of square deviations of sequence data.
-
- If ``c`` is None, the mean is calculated in one pass, and the deviations
- from the mean are calculated in a second pass. Otherwise, deviations are
- calculated from ``c`` as given. Use the second case with care, as it can
- lead to garbage results.
- """
- if c is not None:
- T, total, count = _sum((d := x - c) * d for x in data)
- return (T, total)
- T, total, count = _sum(data)
- mean_n, mean_d = (total / count).as_integer_ratio()
- partials = Counter()
- for n, d in map(_exact_ratio, data):
- diff_n = n * mean_d - d * mean_n
- diff_d = d * mean_d
- partials[diff_d * diff_d] += diff_n * diff_n
- if None in partials:
- # The sum will be a NAN or INF. We can ignore all the finite
- # partials, and just look at this special one.
- total = partials[None]
- assert not _isfinite(total)
- else:
- total = sum(Fraction(n, d) for d, n in partials.items())
- return (T, total)
def variance(data, xbar=None):
@@ -851,12 +849,9 @@ def variance(data, xbar=None):
Fraction(67, 108)
"""
- if iter(data) is data:
- data = list(data)
- n = len(data)
+ T, ss, n = _ss(data, xbar)
if n < 2:
raise StatisticsError('variance requires at least two data points')
- T, ss = _ss(data, xbar)
return _convert(ss / (n - 1), T)
@@ -895,12 +890,9 @@ def pvariance(data, mu=None):
Fraction(13, 72)
"""
- if iter(data) is data:
- data = list(data)
- n = len(data)
+ T, ss, n = _ss(data, mu)
if n < 1:
raise StatisticsError('pvariance requires at least one data point')
- T, ss = _ss(data, mu)
return _convert(ss / n, T)
@@ -913,12 +905,9 @@ def stdev(data, xbar=None):
1.0810874155219827
"""
- if iter(data) is data:
- data = list(data)
- n = len(data)
+ T, ss, n = _ss(data, xbar)
if n < 2:
raise StatisticsError('stdev requires at least two data points')
- T, ss = _ss(data, xbar)
mss = ss / (n - 1)
if issubclass(T, Decimal):
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
@@ -934,12 +923,9 @@ def pstdev(data, mu=None):
0.986893273527251
"""
- if iter(data) is data:
- data = list(data)
- n = len(data)
+ T, ss, n = _ss(data, mu)
if n < 1:
raise StatisticsError('pstdev requires at least one data point')
- T, ss = _ss(data, mu)
mss = ss / n
if issubclass(T, Decimal):
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)