From 9d6215a54c177a5e359c37ecd1c50b594b194f41 Mon Sep 17 00:00:00 2001 From: Erlend Egeberg Aasland Date: Tue, 16 Nov 2021 15:53:35 +0100 Subject: bpo-45126: Harden `sqlite3` connection initialisation (GH-28227) --- Lib/test/test_sqlite3/test_dbapi.py | 38 +++++++++++ Modules/_sqlite/clinic/connection.c.h | 12 ++-- Modules/_sqlite/connection.c | 122 ++++++++++++++++++---------------- 3 files changed, 110 insertions(+), 62 deletions(-) diff --git a/Lib/test/test_sqlite3/test_dbapi.py b/Lib/test/test_sqlite3/test_dbapi.py index 802a691..18359e1 100644 --- a/Lib/test/test_sqlite3/test_dbapi.py +++ b/Lib/test/test_sqlite3/test_dbapi.py @@ -523,6 +523,44 @@ class ConnectionTests(unittest.TestCase): with memory_database(isolation_level=level) as cx: cx.execute("select 'ok'") + def test_connection_reinit(self): + db = ":memory:" + cx = sqlite.connect(db) + cx.text_factory = bytes + cx.row_factory = sqlite.Row + cu = cx.cursor() + cu.execute("create table foo (bar)") + cu.executemany("insert into foo (bar) values (?)", + ((str(v),) for v in range(4))) + cu.execute("select bar from foo") + + rows = [r for r in cu.fetchmany(2)] + self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) + self.assertEqual([r[0] for r in rows], [b"0", b"1"]) + + cx.__init__(db) + cx.execute("create table foo (bar)") + cx.executemany("insert into foo (bar) values (?)", + ((v,) for v in ("a", "b", "c", "d"))) + + # This uses the old database, old row factory, but new text factory + rows = [r for r in cu.fetchall()] + self.assertTrue(all(isinstance(r, sqlite.Row) for r in rows)) + self.assertEqual([r[0] for r in rows], ["2", "3"]) + + def test_connection_bad_reinit(self): + cx = sqlite.connect(":memory:") + with cx: + cx.execute("create table t(t)") + with temp_dir() as db: + self.assertRaisesRegex(sqlite.OperationalError, + "unable to open database file", + cx.__init__, db) + self.assertRaisesRegex(sqlite.ProgrammingError, + "Base Connection.__init__ not called", + cx.executemany, "insert into t values(?)", + ((v,) for v in range(3))) + class UninitialisedConnectionTests(unittest.TestCase): def setUp(self): diff --git a/Modules/_sqlite/clinic/connection.c.h b/Modules/_sqlite/clinic/connection.c.h index 5bfc589..3a3ae04 100644 --- a/Modules/_sqlite/clinic/connection.c.h +++ b/Modules/_sqlite/clinic/connection.c.h @@ -7,7 +7,7 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, const char *database, double timeout, int detect_types, const char *isolation_level, int check_same_thread, PyObject *factory, - int cached_statements, int uri); + int cache_size, int uri); static int pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs) @@ -25,7 +25,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs) const char *isolation_level = ""; int check_same_thread = 1; PyObject *factory = (PyObject*)clinic_state()->ConnectionType; - int cached_statements = 128; + int cache_size = 128; int uri = 0; fastargs = _PyArg_UnpackKeywords(_PyTuple_CAST(args)->ob_item, nargs, kwargs, NULL, &_parser, 1, 8, 0, argsbuf); @@ -101,8 +101,8 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs) } } if (fastargs[6]) { - cached_statements = _PyLong_AsInt(fastargs[6]); - if (cached_statements == -1 && PyErr_Occurred()) { + cache_size = _PyLong_AsInt(fastargs[6]); + if (cache_size == -1 && PyErr_Occurred()) { goto exit; } if (!--noptargs) { @@ -114,7 +114,7 @@ pysqlite_connection_init(PyObject *self, PyObject *args, PyObject *kwargs) goto exit; } skip_optional_pos: - return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cached_statements, uri); + return_value = pysqlite_connection_init_impl((pysqlite_Connection *)self, database, timeout, detect_types, isolation_level, check_same_thread, factory, cache_size, uri); exit: /* Cleanup for database */ @@ -851,4 +851,4 @@ exit: #ifndef PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF #define PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF #endif /* !defined(PYSQLITE_CONNECTION_LOAD_EXTENSION_METHODDEF) */ -/*[clinic end generated code: output=663b1e9e71128f19 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=6f267f20e77f92d0 input=a9049054013a1b77]*/ diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index b902dc8..e794767 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -83,15 +83,17 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self); static void free_callback_context(callback_context *ctx); static void set_callback_context(callback_context **ctx_pp, callback_context *ctx); +static void connection_close(pysqlite_Connection *self); static PyObject * -new_statement_cache(pysqlite_Connection *self, int maxsize) +new_statement_cache(pysqlite_Connection *self, pysqlite_state *state, + int maxsize) { PyObject *args[] = { NULL, PyLong_FromLong(maxsize), }; if (args[1] == NULL) { return NULL; } - PyObject *lru_cache = self->state->lru_cache; + PyObject *lru_cache = state->lru_cache; size_t nargsf = 1 | PY_VECTORCALL_ARGUMENTS_OFFSET; PyObject *inner = PyObject_Vectorcall(lru_cache, args + 1, nargsf, NULL); Py_DECREF(args[1]); @@ -153,7 +155,7 @@ _sqlite3.Connection.__init__ as pysqlite_connection_init isolation_level: str(accept={str, NoneType}) = "" check_same_thread: bool(accept={int}) = True factory: object(c_default='(PyObject*)clinic_state()->ConnectionType') = ConnectionType - cached_statements: int = 128 + cached_statements as cache_size: int = 128 uri: bool = False [clinic start generated code]*/ @@ -162,78 +164,82 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, const char *database, double timeout, int detect_types, const char *isolation_level, int check_same_thread, PyObject *factory, - int cached_statements, int uri) -/*[clinic end generated code: output=d8c37afc46d318b0 input=adfb29ac461f9e61]*/ + int cache_size, int uri) +/*[clinic end generated code: output=7d640ae1d83abfd4 input=35e316f66d9f70fd]*/ { - int rc; - if (PySys_Audit("sqlite3.connect", "s", database) < 0) { return -1; } - pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self)); - self->state = state; - - Py_CLEAR(self->statement_cache); - Py_CLEAR(self->cursors); - - Py_INCREF(Py_None); - Py_XSETREF(self->row_factory, Py_None); - - Py_INCREF(&PyUnicode_Type); - Py_XSETREF(self->text_factory, (PyObject*)&PyUnicode_Type); + if (self->initialized) { + PyTypeObject *tp = Py_TYPE(self); + tp->tp_clear((PyObject *)self); + connection_close(self); + self->initialized = 0; + } + // Create and configure SQLite database object. + sqlite3 *db; + int rc; Py_BEGIN_ALLOW_THREADS - rc = sqlite3_open_v2(database, &self->db, + rc = sqlite3_open_v2(database, &db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | (uri ? SQLITE_OPEN_URI : 0), NULL); + if (rc == SQLITE_OK) { + (void)sqlite3_busy_timeout(db, (int)(timeout*1000)); + } Py_END_ALLOW_THREADS - if (self->db == NULL && rc == SQLITE_NOMEM) { + if (db == NULL && rc == SQLITE_NOMEM) { PyErr_NoMemory(); return -1; } + + pysqlite_state *state = pysqlite_get_state_by_type(Py_TYPE(self)); if (rc != SQLITE_OK) { - _pysqlite_seterror(state, self->db); + _pysqlite_seterror(state, db); return -1; } - if (isolation_level) { - const char *stmt = get_begin_statement(isolation_level); - if (stmt == NULL) { + // Convert isolation level to begin statement. + const char *begin_statement = NULL; + if (isolation_level != NULL) { + begin_statement = get_begin_statement(isolation_level); + if (begin_statement == NULL) { return -1; } - self->begin_statement = stmt; - } - else { - self->begin_statement = NULL; } - self->statement_cache = new_statement_cache(self, cached_statements); - if (self->statement_cache == NULL) { - return -1; - } - if (PyErr_Occurred()) { + // Create LRU statement cache; returns a new reference. + PyObject *statement_cache = new_statement_cache(self, state, cache_size); + if (statement_cache == NULL) { return -1; } - self->created_cursors = 0; - - /* Create list of weak references to cursors */ - self->cursors = PyList_New(0); - if (self->cursors == NULL) { + // Create list of weak references to cursors. + PyObject *cursors = PyList_New(0); + if (cursors == NULL) { + Py_DECREF(statement_cache); return -1; } + // Init connection state members. + self->db = db; + self->state = state; self->detect_types = detect_types; - (void)sqlite3_busy_timeout(self->db, (int)(timeout*1000)); - self->thread_ident = PyThread_get_thread_ident(); + self->begin_statement = begin_statement; self->check_same_thread = check_same_thread; + self->thread_ident = PyThread_get_thread_ident(); + self->statement_cache = statement_cache; + self->cursors = cursors; + self->created_cursors = 0; + self->row_factory = Py_NewRef(Py_None); + self->text_factory = Py_NewRef(&PyUnicode_Type); + self->trace_ctx = NULL; + self->progress_ctx = NULL; + self->authorizer_ctx = NULL; - set_callback_context(&self->trace_ctx, NULL); - set_callback_context(&self->progress_ctx, NULL); - set_callback_context(&self->authorizer_ctx, NULL); - + // Borrowed refs self->Warning = state->Warning; self->Error = state->Error; self->InterfaceError = state->InterfaceError; @@ -250,7 +256,6 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, } self->initialized = 1; - return 0; } @@ -322,16 +327,6 @@ connection_clear(pysqlite_Connection *self) } static void -connection_close(pysqlite_Connection *self) -{ - if (self->db) { - int rc = sqlite3_close_v2(self->db); - assert(rc == SQLITE_OK), (void)rc; - self->db = NULL; - } -} - -static void free_callback_contexts(pysqlite_Connection *self) { set_callback_context(&self->trace_ctx, NULL); @@ -340,6 +335,22 @@ free_callback_contexts(pysqlite_Connection *self) } static void +connection_close(pysqlite_Connection *self) +{ + if (self->db) { + free_callback_contexts(self); + + sqlite3 *db = self->db; + self->db = NULL; + + Py_BEGIN_ALLOW_THREADS + int rc = sqlite3_close_v2(db); + assert(rc == SQLITE_OK), (void)rc; + Py_END_ALLOW_THREADS + } +} + +static void connection_dealloc(pysqlite_Connection *self) { PyTypeObject *tp = Py_TYPE(self); @@ -348,7 +359,6 @@ connection_dealloc(pysqlite_Connection *self) /* Clean up if user has not called .close() explicitly. */ connection_close(self); - free_callback_contexts(self); tp->tp_free(self); Py_DECREF(tp); -- cgit v0.12