diff options
Diffstat (limited to 'Lib/functools.py')
-rw-r--r-- | Lib/functools.py | 157 |
1 files changed, 81 insertions, 76 deletions
diff --git a/Lib/functools.py b/Lib/functools.py index db8cc82..6f79472 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -89,110 +89,115 @@ def wraps(wrapped, ### total_ordering class decorator ################################################################################ -# The correct way to indicate that a comparison operation doesn't -# recognise the other type is to return NotImplemented and let the -# interpreter handle raising TypeError if both operands return -# NotImplemented from their respective comparison methods -# -# This makes the implementation of total_ordering more complicated, since -# we need to be careful not to trigger infinite recursion when two -# different types that both use this decorator encounter each other. -# -# For example, if a type implements __lt__, it's natural to define -# __gt__ as something like: -# -# lambda self, other: not self < other and not self == other -# -# However, using the operator syntax like that ends up invoking the full -# type checking machinery again and means we can end up bouncing back and -# forth between the two operands until we run out of stack space. -# -# The solution is to define helper functions that invoke the appropriate -# magic methods directly, ensuring we only try each operand once, and -# return NotImplemented immediately if it is returned from the -# underlying user provided method. Using this scheme, the __gt__ derived -# from a user provided __lt__ becomes: -# -# 'def __gt__(self, other):' + _not_op_and_not_eq % '__lt__' - -# "not a < b" handles "a >= b" -# "not a <= b" handles "a > b" -# "not a >= b" handles "a < b" -# "not a > b" handles "a <= b" -_not_op = ''' - op_result = self.%s(other) +# The total ordering functions all invoke the root magic method directly +# rather than using the corresponding operator. This avoids possible +# infinite recursion that could occur when the operator dispatch logic +# detects a NotImplemented result and then calls a reflected method. + +def _gt_from_lt(self, other): + 'Return a > b. Computed by @total_ordering from (not a < b) and (a != b).' + op_result = self.__lt__(other) + if op_result is NotImplemented: + return NotImplemented + return not op_result and self != other + +def _le_from_lt(self, other): + 'Return a <= b. Computed by @total_ordering from (a < b) or (a == b).' + op_result = self.__lt__(other) + return op_result or self == other + +def _ge_from_lt(self, other): + 'Return a >= b. Computed by @total_ordering from (not a < b).' + op_result = self.__lt__(other) if op_result is NotImplemented: return NotImplemented return not op_result -''' -# "a > b or a == b" handles "a >= b" -# "a < b or a == b" handles "a <= b" -_op_or_eq = ''' - op_result = self.%s(other) +def _ge_from_le(self, other): + 'Return a >= b. Computed by @total_ordering from (not a <= b) or (a == b).' + op_result = self.__le__(other) if op_result is NotImplemented: return NotImplemented - return op_result or self == other -''' - -# "not (a < b or a == b)" handles "a > b" -# "not a < b and a != b" is equivalent -# "not (a > b or a == b)" handles "a < b" -# "not a > b and a != b" is equivalent -_not_op_and_not_eq = ''' - op_result = self.%s(other) + return not op_result or self == other + +def _lt_from_le(self, other): + 'Return a < b. Computed by @total_ordering from (a <= b) and (a != b).' + op_result = self.__le__(other) + if op_result is NotImplemented: + return NotImplemented + return op_result and self != other + +def _gt_from_le(self, other): + 'Return a > b. Computed by @total_ordering from (not a <= b).' + op_result = self.__le__(other) + if op_result is NotImplemented: + return NotImplemented + return not op_result + +def _lt_from_gt(self, other): + 'Return a < b. Computed by @total_ordering from (not a > b) and (a != b).' + op_result = self.__gt__(other) if op_result is NotImplemented: return NotImplemented return not op_result and self != other -''' -# "not a <= b or a == b" handles "a >= b" -# "not a >= b or a == b" handles "a <= b" -_not_op_or_eq = ''' - op_result = self.%s(other) +def _ge_from_gt(self, other): + 'Return a >= b. Computed by @total_ordering from (a > b) or (a == b).' + op_result = self.__gt__(other) + return op_result or self == other + +def _le_from_gt(self, other): + 'Return a <= b. Computed by @total_ordering from (not a > b).' + op_result = self.__gt__(other) + if op_result is NotImplemented: + return NotImplemented + return not op_result + +def _le_from_ge(self, other): + 'Return a <= b. Computed by @total_ordering from (not a >= b) or (a == b).' + op_result = self.__ge__(other) if op_result is NotImplemented: return NotImplemented return not op_result or self == other -''' -# "a <= b and not a == b" handles "a < b" -# "a >= b and not a == b" handles "a > b" -_op_and_not_eq = ''' - op_result = self.%s(other) +def _gt_from_ge(self, other): + 'Return a > b. Computed by @total_ordering from (a >= b) and (a != b).' + op_result = self.__ge__(other) if op_result is NotImplemented: return NotImplemented return op_result and self != other -''' + +def _lt_from_ge(self, other): + 'Return a < b. Computed by @total_ordering from (not a >= b).' + op_result = self.__ge__(other) + if op_result is NotImplemented: + return NotImplemented + return not op_result def total_ordering(cls): """Class decorator that fills in missing ordering methods""" convert = { - '__lt__': {'__gt__': _not_op_and_not_eq, - '__le__': _op_or_eq, - '__ge__': _not_op}, - '__le__': {'__ge__': _not_op_or_eq, - '__lt__': _op_and_not_eq, - '__gt__': _not_op}, - '__gt__': {'__lt__': _not_op_and_not_eq, - '__ge__': _op_or_eq, - '__le__': _not_op}, - '__ge__': {'__le__': _not_op_or_eq, - '__gt__': _op_and_not_eq, - '__lt__': _not_op} + '__lt__': [('__gt__', _gt_from_lt), + ('__le__', _le_from_lt), + ('__ge__', _ge_from_lt)], + '__le__': [('__ge__', _ge_from_le), + ('__lt__', _lt_from_le), + ('__gt__', _gt_from_le)], + '__gt__': [('__lt__', _lt_from_gt), + ('__ge__', _ge_from_gt), + ('__le__', _le_from_gt)], + '__ge__': [('__le__', _le_from_ge), + ('__gt__', _gt_from_ge), + ('__lt__', _lt_from_ge)] } # Find user-defined comparisons (not those inherited from object). roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)] if not roots: raise ValueError('must define at least one ordering operation: < > <= >=') root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__ - for opname, opfunc in convert[root].items(): + for opname, opfunc in convert[root]: if opname not in roots: - namespace = {} - exec('def %s(self, other):%s' % (opname, opfunc % root), namespace) - opfunc = namespace[opname] - opfunc.__qualname__ = '%s.%s' % (cls.__qualname__, opname) - opfunc.__module__ = cls.__module__ - opfunc.__doc__ = getattr(int, opname).__doc__ + opfunc.__name__ = opname setattr(cls, opname, opfunc) return cls |