summaryrefslogtreecommitdiffstats
path: root/Lib/sqlite3/test/dbapi.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/sqlite3/test/dbapi.py')
-rw-r--r--Lib/sqlite3/test/dbapi.py75
1 files changed, 66 insertions, 9 deletions
diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py
index fbf3072..518b8ae 100644
--- a/Lib/sqlite3/test/dbapi.py
+++ b/Lib/sqlite3/test/dbapi.py
@@ -1,7 +1,7 @@
#-*- coding: ISO-8859-1 -*-
# pysqlite2/test/dbapi.py: tests for DB-API compliance
#
-# Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de>
+# Copyright (C) 2004-2010 Gerhard Häring <gh@ghaering.de>
#
# This file is part of pysqlite.
#
@@ -22,8 +22,11 @@
# 3. This notice may not be removed or altered from any source distribution.
import unittest
-import threading
import sqlite3 as sqlite
+try:
+ import threading
+except ImportError:
+ threading = None
class ModuleTests(unittest.TestCase):
def CheckAPILevel(self):
@@ -81,6 +84,7 @@ class ModuleTests(unittest.TestCase):
"NotSupportedError is not a subclass of DatabaseError")
class ConnectionTests(unittest.TestCase):
+
def setUp(self):
self.cx = sqlite.connect(":memory:")
cu = self.cx.cursor()
@@ -137,6 +141,28 @@ class ConnectionTests(unittest.TestCase):
self.assertEqual(self.cx.ProgrammingError, sqlite.ProgrammingError)
self.assertEqual(self.cx.NotSupportedError, sqlite.NotSupportedError)
+ def CheckInTransaction(self):
+ # Can't use db from setUp because we want to test initial state.
+ cx = sqlite.connect(":memory:")
+ cu = cx.cursor()
+ self.assertEqual(cx.in_transaction, False)
+ cu.execute("create table transactiontest(id integer primary key, name text)")
+ self.assertEqual(cx.in_transaction, False)
+ cu.execute("insert into transactiontest(name) values (?)", ("foo",))
+ self.assertEqual(cx.in_transaction, True)
+ cu.execute("select name from transactiontest where name=?", ["foo"])
+ row = cu.fetchone()
+ self.assertEqual(cx.in_transaction, True)
+ cx.commit()
+ self.assertEqual(cx.in_transaction, False)
+ cu.execute("select name from transactiontest where name=?", ["foo"])
+ row = cu.fetchone()
+ self.assertEqual(cx.in_transaction, False)
+
+ def CheckInTransactionRO(self):
+ with self.assertRaises(AttributeError):
+ self.cx.in_transaction = True
+
class CursorTests(unittest.TestCase):
def setUp(self):
self.cx = sqlite.connect(":memory:")
@@ -460,6 +486,7 @@ class CursorTests(unittest.TestCase):
except TypeError:
pass
+@unittest.skipUnless(threading, 'This test requires threading.')
class ThreadTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
@@ -653,13 +680,13 @@ class ExtensionTests(unittest.TestCase):
res = cur.fetchone()[0]
self.assertEqual(res, 5)
- def CheckScriptErrorIncomplete(self):
+ def CheckScriptSyntaxError(self):
con = sqlite.connect(":memory:")
cur = con.cursor()
raised = False
try:
- cur.executescript("create table test(sadfsadfdsa")
- except sqlite.ProgrammingError:
+ cur.executescript("create table test(x); asdf; create table test2(x)")
+ except sqlite.OperationalError:
raised = True
self.assertEqual(raised, True, "should have raised an exception")
@@ -692,7 +719,7 @@ class ExtensionTests(unittest.TestCase):
result = con.execute("select foo from test").fetchone()[0]
self.assertEqual(result, 5, "Basic test of Connection.executescript")
-class ClosedTests(unittest.TestCase):
+class ClosedConTests(unittest.TestCase):
def setUp(self):
pass
@@ -744,7 +771,6 @@ class ClosedTests(unittest.TestCase):
except:
self.fail("Should have raised a ProgrammingError")
-
def CheckClosedCreateFunction(self):
con = sqlite.connect(":memory:")
con.close()
@@ -811,6 +837,36 @@ class ClosedTests(unittest.TestCase):
except:
self.fail("Should have raised a ProgrammingError")
+class ClosedCurTests(unittest.TestCase):
+ def setUp(self):
+ pass
+
+ def tearDown(self):
+ pass
+
+ def CheckClosed(self):
+ con = sqlite.connect(":memory:")
+ cur = con.cursor()
+ cur.close()
+
+ for method_name in ("execute", "executemany", "executescript", "fetchall", "fetchmany", "fetchone"):
+ if method_name in ("execute", "executescript"):
+ params = ("select 4 union select 5",)
+ elif method_name == "executemany":
+ params = ("insert into foo(bar) values (?)", [(3,), (4,)])
+ else:
+ params = []
+
+ try:
+ method = getattr(cur, method_name)
+
+ method(*params)
+ self.fail("Should have raised a ProgrammingError: method " + method_name)
+ except sqlite.ProgrammingError:
+ pass
+ except:
+ self.fail("Should have raised a ProgrammingError: " + method_name)
+
def suite():
module_suite = unittest.makeSuite(ModuleTests, "Check")
connection_suite = unittest.makeSuite(ConnectionTests, "Check")
@@ -818,8 +874,9 @@ def suite():
thread_suite = unittest.makeSuite(ThreadTests, "Check")
constructor_suite = unittest.makeSuite(ConstructorTests, "Check")
ext_suite = unittest.makeSuite(ExtensionTests, "Check")
- closed_suite = unittest.makeSuite(ClosedTests, "Check")
- return unittest.TestSuite((module_suite, connection_suite, cursor_suite, thread_suite, constructor_suite, ext_suite, closed_suite))
+ closed_con_suite = unittest.makeSuite(ClosedConTests, "Check")
+ closed_cur_suite = unittest.makeSuite(ClosedCurTests, "Check")
+ return unittest.TestSuite((module_suite, connection_suite, cursor_suite, thread_suite, constructor_suite, ext_suite, closed_con_suite, closed_cur_suite))
def test():
runner = unittest.TextTestRunner()