diff options
author | Raymond Hettinger <rhettinger@users.noreply.github.com> | 2023-03-12 17:48:25 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-12 17:48:25 (GMT) |
commit | 6cd7572f859a32a1f4626644c3e8139055df59e3 (patch) | |
tree | fd2e5d887c2ec6d274aaa2dfdfc54041145adc33 | |
parent | e6210621bee4ac10e18b4adc11229b8cc1ee788d (diff) | |
download | cpython-6cd7572f859a32a1f4626644c3e8139055df59e3.zip cpython-6cd7572f859a32a1f4626644c3e8139055df59e3.tar.gz cpython-6cd7572f859a32a1f4626644c3e8139055df59e3.tar.bz2 |
Optimize fmean() weighted average (#102626)
-rw-r--r-- | Lib/statistics.py | 34 |
1 files changed, 16 insertions, 18 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py index 07d1fd5..7d5d750 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -136,9 +136,9 @@ from fractions import Fraction from decimal import Decimal from itertools import count, groupby, repeat from bisect import bisect_left, bisect_right -from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum +from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod from functools import reduce -from operator import mul, itemgetter +from operator import itemgetter from collections import Counter, namedtuple, defaultdict _SQRT2 = sqrt(2.0) @@ -496,28 +496,26 @@ def fmean(data, weights=None): >>> fmean([3.5, 4.0, 5.25]) 4.25 """ - try: - n = len(data) - except TypeError: - # Handle iterators that do not define __len__(). - n = 0 - def count(iterable): - nonlocal n - for n, x in enumerate(iterable, start=1): - yield x - data = count(data) if weights is None: + try: + n = len(data) + except TypeError: + # Handle iterators that do not define __len__(). + n = 0 + def count(iterable): + nonlocal n + for n, x in enumerate(iterable, start=1): + yield x + data = count(data) total = fsum(data) if not n: raise StatisticsError('fmean requires at least one data point') return total / n - try: - num_weights = len(weights) - except TypeError: + if not isinstance(weights, (list, tuple)): weights = list(weights) - num_weights = len(weights) - num = fsum(map(mul, data, weights)) - if n != num_weights: + try: + num = sumprod(data, weights) + except ValueError: raise StatisticsError('data and weights must be the same length') den = fsum(weights) if not den: |