summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorÉric Araujo <merwok@netwok.org>2011-03-19 03:29:36 (GMT)
committerÉric Araujo <merwok@netwok.org>2011-03-19 03:29:36 (GMT)
commit374274db7f7536836d0acd598ff6001aa5ed06c8 (patch)
tree178c43060f69ac301fcd14db9b81b74a13626c00
parent5136867c1b8573de23f7e7f95f8292846aab507e (diff)
downloadcpython-374274db7f7536836d0acd598ff6001aa5ed06c8.zip
cpython-374274db7f7536836d0acd598ff6001aa5ed06c8.tar.gz
cpython-374274db7f7536836d0acd598ff6001aa5ed06c8.tar.bz2
Fix the total_ordering decorator to handle cross-type comparisons
that could lead to infinite recursion (closes #10042).
-rw-r--r--Lib/functools.py16
-rw-r--r--Lib/test/test_functools.py24
-rw-r--r--Misc/ACKS1
-rw-r--r--Misc/NEWS3
4 files changed, 36 insertions, 8 deletions
diff --git a/Lib/functools.py b/Lib/functools.py
index 1a1f22e..01889cb 100644
--- a/Lib/functools.py
+++ b/Lib/functools.py
@@ -53,17 +53,17 @@ def wraps(wrapped,
def total_ordering(cls):
"""Class decorator that fills in missing ordering methods"""
convert = {
- '__lt__': [('__gt__', lambda self, other: other < self),
- ('__le__', lambda self, other: not other < self),
+ '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
+ ('__le__', lambda self, other: self < other or self == other),
('__ge__', lambda self, other: not self < other)],
- '__le__': [('__ge__', lambda self, other: other <= self),
- ('__lt__', lambda self, other: not other <= self),
+ '__le__': [('__ge__', lambda self, other: not self <= other or self == other),
+ ('__lt__', lambda self, other: self <= other and not self == other),
('__gt__', lambda self, other: not self <= other)],
- '__gt__': [('__lt__', lambda self, other: other > self),
- ('__ge__', lambda self, other: not other > self),
+ '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
+ ('__ge__', lambda self, other: self > other or self == other),
('__le__', lambda self, other: not self > other)],
- '__ge__': [('__le__', lambda self, other: other >= self),
- ('__gt__', lambda self, other: not other >= self),
+ '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
+ ('__gt__', lambda self, other: self >= other and not self == other),
('__lt__', lambda self, other: not self >= other)]
}
roots = set(dir(cls)) & set(convert)
diff --git a/Lib/test/test_functools.py b/Lib/test/test_functools.py
index 6bc7b2b..a713314 100644
--- a/Lib/test/test_functools.py
+++ b/Lib/test/test_functools.py
@@ -361,6 +361,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value
def __lt__(self, other):
return self.value < other.value
+ def __eq__(self, other):
+ return self.value == other.value
self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2))
@@ -375,6 +377,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value
def __le__(self, other):
return self.value <= other.value
+ def __eq__(self, other):
+ return self.value == other.value
self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2))
@@ -389,6 +393,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value
def __gt__(self, other):
return self.value > other.value
+ def __eq__(self, other):
+ return self.value == other.value
self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2))
@@ -403,6 +409,8 @@ class TestTotalOrdering(unittest.TestCase):
self.value = value
def __ge__(self, other):
return self.value >= other.value
+ def __eq__(self, other):
+ return self.value == other.value
self.assertTrue(A(1) < A(2))
self.assertTrue(A(2) > A(1))
self.assertTrue(A(1) <= A(2))
@@ -428,6 +436,22 @@ class TestTotalOrdering(unittest.TestCase):
class A:
pass
+ def test_bug_10042(self):
+ @functools.total_ordering
+ class TestTO:
+ def __init__(self, value):
+ self.value = value
+ def __eq__(self, other):
+ if isinstance(other, TestTO):
+ return self.value == other.value
+ return False
+ def __lt__(self, other):
+ if isinstance(other, TestTO):
+ return self.value < other.value
+ raise TypeError
+ with self.assertRaises(TypeError):
+ TestTO(8) <= ()
+
def test_main(verbose=None):
test_classes = (
TestPartial,
diff --git a/Misc/ACKS b/Misc/ACKS
index 5d617da..66ace9a 100644
--- a/Misc/ACKS
+++ b/Misc/ACKS
@@ -669,6 +669,7 @@ Bernhard Reiter
Steven Reiz
Roeland Rengelink
Tim Rice
+Francesco Ricciardi
Jan Pieter Riegel
Armin Rigo
Nicholas Riley
diff --git a/Misc/NEWS b/Misc/NEWS
index 9c136aa..c42d89b 100644
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -43,6 +43,9 @@ Core and Builtins
Library
-------
+- Issue #10042: Fixed the total_ordering decorator to handle cross-type
+ comparisons that could lead to infinite recursion.
+
- Issue #10979: unittest stdout buffering now works with class and module
setup and teardown.