From 7abfee903267e0c62f72df0fe25e6fe6aff572c4 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Wed, 19 Mar 2025 13:49:31 +0100 Subject: [PATCH 1/8] Move actual implementation of upsert from Table to Transaction --- pyiceberg/table/__init__.py | 179 +++++++++++++++++++++++------------- 1 file changed, 113 insertions(+), 66 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 9e9de52dee..d57b6463af 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -695,6 +695,115 @@ def delete( if not delete_snapshot.files_affected and not delete_snapshot.rewrites_needed: warnings.warn("Delete operation did not match any records") + def upsert( + self, + df: pa.Table, + join_cols: Optional[List[str]] = None, + when_matched_update_all: bool = True, + when_not_matched_insert_all: bool = True, + case_sensitive: bool = True, + ) -> UpsertResult: + """Shorthand API for performing an upsert to an iceberg table. + + Args: + + df: The input dataframe to upsert with the table's data. + join_cols: Columns to join on, if not provided, it will use the identifier-field-ids. + when_matched_update_all: Bool indicating to update rows that are matched but require an update due to a value in a non-key column changing + when_not_matched_insert_all: Bool indicating new rows to be inserted that do not match any existing rows in the table + case_sensitive: Bool indicating if the match should be case-sensitive + + To learn more about the identifier-field-ids: https://iceberg.apache.org/spec/#identifier-field-ids + + Example Use Cases: + Case 1: Both Parameters = True (Full Upsert) + Existing row found → Update it + New row found → Insert it + + Case 2: when_matched_update_all = False, when_not_matched_insert_all = True + Existing row found → Do nothing (no updates) + New row found → Insert it + + Case 3: when_matched_update_all = True, when_not_matched_insert_all = False + Existing row found → Update it + New row found → Do nothing (no inserts) + + Case 4: Both Parameters = False (No Merge Effect) + Existing row found → Do nothing + New row found → Do nothing + (Function effectively does nothing) + + + Returns: + An UpsertResult class (contains details of rows updated and inserted) + """ + try: + import pyarrow as pa # noqa: F401 + except ModuleNotFoundError as e: + raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + + from pyiceberg.io.pyarrow import expression_to_pyarrow + from pyiceberg.table import upsert_util + + if join_cols is None: + join_cols = [] + for field_id in df.schema.identifier_field_ids: + col = df.schema.find_column_name(field_id) + if col is not None: + join_cols.append(col) + else: + raise ValueError(f"Field-ID could not be found: {join_cols}") + + if len(join_cols) == 0: + raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") + + if not when_matched_update_all and not when_not_matched_insert_all: + raise ValueError("no upsert options selected...exiting") + + if upsert_util.has_duplicate_rows(df, join_cols): + raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") + + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible + + downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False + _check_pyarrow_schema_compatible( + df.schema, provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + ) + + # get list of rows that exist so we don't have to load the entire target table + matched_predicate = upsert_util.create_match_filter(df, join_cols) + matched_iceberg_table = df.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() + + update_row_cnt = 0 + insert_row_cnt = 0 + + if when_matched_update_all: + # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed + # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed + # this extra step avoids unnecessary IO and writes + rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols) + + update_row_cnt = len(rows_to_update) + + if len(rows_to_update) > 0: + # build the match predicate filter + overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) + + self.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate) + + if when_not_matched_insert_all: + expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) + expr_match_bound = bind(df.schema, expr_match, case_sensitive=case_sensitive) + expr_match_arrow = expression_to_pyarrow(expr_match_bound) + rows_to_insert = df.filter(~expr_match_arrow) + + insert_row_cnt = len(rows_to_insert) + + if insert_row_cnt > 0: + self.append(rows_to_insert) + + return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + def add_files( self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True ) -> None: @@ -1159,73 +1268,11 @@ def upsert( Returns: An UpsertResult class (contains details of rows updated and inserted) """ - try: - import pyarrow as pa # noqa: F401 - except ModuleNotFoundError as e: - raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - - from pyiceberg.io.pyarrow import expression_to_pyarrow - from pyiceberg.table import upsert_util - - if join_cols is None: - join_cols = [] - for field_id in self.schema().identifier_field_ids: - col = self.schema().find_column_name(field_id) - if col is not None: - join_cols.append(col) - else: - raise ValueError(f"Field-ID could not be found: {join_cols}") - - if len(join_cols) == 0: - raise ValueError("Join columns could not be found, please set identifier-field-ids or pass in explicitly.") - - if not when_matched_update_all and not when_not_matched_insert_all: - raise ValueError("no upsert options selected...exiting") - - if upsert_util.has_duplicate_rows(df, join_cols): - raise ValueError("Duplicate rows found in source dataset based on the key columns. No upsert executed") - - from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible - - downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False - _check_pyarrow_schema_compatible( - self.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us - ) - - # get list of rows that exist so we don't have to load the entire target table - matched_predicate = upsert_util.create_match_filter(df, join_cols) - matched_iceberg_table = self.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() - - update_row_cnt = 0 - insert_row_cnt = 0 - with self.transaction() as tx: - if when_matched_update_all: - # function get_rows_to_update is doing a check on non-key columns to see if any of the values have actually changed - # we don't want to do just a blanket overwrite for matched rows if the actual non-key column data hasn't changed - # this extra step avoids unnecessary IO and writes - rows_to_update = upsert_util.get_rows_to_update(df, matched_iceberg_table, join_cols) - - update_row_cnt = len(rows_to_update) - - if len(rows_to_update) > 0: - # build the match predicate filter - overwrite_mask_predicate = upsert_util.create_match_filter(rows_to_update, join_cols) - - tx.overwrite(rows_to_update, overwrite_filter=overwrite_mask_predicate) - - if when_not_matched_insert_all: - expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) - expr_match_bound = bind(self.schema(), expr_match, case_sensitive=case_sensitive) - expr_match_arrow = expression_to_pyarrow(expr_match_bound) - rows_to_insert = df.filter(~expr_match_arrow) - - insert_row_cnt = len(rows_to_insert) - - if insert_row_cnt > 0: - tx.append(rows_to_insert) - - return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) + return tx.upsert( + df=df, join_cols=join_cols, when_matched_update_all=when_matched_update_all, when_not_matched_insert_all=when_not_matched_insert_all, + case_sensitive=case_sensitive + ) def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: """ From db334ae4bfeefff1ab7373ea8c5e55f18fdace84 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Wed, 19 Mar 2025 14:32:19 +0100 Subject: [PATCH 2/8] Fix some incorrect usage of schema --- pyiceberg/table/__init__.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index d57b6463af..99f2fee388 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -747,8 +747,8 @@ def upsert( if join_cols is None: join_cols = [] - for field_id in df.schema.identifier_field_ids: - col = df.schema.find_column_name(field_id) + for field_id in self.table_metadata.schema().identifier_field_ids: + col = self.table_metadata.schema().find_column_name(field_id) if col is not None: join_cols.append(col) else: @@ -767,12 +767,12 @@ def upsert( downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( - df.schema, provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us + self.table_metadata.schema(), provided_schema=df.schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us ) # get list of rows that exist so we don't have to load the entire target table matched_predicate = upsert_util.create_match_filter(df, join_cols) - matched_iceberg_table = df.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() + matched_iceberg_table = self._table.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() update_row_cnt = 0 insert_row_cnt = 0 @@ -793,7 +793,7 @@ def upsert( if when_not_matched_insert_all: expr_match = upsert_util.create_match_filter(matched_iceberg_table, join_cols) - expr_match_bound = bind(df.schema, expr_match, case_sensitive=case_sensitive) + expr_match_bound = bind(self.table_metadata.schema(), expr_match, case_sensitive=case_sensitive) expr_match_arrow = expression_to_pyarrow(expr_match_bound) rows_to_insert = df.filter(~expr_match_arrow) @@ -1270,8 +1270,11 @@ def upsert( """ with self.transaction() as tx: return tx.upsert( - df=df, join_cols=join_cols, when_matched_update_all=when_matched_update_all, when_not_matched_insert_all=when_not_matched_insert_all, - case_sensitive=case_sensitive + df=df, + join_cols=join_cols, + when_matched_update_all=when_matched_update_all, + when_not_matched_insert_all=when_not_matched_insert_all, + case_sensitive=case_sensitive, ) def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None: From cebfda373efd0ea17460ff73f419062027d093da Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 25 Mar 2025 15:41:37 +0100 Subject: [PATCH 3/8] Write a test for upsert transaction --- tests/table/test_upsert.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 70203fd162..429c78091c 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -23,7 +23,7 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.expressions import And, EqualTo, Reference +from pyiceberg.expressions import AlwaysTrue, And, EqualTo, Reference from pyiceberg.expressions.literals import LongLiteral from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema @@ -709,3 +709,26 @@ def test_upsert_with_nulls(catalog: Catalog) -> None: ], schema=schema, ) + + +def test_transaction(catalog: Catalog) -> None: + """Test the upsert within a Transaction. Make sure that if something fails the entire Transaction is + rolled back.""" + identifier = "default.test_merge_source_dups" + _drop_table(catalog, identifier) + + ctx = SessionContext() + + table = gen_target_iceberg_table(1, 10, False, ctx, catalog, identifier) + df_before_transaction = table.scan().to_arrow() + + source_df = gen_source_dataset(5, 15, False, True, ctx) + + with pytest.raises(Exception, match="Duplicate rows found in source dataset based on the key columns. No upsert executed"): + with table.transaction() as tx: + tx.delete(delete_filter=AlwaysTrue()) + tx.upsert(df=source_df, join_cols=["order_id"]) + + df = table.scan().to_arrow() + + assert df_before_transaction == df From 52fd35eebb81dd351be5923bd6f2aeb8341d6671 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Thu, 27 Mar 2025 21:14:46 +0100 Subject: [PATCH 4/8] Add failing test for multiple upserts in same transaction --- tests/table/test_upsert.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 429c78091c..553b1ef5b3 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -732,3 +732,39 @@ def test_transaction(catalog: Catalog) -> None: df = table.scan().to_arrow() assert df_before_transaction == df + + +def test_transaction_multiple_upserts(catalog: Catalog) -> None: + identifier = "default.test_multi_upsert" + _drop_table(catalog, identifier) + + schema = Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "name", StringType(), required=True), + identifier_field_ids=[1], + ) + + tbl = catalog.create_table(identifier, schema=schema) + + # Define exact schema: required int32 and required string + arrow_schema = pa.schema([ + pa.field("id", pa.int32(), nullable=False), + pa.field("name", pa.string(), nullable=False), + ]) + + tbl.append(pa.Table.from_pylist([{"id": 1, "name": "Alice"}], schema=arrow_schema)) + + df = pa.Table.from_pylist([{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alicia"}], schema=arrow_schema) + + with tbl.transaction() as txn: + # This should read the uncommitted changes? + txn.upsert(df, join_cols=["id"]) + + txn.upsert(df, join_cols=["id"]) + + result = tbl.scan().to_arrow().to_pylist() + assert sorted(result, key=lambda x: x["id"]) == [ + {"id": 1, "name": "Alicia"}, + {"id": 2, "name": "Bob"}, + ] + From f336c0b6b92a2516f34b6f52cb332fda1602c6a7 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Wed, 2 Apr 2025 22:02:32 +0200 Subject: [PATCH 5/8] Fix test --- tests/table/test_upsert.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 553b1ef5b3..2e65c97259 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -734,6 +734,7 @@ def test_transaction(catalog: Catalog) -> None: assert df_before_transaction == df +@pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction") def test_transaction_multiple_upserts(catalog: Catalog) -> None: identifier = "default.test_multi_upsert" _drop_table(catalog, identifier) @@ -747,24 +748,28 @@ def test_transaction_multiple_upserts(catalog: Catalog) -> None: tbl = catalog.create_table(identifier, schema=schema) # Define exact schema: required int32 and required string - arrow_schema = pa.schema([ - pa.field("id", pa.int32(), nullable=False), - pa.field("name", pa.string(), nullable=False), - ]) + arrow_schema = pa.schema( + [ + pa.field("id", pa.int32(), nullable=False), + pa.field("name", pa.string(), nullable=False), + ] + ) tbl.append(pa.Table.from_pylist([{"id": 1, "name": "Alice"}], schema=arrow_schema)) df = pa.Table.from_pylist([{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alicia"}], schema=arrow_schema) with tbl.transaction() as txn: + txn.append(df) + txn.delete(delete_filter="id = 1") + txn.append(df) # This should read the uncommitted changes? txn.upsert(df, join_cols=["id"]) - txn.upsert(df, join_cols=["id"]) + # txn.upsert(df, join_cols=["id"]) result = tbl.scan().to_arrow().to_pylist() assert sorted(result, key=lambda x: x["id"]) == [ {"id": 1, "name": "Alicia"}, {"id": 2, "name": "Bob"}, ] - From 07890ac4d73cc051dd92e27cff7981f3e134534d Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 13 May 2025 10:47:34 +0200 Subject: [PATCH 6/8] Add failing test --- tests/table/test_upsert.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 2e65c97259..85315a81db 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -734,7 +734,7 @@ def test_transaction(catalog: Catalog) -> None: assert df_before_transaction == df -@pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction") +# @pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction") def test_transaction_multiple_upserts(catalog: Catalog) -> None: identifier = "default.test_multi_upsert" _drop_table(catalog, identifier) @@ -760,13 +760,12 @@ def test_transaction_multiple_upserts(catalog: Catalog) -> None: df = pa.Table.from_pylist([{"id": 2, "name": "Bob"}, {"id": 1, "name": "Alicia"}], schema=arrow_schema) with tbl.transaction() as txn: - txn.append(df) txn.delete(delete_filter="id = 1") txn.append(df) - # This should read the uncommitted changes? - txn.upsert(df, join_cols=["id"]) - # txn.upsert(df, join_cols=["id"]) + # This should read the uncommitted changes + # TODO: currently fails because it only reads {"id": 1, "name": "Alice"} + txn.upsert(df, join_cols=["id"]) result = tbl.scan().to_arrow().to_pylist() assert sorted(result, key=lambda x: x["id"]) == [ From ae0e60fa4ca725d42a8b1832cdbb4fec936ddc60 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 13 May 2025 10:56:08 +0200 Subject: [PATCH 7/8] Use Transaction.table_metadata when doing the data scan in upsert --- pyiceberg/table/__init__.py | 9 ++++++++- tests/table/test_upsert.py | 1 - 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 99f2fee388..78676a774a 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -772,7 +772,14 @@ def upsert( # get list of rows that exist so we don't have to load the entire target table matched_predicate = upsert_util.create_match_filter(df, join_cols) - matched_iceberg_table = self._table.scan(row_filter=matched_predicate, case_sensitive=case_sensitive).to_arrow() + + # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. + matched_iceberg_table = DataScan( + table_metadata=self.table_metadata, + io=self._table.io, + row_filter=matched_predicate, + case_sensitive=case_sensitive, + ).to_arrow() update_row_cnt = 0 insert_row_cnt = 0 diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 85315a81db..10593ea62e 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -764,7 +764,6 @@ def test_transaction_multiple_upserts(catalog: Catalog) -> None: txn.append(df) # This should read the uncommitted changes - # TODO: currently fails because it only reads {"id": 1, "name": "Alice"} txn.upsert(df, join_cols=["id"]) result = tbl.scan().to_arrow().to_pylist() From ce8d9efc72110e29330355d84e011ba6707d5802 Mon Sep 17 00:00:00 2001 From: Koen Vossen Date: Tue, 13 May 2025 10:58:45 +0200 Subject: [PATCH 8/8] Remove as it's resolved --- tests/table/test_upsert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 10593ea62e..9fecbbb7bb 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -734,7 +734,6 @@ def test_transaction(catalog: Catalog) -> None: assert df_before_transaction == df -# @pytest.mark.skip("This test is just for reference. Multiple upserts or delete+upsert doesn't work in a transaction") def test_transaction_multiple_upserts(catalog: Catalog) -> None: identifier = "default.test_multi_upsert" _drop_table(catalog, identifier)