diff options
Diffstat (limited to 'Lib/test/test_functools.py')
| -rw-r--r-- | Lib/test/test_functools.py | 65 | 
1 files changed, 64 insertions, 1 deletions
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index 7d11b53..97d7524 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -436,19 +436,82 @@ class TestReduce(unittest.TestCase):          self.assertEqual(self.func(add, d), "".join(d.keys()))  class TestCmpToKey(unittest.TestCase): +      def test_cmp_to_key(self): +        def cmp1(x, y): +            return (x > y) - (x < y) +        key = functools.cmp_to_key(cmp1) +        self.assertEqual(key(3), key(3)) +        self.assertGreater(key(3), key(1)) +        def cmp2(x, y): +            return int(x) - int(y) +        key = functools.cmp_to_key(cmp2) +        self.assertEqual(key(4.0), key('4')) +        self.assertLess(key(2), key('35')) + +    def test_cmp_to_key_arguments(self): +        def cmp1(x, y): +            return (x > y) - (x < y) +        key = functools.cmp_to_key(mycmp=cmp1) +        self.assertEqual(key(obj=3), key(obj=3)) +        self.assertGreater(key(obj=3), key(obj=1)) +        with self.assertRaises((TypeError, AttributeError)): +            key(3) > 1    # rhs is not a K object +        with self.assertRaises((TypeError, AttributeError)): +            1 < key(3)    # lhs is not a K object +        with self.assertRaises(TypeError): +            key = functools.cmp_to_key()             # too few args +        with self.assertRaises(TypeError): +            key = functools.cmp_to_key(cmp1, None)   # too many args +        key = functools.cmp_to_key(cmp1) +        with self.assertRaises(TypeError): +            key()                                    # too few args +        with self.assertRaises(TypeError): +            key(None, None)                          # too many args + +    def test_bad_cmp(self): +        def cmp1(x, y): +            raise ZeroDivisionError +        key = functools.cmp_to_key(cmp1) +        with self.assertRaises(ZeroDivisionError): +            key(3) > key(1) + +        class BadCmp: +            def __lt__(self, other): +                raise ZeroDivisionError +        def cmp1(x, y): +            return BadCmp() +        with self.assertRaises(ZeroDivisionError): +            key(3) > key(1) + +    def test_obj_field(self): +        def cmp1(x, y): +            return (x > y) - (x < y) +        key = functools.cmp_to_key(mycmp=cmp1) +        self.assertEqual(key(50).obj, 50) + +    def test_sort_int(self):          def mycmp(x, y):              return y - x          self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),                           [4, 3, 2, 1, 0]) +    def test_sort_int_str(self): +        def mycmp(x, y): +            x, y = int(x), int(y) +            return (x > y) - (x < y) +        values = [5, '3', 7, 2, '0', '1', 4, '10', 1] +        values = sorted(values, key=functools.cmp_to_key(mycmp)) +        self.assertEqual([int(value) for value in values], +                         [0, 1, 1, 2, 3, 4, 5, 7, 10]) +      def test_hash(self):          def mycmp(x, y):              return y - x          key = functools.cmp_to_key(mycmp)          k = key(10)          self.assertRaises(TypeError, hash, k) -        self.assertFalse(isinstance(k, collections.Hashable)) +        self.assertNotIsInstance(k, collections.Hashable)  class TestTotalOrdering(unittest.TestCase):  | 
