summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorMark Dickinson <dickinsm@gmail.com>2010-06-11 10:44:52 (GMT)
committerMark Dickinson <dickinsm@gmail.com>2010-06-11 10:44:52 (GMT)
commit08ade6faa0369a9167d150a1d85265b1a9ea58ef (patch)
tree287b48886b2f9dc890875d08dba06d200d7468d9 /Lib
parentbfd73faf86cc0dc82754372d96318f95c43170c5 (diff)
downloadcpython-08ade6faa0369a9167d150a1d85265b1a9ea58ef.zip
cpython-08ade6faa0369a9167d150a1d85265b1a9ea58ef.tar.gz
cpython-08ade6faa0369a9167d150a1d85265b1a9ea58ef.tar.bz2
Issue #8188: Comparisons between Decimal objects and other numeric
objects (Fraction, float, complex, int) now all function as expected.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/decimal.py43
-rw-r--r--Lib/test/test_fractions.py3
-rw-r--r--Lib/test/test_numeric_tower.py57
3 files changed, 94 insertions, 9 deletions
diff --git a/Lib/decimal.py b/Lib/decimal.py
index 29ce398..0e6a966 100644
--- a/Lib/decimal.py
+++ b/Lib/decimal.py
@@ -862,7 +862,7 @@ class Decimal(object):
# that specified by IEEE 754.
def __eq__(self, other, context=None):
- other = _convert_other(other, allow_float = True)
+ self, other = _convert_for_comparison(self, other, equality_op=True)
if other is NotImplemented:
return other
if self._check_nans(other, context):
@@ -870,7 +870,7 @@ class Decimal(object):
return self._cmp(other) == 0
def __ne__(self, other, context=None):
- other = _convert_other(other, allow_float = True)
+ self, other = _convert_for_comparison(self, other, equality_op=True)
if other is NotImplemented:
return other
if self._check_nans(other, context):
@@ -879,7 +879,7 @@ class Decimal(object):
def __lt__(self, other, context=None):
- other = _convert_other(other, allow_float = True)
+ self, other = _convert_for_comparison(self, other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
@@ -888,7 +888,7 @@ class Decimal(object):
return self._cmp(other) < 0
def __le__(self, other, context=None):
- other = _convert_other(other, allow_float = True)
+ self, other = _convert_for_comparison(self, other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
@@ -897,7 +897,7 @@ class Decimal(object):
return self._cmp(other) <= 0
def __gt__(self, other, context=None):
- other = _convert_other(other, allow_float = True)
+ self, other = _convert_for_comparison(self, other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
@@ -906,7 +906,7 @@ class Decimal(object):
return self._cmp(other) > 0
def __ge__(self, other, context=None):
- other = _convert_other(other, allow_float = True)
+ self, other = _convert_for_comparison(self, other)
if other is NotImplemented:
return other
ans = self._compare_check_nans(other, context)
@@ -5860,6 +5860,37 @@ def _convert_other(other, raiseit=False, allow_float=False):
raise TypeError("Unable to convert %s to Decimal" % other)
return NotImplemented
+def _convert_for_comparison(self, other, equality_op=False):
+ """Given a Decimal instance self and a Python object other, return
+ an pair (s, o) of Decimal instances such that "s op o" is
+ equivalent to "self op other" for any of the 6 comparison
+ operators "op".
+
+ """
+ if isinstance(other, Decimal):
+ return self, other
+
+ # Comparison with a Rational instance (also includes integers):
+ # self op n/d <=> self*d op n (for n and d integers, d positive).
+ # A NaN or infinity can be left unchanged without affecting the
+ # comparison result.
+ if isinstance(other, _numbers.Rational):
+ if not self._is_special:
+ self = _dec_from_triple(self._sign,
+ str(int(self._int) * other.denominator),
+ self._exp)
+ return self, Decimal(other.numerator)
+
+ # Comparisons with float and complex types. == and != comparisons
+ # with complex numbers should succeed, returning either True or False
+ # as appropriate. Other comparisons return NotImplemented.
+ if equality_op and isinstance(other, _numbers.Complex) and other.imag == 0:
+ other = other.real
+ if isinstance(other, float):
+ return self, Decimal.from_float(other)
+ return NotImplemented, NotImplemented
+
+
##### Setup Specific Contexts ############################################
# The default context prototype used by Context()
diff --git a/Lib/test/test_fractions.py b/Lib/test/test_fractions.py
index dd51f9b..a41fd9c 100644
--- a/Lib/test/test_fractions.py
+++ b/Lib/test/test_fractions.py
@@ -395,12 +395,11 @@ class FractionTest(unittest.TestCase):
self.assertTypedEquals(1.0 + 0j, (1.0 + 0j) ** F(1, 10))
def testMixingWithDecimal(self):
- # Decimal refuses mixed comparisons.
+ # Decimal refuses mixed arithmetic (but not mixed comparisons)
self.assertRaisesMessage(
TypeError,
"unsupported operand type(s) for +: 'Fraction' and 'Decimal'",
operator.add, F(3,11), Decimal('3.1415926'))
- self.assertNotEquals(F(5, 2), Decimal('2.5'))
def testComparisons(self):
self.assertTrue(F(1, 2) < F(2, 3))
diff --git a/Lib/test/test_numeric_tower.py b/Lib/test/test_numeric_tower.py
index eafdb0f..b0c9537 100644
--- a/Lib/test/test_numeric_tower.py
+++ b/Lib/test/test_numeric_tower.py
@@ -143,9 +143,64 @@ class HashTest(unittest.TestCase):
x = {'halibut', HalibutProxy()}
self.assertEqual(len(x), 1)
+class ComparisonTest(unittest.TestCase):
+ def test_mixed_comparisons(self):
+
+ # ordered list of distinct test values of various types:
+ # int, float, Fraction, Decimal
+ test_values = [
+ float('-inf'),
+ D('-1e999999999'),
+ -1e308,
+ F(-22, 7),
+ -3.14,
+ -2,
+ 0.0,
+ 1e-320,
+ True,
+ F('1.2'),
+ D('1.3'),
+ float('1.4'),
+ F(275807, 195025),
+ D('1.414213562373095048801688724'),
+ F(114243, 80782),
+ F(473596569, 84615),
+ 7e200,
+ D('infinity'),
+ ]
+ for i, first in enumerate(test_values):
+ for second in test_values[i+1:]:
+ self.assertLess(first, second)
+ self.assertLessEqual(first, second)
+ self.assertGreater(second, first)
+ self.assertGreaterEqual(second, first)
+
+ def test_complex(self):
+ # comparisons with complex are special: equality and inequality
+ # comparisons should always succeed, but order comparisons should
+ # raise TypeError.
+ z = 1.0 + 0j
+ w = -3.14 + 2.7j
+
+ for v in 1, 1.0, F(1), D(1), complex(1):
+ self.assertEqual(z, v)
+ self.assertEqual(v, z)
+
+ for v in 2, 2.0, F(2), D(2), complex(2):
+ self.assertNotEqual(z, v)
+ self.assertNotEqual(v, z)
+ self.assertNotEqual(w, v)
+ self.assertNotEqual(v, w)
+
+ for v in (1, 1.0, F(1), D(1), complex(1),
+ 2, 2.0, F(2), D(2), complex(2), w):
+ for op in operator.le, operator.lt, operator.ge, operator.gt:
+ self.assertRaises(TypeError, op, z, v)
+ self.assertRaises(TypeError, op, v, z)
+
def test_main():
- run_unittest(HashTest)
+ run_unittest(HashTest, ComparisonTest)
if __name__ == '__main__':
test_main()