diff options
Diffstat (limited to 'Lib/enum.py')
-rw-r--r-- | Lib/enum.py | 324 |
1 files changed, 309 insertions, 15 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index b8787d1..4beb187 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -1,8 +1,20 @@ import sys -from collections import OrderedDict from types import MappingProxyType, DynamicClassAttribute +from functools import reduce +from operator import or_ as _or_, and_ as _and_, xor, neg -__all__ = ['Enum', 'IntEnum', 'unique'] +# try _collections first to reduce startup cost +try: + from _collections import OrderedDict +except ImportError: + from collections import OrderedDict + + +__all__ = [ + 'EnumMeta', + 'Enum', 'IntEnum', 'Flag', 'IntFlag', + 'auto', 'unique', + ] def _is_descriptor(obj): @@ -28,7 +40,6 @@ def _is_sunder(name): name[-2:-1] != '_' and len(name) > 2) - def _make_class_unpicklable(cls): """Make the given class un-picklable.""" def _break_on_call_reduce(self, proto): @@ -36,6 +47,13 @@ def _make_class_unpicklable(cls): cls.__reduce_ex__ = _break_on_call_reduce cls.__module__ = '<unknown>' +_auto_null = object() +class auto: + """ + Instances are replaced with an appropriate value in Enum class suites. + """ + value = _auto_null + class _EnumDict(dict): """Track enum member order and ensure member names are not reused. @@ -47,6 +65,7 @@ class _EnumDict(dict): def __init__(self): super().__init__() self._member_names = [] + self._last_values = [] def __setitem__(self, key, value): """Changes anything not dundered or not a descriptor. @@ -58,21 +77,32 @@ class _EnumDict(dict): """ if _is_sunder(key): - raise ValueError('_names_ are reserved for future Enum use') + if key not in ( + '_order_', '_create_pseudo_member_', + '_generate_next_value_', '_missing_', + ): + raise ValueError('_names_ are reserved for future Enum use') + if key == '_generate_next_value_': + setattr(self, '_generate_next_value', value) elif _is_dunder(key): - pass + if key == '__order__': + key = '_order_' elif key in self._member_names: # descriptor overwriting an enum? raise TypeError('Attempted to reuse key: %r' % key) elif not _is_descriptor(value): if key in self: # enum overwriting a descriptor? - raise TypeError('Key already defined as: %r' % self[key]) + raise TypeError('%r already defined as: %r' % (key, self[key])) + if isinstance(value, auto): + if value.value == _auto_null: + value.value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:]) + value = value.value self._member_names.append(key) + self._last_values.append(value) super().__setitem__(key, value) - # Dummy value for Enum as EnumMeta explicitly checks for it, but of course # until EnumMeta finishes running the first time the Enum class doesn't exist. # This is also why there are checks in EnumMeta like `if Enum is not None` @@ -83,7 +113,13 @@ class EnumMeta(type): """Metaclass for Enum""" @classmethod def __prepare__(metacls, cls, bases): - return _EnumDict() + # create the namespace dict + enum_dict = _EnumDict() + # inherit previous flags and _generate_next_value_ function + member_type, first_enum = metacls._get_mixins_(bases) + if first_enum is not None: + enum_dict['_generate_next_value_'] = getattr(first_enum, '_generate_next_value_', None) + return enum_dict def __new__(metacls, cls, bases, classdict): # an Enum class is final once enumeration items have been defined; it @@ -96,12 +132,15 @@ class EnumMeta(type): # save enum items into separate mapping so they don't get baked into # the new class - members = {k: classdict[k] for k in classdict._member_names} + enum_members = {k: classdict[k] for k in classdict._member_names} for name in classdict._member_names: del classdict[name] + # adjust the sunders + _order_ = classdict.pop('_order_', None) + # check for illegal enum names (any others?) - invalid_names = set(members) & {'mro', } + invalid_names = set(enum_members) & {'mro', } if invalid_names: raise ValueError('Invalid enum member name: {0}'.format( ','.join(invalid_names))) @@ -145,7 +184,7 @@ class EnumMeta(type): # a custom __new__ is doing something funky with the values -- such as # auto-numbering ;) for member_name in classdict._member_names: - value = members[member_name] + value = enum_members[member_name] if not isinstance(value, tuple): args = (value, ) else: @@ -159,7 +198,10 @@ class EnumMeta(type): else: enum_member = __new__(enum_class, *args) if not hasattr(enum_member, '_value_'): - enum_member._value_ = member_type(*args) + if member_type is object: + enum_member._value_ = value + else: + enum_member._value_ = member_type(*args) value = enum_member._value_ enum_member._name_ = member_name enum_member.__objclass__ = enum_class @@ -204,6 +246,14 @@ class EnumMeta(type): if save_new: enum_class.__new_member__ = __new__ enum_class.__new__ = Enum.__new__ + + # py3 support for definition order (helps keep py2/py3 code in sync) + if _order_ is not None: + if isinstance(_order_, str): + _order_ = _order_.replace(',', ' ').split() + if _order_ != enum_class._member_names_: + raise TypeError('member order does not match _order_') + return enum_class def __bool__(self): @@ -325,13 +375,19 @@ class EnumMeta(type): """ metacls = cls.__class__ bases = (cls, ) if type is None else (type, cls) + _, first_enum = cls._get_mixins_(bases) classdict = metacls.__prepare__(class_name, bases) # special processing needed for names? if isinstance(names, str): names = names.replace(',', ' ').split() if isinstance(names, (tuple, list)) and isinstance(names[0], str): - names = [(e, i) for (i, e) in enumerate(names, start)] + original_names, names = names, [] + last_values = [] + for count, name in enumerate(original_names): + value = first_enum._generate_next_value_(name, start, count, last_values[:]) + last_values.append(value) + names.append((name, value)) # Here, names is either an iterable of (name, value) or a mapping. for item in names: @@ -473,6 +529,20 @@ class Enum(metaclass=EnumMeta): for member in cls._member_map_.values(): if member._value_ == value: return member + # still not found -- try _missing_ hook + return cls._missing_(value) + + def _generate_next_value_(name, start, count, last_values): + for last_value in reversed(last_values): + try: + return last_value + 1 + except TypeError: + pass + else: + return start + + @classmethod + def _missing_(cls, value): raise ValueError("%r is not a valid %s" % (value, cls.__name__)) def __repr__(self): @@ -544,8 +614,21 @@ class Enum(metaclass=EnumMeta): source = vars(source) else: source = module_globals - members = {name: value for name, value in source.items() - if filter(name)} + # We use an OrderedDict of sorted source keys so that the + # _value2member_map is populated in the same order every time + # for a consistent reverse mapping of number to name when there + # are multiple names for the same number rather than varying + # between runs due to hash randomization of the module dictionary. + members = [ + (name, source[name]) + for name in source.keys() + if filter(name)] + try: + # sort by value + members.sort(key=lambda t: (t[1], t[0])) + except TypeError: + # unless some values aren't comparable, in which case sort by name + members.sort(key=lambda t: t[0]) cls = cls(name, members, module=module) cls.__reduce_ex__ = _reduce_ex_by_name module_globals.update(cls.__members__) @@ -560,6 +643,180 @@ class IntEnum(int, Enum): def _reduce_ex_by_name(self, proto): return self.name +class Flag(Enum): + """Support for flags""" + + def _generate_next_value_(name, start, count, last_values): + """ + Generate the next value when not given. + + name: the name of the member + start: the initital start value or None + count: the number of existing members + last_value: the last value assigned or None + """ + if not count: + return start if start is not None else 1 + for last_value in reversed(last_values): + try: + high_bit = _high_bit(last_value) + break + except Exception: + raise TypeError('Invalid Flag value: %r' % last_value) from None + return 2 ** (high_bit+1) + + @classmethod + def _missing_(cls, value): + original_value = value + if value < 0: + value = ~value + possible_member = cls._create_pseudo_member_(value) + if original_value < 0: + possible_member = ~possible_member + return possible_member + + @classmethod + def _create_pseudo_member_(cls, value): + """ + Create a composite member iff value contains only members. + """ + pseudo_member = cls._value2member_map_.get(value, None) + if pseudo_member is None: + # verify all bits are accounted for + _, extra_flags = _decompose(cls, value) + if extra_flags: + raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + # construct a singleton enum pseudo-member + pseudo_member = object.__new__(cls) + pseudo_member._name_ = None + pseudo_member._value_ = value + cls._value2member_map_[value] = pseudo_member + return pseudo_member + + def __contains__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other._value_ & self._value_ == other._value_ + + def __repr__(self): + cls = self.__class__ + if self._name_ is not None: + return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_) + members, uncovered = _decompose(cls, self._value_) + return '<%s.%s: %r>' % ( + cls.__name__, + '|'.join([str(m._name_ or m._value_) for m in members]), + self._value_, + ) + + def __str__(self): + cls = self.__class__ + if self._name_ is not None: + return '%s.%s' % (cls.__name__, self._name_) + members, uncovered = _decompose(cls, self._value_) + if len(members) == 1 and members[0]._name_ is None: + return '%s.%r' % (cls.__name__, members[0]._value_) + else: + return '%s.%s' % ( + cls.__name__, + '|'.join([str(m._name_ or m._value_) for m in members]), + ) + + def __bool__(self): + return bool(self._value_) + + def __or__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.__class__(self._value_ | other._value_) + + def __and__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.__class__(self._value_ & other._value_) + + def __xor__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.__class__(self._value_ ^ other._value_) + + def __invert__(self): + members, uncovered = _decompose(self.__class__, self._value_) + inverted_members = [ + m for m in self.__class__ + if m not in members and not m._value_ & self._value_ + ] + inverted = reduce(_or_, inverted_members, self.__class__(0)) + return self.__class__(inverted) + + +class IntFlag(int, Flag): + """Support for integer-based Flags""" + + @classmethod + def _missing_(cls, value): + if not isinstance(value, int): + raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + new_member = cls._create_pseudo_member_(value) + return new_member + + @classmethod + def _create_pseudo_member_(cls, value): + pseudo_member = cls._value2member_map_.get(value, None) + if pseudo_member is None: + need_to_create = [value] + # get unaccounted for bits + _, extra_flags = _decompose(cls, value) + # timer = 10 + while extra_flags: + # timer -= 1 + bit = _high_bit(extra_flags) + flag_value = 2 ** bit + if (flag_value not in cls._value2member_map_ and + flag_value not in need_to_create + ): + need_to_create.append(flag_value) + if extra_flags == -flag_value: + extra_flags = 0 + else: + extra_flags ^= flag_value + for value in reversed(need_to_create): + # construct singleton pseudo-members + pseudo_member = int.__new__(cls, value) + pseudo_member._name_ = None + pseudo_member._value_ = value + cls._value2member_map_[value] = pseudo_member + return pseudo_member + + def __or__(self, other): + if not isinstance(other, (self.__class__, int)): + return NotImplemented + result = self.__class__(self._value_ | self.__class__(other)._value_) + return result + + def __and__(self, other): + if not isinstance(other, (self.__class__, int)): + return NotImplemented + return self.__class__(self._value_ & self.__class__(other)._value_) + + def __xor__(self, other): + if not isinstance(other, (self.__class__, int)): + return NotImplemented + return self.__class__(self._value_ ^ self.__class__(other)._value_) + + __ror__ = __or__ + __rand__ = __and__ + __rxor__ = __xor__ + + def __invert__(self): + result = self.__class__(~self._value_) + return result + + +def _high_bit(value): + """returns index of highest bit, or -1 if value is zero or negative""" + return value.bit_length() - 1 + def unique(enumeration): """Class decorator for enumerations ensuring unique member values.""" duplicates = [] @@ -572,3 +829,40 @@ def unique(enumeration): raise ValueError('duplicate values found in %r: %s' % (enumeration, alias_details)) return enumeration + +def _decompose(flag, value): + """Extract all members from the value.""" + # _decompose is only called if the value is not named + not_covered = value + negative = value < 0 + if negative: + # only check for named flags + flags_to_check = [ + (m, v) + for v, m in flag._value2member_map_.items() + if m.name is not None + ] + else: + # check for named flags and powers-of-two flags + flags_to_check = [ + (m, v) + for v, m in flag._value2member_map_.items() + if m.name is not None or _power_of_two(v) + ] + members = [] + for member, member_value in flags_to_check: + if member_value and member_value & value == member_value: + members.append(member) + not_covered &= ~member_value + if not members and value in flag._value2member_map_: + members.append(flag._value2member_map_[value]) + members.sort(key=lambda m: m._value_, reverse=True) + if len(members) > 1 and members[0].value == value: + # we have the breakdown, don't need the value member itself + members.pop(0) + return members, not_covered + +def _power_of_two(value): + if value < 1: + return False + return value == 2 ** _high_bit(value) |