diff options
-rw-r--r-- | Lib/test/test_itertools.py | 29 | ||||
-rw-r--r-- | Modules/itertoolsmodule.c | 14 |
2 files changed, 38 insertions, 5 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py index 6ae6785..02f84b7 100644 --- a/Lib/test/test_itertools.py +++ b/Lib/test/test_itertools.py @@ -53,6 +53,10 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, count, 'a') c = count(sys.maxint-2) # verify that rollover doesn't crash c.next(); c.next(); c.next(); c.next(); c.next() + c = count(3) + self.assertEqual(repr(c), 'count(3)') + c.next() + self.assertEqual(repr(c), 'count(4)') def test_cycle(self): self.assertEqual(take(10, cycle('abc')), list('abcabcabca')) @@ -67,6 +71,7 @@ class TestBasicOps(unittest.TestCase): self.assertEqual([], list(groupby([], key=id))) self.assertRaises(TypeError, list, groupby('abc', [])) self.assertRaises(TypeError, groupby, None) + self.assertRaises(TypeError, groupby, 'abc', lambda x:x, 10) # Check normal input s = [(0, 10, 20), (0, 11,21), (0,12,21), (1,13,21), (1,14,22), @@ -199,6 +204,12 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, repeat) self.assertRaises(TypeError, repeat, None, 3, 4) self.assertRaises(TypeError, repeat, None, 'a') + r = repeat(1+0j) + self.assertEqual(repr(r), 'repeat((1+0j))') + r = repeat(1+0j, 5) + self.assertEqual(repr(r), 'repeat((1+0j), 5)') + list(r) + self.assertEqual(repr(r), 'repeat((1+0j), 0)') def test_imap(self): self.assertEqual(list(imap(operator.pow, range(3), range(1,7))), @@ -275,6 +286,9 @@ class TestBasicOps(unittest.TestCase): self.assertRaises(TypeError, takewhile, operator.pow, [(4,5)], 'extra') self.assertRaises(TypeError, takewhile(10, [(4,5)]).next) self.assertRaises(ValueError, takewhile(errfunc, [(4,5)]).next) + t = takewhile(bool, [1, 1, 1, 0, 0, 0]) + self.assertEqual(list(t), [1, 1, 1]) + self.assertRaises(StopIteration, t.next) def test_dropwhile(self): data = [1, 3, 5, 20, 2, 4, 6, 8] @@ -347,11 +361,26 @@ class TestBasicOps(unittest.TestCase): self.assertEqual(list(a), range(100,2000)) self.assertEqual(list(c), range(2,2000)) + # test values of n + self.assertRaises(TypeError, tee, 'abc', 'invalid') + for n in xrange(5): + result = tee('abc', n) + self.assertEqual(type(result), tuple) + self.assertEqual(len(result), n) + self.assertEqual(map(list, result), [list('abc')]*n) + # tee pass-through to copyable iterator a, b = tee('abc') c, d = tee(a) self.assert_(a is c) + # test tee_new + t1, t2 = tee('abc') + tnew = type(t1) + self.assertRaises(TypeError, tnew) + self.assertRaises(TypeError, tnew, 10) + t3 = tnew(t1) + self.assert_(list(t1) == list(t2) == list(t3) == list('abc')) def test_StopIteration(self): self.assertRaises(StopIteration, izip().next) diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c index bf2c493..3da0258 100644 --- a/Modules/itertoolsmodule.c +++ b/Modules/itertoolsmodule.c @@ -75,7 +75,7 @@ groupby_traverse(groupbyobject *gbo, visitproc visit, void *arg) static PyObject * groupby_next(groupbyobject *gbo) { - PyObject *newvalue, *newkey, *r, *grouper; + PyObject *newvalue, *newkey, *r, *grouper, *tmp; /* skip to next iteration group */ for (;;) { @@ -110,15 +110,19 @@ groupby_next(groupbyobject *gbo) } } - Py_XDECREF(gbo->currkey); + tmp = gbo->currkey; gbo->currkey = newkey; - Py_XDECREF(gbo->currvalue); + Py_XDECREF(tmp); + + tmp = gbo->currvalue; gbo->currvalue = newvalue; + Py_XDECREF(tmp); } - Py_XDECREF(gbo->tgtkey); - gbo->tgtkey = gbo->currkey; Py_INCREF(gbo->currkey); + tmp = gbo->tgtkey; + gbo->tgtkey = gbo->currkey; + Py_XDECREF(tmp); grouper = _grouper_create(gbo, gbo->tgtkey); if (grouper == NULL) |