From 02c9ccb7b32f047b39aa4a54f87379a95a89258e Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Sun, 11 Aug 2024 16:20:33 -0700 Subject: [PATCH 1/3] fix project_batches with limit --- pyiceberg/io/pyarrow.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index b99c3b1702..719d289717 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1455,6 +1455,9 @@ def project_batches( total_row_count = 0 for task in tasks: + # stop early if limit is satisfied + if limit is not None and total_row_count >= limit: + break batches = _task_to_record_batches( fs, task, @@ -1468,9 +1471,10 @@ def project_batches( ) for batch in batches: if limit is not None: - if total_row_count + len(batch) >= limit: - yield batch.slice(0, limit - total_row_count) + if total_row_count >= limit: break + elif total_row_count + len(batch) >= limit: + batch = batch.slice(0, limit - total_row_count) yield batch total_row_count += len(batch) From 18fb8e6281f6334732ae6655a1570d7048259a47 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Sun, 11 Aug 2024 16:21:02 -0700 Subject: [PATCH 2/3] add test --- tests/integration/test_reads.py | 45 +++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 15f284be1f..0cf8634056 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -244,6 +244,51 @@ def test_pyarrow_limit(catalog: Catalog) -> None: full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow() assert len(full_result) == 10 + # test `to_arrow_batch_reader` + limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all() + assert len(limited_result) == 1 + + empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all() + assert len(empty_result) == 0 + + full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all() + assert len(full_result) == 10 + +@pytest.mark.integration +@pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) +def test_pyarrow_limit_with_multiple_files(catalog: Catalog) -> None: + table_name = "default.test_pyarrow_limit_with_multiple_files" + try: + catalog.drop_table(table_name) + except: + pass + reference_table = catalog.load_table("default.test_limit") + data = reference_table.scan().to_arrow() + table_test_limit = catalog.create_table(table_name, schema=reference_table.schema()) + table_test_limit.append(data) + table_test_limit.append(data) + assert len(table_test_limit.inspect.files()) == 2 + + # # test with multiple files + limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow() + assert len(limited_result) == 1 + + empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow() + assert len(empty_result) == 0 + + full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow() + assert len(full_result) == 10 * 2 + + # test `to_arrow_batch_reader` + limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all() + assert len(limited_result) == 1 + + empty_result = table_test_limit.scan(selected_fields=("idx",), limit=0).to_arrow_batch_reader().read_all() + assert len(empty_result) == 0 + + full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all() + assert len(full_result) == 10 * 2 + @pytest.mark.integration @pytest.mark.filterwarnings("ignore") From fad689f0d048d2c563c70e45b53aee15d15e6b2d Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Sun, 11 Aug 2024 16:32:09 -0700 Subject: [PATCH 3/3] lint + readability --- tests/integration/test_reads.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/integration/test_reads.py b/tests/integration/test_reads.py index 0cf8634056..cbfd64e194 100644 --- a/tests/integration/test_reads.py +++ b/tests/integration/test_reads.py @@ -254,22 +254,25 @@ def test_pyarrow_limit(catalog: Catalog) -> None: full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all() assert len(full_result) == 10 + @pytest.mark.integration @pytest.mark.parametrize("catalog", [pytest.lazy_fixture("session_catalog_hive"), pytest.lazy_fixture("session_catalog")]) def test_pyarrow_limit_with_multiple_files(catalog: Catalog) -> None: table_name = "default.test_pyarrow_limit_with_multiple_files" try: catalog.drop_table(table_name) - except: + except NoSuchTableError: pass reference_table = catalog.load_table("default.test_limit") data = reference_table.scan().to_arrow() table_test_limit = catalog.create_table(table_name, schema=reference_table.schema()) - table_test_limit.append(data) - table_test_limit.append(data) - assert len(table_test_limit.inspect.files()) == 2 - # # test with multiple files + n_files = 2 + for _ in range(n_files): + table_test_limit.append(data) + assert len(table_test_limit.inspect.files()) == n_files + + # test with multiple files limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow() assert len(limited_result) == 1 @@ -277,7 +280,7 @@ def test_pyarrow_limit_with_multiple_files(catalog: Catalog) -> None: assert len(empty_result) == 0 full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow() - assert len(full_result) == 10 * 2 + assert len(full_result) == 10 * n_files # test `to_arrow_batch_reader` limited_result = table_test_limit.scan(selected_fields=("idx",), limit=1).to_arrow_batch_reader().read_all() @@ -287,7 +290,7 @@ def test_pyarrow_limit_with_multiple_files(catalog: Catalog) -> None: assert len(empty_result) == 0 full_result = table_test_limit.scan(selected_fields=("idx",), limit=999).to_arrow_batch_reader().read_all() - assert len(full_result) == 10 * 2 + assert len(full_result) == 10 * n_files @pytest.mark.integration