summaryrefslogtreecommitdiffstats
path: root/Lib/enum.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/enum.py')
-rw-r--r--Lib/enum.py67
1 files changed, 39 insertions, 28 deletions
diff --git a/Lib/enum.py b/Lib/enum.py
index 10902c4..432d745 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -275,6 +275,13 @@ class _proto_member:
enum_member.__objclass__ = enum_class
enum_member.__init__(*args)
enum_member._sort_order_ = len(enum_class._member_names_)
+
+ if Flag is not None and issubclass(enum_class, Flag):
+ enum_class._flag_mask_ |= value
+ if _is_single_bit(value):
+ enum_class._singles_mask_ |= value
+ enum_class._all_bits_ = 2 ** ((enum_class._flag_mask_).bit_length()) - 1
+
# If another member with the same value was already defined, the
# new member becomes an alias to the existing one.
try:
@@ -532,12 +539,8 @@ class EnumType(type):
classdict['_use_args_'] = use_args
#
# convert future enum members into temporary _proto_members
- # and record integer values in case this will be a Flag
- flag_mask = 0
for name in member_names:
value = classdict[name]
- if isinstance(value, int):
- flag_mask |= value
classdict[name] = _proto_member(value)
#
# house-keeping structures
@@ -554,8 +557,9 @@ class EnumType(type):
boundary
or getattr(first_enum, '_boundary_', None)
)
- classdict['_flag_mask_'] = flag_mask
- classdict['_all_bits_'] = 2 ** ((flag_mask).bit_length()) - 1
+ classdict['_flag_mask_'] = 0
+ classdict['_singles_mask_'] = 0
+ classdict['_all_bits_'] = 0
classdict['_inverted_'] = None
try:
exc = None
@@ -644,21 +648,10 @@ class EnumType(type):
):
delattr(enum_class, '_boundary_')
delattr(enum_class, '_flag_mask_')
+ delattr(enum_class, '_singles_mask_')
delattr(enum_class, '_all_bits_')
delattr(enum_class, '_inverted_')
elif Flag is not None and issubclass(enum_class, Flag):
- # ensure _all_bits_ is correct and there are no missing flags
- single_bit_total = 0
- multi_bit_total = 0
- for flag in enum_class._member_map_.values():
- flag_value = flag._value_
- if _is_single_bit(flag_value):
- single_bit_total |= flag_value
- else:
- # multi-bit flags are considered aliases
- multi_bit_total |= flag_value
- enum_class._flag_mask_ = single_bit_total
- #
# set correct __iter__
member_list = [m._value_ for m in enum_class]
if member_list != sorted(member_list):
@@ -1303,8 +1296,8 @@ def _reduce_ex_by_global_name(self, proto):
class FlagBoundary(StrEnum):
"""
control how out of range values are handled
- "strict" -> error is raised
- "conform" -> extra bits are discarded [default for Flag]
+ "strict" -> error is raised [default for Flag]
+ "conform" -> extra bits are discarded
"eject" -> lose flag status
"keep" -> keep flag status and all bits [default for IntFlag]
"""
@@ -1315,7 +1308,7 @@ class FlagBoundary(StrEnum):
STRICT, CONFORM, EJECT, KEEP = FlagBoundary
-class Flag(Enum, boundary=CONFORM):
+class Flag(Enum, boundary=STRICT):
"""
Support for flags
"""
@@ -1394,6 +1387,7 @@ class Flag(Enum, boundary=CONFORM):
# - value must not include any skipped flags (e.g. if bit 2 is not
# defined, then 0d10 is invalid)
flag_mask = cls._flag_mask_
+ singles_mask = cls._singles_mask_
all_bits = cls._all_bits_
neg_value = None
if (
@@ -1425,7 +1419,8 @@ class Flag(Enum, boundary=CONFORM):
value = all_bits + 1 + value
# get members and unknown
unknown = value & ~flag_mask
- member_value = value & flag_mask
+ aliases = value & ~singles_mask
+ member_value = value & singles_mask
if unknown and cls._boundary_ is not KEEP:
raise ValueError(
'%s(%r) --> unknown values %r [%s]'
@@ -1439,11 +1434,25 @@ class Flag(Enum, boundary=CONFORM):
pseudo_member = cls._member_type_.__new__(cls, value)
if not hasattr(pseudo_member, '_value_'):
pseudo_member._value_ = value
- if member_value:
- pseudo_member._name_ = '|'.join([
- m._name_ for m in cls._iter_member_(member_value)
- ])
- if unknown:
+ if member_value or aliases:
+ members = []
+ combined_value = 0
+ for m in cls._iter_member_(member_value):
+ members.append(m)
+ combined_value |= m._value_
+ if aliases:
+ value = member_value | aliases
+ for n, pm in cls._member_map_.items():
+ if pm not in members and pm._value_ and pm._value_ & value == pm._value_:
+ members.append(pm)
+ combined_value |= pm._value_
+ unknown = value ^ combined_value
+ pseudo_member._name_ = '|'.join([m._name_ for m in members])
+ if not combined_value:
+ pseudo_member._name_ = None
+ elif unknown and cls._boundary_ is STRICT:
+ raise ValueError('%r: no members with value %r' % (cls, unknown))
+ elif unknown:
pseudo_member._name_ += '|%s' % cls._numeric_repr_(unknown)
else:
pseudo_member._name_ = None
@@ -1675,6 +1684,7 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
body['_boundary_'] = boundary or etype._boundary_
body['_flag_mask_'] = None
body['_all_bits_'] = None
+ body['_singles_mask_'] = None
body['_inverted_'] = None
body['__or__'] = Flag.__or__
body['__xor__'] = Flag.__xor__
@@ -1750,7 +1760,8 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None):
else:
multi_bits |= value
gnv_last_values.append(value)
- enum_class._flag_mask_ = single_bits
+ enum_class._flag_mask_ = single_bits | multi_bits
+ enum_class._singles_mask_ = single_bits
enum_class._all_bits_ = 2 ** ((single_bits|multi_bits).bit_length()) - 1
# set correct __iter__
member_list = [m._value_ for m in enum_class]