summaryrefslogtreecommitdiffstats
path: root/Lib
diff options
context:
space:
mode:
authorEthan Furman <ethan@stoneleaf.us>2016-08-05 23:03:16 (GMT)
committerEthan Furman <ethan@stoneleaf.us>2016-08-05 23:03:16 (GMT)
commit73fc586d9fbfc6e5f5aef20691b21b569eb56ae8 (patch)
tree011cb550c07650ea215977d80049f5df4ec5ffda /Lib
parent20bd9f033af72b4e886ab20d46b1558c4dbf3a3f (diff)
downloadcpython-73fc586d9fbfc6e5f5aef20691b21b569eb56ae8.zip
cpython-73fc586d9fbfc6e5f5aef20691b21b569eb56ae8.tar.gz
cpython-73fc586d9fbfc6e5f5aef20691b21b569eb56ae8.tar.bz2
Add AutoEnum: automatically provides next value if missing. Issue 26988.
Diffstat (limited to 'Lib')
-rw-r--r--Lib/enum.py135
-rw-r--r--Lib/test/test_enum.py324
2 files changed, 445 insertions, 14 deletions
diff --git a/Lib/enum.py b/Lib/enum.py
index 99db9e6..eaf5040 100644
--- a/Lib/enum.py
+++ b/Lib/enum.py
@@ -8,7 +8,9 @@ except ImportError:
from collections import OrderedDict
-__all__ = ['EnumMeta', 'Enum', 'IntEnum', 'unique']
+__all__ = [
+ 'EnumMeta', 'Enum', 'IntEnum', 'AutoEnum', 'unique',
+ ]
def _is_descriptor(obj):
@@ -52,7 +54,30 @@ class _EnumDict(dict):
"""
def __init__(self):
super().__init__()
+ # list of enum members
self._member_names = []
+ # starting value
+ self._start = None
+ # last assigned value
+ self._last_value = None
+ # when the magic turns off
+ self._locked = True
+ # list of temporary names
+ self._ignore = []
+
+ def __getitem__(self, key):
+ if (
+ self._generate_next_value_ is None
+ or self._locked
+ or key in self
+ or key in self._ignore
+ or _is_sunder(key)
+ or _is_dunder(key)
+ ):
+ return super(_EnumDict, self).__getitem__(key)
+ next_value = self._generate_next_value_(key, self._start, len(self._member_names), self._last_value)
+ self[key] = next_value
+ return next_value
def __setitem__(self, key, value):
"""Changes anything not dundered or not a descriptor.
@@ -64,19 +89,55 @@ class _EnumDict(dict):
"""
if _is_sunder(key):
- raise ValueError('_names_ are reserved for future Enum use')
+ if key not in ('_settings_', '_order_', '_ignore_', '_start_', '_generate_next_value_'):
+ raise ValueError('_names_ are reserved for future Enum use')
+ elif key == '_generate_next_value_':
+ if isinstance(value, staticmethod):
+ value = value.__get__(None, self)
+ self._generate_next_value_ = value
+ self._locked = False
+ elif key == '_ignore_':
+ if isinstance(value, str):
+ value = value.split()
+ else:
+ value = list(value)
+ self._ignore = value
+ already = set(value) & set(self._member_names)
+ if already:
+ raise ValueError(
+ '_ignore_ cannot specify already set names: %r'
+ % (already, ))
+ elif key == '_start_':
+ self._start = value
+ self._locked = False
elif _is_dunder(key):
- pass
+ if key == '__order__':
+ key = '_order_'
+ if _is_descriptor(value):
+ self._locked = True
elif key in self._member_names:
# descriptor overwriting an enum?
raise TypeError('Attempted to reuse key: %r' % key)
+ elif key in self._ignore:
+ pass
elif not _is_descriptor(value):
if key in self:
# enum overwriting a descriptor?
- raise TypeError('Key already defined as: %r' % self[key])
+ raise TypeError('%r already defined as: %r' % (key, self[key]))
self._member_names.append(key)
+ if self._generate_next_value_ is not None:
+ self._last_value = value
+ else:
+ # not a new member, turn off the autoassign magic
+ self._locked = True
super().__setitem__(key, value)
+ # for magic "auto values" an Enum class should specify a `_generate_next_value_`
+ # method; that method will be used to generate missing values, and is
+ # implicitly a staticmethod;
+ # the signature should be `def _generate_next_value_(name, last_value)`
+ # last_value will be the last value created and/or assigned, or None
+ _generate_next_value_ = None
# Dummy value for Enum as EnumMeta explicitly checks for it, but of course
@@ -84,14 +145,31 @@ class _EnumDict(dict):
# This is also why there are checks in EnumMeta like `if Enum is not None`
Enum = None
-
+_ignore_sentinel = object()
class EnumMeta(type):
"""Metaclass for Enum"""
@classmethod
- def __prepare__(metacls, cls, bases):
- return _EnumDict()
-
- def __new__(metacls, cls, bases, classdict):
+ def __prepare__(metacls, cls, bases, start=None, ignore=_ignore_sentinel):
+ # create the namespace dict
+ enum_dict = _EnumDict()
+ # inherit previous flags and _generate_next_value_ function
+ member_type, first_enum = metacls._get_mixins_(bases)
+ if first_enum is not None:
+ enum_dict['_generate_next_value_'] = getattr(first_enum, '_generate_next_value_', None)
+ if start is None:
+ start = getattr(first_enum, '_start_', None)
+ if ignore is _ignore_sentinel:
+ enum_dict['_ignore_'] = 'property classmethod staticmethod'.split()
+ elif ignore:
+ enum_dict['_ignore_'] = ignore
+ if start is not None:
+ enum_dict['_start_'] = start
+ return enum_dict
+
+ def __init__(cls, *args , **kwds):
+ super(EnumMeta, cls).__init__(*args)
+
+ def __new__(metacls, cls, bases, classdict, **kwds):
# an Enum class is final once enumeration items have been defined; it
# cannot be mixed with other types (int, float, etc.) if it has an
# inherited __new__ unless a new __new__ is defined (or the resulting
@@ -102,12 +180,24 @@ class EnumMeta(type):
# save enum items into separate mapping so they don't get baked into
# the new class
- members = {k: classdict[k] for k in classdict._member_names}
+ enum_members = {k: classdict[k] for k in classdict._member_names}
for name in classdict._member_names:
del classdict[name]
+ # adjust the sunders
+ _order_ = classdict.pop('_order_', None)
+ classdict.pop('_ignore_', None)
+
+ # py3 support for definition order (helps keep py2/py3 code in sync)
+ if _order_ is not None:
+ if isinstance(_order_, str):
+ _order_ = _order_.replace(',', ' ').split()
+ unique_members = [n for n in clsdict._member_names if n in _order_]
+ if _order_ != unique_members:
+ raise TypeError('member order does not match _order_')
+
# check for illegal enum names (any others?)
- invalid_names = set(members) & {'mro', }
+ invalid_names = set(enum_members) & {'mro', }
if invalid_names:
raise ValueError('Invalid enum member name: {0}'.format(
','.join(invalid_names)))
@@ -151,7 +241,7 @@ class EnumMeta(type):
# a custom __new__ is doing something funky with the values -- such as
# auto-numbering ;)
for member_name in classdict._member_names:
- value = members[member_name]
+ value = enum_members[member_name]
if not isinstance(value, tuple):
args = (value, )
else:
@@ -165,7 +255,10 @@ class EnumMeta(type):
else:
enum_member = __new__(enum_class, *args)
if not hasattr(enum_member, '_value_'):
- enum_member._value_ = member_type(*args)
+ if member_type is object:
+ enum_member._value_ = value
+ else:
+ enum_member._value_ = member_type(*args)
value = enum_member._value_
enum_member._name_ = member_name
enum_member.__objclass__ = enum_class
@@ -572,6 +665,22 @@ class IntEnum(int, Enum):
def _reduce_ex_by_name(self, proto):
return self.name
+class AutoEnum(Enum):
+ """Enum where values are automatically assigned."""
+ def _generate_next_value_(name, start, count, last_value):
+ """
+ Generate the next value when not given.
+
+ name: the name of the member
+ start: the initital start value or None
+ count: the number of existing members
+ last_value: the last value assigned or None
+ """
+ # add one to the last assigned value
+ if not count:
+ return start if start is not None else 1
+ return last_value + 1
+
def unique(enumeration):
"""Class decorator for enumerations ensuring unique member values."""
duplicates = []
diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py
index 564c0e9..4a732f9 100644
--- a/Lib/test/test_enum.py
+++ b/Lib/test/test_enum.py
@@ -3,7 +3,7 @@ import inspect
import pydoc
import unittest
from collections import OrderedDict
-from enum import Enum, IntEnum, EnumMeta, unique
+from enum import EnumMeta, Enum, IntEnum, AutoEnum, unique
from io import StringIO
from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL
from test import support
@@ -1570,6 +1570,328 @@ class TestEnum(unittest.TestCase):
self.assertEqual(LabelledList.unprocessed, 1)
self.assertEqual(LabelledList(1), LabelledList.unprocessed)
+ def test_ignore_as_str(self):
+ from datetime import timedelta
+ class Period(Enum, ignore='Period i'):
+ """
+ different lengths of time
+ """
+ def __new__(cls, value, period):
+ obj = object.__new__(cls)
+ obj._value_ = value
+ obj.period = period
+ return obj
+ Period = vars()
+ for i in range(367):
+ Period['Day%d' % i] = timedelta(days=i), 'day'
+ for i in range(53):
+ Period['Week%d' % i] = timedelta(days=i*7), 'week'
+ for i in range(13):
+ Period['Month%d' % i] = i, 'month'
+ OneDay = Day1
+ OneWeek = Week1
+ self.assertEqual(Period.Day7.value, timedelta(days=7))
+ self.assertEqual(Period.Day7.period, 'day')
+
+ def test_ignore_as_list(self):
+ from datetime import timedelta
+ class Period(Enum, ignore=['Period', 'i']):
+ """
+ different lengths of time
+ """
+ def __new__(cls, value, period):
+ obj = object.__new__(cls)
+ obj._value_ = value
+ obj.period = period
+ return obj
+ Period = vars()
+ for i in range(367):
+ Period['Day%d' % i] = timedelta(days=i), 'day'
+ for i in range(53):
+ Period['Week%d' % i] = timedelta(days=i*7), 'week'
+ for i in range(13):
+ Period['Month%d' % i] = i, 'month'
+ OneDay = Day1
+ OneWeek = Week1
+ self.assertEqual(Period.Day7.value, timedelta(days=7))
+ self.assertEqual(Period.Day7.period, 'day')
+
+ def test_new_with_no_value_and_int_base_class(self):
+ class NoValue(int, Enum):
+ def __new__(cls, value):
+ obj = int.__new__(cls, value)
+ obj.index = len(cls.__members__)
+ return obj
+ this = 1
+ that = 2
+ self.assertEqual(list(NoValue), [NoValue.this, NoValue.that])
+ self.assertEqual(NoValue.this, 1)
+ self.assertEqual(NoValue.this.value, 1)
+ self.assertEqual(NoValue.this.index, 0)
+ self.assertEqual(NoValue.that, 2)
+ self.assertEqual(NoValue.that.value, 2)
+ self.assertEqual(NoValue.that.index, 1)
+
+ def test_new_with_no_value(self):
+ class NoValue(Enum):
+ def __new__(cls, value):
+ obj = object.__new__(cls)
+ obj.index = len(cls.__members__)
+ return obj
+ this = 1
+ that = 2
+ self.assertEqual(list(NoValue), [NoValue.this, NoValue.that])
+ self.assertEqual(NoValue.this.value, 1)
+ self.assertEqual(NoValue.this.index, 0)
+ self.assertEqual(NoValue.that.value, 2)
+ self.assertEqual(NoValue.that.index, 1)
+
+
+class TestAutoNumber(unittest.TestCase):
+
+ def test_autonumbering(self):
+ class Color(AutoEnum):
+ red
+ green
+ blue
+ self.assertEqual(list(Color), [Color.red, Color.green, Color.blue])
+ self.assertEqual(Color.red.value, 1)
+ self.assertEqual(Color.green.value, 2)
+ self.assertEqual(Color.blue.value, 3)
+
+ def test_autointnumbering(self):
+ class Color(int, AutoEnum):
+ red
+ green
+ blue
+ self.assertTrue(isinstance(Color.red, int))
+ self.assertEqual(Color.green, 2)
+ self.assertTrue(Color.blue > Color.red)
+
+ def test_autonumbering_with_start(self):
+ class Color(AutoEnum, start=7):
+ red
+ green
+ blue
+ self.assertEqual(list(Color), [Color.red, Color.green, Color.blue])
+ self.assertEqual(Color.red.value, 7)
+ self.assertEqual(Color.green.value, 8)
+ self.assertEqual(Color.blue.value, 9)
+
+ def test_autonumbering_with_start_and_skip(self):
+ class Color(AutoEnum, start=7):
+ red
+ green
+ blue = 11
+ brown
+ self.assertEqual(list(Color), [Color.red, Color.green, Color.blue, Color.brown])
+ self.assertEqual(Color.red.value, 7)
+ self.assertEqual(Color.green.value, 8)
+ self.assertEqual(Color.blue.value, 11)
+ self.assertEqual(Color.brown.value, 12)
+
+
+ def test_badly_overridden_ignore(self):
+ with self.assertRaisesRegex(TypeError, "'int' object is not callable"):
+ class Color(AutoEnum):
+ _ignore_ = ()
+ red
+ green
+ blue
+ @property
+ def whatever(self):
+ pass
+ with self.assertRaisesRegex(TypeError, "'int' object is not callable"):
+ class Color(AutoEnum, ignore=None):
+ red
+ green
+ blue
+ @property
+ def whatever(self):
+ pass
+ with self.assertRaisesRegex(TypeError, "'int' object is not callable"):
+ class Color(AutoEnum, ignore='classmethod staticmethod'):
+ red
+ green
+ blue
+ @property
+ def whatever(self):
+ pass
+
+ def test_property(self):
+ class Color(AutoEnum):
+ red
+ green
+ blue
+ @property
+ def cap_name(self):
+ return self.name.title()
+ self.assertEqual(Color.blue.cap_name, 'Blue')
+
+ def test_magic_turns_off(self):
+ with self.assertRaisesRegex(NameError, "brown"):
+ class Color(AutoEnum):
+ red
+ green
+ blue
+ @property
+ def cap_name(self):
+ return self.name.title()
+ brown
+
+ with self.assertRaisesRegex(NameError, "rose"):
+ class Color(AutoEnum):
+ red
+ green
+ blue
+ def hello(self):
+ print('Hello! My serial is %s.' % self.value)
+ rose
+
+ with self.assertRaisesRegex(NameError, "cyan"):
+ class Color(AutoEnum):
+ red
+ green
+ blue
+ def __init__(self, *args):
+ pass
+ cyan
+
+
+class TestGenerateMethod(unittest.TestCase):
+
+ def test_autonaming(self):
+ class Color(Enum):
+ def _generate_next_value_(name, start, count, last_value):
+ return name
+ Red
+ Green
+ Blue
+ self.assertEqual(list(Color), [Color.Red, Color.Green, Color.Blue])
+ self.assertEqual(Color.Red.value, 'Red')
+ self.assertEqual(Color.Green.value, 'Green')
+ self.assertEqual(Color.Blue.value, 'Blue')
+
+ def test_autonamestr(self):
+ class Color(str, Enum):
+ def _generate_next_value_(name, start, count, last_value):
+ return name
+ Red
+ Green
+ Blue
+ self.assertTrue(isinstance(Color.Red, str))
+ self.assertEqual(Color.Green, 'Green')
+ self.assertTrue(Color.Blue < Color.Red)
+
+ def test_generate_as_staticmethod(self):
+ class Color(str, Enum):
+ @staticmethod
+ def _generate_next_value_(name, start, count, last_value):
+ return name.lower()
+ Red
+ Green
+ Blue
+ self.assertTrue(isinstance(Color.Red, str))
+ self.assertEqual(Color.Green, 'green')
+ self.assertTrue(Color.Blue < Color.Red)
+
+
+ def test_overridden_ignore(self):
+ with self.assertRaisesRegex(TypeError, "'str' object is not callable"):
+ class Color(Enum):
+ def _generate_next_value_(name, start, count, last_value):
+ return name
+ _ignore_ = ()
+ red
+ green
+ blue
+ @property
+ def whatever(self):
+ pass
+ with self.assertRaisesRegex(TypeError, "'str' object is not callable"):
+ class Color(Enum, ignore=None):
+ def _generate_next_value_(name, start, count, last_value):
+ return name
+ red
+ green
+ blue
+ @property
+ def whatever(self):
+ pass
+
+ def test_property(self):
+ class Color(Enum):
+ def _generate_next_value_(name, start, count, last_value):
+ return name
+ red
+ green
+ blue
+ @property
+ def upper_name(self):
+ return self.name.upper()
+ self.assertEqual(Color.blue.upper_name, 'BLUE')
+
+ def test_magic_turns_off(self):
+ with self.assertRaisesRegex(NameError, "brown"):
+ class Color(Enum):
+ def _generate_next_value_(name, start, count, last_value):
+ return name
+ red
+ green
+ blue
+ @property
+ def cap_name(self):
+ return self.name.title()
+ brown
+
+ with self.assertRaisesRegex(NameError, "rose"):
+ class Color(Enum):
+ def _generate_next_value_(name, start, count, last_value):
+ return name
+ red
+ green
+ blue
+ def hello(self):
+ print('Hello! My value %s.' % self.value)
+ rose
+
+ with self.assertRaisesRegex(NameError, "cyan"):
+ class Color(Enum):
+ def _generate_next_value_(name, start, count, last_value):
+ return name
+ red
+ green
+ blue
+ def __init__(self, *args):
+ pass
+ cyan
+
+ def test_powers_of_two(self):
+ class Bits(Enum):
+ def _generate_next_value_(name, start, count, last_value):
+ return 2 ** count
+ one
+ two
+ four
+ eight
+ self.assertEqual(Bits.one.value, 1)
+ self.assertEqual(Bits.two.value, 2)
+ self.assertEqual(Bits.four.value, 4)
+ self.assertEqual(Bits.eight.value, 8)
+
+ def test_powers_of_two_as_int(self):
+ class Bits(int, Enum):
+ def _generate_next_value_(name, start, count, last_value):
+ return 2 ** count
+ one
+ two
+ four
+ eight
+ self.assertEqual(Bits.one, 1)
+ self.assertEqual(Bits.two, 2)
+ self.assertEqual(Bits.four, 4)
+ self.assertEqual(Bits.eight, 8)
+
class TestUnique(unittest.TestCase):