diff options
-rw-r--r-- | Lib/enum.py | 15 | ||||
-rw-r--r-- | Lib/test/test_enum.py | 90 |
2 files changed, 99 insertions, 6 deletions
diff --git a/Lib/enum.py b/Lib/enum.py index e79b038..056400d 100644 --- a/Lib/enum.py +++ b/Lib/enum.py @@ -690,7 +690,9 @@ class Flag(Enum): pseudo_member = object.__new__(cls) pseudo_member._name_ = None pseudo_member._value_ = value - cls._value2member_map_[value] = pseudo_member + # use setdefault in case another thread already created a composite + # with this value + pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) return pseudo_member def __contains__(self, other): @@ -785,7 +787,9 @@ class IntFlag(int, Flag): pseudo_member = int.__new__(cls, value) pseudo_member._name_ = None pseudo_member._value_ = value - cls._value2member_map_[value] = pseudo_member + # use setdefault in case another thread already created a composite + # with this value + pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member) return pseudo_member def __or__(self, other): @@ -835,18 +839,21 @@ def _decompose(flag, value): # _decompose is only called if the value is not named not_covered = value negative = value < 0 + # issue29167: wrap accesses to _value2member_map_ in a list to avoid race + # conditions between iterating over it and having more psuedo- + # members added to it if negative: # only check for named flags flags_to_check = [ (m, v) - for v, m in flag._value2member_map_.items() + for v, m in list(flag._value2member_map_.items()) if m.name is not None ] else: # check for named flags and powers-of-two flags flags_to_check = [ (m, v) - for v, m in flag._value2member_map_.items() + for v, m in list(flag._value2member_map_.items()) if m.name is not None or _power_of_two(v) ] members = [] diff --git a/Lib/test/test_enum.py b/Lib/test/test_enum.py index e97ef94..13a89fc 100644 --- a/Lib/test/test_enum.py +++ b/Lib/test/test_enum.py @@ -7,6 +7,11 @@ from enum import Enum, IntEnum, EnumMeta, Flag, IntFlag, unique, auto from io import StringIO from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL from test import support +try: + import threading +except ImportError: + threading = None + # for pickle tests try: @@ -1983,6 +1988,45 @@ class TestFlag(unittest.TestCase): d = 6 self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>') + @unittest.skipUnless(threading, 'Threading required for this test.') + @support.reap_threads + def test_unique_composite(self): + # override __eq__ to be identity only + class TestFlag(Flag): + one = auto() + two = auto() + three = auto() + four = auto() + five = auto() + six = auto() + seven = auto() + eight = auto() + def __eq__(self, other): + return self is other + def __hash__(self): + return hash(self._value_) + # have multiple threads competing to complete the composite members + seen = set() + failed = False + def cycle_enum(): + nonlocal failed + try: + for i in range(256): + seen.add(TestFlag(i)) + except Exception: + failed = True + threads = [ + threading.Thread(target=cycle_enum) + for _ in range(8) + ] + with support.start_threads(threads): + pass + # check that only 248 members were created + self.assertFalse( + failed, + 'at least one thread failed while creating composite members') + self.assertEqual(256, len(seen), 'too many composite members created') + class TestIntFlag(unittest.TestCase): """Tests of the IntFlags.""" @@ -2275,6 +2319,46 @@ class TestIntFlag(unittest.TestCase): for f in Open: self.assertEqual(bool(f.value), bool(f)) + @unittest.skipUnless(threading, 'Threading required for this test.') + @support.reap_threads + def test_unique_composite(self): + # override __eq__ to be identity only + class TestFlag(IntFlag): + one = auto() + two = auto() + three = auto() + four = auto() + five = auto() + six = auto() + seven = auto() + eight = auto() + def __eq__(self, other): + return self is other + def __hash__(self): + return hash(self._value_) + # have multiple threads competing to complete the composite members + seen = set() + failed = False + def cycle_enum(): + nonlocal failed + try: + for i in range(256): + seen.add(TestFlag(i)) + except Exception: + failed = True + threads = [ + threading.Thread(target=cycle_enum) + for _ in range(8) + ] + with support.start_threads(threads): + pass + # check that only 248 members were created + self.assertFalse( + failed, + 'at least one thread failed while creating composite members') + self.assertEqual(256, len(seen), 'too many composite members created') + + class TestUnique(unittest.TestCase): def test_unique_clean(self): @@ -2484,7 +2568,8 @@ CONVERT_TEST_NAME_F = 5 class TestIntEnumConvert(unittest.TestCase): def test_convert_value_lookup_priority(self): test_type = enum.IntEnum._convert( - 'UnittestConvert', 'test.test_enum', + 'UnittestConvert', + ('test.test_enum', '__main__')[__name__=='__main__'], filter=lambda x: x.startswith('CONVERT_TEST_')) # We don't want the reverse lookup value to vary when there are # multiple possible names for a given value. It should always @@ -2493,7 +2578,8 @@ class TestIntEnumConvert(unittest.TestCase): def test_convert(self): test_type = enum.IntEnum._convert( - 'UnittestConvert', 'test.test_enum', + 'UnittestConvert', + ('test.test_enum', '__main__')[__name__=='__main__'], filter=lambda x: x.startswith('CONVERT_TEST_')) # Ensure that test_type has all of the desired names and values. self.assertEqual(test_type.CONVERT_TEST_NAME_F, |