diff options
Diffstat (limited to 'Lib/enum.py')
-rw-r--r-- | Lib/enum.py | 152 |
1 files changed, 94 insertions, 58 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index d830320..4beb187 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -1,7 +1,7 @@ import sys from types import MappingProxyType, DynamicClassAttribute from functools import reduce -from operator import or_ as _or_ +from operator import or_ as _or_, and_ as _and_, xor, neg # try _collections first to reduce startup cost try: @@ -47,11 +47,12 @@ def _make_class_unpicklable(cls): cls.__reduce_ex__ = _break_on_call_reduce cls.__module__ = '<unknown>' +_auto_null = object() class auto: """ Instances are replaced with an appropriate value in Enum class suites. """ - pass + value = _auto_null class _EnumDict(dict): @@ -77,7 +78,7 @@ class _EnumDict(dict): """ if _is_sunder(key): if key not in ( - '_order_', '_create_pseudo_member_', '_decompose_', + '_order_', '_create_pseudo_member_', '_generate_next_value_', '_missing_', ): raise ValueError('_names_ are reserved for future Enum use') @@ -94,7 +95,9 @@ class _EnumDict(dict): # enum overwriting a descriptor? raise TypeError('%r already defined as: %r' % (key, self[key])) if isinstance(value, auto): - value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:]) + if value.value == _auto_null: + value.value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:]) + value = value.value self._member_names.append(key) self._last_values.append(value) super().__setitem__(key, value) @@ -658,7 +661,7 @@ class Flag(Enum): try: high_bit = _high_bit(last_value) break - except TypeError: + except Exception: raise TypeError('Invalid Flag value: %r' % last_value) from None return 2 ** (high_bit+1) @@ -668,61 +671,38 @@ class Flag(Enum): if value < 0: value = ~value possible_member = cls._create_pseudo_member_(value) - for member in possible_member._decompose_(): - if member._name_ is None and member._value_ != 0: - raise ValueError('%r is not a valid %s' % (original_value, cls.__name__)) if original_value < 0: possible_member = ~possible_member return possible_member @classmethod def _create_pseudo_member_(cls, value): + """ + Create a composite member iff value contains only members. + """ pseudo_member = cls._value2member_map_.get(value, None) if pseudo_member is None: - # construct a non-singleton enum pseudo-member + # verify all bits are accounted for + _, extra_flags = _decompose(cls, value) + if extra_flags: + raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + # construct a singleton enum pseudo-member pseudo_member = object.__new__(cls) pseudo_member._name_ = None pseudo_member._value_ = value cls._value2member_map_[value] = pseudo_member return pseudo_member - def _decompose_(self): - """Extract all members from the value.""" - value = self._value_ - members = [] - cls = self.__class__ - for member in sorted(cls, key=lambda m: m._value_, reverse=True): - while _high_bit(value) > _high_bit(member._value_): - unknown = self._create_pseudo_member_(2 ** _high_bit(value)) - members.append(unknown) - value &= ~unknown._value_ - if ( - (value & member._value_ == member._value_) - and (member._value_ or not members) - ): - value &= ~member._value_ - members.append(member) - if not members or value: - members.append(self._create_pseudo_member_(value)) - members = list(members) - return members - def __contains__(self, other): if not isinstance(other, self.__class__): return NotImplemented return other._value_ & self._value_ == other._value_ - def __iter__(self): - if self.value == 0: - return iter([]) - else: - return iter(self._decompose_()) - def __repr__(self): cls = self.__class__ if self._name_ is not None: return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_) - members = self._decompose_() + members, uncovered = _decompose(cls, self._value_) return '<%s.%s: %r>' % ( cls.__name__, '|'.join([str(m._name_ or m._value_) for m in members]), @@ -733,7 +713,7 @@ class Flag(Enum): cls = self.__class__ if self._name_ is not None: return '%s.%s' % (cls.__name__, self._name_) - members = self._decompose_() + members, uncovered = _decompose(cls, self._value_) if len(members) == 1 and members[0]._name_ is None: return '%s.%r' % (cls.__name__, members[0]._value_) else: @@ -761,8 +741,11 @@ class Flag(Enum): return self.__class__(self._value_ ^ other._value_) def __invert__(self): - members = self._decompose_() - inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_] + members, uncovered = _decompose(self.__class__, self._value_) + inverted_members = [ + m for m in self.__class__ + if m not in members and not m._value_ & self._value_ + ] inverted = reduce(_or_, inverted_members, self.__class__(0)) return self.__class__(inverted) @@ -771,25 +754,45 @@ class IntFlag(int, Flag): """Support for integer-based Flags""" @classmethod + def _missing_(cls, value): + if not isinstance(value, int): + raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + new_member = cls._create_pseudo_member_(value) + return new_member + + @classmethod def _create_pseudo_member_(cls, value): pseudo_member = cls._value2member_map_.get(value, None) if pseudo_member is None: - # construct a non-singleton enum pseudo-member - pseudo_member = int.__new__(cls, value) - pseudo_member._name_ = None - pseudo_member._value_ = value - cls._value2member_map_[value] = pseudo_member + need_to_create = [value] + # get unaccounted for bits + _, extra_flags = _decompose(cls, value) + # timer = 10 + while extra_flags: + # timer -= 1 + bit = _high_bit(extra_flags) + flag_value = 2 ** bit + if (flag_value not in cls._value2member_map_ and + flag_value not in need_to_create + ): + need_to_create.append(flag_value) + if extra_flags == -flag_value: + extra_flags = 0 + else: + extra_flags ^= flag_value + for value in reversed(need_to_create): + # construct singleton pseudo-members + pseudo_member = int.__new__(cls, value) + pseudo_member._name_ = None + pseudo_member._value_ = value + cls._value2member_map_[value] = pseudo_member return pseudo_member - @classmethod - def _missing_(cls, value): - possible_member = cls._create_pseudo_member_(value) - return possible_member - def __or__(self, other): if not isinstance(other, (self.__class__, int)): return NotImplemented - return self.__class__(self._value_ | self.__class__(other)._value_) + result = self.__class__(self._value_ | self.__class__(other)._value_) + return result def __and__(self, other): if not isinstance(other, (self.__class__, int)): @@ -806,17 +809,13 @@ class IntFlag(int, Flag): __rxor__ = __xor__ def __invert__(self): - # members = self._decompose_() - # inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_] - # inverted = reduce(_or_, inverted_members, self.__class__(0)) - return self.__class__(~self._value_) - - + result = self.__class__(~self._value_) + return result def _high_bit(value): """returns index of highest bit, or -1 if value is zero or negative""" - return value.bit_length() - 1 if value > 0 else -1 + return value.bit_length() - 1 def unique(enumeration): """Class decorator for enumerations ensuring unique member values.""" @@ -830,3 +829,40 @@ def unique(enumeration): raise ValueError('duplicate values found in %r: %s' % (enumeration, alias_details)) return enumeration + +def _decompose(flag, value): + """Extract all members from the value.""" + # _decompose is only called if the value is not named + not_covered = value + negative = value < 0 + if negative: + # only check for named flags + flags_to_check = [ + (m, v) + for v, m in flag._value2member_map_.items() + if m.name is not None + ] + else: + # check for named flags and powers-of-two flags + flags_to_check = [ + (m, v) + for v, m in flag._value2member_map_.items() + if m.name is not None or _power_of_two(v) + ] + members = [] + for member, member_value in flags_to_check: + if member_value and member_value & value == member_value: + members.append(member) + not_covered &= ~member_value + if not members and value in flag._value2member_map_: + members.append(flag._value2member_map_[value]) + members.sort(key=lambda m: m._value_, reverse=True) + if len(members) > 1 and members[0].value == value: + # we have the breakdown, don't need the value member itself + members.pop(0) + return members, not_covered + +def _power_of_two(value): + if value < 1: + return False + return value == 2 ** _high_bit(value) |