diff options
Diffstat (limited to 'Modules/_sqlite/connection.c')
-rw-r--r-- | Modules/_sqlite/connection.c | 99 |
1 files changed, 84 insertions, 15 deletions
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index d52bea4..28bd647 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -677,7 +677,7 @@ void _pysqlite_final_callback(sqlite3_context* context) { PyObject* function_result; PyObject** aggregate_instance; - PyObject* aggregate_class; + _Py_IDENTIFIER(finalize); int ok; #ifdef WITH_THREAD @@ -686,8 +686,6 @@ void _pysqlite_final_callback(sqlite3_context* context) threadstate = PyGILState_Ensure(); #endif - aggregate_class = (PyObject*)sqlite3_user_data(context); - aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*)); if (!*aggregate_instance) { /* this branch is executed if there was an exception in the aggregate's @@ -696,7 +694,7 @@ void _pysqlite_final_callback(sqlite3_context* context) goto error; } - function_result = PyObject_CallMethod(*aggregate_instance, "finalize", ""); + function_result = _PyObject_CallMethodId(*aggregate_instance, &PyId_finalize, ""); Py_DECREF(*aggregate_instance); ok = 0; @@ -916,6 +914,38 @@ static int _progress_handler(void* user_arg) return rc; } +static void _trace_callback(void* user_arg, const char* statement_string) +{ + PyObject *py_statement = NULL; + PyObject *ret = NULL; + +#ifdef WITH_THREAD + PyGILState_STATE gilstate; + + gilstate = PyGILState_Ensure(); +#endif + py_statement = PyUnicode_DecodeUTF8(statement_string, + strlen(statement_string), "replace"); + if (py_statement) { + ret = PyObject_CallFunctionObjArgs((PyObject*)user_arg, py_statement, NULL); + Py_DECREF(py_statement); + } + + if (ret) { + Py_DECREF(ret); + } else { + if (_enable_callback_tracebacks) { + PyErr_Print(); + } else { + PyErr_Clear(); + } + } + +#ifdef WITH_THREAD + PyGILState_Release(gilstate); +#endif +} + static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) { PyObject* authorizer_cb; @@ -975,6 +1005,34 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s return Py_None; } +static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) +{ + PyObject* trace_callback; + + static char *kwlist[] = { "trace_callback", NULL }; + + if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { + return NULL; + } + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:set_trace_callback", + kwlist, &trace_callback)) { + return NULL; + } + + if (trace_callback == Py_None) { + /* None clears the trace callback previously set */ + sqlite3_trace(self->db, 0, (void*)0); + } else { + if (PyDict_SetItem(self->function_pinboard, trace_callback, Py_None) == -1) + return NULL; + sqlite3_trace(self->db, _trace_callback, trace_callback); + } + + Py_INCREF(Py_None); + return Py_None; +} + #ifdef HAVE_LOAD_EXTENSION static PyObject* pysqlite_enable_load_extension(pysqlite_Connection* self, PyObject* args) { @@ -1180,8 +1238,9 @@ PyObject* pysqlite_connection_execute(pysqlite_Connection* self, PyObject* args, PyObject* cursor = 0; PyObject* result = 0; PyObject* method = 0; + _Py_IDENTIFIER(cursor); - cursor = PyObject_CallMethod((PyObject*)self, "cursor", ""); + cursor = _PyObject_CallMethodId((PyObject*)self, &PyId_cursor, ""); if (!cursor) { goto error; } @@ -1209,8 +1268,9 @@ PyObject* pysqlite_connection_executemany(pysqlite_Connection* self, PyObject* a PyObject* cursor = 0; PyObject* result = 0; PyObject* method = 0; + _Py_IDENTIFIER(cursor); - cursor = PyObject_CallMethod((PyObject*)self, "cursor", ""); + cursor = _PyObject_CallMethodId((PyObject*)self, &PyId_cursor, ""); if (!cursor) { goto error; } @@ -1238,8 +1298,9 @@ PyObject* pysqlite_connection_executescript(pysqlite_Connection* self, PyObject* PyObject* cursor = 0; PyObject* result = 0; PyObject* method = 0; + _Py_IDENTIFIER(cursor); - cursor = PyObject_CallMethod((PyObject*)self, "cursor", ""); + cursor = _PyObject_CallMethodId((PyObject*)self, &PyId_cursor, ""); if (!cursor) { goto error; } @@ -1394,10 +1455,12 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args) PyObject* uppercase_name = 0; PyObject* name; PyObject* retval; - Py_UNICODE* chk; Py_ssize_t i, len; + _Py_IDENTIFIER(upper); char *uppercase_name_str; int rc; + unsigned int kind; + void *data; if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { goto finally; @@ -1407,17 +1470,21 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args) goto finally; } - uppercase_name = PyObject_CallMethod(name, "upper", ""); + uppercase_name = _PyObject_CallMethodId(name, &PyId_upper, ""); if (!uppercase_name) { goto finally; } - len = PyUnicode_GET_SIZE(uppercase_name); - chk = PyUnicode_AS_UNICODE(uppercase_name); - for (i=0; i<len; i++, chk++) { - if ((*chk >= '0' && *chk <= '9') - || (*chk >= 'A' && *chk <= 'Z') - || (*chk == '_')) + if (PyUnicode_READY(uppercase_name)) + goto finally; + len = PyUnicode_GET_LENGTH(uppercase_name); + kind = PyUnicode_KIND(uppercase_name); + data = PyUnicode_DATA(uppercase_name); + for (i=0; i<len; i++) { + Py_UCS4 ch = PyUnicode_READ(kind, data, i); + if ((ch >= '0' && ch <= '9') + || (ch >= 'A' && ch <= 'Z') + || (ch == '_')) { continue; } else { @@ -1536,6 +1603,8 @@ static PyMethodDef connection_methods[] = { #endif {"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS, PyDoc_STR("Sets progress handler callback. Non-standard.")}, + {"set_trace_callback", (PyCFunction)pysqlite_connection_set_trace_callback, METH_VARARGS|METH_KEYWORDS, + PyDoc_STR("Sets a trace callback called for each SQL statement (passed as unicode). Non-standard.")}, {"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS, PyDoc_STR("Executes a SQL statement. Non-standard.")}, {"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS, |