summaryrefslogtreecommitdiffstats
path: root/Lib/sqlite3/test/hooks.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/sqlite3/test/hooks.py')
-rw-r--r--Lib/sqlite3/test/hooks.py137
1 files changed, 42 insertions, 95 deletions
diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py
index d74e74b..a5b1632 100644
--- a/Lib/sqlite3/test/hooks.py
+++ b/Lib/sqlite3/test/hooks.py
@@ -1,4 +1,4 @@
-#-*- coding: iso-8859-1 -*-
+#-*- coding: ISO-8859-1 -*-
# pysqlite2/test/hooks.py: tests for various SQLite-specific hooks
#
# Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de>
@@ -21,12 +21,16 @@
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
-import unittest
+import os, unittest
import sqlite3 as sqlite
-from test.support import TESTFN, unlink
-
class CollationTests(unittest.TestCase):
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
def CheckCreateCollationNotString(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError):
@@ -34,14 +38,19 @@ class CollationTests(unittest.TestCase):
def CheckCreateCollationNotCallable(self):
con = sqlite.connect(":memory:")
- with self.assertRaises(TypeError) as cm:
+ try:
con.create_collation("X", 42)
- self.assertEqual(str(cm.exception), 'parameter must be callable')
+ self.fail("should have raised a TypeError")
+ except TypeError, e:
+ self.assertEqual(e.args[0], "parameter must be callable")
def CheckCreateCollationNotAscii(self):
con = sqlite.connect(":memory:")
- with self.assertRaises(sqlite.ProgrammingError):
- con.create_collation("collä", lambda x, y: (x > y) - (x < y))
+ try:
+ con.create_collation("collä", cmp)
+ self.fail("should have raised a ProgrammingError")
+ except sqlite.ProgrammingError, e:
+ pass
def CheckCreateCollationBadUpper(self):
class BadUpperStr(str):
@@ -60,12 +69,12 @@ class CollationTests(unittest.TestCase):
self.assertEqual(result[0][0], 'b')
self.assertEqual(result[1][0], 'a')
- @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1),
- 'old SQLite versions crash on this test')
def CheckCollationIsUsed(self):
+ if sqlite.version_info < (3, 2, 1): # old SQLite versions crash on this test
+ return
def mycoll(x, y):
# reverse order
- return -((x > y) - (x < y))
+ return -cmp(x, y)
con = sqlite.connect(":memory:")
con.create_collation("mycoll", mycoll)
@@ -79,13 +88,15 @@ class CollationTests(unittest.TestCase):
) order by x collate mycoll
"""
result = con.execute(sql).fetchall()
- self.assertEqual(result, [('c',), ('b',), ('a',)],
- msg='the expected order was not returned')
+ if result[0][0] != "c" or result[1][0] != "b" or result[2][0] != "a":
+ self.fail("the expected order was not returned")
con.create_collation("mycoll", None)
- with self.assertRaises(sqlite.OperationalError) as cm:
+ try:
result = con.execute(sql).fetchall()
- self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
+ self.fail("should have raised an OperationalError")
+ except sqlite.OperationalError, e:
+ self.assertEqual(e.args[0].lower(), "no such collation sequence: mycoll")
def CheckCollationReturnsLargeInteger(self):
def mycoll(x, y):
@@ -112,13 +123,13 @@ class CollationTests(unittest.TestCase):
Verify that the last one is actually used.
"""
con = sqlite.connect(":memory:")
- con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
- con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
+ con.create_collation("mycoll", cmp)
+ con.create_collation("mycoll", lambda x, y: -cmp(x, y))
result = con.execute("""
select x from (select 'a' as x union select 'b' as x) order by x collate mycoll
""").fetchall()
- self.assertEqual(result[0][0], 'b')
- self.assertEqual(result[1][0], 'a')
+ if result[0][0] != 'b' or result[1][0] != 'a':
+ self.fail("wrong collation function is used")
def CheckDeregisterCollation(self):
"""
@@ -126,11 +137,14 @@ class CollationTests(unittest.TestCase):
to use it.
"""
con = sqlite.connect(":memory:")
- con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
+ con.create_collation("mycoll", cmp)
con.create_collation("mycoll", None)
- with self.assertRaises(sqlite.OperationalError) as cm:
+ try:
con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
- self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
+ self.fail("should have raised an OperationalError")
+ except sqlite.OperationalError, e:
+ if not e.args[0].startswith("no such collation sequence"):
+ self.fail("wrong OperationalError raised")
class ProgressTests(unittest.TestCase):
def CheckProgressHandlerUsed(self):
@@ -177,7 +191,9 @@ class ProgressTests(unittest.TestCase):
Test that returning a non-zero value stops the operation in progress.
"""
con = sqlite.connect(":memory:")
+ progress_calls = []
def progress():
+ progress_calls.append(None)
return 1
con.set_progress_handler(progress, 1)
curs = con.cursor()
@@ -191,88 +207,19 @@ class ProgressTests(unittest.TestCase):
Test that setting the progress handler to None clears the previously set handler.
"""
con = sqlite.connect(":memory:")
- action = 0
+ action = []
def progress():
- nonlocal action
- action = 1
+ action.append(1)
return 0
con.set_progress_handler(progress, 1)
con.set_progress_handler(None, 1)
con.execute("select 1 union select 2 union select 3").fetchall()
- self.assertEqual(action, 0, "progress handler was not cleared")
-
-class TraceCallbackTests(unittest.TestCase):
- def CheckTraceCallbackUsed(self):
- """
- Test that the trace callback is invoked once it is set.
- """
- con = sqlite.connect(":memory:")
- traced_statements = []
- def trace(statement):
- traced_statements.append(statement)
- con.set_trace_callback(trace)
- con.execute("create table foo(a, b)")
- self.assertTrue(traced_statements)
- self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
-
- def CheckClearTraceCallback(self):
- """
- Test that setting the trace callback to None clears the previously set callback.
- """
- con = sqlite.connect(":memory:")
- traced_statements = []
- def trace(statement):
- traced_statements.append(statement)
- con.set_trace_callback(trace)
- con.set_trace_callback(None)
- con.execute("create table foo(a, b)")
- self.assertFalse(traced_statements, "trace callback was not cleared")
-
- def CheckUnicodeContent(self):
- """
- Test that the statement can contain unicode literals.
- """
- unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
- con = sqlite.connect(":memory:")
- traced_statements = []
- def trace(statement):
- traced_statements.append(statement)
- con.set_trace_callback(trace)
- con.execute("create table foo(x)")
- # Can't execute bound parameters as their values don't appear
- # in traced statements before SQLite 3.6.21
- # (cf. http://www.sqlite.org/draft/releaselog/3_6_21.html)
- con.execute('insert into foo(x) values ("%s")' % unicode_value)
- con.commit()
- self.assertTrue(any(unicode_value in stmt for stmt in traced_statements),
- "Unicode data %s garbled in trace callback: %s"
- % (ascii(unicode_value), ', '.join(map(ascii, traced_statements))))
-
- @unittest.skipIf(sqlite.sqlite_version_info < (3, 3, 9), "sqlite3_prepare_v2 is not available")
- def CheckTraceCallbackContent(self):
- # set_trace_callback() shouldn't produce duplicate content (bpo-26187)
- traced_statements = []
- def trace(statement):
- traced_statements.append(statement)
-
- queries = ["create table foo(x)",
- "insert into foo(x) values(1)"]
- self.addCleanup(unlink, TESTFN)
- con1 = sqlite.connect(TESTFN, isolation_level=None)
- con2 = sqlite.connect(TESTFN)
- con1.set_trace_callback(trace)
- cur = con1.cursor()
- cur.execute(queries[0])
- con2.execute("create table bar(x)")
- cur.execute(queries[1])
- self.assertEqual(traced_statements, queries)
-
+ self.assertEqual(len(action), 0, "progress handler was not cleared")
def suite():
collation_suite = unittest.makeSuite(CollationTests, "Check")
progress_suite = unittest.makeSuite(ProgressTests, "Check")
- trace_suite = unittest.makeSuite(TraceCallbackTests, "Check")
- return unittest.TestSuite((collation_suite, progress_suite, trace_suite))
+ return unittest.TestSuite((collation_suite, progress_suite))
def test():
runner = unittest.TextTestRunner()