summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorRaymond Hettinger <rhettinger@users.noreply.github.com>2023-03-12 17:48:25 (GMT)
committerGitHub <noreply@github.com>2023-03-12 17:48:25 (GMT)
commit6cd7572f859a32a1f4626644c3e8139055df59e3 (patch)
treefd2e5d887c2ec6d274aaa2dfdfc54041145adc33
parente6210621bee4ac10e18b4adc11229b8cc1ee788d (diff)
downloadcpython-6cd7572f859a32a1f4626644c3e8139055df59e3.zip
cpython-6cd7572f859a32a1f4626644c3e8139055df59e3.tar.gz
cpython-6cd7572f859a32a1f4626644c3e8139055df59e3.tar.bz2
Optimize fmean() weighted average (#102626)
-rw-r--r--Lib/statistics.py34
1 files changed, 16 insertions, 18 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py
index 07d1fd5..7d5d750 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -136,9 +136,9 @@ from fractions import Fraction
from decimal import Decimal
from itertools import count, groupby, repeat
from bisect import bisect_left, bisect_right
-from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
+from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum, sumprod
from functools import reduce
-from operator import mul, itemgetter
+from operator import itemgetter
from collections import Counter, namedtuple, defaultdict
_SQRT2 = sqrt(2.0)
@@ -496,28 +496,26 @@ def fmean(data, weights=None):
>>> fmean([3.5, 4.0, 5.25])
4.25
"""
- try:
- n = len(data)
- except TypeError:
- # Handle iterators that do not define __len__().
- n = 0
- def count(iterable):
- nonlocal n
- for n, x in enumerate(iterable, start=1):
- yield x
- data = count(data)
if weights is None:
+ try:
+ n = len(data)
+ except TypeError:
+ # Handle iterators that do not define __len__().
+ n = 0
+ def count(iterable):
+ nonlocal n
+ for n, x in enumerate(iterable, start=1):
+ yield x
+ data = count(data)
total = fsum(data)
if not n:
raise StatisticsError('fmean requires at least one data point')
return total / n
- try:
- num_weights = len(weights)
- except TypeError:
+ if not isinstance(weights, (list, tuple)):
weights = list(weights)
- num_weights = len(weights)
- num = fsum(map(mul, data, weights))
- if n != num_weights:
+ try:
+ num = sumprod(data, weights)
+ except ValueError:
raise StatisticsError('data and weights must be the same length')
den = fsum(weights)
if not den: