diff options
author | Ethan Furman <ethan@stoneleaf.us> | 2022-01-23 02:27:52 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-23 02:27:52 (GMT) |
commit | 353e3b2820bed38da16140276786eef9ba33d3bd (patch) | |
tree | 022c0e1482678051795c6645511890e30396b0e0 /Lib/enum.py | |
parent | 976dec9b3b35fddbaa893c99297e0c54731451b5 (diff) | |
download | cpython-353e3b2820bed38da16140276786eef9ba33d3bd.zip cpython-353e3b2820bed38da16140276786eef9ba33d3bd.tar.gz cpython-353e3b2820bed38da16140276786eef9ba33d3bd.tar.bz2 |
bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)
Diffstat (limited to 'Lib/enum.py')
-rw-r--r-- | Lib/enum.py | 80 |
1 files changed, 41 insertions, 39 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index b510467..85245c9 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -618,6 +618,18 @@ class EnumType(type): if name not in classdict: setattr(enum_class, name, getattr(first_enum, name)) # + # for Flag, add __or__, __and__, __xor__, and __invert__ + if Flag is not None and issubclass(enum_class, Flag): + for name in ( + '__or__', '__and__', '__xor__', + '__ror__', '__rand__', '__rxor__', + '__invert__' + ): + if name not in classdict: + enum_method = getattr(Flag, name) + setattr(enum_class, name, enum_method) + classdict[name] = enum_method + # # replace any other __new__ with our own (as long as Enum is not None, # anyway) -- again, this is to support pickle if Enum is not None: @@ -1467,43 +1479,9 @@ class Flag(Enum, boundary=STRICT): return bool(self._value_) def __or__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.__class__(self._value_ | other._value_) - - def __and__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.__class__(self._value_ & other._value_) - - def __xor__(self, other): - if not isinstance(other, self.__class__): - return NotImplemented - return self.__class__(self._value_ ^ other._value_) - - def __invert__(self): - if self._inverted_ is None: - if self._boundary_ is KEEP: - # use all bits - self._inverted_ = self.__class__(~self._value_) - else: - # calculate flags not in this member - self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_) - if isinstance(self._inverted_, self.__class__): - self._inverted_._inverted_ = self - return self._inverted_ - - -class IntFlag(int, ReprEnum, Flag, boundary=EJECT): - """ - Support for integer-based Flags - """ - - - def __or__(self, other): if isinstance(other, self.__class__): other = other._value_ - elif isinstance(other, int): + elif self._member_type_ is not object and isinstance(other, self._member_type_): other = other else: return NotImplemented @@ -1513,7 +1491,7 @@ class IntFlag(int, ReprEnum, Flag, boundary=EJECT): def __and__(self, other): if isinstance(other, self.__class__): other = other._value_ - elif isinstance(other, int): + elif self._member_type_ is not object and isinstance(other, self._member_type_): other = other else: return NotImplemented @@ -1523,17 +1501,34 @@ class IntFlag(int, ReprEnum, Flag, boundary=EJECT): def __xor__(self, other): if isinstance(other, self.__class__): other = other._value_ - elif isinstance(other, int): + elif self._member_type_ is not object and isinstance(other, self._member_type_): other = other else: return NotImplemented value = self._value_ return self.__class__(value ^ other) - __ror__ = __or__ + def __invert__(self): + if self._inverted_ is None: + if self._boundary_ is KEEP: + # use all bits + self._inverted_ = self.__class__(~self._value_) + else: + # calculate flags not in this member + self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_) + if isinstance(self._inverted_, self.__class__): + self._inverted_._inverted_ = self + return self._inverted_ + __rand__ = __and__ + __ror__ = __or__ __rxor__ = __xor__ - __invert__ = Flag.__invert__ + + +class IntFlag(int, ReprEnum, Flag, boundary=EJECT): + """ + Support for integer-based Flags + """ def _high_bit(value): @@ -1662,6 +1657,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): body['_flag_mask_'] = None body['_all_bits_'] = None body['_inverted_'] = None + body['__or__'] = Flag.__or__ + body['__xor__'] = Flag.__xor__ + body['__and__'] = Flag.__and__ + body['__ror__'] = Flag.__ror__ + body['__rxor__'] = Flag.__rxor__ + body['__rand__'] = Flag.__rand__ + body['__invert__'] = Flag.__invert__ for name, obj in cls.__dict__.items(): if name in ('__dict__', '__weakref__'): continue |