diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index f7e3c7c082..bf16ec5ec3 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -2466,36 +2466,43 @@ def _check_pyarrow_schema_compatible( def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]: for file_path in file_paths: - input_file = io.new_input(file_path) - with input_file.open() as input_stream: - parquet_metadata = pq.read_metadata(input_stream) + data_file = parquet_file_to_data_file(io=io, table_metadata=table_metadata, file_path=file_path) + yield data_file - if visit_pyarrow(parquet_metadata.schema.to_arrow_schema(), _HasIds()): - raise NotImplementedError( - f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids" - ) - schema = table_metadata.schema() - _check_pyarrow_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema()) - statistics = data_file_statistics_from_parquet_metadata( - parquet_metadata=parquet_metadata, - stats_columns=compute_statistics_plan(schema, table_metadata.properties), - parquet_column_mapping=parquet_path_to_id_mapping(schema), - ) - data_file = DataFile( - content=DataFileContent.DATA, - file_path=file_path, - file_format=FileFormat.PARQUET, - partition=statistics.partition(table_metadata.spec(), table_metadata.schema()), - file_size_in_bytes=len(input_file), - sort_order_id=None, - spec_id=table_metadata.default_spec_id, - equality_ids=None, - key_metadata=None, - **statistics.to_serialized_dict(), +def parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, file_path: str) -> DataFile: + input_file = io.new_input(file_path) + with input_file.open() as input_stream: + parquet_metadata = pq.read_metadata(input_stream) + + arrow_schema = parquet_metadata.schema.to_arrow_schema() + if visit_pyarrow(arrow_schema, _HasIds()): + raise NotImplementedError( + f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids" ) - yield data_file + schema = table_metadata.schema() + _check_pyarrow_schema_compatible(schema, arrow_schema) + + statistics = data_file_statistics_from_parquet_metadata( + parquet_metadata=parquet_metadata, + stats_columns=compute_statistics_plan(schema, table_metadata.properties), + parquet_column_mapping=parquet_path_to_id_mapping(schema), + ) + data_file = DataFile( + content=DataFileContent.DATA, + file_path=file_path, + file_format=FileFormat.PARQUET, + partition=statistics.partition(table_metadata.spec(), table_metadata.schema()), + file_size_in_bytes=len(input_file), + sort_order_id=None, + spec_id=table_metadata.default_spec_id, + equality_ids=None, + key_metadata=None, + **statistics.to_serialized_dict(), + ) + + return data_file ICEBERG_UNCOMPRESSED_CODEC = "uncompressed" diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index e625b848b2..45620bce0d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1891,6 +1891,9 @@ def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List Returns: An iterable that supplies DataFiles that describe the parquet files. """ - from pyiceberg.io.pyarrow import parquet_files_to_data_files + from pyiceberg.io.pyarrow import parquet_file_to_data_file - yield from parquet_files_to_data_files(io=io, table_metadata=table_metadata, file_paths=iter(file_paths)) + executor = ExecutorFactory.get_or_create() + futures = [executor.submit(parquet_file_to_data_file, io, table_metadata, file_path) for file_path in file_paths] + + return [f.result() for f in futures if f.result()] diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 8713615218..bfbc8db668 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -16,10 +16,13 @@ # under the License. # pylint:disable=redefined-outer-name +import multiprocessing import os import re +import threading from datetime import date from typing import Iterator +from unittest import mock import pyarrow as pa import pyarrow.parquet as pq @@ -31,9 +34,11 @@ from pyiceberg.exceptions import NoSuchTableError from pyiceberg.io import FileIO from pyiceberg.io.pyarrow import UnsupportedPyArrowTypeException, _pyarrow_schema_ensure_large_types +from pyiceberg.manifest import DataFile from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table +from pyiceberg.table.metadata import TableMetadata from pyiceberg.transforms import BucketTransform, IdentityTransform, MonthTransform from pyiceberg.types import ( BooleanType, @@ -229,6 +234,54 @@ def test_add_files_to_unpartitioned_table_raises_has_field_ids( tbl.add_files(file_paths=file_paths) +@pytest.mark.integration +def test_add_files_parallelized(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + from pyiceberg.io.pyarrow import parquet_file_to_data_file + + real_parquet_file_to_data_file = parquet_file_to_data_file + + lock = threading.Lock() + unique_threads_seen = set() + cpu_count = multiprocessing.cpu_count() + + # patch the function _parquet_file_to_data_file to we can track how many unique thread IDs + # it was executed from + with mock.patch("pyiceberg.io.pyarrow.parquet_file_to_data_file") as patch_func: + + def mock_parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, file_path: str) -> DataFile: + lock.acquire() + thread_id = threading.get_ident() # the current thread ID + unique_threads_seen.add(thread_id) + lock.release() + return real_parquet_file_to_data_file(io=io, table_metadata=table_metadata, file_path=file_path) + + patch_func.side_effect = mock_parquet_file_to_data_file + + identifier = f"default.unpartitioned_table_schema_updates_v{format_version}" + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [ + f"s3://warehouse/default/add_files_parallel/v{format_version}/test-{i}.parquet" for i in range(cpu_count * 2) + ] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + tbl.add_files(file_paths=file_paths) + + # duration creation of threadpool processor, when max_workers is not + # specified, python will add cpu_count + 4 as the number of threads in the + # pool in this case + # https://github.com/python/cpython/blob/e06bebb87e1b33f7251196e1ddb566f528c3fc98/Lib/concurrent/futures/thread.py#L173-L181 + # we check that we have at least seen the number of threads. we don't + # specify the workers in the thread pool and we can't check without + # accessing private attributes of ThreadPoolExecutor + assert len(unique_threads_seen) >= cpu_count + + @pytest.mark.integration def test_add_files_to_unpartitioned_table_with_schema_updates( spark: SparkSession, session_catalog: Catalog, format_version: int