Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementation of run partition query #1080

Merged
merged 5 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
return connection.run_partition(
parsed_statement.client_side_statement_params[0]
)
if statement_type == ClientSideStatementType.RUN_PARTITIONED_QUERY:
return connection.run_partitioned_query(parsed_statement)


def _get_streamed_result_set(column_name, type_code, column_values):
Expand Down
27 changes: 17 additions & 10 deletions google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)
RE_PARTITION_QUERY = re.compile(r"^\s*(PARTITION)\s+(.+)", re.IGNORECASE)
RE_RUN_PARTITION = re.compile(r"^\s*(RUN)\s+(PARTITION)\s+(.+)", re.IGNORECASE)
RE_RUN_PARTITIONED_QUERY = re.compile(
r"^\s*(RUN)\s+(PARTITIONED)\s+(QUERY)\s+(.+)", re.IGNORECASE
)


def parse_stmt(query):
Expand All @@ -53,25 +56,29 @@ def parse_stmt(query):
client_side_statement_params = []
if RE_COMMIT.match(query):
client_side_statement_type = ClientSideStatementType.COMMIT
if RE_BEGIN.match(query):
client_side_statement_type = ClientSideStatementType.BEGIN
if RE_ROLLBACK.match(query):
elif RE_ROLLBACK.match(query):
olavloite marked this conversation as resolved.
Show resolved Hide resolved
client_side_statement_type = ClientSideStatementType.ROLLBACK
if RE_SHOW_COMMIT_TIMESTAMP.match(query):
elif RE_SHOW_COMMIT_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
if RE_SHOW_READ_TIMESTAMP.match(query):
elif RE_SHOW_READ_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
if RE_START_BATCH_DML.match(query):
elif RE_START_BATCH_DML.match(query):
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
if RE_RUN_BATCH.match(query):
elif RE_BEGIN.match(query):
client_side_statement_type = ClientSideStatementType.BEGIN
elif RE_RUN_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.RUN_BATCH
if RE_ABORT_BATCH.match(query):
elif RE_ABORT_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
if RE_PARTITION_QUERY.match(query):
elif RE_RUN_PARTITIONED_QUERY.match(query):
match = re.search(RE_RUN_PARTITIONED_QUERY, query)
client_side_statement_params.append(match.group(4))
client_side_statement_type = ClientSideStatementType.RUN_PARTITIONED_QUERY
elif RE_PARTITION_QUERY.match(query):
match = re.search(RE_PARTITION_QUERY, query)
client_side_statement_params.append(match.group(2))
client_side_statement_type = ClientSideStatementType.PARTITION_QUERY
if RE_RUN_PARTITION.match(query):
elif RE_RUN_PARTITION.match(query):
match = re.search(RE_RUN_PARTITION, query)
client_side_statement_params.append(match.group(3))
client_side_statement_type = ClientSideStatementType.RUN_PARTITION
Expand Down
40 changes: 28 additions & 12 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,15 +511,7 @@ def partition_query(
):
statement = parsed_statement.statement
partitioned_query = parsed_statement.client_side_statement_params[0]
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
"Only queries can be partitioned. Invalid statement: " + statement.sql
)
if self.read_only is not True and self._client_transaction_started is True:
raise ProgrammingError(
"Partitioned query not supported as the connection is not in "
"read only mode or ReadWrite transaction started"
)
self._partitioned_query_validation(partitioned_query, statement)

batch_snapshot = self._database.batch_snapshot()
partition_ids = []
Expand All @@ -531,17 +523,18 @@ def partition_query(
query_options=query_options,
)
)

batch_transaction_id = batch_snapshot.get_batch_transaction_id()
for partition in partitions:
batch_transaction_id = batch_snapshot.get_batch_transaction_id()
partition_ids.append(
partition_helper.encode_to_string(batch_transaction_id, partition)
)
return partition_ids

@check_not_closed
def run_partition(self, batch_transaction_id):
def run_partition(self, encoded_partition_id):
partition_id: PartitionId = partition_helper.decode_from_string(
batch_transaction_id
encoded_partition_id
)
batch_transaction_id = partition_id.batch_transaction_id
batch_snapshot = self._database.batch_snapshot(
Expand All @@ -551,6 +544,29 @@ def run_partition(self, batch_transaction_id):
)
return batch_snapshot.process(partition_id.partition_result)

