summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_decimal.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/test/test_decimal.py')
-rw-r--r--Lib/test/test_decimal.py87
1 files changed, 78 insertions, 9 deletions
diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py
index 51bdf9c..7de2400 100644
--- a/Lib/test/test_decimal.py
+++ b/Lib/test/test_decimal.py
@@ -26,6 +26,7 @@ with the corresponding argument.
import math
import os, sys
+import operator
import pickle, copy
import unittest
from decimal import *
@@ -1096,18 +1097,56 @@ class DecimalArithmeticOperatorsTest(unittest.TestCase):
self.assertEqual(abs(Decimal(45)), abs(Decimal(-45))) # abs
def test_nan_comparisons(self):
+ # comparisons involving signaling nans signal InvalidOperation
+
+ # order comparisons (<, <=, >, >=) involving only quiet nans
+ # also signal InvalidOperation
+
+ # equality comparisons (==, !=) involving only quiet nans
+ # don't signal, but return False or True respectively.
+
n = Decimal('NaN')
s = Decimal('sNaN')
i = Decimal('Inf')
f = Decimal('2')
- for x, y in [(n, n), (n, i), (i, n), (n, f), (f, n),
- (s, n), (n, s), (s, i), (i, s), (s, f), (f, s), (s, s)]:
- self.assertTrue(x != y)
- self.assertTrue(not (x == y))
- self.assertTrue(not (x < y))
- self.assertTrue(not (x <= y))
- self.assertTrue(not (x > y))
- self.assertTrue(not (x >= y))
+
+ qnan_pairs = (n, n), (n, i), (i, n), (n, f), (f, n)
+ snan_pairs = (s, n), (n, s), (s, i), (i, s), (s, f), (f, s), (s, s)
+ order_ops = operator.lt, operator.le, operator.gt, operator.ge
+ equality_ops = operator.eq, operator.ne
+
+ # results when InvalidOperation is not trapped
+ for x, y in qnan_pairs + snan_pairs:
+ for op in order_ops + equality_ops:
+ got = op(x, y)
+ expected = True if op is operator.ne else False
+ self.assertIs(expected, got,
+ "expected {0!r} for operator.{1}({2!r}, {3!r}); "
+ "got {4!r}".format(
+ expected, op.__name__, x, y, got))
+
+ # repeat the above, but this time trap the InvalidOperation
+ with localcontext() as ctx:
+ ctx.traps[InvalidOperation] = 1
+
+ for x, y in qnan_pairs:
+ for op in equality_ops:
+ got = op(x, y)
+ expected = True if op is operator.ne else False
+ self.assertIs(expected, got,
+ "expected {0!r} for "
+ "operator.{1}({2!r}, {3!r}); "
+ "got {4!r}".format(
+ expected, op.__name__, x, y, got))
+
+ for x, y in snan_pairs:
+ for op in equality_ops:
+ self.assertRaises(InvalidOperation, operator.eq, x, y)
+ self.assertRaises(InvalidOperation, operator.ne, x, y)
+
+ for x, y in qnan_pairs + snan_pairs:
+ for op in order_ops:
+ self.assertRaises(InvalidOperation, op, x, y)
def test_copy_sign(self):
d = Decimal(1).copy_sign(Decimal(-2))
@@ -1213,6 +1252,23 @@ class DecimalUsabilityTest(unittest.TestCase):
a.sort()
self.assertEqual(a, b)
+ def test_decimal_float_comparison(self):
+ da = Decimal('0.25')
+ db = Decimal('3.0')
+ self.assert_(da < 3.0)
+ self.assert_(da <= 3.0)
+ self.assert_(db > 0.25)
+ self.assert_(db >= 0.25)
+ self.assert_(da != 1.5)
+ self.assert_(da == 0.25)
+ self.assert_(3.0 > da)
+ self.assert_(3.0 >= da)
+ self.assert_(0.25 < db)
+ self.assert_(0.25 <= db)
+ self.assert_(0.25 != db)
+ self.assert_(3.0 == db)
+ self.assert_(0.1 != Decimal('0.1'))
+
def test_copy_and_deepcopy_methods(self):
d = Decimal('43.24')
c = copy.copy(d)
@@ -1223,6 +1279,10 @@ class DecimalUsabilityTest(unittest.TestCase):
def test_hash_method(self):
#just that it's hashable
hash(Decimal(23))
+ hash(Decimal('Infinity'))
+ hash(Decimal('-Infinity'))
+ hash(Decimal('nan123'))
+ hash(Decimal('-NaN'))
test_values = [Decimal(sign*(2**m + n))
for m in [0, 14, 15, 16, 17, 30, 31,
@@ -1257,10 +1317,19 @@ class DecimalUsabilityTest(unittest.TestCase):
#the same hash that to an int
self.assertEqual(hash(Decimal(23)), hash(23))
- self.assertRaises(TypeError, hash, Decimal('NaN'))
+ self.assertRaises(TypeError, hash, Decimal('sNaN'))
self.assertTrue(hash(Decimal('Inf')))
self.assertTrue(hash(Decimal('-Inf')))
+ # check that the hashes of a Decimal float match when they
+ # represent exactly the same values
+ test_strings = ['inf', '-Inf', '0.0', '-.0e1',
+ '34.0', '2.5', '112390.625', '-0.515625']
+ for s in test_strings:
+ f = float(s)
+ d = Decimal(s)
+ self.assertEqual(hash(f), hash(d))
+
# check that the value of the hash doesn't depend on the
# current context (issue #1757)
c = getcontext()