summaryrefslogtreecommitdiffstats
path: root/Lib/statistics.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/statistics.py')
-rw-r--r--Lib/statistics.py418
1 files changed, 359 insertions, 59 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py
index 4f5c1c1..7d53e0c 100644
--- a/Lib/statistics.py
+++ b/Lib/statistics.py
@@ -1,20 +1,3 @@
-## Module statistics.py
-##
-## Copyright (c) 2013 Steven D'Aprano <steve+python@pearwood.info>.
-##
-## Licensed under the Apache License, Version 2.0 (the "License");
-## you may not use this file except in compliance with the License.
-## You may obtain a copy of the License at
-##
-## http://www.apache.org/licenses/LICENSE-2.0
-##
-## Unless required by applicable law or agreed to in writing, software
-## distributed under the License is distributed on an "AS IS" BASIS,
-## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-## See the License for the specific language governing permissions and
-## limitations under the License.
-
-
"""
Basic statistics module.
@@ -28,6 +11,8 @@ Calculating averages
Function Description
================== =============================================
mean Arithmetic mean (average) of data.
+geometric_mean Geometric mean of data.
+harmonic_mean Harmonic mean of data.
median Median (middle value) of data.
median_low Low median of data.
median_high High median of data.
@@ -95,16 +80,18 @@ A single exception is defined: StatisticsError is a subclass of ValueError.
__all__ = [ 'StatisticsError',
'pstdev', 'pvariance', 'stdev', 'variance',
'median', 'median_low', 'median_high', 'median_grouped',
- 'mean', 'mode',
+ 'mean', 'mode', 'geometric_mean', 'harmonic_mean',
]
-
import collections
+import decimal
import math
+import numbers
from fractions import Fraction
from decimal import Decimal
-from itertools import groupby
+from itertools import groupby, chain
+from bisect import bisect_left, bisect_right
@@ -134,7 +121,8 @@ def _sum(data, start=0):
Some sources of round-off error will be avoided:
- >>> _sum([1e50, 1, -1e50] * 1000) # Built-in sum returns zero.
+ # Built-in sum returns zero.
+ >>> _sum([1e50, 1, -1e50] * 1000)
(<class 'float'>, Fraction(1000, 1), 3000)
Fractions and Decimals are also supported:
@@ -223,56 +211,26 @@ def _exact_ratio(x):
# Optimise the common case of floats. We expect that the most often
# used numeric type will be builtin floats, so try to make this as
# fast as possible.
- if type(x) is float:
+ if type(x) is float or type(x) is Decimal:
return x.as_integer_ratio()
try:
# x may be an int, Fraction, or Integral ABC.
return (x.numerator, x.denominator)
except AttributeError:
try:
- # x may be a float subclass.
+ # x may be a float or Decimal subclass.
return x.as_integer_ratio()
except AttributeError:
- try:
- # x may be a Decimal.
- return _decimal_to_ratio(x)
- except AttributeError:
- # Just give up?
- pass
+ # Just give up?
+ pass
except (OverflowError, ValueError):
# float NAN or INF.
- assert not math.isfinite(x)
+ assert not _isfinite(x)
return (x, None)
msg = "can't convert type '{}' to numerator/denominator"
raise TypeError(msg.format(type(x).__name__))
-# FIXME This is faster than Fraction.from_decimal, but still too slow.
-def _decimal_to_ratio(d):
- """Convert Decimal d to exact integer ratio (numerator, denominator).
-
- >>> from decimal import Decimal
- >>> _decimal_to_ratio(Decimal("2.6"))
- (26, 10)
-
- """
- sign, digits, exp = d.as_tuple()
- if exp in ('F', 'n', 'N'): # INF, NAN, sNAN
- assert not d.is_finite()
- return (d, None)
- num = 0
- for digit in digits:
- num = num*10 + digit
- if exp < 0:
- den = 10**-exp
- else:
- num *= 10**exp
- den = 1
- if sign:
- num = -num
- return (num, den)
-
-
def _convert(value, T):
"""Convert value to given numeric type T."""
if type(value) is T:
@@ -305,6 +263,253 @@ def _counts(data):
return table
+def _find_lteq(a, x):
+ 'Locate the leftmost value exactly equal to x'
+ i = bisect_left(a, x)
+ if i != len(a) and a[i] == x:
+ return i
+ raise ValueError
+
+
+def _find_rteq(a, l, x):
+ 'Locate the rightmost value exactly equal to x'
+ i = bisect_right(a, x, lo=l)
+ if i != (len(a)+1) and a[i-1] == x:
+ return i-1
+ raise ValueError
+
+
+def _fail_neg(values, errmsg='negative value'):
+ """Iterate over values, failing if any are less than zero."""
+ for x in values:
+ if x < 0:
+ raise StatisticsError(errmsg)
+ yield x
+
+
+class _nroot_NS:
+ """Hands off! Don't touch!
+
+ Everything inside this namespace (class) is an even-more-private
+ implementation detail of the private _nth_root function.
+ """
+ # This class exists only to be used as a namespace, for convenience
+ # of being able to keep the related functions together, and to
+ # collapse the group in an editor. If this were C# or C++, I would
+ # use a Namespace, but the closest Python has is a class.
+ #
+ # FIXME possibly move this out into a separate module?
+ # That feels like overkill, and may encourage people to treat it as
+ # a public feature.
+ def __init__(self):
+ raise TypeError('namespace only, do not instantiate')
+
+ def nth_root(x, n):
+ """Return the positive nth root of numeric x.
+
+ This may be more accurate than ** or pow():
+
+ >>> math.pow(1000, 1.0/3) #doctest:+SKIP
+ 9.999999999999998
+
+ >>> _nth_root(1000, 3)
+ 10.0
+ >>> _nth_root(11**5, 5)
+ 11.0
+ >>> _nth_root(2, 12)
+ 1.0594630943592953
+
+ """
+ if not isinstance(n, int):
+ raise TypeError('degree n must be an int')
+ if n < 2:
+ raise ValueError('degree n must be 2 or more')
+ if isinstance(x, decimal.Decimal):
+ return _nroot_NS.decimal_nroot(x, n)
+ elif isinstance(x, numbers.Real):
+ return _nroot_NS.float_nroot(x, n)
+ else:
+ raise TypeError('expected a number, got %s') % type(x).__name__
+
+ def float_nroot(x, n):
+ """Handle nth root of Reals, treated as a float."""
+ assert isinstance(n, int) and n > 1
+ if x < 0:
+ raise ValueError('domain error: root of negative number')
+ elif x == 0:
+ return math.copysign(0.0, x)
+ elif x > 0:
+ try:
+ isinfinity = math.isinf(x)
+ except OverflowError:
+ return _nroot_NS.bignum_nroot(x, n)
+ else:
+ if isinfinity:
+ return float('inf')
+ else:
+ return _nroot_NS.nroot(x, n)
+ else:
+ assert math.isnan(x)
+ return float('nan')
+
+ def nroot(x, n):
+ """Calculate x**(1/n), then improve the answer."""
+ # This uses math.pow() to calculate an initial guess for the root,
+ # then uses the iterated nroot algorithm to improve it.
+ #
+ # By my testing, about 8% of the time the iterated algorithm ends
+ # up converging to a result which is less accurate than the initial
+ # guess. [FIXME: is this still true?] In that case, we use the
+ # guess instead of the "improved" value. This way, we're never
+ # less accurate than math.pow().
+ r1 = math.pow(x, 1.0/n)
+ eps1 = abs(r1**n - x)
+ if eps1 == 0.0:
+ # r1 is the exact root, so we're done. By my testing, this
+ # occurs about 80% of the time for x < 1 and 30% of the
+ # time for x > 1.
+ return r1
+ else:
+ try:
+ r2 = _nroot_NS.iterated_nroot(x, n, r1)
+ except RuntimeError:
+ return r1
+ else:
+ eps2 = abs(r2**n - x)
+ if eps1 < eps2:
+ return r1
+ return r2
+
+ def iterated_nroot(a, n, g):
+ """Return the nth root of a, starting with guess g.
+
+ This is a special case of Newton's Method.
+ https://en.wikipedia.org/wiki/Nth_root_algorithm
+ """
+ np = n - 1
+ def iterate(r):
+ try:
+ return (np*r + a/math.pow(r, np))/n
+ except OverflowError:
+ # If r is large enough, r**np may overflow. If that
+ # happens, r**-np will be small, but not necessarily zero.
+ return (np*r + a*math.pow(r, -np))/n
+ # With a good guess, such as g = a**(1/n), this will converge in
+ # only a few iterations. However a poor guess can take thousands
+ # of iterations to converge, if at all. We guard against poor
+ # guesses by setting an upper limit to the number of iterations.
+ r1 = g
+ r2 = iterate(g)
+ for i in range(1000):
+ if r1 == r2:
+ break
+ # Use Floyd's cycle-finding algorithm to avoid being trapped
+ # in a cycle.
+ # https://en.wikipedia.org/wiki/Cycle_detection#Tortoise_and_hare
+ r1 = iterate(r1)
+ r2 = iterate(iterate(r2))
+ else:
+ # If the guess is particularly bad, the above may fail to
+ # converge in any reasonable time.
+ raise RuntimeError('nth-root failed to converge')
+ return r2
+
+ def decimal_nroot(x, n):
+ """Handle nth root of Decimals."""
+ assert isinstance(x, decimal.Decimal)
+ assert isinstance(n, int)
+ if x.is_snan():
+ # Signalling NANs always raise.
+ raise decimal.InvalidOperation('nth-root of snan')
+ if x.is_qnan():
+ # Quiet NANs only raise if the context is set to raise,
+ # otherwise return a NAN.
+ ctx = decimal.getcontext()
+ if ctx.traps[decimal.InvalidOperation]:
+ raise decimal.InvalidOperation('nth-root of nan')
+ else:
+ # Preserve the input NAN.
+ return x
+ if x < 0:
+ raise ValueError('domain error: root of negative number')
+ if x.is_infinite():
+ return x
+ # FIXME this hasn't had the extensive testing of the float
+ # version _iterated_nroot so there's possibly some buggy
+ # corner cases buried in here. Can it overflow? Fail to
+ # converge or get trapped in a cycle? Converge to a less
+ # accurate root?
+ np = n - 1
+ def iterate(r):
+ return (np*r + x/r**np)/n
+ r0 = x**(decimal.Decimal(1)/n)
+ assert isinstance(r0, decimal.Decimal)
+ r1 = iterate(r0)
+ while True:
+ if r1 == r0:
+ return r1
+ r0, r1 = r1, iterate(r1)
+
+ def bignum_nroot(x, n):
+ """Return the nth root of a positive huge number."""
+ assert x > 0
+ # I state without proof that ⁿ√x ≈ ⁿ√2·ⁿ√(x//2)
+ # and that for sufficiently big x the error is acceptable.
+ # We now halve x until it is small enough to get the root.
+ m = 0
+ while True:
+ x //= 2
+ m += 1
+ try:
+ y = float(x)
+ except OverflowError:
+ continue
+ break
+ a = _nroot_NS.nroot(y, n)
+ # At this point, we want the nth-root of 2**m, or 2**(m/n).
+ # We can write that as 2**(q + r/n) = 2**q * ⁿ√2**r where q = m//n.
+ q, r = divmod(m, n)
+ b = 2**q * _nroot_NS.nroot(2**r, n)
+ return a * b
+
+
+# This is the (private) function for calculating nth roots:
+_nth_root = _nroot_NS.nth_root
+assert type(_nth_root) is type(lambda: None)
+
+
+def _product(values):
+ """Return product of values as (exponent, mantissa)."""
+ errmsg = 'mixed Decimal and float is not supported'
+ prod = 1
+ for x in values:
+ if isinstance(x, float):
+ break
+ prod *= x
+ else:
+ return (0, prod)
+ if isinstance(prod, Decimal):
+ raise TypeError(errmsg)
+ # Since floats can overflow easily, we calculate the product as a
+ # sort of poor-man's BigFloat. Given that:
+ #
+ # x = 2**p * m # p == power or exponent (scale), m = mantissa
+ #
+ # we can calculate the product of two (or more) x values as:
+ #
+ # x1*x2 = 2**p1*m1 * 2**p2*m2 = 2**(p1+p2)*(m1*m2)
+ #
+ mant, scale = 1, 0 #math.frexp(prod) # FIXME
+ for y in chain([x], values):
+ if isinstance(y, Decimal):
+ raise TypeError(errmsg)
+ m1, e1 = math.frexp(y)
+ m2, e2 = math.frexp(mant)
+ scale += (e1 + e2)
+ mant = m1*m2
+ return (scale, mant)
+
+
# === Measures of central tendency (averages) ===
def mean(data):
@@ -333,6 +538,95 @@ def mean(data):
return _convert(total/n, T)
+def geometric_mean(data):
+ """Return the geometric mean of data.
+
+ The geometric mean is appropriate when averaging quantities which
+ are multiplied together rather than added, for example growth rates.
+ Suppose an investment grows by 10% in the first year, falls by 5% in
+ the second, then grows by 12% in the third, what is the average rate
+ of growth over the three years?
+
+ >>> geometric_mean([1.10, 0.95, 1.12])
+ 1.0538483123382172
+
+ giving an average growth of 5.385%. Using the arithmetic mean will
+ give approximately 5.667%, which is too high.
+
+ ``StatisticsError`` will be raised if ``data`` is empty, or any
+ element is less than zero.
+ """
+ if iter(data) is data:
+ data = list(data)
+ errmsg = 'geometric mean does not support negative values'
+ n = len(data)
+ if n < 1:
+ raise StatisticsError('geometric_mean requires at least one data point')
+ elif n == 1:
+ x = data[0]
+ if isinstance(g, (numbers.Real, Decimal)):
+ if x < 0:
+ raise StatisticsError(errmsg)
+ return x
+ else:
+ raise TypeError('unsupported type')
+ else:
+ scale, prod = _product(_fail_neg(data, errmsg))
+ r = _nth_root(prod, n)
+ if scale:
+ p, q = divmod(scale, n)
+ s = 2**p * _nth_root(2**q, n)
+ else:
+ s = 1
+ return s*r
+
+
+def harmonic_mean(data):
+ """Return the harmonic mean of data.
+
+ The harmonic mean, sometimes called the subcontrary mean, is the
+ reciprocal of the arithmetic mean of the reciprocals of the data,
+ and is often appropriate when averaging quantities which are rates
+ or ratios, for example speeds. Example:
+
+ Suppose an investor purchases an equal value of shares in each of
+ three companies, with P/E (price/earning) ratios of 2.5, 3 and 10.
+ What is the average P/E ratio for the investor's portfolio?
+
+ >>> harmonic_mean([2.5, 3, 10]) # For an equal investment portfolio.
+ 3.6
+
+ Using the arithmetic mean would give an average of about 5.167, which
+ is too high.
+
+ If ``data`` is empty, or any element is less than zero,
+ ``harmonic_mean`` will raise ``StatisticsError``.
+ """
+ # For a justification for using harmonic mean for P/E ratios, see
+ # http://fixthepitch.pellucid.com/comps-analysis-the-missing-harmony-of-summary-statistics/
+ # http://papers.ssrn.com/sol3/papers.cfm?abstract_id=2621087
+ if iter(data) is data:
+ data = list(data)
+ errmsg = 'harmonic mean does not support negative values'
+ n = len(data)
+ if n < 1:
+ raise StatisticsError('harmonic_mean requires at least one data point')
+ elif n == 1:
+ x = data[0]
+ if isinstance(x, (numbers.Real, Decimal)):
+ if x < 0:
+ raise StatisticsError(errmsg)
+ return x
+ else:
+ raise TypeError('unsupported type')
+ try:
+ T, total, count = _sum(1/x for x in _fail_neg(data, errmsg))
+ except ZeroDivisionError:
+ return 0
+ assert count == n
+ return _convert(n/total, T)
+
+
# FIXME: investigate ways to calculate medians without sorting? Quickselect?
def median(data):
"""Return the median (middle value) of numeric data.
@@ -442,9 +736,15 @@ def median_grouped(data, interval=1):
except TypeError:
# Mixed type. For now we just coerce to float.
L = float(x) - float(interval)/2
- cf = data.index(x) # Number of values below the median interval.
- # FIXME The following line could be more efficient for big lists.
- f = data.count(x) # Number of data points in the median interval.
+
+ # Uses bisection search to search for x in data with log(n) time complexity
+ # Find the position of leftmost occurrence of x in data
+ l1 = _find_lteq(data, x)
+ # Find the position of rightmost occurrence of x in data[l1...len(data)]
+ # Assuming always l1 <= l2
+ l2 = _find_rteq(data, l1, x)
+ cf = l1
+ f = l2 - l1 + 1
return L + interval*(n/2 - cf)/f