From 78489c6260b87e3bb151e70b525795cef2a30fac Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Wed, 21 Jun 2023 09:58:19 -0700 Subject: [PATCH 01/11] Cloud fetch queue and integration Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/client.py | 9 + src/databricks/sql/thrift_backend.py | 136 ++++----------- src/databricks/sql/utils.py | 242 ++++++++++++++++++++++++++- tests/unit/test_cloud_fetch_queue.py | 223 ++++++++++++++++++++++++ tests/unit/test_thrift_backend.py | 33 ++-- 5 files changed, 518 insertions(+), 125 deletions(-) create mode 100644 tests/unit/test_cloud_fetch_queue.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 14e59df67..e347e36a8 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -153,6 +153,10 @@ def read(self) -> Optional[OAuthToken]: # _use_arrow_native_timestamps # Databricks runtime will return native Arrow types for timestamps instead of Arrow strings # (True by default) + # use_cloud_fetch + # Enable use of cloud fetch to extract large query results in parallel via cloud storage + # max_download_threads + # Number of threads for handling cloud fetch downloads. Defaults to 10 if access_token: access_token_kv = {"access_token": access_token} @@ -189,6 +193,8 @@ def read(self) -> Optional[OAuthToken]: self._session_handle = self.thrift_backend.open_session( session_configuration, catalog, schema ) + self.use_cloud_fetch = kwargs.get("use_cloud_fetch", False) + self.max_download_threads = kwargs.get("max_download_threads", 10) self.open = True logger.info("Successfully opened session " + str(self.get_session_id_hex())) self._cursors = [] # type: List[Cursor] @@ -497,6 +503,7 @@ def execute( max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, cursor=self, + use_cloud_fetch=self.connection.use_cloud_fetch, ) self.active_result_set = ResultSet( self.connection, @@ -804,6 +811,7 @@ def __init__( self.description = execute_response.description self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._next_row_index = 0 + self.results = None if execute_response.arrow_queue: # In this case the server has taken the fast path and returned an initial batch of @@ -822,6 +830,7 @@ def __iter__(self): break def _fill_results_buffer(self): + # At initialization or if the server does not have cloud fetch result links available results, has_more_rows = self.thrift_backend.fetch_results( op_handle=self.command_id, max_rows=self.arraysize, diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 7756c56a1..4829b0e68 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -5,7 +5,6 @@ import time import uuid import threading -import lz4.frame from ssl import CERT_NONE, CERT_REQUIRED, create_default_context from typing import List, Union @@ -31,6 +30,10 @@ _bound, RequestErrorInfo, NoRetryReason, + ResultSetQueueFactory, + convert_arrow_based_set_to_arrow_table, + convert_decimals_in_arrow_table, + convert_column_based_set_to_arrow_table, ) logger = logging.getLogger(__name__) @@ -67,7 +70,6 @@ class ThriftBackend: CLOSED_OP_STATE = ttypes.TOperationState.CLOSED_STATE ERROR_OP_STATE = ttypes.TOperationState.ERROR_STATE - BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] def __init__( self, @@ -558,108 +560,19 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti ( arrow_table, num_rows, - ) = ThriftBackend._convert_column_based_set_to_arrow_table( + ) = convert_column_based_set_to_arrow_table( t_row_set.columns, description ) elif t_row_set.arrowBatches is not None: ( arrow_table, num_rows, - ) = ThriftBackend._convert_arrow_based_set_to_arrow_table( + ) = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set)) - return self._convert_decimals_in_arrow_table(arrow_table, description), num_rows - - @staticmethod - def _convert_decimals_in_arrow_table(table, description): - for (i, col) in enumerate(table.itercolumns()): - if description[i][1] == "decimal": - decimal_col = col.to_pandas().apply( - lambda v: v if v is None else Decimal(v) - ) - precision, scale = description[i][4], description[i][5] - assert scale is not None - assert precision is not None - # Spark limits decimal to a maximum scale of 38, - # so 128 is guaranteed to be big enough - dtype = pyarrow.decimal128(precision, scale) - col_data = pyarrow.array(decimal_col, type=dtype) - field = table.field(i).with_type(dtype) - table = table.set_column(i, field, col_data) - return table - - @staticmethod - def _convert_arrow_based_set_to_arrow_table( - arrow_batches, lz4_compressed, schema_bytes - ): - ba = bytearray() - ba += schema_bytes - n_rows = 0 - if lz4_compressed: - for arrow_batch in arrow_batches: - n_rows += arrow_batch.rowCount - ba += lz4.frame.decompress(arrow_batch.batch) - else: - for arrow_batch in arrow_batches: - n_rows += arrow_batch.rowCount - ba += arrow_batch.batch - arrow_table = pyarrow.ipc.open_stream(ba).read_all() - return arrow_table, n_rows - - @staticmethod - def _convert_column_based_set_to_arrow_table(columns, description): - arrow_table = pyarrow.Table.from_arrays( - [ThriftBackend._convert_column_to_arrow_array(c) for c in columns], - # Only use the column names from the schema, the types are determined by the - # physical types used in column based set, as they can differ from the - # mapping used in _hive_schema_to_arrow_schema. - names=[c[0] for c in description], - ) - return arrow_table, arrow_table.num_rows - - @staticmethod - def _convert_column_to_arrow_array(t_col): - """ - Return a pyarrow array from the values in a TColumn instance. - Note that ColumnBasedSet has no native support for complex types, so they will be converted - to strings server-side. - """ - field_name_to_arrow_type = { - "boolVal": pyarrow.bool_(), - "byteVal": pyarrow.int8(), - "i16Val": pyarrow.int16(), - "i32Val": pyarrow.int32(), - "i64Val": pyarrow.int64(), - "doubleVal": pyarrow.float64(), - "stringVal": pyarrow.string(), - "binaryVal": pyarrow.binary(), - } - for field in field_name_to_arrow_type.keys(): - wrapper = getattr(t_col, field) - if wrapper: - return ThriftBackend._create_arrow_array( - wrapper, field_name_to_arrow_type[field] - ) - - raise OperationalError("Empty TColumn instance {}".format(t_col)) - - @staticmethod - def _create_arrow_array(t_col_value_wrapper, arrow_type): - result = t_col_value_wrapper.values - nulls = t_col_value_wrapper.nulls # bitfield describing which values are null - assert isinstance(nulls, bytes) - - # The number of bits in nulls can be both larger or smaller than the number of - # elements in result, so take the minimum of both to iterate over. - length = min(len(result), len(nulls) * 8) - - for i in range(length): - if nulls[i >> 3] & ThriftBackend.BIT_MASKS[i & 0x7]: - result[i] = None - - return pyarrow.array(result, type=arrow_type) + return convert_decimals_in_arrow_table(arrow_table, description), num_rows def _get_metadata_resp(self, op_handle): req = ttypes.TGetResultSetMetadataReq(operationHandle=op_handle) @@ -752,6 +665,7 @@ def _results_message_to_execute_response(self, resp, operation_state): if t_result_set_metadata_resp.resultFormat not in [ ttypes.TSparkRowSetType.ARROW_BASED_SET, ttypes.TSparkRowSetType.COLUMN_BASED_SET, + ttypes.TSparkRowSetType.URL_BASED_SET, ]: raise OperationalError( "Expected results to be in Arrow or column based format, " @@ -783,13 +697,16 @@ def _results_message_to_execute_response(self, resp, operation_state): assert direct_results.resultSet.results.startRowOffset == 0 assert direct_results.resultSetMetadata - arrow_results, n_rows = self._create_arrow_table( - direct_results.resultSet.results, - lz4_compressed, - schema_bytes, - description, - ) - arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0) + if direct_results.resultSet.results.resultLinks is None: + arrow_results, n_rows = self._create_arrow_table( + direct_results.resultSet.results, + lz4_compressed, + schema_bytes, + description, + ) + arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0) + else: + arrow_queue_opt = None else: arrow_queue_opt = None return ExecuteResponse( @@ -843,7 +760,7 @@ def _check_direct_results_for_error(t_spark_direct_results): ) def execute_command( - self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor + self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor, use_cloud_fetch=False ): assert session_handle is not None @@ -864,7 +781,7 @@ def execute_command( ), canReadArrowResult=True, canDecompressLZ4Result=lz4_compression, - canDownloadResult=False, + canDownloadResult=use_cloud_fetch, confOverlay={ # We want to receive proper Timestamp arrow types. "spark.thriftserver.arrowBasedRowSet.timestampAsString": "false" @@ -993,6 +910,7 @@ def fetch_results( maxRows=max_rows, maxBytes=max_bytes, orientation=ttypes.TFetchOrientation.FETCH_NEXT, + includeResultSetMetadata=True, ) resp = self.make_request(self._client.FetchResults, req) @@ -1002,12 +920,16 @@ def fetch_results( expected_row_start_offset, resp.results.startRowOffset ) ) - arrow_results, n_rows = self._create_arrow_table( - resp.results, lz4_compressed, arrow_schema_bytes, description + + queue = ResultSetQueueFactory.build_queue( + row_set_type=resp.resultSetMetadata.resultFormat, + t_row_set=resp.results, + arrow_schema_bytes=arrow_schema_bytes, + lz4_compressed=lz4_compressed, + description=description, ) - arrow_queue = ArrowQueue(arrow_results, n_rows) - return arrow_queue, resp.hasMoreRows + return queue, resp.hasMoreRows def close_command(self, op_handle): req = ttypes.TCloseOperationReq(operationHandle=op_handle) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index ed5581367..d70f1273b 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -1,16 +1,71 @@ +from abc import ABC, abstractmethod from collections import namedtuple, OrderedDict from collections.abc import Iterable -import datetime, decimal +from decimal import Decimal +import datetime +import decimal from enum import Enum -from typing import Dict +import lz4.frame +from typing import Dict, List import pyarrow -from databricks.sql import exc +from databricks.sql import exc, OperationalError +from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink, TSparkRowSetType, TRowSet +DEFAULT_MAX_DOWNLOAD_THREADS = 10 +BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] -class ArrowQueue: + +class ResultSetQueue(ABC): + @abstractmethod + def next_n_rows(self, num_rows: int) -> pyarrow.Table: + pass + + @abstractmethod + def remaining_rows(self) -> pyarrow.Table: + pass + + +class ResultSetQueueFactory(ABC): + @staticmethod + def build_queue( + row_set_type: TSparkRowSetType, + t_row_set: TRowSet, + arrow_schema_bytes, + lz4_compressed: bool = True, + description: str = None, + ) -> ResultSetQueue: + if row_set_type == TSparkRowSetType.ARROW_BASED_SET: + arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( + t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes + ) + converted_arrow_table = convert_decimals_in_arrow_table(arrow_table, description) + return ArrowQueue(converted_arrow_table, n_valid_rows) + elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: + arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table( + t_row_set.columns, description + ) + converted_arrow_table = convert_decimals_in_arrow_table(arrow_table, description) + return ArrowQueue(converted_arrow_table, n_valid_rows) + elif row_set_type == TSparkRowSetType.URL_BASED_SET: + return CloudFetchQueue( + arrow_schema_bytes, + start_row_index=t_row_set.startRowOffset, + result_links=t_row_set.resultLinks, + lz4_compressed=lz4_compressed, + description=description + ) + else: + raise AssertionError("Row set type is not valid") + + +class ArrowQueue(ResultSetQueue): def __init__( - self, arrow_table: pyarrow.Table, n_valid_rows: int, start_row_index: int = 0 + self, + arrow_table: pyarrow.Table, + n_valid_rows: int, + start_row_index: int = 0, ): """ A queue-like wrapper over an Arrow table @@ -40,6 +95,80 @@ def remaining_rows(self) -> pyarrow.Table: return slice +class CloudFetchQueue(ResultSetQueue): + def __init__( + self, + schema_bytes, + start_row_index: int = 0, + result_links: List[TSparkArrowResultLink] = None, + lz4_compressed: bool = True, + description: str = None, + ): + """ + A queue-like wrapper over CloudFetch arrow batches + """ + self.schema_bytes = schema_bytes + self.start_row_index = start_row_index + self.result_links = result_links + self.lz4_compressed = lz4_compressed + self.description = description + self.max_download_threads = DEFAULT_MAX_DOWNLOAD_THREADS + + self.download_manager = ResultFileDownloadManager(self.max_download_threads, self.lz4_compressed) + self.download_manager.add_file_links(result_links, start_row_index) + + self.table, self.table_num_rows = self._create_next_table() + self.table_row_index = 0 + + def next_n_rows(self, num_rows: int) -> pyarrow.Table: + if not self.table: + # Return empty pyarrow table to cause retry of fetch + return self._create_empty_table() + results = self.table.slice(0, 0) + while num_rows > 0 and self.table: + length = min(num_rows, self.table_num_rows - self.table_row_index) + table_slice = self.table.slice(self.table_row_index, length) + results = pyarrow.concat_tables([results, table_slice]) + self.table_row_index += table_slice.num_rows + if self.table_row_index == self.table_num_rows: + self.table, self.table_num_rows = self._create_next_table() + self.table_row_index = 0 + num_rows -= table_slice.num_rows + return results + + def remaining_rows(self) -> pyarrow.Table: + if not self.table: + # Return empty pyarrow table to cause retry of fetch + return self._create_empty_table() + results = self.table.slice(0, 0) + while self.table: + table_slice = self.table.slice( + self.table_row_index, self.table_num_rows - self.table_row_index + ) + results = pyarrow.concat_tables([results, table_slice]) + self.table_row_index += table_slice.num_rows + self.table, self.table_num_rows = self._create_next_table() + self.table_row_index = 0 + return results + + def _create_next_table(self): + # TODO: add retry logic from _fill_results_buffer_cloudfetch + downloaded_file = self.download_manager.get_next_downloaded_file(self.start_row_index) + if not downloaded_file: + return None, 0 + arrow_table = create_arrow_table_from_arrow_file( + downloaded_file.file_bytes, self.description) + if arrow_table.num_rows > downloaded_file.row_count: + self.start_row_index += downloaded_file.row_count + return arrow_table.slice(0, downloaded_file.row_count), downloaded_file.row_count + assert downloaded_file.row_count == arrow_table.num_rows + self.start_row_index += arrow_table.num_rows + return arrow_table, arrow_table.num_rows + + def _create_empty_table(self): + return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) + + ExecuteResponse = namedtuple( "ExecuteResponse", "status has_been_closed_server_side has_more_rows description lz4_compressed is_staging_operation " @@ -183,3 +312,106 @@ def escape_item(self, item): def inject_parameters(operation: str, parameters: Dict[str, str]): return operation % parameters + + +def create_arrow_table_from_arrow_file( + arrow_batches, description +) -> (pyarrow.Table, int): + arrow_table = convert_arrow_based_file_to_arrow_table(arrow_batches) + return convert_decimals_in_arrow_table(arrow_table, description) + + +def convert_arrow_based_file_to_arrow_table(arrow_bytes): + try: + return pyarrow.ipc.open_stream(arrow_bytes).read_all() + except Exception as e: + raise RuntimeError("Failure to convert arrow based file to arrow table", e) + + +def convert_arrow_based_set_to_arrow_table( + arrow_batches, lz4_compressed, schema_bytes +): + ba = bytearray() + ba += schema_bytes + n_rows = 0 + if lz4_compressed: + for arrow_batch in arrow_batches: + n_rows += arrow_batch.rowCount + ba += lz4.frame.decompress(arrow_batch.batch) + else: + for arrow_batch in arrow_batches: + n_rows += arrow_batch.rowCount + ba += arrow_batch.batch + arrow_table = pyarrow.ipc.open_stream(ba).read_all() + return arrow_table, n_rows + + +def convert_decimals_in_arrow_table(table, description) -> pyarrow.Table: + for (i, col) in enumerate(table.itercolumns()): + if description[i][1] == "decimal": + decimal_col = col.to_pandas().apply( + lambda v: v if v is None else Decimal(v) + ) + precision, scale = description[i][4], description[i][5] + assert scale is not None + assert precision is not None + # Spark limits decimal to a maximum scale of 38, + # so 128 is guaranteed to be big enough + dtype = pyarrow.decimal128(precision, scale) + col_data = pyarrow.array(decimal_col, type=dtype) + field = table.field(i).with_type(dtype) + table = table.set_column(i, field, col_data) + return table + + +def convert_column_based_set_to_arrow_table(columns, description): + arrow_table = pyarrow.Table.from_arrays( + [_convert_column_to_arrow_array(c) for c in columns], + # Only use the column names from the schema, the types are determined by the + # physical types used in column based set, as they can differ from the + # mapping used in _hive_schema_to_arrow_schema. + names=[c[0] for c in description], + ) + return arrow_table, arrow_table.num_rows + + +def _convert_column_to_arrow_array(t_col): + """ + Return a pyarrow array from the values in a TColumn instance. + Note that ColumnBasedSet has no native support for complex types, so they will be converted + to strings server-side. + """ + field_name_to_arrow_type = { + "boolVal": pyarrow.bool_(), + "byteVal": pyarrow.int8(), + "i16Val": pyarrow.int16(), + "i32Val": pyarrow.int32(), + "i64Val": pyarrow.int64(), + "doubleVal": pyarrow.float64(), + "stringVal": pyarrow.string(), + "binaryVal": pyarrow.binary(), + } + for field in field_name_to_arrow_type.keys(): + wrapper = getattr(t_col, field) + if wrapper: + return _create_arrow_array( + wrapper, field_name_to_arrow_type[field] + ) + + raise OperationalError("Empty TColumn instance {}".format(t_col)) + + +def _create_arrow_array(t_col_value_wrapper, arrow_type): + result = t_col_value_wrapper.values + nulls = t_col_value_wrapper.nulls # bitfield describing which values are null + assert isinstance(nulls, bytes) + + # The number of bits in nulls can be both larger or smaller than the number of + # elements in result, so take the minimum of both to iterate over. + length = min(len(result), len(nulls) * 8) + + for i in range(length): + if nulls[i >> 3] & BIT_MASKS[i & 0x7]: + result[i] = None + + return pyarrow.array(result, type=arrow_type) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py new file mode 100644 index 000000000..be64b85d1 --- /dev/null +++ b/tests/unit/test_cloud_fetch_queue.py @@ -0,0 +1,223 @@ +import pyarrow +import unittest +from unittest.mock import MagicMock, patch + +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink +import databricks.sql.utils as utils + + +class CloudFetchQueueSuite(unittest.TestCase): + + def create_result_link( + self, + file_link: str = "fileLink", + start_row_offset: int = 0, + row_count: int = 8000, + bytes_num: int = 20971520 + ): + return TSparkArrowResultLink(file_link, None, start_row_offset, row_count, bytes_num) + + def create_result_links(self, num_files: int, start_row_offset: int = 0): + result_links = [] + for i in range(num_files): + file_link = "fileLink_" + str(i) + result_link = self.create_result_link(file_link=file_link, start_row_offset=start_row_offset) + result_links.append(result_link) + start_row_offset += result_link.rowCount + return result_links + + @staticmethod + def make_arrow_table(): + batch = [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]] + n_cols = len(batch[0]) if batch else 0 + schema = pyarrow.schema({"col%s" % i: pyarrow.uint32() for i in range(n_cols)}) + cols = [[batch[row][col] for row in range(len(batch))] for col in range(n_cols)] + return pyarrow.Table.from_pydict(dict(zip(schema.names, cols)), schema=schema) + + @staticmethod + def get_schema_bytes(): + schema = pyarrow.schema({"col%s" % i: pyarrow.uint32() for i in range(4)}) + sink = pyarrow.BufferOutputStream() + writer = pyarrow.ipc.RecordBatchStreamWriter(sink, schema) + writer.close() + return sink.getvalue().to_pybytes() + + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=[None, None]) + def test_initializer_adds_links(self, mock_create_next_table): + schema_bytes = MagicMock() + result_links = self.create_result_links(10) + queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links) + + assert len(queue.download_manager.download_handlers) == 10 + mock_create_next_table.assert_called() + + def test_initializer_no_links_to_add(self): + schema_bytes = MagicMock() + result_links = [] + queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links) + + assert len(queue.download_manager.download_handlers) == 0 + assert queue.table is None + assert queue.table_num_rows == 0 + + @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) + def test_create_next_table_no_download(self, mock_get_next_downloaded_file): + queue = utils.CloudFetchQueue(MagicMock(), result_links=[]) + + assert queue._create_next_table() == (None, 0) + assert mock_get_next_downloaded_file.called_with(0) + + @patch("databricks.sql.utils.create_arrow_table_from_arrow_file", return_value=make_arrow_table()) + @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", + return_value=MagicMock(file_bytes=b"1234567890", row_count=4)) + def test_initializer_create_next_table_success(self, mock_get_next_downloaded_file, mock_create_arrow_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + expected_result = self.make_arrow_table() + + assert mock_create_arrow_table.called_with(b"1234567890", True, schema_bytes, description) + assert mock_get_next_downloaded_file.called_with(0) + assert queue.table == expected_result + assert queue.table_num_rows == 4 + assert queue.table_row_index == 0 + assert queue.start_row_index == 4 + + table, row_count = queue._create_next_table() + assert table == expected_result + assert row_count == 4 + assert queue.start_row_index == 8 + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(make_arrow_table(), 4)) + def test_next_n_rows_0_rows(self, mock_create_next_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table == self.make_arrow_table() + assert queue.table_num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(0) + assert result.num_rows == 0 + assert queue.table_row_index == 0 + assert result == self.make_arrow_table()[0:0] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(make_arrow_table(), 4)) + def test_next_n_rows_partial_table(self, mock_create_next_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table == self.make_arrow_table() + assert queue.table_num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(3) + assert result.num_rows == 3 + assert queue.table_row_index == 3 + assert result == self.make_arrow_table()[:3] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(make_arrow_table(), 4)) + def test_next_n_rows_more_than_one_table(self, mock_create_next_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table == self.make_arrow_table() + assert queue.table_num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(7) + assert result.num_rows == 7 + assert queue.table_row_index == 3 + assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(make_arrow_table(), 4)) + def test_next_n_rows_more_than_one_table(self, mock_create_next_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table == self.make_arrow_table() + assert queue.table_num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(7) + assert result.num_rows == 7 + assert queue.table_row_index == 3 + assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[(make_arrow_table(), 4), (None, 0)]) + def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table == self.make_arrow_table() + assert queue.table_num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.next_n_rows(7) + assert result.num_rows == 4 + assert result == self.make_arrow_table() + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(None, 0)) + def test_next_n_rows_empty_table(self, mock_create_next_table): + schema_bytes = self.get_schema_bytes() + description = MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table is None + + result = queue.next_n_rows(100) + assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[(make_arrow_table(), 4), (None, 0)]) + def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table == self.make_arrow_table() + assert queue.table_num_rows == 4 + queue.table_row_index = 4 + + result = queue.remaining_rows() + assert result.num_rows == 0 + assert result == self.make_arrow_table()[0:0] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[(make_arrow_table(), 4), (None, 0)]) + def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table == self.make_arrow_table() + assert queue.table_num_rows == 4 + queue.table_row_index = 2 + + result = queue.remaining_rows() + assert result.num_rows == 2 + assert result == self.make_arrow_table()[2:] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[(make_arrow_table(), 4), (None, 0)]) + def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table == self.make_arrow_table() + assert queue.table_num_rows == 4 + assert queue.table_row_index == 0 + + result = queue.remaining_rows() + assert result.num_rows == 4 + assert result == self.make_arrow_table() + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", + side_effect=[(make_arrow_table(), 4), (make_arrow_table(), 4), (None, 0)]) + def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): + schema_bytes, description = MagicMock(), MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table == self.make_arrow_table() + assert queue.table_num_rows == 4 + queue.table_row_index = 3 + + result = queue.remaining_rows() + assert mock_create_next_table.call_count == 3 + assert result.num_rows == 5 + assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[3:] + + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(None, 0)) + def test_remaining_rows_empty_table(self, mock_create_next_table): + schema_bytes = self.get_schema_bytes() + description = MagicMock() + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + assert queue.table is None + + result = queue.remaining_rows() + assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7ef0fa2ce..b44057e99 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -6,9 +6,9 @@ from ssl import CERT_NONE, CERT_REQUIRED import pyarrow -import urllib3 import databricks.sql +from databricks.sql import utils from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql import * from databricks.sql.auth.authenticators import AuthProvider @@ -641,6 +641,9 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response( status=self.okay_status, hasMoreRows=has_more_rows, results=results_mock, + resultSetMetadata=ttypes.TGetResultSetMetadataResp( + resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET + ) ) operation_status_resp = ttypes.TGetOperationStatusResp( @@ -677,7 +680,12 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class): rows=[], arrowBatches=[ ttypes.TSparkArrowBatch(batch=bytearray(), rowCount=15) for _ in range(10) - ])) + ] + ), + resultSetMetadata=ttypes.TGetResultSetMetadataResp( + resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET + ) + ) tcli_service_instance.FetchResults.return_value = t_fetch_results_resp schema = pyarrow.schema([ pyarrow.field("column1", pyarrow.int32()), @@ -875,8 +883,8 @@ def test_create_arrow_table_raises_error_for_unsupported_type(self): with self.assertRaises(OperationalError): thrift_backend._create_arrow_table(t_row_set, Mock(), None, Mock()) - @patch.object(ThriftBackend, "_convert_arrow_based_set_to_arrow_table") - @patch.object(ThriftBackend, "_convert_column_based_set_to_arrow_table") + @patch("databricks.sql.thrift_backend.convert_arrow_based_set_to_arrow_table") + @patch("databricks.sql.thrift_backend.convert_column_based_set_to_arrow_table") def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mock, convert_arrow_mock): thrift_backend = ThriftBackend("foobar", 443, "path", [], auth_provider=AuthProvider()) @@ -910,12 +918,11 @@ def test_convert_arrow_based_set_to_arrow_table(self, open_stream_mock, lz4_deco ]).serialize().to_pybytes() arrow_batches = [ttypes.TSparkArrowBatch(batch=bytearray('Testing','utf-8'), rowCount=1) for _ in range(10)] - thrift_backend._convert_arrow_based_set_to_arrow_table(arrow_batches, False, schema) + utils.convert_arrow_based_set_to_arrow_table(arrow_batches, False, schema) lz4_decompress_mock.assert_not_called() - thrift_backend._convert_arrow_based_set_to_arrow_table(arrow_batches, True, schema) + utils.convert_arrow_based_set_to_arrow_table(arrow_batches, True, schema) lz4_decompress_mock.assert_called() - def test_convert_column_based_set_to_arrow_table_without_nulls(self): # Deliberately duplicate the column name to check that dups work @@ -931,7 +938,7 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self): binaryVal=ttypes.TBinaryColumn(values=[b'\x11', b'\x22', b'\x33'], nulls=bytes(1))) ] - arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table( + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( t_cols, description) self.assertEqual(n_rows, 3) @@ -967,7 +974,7 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self): values=[b'\x11', b'\x22', b'\x33'], nulls=bytes([3]))) ] - arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table( + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( t_cols, description) self.assertEqual(n_rows, 3) @@ -990,7 +997,7 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self): binaryVal=ttypes.TBinaryColumn(values=[b'\x11', b'\x22', b'\x33'], nulls=bytes(1))) ] - arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table( + arrow_table, n_rows = utils.convert_column_based_set_to_arrow_table( t_cols, description) self.assertEqual(n_rows, 3) @@ -1094,7 +1101,7 @@ def test_make_request_will_retry_GetOperationStatus( @patch("databricks.sql.thrift_backend._retry_policy", new_callable=retry_policy_factory) def test_make_request_will_retry_GetOperationStatus_for_http_error( self, mock_retry_policy, mock_gos): - + import urllib3.exceptions mock_gos.side_effect = urllib3.exceptions.HTTPError("Read timed out") @@ -1133,7 +1140,7 @@ def test_make_request_will_retry_GetOperationStatus_for_http_error( self.assertEqual(NoRetryReason.OUT_OF_ATTEMPTS.value, cm.exception.context["no-retry-reason"]) self.assertEqual(f'{EXPECTED_RETRIES}/{EXPECTED_RETRIES}', cm.exception.context["attempt"]) - + @patch("thrift.transport.THttpClient.THttpClient") @@ -1252,7 +1259,7 @@ def test_arrow_decimal_conversion(self): table, description = self.make_table_and_desc(height, n_decimal_cols, width, precision, scale, int_constant, decimal_constant) - decimal_converted_table = ThriftBackend._convert_decimals_in_arrow_table( + decimal_converted_table = utils.convert_decimals_in_arrow_table( table, description) for i in range(width): From 1f8813d7601a3a22610890533dc9085487511000 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Thu, 22 Jun 2023 16:43:17 -0700 Subject: [PATCH 02/11] Enable cloudfetch with direct results Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/thrift_backend.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 4829b0e68..49f6643ee 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -697,16 +697,13 @@ def _results_message_to_execute_response(self, resp, operation_state): assert direct_results.resultSet.results.startRowOffset == 0 assert direct_results.resultSetMetadata - if direct_results.resultSet.results.resultLinks is None: - arrow_results, n_rows = self._create_arrow_table( - direct_results.resultSet.results, - lz4_compressed, - schema_bytes, - description, - ) - arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0) - else: - arrow_queue_opt = None + arrow_queue_opt = ResultSetQueueFactory.build_queue( + row_set_type=t_result_set_metadata_resp.resultFormat, + t_row_set=direct_results.resultSet.results, + arrow_schema_bytes=schema_bytes, + lz4_compressed=lz4_compressed, + description=description, + ) else: arrow_queue_opt = None return ExecuteResponse( From f0f720fc78932b766e58ac1bcd77a5ace62f9885 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 26 Jun 2023 04:23:47 -0700 Subject: [PATCH 03/11] Typing and style changes Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/utils.py | 48 ++++++++++++--------------- tests/unit/test_cloud_fetch_queue.py | 49 ++++++++++++++-------------- 2 files changed, 44 insertions(+), 53 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index d70f1273b..986a7a9c0 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -6,7 +6,7 @@ import decimal from enum import Enum import lz4.frame -from typing import Dict, List +from typing import Dict, List, Union import pyarrow from databricks.sql import exc, OperationalError @@ -117,7 +117,7 @@ def __init__( self.download_manager = ResultFileDownloadManager(self.max_download_threads, self.lz4_compressed) self.download_manager.add_file_links(result_links, start_row_index) - self.table, self.table_num_rows = self._create_next_table() + self.table = self._create_next_table() self.table_row_index = 0 def next_n_rows(self, num_rows: int) -> pyarrow.Table: @@ -126,12 +126,12 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: return self._create_empty_table() results = self.table.slice(0, 0) while num_rows > 0 and self.table: - length = min(num_rows, self.table_num_rows - self.table_row_index) + length = min(num_rows, self.table.num_rows - self.table_row_index) table_slice = self.table.slice(self.table_row_index, length) results = pyarrow.concat_tables([results, table_slice]) self.table_row_index += table_slice.num_rows - if self.table_row_index == self.table_num_rows: - self.table, self.table_num_rows = self._create_next_table() + if self.table_row_index == self.table.num_rows: + self.table = self._create_next_table() self.table_row_index = 0 num_rows -= table_slice.num_rows return results @@ -143,29 +143,28 @@ def remaining_rows(self) -> pyarrow.Table: results = self.table.slice(0, 0) while self.table: table_slice = self.table.slice( - self.table_row_index, self.table_num_rows - self.table_row_index + self.table_row_index, self.table.num_rows - self.table_row_index ) results = pyarrow.concat_tables([results, table_slice]) self.table_row_index += table_slice.num_rows - self.table, self.table_num_rows = self._create_next_table() + self.table = self._create_next_table() self.table_row_index = 0 return results - def _create_next_table(self): + def _create_next_table(self) -> Union[pyarrow.Table, None]: # TODO: add retry logic from _fill_results_buffer_cloudfetch downloaded_file = self.download_manager.get_next_downloaded_file(self.start_row_index) if not downloaded_file: - return None, 0 - arrow_table = create_arrow_table_from_arrow_file( - downloaded_file.file_bytes, self.description) + return None + arrow_table = create_arrow_table_from_arrow_file(downloaded_file.file_bytes, self.description) if arrow_table.num_rows > downloaded_file.row_count: self.start_row_index += downloaded_file.row_count - return arrow_table.slice(0, downloaded_file.row_count), downloaded_file.row_count + return arrow_table.slice(0, downloaded_file.row_count) assert downloaded_file.row_count == arrow_table.num_rows self.start_row_index += arrow_table.num_rows - return arrow_table, arrow_table.num_rows + return arrow_table - def _create_empty_table(self): + def _create_empty_table(self) -> pyarrow.Table: return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) @@ -314,16 +313,14 @@ def inject_parameters(operation: str, parameters: Dict[str, str]): return operation % parameters -def create_arrow_table_from_arrow_file( - arrow_batches, description -) -> (pyarrow.Table, int): - arrow_table = convert_arrow_based_file_to_arrow_table(arrow_batches) +def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> (pyarrow.Table, int): + arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) -def convert_arrow_based_file_to_arrow_table(arrow_bytes): +def convert_arrow_based_file_to_arrow_table(file_bytes: bytes): try: - return pyarrow.ipc.open_stream(arrow_bytes).read_all() + return pyarrow.ipc.open_stream(file_bytes).read_all() except Exception as e: raise RuntimeError("Failure to convert arrow based file to arrow table", e) @@ -334,14 +331,9 @@ def convert_arrow_based_set_to_arrow_table( ba = bytearray() ba += schema_bytes n_rows = 0 - if lz4_compressed: - for arrow_batch in arrow_batches: - n_rows += arrow_batch.rowCount - ba += lz4.frame.decompress(arrow_batch.batch) - else: - for arrow_batch in arrow_batches: - n_rows += arrow_batch.rowCount - ba += arrow_batch.batch + for arrow_batch in arrow_batches: + n_rows += arrow_batch.rowCount + ba += lz4.frame.decompress(arrow_batch.batch) if lz4_compressed else arrow_batch.batch arrow_table = pyarrow.ipc.open_stream(ba).read_all() return arrow_table, n_rows diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index be64b85d1..4ffa8e13a 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -59,13 +59,12 @@ def test_initializer_no_links_to_add(self): assert len(queue.download_manager.download_handlers) == 0 assert queue.table is None - assert queue.table_num_rows == 0 @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): queue = utils.CloudFetchQueue(MagicMock(), result_links=[]) - assert queue._create_next_table() == (None, 0) + assert queue._create_next_table() is None assert mock_get_next_downloaded_file.called_with(0) @patch("databricks.sql.utils.create_arrow_table_from_arrow_file", return_value=make_arrow_table()) @@ -79,21 +78,21 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi assert mock_create_arrow_table.called_with(b"1234567890", True, schema_bytes, description) assert mock_get_next_downloaded_file.called_with(0) assert queue.table == expected_result - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 assert queue.table_row_index == 0 assert queue.start_row_index == 4 - table, row_count = queue._create_next_table() + table = queue._create_next_table() assert table == expected_result - assert row_count == 4 + assert table.num_rows == 4 assert queue.start_row_index == 8 - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(make_arrow_table(), 4)) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_0_rows(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) assert queue.table == self.make_arrow_table() - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 assert queue.table_row_index == 0 result = queue.next_n_rows(0) @@ -101,12 +100,12 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): assert queue.table_row_index == 0 assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(make_arrow_table(), 4)) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_partial_table(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) assert queue.table == self.make_arrow_table() - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 assert queue.table_row_index == 0 result = queue.next_n_rows(3) @@ -114,12 +113,12 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == self.make_arrow_table()[:3] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(make_arrow_table(), 4)) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_more_than_one_table(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) assert queue.table == self.make_arrow_table() - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 assert queue.table_row_index == 0 result = queue.next_n_rows(7) @@ -127,12 +126,12 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(make_arrow_table(), 4)) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_more_than_one_table(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) assert queue.table == self.make_arrow_table() - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 assert queue.table_row_index == 0 result = queue.next_n_rows(7) @@ -140,19 +139,19 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[(make_arrow_table(), 4), (None, 0)]) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) assert queue.table == self.make_arrow_table() - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 assert queue.table_row_index == 0 result = queue.next_n_rows(7) assert result.num_rows == 4 assert result == self.make_arrow_table() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(None, 0)) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() @@ -162,36 +161,36 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): result = queue.next_n_rows(100) assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[(make_arrow_table(), 4), (None, 0)]) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None, 0]) def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) assert queue.table == self.make_arrow_table() - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 queue.table_row_index = 4 result = queue.remaining_rows() assert result.num_rows == 0 assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[(make_arrow_table(), 4), (None, 0)]) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) assert queue.table == self.make_arrow_table() - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 queue.table_row_index = 2 result = queue.remaining_rows() assert result.num_rows == 2 assert result == self.make_arrow_table()[2:] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[(make_arrow_table(), 4), (None, 0)]) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) assert queue.table == self.make_arrow_table() - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 assert queue.table_row_index == 0 result = queue.remaining_rows() @@ -199,12 +198,12 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): assert result == self.make_arrow_table() @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", - side_effect=[(make_arrow_table(), 4), (make_arrow_table(), 4), (None, 0)]) + side_effect=[make_arrow_table(), make_arrow_table(), None]) def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) assert queue.table == self.make_arrow_table() - assert queue.table_num_rows == 4 + assert queue.table.num_rows == 4 queue.table_row_index = 3 result = queue.remaining_rows() @@ -212,7 +211,7 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta assert result.num_rows == 5 assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[3:] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=(None, 0)) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=None) def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() From 2b50597abdf119a37af65def84d28095cdfb3c02 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 26 Jun 2023 06:11:13 -0700 Subject: [PATCH 04/11] Client-settable max_download_threads Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/client.py | 1 + src/databricks/sql/thrift_backend.py | 9 +++++++-- src/databricks/sql/utils.py | 8 +++++--- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e347e36a8..f9f8738b3 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -839,6 +839,7 @@ def _fill_results_buffer(self): lz4_compressed=self.lz4_compressed, arrow_schema_bytes=self._arrow_schema_bytes, description=self.description, + max_download_threads=self.connection.max_download_threads, ) self.results = results self.has_more_rows = has_more_rows diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 49f6643ee..5304d4e40 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -656,7 +656,7 @@ def _hive_schema_to_description(t_table_schema): ThriftBackend._col_to_description(col) for col in t_table_schema.columns ] - def _results_message_to_execute_response(self, resp, operation_state): + def _results_message_to_execute_response(self, resp, operation_state, max_download_threads): if resp.directResults and resp.directResults.resultSetMetadata: t_result_set_metadata_resp = resp.directResults.resultSetMetadata else: @@ -703,6 +703,7 @@ def _results_message_to_execute_response(self, resp, operation_state): arrow_schema_bytes=schema_bytes, lz4_compressed=lz4_compressed, description=description, + max_download_threads=max_download_threads, ) else: arrow_queue_opt = None @@ -883,7 +884,9 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response(resp, final_operation_state) + max_download_threads = cursor.connection.max_download_threads + + return self._results_message_to_execute_response(resp, final_operation_state, max_download_threads) def fetch_results( self, @@ -894,6 +897,7 @@ def fetch_results( lz4_compressed, arrow_schema_bytes, description, + max_download_threads, ): assert op_handle is not None @@ -924,6 +928,7 @@ def fetch_results( arrow_schema_bytes=arrow_schema_bytes, lz4_compressed=lz4_compressed, description=description, + max_download_threads=max_download_threads, ) return queue, resp.hasMoreRows diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 986a7a9c0..3848a1377 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -35,6 +35,7 @@ def build_queue( arrow_schema_bytes, lz4_compressed: bool = True, description: str = None, + max_download_threads: int = DEFAULT_MAX_DOWNLOAD_THREADS, ) -> ResultSetQueue: if row_set_type == TSparkRowSetType.ARROW_BASED_SET: arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( @@ -54,7 +55,8 @@ def build_queue( start_row_index=t_row_set.startRowOffset, result_links=t_row_set.resultLinks, lz4_compressed=lz4_compressed, - description=description + description=description, + max_download_threads=max_download_threads, ) else: raise AssertionError("Row set type is not valid") @@ -99,6 +101,7 @@ class CloudFetchQueue(ResultSetQueue): def __init__( self, schema_bytes, + max_download_threads: int, start_row_index: int = 0, result_links: List[TSparkArrowResultLink] = None, lz4_compressed: bool = True, @@ -108,11 +111,11 @@ def __init__( A queue-like wrapper over CloudFetch arrow batches """ self.schema_bytes = schema_bytes + self.max_download_threads = max_download_threads self.start_row_index = start_row_index self.result_links = result_links self.lz4_compressed = lz4_compressed self.description = description - self.max_download_threads = DEFAULT_MAX_DOWNLOAD_THREADS self.download_manager = ResultFileDownloadManager(self.max_download_threads, self.lz4_compressed) self.download_manager.add_file_links(result_links, start_row_index) @@ -152,7 +155,6 @@ def remaining_rows(self) -> pyarrow.Table: return results def _create_next_table(self) -> Union[pyarrow.Table, None]: - # TODO: add retry logic from _fill_results_buffer_cloudfetch downloaded_file = self.download_manager.get_next_downloaded_file(self.start_row_index) if not downloaded_file: return None From 504b0008ae8b995e4da9e754a11355377701ac00 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 26 Jun 2023 20:37:48 -0700 Subject: [PATCH 05/11] Docstrings and comments Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/thrift_backend.py | 27 ++++--- src/databricks/sql/utils.py | 108 +++++++++++++++++++++------ tests/unit/test_cloud_fetch_queue.py | 30 ++++---- 3 files changed, 117 insertions(+), 48 deletions(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 5304d4e40..fd2ad7c20 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -25,7 +25,6 @@ ) from databricks.sql.utils import ( - ArrowQueue, ExecuteResponse, _bound, RequestErrorInfo, @@ -560,14 +559,9 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti ( arrow_table, num_rows, - ) = convert_column_based_set_to_arrow_table( - t_row_set.columns, description - ) + ) = convert_column_based_set_to_arrow_table(t_row_set.columns, description) elif t_row_set.arrowBatches is not None: - ( - arrow_table, - num_rows, - ) = convert_arrow_based_set_to_arrow_table( + (arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, schema_bytes ) else: @@ -656,7 +650,9 @@ def _hive_schema_to_description(t_table_schema): ThriftBackend._col_to_description(col) for col in t_table_schema.columns ] - def _results_message_to_execute_response(self, resp, operation_state, max_download_threads): + def _results_message_to_execute_response( + self, resp, operation_state, max_download_threads + ): if resp.directResults and resp.directResults.resultSetMetadata: t_result_set_metadata_resp = resp.directResults.resultSetMetadata else: @@ -758,7 +754,14 @@ def _check_direct_results_for_error(t_spark_direct_results): ) def execute_command( - self, operation, session_handle, max_rows, max_bytes, lz4_compression, cursor, use_cloud_fetch=False + self, + operation, + session_handle, + max_rows, + max_bytes, + lz4_compression, + cursor, + use_cloud_fetch=False, ): assert session_handle is not None @@ -886,7 +889,9 @@ def _handle_execute_response(self, resp, cursor): max_download_threads = cursor.connection.max_download_threads - return self._results_message_to_execute_response(resp, final_operation_state, max_download_threads) + return self._results_message_to_execute_response( + resp, final_operation_state, max_download_threads + ) def fetch_results( self, diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 3848a1377..255eb7f4e 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -11,7 +11,11 @@ from databricks.sql import exc, OperationalError from databricks.sql.cloudfetch.download_manager import ResultFileDownloadManager -from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink, TSparkRowSetType, TRowSet +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TSparkArrowResultLink, + TSparkRowSetType, + TRowSet, +) DEFAULT_MAX_DOWNLOAD_THREADS = 10 BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] @@ -32,27 +36,45 @@ class ResultSetQueueFactory(ABC): def build_queue( row_set_type: TSparkRowSetType, t_row_set: TRowSet, - arrow_schema_bytes, + arrow_schema_bytes: bytes, lz4_compressed: bool = True, - description: str = None, + description: List[List[any]] = None, max_download_threads: int = DEFAULT_MAX_DOWNLOAD_THREADS, ) -> ResultSetQueue: + """ + Factory method to build a result set queue. + + Args: + row_set_type (enum): Row set type (Arrow, Column, or URL). + t_row_set (TRowSet): Result containing arrow batches, columns, or cloud fetch links. + arrow_schema_bytes (bytes): Bytes representing the arrow schema. + lz4_compressed (bool): Whether result data has been lz4 compressed. + description (List[List[any]]): Hive table schema description. + max_download_threads (int): Maximum number of downloader thread pool threads. + + Returns: + ResultSetQueue + """ if row_set_type == TSparkRowSetType.ARROW_BASED_SET: arrow_table, n_valid_rows = convert_arrow_based_set_to_arrow_table( t_row_set.arrowBatches, lz4_compressed, arrow_schema_bytes ) - converted_arrow_table = convert_decimals_in_arrow_table(arrow_table, description) + converted_arrow_table = convert_decimals_in_arrow_table( + arrow_table, description + ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.COLUMN_BASED_SET: arrow_table, n_valid_rows = convert_column_based_set_to_arrow_table( t_row_set.columns, description ) - converted_arrow_table = convert_decimals_in_arrow_table(arrow_table, description) + converted_arrow_table = convert_decimals_in_arrow_table( + arrow_table, description + ) return ArrowQueue(converted_arrow_table, n_valid_rows) elif row_set_type == TSparkRowSetType.URL_BASED_SET: return CloudFetchQueue( arrow_schema_bytes, - start_row_index=t_row_set.startRowOffset, + start_row_offset=t_row_set.startRowOffset, result_links=t_row_set.resultLinks, lz4_compressed=lz4_compressed, description=description, @@ -102,37 +124,59 @@ def __init__( self, schema_bytes, max_download_threads: int, - start_row_index: int = 0, + start_row_offset: int = 0, result_links: List[TSparkArrowResultLink] = None, lz4_compressed: bool = True, - description: str = None, + description: List[List[any]] = None, ): """ - A queue-like wrapper over CloudFetch arrow batches + A queue-like wrapper over CloudFetch arrow batches. + + Attributes: + schema_bytes (bytes): Table schema in bytes. + max_download_threads (int): Maximum number of downloader thread pool threads. + start_row_offset (int): The offset of the first row of the cloud fetch links. + result_links (List[TSparkArrowResultLink]): Links containing the downloadable URL and metadata. + lz4_compressed (bool): Whether the files are lz4 compressed. + description (List[List[any]]): Hive table schema description. """ self.schema_bytes = schema_bytes self.max_download_threads = max_download_threads - self.start_row_index = start_row_index + self.start_row_index = start_row_offset self.result_links = result_links self.lz4_compressed = lz4_compressed self.description = description - self.download_manager = ResultFileDownloadManager(self.max_download_threads, self.lz4_compressed) - self.download_manager.add_file_links(result_links, start_row_index) + self.download_manager = ResultFileDownloadManager( + self.max_download_threads, self.lz4_compressed + ) + self.download_manager.add_file_links(result_links) self.table = self._create_next_table() self.table_row_index = 0 def next_n_rows(self, num_rows: int) -> pyarrow.Table: + """ + Get up to the next n rows of the cloud fetch Arrow dataframes. + + Args: + num_rows (int): Number of rows to retrieve. + + Returns: + pyarrow.Table + """ if not self.table: # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() results = self.table.slice(0, 0) while num_rows > 0 and self.table: + # Get remaining of num_rows or the rest of the current table, whichever is smaller length = min(num_rows, self.table.num_rows - self.table_row_index) table_slice = self.table.slice(self.table_row_index, length) results = pyarrow.concat_tables([results, table_slice]) self.table_row_index += table_slice.num_rows + + # Replace current table with the next table if we are at the end of the current table if self.table_row_index == self.table.num_rows: self.table = self._create_next_table() self.table_row_index = 0 @@ -140,6 +184,12 @@ def next_n_rows(self, num_rows: int) -> pyarrow.Table: return results def remaining_rows(self) -> pyarrow.Table: + """ + Get all remaining rows of the cloud fetch Arrow dataframes. + + Returns: + pyarrow.Table + """ if not self.table: # Return empty pyarrow table to cause retry of fetch return self._create_empty_table() @@ -155,18 +205,30 @@ def remaining_rows(self) -> pyarrow.Table: return results def _create_next_table(self) -> Union[pyarrow.Table, None]: - downloaded_file = self.download_manager.get_next_downloaded_file(self.start_row_index) + # Create next table by retrieving the logical next downloaded file, or return None to signal end of queue + downloaded_file = self.download_manager.get_next_downloaded_file( + self.start_row_index + ) if not downloaded_file: + # None signals no more Arrow tables can be built from the remaining handlers if any remain return None - arrow_table = create_arrow_table_from_arrow_file(downloaded_file.file_bytes, self.description) + arrow_table = create_arrow_table_from_arrow_file( + downloaded_file.file_bytes, self.description + ) + + # The server rarely prepares the exact number of rows requested by the client in cloud fetch. + # Subsequently, we drop the extraneous rows in the last file if more rows are retrieved than requested if arrow_table.num_rows > downloaded_file.row_count: self.start_row_index += downloaded_file.row_count return arrow_table.slice(0, downloaded_file.row_count) + + # At this point, whether the file has extraneous rows or not, the arrow table should have the correct num rows assert downloaded_file.row_count == arrow_table.num_rows self.start_row_index += arrow_table.num_rows return arrow_table def _create_empty_table(self) -> pyarrow.Table: + # Create a 0-row table with just the schema bytes return create_arrow_table_from_arrow_file(self.schema_bytes, self.description) @@ -315,7 +377,9 @@ def inject_parameters(operation: str, parameters: Dict[str, str]): return operation % parameters -def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> (pyarrow.Table, int): +def create_arrow_table_from_arrow_file( + file_bytes: bytes, description +) -> (pyarrow.Table, int): arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description) @@ -327,15 +391,17 @@ def convert_arrow_based_file_to_arrow_table(file_bytes: bytes): raise RuntimeError("Failure to convert arrow based file to arrow table", e) -def convert_arrow_based_set_to_arrow_table( - arrow_batches, lz4_compressed, schema_bytes -): +def convert_arrow_based_set_to_arrow_table(arrow_batches, lz4_compressed, schema_bytes): ba = bytearray() ba += schema_bytes n_rows = 0 for arrow_batch in arrow_batches: n_rows += arrow_batch.rowCount - ba += lz4.frame.decompress(arrow_batch.batch) if lz4_compressed else arrow_batch.batch + ba += ( + lz4.frame.decompress(arrow_batch.batch) + if lz4_compressed + else arrow_batch.batch + ) arrow_table = pyarrow.ipc.open_stream(ba).read_all() return arrow_table, n_rows @@ -388,9 +454,7 @@ def _convert_column_to_arrow_array(t_col): for field in field_name_to_arrow_type.keys(): wrapper = getattr(t_col, field) if wrapper: - return _create_arrow_array( - wrapper, field_name_to_arrow_type[field] - ) + return _create_arrow_array(wrapper, field_name_to_arrow_type[field]) raise OperationalError("Empty TColumn instance {}".format(t_col)) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 4ffa8e13a..7c8e4bf40 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -47,7 +47,7 @@ def get_schema_bytes(): def test_initializer_adds_links(self, mock_create_next_table): schema_bytes = MagicMock() result_links = self.create_result_links(10) - queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links) + queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) assert len(queue.download_manager.download_handlers) == 10 mock_create_next_table.assert_called() @@ -55,14 +55,14 @@ def test_initializer_adds_links(self, mock_create_next_table): def test_initializer_no_links_to_add(self): schema_bytes = MagicMock() result_links = [] - queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links) + queue = utils.CloudFetchQueue(schema_bytes, result_links=result_links, max_download_threads=10) assert len(queue.download_manager.download_handlers) == 0 assert queue.table is None @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=None) def test_create_next_table_no_download(self, mock_get_next_downloaded_file): - queue = utils.CloudFetchQueue(MagicMock(), result_links=[]) + queue = utils.CloudFetchQueue(MagicMock(), result_links=[], max_download_threads=10) assert queue._create_next_table() is None assert mock_get_next_downloaded_file.called_with(0) @@ -72,7 +72,7 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): return_value=MagicMock(file_bytes=b"1234567890", row_count=4)) def test_initializer_create_next_table_success(self, mock_get_next_downloaded_file, mock_create_arrow_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) expected_result = self.make_arrow_table() assert mock_create_arrow_table.called_with(b"1234567890", True, schema_bytes, description) @@ -90,7 +90,7 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_0_rows(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -103,7 +103,7 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_partial_table(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -116,7 +116,7 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_more_than_one_table(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -129,7 +129,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) def test_next_n_rows_more_than_one_table(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -142,7 +142,7 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -155,7 +155,7 @@ def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): def test_next_n_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table is None result = queue.next_n_rows(100) @@ -164,7 +164,7 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None, 0]) def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 4 @@ -176,7 +176,7 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 2 @@ -188,7 +188,7 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 assert queue.table_row_index == 0 @@ -201,7 +201,7 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): side_effect=[make_arrow_table(), make_arrow_table(), None]) def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): schema_bytes, description = MagicMock(), MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() assert queue.table.num_rows == 4 queue.table_row_index = 3 @@ -215,7 +215,7 @@ def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_ta def test_remaining_rows_empty_table(self, mock_create_next_table): schema_bytes = self.get_schema_bytes() description = MagicMock() - queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description) + queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table is None result = queue.remaining_rows() From 0ccb63f40aab13b30045ca85f2361b573dd508ee Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 3 Jul 2023 11:21:15 -0700 Subject: [PATCH 06/11] Increase default buffer size bytes to 104857600 Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index f9f8738b3..a8b1a6780 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -DEFAULT_RESULT_BUFFER_SIZE_BYTES = 10485760 +DEFAULT_RESULT_BUFFER_SIZE_BYTES = 104857600 DEFAULT_ARRAY_SIZE = 100000 From de99ec86f68448769e494d531e9b1f9cb265f002 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 3 Jul 2023 12:17:02 -0700 Subject: [PATCH 07/11] Move max_download_threads to kwargs of ThriftBackend, fix unit tests Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/client.py | 5 ----- src/databricks/sql/thrift_backend.py | 16 +++++++++------- src/databricks/sql/utils.py | 3 +-- tests/unit/test_thrift_backend.py | 9 ++++++--- 4 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index a8b1a6780..aa628441b 100644 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -155,8 +155,6 @@ def read(self) -> Optional[OAuthToken]: # (True by default) # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage - # max_download_threads - # Number of threads for handling cloud fetch downloads. Defaults to 10 if access_token: access_token_kv = {"access_token": access_token} @@ -194,7 +192,6 @@ def read(self) -> Optional[OAuthToken]: session_configuration, catalog, schema ) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", False) - self.max_download_threads = kwargs.get("max_download_threads", 10) self.open = True logger.info("Successfully opened session " + str(self.get_session_id_hex())) self._cursors = [] # type: List[Cursor] @@ -811,7 +808,6 @@ def __init__( self.description = execute_response.description self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._next_row_index = 0 - self.results = None if execute_response.arrow_queue: # In this case the server has taken the fast path and returned an initial batch of @@ -839,7 +835,6 @@ def _fill_results_buffer(self): lz4_compressed=self.lz4_compressed, arrow_schema_bytes=self._arrow_schema_bytes, description=self.description, - max_download_threads=self.connection.max_download_threads, ) self.results = results self.has_more_rows = has_more_rows diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index fd2ad7c20..7d57008cc 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -116,6 +116,8 @@ def __init__( # _socket_timeout # The timeout in seconds for socket send, recv and connect operations. Should be a positive float or integer. # (defaults to 900) + # max_download_threads + # Number of threads for handling cloud fetch downloads. Defaults to 10 port = port or 443 if kwargs.get("_connection_uri"): @@ -137,6 +139,9 @@ def __init__( "_use_arrow_native_timestamps", True ) + # Cloud fetch + self.max_download_threads = kwargs.get("max_download_threads", 10) + # Configure tls context ssl_context = create_default_context(cafile=kwargs.get("_tls_trusted_ca_file")) if kwargs.get("_tls_no_verify") is True: @@ -651,7 +656,7 @@ def _hive_schema_to_description(t_table_schema): ] def _results_message_to_execute_response( - self, resp, operation_state, max_download_threads + self, resp, operation_state ): if resp.directResults and resp.directResults.resultSetMetadata: t_result_set_metadata_resp = resp.directResults.resultSetMetadata @@ -697,9 +702,9 @@ def _results_message_to_execute_response( row_set_type=t_result_set_metadata_resp.resultFormat, t_row_set=direct_results.resultSet.results, arrow_schema_bytes=schema_bytes, + max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, - max_download_threads=max_download_threads, ) else: arrow_queue_opt = None @@ -887,10 +892,8 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - max_download_threads = cursor.connection.max_download_threads - return self._results_message_to_execute_response( - resp, final_operation_state, max_download_threads + resp, final_operation_state ) def fetch_results( @@ -902,7 +905,6 @@ def fetch_results( lz4_compressed, arrow_schema_bytes, description, - max_download_threads, ): assert op_handle is not None @@ -931,9 +933,9 @@ def fetch_results( row_set_type=resp.resultSetMetadata.resultFormat, t_row_set=resp.results, arrow_schema_bytes=arrow_schema_bytes, + max_download_threads=self.max_download_threads, lz4_compressed=lz4_compressed, description=description, - max_download_threads=max_download_threads, ) return queue, resp.hasMoreRows diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index 255eb7f4e..c03c32ca4 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -17,7 +17,6 @@ TRowSet, ) -DEFAULT_MAX_DOWNLOAD_THREADS = 10 BIT_MASKS = [1, 2, 4, 8, 16, 32, 64, 128] @@ -37,9 +36,9 @@ def build_queue( row_set_type: TSparkRowSetType, t_row_set: TRowSet, arrow_schema_bytes: bytes, + max_download_threads: int, lz4_compressed: bool = True, description: List[List[any]] = None, - max_download_threads: int = DEFAULT_MAX_DOWNLOAD_THREADS, ) -> ResultSetQueue: """ Factory method to build a result set queue. diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index b44057e99..0a18c39a4 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -327,7 +327,8 @@ def test_handle_execute_response_checks_operation_state_in_direct_results(self): thrift_backend._handle_execute_response(t_execute_resp, Mock()) self.assertIn("some information about the error", str(cm.exception)) - def test_handle_execute_response_sets_compression_in_direct_results(self): + @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) + def test_handle_execute_response_sets_compression_in_direct_results(self, build_queue): for resp_type in self.execute_response_types: lz4Compressed=Mock() resultSet=MagicMock() @@ -589,9 +590,10 @@ def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class): self.assertEqual(hive_schema_mock, thrift_backend._hive_schema_to_arrow_schema.call_args[0][0]) + @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) @patch("databricks.sql.thrift_backend.TCLIService.Client") def test_handle_execute_response_reads_has_more_rows_in_direct_results( - self, tcli_service_class): + self, tcli_service_class, build_queue): for has_more_rows, resp_type in itertools.product([True, False], self.execute_response_types): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): @@ -622,9 +624,10 @@ def test_handle_execute_response_reads_has_more_rows_in_direct_results( self.assertEqual(has_more_rows, execute_response.has_more_rows) + @patch("databricks.sql.utils.ResultSetQueueFactory.build_queue", return_value=Mock()) @patch("databricks.sql.thrift_backend.TCLIService.Client") def test_handle_execute_response_reads_has_more_rows_in_result_response( - self, tcli_service_class): + self, tcli_service_class, build_queue): for has_more_rows, resp_type in itertools.product([True, False], self.execute_response_types): with self.subTest(has_more_rows=has_more_rows, resp_type=resp_type): From 6f868d0439abcce81b5686dda35b036086067df3 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 3 Jul 2023 12:29:46 -0700 Subject: [PATCH 08/11] Fix tests: staticmethod make_arrow_table mock not callable Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- tests/unit/test_cloud_fetch_queue.py | 31 ++++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_cloud_fetch_queue.py b/tests/unit/test_cloud_fetch_queue.py index 7c8e4bf40..e5611ce62 100644 --- a/tests/unit/test_cloud_fetch_queue.py +++ b/tests/unit/test_cloud_fetch_queue.py @@ -67,10 +67,11 @@ def test_create_next_table_no_download(self, mock_get_next_downloaded_file): assert queue._create_next_table() is None assert mock_get_next_downloaded_file.called_with(0) - @patch("databricks.sql.utils.create_arrow_table_from_arrow_file", return_value=make_arrow_table()) + @patch("databricks.sql.utils.create_arrow_table_from_arrow_file") @patch("databricks.sql.cloudfetch.download_manager.ResultFileDownloadManager.get_next_downloaded_file", return_value=MagicMock(file_bytes=b"1234567890", row_count=4)) def test_initializer_create_next_table_success(self, mock_get_next_downloaded_file, mock_create_arrow_table): + mock_create_arrow_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) expected_result = self.make_arrow_table() @@ -87,8 +88,9 @@ def test_initializer_create_next_table_success(self, mock_get_next_downloaded_fi assert table.num_rows == 4 assert queue.start_row_index == 8 - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_0_rows(self, mock_create_next_table): + mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() @@ -100,8 +102,9 @@ def test_next_n_rows_0_rows(self, mock_create_next_table): assert queue.table_row_index == 0 assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_partial_table(self, mock_create_next_table): + mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() @@ -113,8 +116,9 @@ def test_next_n_rows_partial_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == self.make_arrow_table()[:3] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_more_than_one_table(self, mock_create_next_table): + mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() @@ -126,8 +130,9 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", return_value=make_arrow_table()) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_more_than_one_table(self, mock_create_next_table): + mock_create_next_table.return_value = self.make_arrow_table() schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() @@ -139,8 +144,9 @@ def test_next_n_rows_more_than_one_table(self, mock_create_next_table): assert queue.table_row_index == 3 assert result == pyarrow.concat_tables([self.make_arrow_table(), self.make_arrow_table()])[:7] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_next_n_rows_only_one_table_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() @@ -161,8 +167,9 @@ def test_next_n_rows_empty_table(self, mock_create_next_table): result = queue.next_n_rows(100) assert result == pyarrow.ipc.open_stream(bytearray(schema_bytes)).read_all() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None, 0]) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), None, 0] schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() @@ -173,8 +180,9 @@ def test_remaining_rows_empty_table_fully_returned(self, mock_create_next_table) assert result.num_rows == 0 assert result == self.make_arrow_table()[0:0] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() @@ -185,8 +193,9 @@ def test_remaining_rows_partial_table_fully_returned(self, mock_create_next_tabl assert result.num_rows == 2 assert result == self.make_arrow_table()[2:] - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", side_effect=[make_arrow_table(), None]) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() @@ -197,9 +206,9 @@ def test_remaining_rows_one_table_fully_returned(self, mock_create_next_table): assert result.num_rows == 4 assert result == self.make_arrow_table() - @patch("databricks.sql.utils.CloudFetchQueue._create_next_table", - side_effect=[make_arrow_table(), make_arrow_table(), None]) + @patch("databricks.sql.utils.CloudFetchQueue._create_next_table") def test_remaining_rows_multiple_tables_fully_returned(self, mock_create_next_table): + mock_create_next_table.side_effect = [self.make_arrow_table(), self.make_arrow_table(), None] schema_bytes, description = MagicMock(), MagicMock() queue = utils.CloudFetchQueue(schema_bytes, result_links=[], description=description, max_download_threads=10) assert queue.table == self.make_arrow_table() From 9096ccd31512014444c8a658e9936e15817fca0e Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 3 Jul 2023 12:48:57 -0700 Subject: [PATCH 09/11] cancel_futures in shutdown() only available in python >=3.9.0 Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/cloudfetch/download_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/cloudfetch/download_manager.py b/src/databricks/sql/cloudfetch/download_manager.py index aac3ac336..9a997f393 100644 --- a/src/databricks/sql/cloudfetch/download_manager.py +++ b/src/databricks/sql/cloudfetch/download_manager.py @@ -161,6 +161,6 @@ def _check_if_download_successful(self, handler: ResultSetDownloadHandler): return True def _shutdown_manager(self): - # Clear download handlers and shutdown the thread pool to cancel pending futures + # Clear download handlers and shutdown the thread pool self.download_handlers = [] - self.thread_pool.shutdown(wait=False, cancel_futures=True) + self.thread_pool.shutdown(wait=False) From edcf826732b7aa115f5ac95f2145fee912bd84df Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 3 Jul 2023 12:56:55 -0700 Subject: [PATCH 10/11] Black linting Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/thrift_backend.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 7d57008cc..ef225d1f5 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -655,9 +655,7 @@ def _hive_schema_to_description(t_table_schema): ThriftBackend._col_to_description(col) for col in t_table_schema.columns ] - def _results_message_to_execute_response( - self, resp, operation_state - ): + def _results_message_to_execute_response(self, resp, operation_state): if resp.directResults and resp.directResults.resultSetMetadata: t_result_set_metadata_resp = resp.directResults.resultSetMetadata else: @@ -892,9 +890,7 @@ def _handle_execute_response(self, resp, cursor): resp.directResults and resp.directResults.operationStatus, ) - return self._results_message_to_execute_response( - resp, final_operation_state - ) + return self._results_message_to_execute_response(resp, final_operation_state) def fetch_results( self, From 19e6a66047fae0f2b1f2496aebef0de968714339 Mon Sep 17 00:00:00 2001 From: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> Date: Mon, 3 Jul 2023 13:02:52 -0700 Subject: [PATCH 11/11] Fix typing errors Signed-off-by: Matthew Kim <11141331+mattdeekay@users.noreply.github.com> --- src/databricks/sql/utils.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/utils.py b/src/databricks/sql/utils.py index c03c32ca4..0aefc7a1c 100644 --- a/src/databricks/sql/utils.py +++ b/src/databricks/sql/utils.py @@ -6,7 +6,7 @@ import decimal from enum import Enum import lz4.frame -from typing import Dict, List, Union +from typing import Dict, List, Union, Any import pyarrow from databricks.sql import exc, OperationalError @@ -38,7 +38,7 @@ def build_queue( arrow_schema_bytes: bytes, max_download_threads: int, lz4_compressed: bool = True, - description: List[List[any]] = None, + description: List[List[Any]] = None, ) -> ResultSetQueue: """ Factory method to build a result set queue. @@ -48,7 +48,7 @@ def build_queue( t_row_set (TRowSet): Result containing arrow batches, columns, or cloud fetch links. arrow_schema_bytes (bytes): Bytes representing the arrow schema. lz4_compressed (bool): Whether result data has been lz4 compressed. - description (List[List[any]]): Hive table schema description. + description (List[List[Any]]): Hive table schema description. max_download_threads (int): Maximum number of downloader thread pool threads. Returns: @@ -126,7 +126,7 @@ def __init__( start_row_offset: int = 0, result_links: List[TSparkArrowResultLink] = None, lz4_compressed: bool = True, - description: List[List[any]] = None, + description: List[List[Any]] = None, ): """ A queue-like wrapper over CloudFetch arrow batches. @@ -137,7 +137,7 @@ def __init__( start_row_offset (int): The offset of the first row of the cloud fetch links. result_links (List[TSparkArrowResultLink]): Links containing the downloadable URL and metadata. lz4_compressed (bool): Whether the files are lz4 compressed. - description (List[List[any]]): Hive table schema description. + description (List[List[Any]]): Hive table schema description. """ self.schema_bytes = schema_bytes self.max_download_threads = max_download_threads @@ -376,9 +376,7 @@ def inject_parameters(operation: str, parameters: Dict[str, str]): return operation % parameters -def create_arrow_table_from_arrow_file( - file_bytes: bytes, description -) -> (pyarrow.Table, int): +def create_arrow_table_from_arrow_file(file_bytes: bytes, description) -> pyarrow.Table: arrow_table = convert_arrow_based_file_to_arrow_table(file_bytes) return convert_decimals_in_arrow_table(arrow_table, description)