diff options
author | Raymond Hettinger <python@rcn.com> | 2010-04-05 18:56:31 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2010-04-05 18:56:31 (GMT) |
commit | c50846aaef3e38d466ac9a0a87f72f09238e2061 (patch) | |
tree | f6ae48bcfbabb5107c971c240f3b06a549084f98 /Lib | |
parent | 5daab45158094e577b9791cda7d8a0f4e34f45cb (diff) | |
download | cpython-c50846aaef3e38d466ac9a0a87f72f09238e2061.zip cpython-c50846aaef3e38d466ac9a0a87f72f09238e2061.tar.gz cpython-c50846aaef3e38d466ac9a0a87f72f09238e2061.tar.bz2 |
Forward port total_ordering() and cmp_to_key().
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/functools.py | 47 | ||||
-rw-r--r-- | Lib/pstats.py | 12 | ||||
-rw-r--r-- | Lib/test/test_functools.py | 84 | ||||
-rw-r--r-- | Lib/unittest/loader.py | 3 | ||||
-rw-r--r-- | Lib/unittest/util.py | 9 |
5 files changed, 134 insertions, 21 deletions
diff --git a/Lib/functools.py b/Lib/functools.py index a54f030..539dc90 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -49,3 +49,50 @@ def wraps(wrapped, """ return partial(update_wrapper, wrapped=wrapped, assigned=assigned, updated=updated) + +def total_ordering(cls): + 'Class decorator that fills-in missing ordering methods' + convert = { + '__lt__': [('__gt__', lambda self, other: other < self), + ('__le__', lambda self, other: not other < self), + ('__ge__', lambda self, other: not self < other)], + '__le__': [('__ge__', lambda self, other: other <= self), + ('__lt__', lambda self, other: not other <= self), + ('__gt__', lambda self, other: not self <= other)], + '__gt__': [('__lt__', lambda self, other: other > self), + ('__ge__', lambda self, other: not other > self), + ('__le__', lambda self, other: not self > other)], + '__ge__': [('__le__', lambda self, other: other >= self), + ('__gt__', lambda self, other: not other >= self), + ('__lt__', lambda self, other: not self >= other)] + } + roots = set(dir(cls)) & set(convert) + assert roots, 'must define at least one ordering operation: < > <= >=' + root = max(roots) # prefer __lt __ to __le__ to __gt__ to __ge__ + for opname, opfunc in convert[root]: + if opname not in roots: + opfunc.__name__ = opname + opfunc.__doc__ = getattr(int, opname).__doc__ + setattr(cls, opname, opfunc) + return cls + +def cmp_to_key(mycmp): + 'Convert a cmp= function into a key= function' + class K(object): + def __init__(self, obj, *args): + self.obj = obj + def __lt__(self, other): + return mycmp(self.obj, other.obj) < 0 + def __gt__(self, other): + return mycmp(self.obj, other.obj) > 0 + def __eq__(self, other): + return mycmp(self.obj, other.obj) == 0 + def __le__(self, other): + return mycmp(self.obj, other.obj) <= 0 + def __ge__(self, other): + return mycmp(self.obj, other.obj) >= 0 + def __ne__(self, other): + return mycmp(self.obj, other.obj) != 0 + def __hash__(self): + raise TypeError('hash not implemented') + return K diff --git a/Lib/pstats.py b/Lib/pstats.py index e2fee37..14c4606 100644 --- a/Lib/pstats.py +++ b/Lib/pstats.py @@ -37,6 +37,7 @@ import os import time import marshal import re +from functools import cmp_to_key __all__ = ["Stats"] @@ -226,7 +227,7 @@ class Stats: stats_list.append((cc, nc, tt, ct) + func + (func_std_string(func), func)) - stats_list.sort(key=CmpToKey(TupleComp(sort_tuple).compare)) + stats_list.sort(key=cmp_to_key(TupleComp(sort_tuple).compare)) self.fcn_list = fcn_list = [] for tuple in stats_list: @@ -458,15 +459,6 @@ class TupleComp: return direction return 0 -def CmpToKey(mycmp): - 'Convert a cmp= function into a key= function' - class K(object): - def __init__(self, obj): - self.obj = obj - def __lt__(self, other): - return mycmp(self.obj, other.obj) == -1 - return K - #************************************************************************** # func_name is a triple (file:string, line:int, name:string) diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py index ae47dae..5cc2a50 100644 --- a/Lib/test/test_functools.py +++ b/Lib/test/test_functools.py @@ -364,7 +364,89 @@ class TestReduce(unittest.TestCase): d = {"one": 1, "two": 2, "three": 3} self.assertEqual(self.func(add, d), "".join(d.keys())) - +class TestCmpToKey(unittest.TestCase): + def test_cmp_to_key(self): + def mycmp(x, y): + return y - x + self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), + [4, 3, 2, 1, 0]) + + def test_hash(self): + def mycmp(x, y): + return y - x + key = functools.cmp_to_key(mycmp) + k = key(10) + self.assertRaises(TypeError, hash(k)) + +class TestTotalOrdering(unittest.TestCase): + + def test_total_ordering_lt(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __lt__(self, other): + return self.value < other.value + self.assert_(A(1) < A(2)) + self.assert_(A(2) > A(1)) + self.assert_(A(1) <= A(2)) + self.assert_(A(2) >= A(1)) + self.assert_(A(2) <= A(2)) + self.assert_(A(2) >= A(2)) + + def test_total_ordering_le(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __le__(self, other): + return self.value <= other.value + self.assert_(A(1) < A(2)) + self.assert_(A(2) > A(1)) + self.assert_(A(1) <= A(2)) + self.assert_(A(2) >= A(1)) + self.assert_(A(2) <= A(2)) + self.assert_(A(2) >= A(2)) + + def test_total_ordering_gt(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __gt__(self, other): + return self.value > other.value + self.assert_(A(1) < A(2)) + self.assert_(A(2) > A(1)) + self.assert_(A(1) <= A(2)) + self.assert_(A(2) >= A(1)) + self.assert_(A(2) <= A(2)) + self.assert_(A(2) >= A(2)) + + def test_total_ordering_ge(self): + @functools.total_ordering + class A: + def __init__(self, value): + self.value = value + def __ge__(self, other): + return self.value >= other.value + self.assert_(A(1) < A(2)) + self.assert_(A(2) > A(1)) + self.assert_(A(1) <= A(2)) + self.assert_(A(2) >= A(1)) + self.assert_(A(2) <= A(2)) + self.assert_(A(2) >= A(2)) + + def test_total_ordering_no_overwrite(self): + # new methods should not overwrite existing + @functools.total_ordering + class A(int): + raise Exception() + self.assert_(A(1) < A(2)) + self.assert_(A(2) > A(1)) + self.assert_(A(1) <= A(2)) + self.assert_(A(2) >= A(1)) + self.assert_(A(2) <= A(2)) + self.assert_(A(2) >= A(2)) def test_main(verbose=None): diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index 5d11b6e..f00f38d 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -5,6 +5,7 @@ import re import sys import traceback import types +import functools from fnmatch import fnmatch @@ -141,7 +142,7 @@ class TestLoader(object): testFnNames = testFnNames = list(filter(isTestMethod, dir(testCaseClass))) if self.sortTestMethodsUsing: - testFnNames.sort(key=util.CmpToKey(self.sortTestMethodsUsing)) + testFnNames.sort(key=functools.cmp_to_key(self.sortTestMethodsUsing)) return testFnNames def discover(self, start_dir, pattern='test*.py', top_level_dir=None): diff --git a/Lib/unittest/util.py b/Lib/unittest/util.py index 736c202..ea8a68d 100644 --- a/Lib/unittest/util.py +++ b/Lib/unittest/util.py @@ -70,15 +70,6 @@ def unorderable_list_difference(expected, actual): # anything left in actual is unexpected return missing, actual -def CmpToKey(mycmp): - 'Convert a cmp= function into a key= function' - class K(object): - def __init__(self, obj, *args): - self.obj = obj - def __lt__(self, other): - return mycmp(self.obj, other.obj) == -1 - return K - def three_way_cmp(x, y): """Return -1 if x < y, 0 if x == y and 1 if x > y""" return (x > y) - (x < y) |