summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSam Gross <colesbury@gmail.com>2024-08-02 13:32:08 (GMT)
committerGitHub <noreply@github.com>2024-08-02 13:32:08 (GMT)
commitb5e6fb39a246bf7ee470d58632cdf588bb9d0298 (patch)
tree60b36ecce3d74af6816851715609d4322ccbb2b7
parentfb864c76cd5e450e789a7b4095832e118cc49a39 (diff)
downloadcpython-b5e6fb39a246bf7ee470d58632cdf588bb9d0298.zip
cpython-b5e6fb39a246bf7ee470d58632cdf588bb9d0298.tar.gz
cpython-b5e6fb39a246bf7ee470d58632cdf588bb9d0298.tar.bz2
gh-120974: Make asyncio `swap_current_task` safe in free-threaded build (#122317)
* gh-120974: Make asyncio `swap_current_task` safe in free-threaded build
-rw-r--r--Include/internal/pycore_dict.h7
-rw-r--r--Modules/_asynciomodule.c37
-rw-r--r--Objects/dictobject.c54
3 files changed, 67 insertions, 31 deletions
diff --git a/Include/internal/pycore_dict.h b/Include/internal/pycore_dict.h
index fc304ac..a84246e 100644
--- a/Include/internal/pycore_dict.h
+++ b/Include/internal/pycore_dict.h
@@ -108,8 +108,13 @@ PyAPI_FUNC(PyObject *)_PyDict_LoadGlobal(PyDictObject *, PyDictObject *, PyObjec
/* Consumes references to key and value */
PyAPI_FUNC(int) _PyDict_SetItem_Take2(PyDictObject *op, PyObject *key, PyObject *value);
extern int _PyDict_SetItem_LockHeld(PyDictObject *dict, PyObject *name, PyObject *value);
-extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result);
+// Export for '_asyncio' shared extension
+PyAPI_FUNC(int) _PyDict_SetItem_KnownHash_LockHeld(PyDictObject *mp, PyObject *key,
+ PyObject *value, Py_hash_t hash);
+// Export for '_asyncio' shared extension
+PyAPI_FUNC(int) _PyDict_GetItemRef_KnownHash_LockHeld(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result);
extern int _PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result);
+extern int _PyDict_GetItemRef_Unicode_LockHeld(PyDictObject *op, PyObject *key, PyObject **result);
extern int _PyObjectDict_SetItem(PyTypeObject *tp, PyObject *obj, PyObject **dictptr, PyObject *name, PyObject *value);
extern int _PyDict_Pop_KnownHash(
diff --git a/Modules/_asynciomodule.c b/Modules/_asynciomodule.c
index 873c17c..c6eb43f 100644
--- a/Modules/_asynciomodule.c
+++ b/Modules/_asynciomodule.c
@@ -2027,6 +2027,24 @@ leave_task(asyncio_state *state, PyObject *loop, PyObject *task)
}
static PyObject *
+swap_current_task_lock_held(PyDictObject *current_tasks, PyObject *loop,
+ Py_hash_t hash, PyObject *task)
+{
+ PyObject *prev_task;
+ if (_PyDict_GetItemRef_KnownHash_LockHeld(current_tasks, loop, hash, &prev_task) < 0) {
+ return NULL;
+ }
+ if (_PyDict_SetItem_KnownHash_LockHeld(current_tasks, loop, task, hash) < 0) {
+ Py_XDECREF(prev_task);
+ return NULL;
+ }
+ if (prev_task == NULL) {
+ Py_RETURN_NONE;
+ }
+ return prev_task;
+}
+
+static PyObject *
swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task)
{
PyObject *prev_task;
@@ -2041,24 +2059,15 @@ swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task)
return prev_task;
}
- Py_hash_t hash;
- hash = PyObject_Hash(loop);
+ Py_hash_t hash = PyObject_Hash(loop);
if (hash == -1) {
return NULL;
}
- prev_task = _PyDict_GetItem_KnownHash(state->current_tasks, loop, hash);
- if (prev_task == NULL) {
- if (PyErr_Occurred()) {
- return NULL;
- }
- prev_task = Py_None;
- }
- Py_INCREF(prev_task);
- if (_PyDict_SetItem_KnownHash(state->current_tasks, loop, task, hash) == -1) {
- Py_DECREF(prev_task);
- return NULL;
- }
+ PyDictObject *current_tasks = (PyDictObject *)state->current_tasks;
+ Py_BEGIN_CRITICAL_SECTION(current_tasks);
+ prev_task = swap_current_task_lock_held(current_tasks, loop, hash, task);
+ Py_END_CRITICAL_SECTION();
return prev_task;
}
diff --git a/Objects/dictobject.c b/Objects/dictobject.c
index 6a16a04..3e9f982 100644
--- a/Objects/dictobject.c
+++ b/Objects/dictobject.c
@@ -2221,6 +2221,29 @@ _PyDict_GetItem_KnownHash(PyObject *op, PyObject *key, Py_hash_t hash)
* exception occurred.
*/
int
+_PyDict_GetItemRef_KnownHash_LockHeld(PyDictObject *op, PyObject *key,
+ Py_hash_t hash, PyObject **result)
+{
+ PyObject *value;
+ Py_ssize_t ix = _Py_dict_lookup(op, key, hash, &value);
+ assert(ix >= 0 || value == NULL);
+ if (ix == DKIX_ERROR) {
+ *result = NULL;
+ return -1;
+ }
+ if (value == NULL) {
+ *result = NULL;
+ return 0; // missing key
+ }
+ *result = Py_NewRef(value);
+ return 1; // key is present
+}
+
+/* Gets an item and provides a new reference if the value is present.
+ * Returns 1 if the key is present, 0 if the key is missing, and -1 if an
+ * exception occurred.
+*/
+int
_PyDict_GetItemRef_KnownHash(PyDictObject *op, PyObject *key, Py_hash_t hash, PyObject **result)
{
PyObject *value;
@@ -2460,11 +2483,21 @@ setitem_lock_held(PyDictObject *mp, PyObject *key, PyObject *value)
int
-_PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value,
- Py_hash_t hash)
+_PyDict_SetItem_KnownHash_LockHeld(PyDictObject *mp, PyObject *key, PyObject *value,
+ Py_hash_t hash)
{
- PyDictObject *mp;
+ PyInterpreterState *interp = _PyInterpreterState_GET();
+ if (mp->ma_keys == Py_EMPTY_KEYS) {
+ return insert_to_emptydict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value));
+ }
+ /* insertdict() handles any resizing that might be necessary */
+ return insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value));
+}
+int
+_PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value,
+ Py_hash_t hash)
+{
if (!PyDict_Check(op)) {
PyErr_BadInternalCall();
return -1;
@@ -2472,21 +2505,10 @@ _PyDict_SetItem_KnownHash(PyObject *op, PyObject *key, PyObject *value,
assert(key);
assert(value);
assert(hash != -1);
- mp = (PyDictObject *)op;
int res;
- PyInterpreterState *interp = _PyInterpreterState_GET();
-
- Py_BEGIN_CRITICAL_SECTION(mp);
-
- if (mp->ma_keys == Py_EMPTY_KEYS) {
- res = insert_to_emptydict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value));
- }
- else {
- /* insertdict() handles any resizing that might be necessary */
- res = insertdict(interp, mp, Py_NewRef(key), hash, Py_NewRef(value));
- }
-
+ Py_BEGIN_CRITICAL_SECTION(op);
+ res = _PyDict_SetItem_KnownHash_LockHeld((PyDictObject *)op, key, value, hash);
Py_END_CRITICAL_SECTION();
return res;
}