summaryrefslogtreecommitdiffstats
path: root/Lib/test/test_sqlite3/test_hooks.py
diff options
context:
space:
mode:
authorErlend E. Aasland <erlend@python.org>2023-08-17 06:45:48 (GMT)
committerGitHub <noreply@github.com>2023-08-17 06:45:48 (GMT)
commit1344cfac43a1920c596b0e8718ca0567889e697b (patch)
tree8cc0f9288d1c9f58ab92bd077462e532ac3bd057 /Lib/test/test_sqlite3/test_hooks.py
parentc9d83f93d804b80ee14480466ebee63a6f97dac2 (diff)
downloadcpython-1344cfac43a1920c596b0e8718ca0567889e697b.zip
cpython-1344cfac43a1920c596b0e8718ca0567889e697b.tar.gz
cpython-1344cfac43a1920c596b0e8718ca0567889e697b.tar.bz2
gh-105539: Explict resource management for connection objects in sqlite3 tests (#108017)
- Use memory_database() helper - Move test utility functions to util.py - Add convenience memory database mixin - Add check() helper for closed connection tests
Diffstat (limited to 'Lib/test/test_sqlite3/test_hooks.py')
-rw-r--r--Lib/test/test_sqlite3/test_hooks.py79
1 files changed, 36 insertions, 43 deletions
diff --git a/Lib/test/test_sqlite3/test_hooks.py b/Lib/test/test_sqlite3/test_hooks.py
index 89230c0..33f0af9 100644
--- a/Lib/test/test_sqlite3/test_hooks.py
+++ b/Lib/test/test_sqlite3/test_hooks.py
@@ -26,34 +26,31 @@ import unittest
from test.support.os_helper import TESTFN, unlink
-from test.test_sqlite3.test_dbapi import memory_database, cx_limit
-from test.test_sqlite3.test_userfunctions import with_tracebacks
+from .util import memory_database, cx_limit, with_tracebacks
+from .util import MemoryDatabaseMixin
-class CollationTests(unittest.TestCase):
+class CollationTests(MemoryDatabaseMixin, unittest.TestCase):
+
def test_create_collation_not_string(self):
- con = sqlite.connect(":memory:")
with self.assertRaises(TypeError):
- con.create_collation(None, lambda x, y: (x > y) - (x < y))
+ self.con.create_collation(None, lambda x, y: (x > y) - (x < y))
def test_create_collation_not_callable(self):
- con = sqlite.connect(":memory:")
with self.assertRaises(TypeError) as cm:
- con.create_collation("X", 42)
+ self.con.create_collation("X", 42)
self.assertEqual(str(cm.exception), 'parameter must be callable')
def test_create_collation_not_ascii(self):
- con = sqlite.connect(":memory:")
- con.create_collation("collä", lambda x, y: (x > y) - (x < y))
+ self.con.create_collation("collä", lambda x, y: (x > y) - (x < y))
def test_create_collation_bad_upper(self):
class BadUpperStr(str):
def upper(self):
return None
- con = sqlite.connect(":memory:")
mycoll = lambda x, y: -((x > y) - (x < y))
- con.create_collation(BadUpperStr("mycoll"), mycoll)
- result = con.execute("""
+ self.con.create_collation(BadUpperStr("mycoll"), mycoll)
+ result = self.con.execute("""
select x from (
select 'a' as x
union
@@ -68,8 +65,7 @@ class CollationTests(unittest.TestCase):
# reverse order
return -((x > y) - (x < y))
- con = sqlite.connect(":memory:")
- con.create_collation("mycoll", mycoll)
+ self.con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
@@ -79,21 +75,20 @@ class CollationTests(unittest.TestCase):
select 'c' as x
) order by x collate mycoll
"""
- result = con.execute(sql).fetchall()
+ result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg='the expected order was not returned')
- con.create_collation("mycoll", None)
+ self.con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
- result = con.execute(sql).fetchall()
+ result = self.con.execute(sql).fetchall()
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
def test_collation_returns_large_integer(self):
def mycoll(x, y):
# reverse order
return -((x > y) - (x < y)) * 2**32
- con = sqlite.connect(":memory:")
- con.create_collation("mycoll", mycoll)
+ self.con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
@@ -103,7 +98,7 @@ class CollationTests(unittest.TestCase):
select 'c' as x
) order by x collate mycoll
"""
- result = con.execute(sql).fetchall()
+ result = self.con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg="the expected order was not returned")
@@ -112,7 +107,7 @@ class CollationTests(unittest.TestCase):
Register two different collation functions under the same name.
Verify that the last one is actually used.
"""
- con = sqlite.connect(":memory:")
+ con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
result = con.execute("""
@@ -126,25 +121,26 @@ class CollationTests(unittest.TestCase):
Register a collation, then deregister it. Make sure an error is raised if we try
to use it.
"""
- con = sqlite.connect(":memory:")
+ con = self.con
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
-class ProgressTests(unittest.TestCase):
+
+class ProgressTests(MemoryDatabaseMixin, unittest.TestCase):
+
def test_progress_handler_used(self):
"""
Test that the progress handler is invoked once it is set.
"""
- con = sqlite.connect(":memory:")
progress_calls = []
def progress():
progress_calls.append(None)
return 0
- con.set_progress_handler(progress, 1)
- con.execute("""
+ self.con.set_progress_handler(progress, 1)
+ self.con.execute("""
create table foo(a, b)
""")
self.assertTrue(progress_calls)
@@ -153,7 +149,7 @@ class ProgressTests(unittest.TestCase):
"""
Test that the opcode argument is respected.
"""
- con = sqlite.connect(":memory:")
+ con = self.con
progress_calls = []
def progress():
progress_calls.append(None)
@@ -176,11 +172,10 @@ class ProgressTests(unittest.TestCase):
"""
Test that returning a non-zero value stops the operation in progress.
"""
- con = sqlite.connect(":memory:")
def progress():
return 1
- con.set_progress_handler(progress, 1)
- curs = con.cursor()
+ self.con.set_progress_handler(progress, 1)
+ curs = self.con.cursor()
self.assertRaises(
sqlite.OperationalError,
curs.execute,
@@ -190,7 +185,7 @@ class ProgressTests(unittest.TestCase):
"""
Test that setting the progress handler to None clears the previously set handler.
"""
- con = sqlite.connect(":memory:")
+ con = self.con
action = 0
def progress():
nonlocal action
@@ -203,31 +198,30 @@ class ProgressTests(unittest.TestCase):
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler(self):
- con = sqlite.connect(":memory:")
def bad_progress():
1 / 0
- con.set_progress_handler(bad_progress, 1)
+ self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
- con.execute("""
+ self.con.execute("""
create table foo(a, b)
""")
@with_tracebacks(ZeroDivisionError, name="bad_progress")
def test_error_in_progress_handler_result(self):
- con = sqlite.connect(":memory:")
class BadBool:
def __bool__(self):
1 / 0
def bad_progress():
return BadBool()
- con.set_progress_handler(bad_progress, 1)
+ self.con.set_progress_handler(bad_progress, 1)
with self.assertRaises(sqlite.OperationalError):
- con.execute("""
+ self.con.execute("""
create table foo(a, b)
""")
-class TraceCallbackTests(unittest.TestCase):
+class TraceCallbackTests(MemoryDatabaseMixin, unittest.TestCase):
+
@contextlib.contextmanager
def check_stmt_trace(self, cx, expected):
try:
@@ -242,12 +236,11 @@ class TraceCallbackTests(unittest.TestCase):
"""
Test that the trace callback is invoked once it is set.
"""
- con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
- con.set_trace_callback(trace)
- con.execute("create table foo(a, b)")
+ self.con.set_trace_callback(trace)
+ self.con.execute("create table foo(a, b)")
self.assertTrue(traced_statements)
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
@@ -255,7 +248,7 @@ class TraceCallbackTests(unittest.TestCase):
"""
Test that setting the trace callback to None clears the previously set callback.
"""
- con = sqlite.connect(":memory:")
+ con = self.con
traced_statements = []
def trace(statement):
traced_statements.append(statement)
@@ -269,7 +262,7 @@ class TraceCallbackTests(unittest.TestCase):
Test that the statement can contain unicode literals.
"""
unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
- con = sqlite.connect(":memory:")
+ con = self.con
traced_statements = []
def trace(statement):
traced_statements.append(statement)