diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/functools.py | 7 | ||||
-rw-r--r-- | Lib/test/test_functools.py | 66 |
2 files changed, 71 insertions, 2 deletions
diff --git a/Lib/functools.py b/Lib/functools.py index 3bffbac..098f6b6 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -97,7 +97,7 @@ def cmp_to_key(mycmp): """Convert a cmp= function into a key= function""" class K(object): __slots__ = ['obj'] - def __init__(self, obj, *args): + def __init__(self, obj): self.obj = obj def __lt__(self, other): return mycmp(self.obj, other.obj) < 0 @@ -115,6 +115,11 @@ def cmp_to_key(mycmp): raise TypeError('hash not implemented') return K +try: + from _functools import cmp_to_key +except ImportError: + pass + _CacheInfo = namedtuple("CacheInfo", "hits misses maxsize currsize") def lru_cache(maxsize=100): diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 73a77d6..c50336e 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -435,18 +435,81 @@ class TestReduce(unittest.TestCase): self.assertEqual(self.func(add, d), "".join(d.keys())) class TestCmpToKey(unittest.TestCase): + def test_cmp_to_key(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(cmp1) + self.assertEqual(key(3), key(3)) + self.assertGreater(key(3), key(1)) + def cmp2(x, y): + return int(x) - int(y) + key = functools.cmp_to_key(cmp2) + self.assertEqual(key(4.0), key('4')) + self.assertLess(key(2), key('35')) + + def test_cmp_to_key_arguments(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(mycmp=cmp1) + self.assertEqual(key(obj=3), key(obj=3)) + self.assertGreater(key(obj=3), key(obj=1)) + with self.assertRaises((TypeError, AttributeError)): + key(3) > 1 # rhs is not a K object + with self.assertRaises((TypeError, AttributeError)): + 1 < key(3) # lhs is not a K object + with self.assertRaises(TypeError): + key = functools.cmp_to_key() # too few args + with self.assertRaises(TypeError): + key = functools.cmp_to_key(cmp1, None) # too many args + key = functools.cmp_to_key(cmp1) + with self.assertRaises(TypeError): + key() # too few args + with self.assertRaises(TypeError): + key(None, None) # too many args + + def test_bad_cmp(self): + def cmp1(x, y): + raise ZeroDivisionError + key = functools.cmp_to_key(cmp1) + with self.assertRaises(ZeroDivisionError): + key(3) > key(1) + + class BadCmp: + def __lt__(self, other): + raise ZeroDivisionError + def cmp1(x, y): + return BadCmp() + with self.assertRaises(ZeroDivisionError): + key(3) > key(1) + + def test_obj_field(self): + def cmp1(x, y): + return (x > y) - (x < y) + key = functools.cmp_to_key(mycmp=cmp1) + self.assertEqual(key(50).obj, 50) + + def test_sort_int(self): def mycmp(x, y): return y - x self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), [4, 3, 2, 1, 0]) + def test_sort_int_str(self): + def mycmp(x, y): + x, y = int(x), int(y) + return (x > y) - (x < y) + values = [5, '3', 7, 2, '0', '1', 4, '10', 1] + values = sorted(values, key=functools.cmp_to_key(mycmp)) + self.assertEqual([int(value) for value in values], + [0, 1, 1, 2, 3, 4, 5, 7, 10]) + def test_hash(self): def mycmp(x, y): return y - x key = functools.cmp_to_key(mycmp) k = key(10) - self.assertRaises(TypeError, hash(k)) + self.assertRaises(TypeError, hash, k) class TestTotalOrdering(unittest.TestCase): @@ -655,6 +718,7 @@ class TestLRU(unittest.TestCase): def test_main(verbose=None): test_classes = ( + TestCmpToKey, TestPartial, TestPartialSubclass, TestPythonPartial, |