diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 14e59df67..aa628441b 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760 +DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600 DEFAULT_ARRAY_SIZE = 100000 @@ -153,6 +153,8 @@ def read(self) -> Optional[OAuthToken]: # _use_arrow_native_timestamps # Databricks runtime will return native Arrow types for timestamps instead of Arrow strings # (True by default) + # use_cloud_fetch + # Enable use of cloud fetch to extract large query results in parallel via cloud storage if access_token: access_token_kv = {"access_token": access_token} @@ -189,6 +191,7 @@ def read(self) -> Optional[OAuthToken]: self._session_handle = self.thrift_backend.open_session( session_configuration, catalog, schema ) + self.use_cloud_fetch = kwargs.get("use_cloud_fetch", False) self.open = True logger.info("Successfully opened session " + str(self.get_session_id_hex())) self._cursors = [] # type: List[Cursor] @@ -497,6 +500,7 @@ def execute( max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, cursor=self, + use_cloud_fetch=self.connection.use_cloud_fetch, ) self.active_result_set = ResultSet( self.connection, @@ -822,6 +826,7 @@ def __iter__(self): break def _fill_results_buffer(self): + # At initialization or if the server does not have cloud fetch result links available results, has_more_rows = self.thrift_backend.fetch_results( op_handle=self.command_id, max_rows=self.arraysize, diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index aac3ac336..9a997f393 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -161,6 +161,6 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler): return True def _shutdown_manager(self): - # Clear download handlers and shutdown the thread pool to cancel pending futures + # Clear download handlers and shutdown the thread pool self.download_handlers = [] - self.thread_pool.shutdown(wait=False, cancel_futures=True) + self.thread_pool.shutdown(wait=False) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 7756c56a1..ef225d1f5 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -5,7 +5,6 @@ import time import uuid import threading -import lz4.frame from ssl import CERT_NONE, CERT_REQUIRED, create_default_context from typing import List, Union @@ -26,11 +25,14 @@ ) from databricks.sql.utils import ( - ArrowQueue, ExecuteResponse, _bound, RequestErrorInfo, NoRetryReason, + ResultSetQueueFactory, + convert_arrow_based_set_to_arrow_table, + convert_decimals_in_arrow_table, + convert_column_based_set_to_arrow_table, ) logger = logging.getLogger(__name__) @@ -67,7 +69,6 @@ class ThriftBackend: CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE - BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] def __init__( self, @@ -115,6 +116,8 @@ def __init__( # _socket_timeout # The timeout in seconds for socket send, recv and connect operations. Should be a positive float or integer. # (defaults to 900) + # max_download_threads + # Number of threads for handling cloud fetch downloads. Defaults to 10 port = port or 443 if kwargs.get("_connection_uri"): @@ -136,6 +139,9 @@ def __init__( "_use_arrow_native_timestamps", True ) + # Cloud fetch + self.max_download_threads = kwargs.get("max_download_threads", 10) + # Configure tls context ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file")) if kwargs.get("_tls_no_verify") is True: @@ -558,108 +564,14 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti ( arrow_table, num_rows, - ) = ThriftBackend._convert_column_based_set_to_arrow_table( - t_row_set.columns, description - ) + ) = convert_column_based_set_to_arrow_table(t_row_set.columns, description) elif t_row_set.arrowBatches is not None: - ( - arrow_table, - num_rows, - ) = ThriftBackend._convert_arrow_based_set_to_arrow_table( + (arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set)) - return self._convert_decimals_in_arrow_table(arrow_table, description), num_rows - - @staticmethod - def _convert_decimals_in_arrow_table(table, description): - for (i, col) in enumerate(table.itercolumns()): - if description[i][1] == "decimal": - decimal_col = col.to_pandas().apply( - lambda v: v if v is None else Decimal(v) - ) - precision, scale = description[i][4], description[i][5] - assert scale is not None - assert precision is not None - # Spark limits decimal to a maximum scale of 38, - # so 128 is guaranteed to be big enough - dtype = pyarrow.decimal128(precision, scale) - col_data = pyarrow.array(decimal_col, type=dtype) - field = table.field(i).with_type(dtype) - table = table.set_column(i, field, col_data) - return table - - @staticmethod - def _convert_arrow_based_set_to_arrow_table( - arrow_batches, lz4_compressed, schema_bytes - ): - ba = bytearray() - ba += schema_bytes - n_rows = 0 - if lz4_compressed: - for arrow_batch in arrow_batches: - n_rows += arrow_batch.rowCount - ba += lz4.frame.decompress(arrow_batch.batch) - else: - for arrow_batch in arrow_batches: - n_rows += arrow_batch.rowCount - ba += arrow_batch.batch - arrow_table = pyarrow.ipc.open_stream(ba).read_all() - return arrow_table, n_rows - - @staticmethod - def _convert_column_based_set_to_arrow_table(columns, description): - arrow_table = pyarrow.Table.from_arrays( - [ThriftBackend._convert_column_to_arrow_array(c) for c in columns], - # Only use the column names from the schema, the types are determined by the - # physical types used in column based set, as they can differ from the - # mapping used in _hive_schema_to_arrow_schema. - names=[c[0] for c in description], - ) - return arrow_table, arrow_table.num_rows - - @staticmethod - def _convert_column_to_arrow_array(t_col): - """ - Return a pyarrow array from the values in a TColumn instance. - Note that ColumnBasedSet has no native support for complex types, so they will be converted - to strings server-side. - """ - field_name_to_arrow_type = { - "boolVal": pyarrow.bool_(), - "byteVal": pyarrow.int8(), - "i16Val": pyarrow.int16(), - "i32Val": pyarrow.int32(), - "i64Val": pyarrow.int64(), - "doubleVal": pyarrow.float64(), - "stringVal": pyarrow.string(), - "binaryVal": pyarrow.binary(), - } - for field in field_name_to_arrow_type.keys(): - wrapper = getattr(t_col, field) - if wrapper: - return ThriftBackend._create_arrow_array( - wrapper, field_name_to_arrow_type[field] - ) - - raise OperationalError("Empty TColumn instance {}".format(t_col)) - - @staticmethod - def _create_arrow_array(t_col_value_wrapper, arrow_type): - result = t_col_value_wrapper.values - nulls = t_col_value_wrapper.nulls # bitfield describing which values are null - assert isinstance(nulls, bytes) - - # The number of bits in nulls can be both larger or smaller than the number of - # elements in result, so take the minimum of both to iterate over. - length = min(len(result), len(nulls) * 8) - - for i in range(length): - if nulls[i >> 3] & ThriftBackend.BIT_MASKS[i & 0x7]: - result[i] = None - - return pyarrow.array(result, type=arrow_type) + return convert_decimals_in_arrow_table(arrow_table, description), num_rows def _get_metadata_resp(self, op_handle): req = ttypes.TGetResultSetMetadataReq(operationHandle=op_handle) @@ -752,6 +664,7 @@ def _results_message_to_execute_response(self, resp, operation_state): if t_result_set_metadata_resp.resultFormat not in [ ttypes.TSparkRowSetType.ARROW_BASED_SET, ttypes.TSparkRowSetType.COLUMN_BASED_SET, + ttypes.TSparkRowSetType.URL_BASED_SET, ]: raise OperationalError( "Expected results to be in Arrow or column based format, " @@ -783,13 +696,14 @@ def _results_message_to_execute_response(self, resp, operation_state): assert direct_results.resultSet.results.startRowOffset == 0 assert direct_results.resultSetMetadata - arrow_results, n_rows = self._create_arrow_table( - direct_results.resultSet.results, - lz4_compressed, - schema_bytes, - description, + arrow_queue_opt = ResultSetQueueFactory.build_queue( + row_set_type=t_result_set_metadata_resp.resultFormat, + t_row_set=direct_results.resultSet.results, + arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, ) - arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0) else: arrow_queue_opt = None return ExecuteResponse( @@ -843,7 +757,14 @@ def _check_direct_results_for_error(t_spark_direct_results): ) def execute_command( - self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor + self, + operation, + session_handle, + max_rows, + max_bytes, + lz4_compression, + cursor, + use_cloud_fetch=False, ): assert session_handle is not None @@ -864,7 +785,7 @@ def execute_command( ), canReadArrowResult=True, canDecompressLZ4Result=lz4_compression, - canDownloadResult=False, + canDownloadResult=use_cloud_fetch, confOverlay={ # We want to receive proper Timestamp arrow types. "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false" @@ -993,6 +914,7 @@ def fetch_results( maxRows=max_rows, maxBytes=max_bytes, orientation=ttypes.TFetchOrientation.FETCH_NEXT, + includeResultSetMetadata=True, ) resp = self.make_request(self._client.FetchResults, req) @@ -1002,12 +924,17 @@ def fetch_results( expected_row_start_offset, resp.results.startRowOffset ) ) - arrow_results, n_rows = self._create_arrow_table( - resp.results, lz4_compressed, arrow_schema_bytes, description + + queue = ResultSetQueueFactory.build_queue( + row_set_type=resp.resultSetMetadata.resultFormat, + t_row_set=resp.results, + arrow_schema_bytes=arrow_schema_bytes, + max_download_threads=self.max_download_threads, + lz4_compressed=lz4_compressed, + description=description, ) - arrow_queue = ArrowQueue(arrow_results, n_rows) - return arrow_queue, resp.hasMoreRows + return queue, resp.hasMoreRows def close_command(self, op_handle): req = ttypes.TCloseOperationReq(operationHandle=op_handle) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index ed5581367..0aefc7a1c 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,16 +1,94 @@ +from abc import ABC, abstractmethod from collections import namedtuple, OrderedDict from collections.abc import Iterable -import datetime, decimal +from decimal import Decimal +import datetime +import decimal from enum import Enum -from typing import Dict +import lz4.frame +from typing import Dict, List, Union, Any import pyarrow -from databricks.sql import exc +from databricks.sql import exc, OperationalError +from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TSparkArrowResultLink, + TSparkRowSetType, + TRowSet, +) + +BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] + + +class ResultSetQueue(ABC): + @abstractmethod + def next_n_rows(self, num_rows: int) -> pyarrow.Table: + pass + + @abstractmethod + def remaining_rows(self) -> pyarrow.Table: + pass + + +class ResultSetQueueFactory(ABC): + @staticmethod + def build_queue( + row_set_type: TSparkRowSetType, + t_row_set: TRowSet, + arrow_schema_bytes: bytes, + max_download_threads: int, + lz4_compressed: bool = True, + description: List[List[Any]] = None, + ) -> ResultSetQueue: + """ + Factory method to build a result set queue. + + Args: + row_set_type (enum): Row set type (Arrow, Column, or URL). + t_row_set (TRowSet): Result containing arrow batches, columns, or cloud fetch links. + arrow_schema_bytes (bytes): Bytes representing the arrow schema. + lz4_compressed (bool): Whether result data has been lz4 compressed. + description (List[List[Any]]): Hive table schema description. + max_download_threads (int): Maximum number of downloader thread pool threads. + + Returns: + ResultSetQueue + """ + if row_set_type == TSparkRowSetType.ARROW_BASED_SET: + arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( + t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes + ) + converted_arrow_table = convert_decimals_in_arrow_table( + arrow_table, description + ) + return ArrowQueue(converted_arrow_table, n_valid_rows) + elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: + arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table( + t_row_set.columns, description + ) + converted_arrow_table = convert_decimals_in_arrow_table( + arrow_table, description + ) + return ArrowQueue(converted_arrow_table, n_valid_rows) + elif row_set_type == TSparkRowSetType.URL_BASED_SET: + return CloudFetchQueue( + arrow_schema_bytes, + start_row_offset=t_row_set.startRowOffset, + result_links=t_row_set.resultLinks, + lz4_compressed=lz4_compressed, + description=description, + max_download_threads=max_download_threads, + ) + else: + raise AssertionError("Row set type is not valid") -class ArrowQueue: +class ArrowQueue(ResultSetQueue): def __init__( - self, arrow_table: pyarrow.Table, n_valid_rows: int, start_row_index: int = 0 + self, + arrow_table: pyarrow.Table, + n_valid_rows: int, + start_row_index: int = 0, ): """ A queue-like wrapper over an Arrow table @@ -40,6 +118,119 @@ def remaining_rows(self) -> pyarrow.Table: return slice +class CloudFetchQueue(ResultSetQueue): + def __init__( + self, + schema_bytes, + max_download_threads: int, + start_row_offset: int = 0, + result_links: List[TSparkArrowResultLink] = None, + lz4_compressed: bool = True, + description: List[List[Any]] = None, + ): + """ + A queue-like wrapper over CloudFetch arrow batches. + + Attributes: + schema_bytes (bytes): Table schema in bytes. + max_download_threads (int): Maximum number of downloader thread pool threads. + start_row_offset (int): The offset of the first row of the cloud fetch links. + result_links (List[TSparkArrowResultLink]): Links containing the downloadable URL and metadata. + lz4_compressed (bool): Whether the files are lz4 compressed. + description (List[List[Any]]): Hive table schema description. + """ + self.schema_bytes = schema_bytes + self.max_download_threads = max_download_threads + self.start_row_index = start_row_offset + self.result_links = result_links + self.lz4_compressed = lz4_compressed + self.description = description + + self.download_manager = ResultFileDownloadManager( + self.max_download_threads, self.lz4_compressed + ) + self.download_manager.add_file_links(result_links) + + self.table = self._create_next_table() + self.table_row_index = 0 + + def next_n_rows(self, num_rows: int) -> pyarrow.Table: + """ + Get up to the next n rows of the cloud fetch Arrow dataframes. + + Args: + num_rows (int): Number of rows to retrieve. + + Returns: + pyarrow.Table + """ + if not self.table: + # Return empty pyarrow table to cause retry of fetch + return self._create_empty_table() + results = self.table.slice(0, 0) + while num_rows > 0 and self.table: + # Get remaining of num_rows or the rest of the current table, whichever is smaller + length = min(num_rows, self.table.num_rows - self.table_row_index) + table_slice = self.table.slice(self.table_row_index, length) + results = pyarrow.concat_tables([results, table_slice]) + self.table_row_index += table_slice.num_rows + + # Replace current table with the next table if we are at the end of the current table + if self.table_row_index == self.table.num_rows: + self.table = self._create_next_table() + self.table_row_index = 0 + num_rows -= table_slice.num_rows + return results + + def remaining_rows(self) -> pyarrow.Table: + """ + Get all remaining rows of the cloud fetch Arrow dataframes. + + Returns: + pyarrow.Table + """ + if not self.table: + # Return empty pyarrow table to cause retry of fetch + return self._create_empty_table() + results = self.table.slice(0, 0) + while self.table: + table_slice = self.table.slice( + self.table_row_index, self.table.num_rows - self.table_row_index + ) + results = pyarrow.concat_tables([results, table_slice]) + self.table_row_index += table_slice.num_rows + self.table = self._create_next_table() + self.table_row_index = 0 + return results + + def _create_next_table(self) -> Union[pyarrow.Table, None]: + # Create next table by retrieving the logical next downloaded file, or return None to signal end of queue + downloaded_file = self.download_manager.get_next_downloaded_file( + self.start_row_index + ) + if not downloaded_file: + # None signals no more Arrow tables can be built from the remaining handlers if any remain + return None + arrow_table = create_arrow_table_from_arrow_file( + downloaded_file.file_bytes, self.description + ) + + # The server rarely prepares the exact number of rows requested by the client in cloud fetch. + # Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested + if arrow_table.num_rows > downloaded_file.row_count: + self.start_row_index += downloaded_file.row_count + return arrow_table.slice(0, downloaded_file.row_count) + + # At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows + assert downloaded_file.row_count == arrow_table.num_rows + self.start_row_index += arrow_table.num_rows + return arrow_table + + def _create_empty_table(self) -> pyarrow.Table: + # Create a 0-row table with just the schema bytes + return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) + + ExecuteResponse = namedtuple( "ExecuteResponse", "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " @@ -183,3 +374,99 @@ def escape_item(self, item): def inject_parameters(operation: str, parameters: Dict[str, str]): return operation % parameters + + +def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table: + arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) + return convert_decimals_in_arrow_table(arrow_table, description) + + +def convert_arrow_based_file_to_arrow_table(file_bytes: bytes): + try: + return pyarrow.ipc.open_stream(file_bytes).read_all() + except Exception as e: + raise RuntimeError("Failure to convert arrow based file to arrow table", e) + + +def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes): + ba = bytearray() + ba += schema_bytes + n_rows = 0 + for arrow_batch in arrow_batches: + n_rows += arrow_batch.rowCount + ba += ( + lz4.frame.decompress(arrow_batch.batch) + if lz4_compressed + else arrow_batch.batch + ) + arrow_table = pyarrow.ipc.open_stream(ba).read_all() + return arrow_table, n_rows + + +def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table: + for (i, col) in enumerate(table.itercolumns()): + if description[i][1] == "decimal": + decimal_col = col.to_pandas().apply( + lambda v: v if v is None else Decimal(v) + ) + precision, scale = description[i][4], description[i][5] + assert scale is not None + assert precision is not None + # Spark limits decimal to a maximum scale of 38, + # so 128 is guaranteed to be big enough + dtype = pyarrow.decimal128(precision, scale) + col_data = pyarrow.array(decimal_col, type=dtype) + field = table.field(i).with_type(dtype) + table = table.set_column(i, field, col_data) + return table + + +def convert_column_based_set_to_arrow_table(columns, description): + arrow_table = pyarrow.Table.from_arrays( + [_convert_column_to_arrow_array(c) for c in columns], + # Only use the column names from the schema, the types are determined by the + # physical types used in column based set, as they can differ from the + # mapping used in _hive_schema_to_arrow_schema. + names=[c[0] for c in description], + ) + return arrow_table, arrow_table.num_rows + + +def _convert_column_to_arrow_array(t_col): + """ + Return a pyarrow array from the values in a TColumn instance. + Note that ColumnBasedSet has no native support for complex types, so they will be converted + to strings server-side. + """ + field_name_to_arrow_type = { + "boolVal": pyarrow.bool_(), + "byteVal": pyarrow.int8(), + "i16Val": pyarrow.int16(), + "i32Val": pyarrow.int32(), + "i64Val": pyarrow.int64(), + "doubleVal": pyarrow.float64(), + "stringVal": pyarrow.string(), + "binaryVal": pyarrow.binary(), + } + for field in field_name_to_arrow_type.keys(): + wrapper = getattr(t_col, field) + if wrapper: + return _create_arrow_array(wrapper, field_name_to_arrow_type[field]) + + raise OperationalError("Empty TColumn instance {}".format(t_col)) + + +def _create_arrow_array(t_col_value_wrapper, arrow_type): + result = t_col_value_wrapper.values + nulls = t_col_value_wrapper.nulls # bitfield describing which values are null + assert isinstance(nulls, bytes) + + # The number of bits in nulls can be both larger or smaller than the number of + # elements in result, so take the minimum of both to iterate over. + length = min(len(result), len(nulls) * 8) + + for i in range(length): + if nulls[i >> 3] & BIT_MASKS[i & 0x7]: + result[i] = None + + return pyarrow.array(result, type=arrow_type) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py new file mode 100644 index 000000000..e5611ce62 --- /dev/null +++ b/tests/unit/test_cloud_fetch_queue.py @@ -0,0 +1,231 @@ +import pyarrow +import unittest +from unittest.mock import MagicMock, patch + +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink +import databricks.sql.utils as utils + + +class CloudFetchQueueSuite(unittest.TestCase): + + def create_result_link( + self, + file_link: str = "fileLink", + start_row_offset: int = 0, + row_count: int = 8000, + bytes_num: int = 20971520 + ): + return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) + + def create_result_links(self, num_files: int, start_row_offset: int = 0): + result_links = [] + for i in range(num_files): + file_link = "fileLink_" + str(i) + result_link = self.create_result_link(file_link=file_link, start_row_offset=start_row_offset) + result_links.append(result_link) + start_row_offset += result_link.rowCount + return result_links + + @staticmethod + def make_arrow_table(): + batch = [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + n_cols = len(batch[0]) if batch else 0 + schema = pyarrow.schema({"col%s" % i: pyarrow.uint32() for i in range(n_cols)}) + cols = [[batch[row][col] for row in range(len(batch))] for col in range(n_cols)] + return pyarrow.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) + + @staticmethod + def get_schema_bytes(): + schema = pyarrow.schema({"col%s" % i: pyarrow.uint32() for i in range(4)}) + sink = pyarrow.BufferOutputStream() + writer = pyarrow.ipc.RecordBatchStreamWriter(sink, schema) + writer.close() + return sink.getvalue().to_pybytes() + + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=[None, None]) + def test_initializer_adds_links(self, mock_create_next_table): + schema_bytes = MagicMock() + result_links = self.create_result_links(10) + queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) + + assert len(queue.download_manager.download_handlers) == 10 + mock_create_next_table.assert_called() + + def test_initializer_no_links_to_add(self): + schema_bytes = MagicMock() + result_links = [] + queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) + + assert len(queue.download_manager.download_handlers) == 0 + assert queue.table is None + + @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) + def test_create_next_table_no_download(self, mock_get_next_downloaded_file): + queue = utils.CloudFetchQueue(MagicMock(), result_links=[], max_download_threads=10) + + assert queue._create_next_table() is None + assert mock_get_next_downloaded_file.called_with(0) + + @patch("databricks.sql.utils.create_arrow_table_from_arrow_file") + @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", + return_value=MagicMock(file_bytes=b"1234567890", row_count=4)) + def test_initializer_create_next_table_success(self, mock_get_next_downloaded_file, mock_create_arrow_table): + mock_create_arrow_table.return_value = self.make_arrow_table() + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + expected_result = self.make_arrow_table() + + assert mock_create_arrow_table.called_with(b"1234567890", True, schema_bytes, description) + assert mock_get_next_downloaded_file.called_with(0) + assert queue.table == expected_result + assert queue.table.num_rows == 4 + assert queue.table_row_index == 0 + assert queue.start_row_index == 4 + + table = queue._create_next_table() + assert table == expected_result + assert table.num_rows == 4 + assert queue.start_row_index == 8 + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + def test_next_n_rows_0_rows(self, mock_create_next_table): + mock_create_next_table.return_value = self.make_arrow_table() + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table == self.make_arrow_table() + assert queue.table.num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(0) + assert result.num_rows == 0 + assert queue.table_row_index == 0 + assert result == self.make_arrow_table()[0:0] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + def test_next_n_rows_partial_table(self, mock_create_next_table): + mock_create_next_table.return_value = self.make_arrow_table() + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table == self.make_arrow_table() + assert queue.table.num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(3) + assert result.num_rows == 3 + assert queue.table_row_index == 3 + assert result == self.make_arrow_table()[:3] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + def test_next_n_rows_more_than_one_table(self, mock_create_next_table): + mock_create_next_table.return_value = self.make_arrow_table() + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table == self.make_arrow_table() + assert queue.table.num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(7) + assert result.num_rows == 7 + assert queue.table_row_index == 3 + assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + def test_next_n_rows_more_than_one_table(self, mock_create_next_table): + mock_create_next_table.return_value = self.make_arrow_table() + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table == self.make_arrow_table() + assert queue.table.num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(7) + assert result.num_rows == 7 + assert queue.table_row_index == 3 + assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), None] + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table == self.make_arrow_table() + assert queue.table.num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(7) + assert result.num_rows == 4 + assert result == self.make_arrow_table() + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) + def test_next_n_rows_empty_table(self, mock_create_next_table): + schema_bytes = self.get_schema_bytes() + description = MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table is None + + result = queue.next_n_rows(100) + assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table == self.make_arrow_table() + assert queue.table.num_rows == 4 + queue.table_row_index = 4 + + result = queue.remaining_rows() + assert result.num_rows == 0 + assert result == self.make_arrow_table()[0:0] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), None] + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table == self.make_arrow_table() + assert queue.table.num_rows == 4 + queue.table_row_index = 2 + + result = queue.remaining_rows() + assert result.num_rows == 2 + assert result == self.make_arrow_table()[2:] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), None] + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table == self.make_arrow_table() + assert queue.table.num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.remaining_rows() + assert result.num_rows == 4 + assert result == self.make_arrow_table() + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") + def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), self.make_arrow_table(), None] + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table == self.make_arrow_table() + assert queue.table.num_rows == 4 + queue.table_row_index = 3 + + result = queue.remaining_rows() + assert mock_create_next_table.call_count == 3 + assert result.num_rows == 5 + assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[3:] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) + def test_remaining_rows_empty_table(self, mock_create_next_table): + schema_bytes = self.get_schema_bytes() + description = MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) + assert queue.table is None + + result = queue.remaining_rows() + assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7ef0fa2ce..0a18c39a4 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -6,9 +6,9 @@ from ssl import CERT_NONE, CERT_REQUIRED import pyarrow -import urllib3 import databricks.sql +from databricks.sql import utils from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider @@ -327,7 +327,8 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertIn("some information about the error", str(cm.exception)) - def test_handle_execute_response_sets_compression_in_direct_results(self): + @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + def test_handle_execute_response_sets_compression_in_direct_results(self, build_queue): for resp_type in self.execute_response_types: lz4Compressed=Mock() resultSet=MagicMock() @@ -589,9 +590,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): self.assertEqual(hive_schema_mock, thrift_backend._hive_schema_to_arrow_schema.call_args[0][0]) + @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) @patch("databricks.sql.thrift_backend.TCLIService.Client") def test_handle_execute_response_reads_has_more_rows_in_direct_results( - self, tcli_service_class): + self, tcli_service_class, build_queue): for has_more_rows, resp_type in itertools.product([True, False], self.execute_response_types): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): @@ -622,9 +624,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( self.assertEqual(has_more_rows, execute_response.has_more_rows) + @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) @patch("databricks.sql.thrift_backend.TCLIService.Client") def test_handle_execute_response_reads_has_more_rows_in_result_response( - self, tcli_service_class): + self, tcli_service_class, build_queue): for has_more_rows, resp_type in itertools.product([True, False], self.execute_response_types): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): @@ -641,6 +644,9 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( status=self.okay_status, hasMoreRows=has_more_rows, results=results_mock, + resultSetMetadata=ttypes.TGetResultSetMetadataResp( + resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET + ) ) operation_status_resp = ttypes.TGetOperationStatusResp( @@ -677,7 +683,12 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): rows=[], arrowBatches=[ ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) for _ in range(10) - ])) + ] + ), + resultSetMetadata=ttypes.TGetResultSetMetadataResp( + resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET + ) + ) tcli_service_instance.FetchResults.return_value = t_fetch_results_resp schema = pyarrow.schema([ pyarrow.field("column1", pyarrow.int32()), @@ -875,8 +886,8 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) - @patch.object(ThriftBackend, "_convert_arrow_based_set_to_arrow_table") - @patch.object(ThriftBackend, "_convert_column_based_set_to_arrow_table") + @patch("databricks.sql.thrift_backend.convert_arrow_based_set_to_arrow_table") + @patch("databricks.sql.thrift_backend.convert_column_based_set_to_arrow_table") def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mock, convert_arrow_mock): thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) @@ -910,12 +921,11 @@ def test_convert_arrow_based_set_to_arrow_table(self, open_stream_mock, lz4_deco ]).serialize().to_pybytes() arrow_batches = [ttypes.TSparkArrowBatch(batch=bytearray('Testing','utf-8'), rowCount=1) for _ in range(10)] - thrift_backend._convert_arrow_based_set_to_arrow_table(arrow_batches, False, schema) + utils.convert_arrow_based_set_to_arrow_table(arrow_batches, False, schema) lz4_decompress_mock.assert_not_called() - thrift_backend._convert_arrow_based_set_to_arrow_table(arrow_batches, True, schema) + utils.convert_arrow_based_set_to_arrow_table(arrow_batches, True, schema) lz4_decompress_mock.assert_called() - def test_convert_column_based_set_to_arrow_table_without_nulls(self): # Deliberately duplicate the column name to check that dups work @@ -931,7 +941,7 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self): binaryVal=ttypes.TBinaryColumn(values=[b'\x11', b'\x22', b'\x33'], nulls=bytes(1))) ] - arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table( + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( t_cols, description) self.assertEqual(n_rows, 3) @@ -967,7 +977,7 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self): values=[b'\x11', b'\x22', b'\x33'], nulls=bytes([3]))) ] - arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table( + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( t_cols, description) self.assertEqual(n_rows, 3) @@ -990,7 +1000,7 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): binaryVal=ttypes.TBinaryColumn(values=[b'\x11', b'\x22', b'\x33'], nulls=bytes(1))) ] - arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table( + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( t_cols, description) self.assertEqual(n_rows, 3) @@ -1094,7 +1104,7 @@ def test_make_request_will_retry_GetOperationStatus( @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos): - + import urllib3.exceptions mock_gos.side_effect = urllib3.exceptions.HTTPError("Read timed out") @@ -1133,7 +1143,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( self.assertEqual(NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"]) self.assertEqual(f'{EXPECTED_RETRIES}/{EXPECTED_RETRIES}', cm.exception.context["attempt"]) - + @patch("thrift.transport.THttpClient.THttpClient") @@ -1252,7 +1262,7 @@ def test_arrow_decimal_conversion(self): table, description = self.make_table_and_desc(height, n_decimal_cols, width, precision, scale, int_constant, decimal_constant) - decimal_converted_table = ThriftBackend._convert_decimals_in_arrow_table( + decimal_converted_table = utils.convert_decimals_in_arrow_table( table, description) for i in range(width):