Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Removing ClientSideStatementParamKey enum
  • Loading branch information
ankiaga committed Jan 4, 2024
commit fd8db526b8bc679da37b47998751513854b0905a
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from google.cloud.spanner_dbapi.parsed_statement import (
ParsedStatement,
ClientSideStatementType,
ClientSideStatementParamKey,
)
from google.cloud.spanner_v1 import (
Type,
Expand Down Expand Up @@ -102,9 +101,7 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
)
if statement_type == ClientSideStatementType.RUN_PARTITION:
return connection.run_partition(
parsed_statement.client_side_statement_params[
ClientSideStatementParamKey.PARTITION_ID
]
parsed_statement.client_side_statement_params[0]
)


Expand Down
11 changes: 3 additions & 8 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
StatementType,
ClientSideStatementType,
Statement,
ClientSideStatementParamKey,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
Expand Down Expand Up @@ -51,7 +50,7 @@ def parse_stmt(query):
:returns: ParsedStatement object.
"""
client_side_statement_type = None
client_side_statement_params = {}
client_side_statement_params = []
if RE_COMMIT.match(query):
client_side_statement_type = ClientSideStatementType.COMMIT
if RE_BEGIN.match(query):
Expand All @@ -70,15 +69,11 @@ def parse_stmt(query):
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
if RE_PARTITION_QUERY.match(query):
match = re.search(RE_PARTITION_QUERY, query)
client_side_statement_params[
ClientSideStatementParamKey.PARTITIONED_SQL_QUERY
] = match.group(2)
client_side_statement_params.append(match.group(2))
client_side_statement_type = ClientSideStatementType.PARTITION_QUERY
if RE_RUN_PARTITION.match(query):
match = re.search(RE_RUN_PARTITION, query)
client_side_statement_params[
ClientSideStatementParamKey.PARTITION_ID
] = match.group(3)
client_side_statement_params.append(match.group(3))
client_side_statement_type = ClientSideStatementType.RUN_PARTITION
if client_side_statement_type is not None:
return ParsedStatement(
Expand Down
5 changes: 1 addition & 4 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
ParsedStatement,
Statement,
StatementType,
ClientSideStatementParamKey,
)
from google.cloud.spanner_dbapi.partition_helper import PartitionId
from google.cloud.spanner_v1 import RequestOptions
Expand Down Expand Up @@ -600,9 +599,7 @@ def partition_query(
query_options=None,
):
statement = parsed_statement.statement
partitioned_query = parsed_statement.client_side_statement_params[
ClientSideStatementParamKey.PARTITIONED_SQL_QUERY
]
partitioned_query = parsed_statement.client_side_statement_params[0]
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
"Only queries can be partitioned. Invalid statement: " + statement.sql
Expand Down
9 changes: 2 additions & 7 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict
from typing import Any, List

from google.cloud.spanner_dbapi.checksum import ResultsChecksum

Expand All @@ -39,11 +39,6 @@ class ClientSideStatementType(Enum):
RUN_PARTITION = 10


class ClientSideStatementParamKey(Enum):
PARTITIONED_SQL_QUERY = 1
PARTITION_ID = 2


@dataclass
class Statement:
sql: str
Expand All @@ -60,4 +55,4 @@ class ParsedStatement:
statement_type: StatementType
statement: Statement
client_side_statement_type: ClientSideStatementType = None
client_side_statement_params: Dict[ClientSideStatementParamKey, Any] = None
client_side_statement_params: List[Any] = None
7 changes: 2 additions & 5 deletions tests/unit/spanner_dbapi/test_parse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
ParsedStatement,
Statement,
ClientSideStatementType,
ClientSideStatementParamKey,
)
from google.cloud.spanner_v1 import param_types
from google.cloud.spanner_v1 import JsonObject
Expand Down Expand Up @@ -86,9 +85,7 @@ def test_partition_query_classify_stmt(self):
StatementType.CLIENT_SIDE,
Statement("PARTITION SELECT s.SongName FROM Songs AS s"),
ClientSideStatementType.PARTITION_QUERY,
{
ClientSideStatementParamKey.PARTITIONED_SQL_QUERY: "SELECT s.SongName FROM Songs AS s"
},
["SELECT s.SongName FROM Songs AS s"],
),
)

Expand All @@ -100,7 +97,7 @@ def test_run_partition_classify_stmt(self):
StatementType.CLIENT_SIDE,
Statement("RUN PARTITION bj2bjb2j2bj2ebbh"),
ClientSideStatementType.RUN_PARTITION,
{ClientSideStatementParamKey.PARTITION_ID: "bj2bjb2j2bj2ebbh"},
["bj2bjb2j2bj2ebbh"],
),
)

Expand Down