diff options
author | Raymond Hettinger <rhettinger@users.noreply.github.com> | 2021-05-21 03:22:26 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-05-21 03:22:26 (GMT) |
commit | be4dd7fcd93ed29d362c4bbcc48151bc619d6595 (patch) | |
tree | fca75e6315657f7d7fc8ad1355a31e774e1ee4bf /Lib | |
parent | 18f41c04ff4161531f4d08631059fd3ed37c0218 (diff) | |
download | cpython-be4dd7fcd93ed29d362c4bbcc48151bc619d6595.zip cpython-be4dd7fcd93ed29d362c4bbcc48151bc619d6595.tar.gz cpython-be4dd7fcd93ed29d362c4bbcc48151bc619d6595.tar.bz2 |
bpo-44150: Support optional weights parameter for fmean() (GH-26175)
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/statistics.py | 25 | ||||
-rw-r--r-- | Lib/test/test_statistics.py | 21 |
2 files changed, 39 insertions, 7 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py index 5d38f85..bd3813c 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -136,7 +136,7 @@ from decimal import Decimal from itertools import groupby, repeat from bisect import bisect_left, bisect_right from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum -from operator import itemgetter +from operator import itemgetter, mul from collections import Counter, namedtuple # === Exceptions === @@ -345,7 +345,7 @@ def mean(data): return _convert(total / n, T) -def fmean(data): +def fmean(data, weights=None): """Convert data to floats and compute the arithmetic mean. This runs faster than the mean() function and it always returns a float. @@ -363,13 +363,24 @@ def fmean(data): nonlocal n for n, x in enumerate(iterable, start=1): yield x - total = fsum(count(data)) - else: + data = count(data) + if weights is None: total = fsum(data) - try: + if not n: + raise StatisticsError('fmean requires at least one data point') return total / n - except ZeroDivisionError: - raise StatisticsError('fmean requires at least one data point') from None + try: + num_weights = len(weights) + except TypeError: + weights = list(weights) + num_weights = len(weights) + num = fsum(map(mul, data, weights)) + if n != num_weights: + raise StatisticsError('data and weights must be the same length') + den = fsum(weights) + if not den: + raise StatisticsError('sum of weights must be non-zero') + return num / den def geometric_mean(data): diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index 70d269d..3e6e17a 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -1972,6 +1972,27 @@ class TestFMean(unittest.TestCase): with self.assertRaises(ValueError): fmean([Inf, -Inf]) + def test_weights(self): + fmean = statistics.fmean + StatisticsError = statistics.StatisticsError + self.assertEqual( + fmean([10, 10, 10, 50], [0.25] * 4), + fmean([10, 10, 10, 50])) + self.assertEqual( + fmean([10, 10, 20], [0.25, 0.25, 0.50]), + fmean([10, 10, 20, 20])) + self.assertEqual( # inputs are iterators + fmean(iter([10, 10, 20]), iter([0.25, 0.25, 0.50])), + fmean([10, 10, 20, 20])) + with self.assertRaises(StatisticsError): + fmean([10, 20, 30], [1, 2]) # unequal lengths + with self.assertRaises(StatisticsError): + fmean(iter([10, 20, 30]), iter([1, 2])) # unequal lengths + with self.assertRaises(StatisticsError): + fmean([10, 20], [-1, 1]) # sum of weights is zero + with self.assertRaises(StatisticsError): + fmean(iter([10, 20]), iter([-1, 1])) # sum of weights is zero + # === Tests for variances and standard deviations === |