diff options
-rw-r--r-- | Lib/sqlite3/test/dbapi.py | 51 | ||||
-rw-r--r-- | Lib/sqlite3/test/factory.py | 2 | ||||
-rw-r--r-- | Lib/sqlite3/test/hooks.py | 75 | ||||
-rw-r--r-- | Lib/sqlite3/test/regression.py | 62 | ||||
-rw-r--r-- | Lib/sqlite3/test/transactions.py | 19 | ||||
-rw-r--r-- | Lib/sqlite3/test/types.py | 15 | ||||
-rw-r--r-- | Modules/_sqlite/cache.c | 2 | ||||
-rw-r--r-- | Modules/_sqlite/cache.h | 2 | ||||
-rw-r--r-- | Modules/_sqlite/connection.c | 100 | ||||
-rw-r--r-- | Modules/_sqlite/connection.h | 2 | ||||
-rw-r--r-- | Modules/_sqlite/cursor.c | 126 | ||||
-rw-r--r-- | Modules/_sqlite/cursor.h | 4 | ||||
-rw-r--r-- | Modules/_sqlite/microprotocols.h | 4 | ||||
-rw-r--r-- | Modules/_sqlite/module.c | 54 | ||||
-rw-r--r-- | Modules/_sqlite/module.h | 5 | ||||
-rw-r--r-- | Modules/_sqlite/prepare_protocol.h | 2 | ||||
-rw-r--r-- | Modules/_sqlite/row.c | 7 | ||||
-rw-r--r-- | Modules/_sqlite/row.h | 2 | ||||
-rw-r--r-- | Modules/_sqlite/statement.c | 225 | ||||
-rw-r--r-- | Modules/_sqlite/statement.h | 6 | ||||
-rw-r--r-- | Modules/_sqlite/util.c | 9 | ||||
-rw-r--r-- | Modules/_sqlite/util.h | 4 |
22 files changed, 582 insertions, 196 deletions
diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index 6d4c4fe..8327aa1 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/dbapi.py: tests for DB-API compliance # -# Copyright (C) 2004-2005 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> # # This file is part of pysqlite. # @@ -223,12 +223,41 @@ class CursorTests(unittest.TestCase): except sqlite.ProgrammingError: pass + def CheckExecuteParamList(self): + self.cu.execute("insert into test(name) values ('foo')") + self.cu.execute("select name from test where name=?", ["foo"]) + row = self.cu.fetchone() + self.failUnlessEqual(row[0], "foo") + + def CheckExecuteParamSequence(self): + class L(object): + def __len__(self): + return 1 + def __getitem__(self, x): + assert x == 0 + return "foo" + + self.cu.execute("insert into test(name) values ('foo')") + self.cu.execute("select name from test where name=?", L()) + row = self.cu.fetchone() + self.failUnlessEqual(row[0], "foo") + def CheckExecuteDictMapping(self): self.cu.execute("insert into test(name) values ('foo')") self.cu.execute("select name from test where name=:name", {"name": "foo"}) row = self.cu.fetchone() self.failUnlessEqual(row[0], "foo") + def CheckExecuteDictMapping_Mapping(self): + class D(dict): + def __missing__(self, key): + return "foo" + + self.cu.execute("insert into test(name) values ('foo')") + self.cu.execute("select name from test where name=:name", D()) + row = self.cu.fetchone() + self.failUnlessEqual(row[0], "foo") + def CheckExecuteDictMappingTooLittleArgs(self): self.cu.execute("insert into test(name) values ('foo')") try: @@ -378,6 +407,12 @@ class CursorTests(unittest.TestCase): res = self.cu.fetchmany(100) self.failUnlessEqual(res, []) + def CheckFetchmanyKwArg(self): + """Checks if fetchmany works with keyword arguments""" + self.cu.execute("select name from test") + res = self.cu.fetchmany(size=100) + self.failUnlessEqual(len(res), 1) + def CheckFetchall(self): self.cu.execute("select name from test") res = self.cu.fetchall() @@ -609,20 +644,6 @@ class ExtensionTests(unittest.TestCase): res = cur.fetchone()[0] self.failUnlessEqual(res, 5) - def CheckScriptStringUnicode(self): - con = sqlite.connect(":memory:") - cur = con.cursor() - cur.executescript(""" - create table a(i); - insert into a(i) values (5); - select i from a; - delete from a; - insert into a(i) values (6); - """) - cur.execute("select i from a") - res = cur.fetchone()[0] - self.failUnlessEqual(res, 6) - def CheckScriptErrorIncomplete(self): con = sqlite.connect(":memory:") cur = con.cursor() diff --git a/Lib/sqlite3/test/factory.py b/Lib/sqlite3/test/factory.py index a9a828f..bc56caa 100644 --- a/Lib/sqlite3/test/factory.py +++ b/Lib/sqlite3/test/factory.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/factory.py: tests for the various factories in pysqlite # -# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> # # This file is part of pysqlite. # diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index 28f2404..6872fd6 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks # -# Copyright (C) 2006 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de> # # This file is part of pysqlite. # @@ -105,9 +105,80 @@ class CollationTests(unittest.TestCase): if not e.args[0].startswith("no such collation sequence"): self.fail("wrong OperationalError raised") +class ProgressTests(unittest.TestCase): + def CheckProgressHandlerUsed(self): + """ + Test that the progress handler is invoked once it is set. + """ + con = sqlite.connect(":memory:") + progress_calls = [] + def progress(): + progress_calls.append(None) + return 0 + con.set_progress_handler(progress, 1) + con.execute(""" + create table foo(a, b) + """) + self.failUnless(progress_calls) + + + def CheckOpcodeCount(self): + """ + Test that the opcode argument is respected. + """ + con = sqlite.connect(":memory:") + progress_calls = [] + def progress(): + progress_calls.append(None) + return 0 + con.set_progress_handler(progress, 1) + curs = con.cursor() + curs.execute(""" + create table foo (a, b) + """) + first_count = len(progress_calls) + progress_calls = [] + con.set_progress_handler(progress, 2) + curs.execute(""" + create table bar (a, b) + """) + second_count = len(progress_calls) + self.failUnless(first_count > second_count) + + def CheckCancelOperation(self): + """ + Test that returning a non-zero value stops the operation in progress. + """ + con = sqlite.connect(":memory:") + progress_calls = [] + def progress(): + progress_calls.append(None) + return 1 + con.set_progress_handler(progress, 1) + curs = con.cursor() + self.assertRaises( + sqlite.OperationalError, + curs.execute, + "create table bar (a, b)") + + def CheckClearHandler(self): + """ + Test that setting the progress handler to None clears the previously set handler. + """ + con = sqlite.connect(":memory:") + action = 0 + def progress(): + action = 1 + return 0 + con.set_progress_handler(progress, 1) + con.set_progress_handler(None, 1) + con.execute("select 1 union select 2 union select 3").fetchall() + self.failUnlessEqual(action, 0, "progress handler was not cleared") + def suite(): collation_suite = unittest.makeSuite(CollationTests, "Check") - return unittest.TestSuite((collation_suite,)) + progress_suite = unittest.makeSuite(ProgressTests, "Check") + return unittest.TestSuite((collation_suite, progress_suite)) def test(): runner = unittest.TextTestRunner() diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index 4a68d9d..5e89a6c 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -21,6 +21,7 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. +import datetime import unittest import sqlite3 as sqlite @@ -79,6 +80,67 @@ class RegressionTests(unittest.TestCase): cur.fetchone() cur.fetchone() + def CheckStatementFinalizationOnCloseDb(self): + # pysqlite versions <= 2.3.3 only finalized statements in the statement + # cache when closing the database. statements that were still + # referenced in cursors weren't closed an could provoke " + # "OperationalError: Unable to close due to unfinalised statements". + con = sqlite.connect(":memory:") + cursors = [] + # default statement cache size is 100 + for i in range(105): + cur = con.cursor() + cursors.append(cur) + cur.execute("select 1 x union select " + str(i)) + con.close() + + def CheckOnConflictRollback(self): + if sqlite.sqlite_version_info < (3, 2, 2): + return + con = sqlite.connect(":memory:") + con.execute("create table foo(x, unique(x) on conflict rollback)") + con.execute("insert into foo(x) values (1)") + try: + con.execute("insert into foo(x) values (1)") + except sqlite.DatabaseError: + pass + con.execute("insert into foo(x) values (2)") + try: + con.commit() + except sqlite.OperationalError: + self.fail("pysqlite knew nothing about the implicit ROLLBACK") + + def CheckWorkaroundForBuggySqliteTransferBindings(self): + """ + pysqlite would crash with older SQLite versions unless + a workaround is implemented. + """ + self.con.execute("create table foo(bar)") + self.con.execute("drop table foo") + self.con.execute("create table foo(bar)") + + def CheckEmptyStatement(self): + """ + pysqlite used to segfault with SQLite versions 3.5.x. These return NULL + for "no-operation" statements + """ + self.con.execute("") + + def CheckTypeMapUsage(self): + """ + pysqlite until 2.4.1 did not rebuild the row_cast_map when recompiling + a statement. This test exhibits the problem. + """ + SELECT = "select * from foo" + con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES) + con.execute("create table foo(bar timestamp)") + con.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),)) + con.execute(SELECT) + con.execute("drop table foo") + con.execute("create table foo(bar integer)") + con.execute("insert into foo(bar) values (5)") + con.execute(SELECT) + def CheckErrorMsgDecodeError(self): # When porting the module to Python 3.0, the error message about # decoding errors disappeared. This verifies they're back again. diff --git a/Lib/sqlite3/test/transactions.py b/Lib/sqlite3/test/transactions.py index 9edc4ac..da5bd21 100644 --- a/Lib/sqlite3/test/transactions.py +++ b/Lib/sqlite3/test/transactions.py @@ -1,7 +1,7 @@ #-*- coding: ISO-8859-1 -*- # pysqlite2/test/transactions.py: tests transactions # -# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> # # This file is part of pysqlite. # @@ -122,6 +122,23 @@ class TransactionTests(unittest.TestCase): except: self.fail("should have raised an OperationalError") + def CheckLocking(self): + """ + This tests the improved concurrency with pysqlite 2.3.4. You needed + to roll back con2 before you could commit con1. + """ + self.cur1.execute("create table test(i)") + self.cur1.execute("insert into test(i) values (5)") + try: + self.cur2.execute("insert into test(i) values (5)") + self.fail("should have raised an OperationalError") + except sqlite.OperationalError: + pass + except: + self.fail("should have raised an OperationalError") + # NO self.con2.rollback() HERE!!! + self.con1.commit() + class SpecialCommandTests(unittest.TestCase): def setUp(self): self.con = sqlite.connect(":memory:") diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py index 46bed7d..ce740a5 100644 --- a/Lib/sqlite3/test/types.py +++ b/Lib/sqlite3/test/types.py @@ -21,7 +21,7 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. -import bz2, datetime +import zlib, datetime import unittest import sqlite3 as sqlite @@ -221,11 +221,13 @@ class ColNamesTests(unittest.TestCase): self.cur = self.con.cursor() self.cur.execute("create table test(x foo)") - sqlite.converters["BAR"] = lambda x: b"<" + x + b">" + sqlite.converters["FOO"] = lambda x: "[%s]" % x.decode("ascii") + sqlite.converters["BAR"] = lambda x: "<%s>" % x.decode("ascii") sqlite.converters["EXC"] = lambda x: 5/0 sqlite.converters["B1B1"] = lambda x: "MARKER" def tearDown(self): + del sqlite.converters["FOO"] del sqlite.converters["BAR"] del sqlite.converters["EXC"] del sqlite.converters["B1B1"] @@ -252,7 +254,7 @@ class ColNamesTests(unittest.TestCase): self.cur.execute("insert into test(x) values (?)", ("xxx",)) self.cur.execute('select x as "x [bar]" from test') val = self.cur.fetchone()[0] - self.failUnlessEqual(val, b"<xxx>") + self.failUnlessEqual(val, "<xxx>") # Check if the stripping of colnames works. Everything after the first # whitespace should be stripped. @@ -297,7 +299,7 @@ class ObjectAdaptationTests(unittest.TestCase): class BinaryConverterTests(unittest.TestCase): def convert(s): - return bz2.decompress(s) + return zlib.decompress(s) convert = staticmethod(convert) def setUp(self): @@ -309,7 +311,7 @@ class BinaryConverterTests(unittest.TestCase): def CheckBinaryInputForConverter(self): testdata = b"abcdefg" * 10 - result = self.con.execute('select ? as "x [bin]"', (memoryview(bz2.compress(testdata)),)).fetchone()[0] + result = self.con.execute('select ? as "x [bin]"', (memoryview(zlib.compress(testdata)),)).fetchone()[0] self.failUnlessEqual(testdata, result) class DateTimeTests(unittest.TestCase): @@ -341,7 +343,8 @@ class DateTimeTests(unittest.TestCase): if sqlite.sqlite_version_info < (3, 1): return - now = datetime.datetime.utcnow() + # SQLite's current_timestamp uses UTC time, while datetime.datetime.now() uses local time. + now = datetime.datetime.now() self.cur.execute("insert into test(ts) values (current_timestamp)") self.cur.execute("select ts from test") ts = self.cur.fetchone()[0] diff --git a/Modules/_sqlite/cache.c b/Modules/_sqlite/cache.c index 3cb540a..ed53958 100644 --- a/Modules/_sqlite/cache.c +++ b/Modules/_sqlite/cache.c @@ -1,6 +1,6 @@ /* cache .c - a LRU cache * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * diff --git a/Modules/_sqlite/cache.h b/Modules/_sqlite/cache.h index 158bf5a..d6f7f13 100644 --- a/Modules/_sqlite/cache.h +++ b/Modules/_sqlite/cache.h @@ -1,6 +1,6 @@ /* cache.h - definitions for the LRU cache * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index 11a14a1..667c3f0 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -1,6 +1,6 @@ /* connection.c - the connection type * - * Copyright (C) 2004-2006 Gerhard H�ring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * @@ -32,6 +32,9 @@ #include "pythread.h" +#define ACTION_FINALIZE 1 +#define ACTION_RESET 2 + static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level); @@ -63,7 +66,7 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist, &database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements)) { - return -1; + return -1; } self->begin_statement = NULL; @@ -82,7 +85,7 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); return -1; } @@ -169,7 +172,8 @@ void pysqlite_flush_statement_cache(pysqlite_Connection* self) self->statement_cache->decref_factory = 0; } -void pysqlite_reset_all_statements(pysqlite_Connection* self) +/* action in (ACTION_RESET, ACTION_FINALIZE) */ +void pysqlite_do_all_statements(pysqlite_Connection* self, int action) { int i; PyObject* weakref; @@ -179,7 +183,11 @@ void pysqlite_reset_all_statements(pysqlite_Connection* self) weakref = PyList_GetItem(self->statements, i); statement = PyWeakref_GetObject(weakref); if (statement != Py_None) { - (void)pysqlite_statement_reset((pysqlite_Statement*)statement); + if (action == ACTION_RESET) { + (void)pysqlite_statement_reset((pysqlite_Statement*)statement); + } else { + (void)pysqlite_statement_finalize((pysqlite_Statement*)statement); + } } } } @@ -247,7 +255,7 @@ PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args) return NULL; } - pysqlite_flush_statement_cache(self); + pysqlite_do_all_statements(self, ACTION_FINALIZE); if (self->db) { Py_BEGIN_ALLOW_THREADS @@ -255,7 +263,7 @@ PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args) Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); return NULL; } else { self->db = NULL; @@ -292,7 +300,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self) Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); goto error; } @@ -300,7 +308,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self) if (rc == SQLITE_DONE) { self->inTransaction = 1; } else { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); } Py_BEGIN_ALLOW_THREADS @@ -308,7 +316,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self) Py_END_ALLOW_THREADS if (rc != SQLITE_OK && !PyErr_Occurred()) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); } error: @@ -335,7 +343,7 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args) rc = sqlite3_prepare(self->db, "COMMIT", -1, &statement, &tail); Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); goto error; } @@ -343,14 +351,14 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args) if (rc == SQLITE_DONE) { self->inTransaction = 0; } else { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); } Py_BEGIN_ALLOW_THREADS rc = sqlite3_finalize(statement); Py_END_ALLOW_THREADS if (rc != SQLITE_OK && !PyErr_Occurred()) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); } } @@ -375,13 +383,13 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args } if (self->inTransaction) { - pysqlite_reset_all_statements(self); + pysqlite_do_all_statements(self, ACTION_RESET); Py_BEGIN_ALLOW_THREADS rc = sqlite3_prepare(self->db, "ROLLBACK", -1, &statement, &tail); Py_END_ALLOW_THREADS if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); goto error; } @@ -389,14 +397,14 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args if (rc == SQLITE_DONE) { self->inTransaction = 0; } else { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, statement); } Py_BEGIN_ALLOW_THREADS rc = sqlite3_finalize(statement); Py_END_ALLOW_THREADS if (rc != SQLITE_OK && !PyErr_Occurred()) { - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); } } @@ -746,6 +754,33 @@ static int _authorizer_callback(void* user_arg, int action, const char* arg1, co return rc; } +static int _progress_handler(void* user_arg) +{ + int rc; + PyObject *ret; + PyGILState_STATE gilstate; + + gilstate = PyGILState_Ensure(); + ret = PyObject_CallFunction((PyObject*)user_arg, ""); + + if (!ret) { + if (_enable_callback_tracebacks) { + PyErr_Print(); + } else { + PyErr_Clear(); + } + + /* abort query if error occured */ + rc = 1; + } else { + rc = (int)PyObject_IsTrue(ret); + Py_DECREF(ret); + } + + PyGILState_Release(gilstate); + return rc; +} + PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) { PyObject* authorizer_cb; @@ -771,6 +806,30 @@ PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject } } +PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) +{ + PyObject* progress_handler; + int n; + + static char *kwlist[] = { "progress_handler", "n", NULL }; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Oi:set_progress_handler", + kwlist, &progress_handler, &n)) { + return NULL; + } + + if (progress_handler == Py_None) { + /* None clears the progress handler previously set */ + sqlite3_progress_handler(self->db, 0, 0, (void*)0); + } else { + sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler); + PyDict_SetItem(self->function_pinboard, progress_handler, Py_None); + } + + Py_INCREF(Py_None); + return Py_None; +} + int pysqlite_check_thread(pysqlite_Connection* self) { if (self->check_same_thread) { @@ -881,7 +940,8 @@ PyObject* pysqlite_connection_call(pysqlite_Connection* self, PyObject* args, Py } else if (rc == PYSQLITE_SQL_WRONG_TYPE) { PyErr_SetString(pysqlite_Warning, "SQL is of wrong type. Must be string or unicode."); } else { - _pysqlite_seterror(self->db); + (void)pysqlite_statement_reset(statement); + _pysqlite_seterror(self->db, NULL); } Py_DECREF(statement); @@ -1169,7 +1229,7 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args) (callable != Py_None) ? pysqlite_collation_callback : NULL); if (rc != SQLITE_OK) { PyDict_DelItem(self->collations, uppercase_name); - _pysqlite_seterror(self->db); + _pysqlite_seterror(self->db, NULL); goto finally; } @@ -1247,6 +1307,8 @@ static PyMethodDef connection_methods[] = { PyDoc_STR("Creates a new aggregate. Non-standard.")}, {"set_authorizer", (PyCFunction)pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS, PyDoc_STR("Sets authorizer callback. Non-standard.")}, + {"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS, + PyDoc_STR("Sets progress handler callback. Non-standard.")}, {"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS, PyDoc_STR("Executes a SQL statement. Non-standard.")}, {"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS, diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h index 21fcd2a..dd177ae 100644 --- a/Modules/_sqlite/connection.h +++ b/Modules/_sqlite/connection.h @@ -1,6 +1,6 @@ /* connection.h - definitions for the connection type * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c index 0bd317a..7fd7db3 100644 --- a/Modules/_sqlite/cursor.c +++ b/Modules/_sqlite/cursor.c @@ -1,6 +1,6 @@ /* cursor.c - the cursor type * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * @@ -80,7 +80,7 @@ int pysqlite_cursor_init(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs if (!PyArg_ParseTuple(args, "O!", &pysqlite_ConnectionType, &connection)) { - return -1; + return -1; } Py_INCREF(connection); @@ -255,23 +255,6 @@ PyObject* _pysqlite_build_column_name(const char* colname) PyObject* pysqlite_unicode_from_string(const char* val_str, int optimize) { - const char* check; - int is_ascii = 0; - - if (optimize) { - is_ascii = 1; - - check = val_str; - while (*check) { - if (*check & 0x80) { - is_ascii = 0; - break; - } - - check++; - } - } - return PyUnicode_FromString(val_str); } @@ -432,10 +415,14 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* PyObject* descriptor; PyObject* second_argument = NULL; long rowcount = 0; + int allow_8bit_chars; if (!pysqlite_check_thread(self->connection) || !pysqlite_check_connection(self->connection)) { return NULL; } + /* Make shooting yourself in the foot with not utf-8 decodable 8-bit-strings harder */ + allow_8bit_chars = ((self->connection->text_factory != (PyObject*)&PyUnicode_Type) && + (self->connection->text_factory != (PyObject*)&PyUnicode_Type && pysqlite_OptimizedUnicode)); Py_XDECREF(self->next_row); self->next_row = NULL; @@ -443,7 +430,7 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* if (multiple) { /* executemany() */ if (!PyArg_ParseTuple(args, "OO", &operation, &second_argument)) { - return NULL; + return NULL; } if (!PyUnicode_Check(operation)) { @@ -465,7 +452,7 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* } else { /* execute() */ if (!PyArg_ParseTuple(args, "O|O", &operation, &second_argument)) { - return NULL; + return NULL; } if (!PyUnicode_Check(operation)) { @@ -507,17 +494,48 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* if (operation == NULL) goto error; - /* reset description and rowcount */ + /* reset description */ Py_DECREF(self->description); Py_INCREF(Py_None); self->description = Py_None; - Py_DECREF(self->rowcount); - self->rowcount = PyLong_FromLong(-1L); - if (!self->rowcount) { + func_args = PyTuple_New(1); + if (!func_args) { + goto error; + } + Py_INCREF(operation); + if (PyTuple_SetItem(func_args, 0, operation) != 0) { + goto error; + } + + if (self->statement) { + (void)pysqlite_statement_reset(self->statement); + Py_DECREF(self->statement); + } + + self->statement = (pysqlite_Statement*)pysqlite_cache_get(self->connection->statement_cache, func_args); + Py_DECREF(func_args); + + if (!self->statement) { goto error; } + if (self->statement->in_use) { + Py_DECREF(self->statement); + self->statement = PyObject_New(pysqlite_Statement, &pysqlite_StatementType); + if (!self->statement) { + goto error; + } + rc = pysqlite_statement_create(self->statement, self->connection, operation); + if (rc != SQLITE_OK) { + self->statement = 0; + goto error; + } + } + + pysqlite_statement_reset(self->statement); + pysqlite_statement_mark_dirty(self->statement); + statement_type = detect_statement_type(operation_cstr); if (self->connection->begin_statement) { switch (statement_type) { @@ -599,7 +617,7 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* pysqlite_statement_mark_dirty(self->statement); - pysqlite_statement_bind_parameters(self->statement, parameters); + pysqlite_statement_bind_parameters(self->statement, parameters, allow_8bit_chars); if (PyErr_Occurred()) { goto error; } @@ -627,7 +645,8 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* continue; } else { /* If the database gave us an error, promote it to Python. */ - _pysqlite_seterror(self->connection->db); + (void)pysqlite_statement_reset(self->statement); + _pysqlite_seterror(self->connection->db, NULL); goto error; } } else { @@ -639,17 +658,27 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* PyErr_Clear(); } } - _pysqlite_seterror(self->connection->db); + (void)pysqlite_statement_reset(self->statement); + _pysqlite_seterror(self->connection->db, NULL); goto error; } } + if (pysqlite_build_row_cast_map(self) != 0) { + PyErr_SetString(pysqlite_OperationalError, "Error while building row_cast_map"); + goto error; + } + if (rc == SQLITE_ROW || (rc == SQLITE_DONE && statement_type == STATEMENT_SELECT)) { Py_BEGIN_ALLOW_THREADS numcols = sqlite3_column_count(self->statement->st); Py_END_ALLOW_THREADS if (self->description == Py_None) { + Py_BEGIN_ALLOW_THREADS + numcols = sqlite3_column_count(self->statement->st); + Py_END_ALLOW_THREADS + Py_DECREF(self->description); self->description = PyTuple_New(numcols); if (!self->description) { @@ -690,15 +719,11 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* case STATEMENT_DELETE: case STATEMENT_INSERT: case STATEMENT_REPLACE: - Py_BEGIN_ALLOW_THREADS rowcount += (long)sqlite3_changes(self->connection->db); - Py_END_ALLOW_THREADS - Py_DECREF(self->rowcount); - self->rowcount = PyLong_FromLong(rowcount); } Py_DECREF(self->lastrowid); - if (statement_type == STATEMENT_INSERT) { + if (!multiple && statement_type == STATEMENT_INSERT) { Py_BEGIN_ALLOW_THREADS lastrowid = sqlite3_last_insert_rowid(self->connection->db); Py_END_ALLOW_THREADS @@ -715,13 +740,26 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* } error: + /* just to be sure (implicit ROLLBACKs with ON CONFLICT ROLLBACK/OR + * ROLLBACK could have happened */ + #ifdef SQLITE_VERSION_NUMBER + #if SQLITE_VERSION_NUMBER >= 3002002 + self->connection->inTransaction = !sqlite3_get_autocommit(self->connection->db); + #endif + #endif + Py_XDECREF(parameters); Py_XDECREF(parameters_iter); Py_XDECREF(parameters_list); if (PyErr_Occurred()) { + Py_DECREF(self->rowcount); + self->rowcount = PyLong_FromLong(-1L); return NULL; } else { + Py_DECREF(self->rowcount); + self->rowcount = PyLong_FromLong(rowcount); + Py_INCREF(self); return (PyObject*)self; } @@ -748,7 +786,7 @@ PyObject* pysqlite_cursor_executescript(pysqlite_Cursor* self, PyObject* args) int statement_completed = 0; if (!PyArg_ParseTuple(args, "O", &script_obj)) { - return NULL; + return NULL; } if (!pysqlite_check_thread(self->connection) || !pysqlite_check_connection(self->connection)) { @@ -784,7 +822,7 @@ PyObject* pysqlite_cursor_executescript(pysqlite_Cursor* self, PyObject* args) &statement, &script_cstr); if (rc != SQLITE_OK) { - _pysqlite_seterror(self->connection->db); + _pysqlite_seterror(self->connection->db, NULL); goto error; } @@ -792,17 +830,18 @@ PyObject* pysqlite_cursor_executescript(pysqlite_Cursor* self, PyObject* args) rc = SQLITE_ROW; while (rc == SQLITE_ROW) { rc = _sqlite_step_with_busyhandler(statement, self->connection); + /* TODO: we probably need more error handling here */ } if (rc != SQLITE_DONE) { (void)sqlite3_finalize(statement); - _pysqlite_seterror(self->connection->db); + _pysqlite_seterror(self->connection->db, NULL); goto error; } rc = sqlite3_finalize(statement); if (rc != SQLITE_OK) { - _pysqlite_seterror(self->connection->db); + _pysqlite_seterror(self->connection->db, NULL); goto error; } } @@ -860,8 +899,9 @@ PyObject* pysqlite_cursor_iternext(pysqlite_Cursor *self) if (self->statement) { rc = _sqlite_step_with_busyhandler(self->statement->st, self->connection); if (rc != SQLITE_DONE && rc != SQLITE_ROW) { + (void)pysqlite_statement_reset(self->statement); Py_DECREF(next_row); - _pysqlite_seterror(self->connection->db); + _pysqlite_seterror(self->connection->db, NULL); return NULL; } @@ -886,15 +926,17 @@ PyObject* pysqlite_cursor_fetchone(pysqlite_Cursor* self, PyObject* args) return row; } -PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args) +PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs) { + static char *kwlist[] = {"size", NULL, NULL}; + PyObject* row; PyObject* list; int maxrows = self->arraysize; int counter = 0; - if (!PyArg_ParseTuple(args, "|i", &maxrows)) { - return NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i:fetchmany", kwlist, &maxrows)) { + return NULL; } list = PyList_New(0); @@ -988,7 +1030,7 @@ static PyMethodDef cursor_methods[] = { PyDoc_STR("Executes a multiple SQL statements at once. Non-standard.")}, {"fetchone", (PyCFunction)pysqlite_cursor_fetchone, METH_NOARGS, PyDoc_STR("Fetches one row from the resultset.")}, - {"fetchmany", (PyCFunction)pysqlite_cursor_fetchmany, METH_VARARGS, + {"fetchmany", (PyCFunction)pysqlite_cursor_fetchmany, METH_VARARGS|METH_KEYWORDS, PyDoc_STR("Fetches several rows from the resultset.")}, {"fetchall", (PyCFunction)pysqlite_cursor_fetchall, METH_NOARGS, PyDoc_STR("Fetches all rows from the resultset.")}, diff --git a/Modules/_sqlite/cursor.h b/Modules/_sqlite/cursor.h index 5fce64a..d916ca5 100644 --- a/Modules/_sqlite/cursor.h +++ b/Modules/_sqlite/cursor.h @@ -1,6 +1,6 @@ /* cursor.h - definitions for the cursor type * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * @@ -60,7 +60,7 @@ PyObject* pysqlite_cursor_executemany(pysqlite_Cursor* self, PyObject* args); PyObject* pysqlite_cursor_getiter(pysqlite_Cursor *self); PyObject* pysqlite_cursor_iternext(pysqlite_Cursor *self); PyObject* pysqlite_cursor_fetchone(pysqlite_Cursor* self, PyObject* args); -PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args); +PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs); PyObject* pysqlite_cursor_fetchall(pysqlite_Cursor* self, PyObject* args); PyObject* pysqlite_noop(pysqlite_Connection* self, PyObject* args); PyObject* pysqlite_cursor_close(pysqlite_Cursor* self, PyObject* args); diff --git a/Modules/_sqlite/microprotocols.h b/Modules/_sqlite/microprotocols.h index d84ec93..c911c81 100644 --- a/Modules/_sqlite/microprotocols.h +++ b/Modules/_sqlite/microprotocols.h @@ -28,10 +28,6 @@ #include <Python.h> -#ifdef __cplusplus -extern "C" { -#endif - /** adapters registry **/ extern PyObject *psyco_adapters; diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c index 5c2aaa7..2284eaa 100644 --- a/Modules/_sqlite/module.c +++ b/Modules/_sqlite/module.c @@ -1,25 +1,25 @@ - /* module.c - the module itself - * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> - * - * This file is part of pysqlite. - * - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the authors be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - */ +/* module.c - the module itself + * + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> + * + * This file is part of pysqlite. + * + * This software is provided 'as-is', without any express or implied + * warranty. In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + * claim that you wrote the original software. If you use this software + * in a product, an acknowledgment in the product documentation would be + * appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + * misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + */ #include "connection.h" #include "statement.h" @@ -41,6 +41,7 @@ PyObject* pysqlite_Error, *pysqlite_Warning, *pysqlite_InterfaceError, *pysqlite PyObject* converters; int _enable_callback_tracebacks; +int pysqlite_BaseTypeAdapted; static PyObject* module_connect(PyObject* self, PyObject* args, PyObject* kwargs) @@ -133,6 +134,13 @@ static PyObject* module_register_adapter(PyObject* self, PyObject* args, PyObjec return NULL; } + /* a basic type is adapted; there's a performance optimization if that's not the case + * (99 % of all usages) */ + if (type == &PyLong_Type || type == &PyFloat_Type + || type == &PyUnicode_Type || type == &PyBytes_Type) { + pysqlite_BaseTypeAdapted = 1; + } + microprotocols_add(type, (PyObject*)&pysqlite_PrepareProtocolType, caster); Py_INCREF(Py_None); @@ -379,6 +387,8 @@ PyMODINIT_FUNC init_sqlite3(void) _enable_callback_tracebacks = 0; + pysqlite_BaseTypeAdapted = 0; + /* Original comment form _bsddb.c in the Python core. This is also still * needed nowadays for Python 2.3/2.4. * diff --git a/Modules/_sqlite/module.h b/Modules/_sqlite/module.h index ada6b4c..b14be2a 100644 --- a/Modules/_sqlite/module.h +++ b/Modules/_sqlite/module.h @@ -1,6 +1,6 @@ /* module.h - definitions for the module * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * @@ -25,7 +25,7 @@ #define PYSQLITE_MODULE_H #include "Python.h" -#define PYSQLITE_VERSION "2.3.3" +#define PYSQLITE_VERSION "2.4.1" extern PyObject* pysqlite_Error; extern PyObject* pysqlite_Warning; @@ -51,6 +51,7 @@ extern PyObject* time_sleep; extern PyObject* converters; extern int _enable_callback_tracebacks; +extern int pysqlite_BaseTypeAdapted; #define PARSE_DECLTYPES 1 #define PARSE_COLNAMES 2 diff --git a/Modules/_sqlite/prepare_protocol.h b/Modules/_sqlite/prepare_protocol.h index 4c1e4f3..153472e 100644 --- a/Modules/_sqlite/prepare_protocol.h +++ b/Modules/_sqlite/prepare_protocol.h @@ -1,6 +1,6 @@ /* prepare_protocol.h - the protocol for preparing values for SQLite * - * Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * diff --git a/Modules/_sqlite/row.c b/Modules/_sqlite/row.c index a851579..47b91ed 100644 --- a/Modules/_sqlite/row.c +++ b/Modules/_sqlite/row.c @@ -154,6 +154,11 @@ PyObject* pysqlite_row_keys(pysqlite_Row* self, PyObject* args, PyObject* kwargs return list; } +static int pysqlite_row_print(pysqlite_Row* self, FILE *fp, int flags) +{ + return (&PyTuple_Type)->tp_print(self->data, fp, flags); +} + static PyObject* pysqlite_iter(pysqlite_Row* self) { return PyObject_GetIter(self->data); @@ -178,7 +183,7 @@ PyTypeObject pysqlite_RowType = { sizeof(pysqlite_Row), /* tp_basicsize */ 0, /* tp_itemsize */ (destructor)pysqlite_row_dealloc, /* tp_dealloc */ - 0, /* tp_print */ + (printfunc)pysqlite_row_print, /* tp_print */ 0, /* tp_getattr */ 0, /* tp_setattr */ 0, /* tp_compare */ diff --git a/Modules/_sqlite/row.h b/Modules/_sqlite/row.h index b92225b..8ed69ae 100644 --- a/Modules/_sqlite/row.h +++ b/Modules/_sqlite/row.h @@ -1,6 +1,6 @@ /* row.h - an enhanced tuple for database rows * - * Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c index d280a67..66adff3 100644 --- a/Modules/_sqlite/statement.c +++ b/Modules/_sqlite/statement.c @@ -1,6 +1,6 @@ /* statement.c - the statement type * - * Copyright (C) 2005-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * @@ -40,6 +40,15 @@ typedef enum { NORMAL } parse_remaining_sql_state; +typedef enum { + TYPE_LONG, + TYPE_FLOAT, + TYPE_STRING, + TYPE_UNICODE, + TYPE_BUFFER, + TYPE_UNKNOWN +} parameter_type; + int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql) { const char* tail; @@ -77,52 +86,102 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con return rc; } -int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObject* parameter) +int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObject* parameter, int allow_8bit_chars) { int rc = SQLITE_OK; + long longval; #ifdef HAVE_LONG_LONG PY_LONG_LONG longlongval; -#else - long longval; #endif const char* buffer; char* string; Py_ssize_t buflen; + parameter_type paramtype; + char* c; if (parameter == Py_None) { rc = sqlite3_bind_null(self->st, pos); -#ifdef HAVE_LONG_LONG - } else if (PyLong_Check(parameter)) { - longlongval = PyLong_AsLongLong(parameter); - /* in the overflow error case, longlongval is -1, and an exception is set */ - rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longlongval); -#else + goto final; + } + + if (PyLong_CheckExact(parameter)) { + paramtype = TYPE_LONG; + } else if (PyFloat_CheckExact(parameter)) { + paramtype = TYPE_FLOAT; + } else if (PyUnicode_CheckExact(parameter)) { + paramtype = TYPE_UNICODE; } else if (PyLong_Check(parameter)) { - longval = PyLong_AsLong(parameter); - /* in the overflow error case, longval is -1, and an exception is set */ - rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longval); -#endif + paramtype = TYPE_LONG; } else if (PyFloat_Check(parameter)) { - rc = sqlite3_bind_double(self->st, pos, PyFloat_AsDouble(parameter)); - } else if PyUnicode_Check(parameter) { - string = PyUnicode_AsString(parameter); - - rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); + paramtype = TYPE_FLOAT; + } else if (PyUnicode_Check(parameter)) { + paramtype = TYPE_STRING; } else if (PyObject_CheckBuffer(parameter)) { - if (PyObject_AsCharBuffer(parameter, &buffer, &buflen) == 0) { - rc = sqlite3_bind_blob(self->st, pos, buffer, buflen, SQLITE_TRANSIENT); - } else { - PyErr_SetString(PyExc_ValueError, "could not convert BLOB to buffer"); - rc = -1; - } + paramtype = TYPE_BUFFER; } else { - rc = -1; + paramtype = TYPE_UNKNOWN; } + if (paramtype == TYPE_STRING && !allow_8bit_chars) { + string = PyString_AS_STRING(parameter); + for (c = string; *c != 0; c++) { + if (*c & 0x80) { + PyErr_SetString(pysqlite_ProgrammingError, "You must not use 8-bit bytestrings unless you use a text_factory that can interpret 8-bit bytestrings (like text_factory = str). It is highly recommended that you instead just switch your application to Unicode strings."); + rc = -1; + goto final; + } + } + } + + switch (paramtype) { + case TYPE_LONG: + /* in the overflow error case, longval/longlongval is -1, and an exception is set */ +#ifdef HAVE_LONG_LONG + longlongval = PyLong_AsLongLong(parameter); + rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longlongval); +#else + rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longval); +#endif + break; + case TYPE_FLOAT: + rc = sqlite3_bind_double(self->st, pos, PyFloat_AsDouble(parameter)); + break; + case TYPE_UNICODE: + string = PyUnicode_AsString(parameter); + rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); + break; + case TYPE_BUFFER: + if (PyObject_AsCharBuffer(parameter, &buffer, &buflen) == 0) { + rc = sqlite3_bind_blob(self->st, pos, buffer, buflen, SQLITE_TRANSIENT); + } else { + PyErr_SetString(PyExc_ValueError, "could not convert BLOB to buffer"); + rc = -1; + } + break; + case TYPE_UNKNOWN: + rc = -1; + } + +final: return rc; } -void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* parameters) +/* returns 0 if the object is one of Python's internal ones that don't need to be adapted */ +static int _need_adapt(PyObject* obj) +{ + if (pysqlite_BaseTypeAdapted) { + return 1; + } + + if (PyLong_CheckExact(obj) || PyFloat_CheckExact(obj) + || PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) { + return 0; + } else { + return 1; + } +} + +void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* parameters, int allow_8bit_chars) { PyObject* current_param; PyObject* adapted; @@ -136,7 +195,57 @@ void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* para num_params_needed = sqlite3_bind_parameter_count(self->st); Py_END_ALLOW_THREADS - if (PyDict_Check(parameters)) { + if (PyTuple_CheckExact(parameters) || PyList_CheckExact(parameters) || (!PyDict_Check(parameters) && PySequence_Check(parameters))) { + /* parameters passed as sequence */ + if (PyTuple_CheckExact(parameters)) { + num_params = PyTuple_GET_SIZE(parameters); + } else if (PyList_CheckExact(parameters)) { + num_params = PyList_GET_SIZE(parameters); + } else { + num_params = PySequence_Size(parameters); + } + if (num_params != num_params_needed) { + PyErr_Format(pysqlite_ProgrammingError, "Incorrect number of bindings supplied. The current statement uses %d, and there are %d supplied.", + num_params_needed, num_params); + return; + } + for (i = 0; i < num_params; i++) { + if (PyTuple_CheckExact(parameters)) { + current_param = PyTuple_GET_ITEM(parameters, i); + Py_XINCREF(current_param); + } else if (PyList_CheckExact(parameters)) { + current_param = PyList_GET_ITEM(parameters, i); + Py_XINCREF(current_param); + } else { + current_param = PySequence_GetItem(parameters, i); + } + if (!current_param) { + return; + } + + if (!_need_adapt(current_param)) { + adapted = current_param; + } else { + adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); + if (adapted) { + Py_DECREF(current_param); + } else { + PyErr_Clear(); + adapted = current_param; + } + } + + rc = pysqlite_statement_bind_parameter(self, i + 1, adapted, allow_8bit_chars); + Py_DECREF(adapted); + + if (rc != SQLITE_OK) { + if (!PyErr_Occurred()) { + PyErr_Format(pysqlite_InterfaceError, "Error binding parameter %d - probably unsupported type.", i); + } + return; + } + } + } else if (PyDict_Check(parameters)) { /* parameters passed as dictionary */ for (i = 1; i <= num_params_needed; i++) { Py_BEGIN_ALLOW_THREADS @@ -148,59 +257,41 @@ void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* para } binding_name++; /* skip first char (the colon) */ - current_param = PyDict_GetItemString(parameters, binding_name); + if (PyDict_CheckExact(parameters)) { + current_param = PyDict_GetItemString(parameters, binding_name); + Py_XINCREF(current_param); + } else { + current_param = PyMapping_GetItemString(parameters, (char*)binding_name); + } if (!current_param) { PyErr_Format(pysqlite_ProgrammingError, "You did not supply a value for binding %d.", i); return; } - Py_INCREF(current_param); - adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); - if (adapted) { - Py_DECREF(current_param); - } else { - PyErr_Clear(); + if (!_need_adapt(current_param)) { adapted = current_param; + } else { + adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); + if (adapted) { + Py_DECREF(current_param); + } else { + PyErr_Clear(); + adapted = current_param; + } } - rc = pysqlite_statement_bind_parameter(self, i, adapted); + rc = pysqlite_statement_bind_parameter(self, i, adapted, allow_8bit_chars); Py_DECREF(adapted); if (rc != SQLITE_OK) { - PyErr_Format(pysqlite_InterfaceError, "Error binding parameter :%s - probably unsupported type.", binding_name); + if (!PyErr_Occurred()) { + PyErr_Format(pysqlite_InterfaceError, "Error binding parameter :%s - probably unsupported type.", binding_name); + } return; } } } else { - /* parameters passed as sequence */ - num_params = PySequence_Length(parameters); - if (num_params != num_params_needed) { - PyErr_Format(pysqlite_ProgrammingError, "Incorrect number of bindings supplied. The current statement uses %d, and there are %d supplied.", - num_params_needed, num_params); - return; - } - for (i = 0; i < num_params; i++) { - current_param = PySequence_GetItem(parameters, i); - if (!current_param) { - return; - } - adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); - - if (adapted) { - Py_DECREF(current_param); - } else { - PyErr_Clear(); - adapted = current_param; - } - - rc = pysqlite_statement_bind_parameter(self, i + 1, adapted); - Py_DECREF(adapted); - - if (rc != SQLITE_OK) { - PyErr_Format(pysqlite_InterfaceError, "Error binding parameter %d - probably unsupported type.", i); - return; - } - } + PyErr_SetString(PyExc_ValueError, "parameters are of unsupported type"); } } @@ -400,7 +491,7 @@ PyTypeObject pysqlite_StatementType = { 0, /* tp_getattro */ 0, /* tp_setattro */ 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ 0, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ diff --git a/Modules/_sqlite/statement.h b/Modules/_sqlite/statement.h index 10b8823..bfa2091 100644 --- a/Modules/_sqlite/statement.h +++ b/Modules/_sqlite/statement.h @@ -1,6 +1,6 @@ /* statement.h - definitions for the statement type * - * Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * @@ -46,8 +46,8 @@ extern PyTypeObject pysqlite_StatementType; int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql); void pysqlite_statement_dealloc(pysqlite_Statement* self); -int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObject* parameter); -void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* parameters); +int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObject* parameter, int allow_8bit_chars); +void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* parameters, int allow_8bit_chars); int pysqlite_statement_recompile(pysqlite_Statement* self, PyObject* parameters); int pysqlite_statement_finalize(pysqlite_Statement* self); diff --git a/Modules/_sqlite/util.c b/Modules/_sqlite/util.c index 5e78d58..e06c299 100644 --- a/Modules/_sqlite/util.c +++ b/Modules/_sqlite/util.c @@ -1,6 +1,6 @@ /* util.c - various utility functions * - * Copyright (C) 2005-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * @@ -45,10 +45,15 @@ int _sqlite_step_with_busyhandler(sqlite3_stmt* statement, pysqlite_Connection* * Checks the SQLite error code and sets the appropriate DB-API exception. * Returns the error code (0 means no error occurred). */ -int _pysqlite_seterror(sqlite3* db) +int _pysqlite_seterror(sqlite3* db, sqlite3_stmt* st) { int errorcode; + /* SQLite often doesn't report anything useful, unless you reset the statement first */ + if (st != NULL) { + (void)sqlite3_reset(st); + } + errorcode = sqlite3_errcode(db); switch (errorcode) diff --git a/Modules/_sqlite/util.h b/Modules/_sqlite/util.h index 969c5e5..179be78 100644 --- a/Modules/_sqlite/util.h +++ b/Modules/_sqlite/util.h @@ -1,6 +1,6 @@ /* util.h - various utility functions * - * Copyright (C) 2005-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de> * * This file is part of pysqlite. * @@ -34,5 +34,5 @@ int _sqlite_step_with_busyhandler(sqlite3_stmt* statement, pysqlite_Connection* * Checks the SQLite error code and sets the appropriate DB-API exception. * Returns the error code (0 means no error occurred). */ -int _pysqlite_seterror(sqlite3* db); +int _pysqlite_seterror(sqlite3* db, sqlite3_stmt* st); #endif |