diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 91b63a2da1..9fa2269eae 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -20,6 +20,7 @@ from google.api_core.exceptions import Aborted from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner +from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_v1.snapshot import Snapshot @@ -103,6 +104,7 @@ def __init__(self, instance, database, read_only=False): self._own_pool = True self._read_only = read_only self._staleness = None + self.request_priority = None @property def autocommit(self): @@ -442,11 +444,18 @@ def run_statement(self, statement, retried=False): ResultsChecksum() if retried else statement.checksum, ) + if self.request_priority is not None: + req_opts = RequestOptions(priority=self.request_priority) + self.request_priority = None + else: + req_opts = None + return ( transaction.execute_sql( statement.sql, statement.params, param_types=statement.param_types, + request_options=req_opts, ), ResultsChecksum() if retried else statement.checksum, ) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index e15f6af33b..23fc098afc 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -883,6 +883,42 @@ def test_staleness_single_use_readonly_autocommit(self): connection.database.snapshot.assert_called_with(read_timestamp=timestamp) + def test_request_priority(self): + from google.cloud.spanner_dbapi.checksum import ResultsChecksum + from google.cloud.spanner_dbapi.cursor import Statement + from google.cloud.spanner_v1 import RequestOptions + + sql = "SELECT 1" + params = [] + param_types = {} + priority = 2 + + connection = self._make_connection() + connection._transaction = mock.Mock(committed=False, rolled_back=False) + connection._transaction.execute_sql = mock.Mock() + + connection.request_priority = priority + + req_opts = RequestOptions(priority=priority) + + connection.run_statement( + Statement(sql, params, param_types, ResultsChecksum(), False) + ) + + connection._transaction.execute_sql.assert_called_with( + sql, params, param_types=param_types, request_options=req_opts + ) + assert connection.request_priority is None + + # check that priority is applied for only one request + connection.run_statement( + Statement(sql, params, param_types, ResultsChecksum(), False) + ) + + connection._transaction.execute_sql.assert_called_with( + sql, params, param_types=param_types, request_options=None + ) + def exit_ctx_func(self, exc_type, exc_value, traceback): """Context __exit__ method mock."""