@check_not_closed
def run_partitioned_query(
self,
parsed_statement: ParsedStatement,
):
statement = parsed_statement.statement
partitioned_query = parsed_statement.client_side_statement_params[0]
self._partitioned_query_validation(partitioned_query, statement)
batch_snapshot = self._database.batch_snapshot()
return batch_snapshot.run_partitioned_query(
partitioned_query, statement.params, statement.param_types
)

def _partitioned_query_validation(self, partitioned_query, statement):
if _get_statement_type(Statement(partitioned_query)) is not StatementType.QUERY:
raise ProgrammingError(
"Only queries can be partitioned. Invalid statement: " + statement.sql
)
if self.read_only is not True and self._client_transaction_started is True:
raise ProgrammingError(
"Partitioned query is not supported, because the connection is in a read/write transaction."
)

def __enter__(self):
return self

Expand Down
5 changes: 4 additions & 1 deletion google/cloud/spanner_dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from google.cloud.spanner_dbapi.transaction_helper import CursorStatementType
from google.cloud.spanner_dbapi.utils import PeekIterator
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
from google.cloud.spanner_v1.merge_result_set import MergedResultSet

ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])

Expand Down Expand Up @@ -248,7 +249,9 @@ def _execute(self, sql, args=None, call_from_execute_many=False):
self, self._parsed_statement
)
if self._result_set is not None:
if isinstance(self._result_set, StreamedManyResultSets):
if isinstance(
self._result_set, StreamedManyResultSets
) or isinstance(self._result_set, MergedResultSet):
self._itr = self._result_set
else:
self._itr = PeekIterator(self._result_set)
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_dbapi/parsed_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ClientSideStatementType(Enum):
ABORT_BATCH = 8
PARTITION_QUERY = 9
RUN_PARTITION = 10
RUN_PARTITIONED_QUERY = 11


@dataclass
Expand Down
71 changes: 71 additions & 0 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.merge_result_set import MergedResultSet
from google.cloud.spanner_v1.pool import BurstyPool
from google.cloud.spanner_v1.pool import SessionCheckout
from google.cloud.spanner_v1.session import Session
Expand Down Expand Up @@ -1513,6 +1514,76 @@ def process_query_batch(
partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout
)

def run_partitioned_query(
self,
sql,
params=None,
param_types=None,
partition_size_bytes=None,
max_partitions=None,
query_options=None,
data_boost_enabled=False,
):
"""Start a partitioned query operation to get list of partitions and
then executes each partition on a separate thread

:type sql: str
:param sql: SQL query statement

:type params: dict, {str -> column value}
:param params: values for parameter replacement. Keys must match
the names used in ``sql``.

:type param_types: dict[str -> Union[dict, .types.Type]]
:param param_types:
(Optional) maps explicit types for one or more param values;
required if parameters are passed.

:type partition_size_bytes: int
:param partition_size_bytes:
(Optional) desired size for each partition generated. The service
uses this as a hint, the actual partition size may differ.

:type partition_size_bytes: int
:param partition_size_bytes:
(Optional) desired size for each partition generated. The service
uses this as a hint, the actual partition size may differ.
ankiaga marked this conversation as resolved.
Show resolved Hide resolved

:type max_partitions: int
:param max_partitions:
(Optional) desired maximum number of partitions generated. The
service uses this as a hint, the actual number of partitions may
differ.

:type query_options:
:class:`~google.cloud.spanner_v1.types.ExecuteSqlRequest.QueryOptions`
or :class:`dict`
:param query_options:
(Optional) Query optimizer configuration to use for the given query.
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.QueryOptions`

:type data_boost_enabled:
:param data_boost_enabled:
(Optional) If this is for a partitioned query and this field is
set ``true``, the request will be executed via offline access.
ankiaga marked this conversation as resolved.
Show resolved Hide resolved

:rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
partitions = list(
self.generate_query_batches(
sql,
params,
param_types,
partition_size_bytes,
max_partitions,
query_options,
data_boost_enabled,
)
)
return MergedResultSet(self, partitions, 0)

def process(self, batch):
"""Process a single, partitioned query or read.

Expand Down
133 changes: 133 additions & 0 deletions google/cloud/spanner_v1/merge_result_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2024 Google LLC All rights reserved.
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
from dataclasses import dataclass
from queue import Queue
from typing import Any, TYPE_CHECKING
from threading import Lock, Semaphore

if TYPE_CHECKING:
from google.cloud.spanner_v1.database import BatchSnapshot

