summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorErlend Egeberg Aasland <erlend.aasland@innova.no>2021-11-16 14:53:35 (GMT)
committerGitHub <noreply@github.com>2021-11-16 14:53:35 (GMT)
commit9d6215a54c177a5e359c37ecd1c50b594b194f41 (patch)
treed958d7845d5d5c7f2e9b8c41a80207b58a522799
parent6a84d61c55f2e543cf5fa84522d8781a795bba33 (diff)
downloadcpython-9d6215a54c177a5e359c37ecd1c50b594b194f41.zip
cpython-9d6215a54c177a5e359c37ecd1c50b594b194f41.tar.gz
cpython-9d6215a54c177a5e359c37ecd1c50b594b194f41.tar.bz2
bpo-45126: Harden `sqlite3` connection initialisation (GH-28227)
-rw-r--r--Lib/test/test_sqlite3/test_dbapi.py38
-rw-r--r--Modules/_sqlite/clinic/connection.c.h12
-rw-r--r--Modules/_sqlite/connection.c122
3 files changed, 110 insertions, 62 deletions
diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py
index 802a691..18359e1 100644
--- a/Lib/test/test_sqlite3/test_dbapi.py
+++ b/Lib/test/test_sqlite3/test_dbapi.py
@@ -523,6 +523,44 @@ class ConnectionTests(unittest.TestCase):
with memory_database(isolation_level=level) as cx:
cx.execute("select 'ok'")
+ def test_connection_reinit(self):
+ db = ":memory:"
+ cx = sqlite.connect(db)
+ cx.text_factory = bytes
+ cx.row_factory = sqlite.Row
+ cu = cx.cursor()
+ cu.execute("create table foo (bar)")
+ cu.executemany("insert into foo (bar) values (?)",
+ ((str(v),) for v in range(4)))
+ cu.execute("select bar from foo")
+
+ rows = [r for r in cu.fetchmany(2)]
+ self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
+ self.assertEqual([r[0] for r in rows], [b"0", b"1"])
+
+ cx.__init__(db)
+ cx.execute("create table foo (bar)")
+ cx.executemany("insert into foo (bar) values (?)",
+ ((v,) for v in ("a", "b", "c", "d")))
+
+ # This uses the old database, old row factory, but new text factory
+ rows = [r for r in cu.fetchall()]
+ self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows))
+ self.assertEqual([r[0] for r in rows], ["2", "3"])
+
+ def test_connection_bad_reinit(self):
+ cx = sqlite.connect(":memory:")
+ with cx:
+ cx.execute("create table t(t)")
+ with temp_dir() as db:
+ self.assertRaisesRegex(sqlite.OperationalError,
+ "unable to open database file",
+ cx.__init__, db)
+ self.assertRaisesRegex(sqlite.ProgrammingError,
+ "Base Connection.__init__ not called",
+ cx.executemany, "insert into t values(?)",
+ ((v,) for v in range(3)))
+
class UninitialisedConnectionTests(unittest.TestCase):
def setUp(self):
diff --git a/Modules/_sqlite/clinic/connection.c.h b/Modules/_sqlite/clinic/connection.c.h
index 5bfc589..3a3ae04 100644
--- a/Modules/_sqlite/clinic/connection.c.h
+++ b/Modules/_sqlite/clinic/connection.c.h
@@ -7,7 +7,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
const char *database, double timeout,
int detect_types, const char *isolation_level,
int check_same_thread, PyObject *factory,
- int cached_statements, int uri);
+ int cache_size, int uri);
static int
pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
@@ -25,7 +25,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
const char *isolation_level = "";
int check_same_thread = 1;
PyObject *factory = (PyObject*)clinic_state()->ConnectionType;
- int cached_statements = 128;
+ int cache_size = 128;
int uri = 0;
fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 8, 0, argsbuf);
@@ -101,8 +101,8 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
}
}
if (fastargs[6]) {
- cached_statements = _PyLong_AsInt(fastargs[6]);
- if (cached_statements == -1 && PyErr_Occurred()) {
+ cache_size = _PyLong_AsInt(fastargs[6]);
+ if (cache_size == -1 && PyErr_Occurred()) {
goto exit;
}
if (!--noptargs) {
@@ -114,7 +114,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs)
goto exit;
}
skip_optional_pos:
- return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cached_statements, uri);
+ return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cache_size, uri);
exit:
/* Cleanup for database */
@@ -851,4 +851,4 @@ exit:
#ifndef PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
#define PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF
#endif /* !defined(PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF) */
-/*[clinic end generated code: output=663b1e9e71128f19 input=a9049054013a1b77]*/
+/*[clinic end generated code: output=6f267f20e77f92d0 input=a9049054013a1b77]*/
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index b902dc8..e794767 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -83,15 +83,17 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self);
static void free_callback_context(callback_context *ctx);
static void set_callback_context(callback_context **ctx_pp,
callback_context *ctx);
+static void connection_close(pysqlite_Connection *self);
static PyObject *
-new_statement_cache(pysqlite_Connection *self, int maxsize)
+new_statement_cache(pysqlite_Connection *self, pysqlite_state *state,
+ int maxsize)
{
PyObject *args[] = { NULL, PyLong_FromLong(maxsize), };
if (args[1] == NULL) {
return NULL;
}
- PyObject *lru_cache = self->state->lru_cache;
+ PyObject *lru_cache = state->lru_cache;
size_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET;
PyObject *inner = PyObject_Vectorcall(lru_cache, args + 1, nargsf, NULL);
Py_DECREF(args[1]);
@@ -153,7 +155,7 @@ _sqlite3.Connection.__init__ as pysqlite_connection_init
isolation_level: str(accept={str, NoneType}) = ""
check_same_thread: bool(accept={int}) = True
factory: object(c_default='(PyObject*)clinic_state()->ConnectionType') = ConnectionType
- cached_statements: int = 128
+ cached_statements as cache_size: int = 128
uri: bool = False
[clinic start generated code]*/
@@ -162,78 +164,82 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
const char *database, double timeout,
int detect_types, const char *isolation_level,
int check_same_thread, PyObject *factory,
- int cached_statements, int uri)
-/*[clinic end generated code: output=d8c37afc46d318b0 input=adfb29ac461f9e61]*/
+ int cache_size, int uri)
+/*[clinic end generated code: output=7d640ae1d83abfd4 input=35e316f66d9f70fd]*/
{
- int rc;
-
if (PySys_Audit("sqlite3.connect", "s", database) < 0) {
return -1;
}
- pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
- self->state = state;
-
- Py_CLEAR(self->statement_cache);
- Py_CLEAR(self->cursors);
-
- Py_INCREF(Py_None);
- Py_XSETREF(self->row_factory, Py_None);
-
- Py_INCREF(&PyUnicode_Type);
- Py_XSETREF(self->text_factory, (PyObject*)&PyUnicode_Type);
+ if (self->initialized) {
+ PyTypeObject *tp = Py_TYPE(self);
+ tp->tp_clear((PyObject *)self);
+ connection_close(self);
+ self->initialized = 0;
+ }
+ // Create and configure SQLite database object.
+ sqlite3 *db;
+ int rc;
Py_BEGIN_ALLOW_THREADS
- rc = sqlite3_open_v2(database, &self->db,
+ rc = sqlite3_open_v2(database, &db,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE |
(uri ? SQLITE_OPEN_URI : 0), NULL);
+ if (rc == SQLITE_OK) {
+ (void)sqlite3_busy_timeout(db, (int)(timeout*1000));
+ }
Py_END_ALLOW_THREADS
- if (self->db == NULL && rc == SQLITE_NOMEM) {
+ if (db == NULL && rc == SQLITE_NOMEM) {
PyErr_NoMemory();
return -1;
}
+
+ pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self));
if (rc != SQLITE_OK) {
- _pysqlite_seterror(state, self->db);
+ _pysqlite_seterror(state, db);
return -1;
}
- if (isolation_level) {
- const char *stmt = get_begin_statement(isolation_level);
- if (stmt == NULL) {
+ // Convert isolation level to begin statement.
+ const char *begin_statement = NULL;
+ if (isolation_level != NULL) {
+ begin_statement = get_begin_statement(isolation_level);
+ if (begin_statement == NULL) {
return -1;
}
- self->begin_statement = stmt;
- }
- else {
- self->begin_statement = NULL;
}
- self->statement_cache = new_statement_cache(self, cached_statements);
- if (self->statement_cache == NULL) {
- return -1;
- }
- if (PyErr_Occurred()) {
+ // Create LRU statement cache; returns a new reference.
+ PyObject *statement_cache = new_statement_cache(self, state, cache_size);
+ if (statement_cache == NULL) {
return -1;
}
- self->created_cursors = 0;
-
- /* Create list of weak references to cursors */
- self->cursors = PyList_New(0);
- if (self->cursors == NULL) {
+ // Create list of weak references to cursors.
+ PyObject *cursors = PyList_New(0);
+ if (cursors == NULL) {
+ Py_DECREF(statement_cache);
return -1;
}
+ // Init connection state members.
+ self->db = db;
+ self->state = state;
self->detect_types = detect_types;
- (void)sqlite3_busy_timeout(self->db, (int)(timeout*1000));
- self->thread_ident = PyThread_get_thread_ident();
+ self->begin_statement = begin_statement;
self->check_same_thread = check_same_thread;
+ self->thread_ident = PyThread_get_thread_ident();
+ self->statement_cache = statement_cache;
+ self->cursors = cursors;
+ self->created_cursors = 0;
+ self->row_factory = Py_NewRef(Py_None);
+ self->text_factory = Py_NewRef(&PyUnicode_Type);
+ self->trace_ctx = NULL;
+ self->progress_ctx = NULL;
+ self->authorizer_ctx = NULL;
- set_callback_context(&self->trace_ctx, NULL);
- set_callback_context(&self->progress_ctx, NULL);
- set_callback_context(&self->authorizer_ctx, NULL);
-
+ // Borrowed refs
self->Warning = state->Warning;
self->Error = state->Error;
self->InterfaceError = state->InterfaceError;
@@ -250,7 +256,6 @@ pysqlite_connection_init_impl(pysqlite_Connection *self,
}
self->initialized = 1;
-
return 0;
}
@@ -322,16 +327,6 @@ connection_clear(pysqlite_Connection *self)
}
static void
-connection_close(pysqlite_Connection *self)
-{
- if (self->db) {
- int rc = sqlite3_close_v2(self->db);
- assert(rc == SQLITE_OK), (void)rc;
- self->db = NULL;
- }
-}
-
-static void
free_callback_contexts(pysqlite_Connection *self)
{
set_callback_context(&self->trace_ctx, NULL);
@@ -340,6 +335,22 @@ free_callback_contexts(pysqlite_Connection *self)
}
static void
+connection_close(pysqlite_Connection *self)
+{
+ if (self->db) {
+ free_callback_contexts(self);
+
+ sqlite3 *db = self->db;
+ self->db = NULL;
+
+ Py_BEGIN_ALLOW_THREADS
+ int rc = sqlite3_close_v2(db);
+ assert(rc == SQLITE_OK), (void)rc;
+ Py_END_ALLOW_THREADS
+ }
+}
+
+static void
connection_dealloc(pysqlite_Connection *self)
{
PyTypeObject *tp = Py_TYPE(self);
@@ -348,7 +359,6 @@ connection_dealloc(pysqlite_Connection *self)
/* Clean up if user has not called .close() explicitly. */
connection_close(self);
- free_callback_contexts(self);
tp->tp_free(self);
Py_DECREF(tp);