summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/test/test_sqlite3/test_userfunctions.py6
-rw-r--r--Modules/_sqlite/connection.c10
-rw-r--r--Modules/_sqlite/module.c2
-rw-r--r--Modules/_sqlite/module.h1
4 files changed, 13 insertions, 6 deletions
diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py
index 23ecfb4..2588cae 100644
--- a/Lib/test/test_sqlite3/test_userfunctions.py
+++ b/Lib/test/test_sqlite3/test_userfunctions.py
@@ -502,11 +502,13 @@ class AggregateTests(unittest.TestCase):
with self.assertRaises(sqlite.OperationalError):
self.con.create_function("bla", -100, AggrSum)
+ @with_tracebacks(AttributeError, name="AggrNoStep")
def test_aggr_no_step(self):
cur = self.con.cursor()
- with self.assertRaises(AttributeError) as cm:
+ with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select nostep(t) from test")
- self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'")
+ self.assertEqual(str(cm.exception),
+ "user-defined aggregate's 'step' method not defined")
def test_aggr_no_finalize(self):
cur = self.con.cursor()
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 0efb5ae..9f12e69 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -734,11 +734,11 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
PyObject** aggregate_instance;
PyObject* stepmethod = NULL;
- aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
+ callback_context *ctx = (callback_context *)sqlite3_user_data(context);
+ assert(ctx != NULL);
+ aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
if (*aggregate_instance == NULL) {
- callback_context *ctx = (callback_context *)sqlite3_user_data(context);
- assert(ctx != NULL);
*aggregate_instance = PyObject_CallNoArgs(ctx->callable);
if (!*aggregate_instance) {
set_sqlite_error(context,
@@ -747,8 +747,10 @@ step_callback(sqlite3_context *context, int argc, sqlite3_value **params)
}
}
- stepmethod = PyObject_GetAttrString(*aggregate_instance, "step");
+ stepmethod = PyObject_GetAttr(*aggregate_instance, ctx->state->str_step);
if (!stepmethod) {
+ set_sqlite_error(context,
+ "user-defined aggregate's 'step' method not defined");
goto error;
}
diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c
index 70fde49..563105c 100644
--- a/Modules/_sqlite/module.c
+++ b/Modules/_sqlite/module.c
@@ -627,6 +627,7 @@ module_clear(PyObject *module)
Py_CLEAR(state->str___conform__);
Py_CLEAR(state->str_executescript);
Py_CLEAR(state->str_finalize);
+ Py_CLEAR(state->str_step);
Py_CLEAR(state->str_upper);
return 0;
@@ -713,6 +714,7 @@ module_exec(PyObject *module)
ADD_INTERNED(state, __conform__);
ADD_INTERNED(state, executescript);
ADD_INTERNED(state, finalize);
+ ADD_INTERNED(state, step);
ADD_INTERNED(state, upper);
/* Set error constants */
diff --git a/Modules/_sqlite/module.h b/Modules/_sqlite/module.h
index 35c6f38..cca52d1 100644
--- a/Modules/_sqlite/module.h
+++ b/Modules/_sqlite/module.h
@@ -64,6 +64,7 @@ typedef struct {
PyObject *str___conform__;
PyObject *str_executescript;
PyObject *str_finalize;
+ PyObject *str_step;
PyObject *str_upper;
} pysqlite_state;