Skip to content
Closed
81 changes: 80 additions & 1 deletion Lib/sqlite3/test/transactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ def CheckDdlDoesNotAutostartTransaction(self):
result = self.con.execute("select * from test").fetchall()
self.assertEqual(result, [])

self.con.execute("alter table test rename to test2")
self.con.rollback()
result = self.con.execute("select * from test2").fetchall()
self.assertEqual(result, [])

def CheckImmediateTransactionalDDL(self):
# You can achieve transactional DDL by issuing a BEGIN
# statement manually.
Expand All @@ -200,11 +205,85 @@ def CheckTransactionalDDL(self):
def tearDown(self):
self.con.close()


class DMLStatementDetectionTestCase(unittest.TestCase):
"""
Test behavior of sqlite3_stmt_readonly() in determining if a statement is
DML or not.
"""
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), 'needs sqlite 3.8.3 or newer')
def test_dml_detection_cte(self):
conn = sqlite.connect(':memory:')
conn.execute('CREATE TABLE kv ("key" TEXT, "val" INTEGER)')
self.assertFalse(conn.in_transaction)
conn.execute('INSERT INTO kv (key, val) VALUES (?, ?), (?, ?)',
('k1', 1, 'k2', 2))
self.assertTrue(conn.in_transaction)
conn.commit()
self.assertFalse(conn.in_transaction)

rc = conn.execute('UPDATE kv SET val=val + ?', (10,))
self.assertEqual(rc.rowcount, 2)
self.assertTrue(conn.in_transaction)
conn.commit()
self.assertFalse(conn.in_transaction)

rc = conn.execute(
'WITH c(k, v) AS (SELECT key, val + ? FROM kv) '
'UPDATE kv SET val=(SELECT v FROM c WHERE k=kv.key)',
(100,)
)
self.assertEqual(rc.rowcount, 2)
self.assertTrue(conn.in_transaction)

curs = conn.execute('SELECT key, val FROM kv ORDER BY key')
self.assertEqual(curs.fetchall(), [('k1', 111), ('k2', 112)])

@unittest.skipIf(sqlite.sqlite_version_info < (3, 7, 11), 'needs sqlite 3.7.11 or newer')
def test_dml_detection_sql_comment(self):
conn = sqlite.connect(':memory:')
conn.execute('CREATE TABLE kv ("key" TEXT, "val" INTEGER)')
self.assertFalse(conn.in_transaction)
conn.execute('INSERT INTO kv (key, val) VALUES (?, ?), (?, ?)',
('k1', 1, 'k2', 2))
conn.commit()
self.assertFalse(conn.in_transaction)

rc = conn.execute('-- a comment\nUPDATE kv SET val=val + ?', (10,))
self.assertEqual(rc.rowcount, 2)
self.assertTrue(conn.in_transaction)

curs = conn.execute('SELECT key, val FROM kv ORDER BY key')
self.assertEqual(curs.fetchall(), [('k1', 11), ('k2', 12)])
conn.rollback()
self.assertFalse(conn.in_transaction)
# Fetch again after rollback.
curs = conn.execute('SELECT key, val FROM kv ORDER BY key')
self.assertEqual(curs.fetchall(), [('k1', 1), ('k2', 2)])

def test_dml_detection_begin_exclusive(self):
# sqlite3_stmt_readonly() reports BEGIN EXCLUSIVE as being a
# non-read-only statement. To retain compatibility with the
# transactional behavior, we add a special exclusion for these
# statements.
conn = sqlite.connect(':memory:')
conn.execute('BEGIN EXCLUSIVE')
self.assertTrue(conn.in_transaction)
conn.execute('ROLLBACK')
self.assertFalse(conn.in_transaction)

def test_dml_detection_vacuum(self):
conn = sqlite.connect(':memory:')
conn.execute('vacuum')
self.assertFalse(conn.in_transaction)


def suite():
default_suite = unittest.makeSuite(TransactionTests, "Check")
special_command_suite = unittest.makeSuite(SpecialCommandTests, "Check")
ddl_suite = unittest.makeSuite(TransactionalDDL, "Check")
return unittest.TestSuite((default_suite, special_command_suite, ddl_suite))
dml_suite = unittest.makeSuite(DMLStatementDetectionTestCase)
return unittest.TestSuite((default_suite, special_command_suite, ddl_suite, dml_suite))

def test():
runner = unittest.TextTestRunner()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Use ``sqlite3_stmt_readonly()`` internally to determine if a SQL statement is
data-modifying. Requires sqlite3 3.7.11 or newer. Patch by Charles Leifer.
77 changes: 57 additions & 20 deletions Modules/_sqlite/statement.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
#include "prepare_protocol.h"
#include "util.h"

#if SQLITE_VERSION_NUMBER >= 3007004
#define HAVE_SQLITE3_STMT_READONLY
#endif

/* prototypes */
static int pysqlite_check_remaining_sql(const char* tail);

Expand All @@ -48,13 +52,64 @@ typedef enum {
TYPE_UNKNOWN
} parameter_type;

static int pysqlite_statement_is_dml(sqlite3_stmt *statement, const char *sql)
{
const char* p;
int is_dml = 0;

#ifdef HAVE_SQLITE3_STMT_READONLY
is_dml = !sqlite3_stmt_readonly(statement);
if (is_dml) {
/* Retain backwards-compatibility, as sqlite3_stmt_readonly will return
* false for BEGIN [IMMEDIATE|EXCLUSIVE] or DDL statements. */
for (p = sql; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}

is_dml = (PyOS_strnicmp(p, "begin", 5) &&
PyOS_strnicmp(p, "create", 6) &&
PyOS_strnicmp(p, "drop", 4) &&
PyOS_strnicmp(p, "alter", 5) &&
PyOS_strnicmp(p, "analyze", 7) &&
PyOS_strnicmp(p, "reindex", 7) &&
PyOS_strnicmp(p, "vacuum", 6));
break;
}
}
#else
/* Determine if the statement is a DML statement. SELECT is the only
* exception. This is a fallback for older versions of SQLite which do not
* support the sqlite3_stmt_readonly() API. */
for (p = sql; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}

is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
break;
}
#endif
return is_dml;
}

int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* connection, PyObject* sql)
{
const char* tail;
int rc;
const char* sql_cstr;
Py_ssize_t sql_cstr_len;
const char* p;

self->st = NULL;
self->in_use = 0;
Expand All @@ -73,25 +128,6 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con
Py_INCREF(sql);
self->sql = sql;

/* Determine if the statement is a DML statement.
SELECT is the only exception. See #9924. */
self->is_dml = 0;
for (p = sql_cstr; *p != 0; p++) {
switch (*p) {
case ' ':
case '\r':
case '\n':
case '\t':
continue;
}

self->is_dml = (PyOS_strnicmp(p, "insert", 6) == 0)
|| (PyOS_strnicmp(p, "update", 6) == 0)
|| (PyOS_strnicmp(p, "delete", 6) == 0)
|| (PyOS_strnicmp(p, "replace", 7) == 0);
break;
}

Py_BEGIN_ALLOW_THREADS
rc = sqlite3_prepare_v2(connection->db,
sql_cstr,
Expand All @@ -101,6 +137,7 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con
Py_END_ALLOW_THREADS

self->db = connection->db;
self->is_dml = pysqlite_statement_is_dml(self->st, sql_cstr);

if (rc == SQLITE_OK && pysqlite_check_remaining_sql(tail)) {
(void)sqlite3_finalize(self->st);
Expand Down