diff options
Diffstat (limited to 'Lib/test/test_functools.py')
-rw-r--r-- | Lib/test/test_functools.py | 71 |
1 files changed, 70 insertions, 1 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 8dc185b..01d6cd2 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -259,6 +259,74 @@ class TestWraps(TestUpdateWrapper): self.assertEqual(wrapper.attr, 'This is a different test') self.assertEqual(wrapper.dict_attr, f.dict_attr) +class TestReduce(unittest.TestCase): + func = functools.reduce + + def test_reduce(self): + class Squares: + def __init__(self, max): + self.max = max + self.sofar = [] + + def __len__(self): + return len(self.sofar) + + def __getitem__(self, i): + if not 0 <= i < self.max: raise IndexError + n = len(self.sofar) + while n <= i: + self.sofar.append(n*n) + n += 1 + return self.sofar[i] + + self.assertEqual(self.func(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc') + self.assertEqual( + self.func(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []), + ['a','c','d','w'] + ) + self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040) + self.assertEqual( + self.func(lambda x, y: x*y, range(2,21), 1L), + 2432902008176640000L + ) + self.assertEqual(self.func(lambda x, y: x+y, Squares(10)), 285) + self.assertEqual(self.func(lambda x, y: x+y, Squares(10), 0), 285) + self.assertEqual(self.func(lambda x, y: x+y, Squares(0), 0), 0) + self.assertRaises(TypeError, self.func) + self.assertRaises(TypeError, self.func, 42, 42) + self.assertRaises(TypeError, self.func, 42, 42, 42) + self.assertEqual(self.func(42, "1"), "1") # func is never called with one item + self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item + self.assertRaises(TypeError, self.func, 42, (42, 42)) + + class BadSeq: + def __getitem__(self, index): + raise ValueError + self.assertRaises(ValueError, self.func, 42, BadSeq()) + + # Test reduce()'s use of iterators. + def test_iterator_usage(self): + class SequenceClass: + def __init__(self, n): + self.n = n + def __getitem__(self, i): + if 0 <= i < self.n: + return i + else: + raise IndexError + + from operator import add + self.assertEqual(self.func(add, SequenceClass(5)), 10) + self.assertEqual(self.func(add, SequenceClass(5), 42), 52) + self.assertRaises(TypeError, self.func, add, SequenceClass(0)) + self.assertEqual(self.func(add, SequenceClass(0), 42), 42) + self.assertEqual(self.func(add, SequenceClass(1)), 0) + self.assertEqual(self.func(add, SequenceClass(1), 42), 42) + + d = {"one": 1, "two": 2, "three": 3} + self.assertEqual(self.func(add, d), "".join(d.keys())) + + def test_main(verbose=None): @@ -268,7 +336,8 @@ def test_main(verbose=None): TestPartialSubclass, TestPythonPartial, TestUpdateWrapper, - TestWraps + TestWraps, + TestReduce ) test_support.run_unittest(*test_classes) |