diff options
| -rw-r--r-- | Lib/sqlite3/test/dbapi.py | 42 | ||||
| -rw-r--r-- | Lib/sqlite3/test/hooks.py | 77 | ||||
| -rw-r--r-- | Lib/sqlite3/test/py25tests.py | 80 | ||||
| -rw-r--r-- | Lib/sqlite3/test/regression.py | 76 | ||||
| -rw-r--r-- | Lib/sqlite3/test/transactions.py | 20 | ||||
| -rw-r--r-- | Lib/sqlite3/test/types.py | 11 | ||||
| -rw-r--r-- | Lib/test/test_sqlite.py | 6 | ||||
| -rw-r--r-- | Modules/_sqlite/connection.c | 231 | ||||
| -rw-r--r-- | Modules/_sqlite/connection.h | 7 | ||||
| -rw-r--r-- | Modules/_sqlite/cursor.c | 149 | ||||
| -rw-r--r-- | Modules/_sqlite/cursor.h | 4 | ||||
| -rw-r--r-- | Modules/_sqlite/microprotocols.h | 4 | ||||
| -rw-r--r-- | Modules/_sqlite/module.c | 58 | ||||
| -rw-r--r-- | Modules/_sqlite/module.h | 5 | ||||
| -rw-r--r-- | Modules/_sqlite/statement.c | 216 | ||||
| -rw-r--r-- | Modules/_sqlite/util.c | 9 | ||||
| -rw-r--r-- | Modules/_sqlite/util.h | 2 | 
17 files changed, 781 insertions, 216 deletions
diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index b08da9c..b27486d 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -1,7 +1,7 @@  #-*- coding: ISO-8859-1 -*-  # pysqlite2/test/dbapi.py: tests for DB-API compliance  # -# Copyright (C) 2004-2005 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de>  #  # This file is part of pysqlite.  # @@ -22,6 +22,7 @@  # 3. This notice may not be removed or altered from any source distribution.  import unittest +import sys  import threading  import sqlite3 as sqlite @@ -223,12 +224,45 @@ class CursorTests(unittest.TestCase):          except sqlite.ProgrammingError:              pass +    def CheckExecuteParamList(self): +        self.cu.execute("insert into test(name) values ('foo')") +        self.cu.execute("select name from test where name=?", ["foo"]) +        row = self.cu.fetchone() +        self.failUnlessEqual(row[0], "foo") + +    def CheckExecuteParamSequence(self): +        class L(object): +            def __len__(self): +                return 1 +            def __getitem__(self, x): +                assert x == 0 +                return "foo" + +        self.cu.execute("insert into test(name) values ('foo')") +        self.cu.execute("select name from test where name=?", L()) +        row = self.cu.fetchone() +        self.failUnlessEqual(row[0], "foo") +      def CheckExecuteDictMapping(self):          self.cu.execute("insert into test(name) values ('foo')")          self.cu.execute("select name from test where name=:name", {"name": "foo"})          row = self.cu.fetchone()          self.failUnlessEqual(row[0], "foo") +    def CheckExecuteDictMapping_Mapping(self): +        # Test only works with Python 2.5 or later +        if sys.version_info < (2, 5, 0): +            return + +        class D(dict): +            def __missing__(self, key): +                return "foo" + +        self.cu.execute("insert into test(name) values ('foo')") +        self.cu.execute("select name from test where name=:name", D()) +        row = self.cu.fetchone() +        self.failUnlessEqual(row[0], "foo") +      def CheckExecuteDictMappingTooLittleArgs(self):          self.cu.execute("insert into test(name) values ('foo')")          try: @@ -378,6 +412,12 @@ class CursorTests(unittest.TestCase):          res = self.cu.fetchmany(100)          self.failUnlessEqual(res, []) +    def CheckFetchmanyKwArg(self): +        """Checks if fetchmany works with keyword arguments""" +        self.cu.execute("select name from test") +        res = self.cu.fetchmany(size=100) +        self.failUnlessEqual(len(res), 1) +      def CheckFetchall(self):          self.cu.execute("select name from test")          res = self.cu.fetchall() diff --git a/Lib/sqlite3/test/hooks.py b/Lib/sqlite3/test/hooks.py index cb0a621..547dc65 100644 --- a/Lib/sqlite3/test/hooks.py +++ b/Lib/sqlite3/test/hooks.py @@ -1,7 +1,7 @@  #-*- coding: ISO-8859-1 -*-  # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks  # -# Copyright (C) 2006 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de>  #  # This file is part of pysqlite.  # @@ -21,7 +21,7 @@  #    misrepresented as being the original software.  # 3. This notice may not be removed or altered from any source distribution. -import unittest +import os, unittest  import sqlite3 as sqlite  class CollationTests(unittest.TestCase): @@ -105,9 +105,80 @@ class CollationTests(unittest.TestCase):              if not e.args[0].startswith("no such collation sequence"):                  self.fail("wrong OperationalError raised") +class ProgressTests(unittest.TestCase): +    def CheckProgressHandlerUsed(self): +        """ +        Test that the progress handler is invoked once it is set. +        """ +        con = sqlite.connect(":memory:") +        progress_calls = [] +        def progress(): +            progress_calls.append(None) +            return 0 +        con.set_progress_handler(progress, 1) +        con.execute(""" +            create table foo(a, b) +            """) +        self.failUnless(progress_calls) + + +    def CheckOpcodeCount(self): +        """ +        Test that the opcode argument is respected. +        """ +        con = sqlite.connect(":memory:") +        progress_calls = [] +        def progress(): +            progress_calls.append(None) +            return 0 +        con.set_progress_handler(progress, 1) +        curs = con.cursor() +        curs.execute(""" +            create table foo (a, b) +            """) +        first_count = len(progress_calls) +        progress_calls = [] +        con.set_progress_handler(progress, 2) +        curs.execute(""" +            create table bar (a, b) +            """) +        second_count = len(progress_calls) +        self.failUnless(first_count > second_count) + +    def CheckCancelOperation(self): +        """ +        Test that returning a non-zero value stops the operation in progress. +        """ +        con = sqlite.connect(":memory:") +        progress_calls = [] +        def progress(): +            progress_calls.append(None) +            return 1 +        con.set_progress_handler(progress, 1) +        curs = con.cursor() +        self.assertRaises( +            sqlite.OperationalError, +            curs.execute, +            "create table bar (a, b)") + +    def CheckClearHandler(self): +        """ +        Test that setting the progress handler to None clears the previously set handler. +        """ +        con = sqlite.connect(":memory:") +        action = 0 +        def progress(): +            action = 1 +            return 0 +        con.set_progress_handler(progress, 1) +        con.set_progress_handler(None, 1) +        con.execute("select 1 union select 2 union select 3").fetchall() +        self.failUnlessEqual(action, 0, "progress handler was not cleared") +  def suite():      collation_suite = unittest.makeSuite(CollationTests, "Check") -    return unittest.TestSuite((collation_suite,)) +    progress_suite = unittest.makeSuite(ProgressTests, "Check") +    return unittest.TestSuite((collation_suite, progress_suite))  def test():      runner = unittest.TextTestRunner() diff --git a/Lib/sqlite3/test/py25tests.py b/Lib/sqlite3/test/py25tests.py new file mode 100644 index 0000000..bce26b9 --- /dev/null +++ b/Lib/sqlite3/test/py25tests.py @@ -0,0 +1,80 @@ +#-*- coding: ISO-8859-1 -*- +# pysqlite2/test/regression.py: pysqlite regression tests +# +# Copyright (C) 2007 Gerhard Häring <gh@ghaering.de> +# +# This file is part of pysqlite. +# +# This software is provided 'as-is', without any express or implied +# warranty.  In no event will the authors be held liable for any damages +# arising from the use of this software. +# +# Permission is granted to anyone to use this software for any purpose, +# including commercial applications, and to alter it and redistribute it +# freely, subject to the following restrictions: +# +# 1. The origin of this software must not be misrepresented; you must not +#    claim that you wrote the original software. If you use this software +#    in a product, an acknowledgment in the product documentation would be +#    appreciated but is not required. +# 2. Altered source versions must be plainly marked as such, and must not be +#    misrepresented as being the original software. +# 3. This notice may not be removed or altered from any source distribution. + +from __future__ import with_statement +import unittest +import sqlite3 as sqlite + +did_rollback = False + +class MyConnection(sqlite.Connection): +    def rollback(self): +        global did_rollback +        did_rollback = True +        sqlite.Connection.rollback(self) + +class ContextTests(unittest.TestCase): +    def setUp(self): +        global did_rollback +        self.con = sqlite.connect(":memory:", factory=MyConnection) +        self.con.execute("create table test(c unique)") +        did_rollback = False + +    def tearDown(self): +        self.con.close() + +    def CheckContextManager(self): +        """Can the connection be used as a context manager at all?""" +        with self.con: +            pass + +    def CheckContextManagerCommit(self): +        """Is a commit called in the context manager?""" +        with self.con: +            self.con.execute("insert into test(c) values ('foo')") +        self.con.rollback() +        count = self.con.execute("select count(*) from test").fetchone()[0] +        self.failUnlessEqual(count, 1) + +    def CheckContextManagerRollback(self): +        """Is a rollback called in the context manager?""" +        global did_rollback +        self.failUnlessEqual(did_rollback, False) +        try: +            with self.con: +                self.con.execute("insert into test(c) values (4)") +                self.con.execute("insert into test(c) values (4)") +        except sqlite.IntegrityError: +            pass +        self.failUnlessEqual(did_rollback, True) + +def suite(): +    ctx_suite = unittest.makeSuite(ContextTests, "Check") +    return unittest.TestSuite((ctx_suite,)) + +def test(): +    runner = unittest.TextTestRunner() +    runner.run(suite()) + +if __name__ == "__main__": +    test() diff --git a/Lib/sqlite3/test/regression.py b/Lib/sqlite3/test/regression.py index addedb1..45eae90 100644 --- a/Lib/sqlite3/test/regression.py +++ b/Lib/sqlite3/test/regression.py @@ -1,7 +1,7 @@  #-*- coding: ISO-8859-1 -*-  # pysqlite2/test/regression.py: pysqlite regression tests  # -# Copyright (C) 2006 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de>  #  # This file is part of pysqlite.  # @@ -21,6 +21,7 @@  #    misrepresented as being the original software.  # 3. This notice may not be removed or altered from any source distribution. +import datetime  import unittest  import sqlite3 as sqlite @@ -79,6 +80,79 @@ class RegressionTests(unittest.TestCase):          cur.fetchone()          cur.fetchone() +    def CheckStatementFinalizationOnCloseDb(self): +        # pysqlite versions <= 2.3.3 only finalized statements in the statement +        # cache when closing the database. statements that were still +        # referenced in cursors weren't closed an could provoke " +        # "OperationalError: Unable to close due to unfinalised statements". +        con = sqlite.connect(":memory:") +        cursors = [] +        # default statement cache size is 100 +        for i in range(105): +            cur = con.cursor() +            cursors.append(cur) +            cur.execute("select 1 x union select " + str(i)) +        con.close() + +    def CheckOnConflictRollback(self): +        if sqlite.sqlite_version_info < (3, 2, 2): +            return +        con = sqlite.connect(":memory:") +        con.execute("create table foo(x, unique(x) on conflict rollback)") +        con.execute("insert into foo(x) values (1)") +        try: +            con.execute("insert into foo(x) values (1)") +        except sqlite.DatabaseError: +            pass +        con.execute("insert into foo(x) values (2)") +        try: +            con.commit() +        except sqlite.OperationalError: +            self.fail("pysqlite knew nothing about the implicit ROLLBACK") + +    def CheckWorkaroundForBuggySqliteTransferBindings(self): +        """ +        pysqlite would crash with older SQLite versions unless +        a workaround is implemented. +        """ +        self.con.execute("create table if not exists foo(bar)") +        self.con.execute("create table if not exists foo(bar)") + +    def CheckEmptyStatement(self): +        """ +        pysqlite used to segfault with SQLite versions 3.5.x. These return NULL +        for "no-operation" statements +        """ +        self.con.execute("") + +    def CheckUnicodeConnect(self): +        """ +        With pysqlite 2.4.0 you needed to use a string or a APSW connection +        object for opening database connections. + +        Formerly, both bytestrings and unicode strings used to work. + +        Let's make sure unicode strings work in the future. +        """ +        con = sqlite.connect(u":memory:") +        con.close() + +    def CheckTypeMapUsage(self): +        """ +        pysqlite until 2.4.1 did not rebuild the row_cast_map when recompiling +        a statement. This test exhibits the problem. +        """ +        SELECT = "select * from foo" +        con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES) +        con.execute("create table foo(bar timestamp)") +        con.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),)) +        con.execute(SELECT) +        con.execute("drop table foo") +        con.execute("create table foo(bar integer)") +        con.execute("insert into foo(bar) values (5)") +        con.execute(SELECT) + +  def suite():      regression_suite = unittest.makeSuite(RegressionTests, "Check")      return unittest.TestSuite((regression_suite,)) diff --git a/Lib/sqlite3/test/transactions.py b/Lib/sqlite3/test/transactions.py index 1f0b19a..14cae25 100644 --- a/Lib/sqlite3/test/transactions.py +++ b/Lib/sqlite3/test/transactions.py @@ -1,7 +1,7 @@  #-*- coding: ISO-8859-1 -*-  # pysqlite2/test/transactions.py: tests transactions  # -# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>  #  # This file is part of pysqlite.  # @@ -21,6 +21,7 @@  #    misrepresented as being the original software.  # 3. This notice may not be removed or altered from any source distribution. +import sys  import os, unittest  import sqlite3 as sqlite @@ -119,6 +120,23 @@ class TransactionTests(unittest.TestCase):          except:              self.fail("should have raised an OperationalError") +    def CheckLocking(self): +        """ +        This tests the improved concurrency with pysqlite 2.3.4. You needed +        to roll back con2 before you could commit con1. +        """ +        self.cur1.execute("create table test(i)") +        self.cur1.execute("insert into test(i) values (5)") +        try: +            self.cur2.execute("insert into test(i) values (5)") +            self.fail("should have raised an OperationalError") +        except sqlite.OperationalError: +            pass +        except: +            self.fail("should have raised an OperationalError") +        # NO self.con2.rollback() HERE!!! +        self.con1.commit() +  class SpecialCommandTests(unittest.TestCase):      def setUp(self):          self.con = sqlite.connect(":memory:") diff --git a/Lib/sqlite3/test/types.py b/Lib/sqlite3/test/types.py index 3cc9aff..1970401 100644 --- a/Lib/sqlite3/test/types.py +++ b/Lib/sqlite3/test/types.py @@ -1,7 +1,7 @@  #-*- coding: ISO-8859-1 -*-  # pysqlite2/test/types.py: tests for type conversion and detection  # -# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de> +# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>  #  # This file is part of pysqlite.  # @@ -21,7 +21,7 @@  #    misrepresented as being the original software.  # 3. This notice may not be removed or altered from any source distribution. -import bz2, datetime +import zlib, datetime  import unittest  import sqlite3 as sqlite @@ -287,7 +287,7 @@ class ObjectAdaptationTests(unittest.TestCase):  class BinaryConverterTests(unittest.TestCase):      def convert(s): -        return bz2.decompress(s) +        return zlib.decompress(s)      convert = staticmethod(convert)      def setUp(self): @@ -299,7 +299,7 @@ class BinaryConverterTests(unittest.TestCase):      def CheckBinaryInputForConverter(self):          testdata = "abcdefg" * 10 -        result = self.con.execute('select ? as "x [bin]"', (buffer(bz2.compress(testdata)),)).fetchone()[0] +        result = self.con.execute('select ? as "x [bin]"', (buffer(zlib.compress(testdata)),)).fetchone()[0]          self.failUnlessEqual(testdata, result)  class DateTimeTests(unittest.TestCase): @@ -331,7 +331,8 @@ class DateTimeTests(unittest.TestCase):          if sqlite.sqlite_version_info < (3, 1):              return -        now = datetime.datetime.utcnow() +        # SQLite's current_timestamp uses UTC time, while datetime.datetime.now() uses local time. +        now = datetime.datetime.now()          self.cur.execute("insert into test(ts) values (current_timestamp)")          self.cur.execute("select ts from test")          ts = self.cur.fetchone()[0] diff --git a/Lib/test/test_sqlite.py b/Lib/test/test_sqlite.py index c1523e1..3566f31 100644 --- a/Lib/test/test_sqlite.py +++ b/Lib/test/test_sqlite.py @@ -4,13 +4,13 @@ try:      import _sqlite3  except ImportError:      raise TestSkipped('no sqlite available') -from sqlite3.test import (dbapi, types, userfunctions, +from sqlite3.test import (dbapi, types, userfunctions, py25tests,                                  factory, transactions, hooks, regression)  def test_main():      run_unittest(dbapi.suite(), types.suite(), userfunctions.suite(), -                 factory.suite(), transactions.suite(), hooks.suite(), -                 regression.suite()) +                 py25tests.suite(), factory.suite(), transactions.suite(), +                 hooks.suite(), regression.suite())  if __name__ == "__main__":      test_main() diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index f65748a..1ce275c 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -1,6 +1,6 @@  /* connection.c - the connection type   * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de>   *   * This file is part of pysqlite.   *  @@ -32,6 +32,9 @@  #include "pythread.h" +#define ACTION_FINALIZE 1 +#define ACTION_RESET 2 +  static int pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* isolation_level); @@ -51,7 +54,7 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject  {      static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL}; -    char* database; +    PyObject* database;      int detect_types = 0;      PyObject* isolation_level = NULL;      PyObject* factory = NULL; @@ -59,11 +62,15 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject      int cached_statements = 100;      double timeout = 5.0;      int rc; +    PyObject* class_attr = NULL; +    PyObject* class_attr_str = NULL; +    int is_apsw_connection = 0; +    PyObject* database_utf8; -    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist, +    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|diOiOi", kwlist,                                       &database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements))      { -        return -1;  +        return -1;      }      self->begin_statement = NULL; @@ -77,13 +84,53 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject      Py_INCREF(&PyUnicode_Type);      self->text_factory = (PyObject*)&PyUnicode_Type; -    Py_BEGIN_ALLOW_THREADS -    rc = sqlite3_open(database, &self->db); -    Py_END_ALLOW_THREADS +    if (PyString_Check(database) || PyUnicode_Check(database)) { +        if (PyString_Check(database)) { +            database_utf8 = database; +            Py_INCREF(database_utf8); +        } else { +            database_utf8 = PyUnicode_AsUTF8String(database); +            if (!database_utf8) { +                return -1; +            } +        } -    if (rc != SQLITE_OK) { -        _pysqlite_seterror(self->db); -        return -1; +        Py_BEGIN_ALLOW_THREADS +        rc = sqlite3_open(PyString_AsString(database_utf8), &self->db); +        Py_END_ALLOW_THREADS + +        Py_DECREF(database_utf8); + +        if (rc != SQLITE_OK) { +            _pysqlite_seterror(self->db, NULL); +            return -1; +        } +    } else { +        /* Create a pysqlite connection from a APSW connection */ +        class_attr = PyObject_GetAttrString(database, "__class__"); +        if (class_attr) { +            class_attr_str = PyObject_Str(class_attr); +            if (class_attr_str) { +                if (strcmp(PyString_AsString(class_attr_str), "<type 'apsw.Connection'>") == 0) { +                    /* In the APSW Connection object, the first entry after +                     * PyObject_HEAD is the sqlite3* we want to get hold of. +                     * Luckily, this is the same layout as we have in our +                     * pysqlite_Connection */ +                    self->db = ((pysqlite_Connection*)database)->db; + +                    Py_INCREF(database); +                    self->apsw_connection = database; +                    is_apsw_connection = 1; +                } +            } +        } +        Py_XDECREF(class_attr_str); +        Py_XDECREF(class_attr); + +        if (!is_apsw_connection) { +            PyErr_SetString(PyExc_ValueError, "database parameter must be string or APSW Connection object"); +            return -1; +        }      }      if (!isolation_level) { @@ -169,7 +216,8 @@ void pysqlite_flush_statement_cache(pysqlite_Connection* self)      self->statement_cache->decref_factory = 0;  } -void pysqlite_reset_all_statements(pysqlite_Connection* self) +/* action in (ACTION_RESET, ACTION_FINALIZE) */ +void pysqlite_do_all_statements(pysqlite_Connection* self, int action)  {      int i;      PyObject* weakref; @@ -179,13 +227,19 @@ void pysqlite_reset_all_statements(pysqlite_Connection* self)          weakref = PyList_GetItem(self->statements, i);          statement = PyWeakref_GetObject(weakref);          if (statement != Py_None) { -            (void)pysqlite_statement_reset((pysqlite_Statement*)statement); +            if (action == ACTION_RESET) { +                (void)pysqlite_statement_reset((pysqlite_Statement*)statement); +            } else { +                (void)pysqlite_statement_finalize((pysqlite_Statement*)statement); +            }          }      }  }  void pysqlite_connection_dealloc(pysqlite_Connection* self)  { +    PyObject* ret = NULL; +      Py_XDECREF(self->statement_cache);      /* Clean up if user has not called .close() explicitly. */ @@ -193,6 +247,10 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self)          Py_BEGIN_ALLOW_THREADS          sqlite3_close(self->db);          Py_END_ALLOW_THREADS +    } else if (self->apsw_connection) { +        ret = PyObject_CallMethod(self->apsw_connection, "close", ""); +        Py_XDECREF(ret); +        Py_XDECREF(self->apsw_connection);      }      if (self->begin_statement) { @@ -205,7 +263,7 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self)      Py_XDECREF(self->collations);      Py_XDECREF(self->statements); -    Py_TYPE(self)->tp_free((PyObject*)self); +    self->ob_type->tp_free((PyObject*)self);  }  PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) @@ -241,24 +299,33 @@ PyObject* pysqlite_connection_cursor(pysqlite_Connection* self, PyObject* args,  PyObject* pysqlite_connection_close(pysqlite_Connection* self, PyObject* args)  { +    PyObject* ret;      int rc;      if (!pysqlite_check_thread(self)) {          return NULL;      } -    pysqlite_flush_statement_cache(self); +    pysqlite_do_all_statements(self, ACTION_FINALIZE);      if (self->db) { -        Py_BEGIN_ALLOW_THREADS -        rc = sqlite3_close(self->db); -        Py_END_ALLOW_THREADS - -        if (rc != SQLITE_OK) { -            _pysqlite_seterror(self->db); -            return NULL; -        } else { +        if (self->apsw_connection) { +            ret = PyObject_CallMethod(self->apsw_connection, "close", ""); +            Py_XDECREF(ret); +            Py_XDECREF(self->apsw_connection); +            self->apsw_connection = NULL;              self->db = NULL; +        } else { +            Py_BEGIN_ALLOW_THREADS +            rc = sqlite3_close(self->db); +            Py_END_ALLOW_THREADS + +            if (rc != SQLITE_OK) { +                _pysqlite_seterror(self->db, NULL); +                return NULL; +            } else { +                self->db = NULL; +            }          }      } @@ -292,7 +359,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self)      Py_END_ALLOW_THREADS      if (rc != SQLITE_OK) { -        _pysqlite_seterror(self->db); +        _pysqlite_seterror(self->db, statement);          goto error;      } @@ -300,7 +367,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self)      if (rc == SQLITE_DONE) {          self->inTransaction = 1;      } else { -        _pysqlite_seterror(self->db); +        _pysqlite_seterror(self->db, statement);      }      Py_BEGIN_ALLOW_THREADS @@ -308,7 +375,7 @@ PyObject* _pysqlite_connection_begin(pysqlite_Connection* self)      Py_END_ALLOW_THREADS      if (rc != SQLITE_OK && !PyErr_Occurred()) { -        _pysqlite_seterror(self->db); +        _pysqlite_seterror(self->db, NULL);      }  error: @@ -335,7 +402,7 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args)          rc = sqlite3_prepare(self->db, "COMMIT", -1, &statement, &tail);          Py_END_ALLOW_THREADS          if (rc != SQLITE_OK) { -            _pysqlite_seterror(self->db); +            _pysqlite_seterror(self->db, NULL);              goto error;          } @@ -343,14 +410,14 @@ PyObject* pysqlite_connection_commit(pysqlite_Connection* self, PyObject* args)          if (rc == SQLITE_DONE) {              self->inTransaction = 0;          } else { -            _pysqlite_seterror(self->db); +            _pysqlite_seterror(self->db, statement);          }          Py_BEGIN_ALLOW_THREADS          rc = sqlite3_finalize(statement);          Py_END_ALLOW_THREADS          if (rc != SQLITE_OK && !PyErr_Occurred()) { -            _pysqlite_seterror(self->db); +            _pysqlite_seterror(self->db, NULL);          }      } @@ -375,13 +442,13 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args      }      if (self->inTransaction) { -        pysqlite_reset_all_statements(self); +        pysqlite_do_all_statements(self, ACTION_RESET);          Py_BEGIN_ALLOW_THREADS          rc = sqlite3_prepare(self->db, "ROLLBACK", -1, &statement, &tail);          Py_END_ALLOW_THREADS          if (rc != SQLITE_OK) { -            _pysqlite_seterror(self->db); +            _pysqlite_seterror(self->db, NULL);              goto error;          } @@ -389,14 +456,14 @@ PyObject* pysqlite_connection_rollback(pysqlite_Connection* self, PyObject* args          if (rc == SQLITE_DONE) {              self->inTransaction = 0;          } else { -            _pysqlite_seterror(self->db); +            _pysqlite_seterror(self->db, statement);          }          Py_BEGIN_ALLOW_THREADS          rc = sqlite3_finalize(statement);          Py_END_ALLOW_THREADS          if (rc != SQLITE_OK && !PyErr_Occurred()) { -            _pysqlite_seterror(self->db); +            _pysqlite_seterror(self->db, NULL);          }      } @@ -762,6 +829,33 @@ static int _authorizer_callback(void* user_arg, int action, const char* arg1, co      return rc;  } +static int _progress_handler(void* user_arg) +{ +    int rc; +    PyObject *ret; +    PyGILState_STATE gilstate; + +    gilstate = PyGILState_Ensure(); +    ret = PyObject_CallFunction((PyObject*)user_arg, ""); + +    if (!ret) { +        if (_enable_callback_tracebacks) { +            PyErr_Print(); +        } else { +            PyErr_Clear(); +        } + +        /* abort query if error occured */ +        rc = 1;  +    } else { +        rc = (int)PyObject_IsTrue(ret); +    } + +    Py_DECREF(ret); +    PyGILState_Release(gilstate); +    return rc; +} +  PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)  {      PyObject* authorizer_cb; @@ -787,6 +881,30 @@ PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject      }  } +PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) +{ +    PyObject* progress_handler; +    int n; + +    static char *kwlist[] = { "progress_handler", "n", NULL }; + +    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "Oi:set_progress_handler", +                                      kwlist, &progress_handler, &n)) { +        return NULL; +    } + +    if (progress_handler == Py_None) { +        /* None clears the progress handler previously set */ +        sqlite3_progress_handler(self->db, 0, 0, (void*)0); +    } else { +        sqlite3_progress_handler(self->db, n, _progress_handler, progress_handler); +        PyDict_SetItem(self->function_pinboard, progress_handler, Py_None); +    } + +    Py_INCREF(Py_None); +    return Py_None; +} +  int pysqlite_check_thread(pysqlite_Connection* self)  {      if (self->check_same_thread) { @@ -892,7 +1010,8 @@ PyObject* pysqlite_connection_call(pysqlite_Connection* self, PyObject* args, Py          } else if (rc == PYSQLITE_SQL_WRONG_TYPE) {              PyErr_SetString(pysqlite_Warning, "SQL is of wrong type. Must be string or unicode.");          } else { -            _pysqlite_seterror(self->db); +            (void)pysqlite_statement_reset(statement); +            _pysqlite_seterror(self->db, NULL);          }          Py_DECREF(statement); @@ -1134,7 +1253,7 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args)                                    (callable != Py_None) ? pysqlite_collation_callback : NULL);      if (rc != SQLITE_OK) {          PyDict_DelItem(self->collations, uppercase_name); -        _pysqlite_seterror(self->db); +        _pysqlite_seterror(self->db, NULL);          goto finally;      } @@ -1151,6 +1270,44 @@ finally:      return retval;  } +/* Called when the connection is used as a context manager. Returns itself as a + * convenience to the caller. */ +static PyObject * +pysqlite_connection_enter(pysqlite_Connection* self, PyObject* args) +{ +    Py_INCREF(self); +    return (PyObject*)self; +} + +/** Called when the connection is used as a context manager. If there was any + * exception, a rollback takes place; otherwise we commit. */ +static PyObject * +pysqlite_connection_exit(pysqlite_Connection* self, PyObject* args) +{ +    PyObject* exc_type, *exc_value, *exc_tb; +    char* method_name; +    PyObject* result; + +    if (!PyArg_ParseTuple(args, "OOO", &exc_type, &exc_value, &exc_tb)) { +        return NULL; +    } + +    if (exc_type == Py_None && exc_value == Py_None && exc_tb == Py_None) { +        method_name = "commit"; +    } else { +        method_name = "rollback"; +    } + +    result = PyObject_CallMethod((PyObject*)self, method_name, ""); +    if (!result) { +        return NULL; +    } +    Py_DECREF(result); + +    Py_INCREF(Py_False); +    return Py_False; +} +  static char connection_doc[] =  PyDoc_STR("SQLite database connection object."); @@ -1175,6 +1332,8 @@ static PyMethodDef connection_methods[] = {          PyDoc_STR("Creates a new aggregate. Non-standard.")},      {"set_authorizer", (PyCFunction)pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS,          PyDoc_STR("Sets authorizer callback. Non-standard.")}, +    {"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS, +        PyDoc_STR("Sets progress handler callback. Non-standard.")},      {"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS,          PyDoc_STR("Executes a SQL statement. Non-standard.")},      {"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS, @@ -1185,6 +1344,10 @@ static PyMethodDef connection_methods[] = {          PyDoc_STR("Creates a collation function. Non-standard.")},      {"interrupt", (PyCFunction)pysqlite_connection_interrupt, METH_NOARGS,          PyDoc_STR("Abort any pending database operation. Non-standard.")}, +    {"__enter__", (PyCFunction)pysqlite_connection_enter, METH_NOARGS, +        PyDoc_STR("For context manager. Non-standard.")}, +    {"__exit__", (PyCFunction)pysqlite_connection_exit, METH_VARARGS, +        PyDoc_STR("For context manager. Non-standard.")},      {NULL, NULL}  }; diff --git a/Modules/_sqlite/connection.h b/Modules/_sqlite/connection.h index 21fcd2a..3b1c632 100644 --- a/Modules/_sqlite/connection.h +++ b/Modules/_sqlite/connection.h @@ -1,6 +1,6 @@  /* connection.h - definitions for the connection type   * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de>   *   * This file is part of pysqlite.   * @@ -95,6 +95,11 @@ typedef struct      /* a dictionary of registered collation name => collation callable mappings */      PyObject* collations; +    /* if our connection was created from a APSW connection, we keep a +     * reference to the APSW connection around and get rid of it in our +     * destructor */ +    PyObject* apsw_connection; +      /* Exception objects */      PyObject* Warning;      PyObject* Error; diff --git a/Modules/_sqlite/cursor.c b/Modules/_sqlite/cursor.c index 875d55b..566e4ff 100644 --- a/Modules/_sqlite/cursor.c +++ b/Modules/_sqlite/cursor.c @@ -1,6 +1,6 @@  /* cursor.c - the cursor type   * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de>   *   * This file is part of pysqlite.   * @@ -80,7 +80,7 @@ int pysqlite_cursor_init(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs      if (!PyArg_ParseTuple(args, "O!", &pysqlite_ConnectionType, &connection))      { -        return -1;  +        return -1;      }      Py_INCREF(connection); @@ -435,7 +435,7 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject*      if (multiple) {          /* executemany() */          if (!PyArg_ParseTuple(args, "OO", &operation, &second_argument)) { -            return NULL;  +            return NULL;          }          if (!PyString_Check(operation) && !PyUnicode_Check(operation)) { @@ -457,7 +457,7 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject*      } else {          /* execute() */          if (!PyArg_ParseTuple(args, "O|O", &operation, &second_argument)) { -            return NULL;  +            return NULL;          }          if (!PyString_Check(operation) && !PyUnicode_Check(operation)) { @@ -506,16 +506,47 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject*          operation_cstr = PyString_AsString(operation_bytestr);      } -    /* reset description and rowcount */ +    /* reset description */      Py_DECREF(self->description);      Py_INCREF(Py_None);      self->description = Py_None; -    Py_DECREF(self->rowcount); -    self->rowcount = PyInt_FromLong(-1L); -    if (!self->rowcount) { +    func_args = PyTuple_New(1); +    if (!func_args) {          goto error;      } +    Py_INCREF(operation); +    if (PyTuple_SetItem(func_args, 0, operation) != 0) { +        goto error; +    } + +    if (self->statement) { +        (void)pysqlite_statement_reset(self->statement); +        Py_DECREF(self->statement); +    } + +    self->statement = (pysqlite_Statement*)pysqlite_cache_get(self->connection->statement_cache, func_args); +    Py_DECREF(func_args); + +    if (!self->statement) { +        goto error; +    } + +    if (self->statement->in_use) { +        Py_DECREF(self->statement); +        self->statement = PyObject_New(pysqlite_Statement, &pysqlite_StatementType); +        if (!self->statement) { +            goto error; +        } +        rc = pysqlite_statement_create(self->statement, self->connection, operation); +        if (rc != SQLITE_OK) { +            self->statement = 0; +            goto error; +        } +    } + +    pysqlite_statement_reset(self->statement); +    pysqlite_statement_mark_dirty(self->statement);      statement_type = detect_statement_type(operation_cstr);      if (self->connection->begin_statement) { @@ -553,43 +584,6 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject*          }      } -    func_args = PyTuple_New(1); -    if (!func_args) { -        goto error; -    } -    Py_INCREF(operation); -    if (PyTuple_SetItem(func_args, 0, operation) != 0) { -        goto error; -    } - -    if (self->statement) { -        (void)pysqlite_statement_reset(self->statement); -        Py_DECREF(self->statement); -    } - -    self->statement = (pysqlite_Statement*)pysqlite_cache_get(self->connection->statement_cache, func_args); -    Py_DECREF(func_args); - -    if (!self->statement) { -        goto error; -    } - -    if (self->statement->in_use) { -        Py_DECREF(self->statement); -        self->statement = PyObject_New(pysqlite_Statement, &pysqlite_StatementType); -        if (!self->statement) { -            goto error; -        } -        rc = pysqlite_statement_create(self->statement, self->connection, operation); -        if (rc != SQLITE_OK) { -            self->statement = 0; -            goto error; -        } -    } - -    pysqlite_statement_reset(self->statement); -    pysqlite_statement_mark_dirty(self->statement); -      while (1) {          parameters = PyIter_Next(parameters_iter);          if (!parameters) { @@ -603,11 +597,6 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject*              goto error;          } -        if (pysqlite_build_row_cast_map(self) != 0) { -            PyErr_SetString(pysqlite_OperationalError, "Error while building row_cast_map"); -            goto error; -        } -          /* Keep trying the SQL statement until the schema stops changing. */          while (1) {              /* Actually execute the SQL statement. */ @@ -626,7 +615,8 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject*                      continue;                  } else {                      /* If the database gave us an error, promote it to Python. */ -                    _pysqlite_seterror(self->connection->db); +                    (void)pysqlite_statement_reset(self->statement); +                    _pysqlite_seterror(self->connection->db, NULL);                      goto error;                  }              } else { @@ -638,17 +628,23 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject*                          PyErr_Clear();                      }                  } -                _pysqlite_seterror(self->connection->db); +                (void)pysqlite_statement_reset(self->statement); +                _pysqlite_seterror(self->connection->db, NULL);                  goto error;              }          } -        if (rc == SQLITE_ROW || (rc == SQLITE_DONE && statement_type == STATEMENT_SELECT)) { -            Py_BEGIN_ALLOW_THREADS -            numcols = sqlite3_column_count(self->statement->st); -            Py_END_ALLOW_THREADS +        if (pysqlite_build_row_cast_map(self) != 0) { +            PyErr_SetString(pysqlite_OperationalError, "Error while building row_cast_map"); +            goto error; +        } +        if (rc == SQLITE_ROW || (rc == SQLITE_DONE && statement_type == STATEMENT_SELECT)) {              if (self->description == Py_None) { +                Py_BEGIN_ALLOW_THREADS +                numcols = sqlite3_column_count(self->statement->st); +                Py_END_ALLOW_THREADS +                  Py_DECREF(self->description);                  self->description = PyTuple_New(numcols);                  if (!self->description) { @@ -689,15 +685,11 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject*              case STATEMENT_DELETE:              case STATEMENT_INSERT:              case STATEMENT_REPLACE: -                Py_BEGIN_ALLOW_THREADS                  rowcount += (long)sqlite3_changes(self->connection->db); -                Py_END_ALLOW_THREADS -                Py_DECREF(self->rowcount); -                self->rowcount = PyInt_FromLong(rowcount);          }          Py_DECREF(self->lastrowid); -        if (statement_type == STATEMENT_INSERT) { +        if (!multiple && statement_type == STATEMENT_INSERT) {              Py_BEGIN_ALLOW_THREADS              lastrowid = sqlite3_last_insert_rowid(self->connection->db);              Py_END_ALLOW_THREADS @@ -714,14 +706,27 @@ PyObject* _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject*      }  error: +    /* just to be sure (implicit ROLLBACKs with ON CONFLICT ROLLBACK/OR +     * ROLLBACK could have happened */ +    #ifdef SQLITE_VERSION_NUMBER +    #if SQLITE_VERSION_NUMBER >= 3002002 +    self->connection->inTransaction = !sqlite3_get_autocommit(self->connection->db); +    #endif +    #endif +      Py_XDECREF(operation_bytestr);      Py_XDECREF(parameters);      Py_XDECREF(parameters_iter);      Py_XDECREF(parameters_list);      if (PyErr_Occurred()) { +        Py_DECREF(self->rowcount); +        self->rowcount = PyInt_FromLong(-1L);          return NULL;      } else { +        Py_DECREF(self->rowcount); +        self->rowcount = PyInt_FromLong(rowcount); +          Py_INCREF(self);          return (PyObject*)self;      } @@ -748,7 +753,7 @@ PyObject* pysqlite_cursor_executescript(pysqlite_Cursor* self, PyObject* args)      int statement_completed = 0;      if (!PyArg_ParseTuple(args, "O", &script_obj)) { -        return NULL;  +        return NULL;      }      if (!pysqlite_check_thread(self->connection) || !pysqlite_check_connection(self->connection)) { @@ -788,7 +793,7 @@ PyObject* pysqlite_cursor_executescript(pysqlite_Cursor* self, PyObject* args)                               &statement,                               &script_cstr);          if (rc != SQLITE_OK) { -            _pysqlite_seterror(self->connection->db); +            _pysqlite_seterror(self->connection->db, NULL);              goto error;          } @@ -796,17 +801,18 @@ PyObject* pysqlite_cursor_executescript(pysqlite_Cursor* self, PyObject* args)          rc = SQLITE_ROW;          while (rc == SQLITE_ROW) {              rc = _sqlite_step_with_busyhandler(statement, self->connection); +            /* TODO: we probably need more error handling here */          }          if (rc != SQLITE_DONE) {              (void)sqlite3_finalize(statement); -            _pysqlite_seterror(self->connection->db); +            _pysqlite_seterror(self->connection->db, NULL);              goto error;          }          rc = sqlite3_finalize(statement);          if (rc != SQLITE_OK) { -            _pysqlite_seterror(self->connection->db); +            _pysqlite_seterror(self->connection->db, NULL);              goto error;          }      } @@ -864,8 +870,9 @@ PyObject* pysqlite_cursor_iternext(pysqlite_Cursor *self)      if (self->statement) {          rc = _sqlite_step_with_busyhandler(self->statement->st, self->connection);          if (rc != SQLITE_DONE && rc != SQLITE_ROW) { +            (void)pysqlite_statement_reset(self->statement);              Py_DECREF(next_row); -            _pysqlite_seterror(self->connection->db); +            _pysqlite_seterror(self->connection->db, NULL);              return NULL;          } @@ -890,15 +897,17 @@ PyObject* pysqlite_cursor_fetchone(pysqlite_Cursor* self, PyObject* args)      return row;  } -PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args) +PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs)  { +    static char *kwlist[] = {"size", NULL, NULL}; +      PyObject* row;      PyObject* list;      int maxrows = self->arraysize;      int counter = 0; -    if (!PyArg_ParseTuple(args, "|i", &maxrows)) { -        return NULL;  +    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i:fetchmany", kwlist, &maxrows)) { +        return NULL;      }      list = PyList_New(0); @@ -992,7 +1001,7 @@ static PyMethodDef cursor_methods[] = {          PyDoc_STR("Executes a multiple SQL statements at once. Non-standard.")},      {"fetchone", (PyCFunction)pysqlite_cursor_fetchone, METH_NOARGS,          PyDoc_STR("Fetches one row from the resultset.")}, -    {"fetchmany", (PyCFunction)pysqlite_cursor_fetchmany, METH_VARARGS, +    {"fetchmany", (PyCFunction)pysqlite_cursor_fetchmany, METH_VARARGS|METH_KEYWORDS,          PyDoc_STR("Fetches several rows from the resultset.")},      {"fetchall", (PyCFunction)pysqlite_cursor_fetchall, METH_NOARGS,          PyDoc_STR("Fetches all rows from the resultset.")}, diff --git a/Modules/_sqlite/cursor.h b/Modules/_sqlite/cursor.h index 5fce64a..d916ca5 100644 --- a/Modules/_sqlite/cursor.h +++ b/Modules/_sqlite/cursor.h @@ -1,6 +1,6 @@  /* cursor.h - definitions for the cursor type   * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de>   *   * This file is part of pysqlite.   * @@ -60,7 +60,7 @@ PyObject* pysqlite_cursor_executemany(pysqlite_Cursor* self, PyObject* args);  PyObject* pysqlite_cursor_getiter(pysqlite_Cursor *self);  PyObject* pysqlite_cursor_iternext(pysqlite_Cursor *self);  PyObject* pysqlite_cursor_fetchone(pysqlite_Cursor* self, PyObject* args); -PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args); +PyObject* pysqlite_cursor_fetchmany(pysqlite_Cursor* self, PyObject* args, PyObject* kwargs);  PyObject* pysqlite_cursor_fetchall(pysqlite_Cursor* self, PyObject* args);  PyObject* pysqlite_noop(pysqlite_Connection* self, PyObject* args);  PyObject* pysqlite_cursor_close(pysqlite_Cursor* self, PyObject* args); diff --git a/Modules/_sqlite/microprotocols.h b/Modules/_sqlite/microprotocols.h index d84ec93..c911c81 100644 --- a/Modules/_sqlite/microprotocols.h +++ b/Modules/_sqlite/microprotocols.h @@ -28,10 +28,6 @@  #include <Python.h> -#ifdef __cplusplus -extern "C" { -#endif -  /** adapters registry **/  extern PyObject *psyco_adapters; diff --git a/Modules/_sqlite/module.c b/Modules/_sqlite/module.c index 8844d81..af7eace 100644 --- a/Modules/_sqlite/module.c +++ b/Modules/_sqlite/module.c @@ -1,25 +1,25 @@ -    /* module.c - the module itself -     * -     * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> -     * -     * This file is part of pysqlite. -     * -     * This software is provided 'as-is', without any express or implied -     * warranty.  In no event will the authors be held liable for any damages -     * arising from the use of this software. -     * -     * Permission is granted to anyone to use this software for any purpose, -     * including commercial applications, and to alter it and redistribute it -     * freely, subject to the following restrictions: -     * -     * 1. The origin of this software must not be misrepresented; you must not -     *    claim that you wrote the original software. If you use this software -     *    in a product, an acknowledgment in the product documentation would be -     *    appreciated but is not required. -     * 2. Altered source versions must be plainly marked as such, and must not be -     *    misrepresented as being the original software. -     * 3. This notice may not be removed or altered from any source distribution. -     */ +/* module.c - the module itself + * + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de> + * + * This file is part of pysqlite. + * + * This software is provided 'as-is', without any express or implied + * warranty.  In no event will the authors be held liable for any damages + * arising from the use of this software. + * + * Permission is granted to anyone to use this software for any purpose, + * including commercial applications, and to alter it and redistribute it + * freely, subject to the following restrictions: + * + * 1. The origin of this software must not be misrepresented; you must not + *    claim that you wrote the original software. If you use this software + *    in a product, an acknowledgment in the product documentation would be + *    appreciated but is not required. + * 2. Altered source versions must be plainly marked as such, and must not be + *    misrepresented as being the original software. + * 3. This notice may not be removed or altered from any source distribution. + */  #include "connection.h"  #include "statement.h" @@ -41,6 +41,7 @@ PyObject* pysqlite_Error, *pysqlite_Warning, *pysqlite_InterfaceError, *pysqlite  PyObject* converters;  int _enable_callback_tracebacks; +int pysqlite_BaseTypeAdapted;  static PyObject* module_connect(PyObject* self, PyObject* args, PyObject*          kwargs) @@ -50,7 +51,7 @@ static PyObject* module_connect(PyObject* self, PyObject* args, PyObject*       * connection.c and must always be copied from there ... */      static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL}; -    char* database; +    PyObject* database;      int detect_types = 0;      PyObject* isolation_level;      PyObject* factory = NULL; @@ -60,7 +61,7 @@ static PyObject* module_connect(PyObject* self, PyObject* args, PyObject*      PyObject* result; -    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|diOiOi", kwlist, +    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|diOiOi", kwlist,                                       &database, &timeout, &detect_types, &isolation_level, &check_same_thread, &factory, &cached_statements))      {          return NULL;  @@ -133,6 +134,13 @@ static PyObject* module_register_adapter(PyObject* self, PyObject* args, PyObjec          return NULL;      } +    /* a basic type is adapted; there's a performance optimization if that's not the case +     * (99 % of all usages) */ +    if (type == &PyInt_Type || type == &PyLong_Type || type == &PyFloat_Type +            || type == &PyString_Type || type == &PyUnicode_Type || type == &PyBuffer_Type) { +        pysqlite_BaseTypeAdapted = 1; +    } +      microprotocols_add(type, (PyObject*)&pysqlite_PrepareProtocolType, caster);      Py_INCREF(Py_None); @@ -379,6 +387,8 @@ PyMODINIT_FUNC init_sqlite3(void)      _enable_callback_tracebacks = 0; +    pysqlite_BaseTypeAdapted = 0; +      /* Original comment form _bsddb.c in the Python core. This is also still       * needed nowadays for Python 2.3/2.4.       *  diff --git a/Modules/_sqlite/module.h b/Modules/_sqlite/module.h index ada6b4c..b14be2a 100644 --- a/Modules/_sqlite/module.h +++ b/Modules/_sqlite/module.h @@ -1,6 +1,6 @@  /* module.h - definitions for the module   * - * Copyright (C) 2004-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2004-2007 Gerhard Häring <gh@ghaering.de>   *   * This file is part of pysqlite.   * @@ -25,7 +25,7 @@  #define PYSQLITE_MODULE_H  #include "Python.h" -#define PYSQLITE_VERSION "2.3.3" +#define PYSQLITE_VERSION "2.4.1"  extern PyObject* pysqlite_Error;  extern PyObject* pysqlite_Warning; @@ -51,6 +51,7 @@ extern PyObject* time_sleep;  extern PyObject* converters;  extern int _enable_callback_tracebacks; +extern int pysqlite_BaseTypeAdapted;  #define PARSE_DECLTYPES 1  #define PARSE_COLNAMES 2 diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c index 83c0790..556ea01 100644 --- a/Modules/_sqlite/statement.c +++ b/Modules/_sqlite/statement.c @@ -1,6 +1,6 @@  /* statement.c - the statement type   * - * Copyright (C) 2005-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>   *   * This file is part of pysqlite.   * @@ -40,6 +40,16 @@ typedef enum {      NORMAL  } parse_remaining_sql_state; +typedef enum { +    TYPE_INT, +    TYPE_LONG, +    TYPE_FLOAT, +    TYPE_STRING, +    TYPE_UNICODE, +    TYPE_BUFFER, +    TYPE_UNKNOWN +} parameter_type; +  int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql)  {      const char* tail; @@ -97,42 +107,96 @@ int pysqlite_statement_bind_parameter(pysqlite_Statement* self, int pos, PyObjec      char* string;      Py_ssize_t buflen;      PyObject* stringval; +    parameter_type paramtype;      if (parameter == Py_None) {          rc = sqlite3_bind_null(self->st, pos); +        goto final; +    } + +    if (PyInt_CheckExact(parameter)) { +        paramtype = TYPE_INT; +    } else if (PyLong_CheckExact(parameter)) { +        paramtype = TYPE_LONG; +    } else if (PyFloat_CheckExact(parameter)) { +        paramtype = TYPE_FLOAT; +    } else if (PyString_CheckExact(parameter)) { +        paramtype = TYPE_STRING; +    } else if (PyUnicode_CheckExact(parameter)) { +        paramtype = TYPE_UNICODE; +    } else if (PyBuffer_Check(parameter)) { +        paramtype = TYPE_BUFFER;      } else if (PyInt_Check(parameter)) { -        longval = PyInt_AsLong(parameter); -        rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longval); -#ifdef HAVE_LONG_LONG +        paramtype = TYPE_INT;      } else if (PyLong_Check(parameter)) { -        longlongval = PyLong_AsLongLong(parameter); -        /* in the overflow error case, longlongval is -1, and an exception is set */ -        rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longlongval); -#endif +        paramtype = TYPE_LONG;      } else if (PyFloat_Check(parameter)) { -        rc = sqlite3_bind_double(self->st, pos, PyFloat_AsDouble(parameter)); -    } else if (PyBuffer_Check(parameter)) { -        if (PyObject_AsCharBuffer(parameter, &buffer, &buflen) == 0) { -            rc = sqlite3_bind_blob(self->st, pos, buffer, buflen, SQLITE_TRANSIENT); -        } else { -            PyErr_SetString(PyExc_ValueError, "could not convert BLOB to buffer"); -            rc = -1; -        } -    } else if PyString_Check(parameter) { -        string = PyString_AsString(parameter); -        rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); -    } else if PyUnicode_Check(parameter) { -        stringval = PyUnicode_AsUTF8String(parameter); -        string = PyString_AsString(stringval); -        rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); -        Py_DECREF(stringval); +        paramtype = TYPE_FLOAT; +    } else if (PyString_Check(parameter)) { +        paramtype = TYPE_STRING; +    } else if (PyUnicode_Check(parameter)) { +        paramtype = TYPE_UNICODE;      } else { -        rc = -1; +        paramtype = TYPE_UNKNOWN;      } +    switch (paramtype) { +        case TYPE_INT: +            longval = PyInt_AsLong(parameter); +            rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longval); +            break; +#ifdef HAVE_LONG_LONG +        case TYPE_LONG: +            longlongval = PyLong_AsLongLong(parameter); +            /* in the overflow error case, longlongval is -1, and an exception is set */ +            rc = sqlite3_bind_int64(self->st, pos, (sqlite_int64)longlongval); +            break; +#endif +        case TYPE_FLOAT: +            rc = sqlite3_bind_double(self->st, pos, PyFloat_AsDouble(parameter)); +            break; +        case TYPE_STRING: +            string = PyString_AS_STRING(parameter); +            rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); +            break; +        case TYPE_UNICODE: +            stringval = PyUnicode_AsUTF8String(parameter); +            string = PyString_AsString(stringval); +            rc = sqlite3_bind_text(self->st, pos, string, -1, SQLITE_TRANSIENT); +            Py_DECREF(stringval); +            break; +        case TYPE_BUFFER: +            if (PyObject_AsCharBuffer(parameter, &buffer, &buflen) == 0) { +                rc = sqlite3_bind_blob(self->st, pos, buffer, buflen, SQLITE_TRANSIENT); +            } else { +                PyErr_SetString(PyExc_ValueError, "could not convert BLOB to buffer"); +                rc = -1; +            } +            break; +        case TYPE_UNKNOWN: +            rc = -1; +    } + +final:      return rc;  } +/* returns 0 if the object is one of Python's internal ones that don't need to be adapted */ +static int _need_adapt(PyObject* obj) +{ +    if (pysqlite_BaseTypeAdapted) { +        return 1; +    } + +    if (PyInt_CheckExact(obj) || PyLong_CheckExact(obj)  +            || PyFloat_CheckExact(obj) || PyString_CheckExact(obj) +            || PyUnicode_CheckExact(obj) || PyBuffer_Check(obj)) { +        return 0; +    } else { +        return 1; +    } +} +  void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* parameters)  {      PyObject* current_param; @@ -147,7 +211,55 @@ void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* para      num_params_needed = sqlite3_bind_parameter_count(self->st);      Py_END_ALLOW_THREADS -    if (PyDict_Check(parameters)) { +    if (PyTuple_CheckExact(parameters) || PyList_CheckExact(parameters) || (!PyDict_Check(parameters) && PySequence_Check(parameters))) { +        /* parameters passed as sequence */ +        if (PyTuple_CheckExact(parameters)) { +            num_params = PyTuple_GET_SIZE(parameters); +        } else if (PyList_CheckExact(parameters)) { +            num_params = PyList_GET_SIZE(parameters); +        } else { +            num_params = PySequence_Size(parameters); +        } +        if (num_params != num_params_needed) { +            PyErr_Format(pysqlite_ProgrammingError, "Incorrect number of bindings supplied. The current statement uses %d, and there are %d supplied.", +                         num_params_needed, num_params); +            return; +        } +        for (i = 0; i < num_params; i++) { +            if (PyTuple_CheckExact(parameters)) { +                current_param = PyTuple_GET_ITEM(parameters, i); +                Py_XINCREF(current_param); +            } else if (PyList_CheckExact(parameters)) { +                current_param = PyList_GET_ITEM(parameters, i); +                Py_XINCREF(current_param); +            } else { +                current_param = PySequence_GetItem(parameters, i); +            } +            if (!current_param) { +                return; +            } + +            if (!_need_adapt(current_param)) { +                adapted = current_param; +            } else { +                adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); +                if (adapted) { +                    Py_DECREF(current_param); +                } else { +                    PyErr_Clear(); +                    adapted = current_param; +                } +            } + +            rc = pysqlite_statement_bind_parameter(self, i + 1, adapted); +            Py_DECREF(adapted); + +            if (rc != SQLITE_OK) { +                PyErr_Format(pysqlite_InterfaceError, "Error binding parameter %d - probably unsupported type.", i); +                return; +            } +        } +    } else if (PyDict_Check(parameters)) {          /* parameters passed as dictionary */          for (i = 1; i <= num_params_needed; i++) {              Py_BEGIN_ALLOW_THREADS @@ -159,19 +271,27 @@ void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* para              }              binding_name++; /* skip first char (the colon) */ -            current_param = PyDict_GetItemString(parameters, binding_name); +            if (PyDict_CheckExact(parameters)) { +                current_param = PyDict_GetItemString(parameters, binding_name); +                Py_XINCREF(current_param); +            } else { +                current_param = PyMapping_GetItemString(parameters, (char*)binding_name); +            }              if (!current_param) {                  PyErr_Format(pysqlite_ProgrammingError, "You did not supply a value for binding %d.", i);                  return;              } -            Py_INCREF(current_param); -            adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); -            if (adapted) { -                Py_DECREF(current_param); -            } else { -                PyErr_Clear(); +            if (!_need_adapt(current_param)) {                  adapted = current_param; +            } else { +                adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); +                if (adapted) { +                    Py_DECREF(current_param); +                } else { +                    PyErr_Clear(); +                    adapted = current_param; +                }              }              rc = pysqlite_statement_bind_parameter(self, i, adapted); @@ -183,35 +303,7 @@ void pysqlite_statement_bind_parameters(pysqlite_Statement* self, PyObject* para             }          }      } else { -        /* parameters passed as sequence */ -        num_params = PySequence_Length(parameters); -        if (num_params != num_params_needed) { -            PyErr_Format(pysqlite_ProgrammingError, "Incorrect number of bindings supplied. The current statement uses %d, and there are %d supplied.", -                         num_params_needed, num_params); -            return; -        } -        for (i = 0; i < num_params; i++) { -            current_param = PySequence_GetItem(parameters, i); -            if (!current_param) { -                return; -            } -            adapted = microprotocols_adapt(current_param, (PyObject*)&pysqlite_PrepareProtocolType, NULL); - -            if (adapted) { -                Py_DECREF(current_param); -            } else { -                PyErr_Clear(); -                adapted = current_param; -            } - -            rc = pysqlite_statement_bind_parameter(self, i + 1, adapted); -            Py_DECREF(adapted); - -            if (rc != SQLITE_OK) { -                PyErr_Format(pysqlite_InterfaceError, "Error binding parameter %d - probably unsupported type.", i); -                return; -            } -        } +        PyErr_SetString(PyExc_ValueError, "parameters are of unsupported type");      }  } diff --git a/Modules/_sqlite/util.c b/Modules/_sqlite/util.c index 5e78d58..e06c299 100644 --- a/Modules/_sqlite/util.c +++ b/Modules/_sqlite/util.c @@ -1,6 +1,6 @@  /* util.c - various utility functions   * - * Copyright (C) 2005-2006 Gerhard Häring <gh@ghaering.de> + * Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>   *   * This file is part of pysqlite.   * @@ -45,10 +45,15 @@ int _sqlite_step_with_busyhandler(sqlite3_stmt* statement, pysqlite_Connection*   * Checks the SQLite error code and sets the appropriate DB-API exception.   * Returns the error code (0 means no error occurred).   */ -int _pysqlite_seterror(sqlite3* db) +int _pysqlite_seterror(sqlite3* db, sqlite3_stmt* st)  {      int errorcode; +    /* SQLite often doesn't report anything useful, unless you reset the statement first */ +    if (st != NULL) { +        (void)sqlite3_reset(st); +    } +      errorcode = sqlite3_errcode(db);      switch (errorcode) diff --git a/Modules/_sqlite/util.h b/Modules/_sqlite/util.h index 969c5e5..6c34329 100644 --- a/Modules/_sqlite/util.h +++ b/Modules/_sqlite/util.h @@ -34,5 +34,5 @@ int _sqlite_step_with_busyhandler(sqlite3_stmt* statement, pysqlite_Connection*   * Checks the SQLite error code and sets the appropriate DB-API exception.   * Returns the error code (0 means no error occurred).   */ -int _pysqlite_seterror(sqlite3* db); +int _pysqlite_seterror(sqlite3* db, sqlite3_stmt* st);  #endif  | 
