From e533567dbb4ea14a7e4862ebb4f55043cf5244d2 Mon Sep 17 00:00:00 2001 From: ankiaga Date: Thu, 23 Nov 2023 18:13:52 +0530 Subject: [PATCH] feat: Implementation for Begin and Rollback clientside statements --- .../client_side_statement_executor.py | 7 ++ .../client_side_statement_parser.py | 10 ++ google/cloud/spanner_dbapi/connection.py | 48 ++++++-- google/cloud/spanner_dbapi/cursor.py | 14 +-- .../cloud/spanner_dbapi/parsed_statement.py | 1 + tests/system/test_dbapi.py | 103 +++++++++++++++--- tests/unit/spanner_dbapi/test_connection.py | 50 ++++++++- tests/unit/spanner_dbapi/test_parse_utils.py | 6 + 8 files changed, 200 insertions(+), 39 deletions(-) diff --git a/google/cloud/spanner_dbapi/client_side_statement_executor.py b/google/cloud/spanner_dbapi/client_side_statement_executor.py index f65e8ada1a1..e75e3a611fc 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_executor.py +++ b/google/cloud/spanner_dbapi/client_side_statement_executor.py @@ -22,8 +22,15 @@ def execute(connection, parsed_statement: ParsedStatement): It is an internal method that can make backwards-incompatible changes. + :type connection: Connection + :param connection: Connection object of the dbApi + :type parsed_statement: ParsedStatement :param parsed_statement: parsed_statement based on the sql query """ if parsed_statement.client_side_statement_type == ClientSideStatementType.COMMIT: return connection.commit() + if parsed_statement.client_side_statement_type == ClientSideStatementType.BEGIN: + return connection.begin() + if parsed_statement.client_side_statement_type == ClientSideStatementType.ROLLBACK: + return connection.rollback() diff --git a/google/cloud/spanner_dbapi/client_side_statement_parser.py b/google/cloud/spanner_dbapi/client_side_statement_parser.py index e93b71f3e1a..ce1474e809b 100644 --- a/google/cloud/spanner_dbapi/client_side_statement_parser.py +++ b/google/cloud/spanner_dbapi/client_side_statement_parser.py @@ -20,7 +20,9 @@ ClientSideStatementType, ) +RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE) RE_COMMIT = re.compile(r"^\s*(COMMIT)(TRANSACTION)?", re.IGNORECASE) +RE_ROLLBACK = re.compile(r"^\s*(ROLLBACK)(TRANSACTION)?", re.IGNORECASE) def parse_stmt(query): @@ -39,4 +41,12 @@ def parse_stmt(query): return ParsedStatement( StatementType.CLIENT_SIDE, query, ClientSideStatementType.COMMIT ) + if RE_BEGIN.match(query): + return ParsedStatement( + StatementType.CLIENT_SIDE, query, ClientSideStatementType.BEGIN + ) + if RE_ROLLBACK.match(query): + return ParsedStatement( + StatementType.CLIENT_SIDE, query, ClientSideStatementType.ROLLBACK + ) return None diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index efbdc80f3ff..2ebb7d4eab2 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -34,7 +34,9 @@ from google.rpc.code_pb2 import ABORTED -AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" +TRANSACTION_NOT_BEGUN_WARNING = ( + "This method is non-operational as transaction has not begun" +) MAX_INTERNAL_RETRIES = 50 @@ -104,6 +106,7 @@ def __init__(self, instance, database=None, read_only=False): self._read_only = read_only self._staleness = None self.request_priority = None + self._transaction_begin_marked = False @property def autocommit(self): @@ -141,7 +144,7 @@ def inside_transaction(self): """Flag: transaction is started. Returns: - bool: True if transaction begun, False otherwise. + bool: True if transaction started, False otherwise. """ return ( self._transaction @@ -149,6 +152,15 @@ def inside_transaction(self): and not self._transaction.rolled_back ) + @property + def transaction_begun(self): + """Flag: transaction has begun + + Returns: + bool: True if transaction begun, False otherwise. + """ + return (not self._autocommit) or self._transaction_begin_marked + @property def instance(self): """Instance to which this connection relates. @@ -333,12 +345,10 @@ def transaction_checkout(self): Begin a new transaction, if there is no transaction in this connection yet. Return the begun one otherwise. - The method is non operational in autocommit mode. - :rtype: :class:`google.cloud.spanner_v1.transaction.Transaction` :returns: A Cloud Spanner transaction object, ready to use. """ - if not self.autocommit: + if self.transaction_begun: if not self.inside_transaction: self._transaction = self._session_checkout().transaction() self._transaction.begin() @@ -354,7 +364,7 @@ def snapshot_checkout(self): :rtype: :class:`google.cloud.spanner_v1.snapshot.Snapshot` :returns: A Cloud Spanner snapshot object, ready to use. """ - if self.read_only and not self.autocommit: + if self.read_only and self.transaction_begun: if not self._snapshot: self._snapshot = Snapshot( self._session_checkout(), multi_use=True, **self.staleness @@ -377,6 +387,22 @@ def close(self): self.is_closed = True + @check_not_closed + def begin(self): + """ + Marks the transaction as started. + + :raises: :class:`InterfaceError`: if this connection is closed. + :raises: :class:`OperationalError`: if there is an existing transaction that has begin or is running + """ + if self._transaction_begin_marked: + raise OperationalError("A transaction has already begun") + if self.inside_transaction: + raise OperationalError( + "Beginning a new transaction is not allowed when a transaction is already running" + ) + self._transaction_begin_marked = True + def commit(self): """Commits any pending transaction to the database. @@ -386,8 +412,8 @@ def commit(self): raise ValueError("Database needs to be passed for this operation") self._snapshot = None - if self._autocommit: - warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + if not self.transaction_begun: + warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2) return self.run_prior_DDL_statements() @@ -398,6 +424,7 @@ def commit(self): self._release_session() self._statements = [] + self._transaction_begin_marked = False except Aborted: self.retry_transaction() self.commit() @@ -410,14 +437,15 @@ def rollback(self): """ self._snapshot = None - if self._autocommit: - warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + if not self.transaction_begun: + warnings.warn(TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2) elif self._transaction: if not self.read_only: self._transaction.rollback() self._release_session() self._statements = [] + self._transaction_begin_marked = False @check_not_closed def cursor(self): diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 95d20f5730d..790ed333f79 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -250,7 +250,7 @@ def execute(self, sql, args=None): ) if parsed_statement.statement_type == StatementType.DDL: self._batch_DDLs(sql) - if self.connection.autocommit: + if not self.connection.transaction_begun: self.connection.run_prior_DDL_statements() return @@ -264,7 +264,7 @@ def execute(self, sql, args=None): sql, args = sql_pyformat_args_to_spanner(sql, args or None) - if not self.connection.autocommit: + if self.connection.transaction_begun: statement = Statement( sql, args, @@ -348,7 +348,7 @@ def executemany(self, operation, seq_of_params): ) statements.append((sql, params, get_param_types(params))) - if self.connection.autocommit: + if self.connection.transaction_begun: self.connection.database.run_in_transaction( self._do_batch_update, statements, many_result_set ) @@ -396,7 +396,7 @@ def fetchone(self): sequence, or None when no more data is available.""" try: res = next(self) - if not self.connection.autocommit and not self.connection.read_only: + if self.connection.transaction_begun and not self.connection.read_only: self._checksum.consume_result(res) return res except StopIteration: @@ -414,7 +414,7 @@ def fetchall(self): res = [] try: for row in self: - if not self.connection.autocommit and not self.connection.read_only: + if self.connection.transaction_begun and not self.connection.read_only: self._checksum.consume_result(row) res.append(row) except Aborted: @@ -443,7 +443,7 @@ def fetchmany(self, size=None): for _ in range(size): try: res = next(self) - if not self.connection.autocommit and not self.connection.read_only: + if self.connection.transaction_begun and not self.connection.read_only: self._checksum.consume_result(res) items.append(res) except StopIteration: @@ -473,7 +473,7 @@ def _handle_DQL(self, sql, params): if self.connection.database is None: raise ValueError("Database needs to be passed for this operation") sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) - if self.connection.read_only and not self.connection.autocommit: + if self.connection.read_only and self.connection.transaction_begun: # initiate or use the existing multi-use snapshot self._handle_DQL_with_snapshot( self.connection.snapshot_checkout(), sql, params diff --git a/google/cloud/spanner_dbapi/parsed_statement.py b/google/cloud/spanner_dbapi/parsed_statement.py index c36bc1d81cf..28705b69ed2 100644 --- a/google/cloud/spanner_dbapi/parsed_statement.py +++ b/google/cloud/spanner_dbapi/parsed_statement.py @@ -27,6 +27,7 @@ class StatementType(Enum): class ClientSideStatementType(Enum): COMMIT = 1 BEGIN = 2 + ROLLBACK = 3 @dataclass diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index ada21fef2cc..a114166696b 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -22,7 +22,7 @@ from google.cloud._helpers import UTC from google.cloud.spanner_dbapi.connection import Connection, connect -from google.cloud.spanner_dbapi.exceptions import ProgrammingError +from google.cloud.spanner_dbapi.exceptions import ProgrammingError, OperationalError from google.cloud.spanner_v1 import JsonObject from google.cloud.spanner_v1 import gapic_version as package_version from . import _helpers @@ -80,32 +80,28 @@ def init_connection(self, request, shared_instance, dbapi_database): self._cursor.close() self._conn.close() - @pytest.fixture - def execute_common_statements(self): + def _execute_common_statements(self, cursor): # execute several DML statements within one transaction - self._cursor.execute( + cursor.execute( """ INSERT INTO contacts (contact_id, first_name, last_name, email) VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') """ ) - self._cursor.execute( + cursor.execute( """ UPDATE contacts SET first_name = 'updated-first-name' WHERE first_name = 'first-name' """ ) - self._cursor.execute( + cursor.execute( """ UPDATE contacts SET email = 'test.email_updated@domen.ru' WHERE email = 'test.email@domen.ru' """ ) - - @pytest.fixture - def updated_row(self, execute_common_statements): return ( 1, "updated-first-name", @@ -113,9 +109,14 @@ def updated_row(self, execute_common_statements): "test.email_updated@domen.ru", ) - def test_commit(self, updated_row): + @pytest.mark.parametrize("client_side", [False, True]) + def test_commit(self, client_side): """Test committing a transaction with several statements.""" - self._conn.commit() + updated_row = self._execute_common_statements(self._cursor) + if client_side: + self._cursor.execute("""COMMIT""") + else: + self._conn.commit() # read the resulting data from the database self._cursor.execute("SELECT * FROM contacts") @@ -124,18 +125,80 @@ def test_commit(self, updated_row): assert got_rows == [updated_row] - def test_commit_client_side(self, updated_row): - """Test committing a transaction with several statements.""" - self._cursor.execute("""COMMIT""") + @pytest.mark.noautofixt + def test_begin_client_side(self, shared_instance, dbapi_database): + """Test beginning a transaction using client side statement, + where connection is in autocommit mode.""" + + conn1 = Connection(shared_instance, dbapi_database) + conn1.autocommit = True + cursor1 = conn1.cursor() + cursor1.execute("begin transaction") + updated_row = self._execute_common_statements(cursor1) + + # As the connection conn1 is not committed a new connection wont see its results + conn2 = Connection(shared_instance, dbapi_database) + cursor2 = conn2.cursor() + cursor2.execute("SELECT * FROM contacts") + conn2.commit() + got_rows = cursor2.fetchall() + assert got_rows != [updated_row] + + assert conn1._transaction_begin_marked is True + conn1.commit() + assert conn1._transaction_begin_marked is False + + # As the connection conn1 is committed a new connection should see its results + conn3 = Connection(shared_instance, dbapi_database) + cursor3 = conn3.cursor() + cursor3.execute("SELECT * FROM contacts") + conn3.commit() + got_rows = cursor3.fetchall() + assert got_rows == [updated_row] - # read the resulting data from the database + conn1.close() + conn2.close() + conn3.close() + cursor1.close() + cursor2.close() + cursor3.close() + + def test_begin_success_post_commit(self): + """Test beginning a new transaction post commiting an existing transaction + is possible on a connection, when connection is in autocommit mode.""" + want_row = (2, "first-name", "last-name", "test.email@domen.ru") + self._conn.autocommit = True + self._cursor.execute("begin transaction") + self._cursor.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + self._conn.commit() + + self._cursor.execute("begin transaction") self._cursor.execute("SELECT * FROM contacts") got_rows = self._cursor.fetchall() self._conn.commit() + assert got_rows == [want_row] - assert got_rows == [updated_row] + def test_begin_error_before_commit(self): + """Test beginning a new transaction before commiting an existing transaction is not possible on a connection, when connection is in autocommit mode.""" + self._conn.autocommit = True + self._cursor.execute("begin transaction") + self._cursor.execute( + """ + INSERT INTO contacts (contact_id, first_name, last_name, email) + VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + + with pytest.raises(OperationalError): + self._cursor.execute("begin transaction") - def test_rollback(self): + @pytest.mark.parametrize("client_side", [False, True]) + def test_rollback(self, client_side): """Test rollbacking a transaction with several statements.""" want_row = (2, "first-name", "last-name", "test.email@domen.ru") @@ -162,7 +225,11 @@ def test_rollback(self): WHERE email = 'test.email@domen.ru' """ ) - self._conn.rollback() + + if client_side: + self._cursor.execute("ROLLBACK") + else: + self._conn.rollback() # read the resulting data from the database self._cursor.execute("SELECT * FROM contacts") diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 1628f840624..2cf2b828d32 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -280,7 +280,7 @@ def test_close(self, mock_client): @mock.patch.object(warnings, "warn") def test_commit(self, mock_warn): from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING + from google.cloud.spanner_dbapi.connection import TRANSACTION_NOT_BEGUN_WARNING connection = Connection(INSTANCE, DATABASE) @@ -307,7 +307,7 @@ def test_commit(self, mock_warn): connection._autocommit = True connection.commit() mock_warn.assert_called_once_with( - AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2 ) def test_commit_database_error(self): @@ -321,7 +321,7 @@ def test_commit_database_error(self): @mock.patch.object(warnings, "warn") def test_rollback(self, mock_warn): from google.cloud.spanner_dbapi import Connection - from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING + from google.cloud.spanner_dbapi.connection import TRANSACTION_NOT_BEGUN_WARNING connection = Connection(INSTANCE, DATABASE) @@ -348,7 +348,7 @@ def test_rollback(self, mock_warn): connection._autocommit = True connection.rollback() mock_warn.assert_called_once_with( - AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + TRANSACTION_NOT_BEGUN_WARNING, UserWarning, stacklevel=2 ) @mock.patch("google.cloud.spanner_v1.database.Database", autospec=True) @@ -385,6 +385,48 @@ def test_as_context_manager(self): self.assertTrue(connection.is_closed) + def test_begin_cursor_closed(self): + from google.cloud.spanner_dbapi.exceptions import InterfaceError + + connection = self._make_connection() + connection.close() + + with self.assertRaises(InterfaceError): + connection.begin() + + self.assertEqual(connection._transaction_begin_marked, False) + + def test_begin_transaction_begin_marked(self): + from google.cloud.spanner_dbapi.exceptions import OperationalError + + connection = self._make_connection() + connection._transaction_begin_marked = True + + with self.assertRaises(OperationalError): + connection.begin() + + self.assertEqual(connection._transaction_begin_marked, False) + + def test_begin_inside_transaction(self): + from google.cloud.spanner_dbapi.exceptions import OperationalError + + connection = self._make_connection() + mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + connection._transaction = mock_transaction + + with self.assertRaises(OperationalError): + connection.begin() + + self.assertEqual(connection._transaction_begin_marked, False) + + def test_begin(self): + connection = self._make_connection() + + connection.begin() + + self.assertEqual(connection._transaction_begin_marked, True) + def test_run_statement_wo_retried(self): """Check that Connection remembers executed statements.""" from google.cloud.spanner_dbapi.checksum import ResultsChecksum diff --git a/tests/unit/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py index 162535349fc..06819c3a3d6 100644 --- a/tests/unit/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -53,6 +53,12 @@ def test_classify_stmt(self): ("CREATE ROLE parent", StatementType.DDL), ("commit", StatementType.CLIENT_SIDE), (" commit TRANSACTION ", StatementType.CLIENT_SIDE), + ("begin", StatementType.CLIENT_SIDE), + ("start", StatementType.CLIENT_SIDE), + ("begin transaction", StatementType.CLIENT_SIDE), + ("start transaction", StatementType.CLIENT_SIDE), + ("rollback", StatementType.CLIENT_SIDE), + (" rollback TRANSACTION ", StatementType.CLIENT_SIDE), ("GRANT SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL), ("REVOKE SELECT ON TABLE Singers TO ROLE parent", StatementType.DDL), ("GRANT ROLE parent TO ROLE child", StatementType.DDL),