diff options
-rw-r--r-- | Doc/library/enum.rst | 22 | ||||
-rw-r--r-- | Lib/enum.py | 152 | ||||
-rw-r--r-- | Lib/test/test_enum.py | 105 |
3 files changed, 193 insertions, 86 deletions
diff --git a/Doc/library/enum.rst b/Doc/library/enum.rst index eb8b94b..87aa8b1 100644 --- a/Doc/library/enum.rst +++ b/Doc/library/enum.rst @@ -674,6 +674,8 @@ while combinations of flags won't:: ... green = auto() ... white = red | blue | green ... + >>> Color.white + <Color.white: 7> Giving a name to the "no flags set" condition does not change its boolean value:: @@ -1068,3 +1070,23 @@ but not of the class:: >>> dir(Planet.EARTH) ['__class__', '__doc__', '__module__', 'name', 'surface_gravity', 'value'] + +Combining members of ``Flag`` +""""""""""""""""""""""""""""" + +If a combination of Flag members is not named, the :func:`repr` will include +all named flags and all named combinations of flags that are in the value:: + + >>> class Color(Flag): + ... red = auto() + ... green = auto() + ... blue = auto() + ... magenta = red | blue + ... yellow = red | green + ... cyan = green | blue + ... + >>> Color(3) # named combination + <Color.yellow: 3> + >>> Color(7) # not named combination + <Color.cyan|magenta|blue|yellow|green|red: 7> + diff --git a/Lib/enum.py b/Lib/enum.py index d830320..4beb187 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -1,7 +1,7 @@ import sys from types import MappingProxyType, DynamicClassAttribute from functools import reduce -from operator import or_ as _or_ +from operator import or_ as _or_, and_ as _and_, xor, neg # try _collections first to reduce startup cost try: @@ -47,11 +47,12 @@ def _make_class_unpicklable(cls): cls.__reduce_ex__ = _break_on_call_reduce cls.__module__ = '<unknown>' +_auto_null = object() class auto: """ Instances are replaced with an appropriate value in Enum class suites. """ - pass + value = _auto_null class _EnumDict(dict): @@ -77,7 +78,7 @@ class _EnumDict(dict): """ if _is_sunder(key): if key not in ( - '_order_', '_create_pseudo_member_', '_decompose_', + '_order_', '_create_pseudo_member_', '_generate_next_value_', '_missing_', ): raise ValueError('_names_ are reserved for future Enum use') @@ -94,7 +95,9 @@ class _EnumDict(dict): # enum overwriting a descriptor? raise TypeError('%r already defined as: %r' % (key, self[key])) if isinstance(value, auto): - value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:]) + if value.value == _auto_null: + value.value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:]) + value = value.value self._member_names.append(key) self._last_values.append(value) super().__setitem__(key, value) @@ -658,7 +661,7 @@ class Flag(Enum): try: high_bit = _high_bit(last_value) break - except TypeError: + except Exception: raise TypeError('Invalid Flag value: %r' % last_value) from None return 2 ** (high_bit+1) @@ -668,61 +671,38 @@ class Flag(Enum): if value < 0: value = ~value possible_member = cls._create_pseudo_member_(value) - for member in possible_member._decompose_(): - if member._name_ is None and member._value_ != 0: - raise ValueError('%r is not a valid %s' % (original_value, cls.__name__)) if original_value < 0: possible_member = ~possible_member return possible_member @classmethod def _create_pseudo_member_(cls, value): + """ + Create a composite member iff value contains only members. + """ pseudo_member = cls._value2member_map_.get(value, None) if pseudo_member is None: - # construct a non-singleton enum pseudo-member + # verify all bits are accounted for + _, extra_flags = _decompose(cls, value) + if extra_flags: + raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + # construct a singleton enum pseudo-member pseudo_member = object.__new__(cls) pseudo_member._name_ = None pseudo_member._value_ = value cls._value2member_map_[value] = pseudo_member return pseudo_member - def _decompose_(self): - """Extract all members from the value.""" - value = self._value_ - members = [] - cls = self.__class__ - for member in sorted(cls, key=lambda m: m._value_, reverse=True): - while _high_bit(value) > _high_bit(member._value_): - unknown = self._create_pseudo_member_(2 ** _high_bit(value)) - members.append(unknown) - value &= ~unknown._value_ - if ( - (value & member._value_ == member._value_) - and (member._value_ or not members) - ): - value &= ~member._value_ - members.append(member) - if not members or value: - members.append(self._create_pseudo_member_(value)) - members = list(members) - return members - def __contains__(self, other): if not isinstance(other, self.__class__): return NotImplemented return other._value_ & self._value_ == other._value_ - def __iter__(self): - if self.value == 0: - return iter([]) - else: - return iter(self._decompose_()) - def __repr__(self): cls = self.__class__ if self._name_ is not None: return '<%s.%s: %r>' % (cls.__name__, self._name_, self._value_) - members = self._decompose_() + members, uncovered = _decompose(cls, self._value_) return '<%s.%s: %r>' % ( cls.__name__, '|'.join([str(m._name_ or m._value_) for m in members]), @@ -733,7 +713,7 @@ class Flag(Enum): cls = self.__class__ if self._name_ is not None: return '%s.%s' % (cls.__name__, self._name_) - members = self._decompose_() + members, uncovered = _decompose(cls, self._value_) if len(members) == 1 and members[0]._name_ is None: return '%s.%r' % (cls.__name__, members[0]._value_) else: @@ -761,8 +741,11 @@ class Flag(Enum): return self.__class__(self._value_ ^ other._value_) def __invert__(self): - members = self._decompose_() - inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_] + members, uncovered = _decompose(self.__class__, self._value_) + inverted_members = [ + m for m in self.__class__ + if m not in members and not m._value_ & self._value_ + ] inverted = reduce(_or_, inverted_members, self.__class__(0)) return self.__class__(inverted) @@ -771,25 +754,45 @@ class IntFlag(int, Flag): """Support for integer-based Flags""" @classmethod + def _missing_(cls, value): + if not isinstance(value, int): + raise ValueError("%r is not a valid %s" % (value, cls.__name__)) + new_member = cls._create_pseudo_member_(value) + return new_member + + @classmethod def _create_pseudo_member_(cls, value): pseudo_member = cls._value2member_map_.get(value, None) if pseudo_member is None: - # construct a non-singleton enum pseudo-member - pseudo_member = int.__new__(cls, value) - pseudo_member._name_ = None - pseudo_member._value_ = value - cls._value2member_map_[value] = pseudo_member + need_to_create = [value] + # get unaccounted for bits + _, extra_flags = _decompose(cls, value) + # timer = 10 + while extra_flags: + # timer -= 1 + bit = _high_bit(extra_flags) + flag_value = 2 ** bit + if (flag_value not in cls._value2member_map_ and + flag_value not in need_to_create + ): + need_to_create.append(flag_value) + if extra_flags == -flag_value: + extra_flags = 0 + else: + extra_flags ^= flag_value + for value in reversed(need_to_create): + # construct singleton pseudo-members + pseudo_member = int.__new__(cls, value) + pseudo_member._name_ = None + pseudo_member._value_ = value + cls._value2member_map_[value] = pseudo_member return pseudo_member - @classmethod - def _missing_(cls, value): - possible_member = cls._create_pseudo_member_(value) - return possible_member - def __or__(self, other): if not isinstance(other, (self.__class__, int)): return NotImplemented - return self.__class__(self._value_ | self.__class__(other)._value_) + result = self.__class__(self._value_ | self.__class__(other)._value_) + return result def __and__(self, other): if not isinstance(other, (self.__class__, int)): @@ -806,17 +809,13 @@ class IntFlag(int, Flag): __rxor__ = __xor__ def __invert__(self): - # members = self._decompose_() - # inverted_members = [m for m in self.__class__ if m not in members and not m._value_ & self._value_] - # inverted = reduce(_or_, inverted_members, self.__class__(0)) - return self.__class__(~self._value_) - - + result = self.__class__(~self._value_) + return result def _high_bit(value): """returns index of highest bit, or -1 if value is zero or negative""" - return value.bit_length() - 1 if value > 0 else -1 + return value.bit_length() - 1 def unique(enumeration): """Class decorator for enumerations ensuring unique member values.""" @@ -830,3 +829,40 @@ def unique(enumeration): raise ValueError('duplicate values found in %r: %s' % (enumeration, alias_details)) return enumeration + +def _decompose(flag, value): + """Extract all members from the value.""" + # _decompose is only called if the value is not named + not_covered = value + negative = value < 0 + if negative: + # only check for named flags + flags_to_check = [ + (m, v) + for v, m in flag._value2member_map_.items() + if m.name is not None + ] + else: + # check for named flags and powers-of-two flags + flags_to_check = [ + (m, v) + for v, m in flag._value2member_map_.items() + if m.name is not None or _power_of_two(v) + ] + members = [] + for member, member_value in flags_to_check: + if member_value and member_value & value == member_value: + members.append(member) + not_covered &= ~member_value + if not members and value in flag._value2member_map_: + members.append(flag._value2member_map_[value]) + members.sort(key=lambda m: m._value_, reverse=True) + if len(members) > 1 and members[0].value == value: + # we have the breakdown, don't need the value member itself + members.pop(0) + return members, not_covered + +def _power_of_two(value): + if value < 1: + return False + return value == 2 ** _high_bit(value) diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 153bfb4..2b3bfea 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -1634,6 +1634,13 @@ class TestEnum(unittest.TestCase): self.assertEqual(Color.blue.value, 2) self.assertEqual(Color.green.value, 3) + def test_duplicate_auto(self): + class Dupes(Enum): + first = primero = auto() + second = auto() + third = auto() + self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes)) + class TestOrder(unittest.TestCase): @@ -1731,7 +1738,7 @@ class TestFlag(unittest.TestCase): self.assertEqual(str(Open.AC), 'Open.AC') self.assertEqual(str(Open.RO | Open.CE), 'Open.CE') self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO') - self.assertEqual(str(~Open.RO), 'Open.CE|AC') + self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO') self.assertEqual(str(~Open.WO), 'Open.CE|RW') self.assertEqual(str(~Open.AC), 'Open.CE') self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC') @@ -1758,7 +1765,7 @@ class TestFlag(unittest.TestCase): self.assertEqual(repr(Open.AC), '<Open.AC: 3>') self.assertEqual(repr(Open.RO | Open.CE), '<Open.CE: 524288>') self.assertEqual(repr(Open.WO | Open.CE), '<Open.CE|WO: 524289>') - self.assertEqual(repr(~Open.RO), '<Open.CE|AC: 524291>') + self.assertEqual(repr(~Open.RO), '<Open.CE|AC|RW|WO: 524291>') self.assertEqual(repr(~Open.WO), '<Open.CE|RW: 524290>') self.assertEqual(repr(~Open.AC), '<Open.CE: 524288>') self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC: 3>') @@ -1949,6 +1956,33 @@ class TestFlag(unittest.TestCase): red = 'not an int' blue = auto() + def test_cascading_failure(self): + class Bizarre(Flag): + c = 3 + d = 4 + f = 6 + # Bizarre.c | Bizarre.d + self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5) + self.assertRaisesRegex(ValueError, "5 is not a valid Bizarre", Bizarre, 5) + self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2) + self.assertRaisesRegex(ValueError, "2 is not a valid Bizarre", Bizarre, 2) + self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1) + self.assertRaisesRegex(ValueError, "1 is not a valid Bizarre", Bizarre, 1) + + def test_duplicate_auto(self): + class Dupes(Enum): + first = primero = auto() + second = auto() + third = auto() + self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes)) + + def test_bizarre(self): + class Bizarre(Flag): + b = 3 + c = 4 + d = 6 + self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>') + class TestIntFlag(unittest.TestCase): """Tests of the IntFlags.""" @@ -1965,6 +1999,21 @@ class TestIntFlag(unittest.TestCase): AC = 3 CE = 1<<19 + def test_type(self): + Perm = self.Perm + Open = self.Open + for f in Perm: + self.assertTrue(isinstance(f, Perm)) + self.assertEqual(f, f.value) + self.assertTrue(isinstance(Perm.W | Perm.X, Perm)) + self.assertEqual(Perm.W | Perm.X, 3) + for f in Open: + self.assertTrue(isinstance(f, Open)) + self.assertEqual(f, f.value) + self.assertTrue(isinstance(Open.WO | Open.RW, Open)) + self.assertEqual(Open.WO | Open.RW, 3) + + def test_str(self): Perm = self.Perm self.assertEqual(str(Perm.R), 'Perm.R') @@ -1975,14 +2024,14 @@ class TestIntFlag(unittest.TestCase): self.assertEqual(str(Perm.R | 8), 'Perm.8|R') self.assertEqual(str(Perm(0)), 'Perm.0') self.assertEqual(str(Perm(8)), 'Perm.8') - self.assertEqual(str(~Perm.R), 'Perm.W|X|-8') - self.assertEqual(str(~Perm.W), 'Perm.R|X|-8') - self.assertEqual(str(~Perm.X), 'Perm.R|W|-8') - self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X|-8') + self.assertEqual(str(~Perm.R), 'Perm.W|X') + self.assertEqual(str(~Perm.W), 'Perm.R|X') + self.assertEqual(str(~Perm.X), 'Perm.R|W') + self.assertEqual(str(~(Perm.R | Perm.W)), 'Perm.X') self.assertEqual(str(~(Perm.R | Perm.W | Perm.X)), 'Perm.-8') - self.assertEqual(str(~(Perm.R | 8)), 'Perm.W|X|-16') - self.assertEqual(str(Perm(~0)), 'Perm.R|W|X|-8') - self.assertEqual(str(Perm(~8)), 'Perm.R|W|X|-16') + self.assertEqual(str(~(Perm.R | 8)), 'Perm.W|X') + self.assertEqual(str(Perm(~0)), 'Perm.R|W|X') + self.assertEqual(str(Perm(~8)), 'Perm.R|W|X') Open = self.Open self.assertEqual(str(Open.RO), 'Open.RO') @@ -1991,12 +2040,12 @@ class TestIntFlag(unittest.TestCase): self.assertEqual(str(Open.RO | Open.CE), 'Open.CE') self.assertEqual(str(Open.WO | Open.CE), 'Open.CE|WO') self.assertEqual(str(Open(4)), 'Open.4') - self.assertEqual(str(~Open.RO), 'Open.CE|AC|-524292') - self.assertEqual(str(~Open.WO), 'Open.CE|RW|-524292') - self.assertEqual(str(~Open.AC), 'Open.CE|-524292') - self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC|-524292') - self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW|-524292') - self.assertEqual(str(Open(~4)), 'Open.CE|AC|-524296') + self.assertEqual(str(~Open.RO), 'Open.CE|AC|RW|WO') + self.assertEqual(str(~Open.WO), 'Open.CE|RW') + self.assertEqual(str(~Open.AC), 'Open.CE') + self.assertEqual(str(~(Open.RO | Open.CE)), 'Open.AC|RW|WO') + self.assertEqual(str(~(Open.WO | Open.CE)), 'Open.RW') + self.assertEqual(str(Open(~4)), 'Open.CE|AC|RW|WO') def test_repr(self): Perm = self.Perm @@ -2008,14 +2057,14 @@ class TestIntFlag(unittest.TestCase): self.assertEqual(repr(Perm.R | 8), '<Perm.8|R: 12>') self.assertEqual(repr(Perm(0)), '<Perm.0: 0>') self.assertEqual(repr(Perm(8)), '<Perm.8: 8>') - self.assertEqual(repr(~Perm.R), '<Perm.W|X|-8: -5>') - self.assertEqual(repr(~Perm.W), '<Perm.R|X|-8: -3>') - self.assertEqual(repr(~Perm.X), '<Perm.R|W|-8: -2>') - self.assertEqual(repr(~(Perm.R | Perm.W)), '<Perm.X|-8: -7>') + self.assertEqual(repr(~Perm.R), '<Perm.W|X: -5>') + self.assertEqual(repr(~Perm.W), '<Perm.R|X: -3>') + self.assertEqual(repr(~Perm.X), '<Perm.R|W: -2>') + self.assertEqual(repr(~(Perm.R | Perm.W)), '<Perm.X: -7>') self.assertEqual(repr(~(Perm.R | Perm.W | Perm.X)), '<Perm.-8: -8>') - self.assertEqual(repr(~(Perm.R | 8)), '<Perm.W|X|-16: -13>') - self.assertEqual(repr(Perm(~0)), '<Perm.R|W|X|-8: -1>') - self.assertEqual(repr(Perm(~8)), '<Perm.R|W|X|-16: -9>') + self.assertEqual(repr(~(Perm.R | 8)), '<Perm.W|X: -13>') + self.assertEqual(repr(Perm(~0)), '<Perm.R|W|X: -1>') + self.assertEqual(repr(Perm(~8)), '<Perm.R|W|X: -9>') Open = self.Open self.assertEqual(repr(Open.RO), '<Open.RO: 0>') @@ -2024,12 +2073,12 @@ class TestIntFlag(unittest.TestCase): self.assertEqual(repr(Open.RO | Open.CE), '<Open.CE: 524288>') self.assertEqual(repr(Open.WO | Open.CE), '<Open.CE|WO: 524289>') self.assertEqual(repr(Open(4)), '<Open.4: 4>') - self.assertEqual(repr(~Open.RO), '<Open.CE|AC|-524292: -1>') - self.assertEqual(repr(~Open.WO), '<Open.CE|RW|-524292: -2>') - self.assertEqual(repr(~Open.AC), '<Open.CE|-524292: -4>') - self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC|-524292: -524289>') - self.assertEqual(repr(~(Open.WO | Open.CE)), '<Open.RW|-524292: -524290>') - self.assertEqual(repr(Open(~4)), '<Open.CE|AC|-524296: -5>') + self.assertEqual(repr(~Open.RO), '<Open.CE|AC|RW|WO: -1>') + self.assertEqual(repr(~Open.WO), '<Open.CE|RW: -2>') + self.assertEqual(repr(~Open.AC), '<Open.CE: -4>') + self.assertEqual(repr(~(Open.RO | Open.CE)), '<Open.AC|RW|WO: -524289>') + self.assertEqual(repr(~(Open.WO | Open.CE)), '<Open.RW: -524290>') + self.assertEqual(repr(Open(~4)), '<Open.CE|AC|RW|WO: -5>') def test_or(self): Perm = self.Perm |