diff options
author | Ethan Furman <ethan@stoneleaf.us> | 2016-08-05 23:03:16 (GMT) |
---|---|---|
committer | Ethan Furman <ethan@stoneleaf.us> | 2016-08-05 23:03:16 (GMT) |
commit | 73fc586d9fbfc6e5f5aef20691b21b569eb56ae8 (patch) | |
tree | 011cb550c07650ea215977d80049f5df4ec5ffda /Lib | |
parent | 20bd9f033af72b4e886ab20d46b1558c4dbf3a3f (diff) | |
download | cpython-73fc586d9fbfc6e5f5aef20691b21b569eb56ae8.zip cpython-73fc586d9fbfc6e5f5aef20691b21b569eb56ae8.tar.gz cpython-73fc586d9fbfc6e5f5aef20691b21b569eb56ae8.tar.bz2 |
Add AutoEnum: automatically provides next value if missing. Issue 26988.
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/enum.py | 135 | ||||
-rw-r--r-- | Lib/test/test_enum.py | 324 |
2 files changed, 445 insertions, 14 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index 99db9e6..eaf5040 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -8,7 +8,9 @@ except ImportError: from collections import OrderedDict -__all__ = ['EnumMeta', 'Enum', 'IntEnum', 'unique'] +__all__ = [ + 'EnumMeta', 'Enum', 'IntEnum', 'AutoEnum', 'unique', + ] def _is_descriptor(obj): @@ -52,7 +54,30 @@ class _EnumDict(dict): """ def __init__(self): super().__init__() + # list of enum members self._member_names = [] + # starting value + self._start = None + # last assigned value + self._last_value = None + # when the magic turns off + self._locked = True + # list of temporary names + self._ignore = [] + + def __getitem__(self, key): + if ( + self._generate_next_value_ is None + or self._locked + or key in self + or key in self._ignore + or _is_sunder(key) + or _is_dunder(key) + ): + return super(_EnumDict, self).__getitem__(key) + next_value = self._generate_next_value_(key, self._start, len(self._member_names), self._last_value) + self[key] = next_value + return next_value def __setitem__(self, key, value): """Changes anything not dundered or not a descriptor. @@ -64,19 +89,55 @@ class _EnumDict(dict): """ if _is_sunder(key): - raise ValueError('_names_ are reserved for future Enum use') + if key not in ('_settings_', '_order_', '_ignore_', '_start_', '_generate_next_value_'): + raise ValueError('_names_ are reserved for future Enum use') + elif key == '_generate_next_value_': + if isinstance(value, staticmethod): + value = value.__get__(None, self) + self._generate_next_value_ = value + self._locked = False + elif key == '_ignore_': + if isinstance(value, str): + value = value.split() + else: + value = list(value) + self._ignore = value + already = set(value) & set(self._member_names) + if already: + raise ValueError( + '_ignore_ cannot specify already set names: %r' + % (already, )) + elif key == '_start_': + self._start = value + self._locked = False elif _is_dunder(key): - pass + if key == '__order__': + key = '_order_' + if _is_descriptor(value): + self._locked = True elif key in self._member_names: # descriptor overwriting an enum? raise TypeError('Attempted to reuse key: %r' % key) + elif key in self._ignore: + pass elif not _is_descriptor(value): if key in self: # enum overwriting a descriptor? - raise TypeError('Key already defined as: %r' % self[key]) + raise TypeError('%r already defined as: %r' % (key, self[key])) self._member_names.append(key) + if self._generate_next_value_ is not None: + self._last_value = value + else: + # not a new member, turn off the autoassign magic + self._locked = True super().__setitem__(key, value) + # for magic "auto values" an Enum class should specify a `_generate_next_value_` + # method; that method will be used to generate missing values, and is + # implicitly a staticmethod; + # the signature should be `def _generate_next_value_(name, last_value)` + # last_value will be the last value created and/or assigned, or None + _generate_next_value_ = None # Dummy value for Enum as EnumMeta explicitly checks for it, but of course @@ -84,14 +145,31 @@ class _EnumDict(dict): # This is also why there are checks in EnumMeta like `if Enum is not None` Enum = None - +_ignore_sentinel = object() class EnumMeta(type): """Metaclass for Enum""" @classmethod - def __prepare__(metacls, cls, bases): - return _EnumDict() - - def __new__(metacls, cls, bases, classdict): + def __prepare__(metacls, cls, bases, start=None, ignore=_ignore_sentinel): + # create the namespace dict + enum_dict = _EnumDict() + # inherit previous flags and _generate_next_value_ function + member_type, first_enum = metacls._get_mixins_(bases) + if first_enum is not None: + enum_dict['_generate_next_value_'] = getattr(first_enum, '_generate_next_value_', None) + if start is None: + start = getattr(first_enum, '_start_', None) + if ignore is _ignore_sentinel: + enum_dict['_ignore_'] = 'property classmethod staticmethod'.split() + elif ignore: + enum_dict['_ignore_'] = ignore + if start is not None: + enum_dict['_start_'] = start + return enum_dict + + def __init__(cls, *args , **kwds): + super(EnumMeta, cls).__init__(*args) + + def __new__(metacls, cls, bases, classdict, **kwds): # an Enum class is final once enumeration items have been defined; it # cannot be mixed with other types (int, float, etc.) if it has an # inherited __new__ unless a new __new__ is defined (or the resulting @@ -102,12 +180,24 @@ class EnumMeta(type): # save enum items into separate mapping so they don't get baked into # the new class - members = {k: classdict[k] for k in classdict._member_names} + enum_members = {k: classdict[k] for k in classdict._member_names} for name in classdict._member_names: del classdict[name] + # adjust the sunders + _order_ = classdict.pop('_order_', None) + classdict.pop('_ignore_', None) + + # py3 support for definition order (helps keep py2/py3 code in sync) + if _order_ is not None: + if isinstance(_order_, str): + _order_ = _order_.replace(',', ' ').split() + unique_members = [n for n in clsdict._member_names if n in _order_] + if _order_ != unique_members: + raise TypeError('member order does not match _order_') + # check for illegal enum names (any others?) - invalid_names = set(members) & {'mro', } + invalid_names = set(enum_members) & {'mro', } if invalid_names: raise ValueError('Invalid enum member name: {0}'.format( ','.join(invalid_names))) @@ -151,7 +241,7 @@ class EnumMeta(type): # a custom __new__ is doing something funky with the values -- such as # auto-numbering ;) for member_name in classdict._member_names: - value = members[member_name] + value = enum_members[member_name] if not isinstance(value, tuple): args = (value, ) else: @@ -165,7 +255,10 @@ class EnumMeta(type): else: enum_member = __new__(enum_class, *args) if not hasattr(enum_member, '_value_'): - enum_member._value_ = member_type(*args) + if member_type is object: + enum_member._value_ = value + else: + enum_member._value_ = member_type(*args) value = enum_member._value_ enum_member._name_ = member_name enum_member.__objclass__ = enum_class @@ -572,6 +665,22 @@ class IntEnum(int, Enum): def _reduce_ex_by_name(self, proto): return self.name +class AutoEnum(Enum): + """Enum where values are automatically assigned.""" + def _generate_next_value_(name, start, count, last_value): + """ + Generate the next value when not given. + + name: the name of the member + start: the initital start value or None + count: the number of existing members + last_value: the last value assigned or None + """ + # add one to the last assigned value + if not count: + return start if start is not None else 1 + return last_value + 1 + def unique(enumeration): """Class decorator for enumerations ensuring unique member values.""" duplicates = [] diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 564c0e9..4a732f9 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -3,7 +3,7 @@ import inspect import pydoc import unittest from collections import OrderedDict -from enum import Enum, IntEnum, EnumMeta, unique +from enum import EnumMeta, Enum, IntEnum, AutoEnum, unique from io import StringIO from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL from test import support @@ -1570,6 +1570,328 @@ class TestEnum(unittest.TestCase): self.assertEqual(LabelledList.unprocessed, 1) self.assertEqual(LabelledList(1), LabelledList.unprocessed) + def test_ignore_as_str(self): + from datetime import timedelta + class Period(Enum, ignore='Period i'): + """ + different lengths of time + """ + def __new__(cls, value, period): + obj = object.__new__(cls) + obj._value_ = value + obj.period = period + return obj + Period = vars() + for i in range(367): + Period['Day%d' % i] = timedelta(days=i), 'day' + for i in range(53): + Period['Week%d' % i] = timedelta(days=i*7), 'week' + for i in range(13): + Period['Month%d' % i] = i, 'month' + OneDay = Day1 + OneWeek = Week1 + self.assertEqual(Period.Day7.value, timedelta(days=7)) + self.assertEqual(Period.Day7.period, 'day') + + def test_ignore_as_list(self): + from datetime import timedelta + class Period(Enum, ignore=['Period', 'i']): + """ + different lengths of time + """ + def __new__(cls, value, period): + obj = object.__new__(cls) + obj._value_ = value + obj.period = period + return obj + Period = vars() + for i in range(367): + Period['Day%d' % i] = timedelta(days=i), 'day' + for i in range(53): + Period['Week%d' % i] = timedelta(days=i*7), 'week' + for i in range(13): + Period['Month%d' % i] = i, 'month' + OneDay = Day1 + OneWeek = Week1 + self.assertEqual(Period.Day7.value, timedelta(days=7)) + self.assertEqual(Period.Day7.period, 'day') + + def test_new_with_no_value_and_int_base_class(self): + class NoValue(int, Enum): + def __new__(cls, value): + obj = int.__new__(cls, value) + obj.index = len(cls.__members__) + return obj + this = 1 + that = 2 + self.assertEqual(list(NoValue), [NoValue.this, NoValue.that]) + self.assertEqual(NoValue.this, 1) + self.assertEqual(NoValue.this.value, 1) + self.assertEqual(NoValue.this.index, 0) + self.assertEqual(NoValue.that, 2) + self.assertEqual(NoValue.that.value, 2) + self.assertEqual(NoValue.that.index, 1) + + def test_new_with_no_value(self): + class NoValue(Enum): + def __new__(cls, value): + obj = object.__new__(cls) + obj.index = len(cls.__members__) + return obj + this = 1 + that = 2 + self.assertEqual(list(NoValue), [NoValue.this, NoValue.that]) + self.assertEqual(NoValue.this.value, 1) + self.assertEqual(NoValue.this.index, 0) + self.assertEqual(NoValue.that.value, 2) + self.assertEqual(NoValue.that.index, 1) + + +class TestAutoNumber(unittest.TestCase): + + def test_autonumbering(self): + class Color(AutoEnum): + red + green + blue + self.assertEqual(list(Color), [Color.red, Color.green, Color.blue]) + self.assertEqual(Color.red.value, 1) + self.assertEqual(Color.green.value, 2) + self.assertEqual(Color.blue.value, 3) + + def test_autointnumbering(self): + class Color(int, AutoEnum): + red + green + blue + self.assertTrue(isinstance(Color.red, int)) + self.assertEqual(Color.green, 2) + self.assertTrue(Color.blue > Color.red) + + def test_autonumbering_with_start(self): + class Color(AutoEnum, start=7): + red + green + blue + self.assertEqual(list(Color), [Color.red, Color.green, Color.blue]) + self.assertEqual(Color.red.value, 7) + self.assertEqual(Color.green.value, 8) + self.assertEqual(Color.blue.value, 9) + + def test_autonumbering_with_start_and_skip(self): + class Color(AutoEnum, start=7): + red + green + blue = 11 + brown + self.assertEqual(list(Color), [Color.red, Color.green, Color.blue, Color.brown]) + self.assertEqual(Color.red.value, 7) + self.assertEqual(Color.green.value, 8) + self.assertEqual(Color.blue.value, 11) + self.assertEqual(Color.brown.value, 12) + + + def test_badly_overridden_ignore(self): + with self.assertRaisesRegex(TypeError, "'int' object is not callable"): + class Color(AutoEnum): + _ignore_ = () + red + green + blue + @property + def whatever(self): + pass + with self.assertRaisesRegex(TypeError, "'int' object is not callable"): + class Color(AutoEnum, ignore=None): + red + green + blue + @property + def whatever(self): + pass + with self.assertRaisesRegex(TypeError, "'int' object is not callable"): + class Color(AutoEnum, ignore='classmethod staticmethod'): + red + green + blue + @property + def whatever(self): + pass + + def test_property(self): + class Color(AutoEnum): + red + green + blue + @property + def cap_name(self): + return self.name.title() + self.assertEqual(Color.blue.cap_name, 'Blue') + + def test_magic_turns_off(self): + with self.assertRaisesRegex(NameError, "brown"): + class Color(AutoEnum): + red + green + blue + @property + def cap_name(self): + return self.name.title() + brown + + with self.assertRaisesRegex(NameError, "rose"): + class Color(AutoEnum): + red + green + blue + def hello(self): + print('Hello! My serial is %s.' % self.value) + rose + + with self.assertRaisesRegex(NameError, "cyan"): + class Color(AutoEnum): + red + green + blue + def __init__(self, *args): + pass + cyan + + +class TestGenerateMethod(unittest.TestCase): + + def test_autonaming(self): + class Color(Enum): + def _generate_next_value_(name, start, count, last_value): + return name + Red + Green + Blue + self.assertEqual(list(Color), [Color.Red, Color.Green, Color.Blue]) + self.assertEqual(Color.Red.value, 'Red') + self.assertEqual(Color.Green.value, 'Green') + self.assertEqual(Color.Blue.value, 'Blue') + + def test_autonamestr(self): + class Color(str, Enum): + def _generate_next_value_(name, start, count, last_value): + return name + Red + Green + Blue + self.assertTrue(isinstance(Color.Red, str)) + self.assertEqual(Color.Green, 'Green') + self.assertTrue(Color.Blue < Color.Red) + + def test_generate_as_staticmethod(self): + class Color(str, Enum): + @staticmethod + def _generate_next_value_(name, start, count, last_value): + return name.lower() + Red + Green + Blue + self.assertTrue(isinstance(Color.Red, str)) + self.assertEqual(Color.Green, 'green') + self.assertTrue(Color.Blue < Color.Red) + + + def test_overridden_ignore(self): + with self.assertRaisesRegex(TypeError, "'str' object is not callable"): + class Color(Enum): + def _generate_next_value_(name, start, count, last_value): + return name + _ignore_ = () + red + green + blue + @property + def whatever(self): + pass + with self.assertRaisesRegex(TypeError, "'str' object is not callable"): + class Color(Enum, ignore=None): + def _generate_next_value_(name, start, count, last_value): + return name + red + green + blue + @property + def whatever(self): + pass + + def test_property(self): + class Color(Enum): + def _generate_next_value_(name, start, count, last_value): + return name + red + green + blue + @property + def upper_name(self): + return self.name.upper() + self.assertEqual(Color.blue.upper_name, 'BLUE') + + def test_magic_turns_off(self): + with self.assertRaisesRegex(NameError, "brown"): + class Color(Enum): + def _generate_next_value_(name, start, count, last_value): + return name + red + green + blue + @property + def cap_name(self): + return self.name.title() + brown + + with self.assertRaisesRegex(NameError, "rose"): + class Color(Enum): + def _generate_next_value_(name, start, count, last_value): + return name + red + green + blue + def hello(self): + print('Hello! My value %s.' % self.value) + rose + + with self.assertRaisesRegex(NameError, "cyan"): + class Color(Enum): + def _generate_next_value_(name, start, count, last_value): + return name + red + green + blue + def __init__(self, *args): + pass + cyan + + def test_powers_of_two(self): + class Bits(Enum): + def _generate_next_value_(name, start, count, last_value): + return 2 ** count + one + two + four + eight + self.assertEqual(Bits.one.value, 1) + self.assertEqual(Bits.two.value, 2) + self.assertEqual(Bits.four.value, 4) + self.assertEqual(Bits.eight.value, 8) + + def test_powers_of_two_as_int(self): + class Bits(int, Enum): + def _generate_next_value_(name, start, count, last_value): + return 2 ** count + one + two + four + eight + self.assertEqual(Bits.one, 1) + self.assertEqual(Bits.two, 2) + self.assertEqual(Bits.four, 4) + self.assertEqual(Bits.eight, 8) + class TestUnique(unittest.TestCase): |