diff options
author | Ethan Furman <ethan@stoneleaf.us> | 2016-09-11 06:36:59 (GMT) |
---|---|---|
committer | Ethan Furman <ethan@stoneleaf.us> | 2016-09-11 06:36:59 (GMT) |
commit | c16595e567d51a1773be30c34622620b52be7acf (patch) | |
tree | 3126fe0d721c7eb7b6eb279d20e3f7ea79c6bc63 /Lib | |
parent | 944368e1cc90a0bebaaf1a0a6f4346a81d8f46ad (diff) | |
download | cpython-c16595e567d51a1773be30c34622620b52be7acf.zip cpython-c16595e567d51a1773be30c34622620b52be7acf.tar.gz cpython-c16595e567d51a1773be30c34622620b52be7acf.tar.bz2 |
issue23591: add auto() for auto-generating Enum member values
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/enum.py | 50 | ||||
-rw-r--r-- | Lib/test/test_enum.py | 77 |
2 files changed, 112 insertions, 15 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index 6a18999..1f87664 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -10,7 +10,11 @@ except ImportError: from collections import OrderedDict -__all__ = ['EnumMeta', 'Enum', 'IntEnum', 'Flag', 'IntFlag', 'unique'] +__all__ = [ + 'EnumMeta', + 'Enum', 'IntEnum', 'Flag', 'IntFlag', + 'auto', 'unique', + ] def _is_descriptor(obj): @@ -36,7 +40,6 @@ def _is_sunder(name): name[-2:-1] != '_' and len(name) > 2) - def _make_class_unpicklable(cls): """Make the given class un-picklable.""" def _break_on_call_reduce(self, proto): @@ -44,6 +47,12 @@ def _make_class_unpicklable(cls): cls.__reduce_ex__ = _break_on_call_reduce cls.__module__ = '<unknown>' +class auto: + """ + Instances are replaced with an appropriate value in Enum class suites. + """ + pass + class _EnumDict(dict): """Track enum member order and ensure member names are not reused. @@ -55,6 +64,7 @@ class _EnumDict(dict): def __init__(self): super().__init__() self._member_names = [] + self._last_values = [] def __setitem__(self, key, value): """Changes anything not dundered or not a descriptor. @@ -71,6 +81,8 @@ class _EnumDict(dict): '_generate_next_value_', '_missing_', ): raise ValueError('_names_ are reserved for future Enum use') + if key == '_generate_next_value_': + setattr(self, '_generate_next_value', value) elif _is_dunder(key): if key == '__order__': key = '_order_' @@ -81,11 +93,13 @@ class _EnumDict(dict): if key in self: # enum overwriting a descriptor? raise TypeError('%r already defined as: %r' % (key, self[key])) + if isinstance(value, auto): + value = self._generate_next_value(key, 1, len(self._member_names), self._last_values[:]) self._member_names.append(key) + self._last_values.append(value) super().__setitem__(key, value) - # Dummy value for Enum as EnumMeta explicitly checks for it, but of course # until EnumMeta finishes running the first time the Enum class doesn't exist. # This is also why there are checks in EnumMeta like `if Enum is not None` @@ -366,10 +380,11 @@ class EnumMeta(type): names = names.replace(',', ' ').split() if isinstance(names, (tuple, list)) and isinstance(names[0], str): original_names, names = names, [] - last_value = None + last_values = [] for count, name in enumerate(original_names): - last_value = first_enum._generate_next_value_(name, start, count, last_value) - names.append((name, last_value)) + value = first_enum._generate_next_value_(name, start, count, last_values[:]) + last_values.append(value) + names.append((name, value)) # Here, names is either an iterable of (name, value) or a mapping. for item in names: @@ -514,11 +529,15 @@ class Enum(metaclass=EnumMeta): # still not found -- try _missing_ hook return cls._missing_(value) - @staticmethod - def _generate_next_value_(name, start, count, last_value): - if not count: + def _generate_next_value_(name, start, count, last_values): + for last_value in reversed(last_values): + try: + return last_value + 1 + except TypeError: + pass + else: return start - return last_value + 1 + @classmethod def _missing_(cls, value): raise ValueError("%r is not a valid %s" % (value, cls.__name__)) @@ -616,8 +635,8 @@ def _reduce_ex_by_name(self, proto): class Flag(Enum): """Support for flags""" - @staticmethod - def _generate_next_value_(name, start, count, last_value): + + def _generate_next_value_(name, start, count, last_values): """ Generate the next value when not given. @@ -628,7 +647,12 @@ class Flag(Enum): """ if not count: return start if start is not None else 1 - high_bit = _high_bit(last_value) + for last_value in reversed(last_values): + try: + high_bit = _high_bit(last_value) + break + except TypeError: + raise TypeError('Invalid Flag value: %r' % last_value) from None return 2 ** (high_bit+1) @classmethod diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index 698fd30..153bfb4 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -3,7 +3,7 @@ import inspect import pydoc import unittest from collections import OrderedDict -from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique +from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique, auto from io import StringIO from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL from test import support @@ -113,6 +113,7 @@ class TestHelpers(unittest.TestCase): '__', '___', '____', '_____',): self.assertFalse(enum._is_dunder(s)) +# tests class TestEnum(unittest.TestCase): @@ -1578,6 +1579,61 @@ class TestEnum(unittest.TestCase): self.assertEqual(LabelledList.unprocessed, 1) self.assertEqual(LabelledList(1), LabelledList.unprocessed) + def test_auto_number(self): + class Color(Enum): + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 1) + self.assertEqual(Color.blue.value, 2) + self.assertEqual(Color.green.value, 3) + + def test_auto_name(self): + class Color(Enum): + def _generate_next_value_(name, start, count, last): + return name + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 'red') + self.assertEqual(Color.blue.value, 'blue') + self.assertEqual(Color.green.value, 'green') + + def test_auto_name_inherit(self): + class AutoNameEnum(Enum): + def _generate_next_value_(name, start, count, last): + return name + class Color(AutoNameEnum): + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 'red') + self.assertEqual(Color.blue.value, 'blue') + self.assertEqual(Color.green.value, 'green') + + def test_auto_garbage(self): + class Color(Enum): + red = 'red' + blue = auto() + self.assertEqual(Color.blue.value, 1) + + def test_auto_garbage_corrected(self): + class Color(Enum): + red = 'red' + blue = 2 + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 'red') + self.assertEqual(Color.blue.value, 2) + self.assertEqual(Color.green.value, 3) + class TestOrder(unittest.TestCase): @@ -1856,7 +1912,6 @@ class TestFlag(unittest.TestCase): test_pickle_dump_load(self.assertIs, FlagStooges.CURLY|FlagStooges.MOE) test_pickle_dump_load(self.assertIs, FlagStooges) - def test_containment(self): Perm = self.Perm R, W, X = Perm @@ -1877,6 +1932,24 @@ class TestFlag(unittest.TestCase): self.assertFalse(W in RX) self.assertFalse(X in RW) + def test_auto_number(self): + class Color(Flag): + red = auto() + blue = auto() + green = auto() + + self.assertEqual(list(Color), [Color.red, Color.blue, Color.green]) + self.assertEqual(Color.red.value, 1) + self.assertEqual(Color.blue.value, 2) + self.assertEqual(Color.green.value, 4) + + def test_auto_number_garbage(self): + with self.assertRaisesRegex(TypeError, 'Invalid Flag value: .not an int.'): + class Color(Flag): + red = 'not an int' + blue = auto() + + class TestIntFlag(unittest.TestCase): """Tests of the IntFlags.""" |