diff options
author | Raymond Hettinger <python@rcn.com> | 2004-07-05 22:53:03 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2004-07-05 22:53:03 (GMT) |
commit | 0aeac107cadd1472e5ea109f7f29e6a6527ae1ec (patch) | |
tree | 4d0d33fdc1199cdc32794cdfcc20b2f90a44f4b9 | |
parent | 10959b1c2acb48ebcdc909421354543043cfd671 (diff) | |
download | cpython-0aeac107cadd1472e5ea109f7f29e6a6527ae1ec.zip cpython-0aeac107cadd1472e5ea109f7f29e6a6527ae1ec.tar.gz cpython-0aeac107cadd1472e5ea109f7f29e6a6527ae1ec.tar.bz2 |
* Add __eq__ and __ne__ so that things like list.index() work properly
for lists of mixed types.
* Test that sort works.
-rw-r--r-- | Lib/decimal.py | 14 | ||||
-rw-r--r-- | Lib/test/test_decimal.py | 27 |
2 files changed, 26 insertions, 15 deletions
diff --git a/Lib/decimal.py b/Lib/decimal.py index 500ba07..1d13767 100644 --- a/Lib/decimal.py +++ b/Lib/decimal.py @@ -8,10 +8,6 @@ # and Tim Peters -# Todo: -# Add rich comparisons for equality testing with other types - - """ This is a Py2.3 implementation of decimal floating point arithmetic based on the General Decimal Arithmetic Specification: @@ -644,6 +640,16 @@ class Decimal(object): return -1 return 1 + def __eq__(self, other): + if not isinstance(other, (Decimal, int, long)): + return False + return self.__cmp__(other) == 0 + + def __ne__(self, other): + if not isinstance(other, (Decimal, int, long)): + return True + return self.__cmp__(other) != 0 + def compare(self, other, context=None): """Compares one to another. diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index 51b3528..a03b784 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -33,6 +33,7 @@ import pickle, copy from decimal import * from test.test_support import TestSkipped, run_unittest, run_doctest, is_resource_enabled import threading +import random # Tests are built around these assumed context defaults DefaultContext.prec=9 @@ -841,17 +842,17 @@ class DecimalUsabilityTest(unittest.TestCase): self.assertEqual(cmp(dc,45), 0) #a Decimal and uncomparable - try: da == 'ugly' - except TypeError: pass - else: self.fail('Did not raised an error!') - - try: da == '32.7' - except TypeError: pass - else: self.fail('Did not raised an error!') - - try: da == object - except TypeError: pass - else: self.fail('Did not raised an error!') + self.assertNotEqual(da, 'ugly') + self.assertNotEqual(da, 32.7) + self.assertNotEqual(da, object()) + self.assertNotEqual(da, object) + + # sortable + a = map(Decimal, xrange(100)) + b = a[:] + random.shuffle(a) + a.sort() + self.assertEqual(a, b) def test_copy_and_deepcopy_methods(self): d = Decimal('43.24') @@ -1078,6 +1079,10 @@ class ContextAPItests(unittest.TestCase): v2 = vars(e)[k] self.assertEqual(v1, v2) + def test_equality_with_other_types(self): + self.assert_(Decimal(10) in ['a', 1.0, Decimal(10), (1,2), {}]) + self.assert_(Decimal(10) not in ['a', 1.0, (1,2), {}]) + def test_main(arith=False, verbose=None): """ Execute the tests. |