diff options
Diffstat (limited to 'Lib/functools.py')
| -rw-r--r-- | Lib/functools.py | 94 | 
1 files changed, 82 insertions, 12 deletions
| diff --git a/Lib/functools.py b/Lib/functools.py index 19f88c7..6a6974f 100644 --- a/Lib/functools.py +++ b/Lib/functools.py @@ -89,21 +89,91 @@ def wraps(wrapped,  ### total_ordering class decorator  ################################################################################ +# The correct way to indicate that a comparison operation doesn't +# recognise the other type is to return NotImplemented and let the +# interpreter handle raising TypeError if both operands return +# NotImplemented from their respective comparison methods +# +# This makes the implementation of total_ordering more complicated, since +# we need to be careful not to trigger infinite recursion when two +# different types that both use this decorator encounter each other. +# +# For example, if a type implements __lt__, it's natural to define +# __gt__ as something like: +# +#    lambda self, other: not self < other and not self == other +# +# However, using the operator syntax like that ends up invoking the full +# type checking machinery again and means we can end up bouncing back and +# forth between the two operands until we run out of stack space. +# +# The solution is to define helper functions that invoke the appropriate +# magic methods directly, ensuring we only try each operand once, and +# return NotImplemented immediately if it is returned from the +# underlying user provided method. Using this scheme, the __gt__ derived +# from a user provided __lt__ becomes: +# +#    lambda self, other: _not_op_and_not_eq(self.__lt__, self, other)) + +def _not_op(op, other): +    # "not a < b" handles "a >= b" +    # "not a <= b" handles "a > b" +    # "not a >= b" handles "a < b" +    # "not a > b" handles "a <= b" +    op_result = op(other) +    if op_result is NotImplemented: +        return NotImplemented +    return not op_result + +def _op_or_eq(op, self, other): +    # "a < b or a == b" handles "a <= b" +    # "a > b or a == b" handles "a >= b" +    op_result = op(other) +    if op_result is NotImplemented: +        return NotImplemented +    return op_result or self == other + +def _not_op_and_not_eq(op, self, other): +    # "not (a < b or a == b)" handles "a > b" +    # "not a < b and a != b" is equivalent +    # "not (a > b or a == b)" handles "a < b" +    # "not a > b and a != b" is equivalent +    op_result = op(other) +    if op_result is NotImplemented: +        return NotImplemented +    return not op_result and self != other + +def _not_op_or_eq(op, self, other): +    # "not a <= b or a == b" handles "a >= b" +    # "not a >= b or a == b" handles "a <= b" +    op_result = op(other) +    if op_result is NotImplemented: +        return NotImplemented +    return not op_result or self == other + +def _op_and_not_eq(op, self, other): +    # "a <= b and not a == b" handles "a < b" +    # "a >= b and not a == b" handles "a > b" +    op_result = op(other) +    if op_result is NotImplemented: +        return NotImplemented +    return op_result and self != other +  def total_ordering(cls):      """Class decorator that fills in missing ordering methods"""      convert = { -        '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)), -                   ('__le__', lambda self, other: self < other or self == other), -                   ('__ge__', lambda self, other: not self < other)], -        '__le__': [('__ge__', lambda self, other: not self <= other or self == other), -                   ('__lt__', lambda self, other: self <= other and not self == other), -                   ('__gt__', lambda self, other: not self <= other)], -        '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)), -                   ('__ge__', lambda self, other: self > other or self == other), -                   ('__le__', lambda self, other: not self > other)], -        '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other), -                   ('__gt__', lambda self, other: self >= other and not self == other), -                   ('__lt__', lambda self, other: not self >= other)] +        '__lt__': [('__gt__', lambda self, other: _not_op_and_not_eq(self.__lt__, self, other)), +                   ('__le__', lambda self, other: _op_or_eq(self.__lt__, self, other)), +                   ('__ge__', lambda self, other: _not_op(self.__lt__, other))], +        '__le__': [('__ge__', lambda self, other: _not_op_or_eq(self.__le__, self, other)), +                   ('__lt__', lambda self, other: _op_and_not_eq(self.__le__, self, other)), +                   ('__gt__', lambda self, other: _not_op(self.__le__, other))], +        '__gt__': [('__lt__', lambda self, other: _not_op_and_not_eq(self.__gt__, self, other)), +                   ('__ge__', lambda self, other: _op_or_eq(self.__gt__, self, other)), +                   ('__le__', lambda self, other: _not_op(self.__gt__, other))], +        '__ge__': [('__le__', lambda self, other: _not_op_or_eq(self.__ge__, self, other)), +                   ('__gt__', lambda self, other: _op_and_not_eq(self.__ge__, self, other)), +                   ('__lt__', lambda self, other: _not_op(self.__ge__, other))]      }      # Find user-defined comparisons (not those inherited from object).      roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)] | 
