diff options
author | Erlend Egeberg Aasland <erlend.aasland@innova.no> | 2022-04-12 00:55:59 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-04-12 00:55:59 (GMT) |
commit | 9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6 (patch) | |
tree | ef6b3c2d043f9b85ed4b15aa684eab941e25347f /Modules/_sqlite/connection.c | |
parent | f45aa8f304a12990c2ca687f2088f04b07906033 (diff) | |
download | cpython-9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6.zip cpython-9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6.tar.gz cpython-9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6.tar.bz2 |
gh-79097: Add support for aggregate window functions in sqlite3 (GH-20903)
Diffstat (limited to 'Modules/_sqlite/connection.c')
-rw-r--r-- | Modules/_sqlite/connection.c | 179 |
1 files changed, 172 insertions, 7 deletions
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 9d187cf..d7c0a9e 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -33,6 +33,10 @@ #define HAVE_TRACE_V2 #endif +#if SQLITE_VERSION_NUMBER >= 3025000 +#define HAVE_WINDOW_FUNCTIONS +#endif + static const char * get_isolation_level(const char *level) { @@ -799,7 +803,7 @@ final_callback(sqlite3_context *context) goto error; } - /* Keep the exception (if any) of the last call to step() */ + // Keep the exception (if any) of the last call to step, value, or inverse PyErr_Fetch(&exception, &value, &tb); callback_context *ctx = (callback_context *)sqlite3_user_data(context); @@ -814,13 +818,20 @@ final_callback(sqlite3_context *context) Py_DECREF(function_result); } if (!ok) { - set_sqlite_error(context, - "user-defined aggregate's 'finalize' method raised error"); - } + int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError); + _PyErr_ChainExceptions(exception, value, tb); - /* Restore the exception (if any) of the last call to step(), - but clear also the current exception if finalize() failed */ - PyErr_Restore(exception, value, tb); + /* Note: contrary to the step, value, and inverse callbacks, SQLite + * does _not_, as of SQLite 3.38.0, propagate errors to sqlite3_step() + * from the finalize callback. This implies that execute*() will not + * raise OperationalError, as it normally would. */ + set_sqlite_error(context, attr_err + ? "user-defined aggregate's 'finalize' method not defined" + : "user-defined aggregate's 'finalize' method raised error"); + } + else { + PyErr_Restore(exception, value, tb); + } error: PyGILState_Release(threadstate); @@ -968,6 +979,159 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self, Py_RETURN_NONE; } +#ifdef HAVE_WINDOW_FUNCTIONS +/* + * Regarding the 'inverse' aggregate callback: + * This method is only required by window aggregate functions, not + * ordinary aggregate function implementations. It is invoked to remove + * a row from the current window. The function arguments, if any, + * correspond to the row being removed. + */ +static void +inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params) +{ + PyGILState_STATE gilstate = PyGILState_Ensure(); + + callback_context *ctx = (callback_context *)sqlite3_user_data(context); + assert(ctx != NULL); + + int size = sizeof(PyObject *); + PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size); + assert(cls != NULL); + assert(*cls != NULL); + + PyObject *method = PyObject_GetAttr(*cls, ctx->state->str_inverse); + if (method == NULL) { + set_sqlite_error(context, + "user-defined aggregate's 'inverse' method not defined"); + goto exit; + } + + PyObject *args = _pysqlite_build_py_params(context, argc, params); + if (args == NULL) { + set_sqlite_error(context, + "unable to build arguments for user-defined aggregate's " + "'inverse' method"); + goto exit; + } + + PyObject *res = PyObject_CallObject(method, args); + Py_DECREF(args); + if (res == NULL) { + set_sqlite_error(context, + "user-defined aggregate's 'inverse' method raised error"); + goto exit; + } + Py_DECREF(res); + +exit: + Py_XDECREF(method); + PyGILState_Release(gilstate); +} + +/* + * Regarding the 'value' aggregate callback: + * This method is only required by window aggregate functions, not + * ordinary aggregate function implementations. It is invoked to return + * the current value of the aggregate. + */ +static void +value_callback(sqlite3_context *context) +{ + PyGILState_STATE gilstate = PyGILState_Ensure(); + + callback_context *ctx = (callback_context *)sqlite3_user_data(context); + assert(ctx != NULL); + + int size = sizeof(PyObject *); + PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size); + assert(cls != NULL); + assert(*cls != NULL); + + PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value); + if (res == NULL) { + int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError); + set_sqlite_error(context, attr_err + ? "user-defined aggregate's 'value' method not defined" + : "user-defined aggregate's 'value' method raised error"); + } + else { + int rc = _pysqlite_set_result(context, res); + Py_DECREF(res); + if (rc < 0) { + set_sqlite_error(context, + "unable to set result from user-defined aggregate's " + "'value' method"); + } + } + + PyGILState_Release(gilstate); +} + +/*[clinic input] +_sqlite3.Connection.create_window_function as create_window_function + + cls: defining_class + name: str + The name of the SQL aggregate window function to be created or + redefined. + num_params: int + The number of arguments the step and inverse methods takes. + aggregate_class: object + A class with step(), finalize(), value(), and inverse() methods. + Set to None to clear the window function. + / + +Creates or redefines an aggregate window function. Non-standard. +[clinic start generated code]*/ + +static PyObject * +create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls, + const char *name, int num_params, + PyObject *aggregate_class) +/*[clinic end generated code: output=5332cd9464522235 input=46d57a54225b5228]*/ +{ + if (sqlite3_libversion_number() < 3025000) { + PyErr_SetString(self->NotSupportedError, + "create_window_function() requires " + "SQLite 3.25.0 or higher"); + return NULL; + } + + if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { + return NULL; + } + + int flags = SQLITE_UTF8; + int rc; + if (Py_IsNone(aggregate_class)) { + rc = sqlite3_create_window_function(self->db, name, num_params, flags, + 0, 0, 0, 0, 0, 0); + } + else { + callback_context *ctx = create_callback_context(cls, aggregate_class); + if (ctx == NULL) { + return NULL; + } + rc = sqlite3_create_window_function(self->db, name, num_params, flags, + ctx, + &step_callback, + &final_callback, + &value_callback, + &inverse_callback, + &destructor_callback); + } + + if (rc != SQLITE_OK) { + // Errors are not set on the database connection, so we cannot + // use _pysqlite_seterror(). + PyErr_SetString(self->ProgrammingError, sqlite3_errstr(rc)); + return NULL; + } + Py_RETURN_NONE; +} +#endif + /*[clinic input] _sqlite3.Connection.create_aggregate as pysqlite_connection_create_aggregate @@ -2092,6 +2256,7 @@ static PyMethodDef connection_methods[] = { GETLIMIT_METHODDEF SERIALIZE_METHODDEF DESERIALIZE_METHODDEF + CREATE_WINDOW_FUNCTION_METHODDEF {NULL, NULL} }; |