summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorErlend Egeberg Aasland <erlend.aasland@innova.no>2021-08-24 12:24:09 (GMT)
committerGitHub <noreply@github.com>2021-08-24 12:24:09 (GMT)
commit9ed523159c7ba840dbf403e02498eeae1b5d3ed9 (patch)
treee194a41d3a15d9d72f3549f7f6fdc810bceb78cd
parent7179930ab5f5b2dea039023bec968aadc03e3775 (diff)
downloadcpython-9ed523159c7ba840dbf403e02498eeae1b5d3ed9.zip
cpython-9ed523159c7ba840dbf403e02498eeae1b5d3ed9.tar.gz
cpython-9ed523159c7ba840dbf403e02498eeae1b5d3ed9.tar.bz2
bpo-42064: Pass module state to `sqlite3` UDF callbacks (GH-27456)
- Establish common callback context struct - Convert UDF callbacks to fetch module state from callback context
-rw-r--r--Modules/_sqlite/connection.c89
-rw-r--r--Modules/_sqlite/connection.h6
2 files changed, 64 insertions, 31 deletions
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 0645367..8ad5f5f 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -612,8 +612,10 @@ set_sqlite_error(sqlite3_context *context, const char *msg)
else {
sqlite3_result_error(context, msg, -1);
}
- pysqlite_state *state = pysqlite_get_state(NULL);
- if (state->enable_callback_tracebacks) {
+ callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ assert(ctx != NULL);
+ assert(ctx->state != NULL);
+ if (ctx->state->enable_callback_tracebacks) {
PyErr_Print();
}
else {
@@ -625,7 +627,6 @@ static void
_pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv)
{
PyObject* args;
- PyObject* py_func;
PyObject* py_retval = NULL;
int ok;
@@ -633,11 +634,11 @@ _pysqlite_func_callback(sqlite3_context *context, int argc, sqlite3_value **argv
threadstate = PyGILState_Ensure();
- py_func = (PyObject*)sqlite3_user_data(context);
-
args = _pysqlite_build_py_params(context, argc, argv);
if (args) {
- py_retval = PyObject_CallObject(py_func, args);
+ callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ assert(ctx != NULL);
+ py_retval = PyObject_CallObject(ctx->callable, args);
Py_DECREF(args);
}
@@ -657,7 +658,6 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
{
PyObject* args;
PyObject* function_result = NULL;
- PyObject* aggregate_class;
PyObject** aggregate_instance;
PyObject* stepmethod = NULL;
@@ -665,12 +665,12 @@ static void _pysqlite_step_callback(sqlite3_context *context, int argc, sqlite3_
threadstate = PyGILState_Ensure();
- aggregate_class = (PyObject*)sqlite3_user_data(context);
-
aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
if (*aggregate_instance == NULL) {
- *aggregate_instance = _PyObject_CallNoArg(aggregate_class);
+ callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ assert(ctx != NULL);
+ *aggregate_instance = _PyObject_CallNoArg(ctx->callable);
if (!*aggregate_instance) {
set_sqlite_error(context,
"user-defined aggregate's '__init__' method raised error");
@@ -784,14 +784,35 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
Py_SETREF(self->cursors, new_list);
}
-static void _destructor(void* args)
+static callback_context *
+create_callback_context(pysqlite_state *state, PyObject *callable)
{
- // This function may be called without the GIL held, so we need to ensure
- // that we destroy 'args' with the GIL
- PyGILState_STATE gstate;
- gstate = PyGILState_Ensure();
- Py_DECREF((PyObject*)args);
+ PyGILState_STATE gstate = PyGILState_Ensure();
+ callback_context *ctx = PyMem_Malloc(sizeof(callback_context));
+ if (ctx != NULL) {
+ ctx->callable = Py_NewRef(callable);
+ ctx->state = state;
+ }
PyGILState_Release(gstate);
+ return ctx;
+}
+
+static void
+free_callback_context(callback_context *ctx)
+{
+ if (ctx != NULL) {
+ // This function may be called without the GIL held, so we need to
+ // ensure that we destroy 'ctx' with the GIL held.
+ PyGILState_STATE gstate = PyGILState_Ensure();
+ Py_DECREF(ctx->callable);
+ PyMem_Free(ctx);
+ PyGILState_Release(gstate);
+ }
+}
+
+static void _destructor(void* args)
+{
+ free_callback_context((callback_context *)args);
}
/*[clinic input]
@@ -833,11 +854,11 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
flags |= SQLITE_DETERMINISTIC;
#endif
}
- rc = sqlite3_create_function_v2(self->db,
- name,
- narg,
- flags,
- (void*)Py_NewRef(func),
+ callback_context *ctx = create_callback_context(self->state, func);
+ if (ctx == NULL) {
+ return NULL;
+ }
+ rc = sqlite3_create_function_v2(self->db, name, narg, flags, ctx,
_pysqlite_func_callback,
NULL,
NULL,
@@ -873,11 +894,12 @@ pysqlite_connection_create_aggregate_impl(pysqlite_Connection *self,
return NULL;
}
- rc = sqlite3_create_function_v2(self->db,
- name,
- n_arg,
- SQLITE_UTF8,
- (void*)Py_NewRef(aggregate_class),
+ callback_context *ctx = create_callback_context(self->state,
+ aggregate_class);
+ if (ctx == NULL) {
+ return NULL;
+ }
+ rc = sqlite3_create_function_v2(self->db, name, n_arg, SQLITE_UTF8, ctx,
0,
&_pysqlite_step_callback,
&_pysqlite_final_callback,
@@ -1439,7 +1461,6 @@ pysqlite_collation_callback(
int text1_length, const void* text1_data,
int text2_length, const void* text2_data)
{
- PyObject* callback = (PyObject*)context;
PyObject* string1 = 0;
PyObject* string2 = 0;
PyGILState_STATE gilstate;
@@ -1459,8 +1480,10 @@ pysqlite_collation_callback(
goto finally; /* failed to allocate strings */
}
+ callback_context *ctx = (callback_context *)context;
+ assert(ctx != NULL);
PyObject *args[] = { string1, string2 }; // Borrowed refs.
- retval = PyObject_Vectorcall(callback, args, 2, NULL);
+ retval = PyObject_Vectorcall(ctx->callable, args, 2, NULL);
if (retval == NULL) {
/* execution failed */
goto finally;
@@ -1690,6 +1713,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
return NULL;
}
+ callback_context *ctx = NULL;
int rc;
int flags = SQLITE_UTF8;
if (callable == Py_None) {
@@ -1701,8 +1725,11 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
PyErr_SetString(PyExc_TypeError, "parameter must be callable");
return NULL;
}
- rc = sqlite3_create_collation_v2(self->db, name, flags,
- Py_NewRef(callable),
+ ctx = create_callback_context(self->state, callable);
+ if (ctx == NULL) {
+ return NULL;
+ }
+ rc = sqlite3_create_collation_v2(self->db, name, flags, ctx,
&pysqlite_collation_callback,
&_destructor);
}
@@ -1713,7 +1740,7 @@ pysqlite_connection_create_collation_impl(pysqlite_Connection *self,
* the context before returning.
*/
if (callable != Py_None) {
- Py_DECREF(callable);
+ free_callback_context(ctx);
}
_pysqlite_seterror(self->state, self->db);
return NULL;
diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h
index 4f08a6d..11b3a80 100644
--- a/Modules/_sqlite/connection.h
+++ b/Modules/_sqlite/connection.h
@@ -32,6 +32,12 @@
#include "sqlite3.h"
+typedef struct _callback_context
+{
+ PyObject *callable;
+ pysqlite_state *state;
+} callback_context;
+
typedef struct
{
PyObject_HEAD