diff options
Diffstat (limited to 'Lib/test/test_sqlite3')
-rw-r--r-- | Lib/test/test_sqlite3/test_dbapi.py | 2 | ||||
-rw-r--r-- | Lib/test/test_sqlite3/test_userfunctions.py | 168 |
2 files changed, 165 insertions, 5 deletions
diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 0248281..2d2e58a 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -1084,6 +1084,8 @@ class ThreadTests(unittest.TestCase): if hasattr(sqlite.Connection, "serialize"): fns.append(lambda: self.con.serialize()) fns.append(lambda: self.con.deserialize(b"")) + if sqlite.sqlite_version_info >= (3, 25, 0): + fns.append(lambda: self.con.create_window_function("foo", 0, None)) for fn in fns: with self.subTest(fn=fn): diff --git a/Lib/test/test_sqlite3/test_userfunctions.py b/Lib/test/test_sqlite3/test_userfunctions.py index 9070c9e..0970b03 100644 --- a/Lib/test/test_sqlite3/test_userfunctions.py +++ b/Lib/test/test_sqlite3/test_userfunctions.py @@ -27,9 +27,9 @@ import io import re import sys import unittest -import unittest.mock import sqlite3 as sqlite +from unittest.mock import Mock, patch from test.support import bigmemtest, catch_unraisable_exception, gc_collect from test.test_sqlite3.test_dbapi import cx_limit @@ -393,7 +393,7 @@ class FunctionTests(unittest.TestCase): # indices, which allows testing based on syntax, iso. the query optimizer. @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") def test_func_non_deterministic(self): - mock = unittest.mock.Mock(return_value=None) + mock = Mock(return_value=None) self.con.create_function("nondeterministic", 0, mock, deterministic=False) if sqlite.sqlite_version_info < (3, 15, 0): self.con.execute("select nondeterministic() = nondeterministic()") @@ -404,7 +404,7 @@ class FunctionTests(unittest.TestCase): @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher") def test_func_deterministic(self): - mock = unittest.mock.Mock(return_value=None) + mock = Mock(return_value=None) self.con.create_function("deterministic", 0, mock, deterministic=True) if sqlite.sqlite_version_info < (3, 15, 0): self.con.execute("select deterministic() = deterministic()") @@ -482,6 +482,164 @@ class FunctionTests(unittest.TestCase): self.con.execute, "select badreturn()") +class WindowSumInt: + def __init__(self): + self.count = 0 + + def step(self, value): + self.count += value + + def value(self): + return self.count + + def inverse(self, value): + self.count -= value + + def finalize(self): + return self.count + +class BadWindow(Exception): + pass + + +@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0), + "Requires SQLite 3.25.0 or newer") +class WindowFunctionTests(unittest.TestCase): + def setUp(self): + self.con = sqlite.connect(":memory:") + self.cur = self.con.cursor() + + # Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc + values = [ + ("a", 4), + ("b", 5), + ("c", 3), + ("d", 8), + ("e", 1), + ] + with self.con: + self.con.execute("create table test(x, y)") + self.con.executemany("insert into test values(?, ?)", values) + self.expected = [ + ("a", 9), + ("b", 12), + ("c", 16), + ("d", 12), + ("e", 9), + ] + self.query = """ + select x, %s(y) over ( + order by x rows between 1 preceding and 1 following + ) as sum_y + from test order by x + """ + self.con.create_window_function("sumint", 1, WindowSumInt) + + def test_win_sum_int(self): + self.cur.execute(self.query % "sumint") + self.assertEqual(self.cur.fetchall(), self.expected) + + def test_win_error_on_create(self): + self.assertRaises(sqlite.ProgrammingError, + self.con.create_window_function, + "shouldfail", -100, WindowSumInt) + + @with_tracebacks(BadWindow) + def test_win_exception_in_method(self): + for meth in "__init__", "step", "value", "inverse": + with self.subTest(meth=meth): + with patch.object(WindowSumInt, meth, side_effect=BadWindow): + name = f"exc_{meth}" + self.con.create_window_function(name, 1, WindowSumInt) + msg = f"'{meth}' method raised error" + with self.assertRaisesRegex(sqlite.OperationalError, msg): + self.cur.execute(self.query % name) + self.cur.fetchall() + + @with_tracebacks(BadWindow) + def test_win_exception_in_finalize(self): + # Note: SQLite does not (as of version 3.38.0) propagate finalize + # callback errors to sqlite3_step(); this implies that OperationalError + # is _not_ raised. + with patch.object(WindowSumInt, "finalize", side_effect=BadWindow): + name = f"exception_in_finalize" + self.con.create_window_function(name, 1, WindowSumInt) + self.cur.execute(self.query % name) + self.cur.fetchall() + + @with_tracebacks(AttributeError) + def test_win_missing_method(self): + class MissingValue: + def step(self, x): pass + def inverse(self, x): pass + def finalize(self): return 42 + + class MissingInverse: + def step(self, x): pass + def value(self): return 42 + def finalize(self): return 42 + + class MissingStep: + def value(self): return 42 + def inverse(self, x): pass + def finalize(self): return 42 + + dataset = ( + ("step", MissingStep), + ("value", MissingValue), + ("inverse", MissingInverse), + ) + for meth, cls in dataset: + with self.subTest(meth=meth, cls=cls): + name = f"exc_{meth}" + self.con.create_window_function(name, 1, cls) + with self.assertRaisesRegex(sqlite.OperationalError, + f"'{meth}' method not defined"): + self.cur.execute(self.query % name) + self.cur.fetchall() + + @with_tracebacks(AttributeError) + def test_win_missing_finalize(self): + # Note: SQLite does not (as of version 3.38.0) propagate finalize + # callback errors to sqlite3_step(); this implies that OperationalError + # is _not_ raised. + class MissingFinalize: + def step(self, x): pass + def value(self): return 42 + def inverse(self, x): pass + + name = "missing_finalize" + self.con.create_window_function(name, 1, MissingFinalize) + self.cur.execute(self.query % name) + self.cur.fetchall() + + def test_win_clear_function(self): + self.con.create_window_function("sumint", 1, None) + self.assertRaises(sqlite.OperationalError, self.cur.execute, + self.query % "sumint") + + def test_win_redefine_function(self): + # Redefine WindowSumInt; adjust the expected results accordingly. + class Redefined(WindowSumInt): + def step(self, value): self.count += value * 2 + def inverse(self, value): self.count -= value * 2 + expected = [(v[0], v[1]*2) for v in self.expected] + + self.con.create_window_function("sumint", 1, Redefined) + self.cur.execute(self.query % "sumint") + self.assertEqual(self.cur.fetchall(), expected) + + def test_win_error_value_return(self): + class ErrorValueReturn: + def __init__(self): pass + def step(self, x): pass + def value(self): return 1 << 65 + + self.con.create_window_function("err_val_ret", 1, ErrorValueReturn) + self.assertRaisesRegex(sqlite.DataError, "string or blob too big", + self.cur.execute, self.query % "err_val_ret") + + class AggregateTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") @@ -527,10 +685,10 @@ class AggregateTests(unittest.TestCase): def test_aggr_no_finalize(self): cur = self.con.cursor() - with self.assertRaises(sqlite.OperationalError) as cm: + msg = "user-defined aggregate's 'finalize' method not defined" + with self.assertRaisesRegex(sqlite.OperationalError, msg): cur.execute("select nofinalize(t) from test") val = cur.fetchone()[0] - self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") @with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit") def test_aggr_exception_in_init(self): |