Skip to content

Commit 1adff26

Browse files
authored
Make View.sql_for case-insensitive and return None for unknown dialect (#3407)
# Rationale for this change Iceberg Java compares the dialect in a case-insensitive manner, and returns null if the dialect is unknown. https://github.com/apache/iceberg/blob/29ba0a14dc6667db0683fdfcd520639b3da77774/core/src/main/java/org/apache/iceberg/view/BaseView.java#L113-L130 ## Are these changes tested? Yes ## Are there any user-facing changes? No - the method isn't released yet <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent 2be1827 commit 1adff26

2 files changed

Lines changed: 14 additions & 5 deletions

File tree

pyiceberg/view/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,12 @@ def uuid(self) -> UUID:
8080
"""Return the view's UUID."""
8181
return UUID(self.metadata.view_uuid)
8282

83-
def sql_for(self, dialect: str) -> SQLViewRepresentation:
84-
"""Return the view representation for the sql dialect."""
85-
return next(repr.root for repr in self.current_version().representations if repr.root.dialect == dialect)
83+
def sql_for(self, dialect: str) -> SQLViewRepresentation | None:
84+
"""Return the view representation for the sql dialect, or None if no representation could be resolved."""
85+
return next(
86+
(repr.root for repr in self.current_version().representations if repr.root.dialect.casefold() == dialect.casefold()),
87+
None,
88+
)
8689

8790
def __eq__(self, other: Any) -> bool:
8891
"""Return the equality of two instances of the View class."""

tests/test_view.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ def test_view_sql_for_dialect(view: View) -> None:
9696
assert repr.sql == "SELECT * FROM prod.db.table"
9797

9898

99+
def test_view_sql_for_dialect_ignore_case(view: View) -> None:
100+
repr = view.sql_for("Spark")
101+
assert isinstance(repr, SQLViewRepresentation)
102+
assert repr.dialect == "spark"
103+
assert repr.sql == "SELECT * FROM prod.db.table"
104+
105+
99106
def test_view_schemas_multiple(example_view_metadata_v1_multiple_versions: dict[str, Any]) -> None:
100107
view = View(("default", "test_view"), ViewMetadata.model_validate(example_view_metadata_v1_multiple_versions))
101108
schemas = view.schemas()
@@ -117,5 +124,4 @@ def test_view_version_unknown_id(view: View) -> None:
117124

118125

119126
def test_view_sql_for_unknown_dialect(view: View) -> None:
120-
with pytest.raises(StopIteration):
121-
view.sql_for("trino")
127+
assert not view.sql_for("trino")

0 commit comments

Comments
 (0)