diff options
author | Erlend Egeberg Aasland <erlend.aasland@innova.no> | 2022-06-14 13:05:36 (GMT) |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-14 13:05:36 (GMT) |
commit | 2229d34a6ef9c03e476882f46bba4d7c83552d1e (patch) | |
tree | a699290af00f9bedf002fc4af83fbf152eb10193 /Modules | |
parent | f9585e2adc36fb2886f96ef86a5aa14215fc1151 (diff) | |
download | cpython-2229d34a6ef9c03e476882f46bba4d7c83552d1e.zip cpython-2229d34a6ef9c03e476882f46bba4d7c83552d1e.tar.gz cpython-2229d34a6ef9c03e476882f46bba4d7c83552d1e.tar.bz2 |
[3.10] gh-79579: Improve DML query detection in sqlite3 (GH-93623) (#93801)
The fix involves using pysqlite_check_remaining_sql(), not only to check
for multiple statements, but now also to strip leading comments and
whitespace from SQL statements, so we can improve DML query detection.
pysqlite_check_remaining_sql() is renamed lstrip_sql(), to more
accurately reflect its function, and hardened to handle more SQL comment
corner cases.
(cherry picked from commit 46740073ef32bf83964c39609c7a7a4772c51ce3)
Diffstat (limited to 'Modules')
-rw-r--r-- | Modules/_sqlite/statement.c | 120 |
1 files changed, 45 insertions, 75 deletions
diff --git a/Modules/_sqlite/statement.c b/Modules/_sqlite/statement.c index 3bc8642..adc531f 100644 --- a/Modules/_sqlite/statement.c +++ b/Modules/_sqlite/statement.c @@ -29,16 +29,7 @@ #include "util.h" /* prototypes */ -static int pysqlite_check_remaining_sql(const char* tail); - -typedef enum { - LINECOMMENT_1, - IN_LINECOMMENT, - COMMENTSTART_1, - IN_COMMENT, - COMMENTEND_1, - NORMAL -} parse_remaining_sql_state; +static const char *lstrip_sql(const char *sql); typedef enum { TYPE_LONG, @@ -55,7 +46,6 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql) int rc; const char* sql_cstr; Py_ssize_t sql_cstr_len; - const char* p; assert(PyUnicode_Check(sql)); @@ -87,20 +77,12 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql) /* Determine if the statement is a DML statement. SELECT is the only exception. See #9924. */ - for (p = sql_cstr; *p != 0; p++) { - switch (*p) { - case ' ': - case '\r': - case '\n': - case '\t': - continue; - } - + const char *p = lstrip_sql(sql_cstr); + if (p != NULL) { self->is_dml = (PyOS_strnicmp(p, "insert", 6) == 0) || (PyOS_strnicmp(p, "update", 6) == 0) || (PyOS_strnicmp(p, "delete", 6) == 0) || (PyOS_strnicmp(p, "replace", 7) == 0); - break; } Py_BEGIN_ALLOW_THREADS @@ -118,7 +100,7 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql) goto error; } - if (rc == SQLITE_OK && pysqlite_check_remaining_sql(tail)) { + if (rc == SQLITE_OK && lstrip_sql(tail)) { (void)sqlite3_finalize(self->st); self->st = NULL; PyErr_SetString(pysqlite_Warning, @@ -431,73 +413,61 @@ stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg) } /* - * Checks if there is anything left in an SQL string after SQLite compiled it. - * This is used to check if somebody tried to execute more than one SQL command - * with one execute()/executemany() command, which the DB-API and we don't - * allow. + * Strip leading whitespace and comments from incoming SQL (null terminated C + * string) and return a pointer to the first non-whitespace, non-comment + * character. + * + * This is used to check if somebody tries to execute more than one SQL query + * with one execute()/executemany() command, which the DB-API don't allow. * - * Returns 1 if there is more left than should be. 0 if ok. + * It is also used to harden DML query detection. */ -static int pysqlite_check_remaining_sql(const char* tail) +static inline const char * +lstrip_sql(const char *sql) { - const char* pos = tail; - - parse_remaining_sql_state state = NORMAL; - - for (;;) { + // This loop is borrowed from the SQLite source code. + for (const char *pos = sql; *pos; pos++) { switch (*pos) { - case 0: - return 0; - case '-': - if (state == NORMAL) { - state = LINECOMMENT_1; - } else if (state == LINECOMMENT_1) { - state = IN_LINECOMMENT; - } - break; case ' ': case '\t': - break; + case '\f': case '\n': - case 13: - if (state == IN_LINECOMMENT) { - state = NORMAL; - } + case '\r': + // Skip whitespace. break; - case '/': - if (state == NORMAL) { - state = COMMENTSTART_1; - } else if (state == COMMENTEND_1) { - state = NORMAL; - } else if (state == COMMENTSTART_1) { - return 1; + case '-': + // Skip line comments. + if (pos[1] == '-') { + pos += 2; + while (pos[0] && pos[0] != '\n') { + pos++; + } + if (pos[0] == '\0') { + return NULL; + } + continue; } - break; - case '*': - if (state == NORMAL) { - return 1; - } else if (state == LINECOMMENT_1) { - return 1; - } else if (state == COMMENTSTART_1) { - state = IN_COMMENT; - } else if (state == IN_COMMENT) { - state = COMMENTEND_1; + return pos; + case '/': + // Skip C style comments. + if (pos[1] == '*') { + pos += 2; + while (pos[0] && (pos[0] != '*' || pos[1] != '/')) { + pos++; + } + if (pos[0] == '\0') { + return NULL; + } + pos++; + continue; } - break; + return pos; default: - if (state == COMMENTEND_1) { - state = IN_COMMENT; - } else if (state == IN_LINECOMMENT) { - } else if (state == IN_COMMENT) { - } else { - return 1; - } + return pos; } - - pos++; } - return 0; + return NULL; } static PyMemberDef stmt_members[] = { |