diff options
Diffstat (limited to 'Lib/sqlite3')
-rw-r--r-- | Lib/sqlite3/test/hooks.py | 52 | ||||
-rw-r--r-- | Lib/sqlite3/test/regression.py | 22 | ||||
-rw-r--r-- | Lib/sqlite3/test/types.py | 2 |
3 files changed, 74 insertions, 2 deletions
diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index a6161fa..dad35d9 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -175,10 +175,60 @@ class ProgressTests(unittest.TestCase): con.execute("select 1 union select 2 union select 3").fetchall() self.assertEqual(action, 0, "progress handler was not cleared") +class TraceCallbackTests(unittest.TestCase): + def CheckTraceCallbackUsed(self): + """ + Test that the trace callback is invoked once it is set. + """ + con = sqlite.connect(":memory:") + traced_statements = [] + def trace(statement): + traced_statements.append(statement) + con.set_trace_callback(trace) + con.execute("create table foo(a, b)") + self.assertTrue(traced_statements) + self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) + + def CheckClearTraceCallback(self): + """ + Test that setting the trace callback to None clears the previously set callback. + """ + con = sqlite.connect(":memory:") + traced_statements = [] + def trace(statement): + traced_statements.append(statement) + con.set_trace_callback(trace) + con.set_trace_callback(None) + con.execute("create table foo(a, b)") + self.assertFalse(traced_statements, "trace callback was not cleared") + + def CheckUnicodeContent(self): + """ + Test that the statement can contain unicode literals. + """ + unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac' + con = sqlite.connect(":memory:") + traced_statements = [] + def trace(statement): + traced_statements.append(statement) + con.set_trace_callback(trace) + con.execute("create table foo(x)") + # Can't execute bound parameters as their values don't appear + # in traced statements before SQLite 3.6.21 + # (cf. http://www.sqlite.org/draft/releaselog/3_6_21.html) + con.execute('insert into foo(x) values ("%s")' % unicode_value) + con.commit() + self.assertTrue(any(unicode_value in stmt for stmt in traced_statements), + "Unicode data %s garbled in trace callback: %s" + % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) + + + def suite(): collation_suite = unittest.makeSuite(CollationTests, "Check") progress_suite = unittest.makeSuite(ProgressTests, "Check") - return unittest.TestSuite((collation_suite, progress_suite)) + trace_suite = unittest.makeSuite(TraceCallbackTests, "Check") + return unittest.TestSuite((collation_suite, progress_suite, trace_suite)) def test(): runner = unittest.TextTestRunner() diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index 7d0553d..c7551e3 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -281,6 +281,28 @@ class RegressionTests(unittest.TestCase): # Lone surrogate cannot be encoded to the default encoding (utf8) "\uDC80", collation_cb) + def CheckRecursiveCursorUse(self): + """ + http://bugs.python.org/issue10811 + + Recursively using a cursor, such as when reusing it from a generator led to segfaults. + Now we catch recursive cursor usage and raise a ProgrammingError. + """ + con = sqlite.connect(":memory:") + + cur = con.cursor() + cur.execute("create table a (bar)") + cur.execute("create table b (baz)") + + def foo(): + cur.execute("insert into a (bar) values (?)", (1,)) + yield 1 + + with self.assertRaises(sqlite.ProgrammingError): + cur.executemany("insert into b (baz) values (?)", + ((i,) for i in foo())) + + def suite(): regression_suite = unittest.makeSuite(RegressionTests, "Check") return unittest.TestSuite((regression_suite,)) diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py index 29413e1..d214f3d 100644 --- a/Lib/sqlite3/test/types.py +++ b/Lib/sqlite3/test/types.py @@ -85,7 +85,7 @@ class DeclTypesTests(unittest.TestCase): if isinstance(_val, bytes): # sqlite3 always calls __init__ with a bytes created from a # UTF-8 string when __conform__ was used to store the object. - _val = _val.decode('utf8') + _val = _val.decode('utf-8') self.val = _val def __cmp__(self, other): |