diff options
Diffstat (limited to 'Lib/enum.py')
-rw-r--r-- | Lib/enum.py | 107 |
1 files changed, 96 insertions, 11 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index 01f4310..f74cc8c 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -6,10 +6,10 @@ from builtins import property as _bltin_property, bin as _bltin_bin __all__ = [ 'EnumType', 'EnumMeta', 'Enum', 'IntEnum', 'StrEnum', 'Flag', 'IntFlag', - 'auto', 'unique', - 'property', + 'auto', 'unique', 'property', 'verify', 'FlagBoundary', 'STRICT', 'CONFORM', 'EJECT', 'KEEP', 'global_flag_repr', 'global_enum_repr', 'global_enum', + 'EnumCheck', 'CONTINUOUS', 'NAMED_FLAGS', 'UNIQUE', ] @@ -89,6 +89,9 @@ def _make_class_unpicklable(obj): setattr(obj, '__module__', '<unknown>') def _iter_bits_lsb(num): + # num must be an integer + if isinstance(num, Enum): + num = num.value while num: b = num & (~num + 1) yield b @@ -538,13 +541,6 @@ class EnumType(type): else: # multi-bit flags are considered aliases multi_bit_total |= flag_value - if enum_class._boundary_ is not KEEP: - missed = list(_iter_bits_lsb(multi_bit_total & ~single_bit_total)) - if missed: - raise TypeError( - 'invalid Flag %r -- missing values: %s' - % (cls, ', '.join((str(i) for i in missed))) - ) enum_class._flag_mask_ = single_bit_total # # set correct __iter__ @@ -688,7 +684,10 @@ class EnumType(type): return MappingProxyType(cls._member_map_) def __repr__(cls): - return "<enum %r>" % cls.__name__ + if Flag is not None and issubclass(cls, Flag): + return "<flag %r>" % cls.__name__ + else: + return "<enum %r>" % cls.__name__ def __reversed__(cls): """ @@ -1303,7 +1302,8 @@ class Flag(Enum, boundary=STRICT): else: # calculate flags not in this member self._inverted_ = self.__class__(self._flag_mask_ ^ self._value_) - self._inverted_._inverted_ = self + if isinstance(self._inverted_, self.__class__): + self._inverted_._inverted_ = self return self._inverted_ @@ -1561,6 +1561,91 @@ def _simple_enum(etype=Enum, *, boundary=None, use_args=None): return enum_class return convert_class +@_simple_enum(StrEnum) +class EnumCheck: + """ + various conditions to check an enumeration for + """ + CONTINUOUS = "no skipped integer values" + NAMED_FLAGS = "multi-flag aliases may not contain unnamed flags" + UNIQUE = "one name per value" +CONTINUOUS, NAMED_FLAGS, UNIQUE = EnumCheck + + +class verify: + """ + Check an enumeration for various constraints. (see EnumCheck) + """ + def __init__(self, *checks): + self.checks = checks + def __call__(self, enumeration): + checks = self.checks + cls_name = enumeration.__name__ + if Flag is not None and issubclass(enumeration, Flag): + enum_type = 'flag' + elif issubclass(enumeration, Enum): + enum_type = 'enum' + else: + raise TypeError("the 'verify' decorator only works with Enum and Flag") + for check in checks: + if check is UNIQUE: + # check for duplicate names + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + alias_details = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates]) + raise ValueError('aliases found in %r: %s' % + (enumeration, alias_details)) + elif check is CONTINUOUS: + values = set(e.value for e in enumeration) + if len(values) < 2: + continue + low, high = min(values), max(values) + missing = [] + if enum_type == 'flag': + # check for powers of two + for i in range(_high_bit(low)+1, _high_bit(high)): + if 2**i not in values: + missing.append(2**i) + elif enum_type == 'enum': + # check for powers of one + for i in range(low+1, high): + if i not in values: + missing.append(i) + else: + raise Exception('verify: unknown type %r' % enum_type) + if missing: + raise ValueError('invalid %s %r: missing values %s' % ( + enum_type, cls_name, ', '.join((str(m) for m in missing))) + ) + elif check is NAMED_FLAGS: + # examine each alias and check for unnamed flags + member_names = enumeration._member_names_ + member_values = [m.value for m in enumeration] + missing = [] + for name, alias in enumeration._member_map_.items(): + if name in member_names: + # not an alias + continue + values = list(_iter_bits_lsb(alias.value)) + missed = [v for v in values if v not in member_values] + if missed: + plural = ('', 's')[len(missed) > 1] + a = ('a ', '')[len(missed) > 1] + missing.append('%r is missing %snamed flag%s for value%s %s' % ( + name, a, plural, plural, + ', '.join(str(v) for v in missed) + )) + if missing: + raise ValueError( + 'invalid Flag %r: %s' + % (cls_name, '; '.join(missing)) + ) + return enumeration + def _test_simple_enum(checked_enum, simple_enum): """ A function that can be used to test an enum created with :func:`_simple_enum` |