Skip to content

Commit 28cf663

Browse files
committed
closes issue29167: fix race condition in (Int)Flag
1 parent 3831b0a commit 28cf663

File tree

2 files changed

+99
-6
lines changed

2 files changed

+99
-6
lines changed

Lib/enum.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,9 @@ def _create_pseudo_member_(cls, value):
690690
pseudo_member = object.__new__(cls)
691691
pseudo_member._name_ = None
692692
pseudo_member._value_ = value
693-
cls._value2member_map_[value] = pseudo_member
693+
# use setdefault in case another thread already created a composite
694+
# with this value
695+
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
694696
return pseudo_member
695697

696698
def __contains__(self, other):
@@ -785,7 +787,9 @@ def _create_pseudo_member_(cls, value):
785787
pseudo_member = int.__new__(cls, value)
786788
pseudo_member._name_ = None
787789
pseudo_member._value_ = value
788-
cls._value2member_map_[value] = pseudo_member
790+
# use setdefault in case another thread already created a composite
791+
# with this value
792+
pseudo_member = cls._value2member_map_.setdefault(value, pseudo_member)
789793
return pseudo_member
790794

791795
def __or__(self, other):
@@ -835,18 +839,21 @@ def _decompose(flag, value):
835839
# _decompose is only called if the value is not named
836840
not_covered = value
837841
negative = value < 0
842+
# issue29167: wrap accesses to _value2member_map_ in a list to avoid race
843+
# conditions between iterating over it and having more psuedo-
844+
# members added to it
838845
if negative:
839846
# only check for named flags
840847
flags_to_check = [
841848
(m, v)
842-
for v, m in flag._value2member_map_.items()
849+
for v, m in list(flag._value2member_map_.items())
843850
if m.name is not None
844851
]
845852
else:
846853
# check for named flags and powers-of-two flags
847854
flags_to_check = [
848855
(m, v)
849-
for v, m in flag._value2member_map_.items()
856+
for v, m in list(flag._value2member_map_.items())
850857
if m.name is not None or _power_of_two(v)
851858
]
852859
members = []

Lib/test/test_enum.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from io import StringIO
88
from pickle import dumps, loads, PicklingError, HIGHEST_PROTOCOL
99
from test import support
10+
try:
11+
import threading
12+
except ImportError:
13+
threading = None
14+
1015

1116
# for pickle tests
1217
try:
@@ -1983,6 +1988,45 @@ class Bizarre(Flag):
19831988
d = 6
19841989
self.assertEqual(repr(Bizarre(7)), '<Bizarre.d|c|b: 7>')
19851990

1991+
@unittest.skipUnless(threading, 'Threading required for this test.')
1992+
@support.reap_threads
1993+
def test_unique_composite(self):
1994+
# override __eq__ to be identity only
1995+
class TestFlag(Flag):
1996+
one = auto()
1997+
two = auto()
1998+
three = auto()
1999+
four = auto()
2000+
five = auto()
2001+
six = auto()
2002+
seven = auto()
2003+
eight = auto()
2004+
def __eq__(self, other):
2005+
return self is other
2006+
def __hash__(self):
2007+
return hash(self._value_)
2008+
# have multiple threads competing to complete the composite members
2009+
seen = set()
2010+
failed = False
2011+
def cycle_enum():
2012+
nonlocal failed
2013+
try:
2014+
for i in range(256):
2015+
seen.add(TestFlag(i))
2016+
except Exception:
2017+
failed = True
2018+
threads = [
2019+
threading.Thread(target=cycle_enum)
2020+
for _ in range(8)
2021+
]
2022+
with support.start_threads(threads):
2023+
pass
2024+
# check that only 248 members were created
2025+
self.assertFalse(
2026+
failed,
2027+
'at least one thread failed while creating composite members')
2028+
self.assertEqual(256, len(seen), 'too many composite members created')
2029+
19862030

19872031
class TestIntFlag(unittest.TestCase):
19882032
"""Tests of the IntFlags."""
@@ -2275,6 +2319,46 @@ def test_bool(self):
22752319
for f in Open:
22762320
self.assertEqual(bool(f.value), bool(f))
22772321

2322+
@unittest.skipUnless(threading, 'Threading required for this test.')
2323+
@support.reap_threads
2324+
def test_unique_composite(self):
2325+
# override __eq__ to be identity only
2326+
class TestFlag(IntFlag):
2327+
one = auto()
2328+
two = auto()
2329+
three = auto()
2330+
four = auto()
2331+
five = auto()
2332+
six = auto()
2333+
seven = auto()
2334+
eight = auto()
2335+
def __eq__(self, other):
2336+
return self is other
2337+
def __hash__(self):
2338+
return hash(self._value_)
2339+
# have multiple threads competing to complete the composite members
2340+
seen = set()
2341+
failed = False
2342+
def cycle_enum():
2343+
nonlocal failed
2344+
try:
2345+
for i in range(256):
2346+
seen.add(TestFlag(i))
2347+
except Exception:
2348+
failed = True
2349+
threads = [
2350+
threading.Thread(target=cycle_enum)
2351+
for _ in range(8)
2352+
]
2353+
with support.start_threads(threads):
2354+
pass
2355+
# check that only 248 members were created
2356+
self.assertFalse(
2357+
failed,
2358+
'at least one thread failed while creating composite members')
2359+
self.assertEqual(256, len(seen), 'too many composite members created')
2360+
2361+
22782362
class TestUnique(unittest.TestCase):
22792363

22802364
def test_unique_clean(self):
@@ -2484,7 +2568,8 @@ def test__all__(self):
24842568
class TestIntEnumConvert(unittest.TestCase):
24852569
def test_convert_value_lookup_priority(self):
24862570
test_type = enum.IntEnum._convert(
2487-
'UnittestConvert', 'test.test_enum',
2571+
'UnittestConvert',
2572+
('test.test_enum', '__main__')[__name__=='__main__'],
24882573
filter=lambda x: x.startswith('CONVERT_TEST_'))
24892574
# We don't want the reverse lookup value to vary when there are
24902575
# multiple possible names for a given value. It should always
@@ -2493,7 +2578,8 @@ def test_convert_value_lookup_priority(self):
24932578

24942579
def test_convert(self):
24952580
test_type = enum.IntEnum._convert(
2496-
'UnittestConvert', 'test.test_enum',
2581+
'UnittestConvert',
2582+
('test.test_enum', '__main__')[__name__=='__main__'],
24972583
filter=lambda x: x.startswith('CONVERT_TEST_'))
24982584
# Ensure that test_type has all of the desired names and values.
24992585
self.assertEqual(test_type.CONVERT_TEST_NAME_F,

0 commit comments

Comments
 (0)