QUEUE_SIZE_PER_WORKER = 32
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
MAX_PARALLELISM = 100
ankiaga marked this conversation as resolved.
Show resolved Hide resolved
METADATA_LOCK = Lock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now a global lock, right? Would it be possible to make it an instance variable for the MergedResultSet class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we want it to be an instance of MergedResultSet because it is not used by MergedResultSet class but used just by static _set_metadata() method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that if you have multiple MergedResultSets open, then they will block each other, which is not necessary. So a more correct design would be to have the lock as an instance variable, and the _set_meta_data as an instance method. The goal of this lock is to prevent multiple threads from setting/reading the metadata field of a specific MergedResultSet at the same time, not to prevent different MergedResultSets from setting their respective metadata fields.

I agree that it is a bit theoretical, as it is unlikely that a user will have a large number of MergedResultSets open at the same time, but putting the lock where it is actually needed will make the code easier to read and understand. Now it seems like it would be a problem if two different MergedResultSets try to set their metadata at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SG, Changed



def _set_metadata(merged_result_set, results):
METADATA_LOCK.acquire()
try:
merged_result_set._metadata = results.metadata
finally:
METADATA_LOCK.release()
merged_result_set.metadata_semaphore.release()


class PartitionExecutor:
"""
Executor that executes single partition on a separate thread and inserts
rows in the queue
"""

def __init__(self, batch_snapshot, partition_id, merged_result_set):
self._batch_snapshot: BatchSnapshot = batch_snapshot
self._partition_id = partition_id
self._merged_result_set: MergedResultSet = merged_result_set
self._queue: Queue[PartitionExecutorResult] = merged_result_set._queue

def run(self):
try:
results = self._batch_snapshot.process_query_batch(self._partition_id)
merged_result_set = self._merged_result_set
for row in results:
if merged_result_set._metadata is None:
_set_metadata(merged_result_set, results)
self._queue.put(PartitionExecutorResult(data=row))
# Special case: The result set did not return any rows.
# Push the metadata to the merged result set.
if merged_result_set._metadata is None:
_set_metadata(merged_result_set, results)
except Exception as ex:
self._queue.put(PartitionExecutorResult(exception=ex))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also call _set_metadata here (if it has not already been set) to prevent the metadata property from blocking indefinitely if someone tries to call that after an error has occurred. If for example the query fails for all partitions, then the user will get a MergedResultSet that returns an error whenever you try to iterate over the rows, but that hangs forever if you try to call metadata. The latter is always very hard to debug, so we should whenever possible return an error instead of block when something goes wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

finally:
# Emit a special 'is_last' result to ensure that the MergedResultSet
# is not blocked on a queue that never receives any more results.
self._queue.put(PartitionExecutorResult(is_last=True))


@dataclass
class PartitionExecutorResult:
data: Any = None
exception: Exception = None
is_last: bool = False


class MergedResultSet:
"""
Executes multiple partitions on different threads and then combines the
results from multiple queries using a synchronized queue. The order of the
records in the MergedResultSet is not guaranteed.
"""

def __init__(self, batch_snapshot, partition_ids, max_parallelism):
self._exception = None
self.metadata_semaphore = Semaphore(0)

partition_ids_count = len(partition_ids)
self._finished_count_down_latch = partition_ids_count
parallelism = min(MAX_PARALLELISM, partition_ids_count)
if max_parallelism != 0:
parallelism = min(partition_ids_count, max_parallelism)
self._queue = Queue(maxsize=QUEUE_SIZE_PER_WORKER * parallelism)

partition_executors = []
for partition_id in partition_ids:
partition_executors.append(
PartitionExecutor(batch_snapshot, partition_id, self)
)
executor = ThreadPoolExecutor(max_workers=parallelism)
for partition_executor in partition_executors:
executor.submit(partition_executor.run)
executor.shutdown(False)

self._metadata = None
# This will make sure that _metadata is set
self.metadata_semaphore.acquire()
ankiaga marked this conversation as resolved.
Show resolved Hide resolved

def __iter__(self):
return self

def __next__(self):
if self._exception is not None:
raise self._exception
while True:
partition_result = self._queue.get()
if partition_result.is_last:
self._finished_count_down_latch -= 1
if self._finished_count_down_latch == 0:
raise StopIteration
elif partition_result.exception is not None:
self._exception = partition_result.exception
raise self._exception
else:
return partition_result.data

@property
def metadata(self):
return self._metadata

@property
def stats(self):
olavloite marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Implement
return None
Loading