summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Include/internal/pycore_dict.h2
-rw-r--r--Lib/test/test_free_threading/test_dict.py141
-rw-r--r--Objects/dictobject.c140
-rw-r--r--Objects/object.c4
4 files changed, 216 insertions, 71 deletions
diff --git a/Include/internal/pycore_dict.h b/Include/internal/pycore_dict.h
index 3ba8ee7..cb7d4c3 100644
--- a/Include/internal/pycore_dict.h
+++ b/Include/internal/pycore_dict.h
@@ -105,10 +105,10 @@ PyAPI_FUNC(PyObject *)_PyDict_LoadGlobal(PyDictObject *, PyDictObject *, PyObjec
/* Consumes references to key and value */
PyAPI_FUNC(int) _PyDict_SetItem_Take2(PyDictObject *op, PyObject *key, PyObject *value);
-extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr, PyObject *name, PyObject *value);
extern int _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value);
extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result);
extern int _PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result);
+extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr, PyObject *name, PyObject *value);
extern int _PyDict_Pop_KnownHash(
PyDictObject *dict,
diff --git a/Lib/test/test_free_threading/test_dict.py b/Lib/test/test_free_threading/test_dict.py
new file mode 100644
index 0000000..6a909dd
--- /dev/null
+++ b/Lib/test/test_free_threading/test_dict.py
@@ -0,0 +1,141 @@
+import gc
+import time
+import unittest
+import weakref
+
+from ast import Or
+from functools import partial
+from threading import Thread
+from unittest import TestCase
+
+from test.support import threading_helper
+
+
+@threading_helper.requires_working_threading()
+class TestDict(TestCase):
+ def test_racing_creation_shared_keys(self):
+ """Verify that creating dictionaries is thread safe when we
+ have a type with shared keys"""
+ class C(int):
+ pass
+
+ self.racing_creation(C)
+
+ def test_racing_creation_no_shared_keys(self):
+ """Verify that creating dictionaries is thread safe when we
+ have a type with an ordinary dict"""
+ self.racing_creation(Or)
+
+ def test_racing_creation_inline_values_invalid(self):
+ """Verify that re-creating a dict after we have invalid inline values
+ is thread safe"""
+ class C:
+ pass
+
+ def make_obj():
+ a = C()
+ # Make object, make inline values invalid, and then delete dict
+ a.__dict__ = {}
+ del a.__dict__
+ return a
+
+ self.racing_creation(make_obj)
+
+ def test_racing_creation_nonmanaged_dict(self):
+ """Verify that explicit creation of an unmanaged dict is thread safe
+ outside of the normal attribute setting code path"""
+ def make_obj():
+ def f(): pass
+ return f
+
+ def set(func, name, val):
+ # Force creation of the dict via PyObject_GenericGetDict
+ func.__dict__[name] = val
+
+ self.racing_creation(make_obj, set)
+
+ def racing_creation(self, cls, set=setattr):
+ objects = []
+ processed = []
+
+ OBJECT_COUNT = 100
+ THREAD_COUNT = 10
+ CUR = 0
+
+ for i in range(OBJECT_COUNT):
+ objects.append(cls())
+
+ def writer_func(name):
+ last = -1
+ while True:
+ if CUR == last:
+ continue
+ elif CUR == OBJECT_COUNT:
+ break
+
+ obj = objects[CUR]
+ set(obj, name, name)
+ last = CUR
+ processed.append(name)
+
+ writers = []
+ for x in range(THREAD_COUNT):
+ writer = Thread(target=partial(writer_func, f"a{x:02}"))
+ writers.append(writer)
+ writer.start()
+
+ for i in range(OBJECT_COUNT):
+ CUR = i
+ while len(processed) != THREAD_COUNT:
+ time.sleep(0.001)
+ processed.clear()
+
+ CUR = OBJECT_COUNT
+
+ for writer in writers:
+ writer.join()
+
+ for obj_idx, obj in enumerate(objects):
+ assert (
+ len(obj.__dict__) == THREAD_COUNT
+ ), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
+ for i in range(THREAD_COUNT):
+ assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
+
+ def test_racing_set_dict(self):
+ """Races assigning to __dict__ should be thread safe"""
+
+ def f(): pass
+ l = []
+ THREAD_COUNT = 10
+ class MyDict(dict): pass
+
+ def writer_func(l):
+ for i in range(1000):
+ d = MyDict()
+ l.append(weakref.ref(d))
+ f.__dict__ = d
+
+ lists = []
+ writers = []
+ for x in range(THREAD_COUNT):
+ thread_list = []
+ lists.append(thread_list)
+ writer = Thread(target=partial(writer_func, thread_list))
+ writers.append(writer)
+
+ for writer in writers:
+ writer.start()
+
+ for writer in writers:
+ writer.join()
+
+ f.__dict__ = {}
+ gc.collect()
+
+ for thread_list in lists:
+ for ref in thread_list:
+ self.assertIsNone(ref())
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index 3e662e0..b0fce09 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -924,16 +924,15 @@ new_dict(PyInterpreterState *interp,
return (PyObject *)mp;
}
-/* Consumes a reference to the keys object */
static PyObject *
new_dict_with_shared_keys(PyInterpreterState *interp, PyDictKeysObject *keys)
{
size_t size = shared_keys_usable_size(keys);
PyDictValues *values = new_values(size);
if (values == NULL) {
- dictkeys_decref(interp, keys, false);
return PyErr_NoMemory();
}
+ dictkeys_incref(keys);
for (size_t i = 0; i < size; i++) {
values->values[i] = NULL;
}
@@ -6693,8 +6692,6 @@ materialize_managed_dict_lock_held(PyObject *obj)
{
_Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(obj);
- OBJECT_STAT_INC(dict_materialized_on_request);
-
PyDictValues *values = _PyObject_InlineValues(obj);
PyInterpreterState *interp = _PyInterpreterState_GET();
PyDictKeysObject *keys = CACHED_KEYS(Py_TYPE(obj));
@@ -7186,35 +7183,77 @@ _PyDict_DetachFromObject(PyDictObject *mp, PyObject *obj)
return 0;
}
-PyObject *
-PyObject_GenericGetDict(PyObject *obj, void *context)
+static inline PyObject *
+ensure_managed_dict(PyObject *obj)
{
- PyInterpreterState *interp = _PyInterpreterState_GET();
- PyTypeObject *tp = Py_TYPE(obj);
- PyDictObject *dict;
- if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
- dict = _PyObject_GetManagedDict(obj);
- if (dict == NULL &&
- (tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
+ PyDictObject *dict = _PyObject_GetManagedDict(obj);
+ if (dict == NULL) {
+ PyTypeObject *tp = Py_TYPE(obj);
+ if ((tp->tp_flags & Py_TPFLAGS_INLINE_VALUES) &&
FT_ATOMIC_LOAD_UINT8(_PyObject_InlineValues(obj)->valid)) {
dict = _PyObject_MaterializeManagedDict(obj);
}
- else if (dict == NULL) {
- Py_BEGIN_CRITICAL_SECTION(obj);
-
+ else {
+#ifdef Py_GIL_DISABLED
// Check again that we're not racing with someone else creating the dict
+ Py_BEGIN_CRITICAL_SECTION(obj);
dict = _PyObject_GetManagedDict(obj);
- if (dict == NULL) {
- OBJECT_STAT_INC(dict_materialized_on_request);
- dictkeys_incref(CACHED_KEYS(tp));
- dict = (PyDictObject *)new_dict_with_shared_keys(interp, CACHED_KEYS(tp));
- FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
- (PyDictObject *)dict);
+ if (dict != NULL) {
+ goto done;
}
+#endif
+ dict = (PyDictObject *)new_dict_with_shared_keys(_PyInterpreterState_GET(),
+ CACHED_KEYS(tp));
+ FT_ATOMIC_STORE_PTR_RELEASE(_PyObject_ManagedDictPointer(obj)->dict,
+ (PyDictObject *)dict);
+#ifdef Py_GIL_DISABLED
+done:
Py_END_CRITICAL_SECTION();
+#endif
}
- return Py_XNewRef((PyObject *)dict);
+ }
+ return (PyObject *)dict;
+}
+
+static inline PyObject *
+ensure_nonmanaged_dict(PyObject *obj, PyObject **dictptr)
+{
+ PyDictKeysObject *cached;
+
+ PyObject *dict = FT_ATOMIC_LOAD_PTR_ACQUIRE(*dictptr);
+ if (dict == NULL) {
+#ifdef Py_GIL_DISABLED
+ Py_BEGIN_CRITICAL_SECTION(obj);
+ dict = *dictptr;
+ if (dict != NULL) {
+ goto done;
+ }
+#endif
+ PyTypeObject *tp = Py_TYPE(obj);
+ if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
+ PyInterpreterState *interp = _PyInterpreterState_GET();
+ assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
+ dict = new_dict_with_shared_keys(interp, cached);
+ }
+ else {
+ dict = PyDict_New();
+ }
+ FT_ATOMIC_STORE_PTR_RELEASE(*dictptr, dict);
+#ifdef Py_GIL_DISABLED
+done:
+ Py_END_CRITICAL_SECTION();
+#endif
+ }
+ return dict;
+}
+
+PyObject *
+PyObject_GenericGetDict(PyObject *obj, void *context)
+{
+ PyTypeObject *tp = Py_TYPE(obj);
+ if (_PyType_HasFeature(tp, Py_TPFLAGS_MANAGED_DICT)) {
+ return Py_XNewRef(ensure_managed_dict(obj));
}
else {
PyObject **dictptr = _PyObject_ComputedDictPointer(obj);
@@ -7223,65 +7262,28 @@ PyObject_GenericGetDict(PyObject *obj, void *context)
"This object has no __dict__");
return NULL;
}
- PyObject *dict = *dictptr;
- if (dict == NULL) {
- PyTypeObject *tp = Py_TYPE(obj);
- if (_PyType_HasFeature(tp, Py_TPFLAGS_HEAPTYPE) && CACHED_KEYS(tp)) {
- dictkeys_incref(CACHED_KEYS(tp));
- *dictptr = dict = new_dict_with_shared_keys(
- interp, CACHED_KEYS(tp));
- }
- else {
- *dictptr = dict = PyDict_New();
- }
- }
- return Py_XNewRef(dict);
+
+ return Py_XNewRef(ensure_nonmanaged_dict(obj, dictptr));
}
}
int
-_PyObjectDict_SetItem(PyTypeObject *tp, PyObject **dictptr,
+_PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr,
PyObject *key, PyObject *value)
{
PyObject *dict;
int res;
- PyDictKeysObject *cached;
- PyInterpreterState *interp = _PyInterpreterState_GET();
assert(dictptr != NULL);
- if ((tp->tp_flags & Py_TPFLAGS_HEAPTYPE) && (cached = CACHED_KEYS(tp))) {
- assert(dictptr != NULL);
- dict = *dictptr;
- if (dict == NULL) {
- assert(!_PyType_HasFeature(tp, Py_TPFLAGS_INLINE_VALUES));
- dictkeys_incref(cached);
- dict = new_dict_with_shared_keys(interp, cached);
- if (dict == NULL)
- return -1;
- *dictptr = dict;
- }
- if (value == NULL) {
- res = PyDict_DelItem(dict, key);
- }
- else {
- res = PyDict_SetItem(dict, key, value);
- }
- } else {
- dict = *dictptr;
- if (dict == NULL) {
- dict = PyDict_New();
- if (dict == NULL)
- return -1;
- *dictptr = dict;
- }
- if (value == NULL) {
- res = PyDict_DelItem(dict, key);
- } else {
- res = PyDict_SetItem(dict, key, value);
- }
+ dict = ensure_nonmanaged_dict(obj, dictptr);
+ if (dict == NULL) {
+ return -1;
}
+ Py_BEGIN_CRITICAL_SECTION(dict);
+ res = _PyDict_SetItem_LockHeld((PyDictObject *)dict, key, value);
ASSERT_CONSISTENT(dict);
+ Py_END_CRITICAL_SECTION();
return res;
}
diff --git a/Objects/object.c b/Objects/object.c
index effbd51..8ad0389 100644
--- a/Objects/object.c
+++ b/Objects/object.c
@@ -1731,7 +1731,7 @@ _PyObject_GenericSetAttrWithDict(PyObject *obj, PyObject *name,
goto done;
}
else {
- res = _PyObjectDict_SetItem(tp, dictptr, name, value);
+ res = _PyObjectDict_SetItem(tp, obj, dictptr, name, value);
}
}
else {
@@ -1789,7 +1789,9 @@ PyObject_GenericSetDict(PyObject *obj, PyObject *value, void *context)
"not a '%.200s'", Py_TYPE(value)->tp_name);
return -1;
}
+ Py_BEGIN_CRITICAL_SECTION(obj);
Py_XSETREF(*dictptr, Py_NewRef(value));
+ Py_END_CRITICAL_SECTION();
return 0;
}