diff options
author | Ethan Furman <ethan@stoneleaf.us> | 2016-08-31 07:12:15 (GMT) |
---|---|---|
committer | Ethan Furman <ethan@stoneleaf.us> | 2016-08-31 07:12:15 (GMT) |
commit | ee47e5cf8ad0e52e2c5291662b9b15c2ba8848ea (patch) | |
tree | 5b80118379916e520e54f6652398c2fbc26cc34b /Lib/enum.py | |
parent | bfbaa6b206abdb8b1c3861926f4334b879ec91cc (diff) | |
download | cpython-ee47e5cf8ad0e52e2c5291662b9b15c2ba8848ea.zip cpython-ee47e5cf8ad0e52e2c5291662b9b15c2ba8848ea.tar.gz cpython-ee47e5cf8ad0e52e2c5291662b9b15c2ba8848ea.tar.bz2 |
issue23591: add Flags, IntFlags, and tests
Diffstat (limited to 'Lib/enum.py')
-rw-r--r-- | Lib/enum.py | 227 |
1 files changed, 217 insertions, 10 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index e7889a8..e89c17d 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -1,5 +1,7 @@ import sys from types import MappingProxyType, DynamicClassAttribute +from functools import reduce +from operator import or_ as _or_ # try _collections first to reduce startup cost try: @@ -8,7 +10,7 @@ except ImportError: from collections import OrderedDict -__all__ = ['EnumMeta', 'Enum', 'IntEnum', 'unique'] +__all__ = ['EnumMeta', 'Enum', 'IntEnum', 'Flags', 'IntFlags', 'unique'] def _is_descriptor(obj): @@ -64,7 +66,10 @@ class _EnumDict(dict): """ if _is_sunder(key): - if key not in ('_order_', ): + if key not in ( + '_order_', '_create_pseudo_member_', '_decompose_', + '_generate_next_value_', '_missing_', + ): raise ValueError('_names_ are reserved for future Enum use') elif _is_dunder(key): if key == '__order__': @@ -75,7 +80,7 @@ class _EnumDict(dict): elif not _is_descriptor(value): if key in self: # enum overwriting a descriptor? - raise TypeError('Key already defined as: %r' % self[key]) + raise TypeError('%r already defined as: %r' % (key, self[key])) self._member_names.append(key) super().__setitem__(key, value) @@ -91,9 +96,15 @@ class EnumMeta(type): """Metaclass for Enum""" @classmethod def __prepare__(metacls, cls, bases): - return _EnumDict() + # create the namespace dict + enum_dict = _EnumDict() + # inherit previous flags and _generate_next_value_ function + member_type, first_enum = metacls._get_mixins_(bases) + if first_enum is not None: + enum_dict['_generate_next_value_'] = getattr(first_enum, '_generate_next_value_', None) + return enum_dict - def __new__(metacls, cls, bases, classdict): + def __new__(metacls, cls, bases, classdict, **kwds): # an Enum class is final once enumeration items have been defined; it # cannot be mixed with other types (int, float, etc.) if it has an # inherited __new__ unless a new __new__ is defined (or the resulting @@ -104,7 +115,7 @@ class EnumMeta(type): # save enum items into separate mapping so they don't get baked into # the new class - members = {k: classdict[k] for k in classdict._member_names} + enum_members = {k: classdict[k] for k in classdict._member_names} for name in classdict._member_names: del classdict[name] @@ -112,7 +123,7 @@ class EnumMeta(type): _order_ = classdict.pop('_order_', None) # check for illegal enum names (any others?) - invalid_names = set(members) & {'mro', } + invalid_names = set(enum_members) & {'mro', } if invalid_names: raise ValueError('Invalid enum member name: {0}'.format( ','.join(invalid_names))) @@ -156,7 +167,7 @@ class EnumMeta(type): # a custom __new__ is doing something funky with the values -- such as # auto-numbering ;) for member_name in classdict._member_names: - value = members[member_name] + value = enum_members[member_name] if not isinstance(value, tuple): args = (value, ) else: @@ -170,7 +181,10 @@ class EnumMeta(type): else: enum_member = __new__(enum_class, *args) if not hasattr(enum_member, '_value_'): - enum_member._value_ = member_type(*args) + if member_type is object: + enum_member._value_ = value + else: + enum_member._value_ = member_type(*args) value = enum_member._value_ enum_member._name_ = member_name enum_member.__objclass__ = enum_class @@ -344,13 +358,18 @@ class EnumMeta(type): """ metacls = cls.__class__ bases = (cls, ) if type is None else (type, cls) + _, first_enum = cls._get_mixins_(bases) classdict = metacls.__prepare__(class_name, bases) # special processing needed for names? if isinstance(names, str): names = names.replace(',', ' ').split() if isinstance(names, (tuple, list)) and isinstance(names[0], str): - names = [(e, i) for (i, e) in enumerate(names, start)] + original_names, names = names, [] + last_value = None + for count, name in enumerate(original_names): + last_value = first_enum._generate_next_value_(name, start, count, last_value) + names.append((name, last_value)) # Here, names is either an iterable of (name, value) or a mapping. for item in names: @@ -492,6 +511,16 @@ class Enum(metaclass=EnumMeta): for member in cls._member_map_.values(): if member._value_ == value: return member + # still not found -- try _missing_ hook + return cls._missing_(value) + + @staticmethod + def _generate_next_value_(name, start, count, last_value): + if not count: + return start + return last_value + 1 + @classmethod + def _missing_(cls, value): raise ValueError("%r is not a valid %s" % (value, cls.__name__)) def __repr__(self): @@ -585,6 +614,184 @@ class IntEnum(int, Enum): def _reduce_ex_by_name(self, proto): return self.name +class Flags(Enum): + """Support for flags""" + @staticmethod + def _generate_next_value_(name, start, count, last_value): + """ + Generate the next value when not given. + + name: the name of the member + start: the initital start value or None + count: the number of existing members + last_value: the last value assigned or None + """ + if not count: + return start if start is not None else 1 + high_bit = _high_bit(last_value) + return 2 ** (high_bit+1) + + @classmethod + def _missing_(cls, value): + original_value = value + if value < 0: + value = ~value + possible_member = cls._create_pseudo_member_(value) + for member in possible_member._decompose_(): + if member._name_ is None and member._value_ != 0: + raise ValueError('%r is not a valid %s' % (original_value, cls.__name__)) + if original_value < 0: + possible_member = ~possible_member + return possible_member + + @classmethod + def _create_pseudo_member_(cls, value): + pseudo_member = cls._value2member_map_.get(value, None) + if pseudo_member is None: + # construct a non-singleton enum pseudo-member + pseudo_member = object.__new__(cls) + pseudo_member._name_ = None + pseudo_member._value_ = value + cls._value2member_map_[value] = pseudo_member + return pseudo_member + + def _decompose_(self): + """Extract all members from the value.""" + value = self._value_ + members = [] + cls = self.__class__ + for member in sorted(cls, key=lambda m: m._value_, reverse=True): + while _high_bit(value) > _high_bit(member._value_): + unknown = self._create_pseudo_member_(2 ** _high_bit(value)) + members.append(unknown) + value &= ~unknown._value_ + if ( + (value & member._value_ == member._value_) + and (member._value_ or not members) + ): + value &= ~member._value_ + members.append(member) + if not members or value: + members.append(self._create_pseudo_member_(value)) + members = list(members) + return members + + def __contains__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return other._value_ & self._value_ == other._value_ + + def __iter__(self): + if self.value == 0: + return iter([]) + else: + return iter(self._decompose_()) + + def __repr__(self): + cls = self.__class__ + if self._name_ is not None: + return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_) + members = self._decompose_() + if len(members) == 1 and members[0]._name_ is None: + return '<%s: %r>' % (cls.__name__, members[0]._value_) + else: + return '<%s.%s: %r>' % ( + cls.__name__, + '|'.join([str(m._name_ or m._value_) for m in members]), + self._value_, + ) + + def __str__(self): + cls = self.__class__ + if self._name_ is not None: + return '%s.%s' % (cls.__name__, self._name_) + members = self._decompose_() + if len(members) == 1 and members[0]._name_ is None: + return '%s.%r' % (cls.__name__, members[0]._value_) + else: + return '%s.%s' % ( + cls.__name__, + '|'.join([str(m._name_ or m._value_) for m in members]), + ) + + def __or__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.__class__(self._value_ | other._value_) + + def __and__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.__class__(self._value_ & other._value_) + + def __xor__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented + return self.__class__(self._value_ ^ other._value_) + + def __invert__(self): + members = self._decompose_() + inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_] + inverted = reduce(_or_, inverted_members, self.__class__(0)) + return self.__class__(inverted) + + +class IntFlags(int, Flags): + """Support for integer-based Flags""" + + @classmethod + def _create_pseudo_member_(cls, value): + pseudo_member = cls._value2member_map_.get(value, None) + if pseudo_member is None: + # construct a non-singleton enum pseudo-member + pseudo_member = int.__new__(cls, value) + pseudo_member._name_ = None + pseudo_member._value_ = value + cls._value2member_map_[value] = pseudo_member + return pseudo_member + + @classmethod + def _missing_(cls, value): + possible_member = cls._create_pseudo_member_(value) + return possible_member + + def __or__(self, other): + if not isinstance(other, (self.__class__, int)): + return NotImplemented + return self.__class__(self._value_ | self.__class__(other)._value_) + + def __and__(self, other): + if not isinstance(other, (self.__class__, int)): + return NotImplemented + return self.__class__(self._value_ & self.__class__(other)._value_) + + def __xor__(self, other): + if not isinstance(other, (self.__class__, int)): + return NotImplemented + return self.__class__(self._value_ ^ self.__class__(other)._value_) + + __ror__ = __or__ + __rand__ = __and__ + __rxor__ = __xor__ + + def __invert__(self): + # members = self._decompose_() + # inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_] + # inverted = reduce(_or_, inverted_members, self.__class__(0)) + return self.__class__(~self._value_) + + + + +def _high_bit(value): + """return the highest bit set in value""" + bit = 0 + while 'looking for the highest bit': + limit = 2 ** bit + if limit > value: + return bit - 1 + bit += 1 + def unique(enumeration): """Class decorator for enumerations ensuring unique member values.""" duplicates = [] |