diff --git a/src/databricks/sql/cloudfetch/downloader.py b/src/databricks/sql/cloudfetch/downloader.py new file mode 100644 index 000000000..d3c4a480f --- /dev/null +++ b/src/databricks/sql/cloudfetch/downloader.py @@ -0,0 +1,151 @@ +import logging + +import requests +import lz4.frame +import threading +import time + +from databricks.sql.thrift_api.TCLIService.ttypes import TSparkArrowResultLink + +logger = logging.getLogger(__name__) + + +class ResultSetDownloadHandler(threading.Thread): + def __init__( + self, + downloadable_result_settings, + t_spark_arrow_result_link: TSparkArrowResultLink, + ): + super().__init__() + self.settings = downloadable_result_settings + self.result_link = t_spark_arrow_result_link + self.is_download_scheduled = False + self.is_download_finished = threading.Event() + self.is_file_downloaded_successfully = False + self.is_link_expired = False + self.is_download_timedout = False + self.result_file = None + + def is_file_download_successful(self) -> bool: + """ + Check and report if cloud fetch file downloaded successfully. + + This function will block until a file download finishes or until a timeout. + """ + timeout = self.settings.download_timeout + timeout = timeout if timeout and timeout > 0 else None + try: + if not self.is_download_finished.wait(timeout=timeout): + self.is_download_timedout = True + logger.debug( + "Cloud fetch download timed out after {} seconds for link representing rows {} to {}".format( + self.settings.download_timeout, + self.result_link.startRowOffset, + self.result_link.startRowOffset + self.result_link.rowCount, + ) + ) + return False + except Exception as e: + logger.error(e) + return False + return self.is_file_downloaded_successfully + + def run(self): + """ + Download the file described in the cloud fetch link. + + This function checks if the link has or is expiring, gets the file via a requests session, decompresses the + file, and signals to waiting threads that the download is finished and whether it was successful. + """ + self._reset() + + # Check if link is already expired or is expiring + if ResultSetDownloadHandler.check_link_expired( + self.result_link, self.settings.link_expiry_buffer_secs + ): + self.is_link_expired = True + return + + session = requests.Session() + session.timeout = self.settings.download_timeout + + try: + # Get the file via HTTP request + response = session.get(self.result_link.fileLink) + + if not response.ok: + self.is_file_downloaded_successfully = False + return + + # Save (and decompress if needed) the downloaded file + compressed_data = response.content + decompressed_data = ( + ResultSetDownloadHandler.decompress_data(compressed_data) + if self.settings.is_lz4_compressed + else compressed_data + ) + self.result_file = decompressed_data + + # The size of the downloaded file should match the size specified from TSparkArrowResultLink + self.is_file_downloaded_successfully = ( + len(self.result_file) == self.result_link.bytesNum + ) + except Exception as e: + logger.error(e) + self.is_file_downloaded_successfully = False + + finally: + session and session.close() + # Awaken threads waiting for this to be true which signals the run is complete + self.is_download_finished.set() + + def _reset(self): + """ + Reset download-related flags for every retry of run() + """ + self.is_file_downloaded_successfully = False + self.is_link_expired = False + self.is_download_timedout = False + self.is_download_finished = threading.Event() + + @staticmethod + def check_link_expired( + link: TSparkArrowResultLink, expiry_buffer_secs: int + ) -> bool: + """ + Check if a link has expired or will expire. + + Expiry buffer can be set to avoid downloading files that has not expired yet when the function is called, + but may expire before the file has fully downloaded. + """ + current_time = int(time.time()) + if ( + link.expiryTime < current_time + or link.expiryTime - current_time < expiry_buffer_secs + ): + return True + return False + + @staticmethod + def decompress_data(compressed_data: bytes) -> bytes: + """ + Decompress lz4 frame compressed data. + + Decompresses data that has been lz4 compressed, either via the whole frame or by series of chunks. + """ + uncompressed_data, bytes_read = lz4.frame.decompress( + compressed_data, return_bytes_read=True + ) + # The last cloud fetch file of the entire result is commonly punctuated by frequent end-of-frame markers. + # Full frame decompression above will short-circuit, so chunking is necessary + if bytes_read < len(compressed_data): + d_context = lz4.frame.create_decompression_context() + start = 0 + uncompressed_data = bytearray() + while start < len(compressed_data): + data, num_bytes, is_end = lz4.frame.decompress_chunk( + d_context, compressed_data[start:] + ) + uncompressed_data += data + start += num_bytes + return uncompressed_data diff --git a/tests/unit/test_downloader.py b/tests/unit/test_downloader.py new file mode 100644 index 000000000..cee3a83c7 --- /dev/null +++ b/tests/unit/test_downloader.py @@ -0,0 +1,155 @@ +import unittest +from unittest.mock import Mock, patch, MagicMock + +import databricks.sql.cloudfetch.downloader as downloader + + +class DownloaderTests(unittest.TestCase): + """ + Unit tests for checking downloader logic. + """ + + @patch('time.time', return_value=1000) + def test_run_link_expired(self, mock_time): + settings = Mock() + result_link = Mock() + # Already expired + result_link.expiryTime = 999 + d = downloader.ResultSetDownloadHandler(settings, result_link) + assert not d.is_link_expired + d.run() + assert d.is_link_expired + mock_time.assert_called_once() + + @patch('time.time', return_value=1000) + def test_run_link_past_expiry_buffer(self, mock_time): + settings = Mock(link_expiry_buffer_secs=5) + result_link = Mock() + # Within the expiry buffer time + result_link.expiryTime = 1004 + d = downloader.ResultSetDownloadHandler(settings, result_link) + assert not d.is_link_expired + d.run() + assert d.is_link_expired + mock_time.assert_called_once() + + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=False)))) + @patch('time.time', return_value=1000) + def test_run_get_response_not_ok(self, mock_time, mock_session): + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0) + settings.download_timeout = 0 + settings.use_proxy = False + result_link = Mock(expiryTime=1001) + + d = downloader.ResultSetDownloadHandler(settings, result_link) + d.run() + + assert not d.is_file_downloaded_successfully + assert d.is_download_finished.is_set() + + @patch('requests.Session', + return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 9)))) + @patch('time.time', return_value=1000) + def test_run_uncompressed_data_length_incorrect(self, mock_time, mock_session): + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False, is_lz4_compressed=False) + result_link = Mock(bytesNum=100, expiryTime=1001) + + d = downloader.ResultSetDownloadHandler(settings, result_link) + d.run() + + assert not d.is_file_downloaded_successfully + assert d.is_download_finished.is_set() + + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) + @patch('time.time', return_value=1000) + def test_run_compressed_data_length_incorrect(self, mock_time, mock_session): + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) + settings.is_lz4_compressed = True + result_link = Mock(bytesNum=100, expiryTime=1001) + mock_session.return_value.get.return_value.content = \ + b'\x04"M\x18h@Z\x00\x00\x00\x00\x00\x00\x00\xec\x14\x00\x00\x00\xaf1234567890\n\x008P67890\x00\x00\x00\x00' + + d = downloader.ResultSetDownloadHandler(settings, result_link) + d.run() + + assert not d.is_file_downloaded_successfully + assert d.is_download_finished.is_set() + + @patch('requests.Session', + return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True, content=b"1234567890" * 10)))) + @patch('time.time', return_value=1000) + def test_run_uncompressed_successful(self, mock_time, mock_session): + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) + settings.is_lz4_compressed = False + result_link = Mock(bytesNum=100, expiryTime=1001) + + d = downloader.ResultSetDownloadHandler(settings, result_link) + d.run() + + assert d.result_file == b"1234567890" * 10 + assert d.is_file_downloaded_successfully + assert d.is_download_finished.is_set() + + @patch('requests.Session', return_value=MagicMock(get=MagicMock(return_value=MagicMock(ok=True)))) + @patch('time.time', return_value=1000) + def test_run_compressed_successful(self, mock_time, mock_session): + settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False) + settings.is_lz4_compressed = True + result_link = Mock(bytesNum=100, expiryTime=1001) + mock_session.return_value.get.return_value.content = \ + b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' + + d = downloader.ResultSetDownloadHandler(settings, result_link) + d.run() + + assert d.result_file == b"1234567890" * 10 + assert d.is_file_downloaded_successfully + assert d.is_download_finished.is_set() + + @patch('requests.Session.get', side_effect=ConnectionError('foo')) + @patch('time.time', return_value=1000) + def test_download_connection_error(self, mock_time, mock_session): + settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) + result_link = Mock(bytesNum=100, expiryTime=1001) + mock_session.return_value.get.return_value.content = \ + b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' + + d = downloader.ResultSetDownloadHandler(settings, result_link) + d.run() + + assert not d.is_file_downloaded_successfully + assert d.is_download_finished.is_set() + + @patch('requests.Session.get', side_effect=TimeoutError('foo')) + @patch('time.time', return_value=1000) + def test_download_timeout(self, mock_time, mock_session): + settings = Mock(link_expiry_buffer_secs=0, use_proxy=False, is_lz4_compressed=True) + result_link = Mock(bytesNum=100, expiryTime=1001) + mock_session.return_value.get.return_value.content = \ + b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00' + + d = downloader.ResultSetDownloadHandler(settings, result_link) + d.run() + + assert not d.is_file_downloaded_successfully + assert d.is_download_finished.is_set() + + @patch("threading.Event.wait", return_value=True) + def test_is_file_download_successful_has_finished(self, mock_wait): + for timeout in [None, 0, 1]: + with self.subTest(timeout=timeout): + settings = Mock(download_timeout=timeout) + result_link = Mock() + handler = downloader.ResultSetDownloadHandler(settings, result_link) + + status = handler.is_file_download_successful() + assert status == handler.is_file_downloaded_successfully + + def test_is_file_download_successful_times_outs(self): + settings = Mock(download_timeout=1) + result_link = Mock() + handler = downloader.ResultSetDownloadHandler(settings, result_link) + + status = handler.is_file_download_successful() + assert not status + assert handler.is_download_timedout