Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-79579: Improve DML query detection in sqlite3 #93623

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions Lib/test/test_sqlite3/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,30 @@ def test_rowcount_update_returning(self):
self.assertEqual(self.cu.fetchone()[0], 1)
self.assertEqual(self.cu.rowcount, 1)

def test_rowcount_prefixed_with_comment(self):
# gh-79579: rowcount is updated even if query is prefixed with comments
self.cu.execute("""
-- foo
insert into test(name) values ('foo')
""")
self.assertEqual(self.cu.rowcount, 1)
self.cu.execute("""
/* -- messy /* /* ** *- *--
*/
/* one more */ insert into test(name) values ('messy')
""")
self.assertEqual(self.cu.rowcount, 1)
self.cu.execute("/* bar */ update test set name='bar' where name='foo'")
self.assertEqual(self.cu.rowcount, 2)

def test_rowcount_vaccuum(self):
data = ((1,), (2,), (3,))
self.cu.executemany("insert into test(income) values(?)", data)
self.assertEqual(self.cu.rowcount, 3)
self.cx.commit()
self.cu.execute("vacuum")
self.assertEqual(self.cu.rowcount, -1)

def test_total_changes(self):
self.cu.execute("insert into test(name) values ('foo')")
self.cu.execute("insert into test(name) values ('foo')")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
:mod:`sqlite3` now correctly detects DML queries with leading comments.
Patch by Erlend E. Aasland.
49 changes: 23 additions & 26 deletions Modules/_sqlite/statement.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include "util.h"

/* prototypes */
static int pysqlite_check_remaining_sql(const char* tail);
static const char *lstrip_sql(const char *sql);

typedef enum {
LINECOMMENT_1,
Expand Down Expand Up @@ -73,7 +73,7 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
return NULL;
}

if (pysqlite_check_remaining_sql(tail)) {
if (lstrip_sql(tail) != NULL) {
PyErr_SetString(connection->ProgrammingError,
"You can only execute one statement at a time.");
goto error;
Expand All @@ -82,20 +82,12 @@ pysqlite_statement_create(pysqlite_Connection *connection, PyObject *sql)
/* Determine if the statement is a DML statement.
SELECT is the only exception. See #9924. */
int is_dml = 0;
for (const char *p = sql_cstr; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}

const char *p = lstrip_sql(sql_cstr);
if (p != NULL) {
is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
erlend-aasland marked this conversation as resolved.
Show resolved Hide resolved
break;
}

pysqlite_Statement *self = PyObject_GC_New(pysqlite_Statement,
Expand Down Expand Up @@ -139,23 +131,25 @@ stmt_traverse(pysqlite_Statement *self, visitproc visit, void *arg)
}

/*
* Checks if there is anything left in an SQL string after SQLite compiled it.
* This is used to check if somebody tried to execute more than one SQL command
* with one execute()/executemany() command, which the DB-API and we don't
* allow.
* Strip leading whitespace and comments from SQL string and return a
* pointer to the first non-whitespace, non-comment character.
*
* This is used to check if somebody tries to execute more than one SQL query
* with one execute()/executemany() command, which the DB-API don't allow.
*
* Returns 1 if there is more left than should be. 0 if ok.
* It is also used to harden DML query detection.
*/
static int pysqlite_check_remaining_sql(const char* tail)
static const char *
lstrip_sql(const char *sql)
{
const char* pos = tail;
const char *pos = sql;

parse_remaining_sql_state state = NORMAL;

for (;;) {
switch (*pos) {
case 0:
return 0;
return NULL;
erlend-aasland marked this conversation as resolved.
Show resolved Hide resolved
case '-':
if (state == NORMAL) {
state = LINECOMMENT_1;
Expand All @@ -165,9 +159,12 @@ static int pysqlite_check_remaining_sql(const char* tail)
break;
case ' ':
case '\t':
if (state == COMMENTEND_1) {
state = IN_COMMENT;
}
break;
case '\n':
case 13:
case '\r':
if (state == IN_LINECOMMENT) {
state = NORMAL;
}
Expand All @@ -178,14 +175,14 @@ static int pysqlite_check_remaining_sql(const char* tail)
} else if (state == COMMENTEND_1) {
state = NORMAL;
} else if (state == COMMENTSTART_1) {
return 1;
return pos;
}
break;
case '*':
if (state == NORMAL) {
return 1;
return pos;
} else if (state == LINECOMMENT_1) {
return 1;
return pos;
} else if (state == COMMENTSTART_1) {
state = IN_COMMENT;
} else if (state == IN_COMMENT) {
Expand All @@ -198,14 +195,14 @@ static int pysqlite_check_remaining_sql(const char* tail)
} else if (state == IN_LINECOMMENT) {
} else if (state == IN_COMMENT) {
} else {
return 1;
return pos;
}
}

pos++;
}

return 0;
return NULL;
}

static PyType_Slot stmt_slots[] = {
Expand Down