summaryrefslogtreecommitdiffstats
path: root/Modules/_sqlite/connection.c
diff options
context:
space:
mode:
authorErlend Egeberg Aasland <erlend.aasland@innova.no>2022-04-12 00:55:59 (GMT)
committerGitHub <noreply@github.com>2022-04-12 00:55:59 (GMT)
commit9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6 (patch)
treeef6b3c2d043f9b85ed4b15aa684eab941e25347f /Modules/_sqlite/connection.c
parentf45aa8f304a12990c2ca687f2088f04b07906033 (diff)
downloadcpython-9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6.zip
cpython-9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6.tar.gz
cpython-9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6.tar.bz2
gh-79097: Add support for aggregate window functions in sqlite3 (GH-20903)
Diffstat (limited to 'Modules/_sqlite/connection.c')
-rw-r--r--Modules/_sqlite/connection.c179
1 files changed, 172 insertions, 7 deletions
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 9d187cf..d7c0a9e 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -33,6 +33,10 @@
#define HAVE_TRACE_V2
#endif
+#if SQLITE_VERSION_NUMBER >= 3025000
+#define HAVE_WINDOW_FUNCTIONS
+#endif
+
static const char *
get_isolation_level(const char *level)
{
@@ -799,7 +803,7 @@ final_callback(sqlite3_context *context)
goto error;
}
- /* Keep the exception (if any) of the last call to step() */
+ // Keep the exception (if any) of the last call to step, value, or inverse
PyErr_Fetch(&exception, &value, &tb);
callback_context *ctx = (callback_context *)sqlite3_user_data(context);
@@ -814,13 +818,20 @@ final_callback(sqlite3_context *context)
Py_DECREF(function_result);
}
if (!ok) {
- set_sqlite_error(context,
- "user-defined aggregate's 'finalize' method raised error");
- }
+ int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
+ _PyErr_ChainExceptions(exception, value, tb);
- /* Restore the exception (if any) of the last call to step(),
- but clear also the current exception if finalize() failed */
- PyErr_Restore(exception, value, tb);
+ /* Note: contrary to the step, value, and inverse callbacks, SQLite
+ * does _not_, as of SQLite 3.38.0, propagate errors to sqlite3_step()
+ * from the finalize callback. This implies that execute*() will not
+ * raise OperationalError, as it normally would. */
+ set_sqlite_error(context, attr_err
+ ? "user-defined aggregate's 'finalize' method not defined"
+ : "user-defined aggregate's 'finalize' method raised error");
+ }
+ else {
+ PyErr_Restore(exception, value, tb);
+ }
error:
PyGILState_Release(threadstate);
@@ -968,6 +979,159 @@ pysqlite_connection_create_function_impl(pysqlite_Connection *self,
Py_RETURN_NONE;
}
+#ifdef HAVE_WINDOW_FUNCTIONS
+/*
+ * Regarding the 'inverse' aggregate callback:
+ * This method is only required by window aggregate functions, not
+ * ordinary aggregate function implementations. It is invoked to remove
+ * a row from the current window. The function arguments, if any,
+ * correspond to the row being removed.
+ */
+static void
+inverse_callback(sqlite3_context *context, int argc, sqlite3_value **params)
+{
+ PyGILState_STATE gilstate = PyGILState_Ensure();
+
+ callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ assert(ctx != NULL);
+
+ int size = sizeof(PyObject *);
+ PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
+ assert(cls != NULL);
+ assert(*cls != NULL);
+
+ PyObject *method = PyObject_GetAttr(*cls, ctx->state->str_inverse);
+ if (method == NULL) {
+ set_sqlite_error(context,
+ "user-defined aggregate's 'inverse' method not defined");
+ goto exit;
+ }
+
+ PyObject *args = _pysqlite_build_py_params(context, argc, params);
+ if (args == NULL) {
+ set_sqlite_error(context,
+ "unable to build arguments for user-defined aggregate's "
+ "'inverse' method");
+ goto exit;
+ }
+
+ PyObject *res = PyObject_CallObject(method, args);
+ Py_DECREF(args);
+ if (res == NULL) {
+ set_sqlite_error(context,
+ "user-defined aggregate's 'inverse' method raised error");
+ goto exit;
+ }
+ Py_DECREF(res);
+
+exit:
+ Py_XDECREF(method);
+ PyGILState_Release(gilstate);
+}
+
+/*
+ * Regarding the 'value' aggregate callback:
+ * This method is only required by window aggregate functions, not
+ * ordinary aggregate function implementations. It is invoked to return
+ * the current value of the aggregate.
+ */
+static void
+value_callback(sqlite3_context *context)
+{
+ PyGILState_STATE gilstate = PyGILState_Ensure();
+
+ callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ assert(ctx != NULL);
+
+ int size = sizeof(PyObject *);
+ PyObject **cls = (PyObject **)sqlite3_aggregate_context(context, size);
+ assert(cls != NULL);
+ assert(*cls != NULL);
+
+ PyObject *res = PyObject_CallMethodNoArgs(*cls, ctx->state->str_value);
+ if (res == NULL) {
+ int attr_err = PyErr_ExceptionMatches(PyExc_AttributeError);
+ set_sqlite_error(context, attr_err
+ ? "user-defined aggregate's 'value' method not defined"
+ : "user-defined aggregate's 'value' method raised error");
+ }
+ else {
+ int rc = _pysqlite_set_result(context, res);
+ Py_DECREF(res);
+ if (rc < 0) {
+ set_sqlite_error(context,
+ "unable to set result from user-defined aggregate's "
+ "'value' method");
+ }
+ }
+
+ PyGILState_Release(gilstate);
+}
+
+/*[clinic input]
+_sqlite3.Connection.create_window_function as create_window_function
+
+ cls: defining_class
+ name: str
+ The name of the SQL aggregate window function to be created or
+ redefined.
+ num_params: int
+ The number of arguments the step and inverse methods takes.
+ aggregate_class: object
+ A class with step(), finalize(), value(), and inverse() methods.
+ Set to None to clear the window function.
+ /
+
+Creates or redefines an aggregate window function. Non-standard.
+[clinic start generated code]*/
+
+static PyObject *
+create_window_function_impl(pysqlite_Connection *self, PyTypeObject *cls,
+ const char *name, int num_params,
+ PyObject *aggregate_class)
+/*[clinic end generated code: output=5332cd9464522235 input=46d57a54225b5228]*/
+{
+ if (sqlite3_libversion_number() < 3025000) {
+ PyErr_SetString(self->NotSupportedError,
+ "create_window_function() requires "
+ "SQLite 3.25.0 or higher");
+ return NULL;
+ }
+
+ if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
+ return NULL;
+ }
+
+ int flags = SQLITE_UTF8;
+ int rc;
+ if (Py_IsNone(aggregate_class)) {
+ rc = sqlite3_create_window_function(self->db, name, num_params, flags,
+ 0, 0, 0, 0, 0, 0);
+ }
+ else {
+ callback_context *ctx = create_callback_context(cls, aggregate_class);
+ if (ctx == NULL) {
+ return NULL;
+ }
+ rc = sqlite3_create_window_function(self->db, name, num_params, flags,
+ ctx,
+ &step_callback,
+ &final_callback,
+ &value_callback,
+ &inverse_callback,
+ &destructor_callback);
+ }
+
+ if (rc != SQLITE_OK) {
+ // Errors are not set on the database connection, so we cannot
+ // use _pysqlite_seterror().
+ PyErr_SetString(self->ProgrammingError, sqlite3_errstr(rc));
+ return NULL;
+ }
+ Py_RETURN_NONE;
+}
+#endif
+
/*[clinic input]
_sqlite3.Connection.create_aggregate as pysqlite_connection_create_aggregate
@@ -2092,6 +2256,7 @@ static PyMethodDef connection_methods[] = {
GETLIMIT_METHODDEF
SERIALIZE_METHODDEF
DESERIALIZE_METHODDEF
+ CREATE_WINDOW_FUNCTION_METHODDEF
{NULL, NULL}
};