diff options
author | Raymond Hettinger <python@rcn.com> | 2005-03-27 10:55:27 (GMT) |
---|---|---|
committer | Raymond Hettinger <python@rcn.com> | 2005-03-27 10:55:27 (GMT) |
commit | 680db018be01f1c9e442b1ae887dad1d750f680e (patch) | |
tree | e6035aff32091cd7d7a3ee308610aac7d6dfd41d /Lib | |
parent | 3589d815b4ecab7bc033e971d6a345d7ad801e1a (diff) | |
download | cpython-680db018be01f1c9e442b1ae887dad1d750f680e.zip cpython-680db018be01f1c9e442b1ae887dad1d750f680e.tar.gz cpython-680db018be01f1c9e442b1ae887dad1d750f680e.tar.bz2 |
- Fixed decimal operator and comparison methods to return NotImplemented
instead of raising a TypeError when interacting with other types.
Allows other classes to successfully implement __radd__ style methods.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/decimal.py | 43 | ||||
-rw-r--r-- | Lib/test/test_decimal.py | 57 |
2 files changed, 89 insertions, 11 deletions
diff --git a/Lib/decimal.py b/Lib/decimal.py index 539fe8d..18f1c90 100644 --- a/Lib/decimal.py +++ b/Lib/decimal.py @@ -645,6 +645,8 @@ class Decimal(object): def __cmp__(self, other, context=None): other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: ans = self._check_nans(other, context) @@ -696,12 +698,12 @@ class Decimal(object): def __eq__(self, other): if not isinstance(other, (Decimal, int, long)): - return False + return NotImplemented return self.__cmp__(other) == 0 def __ne__(self, other): if not isinstance(other, (Decimal, int, long)): - return True + return NotImplemented return self.__cmp__(other) != 0 def compare(self, other, context=None): @@ -714,6 +716,8 @@ class Decimal(object): Like __cmp__, but returns Decimal instances. """ other = _convert_other(other) + if other is NotImplemented: + return other #compare(NaN, NaN) = NaN if (self._is_special or other and other._is_special): @@ -919,6 +923,8 @@ class Decimal(object): -INF + INF (or the reverse) cause InvalidOperation errors. """ other = _convert_other(other) + if other is NotImplemented: + return other if context is None: context = getcontext() @@ -1006,6 +1012,8 @@ class Decimal(object): def __sub__(self, other, context=None): """Return self + (-other)""" other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: ans = self._check_nans(other, context=context) @@ -1023,6 +1031,8 @@ class Decimal(object): def __rsub__(self, other, context=None): """Return other + (-self)""" other = _convert_other(other) + if other is NotImplemented: + return other tmp = Decimal(self) tmp._sign = 1 - tmp._sign @@ -1068,6 +1078,8 @@ class Decimal(object): (+-) INF * 0 (or its reverse) raise InvalidOperation. """ other = _convert_other(other) + if other is NotImplemented: + return other if context is None: context = getcontext() @@ -1140,6 +1152,10 @@ class Decimal(object): computing the other value are not raised. """ other = _convert_other(other) + if other is NotImplemented: + if divmod in (0, 1): + return NotImplemented + return (NotImplemented, NotImplemented) if context is None: context = getcontext() @@ -1292,6 +1308,8 @@ class Decimal(object): def __rdiv__(self, other, context=None): """Swaps self/other and returns __div__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__div__(self, context=context) __rtruediv__ = __rdiv__ @@ -1304,6 +1322,8 @@ class Decimal(object): def __rdivmod__(self, other, context=None): """Swaps self/other and returns __divmod__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__divmod__(self, context=context) def __mod__(self, other, context=None): @@ -1311,6 +1331,8 @@ class Decimal(object): self % other """ other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: ans = self._check_nans(other, context) @@ -1325,6 +1347,8 @@ class Decimal(object): def __rmod__(self, other, context=None): """Swaps self/other and returns __mod__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__mod__(self, context=context) def remainder_near(self, other, context=None): @@ -1332,6 +1356,8 @@ class Decimal(object): Remainder nearest to 0- abs(remainder-near) <= other/2 """ other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: ans = self._check_nans(other, context) @@ -1411,6 +1437,8 @@ class Decimal(object): def __rfloordiv__(self, other, context=None): """Swaps self/other and returns __floordiv__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__floordiv__(self, context=context) def __float__(self): @@ -1661,6 +1689,8 @@ class Decimal(object): If modulo is None (default), don't take it mod modulo. """ n = _convert_other(n) + if n is NotImplemented: + return n if context is None: context = getcontext() @@ -1747,6 +1777,8 @@ class Decimal(object): def __rpow__(self, other, context=None): """Swaps self/other and returns __pow__.""" other = _convert_other(other) + if other is NotImplemented: + return other return other.__pow__(self, context=context) def normalize(self, context=None): @@ -2001,6 +2033,8 @@ class Decimal(object): NaN (and signals if one is sNaN). Also rounds. """ other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: # if one operand is a quiet NaN and the other is number, then the @@ -2048,6 +2082,8 @@ class Decimal(object): NaN (and signals if one is sNaN). Also rounds. """ other = _convert_other(other) + if other is NotImplemented: + return other if self._is_special or other._is_special: # if one operand is a quiet NaN and the other is number, then the @@ -2874,8 +2910,7 @@ def _convert_other(other): return other if isinstance(other, (int, long)): return Decimal(other) - - raise TypeError, "You can interact Decimal only with int, long or Decimal data types." + return NotImplemented _infinity_map = { 'inf' : 1, diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index fc1e048..34f034b 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -24,8 +24,6 @@ you're working through IDLE, you can import this test module and call test_main( with the corresponding argument. """ -from __future__ import division - import unittest import glob import os, sys @@ -54,9 +52,9 @@ if __name__ == '__main__': else: file = __file__ testdir = os.path.dirname(file) or os.curdir -dir = testdir + os.sep + TESTDATADIR + os.sep +directory = testdir + os.sep + TESTDATADIR + os.sep -skip_expected = not os.path.isdir(dir) +skip_expected = not os.path.isdir(directory) # Make sure it actually raises errors when not expected and caught in flags # Slower, since it runs some things several times. @@ -109,7 +107,6 @@ class DecimalTest(unittest.TestCase): Changed for unittest. """ def setUp(self): - global dir self.context = Context() for key in DefaultContext.traps.keys(): DefaultContext.traps[key] = 1 @@ -302,11 +299,11 @@ class DecimalTest(unittest.TestCase): # Dynamically build custom test definition for each file in the test # directory and add the definitions to the DecimalTest class. This # procedure insures that new files do not get skipped. -for filename in os.listdir(dir): +for filename in os.listdir(directory): if '.decTest' not in filename: continue head, tail = filename.split('.') - tester = lambda self, f=filename: self.eval_file(dir + f) + tester = lambda self, f=filename: self.eval_file(directory + f) setattr(DecimalTest, 'test_' + head, tester) del filename, head, tail, tester @@ -476,6 +473,52 @@ class DecimalImplicitConstructionTest(unittest.TestCase): def test_implicit_from_Decimal(self): self.assertEqual(Decimal(5) + Decimal(45), Decimal(50)) + def test_rop(self): + # Allow other classes to be trained to interact with Decimals + class E: + def __divmod__(self, other): + return 'divmod ' + str(other) + def __rdivmod__(self, other): + return str(other) + ' rdivmod' + def __lt__(self, other): + return 'lt ' + str(other) + def __gt__(self, other): + return 'gt ' + str(other) + def __le__(self, other): + return 'le ' + str(other) + def __ge__(self, other): + return 'ge ' + str(other) + def __eq__(self, other): + return 'eq ' + str(other) + def __ne__(self, other): + return 'ne ' + str(other) + + self.assertEqual(divmod(E(), Decimal(10)), 'divmod 10') + self.assertEqual(divmod(Decimal(10), E()), '10 rdivmod') + self.assertEqual(eval('Decimal(10) < E()'), 'gt 10') + self.assertEqual(eval('Decimal(10) > E()'), 'lt 10') + self.assertEqual(eval('Decimal(10) <= E()'), 'ge 10') + self.assertEqual(eval('Decimal(10) >= E()'), 'le 10') + self.assertEqual(eval('Decimal(10) == E()'), 'eq 10') + self.assertEqual(eval('Decimal(10) != E()'), 'ne 10') + + # insert operator methods and then exercise them + for sym, lop, rop in ( + ('+', '__add__', '__radd__'), + ('-', '__sub__', '__rsub__'), + ('*', '__mul__', '__rmul__'), + ('/', '__div__', '__rdiv__'), + ('%', '__mod__', '__rmod__'), + ('//', '__floordiv__', '__rfloordiv__'), + ('**', '__pow__', '__rpow__'), + ): + + setattr(E, lop, lambda self, other: 'str' + lop + str(other)) + setattr(E, rop, lambda self, other: str(other) + rop + 'str') + self.assertEqual(eval('E()' + sym + 'Decimal(10)'), + 'str' + lop + '10') + self.assertEqual(eval('Decimal(10)' + sym + 'E()'), + '10' + rop + 'str') class DecimalArithmeticOperatorsTest(unittest.TestCase): '''Unit tests for all arithmetic operators, binary and unary.''' |