summaryrefslogtreecommitdiffstats
path: root/Lib/enum.py
diff options
context:
space:
mode:
authorEthan Furman <ethan@stoneleaf.us>2022-01-23 02:27:52 (GMT)
committerGitHub <noreply@github.com>2022-01-23 02:27:52 (GMT)
commit353e3b2820bed38da16140276786eef9ba33d3bd (patch)
tree022c0e1482678051795c6645511890e30396b0e0 /Lib/enum.py
parent976dec9b3b35fddbaa893c99297e0c54731451b5 (diff)
downloadcpython-353e3b2820bed38da16140276786eef9ba33d3bd.zip
cpython-353e3b2820bed38da16140276786eef9ba33d3bd.tar.gz
cpython-353e3b2820bed38da16140276786eef9ba33d3bd.tar.bz2
bpo-46477: [Enum] ensure Flag subclasses have correct bitwise methods (GH-30816)
Diffstat (limited to 'Lib/enum.py')
-rw-r--r--Lib/enum.py80
1 files changed, 41 insertions, 39 deletions
diff --git a/Lib/enum.py b/Lib/enum.py
index b510467..85245c9 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -618,6 +618,18 @@ class EnumType(type):
if name not in classdict:
setattr(enum_class, name, getattr(first_enum, name))
#
+ # for Flag, add __or__, __and__, __xor__, and __invert__
+ if Flag is not None and issubclass(enum_class, Flag):
+ for name in (
+ '__or__', '__and__', '__xor__',
+ '__ror__', '__rand__', '__rxor__',
+ '__invert__'
+ ):
+ if name not in classdict:
+ enum_method = getattr(Flag, name)
+ setattr(enum_class, name, enum_method)
+ classdict[name] = enum_method
+ #
# replace any other __new__ with our own (as long as Enum is not None,
# anyway) -- again, this is to support pickle
if Enum is not None:
@@ -1467,43 +1479,9 @@ class Flag(Enum, boundary=STRICT):
return bool(self._value_)
def __or__(self, other):
- if not isinstance(other, self.__class__):
- return NotImplemented
- return self.__class__(self._value_ | other._value_)
-
- def __and__(self, other):
- if not isinstance(other, self.__class__):
- return NotImplemented
- return self.__class__(self._value_ & other._value_)
-
- def __xor__(self, other):
- if not isinstance(other, self.__class__):
- return NotImplemented
- return self.__class__(self._value_ ^ other._value_)
-
- def __invert__(self):
- if self._inverted_ is None:
- if self._boundary_ is KEEP:
- # use all bits
- self._inverted_ = self.__class__(~self._value_)
- else:
- # calculate flags not in this member
- self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
- if isinstance(self._inverted_, self.__class__):
- self._inverted_._inverted_ = self
- return self._inverted_
-
-
-class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
- """
- Support for integer-based Flags
- """
-
-
- def __or__(self, other):
if isinstance(other, self.__class__):
other = other._value_
- elif isinstance(other, int):
+ elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
@@ -1513,7 +1491,7 @@ class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
def __and__(self, other):
if isinstance(other, self.__class__):
other = other._value_
- elif isinstance(other, int):
+ elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
@@ -1523,17 +1501,34 @@ class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
def __xor__(self, other):
if isinstance(other, self.__class__):
other = other._value_
- elif isinstance(other, int):
+ elif self._member_type_ is not object and isinstance(other, self._member_type_):
other = other
else:
return NotImplemented
value = self._value_
return self.__class__(value ^ other)
- __ror__ = __or__
+ def __invert__(self):
+ if self._inverted_ is None:
+ if self._boundary_ is KEEP:
+ # use all bits
+ self._inverted_ = self.__class__(~self._value_)
+ else:
+ # calculate flags not in this member
+ self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_)
+ if isinstance(self._inverted_, self.__class__):
+ self._inverted_._inverted_ = self
+ return self._inverted_
+
__rand__ = __and__
+ __ror__ = __or__
__rxor__ = __xor__
- __invert__ = Flag.__invert__
+
+
+class IntFlag(int, ReprEnum, Flag, boundary=EJECT):
+ """
+ Support for integer-based Flags
+ """
def _high_bit(value):
@@ -1662,6 +1657,13 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
body['_flag_mask_'] = None
body['_all_bits_'] = None
body['_inverted_'] = None
+ body['__or__'] = Flag.__or__
+ body['__xor__'] = Flag.__xor__
+ body['__and__'] = Flag.__and__
+ body['__ror__'] = Flag.__ror__
+ body['__rxor__'] = Flag.__rxor__
+ body['__rand__'] = Flag.__rand__
+ body['__invert__'] = Flag.__invert__
for name, obj in cls.__dict__.items():
if name in ('__dict__', '__weakref__'):
continue