diff options
author | Erlend E. Aasland <erlend@python.org> | 2023-08-17 06:45:48 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-17 06:45:48 (GMT) |
commit | 1344cfac43a1920c596b0e8718ca0567889e697b (patch) | |
tree | 8cc0f9288d1c9f58ab92bd077462e532ac3bd057 /Lib/test/test_sqlite3/test_hooks.py | |
parent | c9d83f93d804b80ee14480466ebee63a6f97dac2 (diff) | |
download | cpython-1344cfac43a1920c596b0e8718ca0567889e697b.zip cpython-1344cfac43a1920c596b0e8718ca0567889e697b.tar.gz cpython-1344cfac43a1920c596b0e8718ca0567889e697b.tar.bz2 |
gh-105539: Explict resource management for connection objects in sqlite3 tests (#108017)
- Use memory_database() helper
- Move test utility functions to util.py
- Add convenience memory database mixin
- Add check() helper for closed connection tests
Diffstat (limited to 'Lib/test/test_sqlite3/test_hooks.py')
-rw-r--r-- | Lib/test/test_sqlite3/test_hooks.py | 79 |
1 files changed, 36 insertions, 43 deletions
diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py index 89230c0..33f0af9 100644 --- a/Lib/test/test_sqlite3/test_hooks.py +++ b/Lib/test/test_sqlite3/test_hooks.py @@ -26,34 +26,31 @@ import unittest from test.support.os_helper import TESTFN, unlink -from test.test_sqlite3.test_dbapi import memory_database, cx_limit -from test.test_sqlite3.test_userfunctions import with_tracebacks +from .util import memory_database, cx_limit, with_tracebacks +from .util import MemoryDatabaseMixin -class CollationTests(unittest.TestCase): +class CollationTests(MemoryDatabaseMixin, unittest.TestCase): + def test_create_collation_not_string(self): - con = sqlite.connect(":memory:") with self.assertRaises(TypeError): - con.create_collation(None, lambda x, y: (x > y) - (x < y)) + self.con.create_collation(None, lambda x, y: (x > y) - (x < y)) def test_create_collation_not_callable(self): - con = sqlite.connect(":memory:") with self.assertRaises(TypeError) as cm: - con.create_collation("X", 42) + self.con.create_collation("X", 42) self.assertEqual(str(cm.exception), 'parameter must be callable') def test_create_collation_not_ascii(self): - con = sqlite.connect(":memory:") - con.create_collation("collä", lambda x, y: (x > y) - (x < y)) + self.con.create_collation("collä", lambda x, y: (x > y) - (x < y)) def test_create_collation_bad_upper(self): class BadUpperStr(str): def upper(self): return None - con = sqlite.connect(":memory:") mycoll = lambda x, y: -((x > y) - (x < y)) - con.create_collation(BadUpperStr("mycoll"), mycoll) - result = con.execute(""" + self.con.create_collation(BadUpperStr("mycoll"), mycoll) + result = self.con.execute(""" select x from ( select 'a' as x union @@ -68,8 +65,7 @@ class CollationTests(unittest.TestCase): # reverse order return -((x > y) - (x < y)) - con = sqlite.connect(":memory:") - con.create_collation("mycoll", mycoll) + self.con.create_collation("mycoll", mycoll) sql = """ select x from ( select 'a' as x @@ -79,21 +75,20 @@ class CollationTests(unittest.TestCase): select 'c' as x ) order by x collate mycoll """ - result = con.execute(sql).fetchall() + result = self.con.execute(sql).fetchall() self.assertEqual(result, [('c',), ('b',), ('a',)], msg='the expected order was not returned') - con.create_collation("mycoll", None) + self.con.create_collation("mycoll", None) with self.assertRaises(sqlite.OperationalError) as cm: - result = con.execute(sql).fetchall() + result = self.con.execute(sql).fetchall() self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') def test_collation_returns_large_integer(self): def mycoll(x, y): # reverse order return -((x > y) - (x < y)) * 2**32 - con = sqlite.connect(":memory:") - con.create_collation("mycoll", mycoll) + self.con.create_collation("mycoll", mycoll) sql = """ select x from ( select 'a' as x @@ -103,7 +98,7 @@ class CollationTests(unittest.TestCase): select 'c' as x ) order by x collate mycoll """ - result = con.execute(sql).fetchall() + result = self.con.execute(sql).fetchall() self.assertEqual(result, [('c',), ('b',), ('a',)], msg="the expected order was not returned") @@ -112,7 +107,7 @@ class CollationTests(unittest.TestCase): Register two different collation functions under the same name. Verify that the last one is actually used. """ - con = sqlite.connect(":memory:") + con = self.con con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y))) result = con.execute(""" @@ -126,25 +121,26 @@ class CollationTests(unittest.TestCase): Register a collation, then deregister it. Make sure an error is raised if we try to use it. """ - con = sqlite.connect(":memory:") + con = self.con con.create_collation("mycoll", lambda x, y: (x > y) - (x < y)) con.create_collation("mycoll", None) with self.assertRaises(sqlite.OperationalError) as cm: con.execute("select 'a' as x union select 'b' as x order by x collate mycoll") self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') -class ProgressTests(unittest.TestCase): + +class ProgressTests(MemoryDatabaseMixin, unittest.TestCase): + def test_progress_handler_used(self): """ Test that the progress handler is invoked once it is set. """ - con = sqlite.connect(":memory:") progress_calls = [] def progress(): progress_calls.append(None) return 0 - con.set_progress_handler(progress, 1) - con.execute(""" + self.con.set_progress_handler(progress, 1) + self.con.execute(""" create table foo(a, b) """) self.assertTrue(progress_calls) @@ -153,7 +149,7 @@ class ProgressTests(unittest.TestCase): """ Test that the opcode argument is respected. """ - con = sqlite.connect(":memory:") + con = self.con progress_calls = [] def progress(): progress_calls.append(None) @@ -176,11 +172,10 @@ class ProgressTests(unittest.TestCase): """ Test that returning a non-zero value stops the operation in progress. """ - con = sqlite.connect(":memory:") def progress(): return 1 - con.set_progress_handler(progress, 1) - curs = con.cursor() + self.con.set_progress_handler(progress, 1) + curs = self.con.cursor() self.assertRaises( sqlite.OperationalError, curs.execute, @@ -190,7 +185,7 @@ class ProgressTests(unittest.TestCase): """ Test that setting the progress handler to None clears the previously set handler. """ - con = sqlite.connect(":memory:") + con = self.con action = 0 def progress(): nonlocal action @@ -203,31 +198,30 @@ class ProgressTests(unittest.TestCase): @with_tracebacks(ZeroDivisionError, name="bad_progress") def test_error_in_progress_handler(self): - con = sqlite.connect(":memory:") def bad_progress(): 1 / 0 - con.set_progress_handler(bad_progress, 1) + self.con.set_progress_handler(bad_progress, 1) with self.assertRaises(sqlite.OperationalError): - con.execute(""" + self.con.execute(""" create table foo(a, b) """) @with_tracebacks(ZeroDivisionError, name="bad_progress") def test_error_in_progress_handler_result(self): - con = sqlite.connect(":memory:") class BadBool: def __bool__(self): 1 / 0 def bad_progress(): return BadBool() - con.set_progress_handler(bad_progress, 1) + self.con.set_progress_handler(bad_progress, 1) with self.assertRaises(sqlite.OperationalError): - con.execute(""" + self.con.execute(""" create table foo(a, b) """) -class TraceCallbackTests(unittest.TestCase): +class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase): + @contextlib.contextmanager def check_stmt_trace(self, cx, expected): try: @@ -242,12 +236,11 @@ class TraceCallbackTests(unittest.TestCase): """ Test that the trace callback is invoked once it is set. """ - con = sqlite.connect(":memory:") traced_statements = [] def trace(statement): traced_statements.append(statement) - con.set_trace_callback(trace) - con.execute("create table foo(a, b)") + self.con.set_trace_callback(trace) + self.con.execute("create table foo(a, b)") self.assertTrue(traced_statements) self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) @@ -255,7 +248,7 @@ class TraceCallbackTests(unittest.TestCase): """ Test that setting the trace callback to None clears the previously set callback. """ - con = sqlite.connect(":memory:") + con = self.con traced_statements = [] def trace(statement): traced_statements.append(statement) @@ -269,7 +262,7 @@ class TraceCallbackTests(unittest.TestCase): Test that the statement can contain unicode literals. """ unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' - con = sqlite.connect(":memory:") + con = self.con traced_statements = [] def trace(statement): traced_statements.append(statement) |