diff options
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/enum.py | 107 | ||||
-rw-r--r-- | Lib/test/test_enum.py | 144 |
2 files changed, 225 insertions, 26 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index 01f4310..f74cc8c 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -6,10 +6,10 @@ from builtins import property as _bltin_property, bin as _bltin_bin __all__ = [ 'EnumType', 'EnumMeta', 'Enum', 'IntEnum', 'StrEnum', 'Flag', 'IntFlag', - 'auto', 'unique', - 'property', + 'auto', 'unique', 'property', 'verify', 'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT', 'KEEP', 'global_flag_repr', 'global_enum_repr', 'global_enum', + 'EnumCheck', 'CONTINUOUS', 'NAMED_FLAGS', 'UNIQUE', ] @@ -89,6 +89,9 @@ def _make_class_unpicklable(obj): setattr(obj, '__module__', '<unknown>') def _iter_bits_lsb(num): + # num must be an integer + if isinstance(num, Enum): + num = num.value while num: b = num & (~num + 1) yield b @@ -538,13 +541,6 @@ class EnumType(type): else: # multi-bit flags are considered aliases multi_bit_total |= flag_value - if enum_class._boundary_ is not KEEP: - missed = list(_iter_bits_lsb(multi_bit_total & ~single_bit_total)) - if missed: - raise TypeError( - 'invalid Flag %r -- missing values: %s' - % (cls, ', '.join((str(i) for i in missed))) - ) enum_class._flag_mask_ = single_bit_total # # set correct __iter__ @@ -688,7 +684,10 @@ class EnumType(type): return MappingProxyType(cls._member_map_) def __repr__(cls): - return "<enum %r>" % cls.__name__ + if Flag is not None and issubclass(cls, Flag): + return "<flag %r>" % cls.__name__ + else: + return "<enum %r>" % cls.__name__ def __reversed__(cls): """ @@ -1303,7 +1302,8 @@ class Flag(Enum, boundary=STRICT): else: # calculate flags not in this member self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_) - self._inverted_._inverted_ = self + if isinstance(self._inverted_, self.__class__): + self._inverted_._inverted_ = self return self._inverted_ @@ -1561,6 +1561,91 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): return enum_class return convert_class +@_simple_enum(StrEnum) +class EnumCheck: + """ + various conditions to check an enumeration for + """ + CONTINUOUS = "no skipped integer values" + NAMED_FLAGS = "multi-flag aliases may not contain unnamed flags" + UNIQUE = "one name per value" +CONTINUOUS, NAMED_FLAGS, UNIQUE = EnumCheck + + +class verify: + """ + Check an enumeration for various constraints. (see EnumCheck) + """ + def __init__(self, *checks): + self.checks = checks + def __call__(self, enumeration): + checks = self.checks + cls_name = enumeration.__name__ + if Flag is not None and issubclass(enumeration, Flag): + enum_type = 'flag' + elif issubclass(enumeration, Enum): + enum_type = 'enum' + else: + raise TypeError("the 'verify' decorator only works with Enum and Flag") + for check in checks: + if check is UNIQUE: + # check for duplicate names + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + alias_details = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates]) + raise ValueError('aliases found in %r: %s' % + (enumeration, alias_details)) + elif check is CONTINUOUS: + values = set(e.value for e in enumeration) + if len(values) < 2: + continue + low, high = min(values), max(values) + missing = [] + if enum_type == 'flag': + # check for powers of two + for i in range(_high_bit(low)+1, _high_bit(high)): + if 2**i not in values: + missing.append(2**i) + elif enum_type == 'enum': + # check for powers of one + for i in range(low+1, high): + if i not in values: + missing.append(i) + else: + raise Exception('verify: unknown type %r' % enum_type) + if missing: + raise ValueError('invalid %s %r: missing values %s' % ( + enum_type, cls_name, ', '.join((str(m) for m in missing))) + ) + elif check is NAMED_FLAGS: + # examine each alias and check for unnamed flags + member_names = enumeration._member_names_ + member_values = [m.value for m in enumeration] + missing = [] + for name, alias in enumeration._member_map_.items(): + if name in member_names: + # not an alias + continue + values = list(_iter_bits_lsb(alias.value)) + missed = [v for v in values if v not in member_values] + if missed: + plural = ('', 's')[len(missed) > 1] + a = ('a ', '')[len(missed) > 1] + missing.append('%r is missing %snamed flag%s for value%s %s' % ( + name, a, plural, plural, + ', '.join(str(v) for v in missed) + )) + if missing: + raise ValueError( + 'invalid Flag %r: %s' + % (cls_name, '; '.join(missing)) + ) + return enumeration + def _test_simple_enum(checked_enum, simple_enum): """ A function that can be used to test an enum created with :func:`_simple_enum` diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index e918b03..34b190b 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -9,6 +9,7 @@ import threading from collections import OrderedDict from enum import Enum, IntEnum, StrEnum, EnumType, Flag, IntFlag, unique, auto from enum import STRICT, CONFORM, EJECT, KEEP, _simple_enum, _test_simple_enum +from enum import verify, UNIQUE, CONTINUOUS, NAMED_FLAGS from io import StringIO from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL from test import support @@ -2774,13 +2775,6 @@ class TestFlag(unittest.TestCase): third = auto() self.assertEqual([Dupes.first, Dupes.second, Dupes.third], list(Dupes)) - def test_bizarre(self): - with self.assertRaisesRegex(TypeError, "invalid Flag 'Bizarre' -- missing values: 1, 2"): - class Bizarre(Flag): - b = 3 - c = 4 - d = 6 - def test_multiple_mixin(self): class AllMixin: @classproperty @@ -3345,12 +3339,6 @@ class TestIntFlag(unittest.TestCase): for f in Open: self.assertEqual(bool(f.value), bool(f)) - def test_bizarre(self): - with self.assertRaisesRegex(TypeError, "invalid Flag 'Bizarre' -- missing values: 1, 2"): - class Bizarre(IntFlag): - b = 3 - c = 4 - d = 6 def test_multiple_mixin(self): class AllMixin: @@ -3459,6 +3447,7 @@ class TestUnique(unittest.TestCase): one = 1 two = 'dos' tres = 4.0 + # @unique class Cleaner(IntEnum): single = 1 @@ -3484,12 +3473,137 @@ class TestUnique(unittest.TestCase): turkey = 3 def test_unique_with_name(self): - @unique + @verify(UNIQUE) class Silly(Enum): one = 1 two = 'dos' name = 3 - @unique + # + @verify(UNIQUE) + class Sillier(IntEnum): + single = 1 + name = 2 + triple = 3 + value = 4 + +class TestVerify(unittest.TestCase): + + def test_continuous(self): + @verify(CONTINUOUS) + class Auto(Enum): + FIRST = auto() + SECOND = auto() + THIRD = auto() + FORTH = auto() + # + @verify(CONTINUOUS) + class Manual(Enum): + FIRST = 3 + SECOND = 4 + THIRD = 5 + FORTH = 6 + # + with self.assertRaisesRegex(ValueError, 'invalid enum .Missing.: missing values 5, 6, 7, 8, 9, 10, 12'): + @verify(CONTINUOUS) + class Missing(Enum): + FIRST = 3 + SECOND = 4 + THIRD = 11 + FORTH = 13 + # + with self.assertRaisesRegex(ValueError, 'invalid flag .Incomplete.: missing values 32'): + @verify(CONTINUOUS) + class Incomplete(Flag): + FIRST = 4 + SECOND = 8 + THIRD = 16 + FORTH = 64 + # + with self.assertRaisesRegex(ValueError, 'invalid flag .StillIncomplete.: missing values 16'): + @verify(CONTINUOUS) + class StillIncomplete(Flag): + FIRST = 4 + SECOND = 8 + THIRD = 11 + FORTH = 32 + + + def test_composite(self): + class Bizarre(Flag): + b = 3 + c = 4 + d = 6 + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertEqual(Bizarre.b.value, 3) + self.assertEqual(Bizarre.c.value, 4) + self.assertEqual(Bizarre.d.value, 6) + with self.assertRaisesRegex( + ValueError, + "invalid Flag 'Bizarre': 'b' is missing named flags for values 1, 2; 'd' is missing a named flag for value 2", + ): + @verify(NAMED_FLAGS) + class Bizarre(Flag): + b = 3 + c = 4 + d = 6 + # + class Bizarre(IntFlag): + b = 3 + c = 4 + d = 6 + self.assertEqual(list(Bizarre), [Bizarre.c]) + self.assertEqual(Bizarre.b.value, 3) + self.assertEqual(Bizarre.c.value, 4) + self.assertEqual(Bizarre.d.value, 6) + with self.assertRaisesRegex( + ValueError, + "invalid Flag 'Bizarre': 'b' is missing named flags for values 1, 2; 'd' is missing a named flag for value 2", + ): + @verify(NAMED_FLAGS) + class Bizarre(IntFlag): + b = 3 + c = 4 + d = 6 + + def test_unique_clean(self): + @verify(UNIQUE) + class Clean(Enum): + one = 1 + two = 'dos' + tres = 4.0 + # + @verify(UNIQUE) + class Cleaner(IntEnum): + single = 1 + double = 2 + triple = 3 + + def test_unique_dirty(self): + with self.assertRaisesRegex(ValueError, 'tres.*one'): + @verify(UNIQUE) + class Dirty(Enum): + one = 1 + two = 'dos' + tres = 1 + with self.assertRaisesRegex( + ValueError, + 'double.*single.*turkey.*triple', + ): + @verify(UNIQUE) + class Dirtier(IntEnum): + single = 1 + double = 1 + triple = 3 + turkey = 3 + + def test_unique_with_name(self): + @verify(UNIQUE) + class Silly(Enum): + one = 1 + two = 'dos' + name = 3 + # + @verify(UNIQUE) class Sillier(IntEnum): single = 1 name = 2 |