Skip to content

Commit 56899e6

Browse files
kevinjqliuHonahX
authored andcommitted
Cast data to Iceberg Table's pyarrow schema (apache#523)
Backport to 0.6.1
1 parent c0c0e79 commit 56899e6

File tree

4 files changed

+70
-9
lines changed

4 files changed

+70
-9
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1731,7 +1731,7 @@ def write_file(table: Table, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
17311731
parquet_writer_kwargs = _get_parquet_writer_kwargs(table.properties)
17321732

17331733
file_path = f'{table.location()}/data/{task.generate_data_file_filename("parquet")}'
1734-
file_schema = schema_to_pyarrow(table.schema())
1734+
file_schema = table.schema().as_arrow()
17351735

17361736
fo = table.io.new_output(file_path)
17371737
row_group_size = PropertyUtil.property_as_int(

pyiceberg/table/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,15 @@
133133
_JAVA_LONG_MAX = 9223372036854775807
134134

135135

136-
def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None:
136+
def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
137+
"""
138+
Check if the `table_schema` is compatible with `other_schema`.
139+
140+
Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.
141+
142+
Raises:
143+
ValueError: If the schemas are not compatible.
144+
"""
137145
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema
138146

139147
name_mapping = table_schema.name_mapping
@@ -1045,7 +1053,10 @@ def append(self, df: pa.Table) -> None:
10451053
if len(self.spec().fields) > 0:
10461054
raise ValueError("Cannot write to partitioned tables")
10471055

1048-
_check_schema(self.schema(), other_schema=df.schema)
1056+
_check_schema_compatible(self.schema(), other_schema=df.schema)
1057+
# cast if the two schemas are compatible but not equal
1058+
if self.schema().as_arrow() != df.schema:
1059+
df = df.cast(self.schema().as_arrow())
10491060

10501061
merge = _MergingSnapshotProducer(operation=Operation.APPEND, table=self)
10511062

@@ -1080,7 +1091,10 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T
10801091
if len(self.spec().fields) > 0:
10811092
raise ValueError("Cannot write to partitioned tables")
10821093

1083-
_check_schema(self.schema(), other_schema=df.schema)
1094+
_check_schema_compatible(self.schema(), other_schema=df.schema)
1095+
# cast if the two schemas are compatible but not equal
1096+
if self.schema().as_arrow() != df.schema:
1097+
df = df.cast(self.schema().as_arrow())
10841098

10851099
merge = _MergingSnapshotProducer(
10861100
operation=Operation.OVERWRITE if self.current_snapshot() is not None else Operation.APPEND,

tests/catalog/test_sql.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,39 @@ def test_create_table_with_pyarrow_schema(
191191
catalog.drop_table(random_identifier)
192192

193193

194+
@pytest.mark.parametrize(
195+
'catalog',
196+
[
197+
lazy_fixture('catalog_memory'),
198+
# lazy_fixture('catalog_sqlite'),
199+
],
200+
)
201+
def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None:
202+
import pyarrow as pa
203+
204+
pyarrow_table = pa.Table.from_arrays(
205+
[
206+
pa.array([None, "A", "B", "C"]), # 'foo' column
207+
pa.array([1, 2, 3, 4]), # 'bar' column
208+
pa.array([True, None, False, True]), # 'baz' column
209+
pa.array([None, "A", "B", "C"]), # 'large' column
210+
],
211+
schema=pa.schema([
212+
pa.field('foo', pa.string(), nullable=True),
213+
pa.field('bar', pa.int32(), nullable=False),
214+
pa.field('baz', pa.bool_(), nullable=True),
215+
pa.field('large', pa.large_string(), nullable=True),
216+
]),
217+
)
218+
database_name, _table_name = random_identifier
219+
catalog.create_namespace(database_name)
220+
table = catalog.create_table(random_identifier, pyarrow_table.schema)
221+
print(pyarrow_table.schema)
222+
print(table.schema().as_struct())
223+
print()
224+
table.overwrite(pyarrow_table)
225+
226+
194227
@pytest.mark.parametrize(
195228
'catalog',
196229
[

tests/table/test_init.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
Table,
6060
UpdateSchema,
6161
_apply_table_update,
62-
_check_schema,
62+
_check_schema_compatible,
6363
_generate_snapshot_id,
6464
_match_deletes_to_data_file,
6565
_TableMetadataUpdateContext,
@@ -1004,7 +1004,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
10041004
"""
10051005

10061006
with pytest.raises(ValueError, match=expected):
1007-
_check_schema(table_schema_simple, other_schema)
1007+
_check_schema_compatible(table_schema_simple, other_schema)
10081008

10091009

10101010
def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
@@ -1025,7 +1025,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
10251025
"""
10261026

10271027
with pytest.raises(ValueError, match=expected):
1028-
_check_schema(table_schema_simple, other_schema)
1028+
_check_schema_compatible(table_schema_simple, other_schema)
10291029

10301030

10311031
def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
@@ -1045,7 +1045,7 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
10451045
"""
10461046

10471047
with pytest.raises(ValueError, match=expected):
1048-
_check_schema(table_schema_simple, other_schema)
1048+
_check_schema_compatible(table_schema_simple, other_schema)
10491049

10501050

10511051
def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
@@ -1059,4 +1059,18 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
10591059
expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."
10601060

10611061
with pytest.raises(ValueError, match=expected):
1062-
_check_schema(table_schema_simple, other_schema)
1062+
_check_schema_compatible(table_schema_simple, other_schema)
1063+
1064+
1065+
def test_schema_downcast(table_schema_simple: Schema) -> None:
1066+
# large_string type is compatible with string type
1067+
other_schema = pa.schema((
1068+
pa.field("foo", pa.large_string(), nullable=True),
1069+
pa.field("bar", pa.int32(), nullable=False),
1070+
pa.field("baz", pa.bool_(), nullable=True),
1071+
))
1072+
1073+
try:
1074+
_check_schema_compatible(table_schema_simple, other_schema)
1075+
except Exception:
1076+
pytest.fail("Unexpected Exception raised when calling `_check_schema`")

0 commit comments

Comments
 (0)