summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_itertools.py2
-rw-r--r--Modules/itertoolsmodule.c4
2 files changed, 4 insertions, 2 deletions
diff --git a/Lib/test/test_itertools.py b/Lib/test/test_itertools.py
index f5dd069..dc9081e 100644
--- a/Lib/test/test_itertools.py
+++ b/Lib/test/test_itertools.py
@@ -171,6 +171,7 @@ class TestBasicOps(unittest.TestCase):
def test_ifilter(self):
self.assertEqual(list(ifilter(isEven, range(6))), [0,2,4])
self.assertEqual(list(ifilter(None, [0,1,0,2,0])), [1,2])
+ self.assertEqual(list(ifilter(bool, [0,1,0,2,0])), [1,2])
self.assertEqual(take(4, ifilter(isEven, count())), [0,2,4,6])
self.assertRaises(TypeError, ifilter)
self.assertRaises(TypeError, ifilter, lambda x:x)
@@ -181,6 +182,7 @@ class TestBasicOps(unittest.TestCase):
def test_ifilterfalse(self):
self.assertEqual(list(ifilterfalse(isEven, range(6))), [1,3,5])
self.assertEqual(list(ifilterfalse(None, [0,1,0,2,0])), [0,0,0])
+ self.assertEqual(list(ifilterfalse(bool, [0,1,0,2,0])), [0,0,0])
self.assertEqual(take(4, ifilterfalse(isEven, count())), [1,3,5,7])
self.assertRaises(TypeError, ifilterfalse)
self.assertRaises(TypeError, ifilterfalse, lambda x:x)
diff --git a/Modules/itertoolsmodule.c b/Modules/itertoolsmodule.c
index ef15a39..e53c353 100644
--- a/Modules/itertoolsmodule.c
+++ b/Modules/itertoolsmodule.c
@@ -2055,7 +2055,7 @@ ifilter_next(ifilterobject *lz)
if (item == NULL)
return NULL;
- if (lz->func == Py_None) {
+ if (lz->func == Py_None || lz->func == PyBool_Type) {
ok = PyObject_IsTrue(item);
} else {
PyObject *good;
@@ -2199,7 +2199,7 @@ ifilterfalse_next(ifilterfalseobject *lz)
if (item == NULL)
return NULL;
- if (lz->func == Py_None) {
+ if (lz->func == Py_None || lz->func == PyBool_Type) {
ok = PyObject_IsTrue(item);
} else {
PyObject *good;