diff options
-rw-r--r-- | Modules/_sqlite/connection.c | 176 | ||||
-rw-r--r-- | Modules/_sqlite/connection.h | 8 |
2 files changed, 111 insertions, 73 deletions
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 0780d41..bf80337 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -55,6 +55,9 @@ static const char * const begin_statements[] = { static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level, void *Py_UNUSED(ignored)); static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self); +static void free_callback_context(callback_context *ctx); +static void set_callback_context(callback_context **ctx_pp, + callback_context *ctx); static PyObject * new_statement_cache(pysqlite_Connection *self, int maxsize) @@ -170,9 +173,9 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, self->thread_ident = PyThread_get_thread_ident(); self->check_same_thread = check_same_thread; - self->function_pinboard_trace_callback = NULL; - self->function_pinboard_progress_handler = NULL; - self->function_pinboard_authorizer_cb = NULL; + set_callback_context(&self->trace_ctx, NULL); + set_callback_context(&self->progress_ctx, NULL); + set_callback_context(&self->authorizer_ctx, NULL); self->Warning = state->Warning; self->Error = state->Error; @@ -216,6 +219,13 @@ pysqlite_do_all_statements(pysqlite_Connection *self) } } +#define VISIT_CALLBACK_CONTEXT(ctx) \ +do { \ + if (ctx) { \ + Py_VISIT(ctx->callable); \ + } \ +} while (0) + static int connection_traverse(pysqlite_Connection *self, visitproc visit, void *arg) { @@ -225,12 +235,21 @@ connection_traverse(pysqlite_Connection *self, visitproc visit, void *arg) Py_VISIT(self->cursors); Py_VISIT(self->row_factory); Py_VISIT(self->text_factory); - Py_VISIT(self->function_pinboard_trace_callback); - Py_VISIT(self->function_pinboard_progress_handler); - Py_VISIT(self->function_pinboard_authorizer_cb); + VISIT_CALLBACK_CONTEXT(self->trace_ctx); + VISIT_CALLBACK_CONTEXT(self->progress_ctx); + VISIT_CALLBACK_CONTEXT(self->authorizer_ctx); +#undef VISIT_CALLBACK_CONTEXT return 0; } +static inline void +clear_callback_context(callback_context *ctx) +{ + if (ctx != NULL) { + Py_CLEAR(ctx->callable); + } +} + static int connection_clear(pysqlite_Connection *self) { @@ -239,9 +258,9 @@ connection_clear(pysqlite_Connection *self) Py_CLEAR(self->cursors); Py_CLEAR(self->row_factory); Py_CLEAR(self->text_factory); - Py_CLEAR(self->function_pinboard_trace_callback); - Py_CLEAR(self->function_pinboard_progress_handler); - Py_CLEAR(self->function_pinboard_authorizer_cb); + clear_callback_context(self->trace_ctx); + clear_callback_context(self->progress_ctx); + clear_callback_context(self->authorizer_ctx); return 0; } @@ -256,6 +275,14 @@ connection_close(pysqlite_Connection *self) } static void +free_callback_contexts(pysqlite_Connection *self) +{ + set_callback_context(&self->trace_ctx, NULL); + set_callback_context(&self->progress_ctx, NULL); + set_callback_context(&self->authorizer_ctx, NULL); +} + +static void connection_dealloc(pysqlite_Connection *self) { PyTypeObject *tp = Py_TYPE(self); @@ -264,6 +291,7 @@ connection_dealloc(pysqlite_Connection *self) /* Clean up if user has not called .close() explicitly. */ connection_close(self); + free_callback_contexts(self); tp->tp_free(self); Py_DECREF(tp); @@ -600,6 +628,19 @@ error: return NULL; } +static void +print_or_clear_traceback(callback_context *ctx) +{ + assert(ctx != NULL); + assert(ctx->state != NULL); + if (ctx->state->enable_callback_tracebacks) { + PyErr_Print(); + } + else { + PyErr_Clear(); + } +} + // Checks the Python exception and sets the appropriate SQLite error code. static void set_sqlite_error(sqlite3_context *context, const char *msg) @@ -615,14 +656,7 @@ set_sqlite_error(sqlite3_context *context, const char *msg) sqlite3_result_error(context, msg, -1); } callback_context *ctx = (callback_context *)sqlite3_user_data(context); - assert(ctx != NULL); - assert(ctx->state != NULL); - if (ctx->state->enable_callback_tracebacks) { - PyErr_Print(); - } - else { - PyErr_Clear(); - } + print_or_clear_traceback(ctx); } static void @@ -796,11 +830,22 @@ static void free_callback_context(callback_context *ctx) { assert(ctx != NULL); - Py_DECREF(ctx->callable); + Py_XDECREF(ctx->callable); PyMem_Free(ctx); } static void +set_callback_context(callback_context **ctx_pp, callback_context *ctx) +{ + assert(ctx_pp != NULL); + callback_context *tmp = *ctx_pp; + *ctx_pp = ctx; + if (tmp != NULL) { + free_callback_context(tmp); + } +} + +static void destructor_callback(void *ctx) { if (ctx != NULL) { @@ -917,33 +962,22 @@ authorizer_callback(void *ctx, int action, const char *arg1, PyGILState_STATE gilstate = PyGILState_Ensure(); PyObject *ret; - int rc; + int rc = SQLITE_DENY; - ret = PyObject_CallFunction((PyObject*)ctx, "issss", action, arg1, arg2, - dbname, access_attempt_source); + assert(ctx != NULL); + PyObject *callable = ((callback_context *)ctx)->callable; + ret = PyObject_CallFunction(callable, "issss", action, arg1, arg2, dbname, + access_attempt_source); if (ret == NULL) { - pysqlite_state *state = pysqlite_get_state(NULL); - if (state->enable_callback_tracebacks) { - PyErr_Print(); - } - else { - PyErr_Clear(); - } - + print_or_clear_traceback(ctx); rc = SQLITE_DENY; } else { if (PyLong_Check(ret)) { rc = _PyLong_AsInt(ret); if (rc == -1 && PyErr_Occurred()) { - pysqlite_state *state = pysqlite_get_state(NULL); - if (state->enable_callback_tracebacks) { - PyErr_Print(); - } - else { - PyErr_Clear(); - } + print_or_clear_traceback(ctx); rc = SQLITE_DENY; } } @@ -964,8 +998,10 @@ progress_callback(void *ctx) int rc; PyObject *ret; - ret = _PyObject_CallNoArg((PyObject*)ctx); + assert(ctx != NULL); + PyObject *callable = ((callback_context *)ctx)->callable; + ret = _PyObject_CallNoArg(callable); if (!ret) { /* abort query if error occurred */ rc = -1; @@ -975,13 +1011,7 @@ progress_callback(void *ctx) Py_DECREF(ret); } if (rc < 0) { - pysqlite_state *state = pysqlite_get_state(NULL); - if (state->enable_callback_tracebacks) { - PyErr_Print(); - } - else { - PyErr_Clear(); - } + print_or_clear_traceback(ctx); } PyGILState_Release(gilstate); @@ -1015,21 +1045,18 @@ trace_callback(void *ctx, const char *statement_string) PyObject *ret = NULL; py_statement = PyUnicode_DecodeUTF8(statement_string, strlen(statement_string), "replace"); + assert(ctx != NULL); if (py_statement) { - ret = PyObject_CallOneArg((PyObject*)ctx, py_statement); + PyObject *callable = ((callback_context *)ctx)->callable; + ret = PyObject_CallOneArg(callable, py_statement); Py_DECREF(py_statement); } if (ret) { Py_DECREF(ret); - } else { - pysqlite_state *state = pysqlite_get_state(NULL); - if (state->enable_callback_tracebacks) { - PyErr_Print(); - } - else { - PyErr_Clear(); - } + } + else { + print_or_clear_traceback(ctx); } PyGILState_Release(gilstate); @@ -1058,17 +1085,20 @@ pysqlite_connection_set_authorizer_impl(pysqlite_Connection *self, int rc; if (callable == Py_None) { rc = sqlite3_set_authorizer(self->db, NULL, NULL); - Py_XSETREF(self->function_pinboard_authorizer_cb, NULL); + set_callback_context(&self->authorizer_ctx, NULL); } else { - Py_INCREF(callable); - Py_XSETREF(self->function_pinboard_authorizer_cb, callable); - rc = sqlite3_set_authorizer(self->db, authorizer_callback, callable); + callback_context *ctx = create_callback_context(self->state, callable); + if (ctx == NULL) { + return NULL; + } + rc = sqlite3_set_authorizer(self->db, authorizer_callback, ctx); + set_callback_context(&self->authorizer_ctx, ctx); } if (rc != SQLITE_OK) { PyErr_SetString(self->OperationalError, "Error setting authorizer callback"); - Py_XSETREF(self->function_pinboard_authorizer_cb, NULL); + set_callback_context(&self->authorizer_ctx, NULL); return NULL; } Py_RETURN_NONE; @@ -1095,11 +1125,15 @@ pysqlite_connection_set_progress_handler_impl(pysqlite_Connection *self, if (callable == Py_None) { /* None clears the progress handler previously set */ sqlite3_progress_handler(self->db, 0, 0, (void*)0); - Py_XSETREF(self->function_pinboard_progress_handler, NULL); - } else { - sqlite3_progress_handler(self->db, n, progress_callback, callable); - Py_INCREF(callable); - Py_XSETREF(self->function_pinboard_progress_handler, callable); + set_callback_context(&self->progress_ctx, NULL); + } + else { + callback_context *ctx = create_callback_context(self->state, callable); + if (ctx == NULL) { + return NULL; + } + sqlite3_progress_handler(self->db, n, progress_callback, ctx); + set_callback_context(&self->progress_ctx, ctx); } Py_RETURN_NONE; } @@ -1136,15 +1170,19 @@ pysqlite_connection_set_trace_callback_impl(pysqlite_Connection *self, #else sqlite3_trace(self->db, 0, (void*)0); #endif - Py_XSETREF(self->function_pinboard_trace_callback, NULL); - } else { + set_callback_context(&self->trace_ctx, NULL); + } + else { + callback_context *ctx = create_callback_context(self->state, callable); + if (ctx == NULL) { + return NULL; + } #ifdef HAVE_TRACE_V2 - sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, callable); + sqlite3_trace_v2(self->db, SQLITE_TRACE_STMT, trace_callback, ctx); #else - sqlite3_trace(self->db, trace_callback, callable); + sqlite3_trace(self->db, trace_callback, ctx); #endif - Py_INCREF(callable); - Py_XSETREF(self->function_pinboard_trace_callback, callable); + set_callback_context(&self->trace_ctx, ctx); } Py_RETURN_NONE; diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h index 11b3a80..c4cec85 100644 --- a/Modules/_sqlite/connection.h +++ b/Modules/_sqlite/connection.h @@ -82,10 +82,10 @@ typedef struct */ PyObject* text_factory; - /* remember references to object used in trace_callback/progress_handler/authorizer_cb */ - PyObject* function_pinboard_trace_callback; - PyObject* function_pinboard_progress_handler; - PyObject* function_pinboard_authorizer_cb; + // Remember contexts used by the trace, progress, and authoriser callbacks + callback_context *trace_ctx; + callback_context *progress_ctx; + callback_context *authorizer_ctx; /* Exception objects: borrowed refs. */ PyObject* Warning; |