diff options
author | Steven D'Aprano <steve@pearwood.info> | 2016-08-09 03:58:10 (GMT) |
---|---|---|
committer | Steven D'Aprano <steve@pearwood.info> | 2016-08-09 03:58:10 (GMT) |
commit | 9a2be91c6b974c82d61ce7160f36d0b293e66a2b (patch) | |
tree | bafafcd919fe0cad84648e44acd805514d518726 | |
parent | e7fef52f9865b29785be04d1dc7d9b0b52f3c085 (diff) | |
download | cpython-9a2be91c6b974c82d61ce7160f36d0b293e66a2b.zip cpython-9a2be91c6b974c82d61ce7160f36d0b293e66a2b.tar.gz cpython-9a2be91c6b974c82d61ce7160f36d0b293e66a2b.tar.bz2 |
Issue27181 add geometric mean.
-rw-r--r-- | Lib/statistics.py | 267 | ||||
-rw-r--r-- | Lib/test/test_statistics.py | 285 |
2 files changed, 552 insertions, 0 deletions
diff --git a/Lib/statistics.py b/Lib/statistics.py index 8c41dd3..f4b49b5 100644 --- a/Lib/statistics.py +++ b/Lib/statistics.py @@ -303,6 +303,230 @@ def _fail_neg(values, errmsg='negative value'): yield x +class _nroot_NS: + """Hands off! Don't touch! + + Everything inside this namespace (class) is an even-more-private + implementation detail of the private _nth_root function. + """ + # This class exists only to be used as a namespace, for convenience + # of being able to keep the related functions together, and to + # collapse the group in an editor. If this were C# or C++, I would + # use a Namespace, but the closest Python has is a class. + # + # FIXME possibly move this out into a separate module? + # That feels like overkill, and may encourage people to treat it as + # a public feature. + def __init__(self): + raise TypeError('namespace only, do not instantiate') + + def nth_root(x, n): + """Return the positive nth root of numeric x. + + This may be more accurate than ** or pow(): + + >>> math.pow(1000, 1.0/3) #doctest:+SKIP + 9.999999999999998 + + >>> _nth_root(1000, 3) + 10.0 + >>> _nth_root(11**5, 5) + 11.0 + >>> _nth_root(2, 12) + 1.0594630943592953 + + """ + if not isinstance(n, int): + raise TypeError('degree n must be an int') + if n < 2: + raise ValueError('degree n must be 2 or more') + if isinstance(x, decimal.Decimal): + return _nroot_NS.decimal_nroot(x, n) + elif isinstance(x, numbers.Real): + return _nroot_NS.float_nroot(x, n) + else: + raise TypeError('expected a number, got %s') % type(x).__name__ + + def float_nroot(x, n): + """Handle nth root of Reals, treated as a float.""" + assert isinstance(n, int) and n > 1 + if x < 0: + if n%2 == 0: + raise ValueError('domain error: even root of negative number') + else: + return -_nroot_NS.nroot(-x, n) + elif x == 0: + return math.copysign(0.0, x) + elif x > 0: + try: + isinfinity = math.isinf(x) + except OverflowError: + return _nroot_NS.bignum_nroot(x, n) + else: + if isinfinity: + return float('inf') + else: + return _nroot_NS.nroot(x, n) + else: + assert math.isnan(x) + return float('nan') + + def nroot(x, n): + """Calculate x**(1/n), then improve the answer.""" + # This uses math.pow() to calculate an initial guess for the root, + # then uses the iterated nroot algorithm to improve it. + # + # By my testing, about 8% of the time the iterated algorithm ends + # up converging to a result which is less accurate than the initial + # guess. [FIXME: is this still true?] In that case, we use the + # guess instead of the "improved" value. This way, we're never + # less accurate than math.pow(). + r1 = math.pow(x, 1.0/n) + eps1 = abs(r1**n - x) + if eps1 == 0.0: + # r1 is the exact root, so we're done. By my testing, this + # occurs about 80% of the time for x < 1 and 30% of the + # time for x > 1. + return r1 + else: + try: + r2 = _nroot_NS.iterated_nroot(x, n, r1) + except RuntimeError: + return r1 + else: + eps2 = abs(r2**n - x) + if eps1 < eps2: + return r1 + return r2 + + def iterated_nroot(a, n, g): + """Return the nth root of a, starting with guess g. + + This is a special case of Newton's Method. + https://en.wikipedia.org/wiki/Nth_root_algorithm + """ + np = n - 1 + def iterate(r): + try: + return (np*r + a/math.pow(r, np))/n + except OverflowError: + # If r is large enough, r**np may overflow. If that + # happens, r**-np will be small, but not necessarily zero. + return (np*r + a*math.pow(r, -np))/n + # With a good guess, such as g = a**(1/n), this will converge in + # only a few iterations. However a poor guess can take thousands + # of iterations to converge, if at all. We guard against poor + # guesses by setting an upper limit to the number of iterations. + r1 = g + r2 = iterate(g) + for i in range(1000): + if r1 == r2: + break + # Use Floyd's cycle-finding algorithm to avoid being trapped + # in a cycle. + # https://en.wikipedia.org/wiki/Cycle_detection#Tortoise_and_hare + r1 = iterate(r1) + r2 = iterate(iterate(r2)) + else: + # If the guess is particularly bad, the above may fail to + # converge in any reasonable time. + raise RuntimeError('nth-root failed to converge') + return r2 + + def decimal_nroot(x, n): + """Handle nth root of Decimals.""" + assert isinstance(x, decimal.Decimal) + assert isinstance(n, int) + if x.is_snan(): + # Signalling NANs always raise. + raise decimal.InvalidOperation('nth-root of snan') + if x.is_qnan(): + # Quiet NANs only raise if the context is set to raise, + # otherwise return a NAN. + ctx = decimal.getcontext() + if ctx.traps[decimal.InvalidOperation]: + raise decimal.InvalidOperation('nth-root of nan') + else: + # Preserve the input NAN. + return x + if x.is_infinite(): + return x + # FIXME this hasn't had the extensive testing of the float + # version _iterated_nroot so there's possibly some buggy + # corner cases buried in here. Can it overflow? Fail to + # converge or get trapped in a cycle? Converge to a less + # accurate root? + np = n - 1 + def iterate(r): + return (np*r + x/r**np)/n + r0 = x**(decimal.Decimal(1)/n) + assert isinstance(r0, decimal.Decimal) + r1 = iterate(r0) + while True: + if r1 == r0: + return r1 + r0, r1 = r1, iterate(r1) + + def bignum_nroot(x, n): + """Return the nth root of a positive huge number.""" + assert x > 0 + # I state without proof that ⁿ√x ≈ ⁿ√2·ⁿ√(x//2) + # and that for sufficiently big x the error is acceptible. + # We now halve x until it is small enough to get the root. + m = 0 + while True: + x //= 2 + m += 1 + try: + y = float(x) + except OverflowError: + continue + break + a = _nroot_NS.nroot(y, n) + # At this point, we want the nth-root of 2**m, or 2**(m/n). + # We can write that as 2**(q + r/n) = 2**q * ⁿ√2**r where q = m//n. + q, r = divmod(m, n) + b = 2**q * _nroot_NS.nroot(2**r, n) + return a * b + + +# This is the (private) function for calculating nth roots: +_nth_root = _nroot_NS.nth_root +assert type(_nth_root) is type(lambda: None) + + +def _product(values): + """Return product of values as (exponent, mantissa).""" + errmsg = 'mixed Decimal and float is not supported' + prod = 1 + for x in values: + if isinstance(x, float): + break + prod *= x + else: + return (0, prod) + if isinstance(prod, Decimal): + raise TypeError(errmsg) + # Since floats can overflow easily, we calculate the product as a + # sort of poor-man's BigFloat. Given that: + # + # x = 2**p * m # p == power or exponent (scale), m = mantissa + # + # we can calculate the product of two (or more) x values as: + # + # x1*x2 = 2**p1*m1 * 2**p2*m2 = 2**(p1+p2)*(m1*m2) + # + mant, scale = 1, 0 #math.frexp(prod) # FIXME + for y in chain([x], values): + if isinstance(y, Decimal): + raise TypeError(errmsg) + m1, e1 = math.frexp(y) + m2, e2 = math.frexp(mant) + scale += (e1 + e2) + mant = m1*m2 + return (scale, mant) + + # === Measures of central tendency (averages) === def mean(data): @@ -331,6 +555,49 @@ def mean(data): return _convert(total/n, T) +def geometric_mean(data): + """Return the geometric mean of data. + + The geometric mean is appropriate when averaging quantities which + are multiplied together rather than added, for example growth rates. + Suppose an investment grows by 10% in the first year, falls by 5% in + the second, then grows by 12% in the third, what is the average rate + of growth over the three years? + + >>> geometric_mean([1.10, 0.95, 1.12]) + 1.0538483123382172 + + giving an average growth of 5.385%. Using the arithmetic mean will + give approximately 5.667%, which is too high. + + ``StatisticsError`` will be raised if ``data`` is empty, or any + element is less than zero. + """ + if iter(data) is data: + data = list(data) + errmsg = 'geometric mean does not support negative values' + n = len(data) + if n < 1: + raise StatisticsError('geometric_mean requires at least one data point') + elif n == 1: + x = data[0] + if isinstance(g, (numbers.Real, Decimal)): + if x < 0: + raise StatisticsError(errmsg) + return x + else: + raise TypeError('unsupported type') + else: + scale, prod = _product(_fail_neg(data, errmsg)) + r = _nth_root(prod, n) + if scale: + p, q = divmod(scale, n) + s = 2**p * _nth_root(2**q, n) + else: + s = 1 + return s*r + + def harmonic_mean(data): """Return the harmonic mean of data. diff --git a/Lib/test/test_statistics.py b/Lib/test/test_statistics.py index 4b3fd36..8b0c01f 100644 --- a/Lib/test/test_statistics.py +++ b/Lib/test/test_statistics.py @@ -1010,6 +1010,291 @@ class FailNegTest(unittest.TestCase): self.assertEqual(errmsg, msg) +class Test_Product(NumericTestCase): + """Test the private _product function.""" + + def test_ints(self): + data = [1, 2, 5, 7, 9] + self.assertEqual(statistics._product(data), (0, 630)) + self.assertEqual(statistics._product(data*100), (0, 630**100)) + + def test_floats(self): + data = [1.0, 2.0, 4.0, 8.0] + self.assertEqual(statistics._product(data), (8, 0.25)) + + def test_overflow(self): + # Test with floats that overflow. + data = [1e300]*5 + self.assertEqual(statistics._product(data), (5980, 0.6928287951283193)) + + def test_fractions(self): + F = Fraction + data = [F(14, 23), F(69, 1), F(665, 529), F(299, 105), F(1683, 39)] + exp, mant = statistics._product(data) + self.assertEqual(exp, 0) + self.assertEqual(mant, F(2*3*7*11*17*19, 23)) + self.assertTrue(isinstance(mant, F)) + # Mixed Fraction and int. + data = [3, 25, F(2, 15)] + exp, mant = statistics._product(data) + self.assertEqual(exp, 0) + self.assertEqual(mant, F(10)) + self.assertTrue(isinstance(mant, F)) + + @unittest.expectedFailure + def test_decimal(self): + D = Decimal + data = [D('24.5'), D('17.6'), D('0.025'), D('1.3')] + assert False + + def test_mixed_decimal_float(self): + # Test that mixed Decimal and float raises. + self.assertRaises(TypeError, statistics._product, [1.0, Decimal(1)]) + self.assertRaises(TypeError, statistics._product, [Decimal(1), 1.0]) + + +class Test_Nth_Root(NumericTestCase): + """Test the functionality of the private _nth_root function.""" + + def setUp(self): + self.nroot = statistics._nth_root + + # --- Special values (infinities, NANs, zeroes) --- + + def test_float_NAN(self): + # Test that the root of a float NAN is a float NAN. + NAN = float('nan') + for n in range(2, 9): + with self.subTest(n=n): + result = self.nroot(NAN, n) + self.assertTrue(math.isnan(result)) + + def test_decimal_QNAN(self): + # Test the behaviour when taking the root of a Decimal quiet NAN. + NAN = decimal.Decimal('nan') + with decimal.localcontext() as ctx: + ctx.traps[decimal.InvalidOperation] = 1 + self.assertRaises(decimal.InvalidOperation, self.nroot, NAN, 5) + ctx.traps[decimal.InvalidOperation] = 0 + self.assertTrue(self.nroot(NAN, 5).is_qnan()) + + def test_decimal_SNAN(self): + # Test that taking the root of a Decimal sNAN always raises. + sNAN = decimal.Decimal('snan') + with decimal.localcontext() as ctx: + ctx.traps[decimal.InvalidOperation] = 1 + self.assertRaises(decimal.InvalidOperation, self.nroot, sNAN, 5) + ctx.traps[decimal.InvalidOperation] = 0 + self.assertRaises(decimal.InvalidOperation, self.nroot, sNAN, 5) + + def test_inf(self): + # Test that the root of infinity is infinity. + for INF in (float('inf'), decimal.Decimal('inf')): + for n in range(2, 9): + with self.subTest(n=n, inf=INF): + self.assertEqual(self.nroot(INF, n), INF) + + def testNInf(self): + # Test that the root of -inf is -inf for odd n. + for NINF in (float('-inf'), decimal.Decimal('-inf')): + for n in range(3, 11, 2): + with self.subTest(n=n, inf=NINF): + self.assertEqual(self.nroot(NINF, n), NINF) + + # FIXME: need to check Decimal zeroes too. + def test_zero(self): + # Test that the root of +0.0 is +0.0. + for n in range(2, 11): + with self.subTest(n=n): + result = self.nroot(+0.0, n) + self.assertEqual(result, 0.0) + self.assertEqual(sign(result), +1) + + # FIXME: need to check Decimal zeroes too. + def test_neg_zero(self): + # Test that the root of -0.0 is -0.0. + for n in range(2, 11): + with self.subTest(n=n): + result = self.nroot(-0.0, n) + self.assertEqual(result, 0.0) + self.assertEqual(sign(result), -1) + + # --- Test return types --- + + def check_result_type(self, x, n, outtype): + self.assertIsInstance(self.nroot(x, n), outtype) + class MySubclass(type(x)): + pass + self.assertIsInstance(self.nroot(MySubclass(x), n), outtype) + + def testDecimal(self): + # Test that Decimal arguments return Decimal results. + self.check_result_type(decimal.Decimal('33.3'), 3, decimal.Decimal) + + def testFloat(self): + # Test that other arguments return float results. + for x in (0.2, Fraction(11, 7), 91): + self.check_result_type(x, 6, float) + + # --- Test bad input --- + + def testBadOrderTypes(self): + # Test that nroot raises correctly when n has the wrong type. + for n in (5.0, 2j, None, 'x', b'x', [], {}, set(), sign): + with self.subTest(n=n): + self.assertRaises(TypeError, self.nroot, 2.5, n) + + def testBadOrderValues(self): + # Test that nroot raises correctly when n has a wrong value. + for n in (1, 0, -1, -2, -87): + with self.subTest(n=n): + self.assertRaises(ValueError, self.nroot, 2.5, n) + + def testBadTypes(self): + # Test that nroot raises correctly when x has the wrong type. + for x in (None, 'x', b'x', [], {}, set(), sign): + with self.subTest(x=x): + self.assertRaises(TypeError, self.nroot, x, 3) + + def testNegativeEvenPower(self): + # Test negative x with even n raises correctly. + x = random.uniform(-20.0, -0.1) + assert x < 0 + for n in range(2, 9, 2): + with self.subTest(x=x, n=n): + self.assertRaises(ValueError, self.nroot, x, n) + + # --- Test that nroot is never worse than calling math.pow() --- + + def check_error_is_no_worse(self, x, n): + y = math.pow(x, n) + with self.subTest(x=x, n=n, y=y): + err1 = abs(self.nroot(y, n) - x) + err2 = abs(math.pow(y, 1.0/n) - x) + self.assertLessEqual(err1, err2) + + def testCompareWithPowSmall(self): + # Compare nroot with pow for small values of x. + for i in range(200): + x = random.uniform(1e-9, 1.0-1e-9) + n = random.choice(range(2, 16)) + self.check_error_is_no_worse(x, n) + + def testCompareWithPowMedium(self): + # Compare nroot with pow for medium-sized values of x. + for i in range(200): + x = random.uniform(1.0, 100.0) + n = random.choice(range(2, 16)) + self.check_error_is_no_worse(x, n) + + def testCompareWithPowLarge(self): + # Compare nroot with pow for largish values of x. + for i in range(200): + x = random.uniform(100.0, 10000.0) + n = random.choice(range(2, 16)) + self.check_error_is_no_worse(x, n) + + def testCompareWithPowHuge(self): + # Compare nroot with pow for huge values of x. + for i in range(200): + x = random.uniform(1e20, 1e50) + # We restrict the order here to avoid an Overflow error. + n = random.choice(range(2, 7)) + self.check_error_is_no_worse(x, n) + + # --- Test for numerically correct answers --- + + def testExactPowers(self): + # Test that small integer powers are calculated exactly. + for i in range(1, 51): + for n in range(2, 16): + if (i, n) == (35, 13): + # See testExpectedFailure35p13 + continue + with self.subTest(i=i, n=n): + x = i**n + self.assertEqual(self.nroot(x, n), i) + + def testExactPowersNegatives(self): + # Test that small negative integer powers are calculated exactly. + for i in range(-1, -51, -1): + for n in range(3, 16, 2): + if (i, n) == (-35, 13): + # See testExpectedFailure35p13 + continue + with self.subTest(i=i, n=n): + x = i**n + assert sign(x) == -1 + self.assertEqual(self.nroot(x, n), i) + + def testExpectedFailure35p13(self): + # Test the expected failure 35**13 is almost exact. + x = 35**13 + err = abs(self.nroot(x, 13) - 35) + self.assertLessEqual(err, 0.000000001) + err = abs(self.nroot(-x, 13) + 35) + self.assertLessEqual(err, 0.000000001) + + def testOne(self): + # Test that the root of 1.0 is 1.0. + for n in range(2, 11): + with self.subTest(n=n): + self.assertEqual(self.nroot(1.0, n), 1.0) + + def testFraction(self): + # Test Fraction results. + x = Fraction(89, 75) + self.assertEqual(self.nroot(x**12, 12), float(x)) + + def testInt(self): + # Test int results. + x = 276 + self.assertEqual(self.nroot(x**24, 24), x) + + def testBigInt(self): + # Test that ints too big to convert to floats work. + bignum = 10**20 # That's not that big... + self.assertEqual(self.nroot(bignum**280, 280), bignum) + # Can we make it bigger? + hugenum = bignum**50 + # Make sure that it is too big to convert to a float. + try: + y = float(hugenum) + except OverflowError: + pass + else: + raise AssertionError('hugenum is not big enough') + self.assertEqual(self.nroot(hugenum, 50), float(bignum)) + + def testDecimal(self): + # Test Decimal results. + for s in '3.759 64.027 5234.338'.split(): + x = decimal.Decimal(s) + with self.subTest(x=x): + a = self.nroot(x**5, 5) + self.assertEqual(a, x) + a = self.nroot(x**17, 17) + self.assertEqual(a, x) + + def testFloat(self): + # Test float results. + for x in (3.04e-16, 18.25, 461.3, 1.9e17): + with self.subTest(x=x): + self.assertEqual(self.nroot(x**3, 3), x) + self.assertEqual(self.nroot(x**8, 8), x) + self.assertEqual(self.nroot(x**11, 11), x) + + +class Test_NthRoot_NS(unittest.TestCase): + """Test internals of the nth_root function, hidden in _nroot_NS.""" + + def test_class_cannot_be_instantiated(self): + # Test that _nroot_NS cannot be instantiated. + # It should be a namespace, like in C++ or C#, but Python + # lacks that feature and so we have to make do with a class. + self.assertRaises(TypeError, statistics._nroot_NS) + + # === Tests for public functions === class UnivariateCommonMixin: |