summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_sqlite3
diff options
context:
space:
mode:
authorErlend Egeberg Aasland <erlend.aasland@innova.no>2022-04-12 00:55:59 (GMT)
committerGitHub <noreply@github.com>2022-04-12 00:55:59 (GMT)
commit9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6 (patch)
treeef6b3c2d043f9b85ed4b15aa684eab941e25347f /Lib/test/test_sqlite3
parentf45aa8f304a12990c2ca687f2088f04b07906033 (diff)
downloadcpython-9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6.zip
cpython-9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6.tar.gz
cpython-9ebcece82fe11b87cc3d6e6b4c439aab9e3ab1e6.tar.bz2
gh-79097: Add support for aggregate window functions in sqlite3 (GH-20903)
Diffstat (limited to 'Lib/test/test_sqlite3')
-rw-r--r--Lib/test/test_sqlite3/test_dbapi.py2
-rw-r--r--Lib/test/test_sqlite3/test_userfunctions.py168
2 files changed, 165 insertions, 5 deletions
diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py
index 0248281..2d2e58a 100644
--- a/Lib/test/test_sqlite3/test_dbapi.py
+++ b/Lib/test/test_sqlite3/test_dbapi.py
@@ -1084,6 +1084,8 @@ class ThreadTests(unittest.TestCase):
if hasattr(sqlite.Connection, "serialize"):
fns.append(lambda: self.con.serialize())
fns.append(lambda: self.con.deserialize(b""))
+ if sqlite.sqlite_version_info >= (3, 25, 0):
+ fns.append(lambda: self.con.create_window_function("foo", 0, None))
for fn in fns:
with self.subTest(fn=fn):
diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py
index 9070c9e..0970b03 100644
--- a/Lib/test/test_sqlite3/test_userfunctions.py
+++ b/Lib/test/test_sqlite3/test_userfunctions.py
@@ -27,9 +27,9 @@ import io
import re
import sys
import unittest
-import unittest.mock
import sqlite3 as sqlite
+from unittest.mock import Mock, patch
from test.support import bigmemtest, catch_unraisable_exception, gc_collect
from test.test_sqlite3.test_dbapi import cx_limit
@@ -393,7 +393,7 @@ class FunctionTests(unittest.TestCase):
# indices, which allows testing based on syntax, iso. the query optimizer.
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
def test_func_non_deterministic(self):
- mock = unittest.mock.Mock(return_value=None)
+ mock = Mock(return_value=None)
self.con.create_function("nondeterministic", 0, mock, deterministic=False)
if sqlite.sqlite_version_info < (3, 15, 0):
self.con.execute("select nondeterministic() = nondeterministic()")
@@ -404,7 +404,7 @@ class FunctionTests(unittest.TestCase):
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
def test_func_deterministic(self):
- mock = unittest.mock.Mock(return_value=None)
+ mock = Mock(return_value=None)
self.con.create_function("deterministic", 0, mock, deterministic=True)
if sqlite.sqlite_version_info < (3, 15, 0):
self.con.execute("select deterministic() = deterministic()")
@@ -482,6 +482,164 @@ class FunctionTests(unittest.TestCase):
self.con.execute, "select badreturn()")
+class WindowSumInt:
+ def __init__(self):
+ self.count = 0
+
+ def step(self, value):
+ self.count += value
+
+ def value(self):
+ return self.count
+
+ def inverse(self, value):
+ self.count -= value
+
+ def finalize(self):
+ return self.count
+
+class BadWindow(Exception):
+ pass
+
+
+@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0),
+ "Requires SQLite 3.25.0 or newer")
+class WindowFunctionTests(unittest.TestCase):
+ def setUp(self):
+ self.con = sqlite.connect(":memory:")
+ self.cur = self.con.cursor()
+
+ # Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
+ values = [
+ ("a", 4),
+ ("b", 5),
+ ("c", 3),
+ ("d", 8),
+ ("e", 1),
+ ]
+ with self.con:
+ self.con.execute("create table test(x, y)")
+ self.con.executemany("insert into test values(?, ?)", values)
+ self.expected = [
+ ("a", 9),
+ ("b", 12),
+ ("c", 16),
+ ("d", 12),
+ ("e", 9),
+ ]
+ self.query = """
+ select x, %s(y) over (
+ order by x rows between 1 preceding and 1 following
+ ) as sum_y
+ from test order by x
+ """
+ self.con.create_window_function("sumint", 1, WindowSumInt)
+
+ def test_win_sum_int(self):
+ self.cur.execute(self.query % "sumint")
+ self.assertEqual(self.cur.fetchall(), self.expected)
+
+ def test_win_error_on_create(self):
+ self.assertRaises(sqlite.ProgrammingError,
+ self.con.create_window_function,
+ "shouldfail", -100, WindowSumInt)
+
+ @with_tracebacks(BadWindow)
+ def test_win_exception_in_method(self):
+ for meth in "__init__", "step", "value", "inverse":
+ with self.subTest(meth=meth):
+ with patch.object(WindowSumInt, meth, side_effect=BadWindow):
+ name = f"exc_{meth}"
+ self.con.create_window_function(name, 1, WindowSumInt)
+ msg = f"'{meth}' method raised error"
+ with self.assertRaisesRegex(sqlite.OperationalError, msg):
+ self.cur.execute(self.query % name)
+ self.cur.fetchall()
+
+ @with_tracebacks(BadWindow)
+ def test_win_exception_in_finalize(self):
+ # Note: SQLite does not (as of version 3.38.0) propagate finalize
+ # callback errors to sqlite3_step(); this implies that OperationalError
+ # is _not_ raised.
+ with patch.object(WindowSumInt, "finalize", side_effect=BadWindow):
+ name = f"exception_in_finalize"
+ self.con.create_window_function(name, 1, WindowSumInt)
+ self.cur.execute(self.query % name)
+ self.cur.fetchall()
+
+ @with_tracebacks(AttributeError)
+ def test_win_missing_method(self):
+ class MissingValue:
+ def step(self, x): pass
+ def inverse(self, x): pass
+ def finalize(self): return 42
+
+ class MissingInverse:
+ def step(self, x): pass
+ def value(self): return 42
+ def finalize(self): return 42
+
+ class MissingStep:
+ def value(self): return 42
+ def inverse(self, x): pass
+ def finalize(self): return 42
+
+ dataset = (
+ ("step", MissingStep),
+ ("value", MissingValue),
+ ("inverse", MissingInverse),
+ )
+ for meth, cls in dataset:
+ with self.subTest(meth=meth, cls=cls):
+ name = f"exc_{meth}"
+ self.con.create_window_function(name, 1, cls)
+ with self.assertRaisesRegex(sqlite.OperationalError,
+ f"'{meth}' method not defined"):
+ self.cur.execute(self.query % name)
+ self.cur.fetchall()
+
+ @with_tracebacks(AttributeError)
+ def test_win_missing_finalize(self):
+ # Note: SQLite does not (as of version 3.38.0) propagate finalize
+ # callback errors to sqlite3_step(); this implies that OperationalError
+ # is _not_ raised.
+ class MissingFinalize:
+ def step(self, x): pass
+ def value(self): return 42
+ def inverse(self, x): pass
+
+ name = "missing_finalize"
+ self.con.create_window_function(name, 1, MissingFinalize)
+ self.cur.execute(self.query % name)
+ self.cur.fetchall()
+
+ def test_win_clear_function(self):
+ self.con.create_window_function("sumint", 1, None)
+ self.assertRaises(sqlite.OperationalError, self.cur.execute,
+ self.query % "sumint")
+
+ def test_win_redefine_function(self):
+ # Redefine WindowSumInt; adjust the expected results accordingly.
+ class Redefined(WindowSumInt):
+ def step(self, value): self.count += value * 2
+ def inverse(self, value): self.count -= value * 2
+ expected = [(v[0], v[1]*2) for v in self.expected]
+
+ self.con.create_window_function("sumint", 1, Redefined)
+ self.cur.execute(self.query % "sumint")
+ self.assertEqual(self.cur.fetchall(), expected)
+
+ def test_win_error_value_return(self):
+ class ErrorValueReturn:
+ def __init__(self): pass
+ def step(self, x): pass
+ def value(self): return 1 << 65
+
+ self.con.create_window_function("err_val_ret", 1, ErrorValueReturn)
+ self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
+ self.cur.execute, self.query % "err_val_ret")
+
+
class AggregateTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
@@ -527,10 +685,10 @@ class AggregateTests(unittest.TestCase):
def test_aggr_no_finalize(self):
cur = self.con.cursor()
- with self.assertRaises(sqlite.OperationalError) as cm:
+ msg = "user-defined aggregate's 'finalize' method not defined"
+ with self.assertRaisesRegex(sqlite.OperationalError, msg):
cur.execute("select nofinalize(t) from test")
val = cur.fetchone()[0]
- self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
@with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
def test_aggr_exception_in_init(self):