diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala index 4debc4d343a0..685b85a0d75f 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckViews.scala @@ -20,8 +20,12 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.logical.AlterViewAs import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias +import org.apache.spark.sql.catalyst.plans.logical.View import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView import org.apache.spark.sql.catalyst.plans.logical.views.ResolvedV2View import org.apache.spark.sql.connector.catalog.ViewCatalog @@ -30,12 +34,18 @@ import org.apache.spark.sql.util.SchemaUtils object CheckViews extends (LogicalPlan => Unit) { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + override def apply(plan: LogicalPlan): Unit = { plan foreach { case CreateIcebergView(resolvedIdent@ResolvedIdentifier(_: ViewCatalog, _), _, query, columnAliases, _, - _, _, _, _, _, _) => + _, _, _, _, replace, _) => verifyColumnCount(resolvedIdent, columnAliases, query) SchemaUtils.checkColumnNameDuplication(query.schema.fieldNames, SQLConf.get.resolver) + if (replace) { + val viewIdent: Seq[String] = resolvedIdent.catalog.name() +: resolvedIdent.identifier.asMultipartIdentifier + checkCyclicViewReference(viewIdent, query, Seq(viewIdent)) + } case AlterViewAs(ResolvedV2View(_, _), _, _) => throw new AnalysisException("ALTER VIEW AS is not supported. Use CREATE OR REPLACE VIEW instead") @@ -59,4 +69,44 @@ object CheckViews extends (LogicalPlan => Unit) { } } } + + private def checkCyclicViewReference( + viewIdent: Seq[String], + plan: LogicalPlan, + cyclePath: Seq[Seq[String]]): Unit = { + plan match { + case sub@SubqueryAlias(_, Project(_, _)) => + val currentViewIdent: Seq[String] = sub.identifier.qualifier :+ sub.identifier.name + checkIfRecursiveView(viewIdent, currentViewIdent, cyclePath, sub.children) + case v1View: View => + val currentViewIdent: Seq[String] = v1View.desc.identifier.nameParts + checkIfRecursiveView(viewIdent, currentViewIdent, cyclePath, v1View.children) + case _ => + plan.children.foreach(child => checkCyclicViewReference(viewIdent, child, cyclePath)) + } + + plan.expressions.flatMap(_.flatMap { + case e: SubqueryExpression => + checkCyclicViewReference(viewIdent, e.plan, cyclePath) + None + case _ => None + }) + } + + private def checkIfRecursiveView( + viewIdent: Seq[String], + currentViewIdent: Seq[String], + cyclePath: Seq[Seq[String]], + children: Seq[LogicalPlan] + ): Unit = { + val newCyclePath = cyclePath :+ currentViewIdent + if (currentViewIdent == viewIdent) { + throw new AnalysisException(String.format("Recursive cycle in view detected: %s (cycle: %s)", + viewIdent.asIdentifier, newCyclePath.map(p => p.mkString(".")).mkString(" -> "))) + } else { + children.foreach { c => + checkCyclicViewReference(viewIdent, c, newCyclePath) + } + } + } } diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java index 5d1cb2db612b..624b4e354937 100644 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestViews.java @@ -1867,6 +1867,93 @@ public void replacingViewWithDialectDropAllowed() { .isEqualTo(ImmutableSQLViewRepresentation.builder().dialect("spark").sql(sql).build()); } + @Test + public void createViewWithRecursiveCycle() { + String viewOne = viewName("viewOne"); + String viewTwo = viewName("viewTwo"); + + sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName); + // viewTwo points to viewOne + sql("CREATE VIEW %s AS SELECT * FROM %s", viewTwo, viewOne); + + // viewOne points to viewTwo points to viewOne, creating a recursive cycle + String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne); + String view2 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewTwo); + String cycle = String.format("%s -> %s -> %s", view1, view2, view1); + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS SELECT * FROM %s", viewOne, view2)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle)); + } + + @Test + public void createViewWithRecursiveCycleToV1View() { + String viewOne = viewName("view_one"); + String viewTwo = viewName("view_two"); + + sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName); + // viewTwo points to viewOne + sql("USE spark_catalog"); + sql("CREATE VIEW %s AS SELECT * FROM %s.%s.%s", viewTwo, catalogName, NAMESPACE, viewOne); + + sql("USE %s", catalogName); + // viewOne points to viewTwo points to viewOne, creating a recursive cycle + String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne); + String view2 = String.format("%s.%s.%s", "spark_catalog", NAMESPACE, viewTwo); + String cycle = String.format("%s -> %s -> %s", view1, view2, view1); + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS SELECT * FROM %s", viewOne, view2)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle)); + } + + @Test + public void createViewWithRecursiveCycleInCTE() { + String viewOne = viewName("viewOne"); + String viewTwo = viewName("viewTwo"); + + sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName); + // viewTwo points to viewOne + sql("CREATE VIEW %s AS SELECT * FROM %s", viewTwo, viewOne); + + // CTE points to viewTwo + String sql = + String.format( + "WITH max_by_data AS (SELECT max(id) as max FROM %s) " + + "SELECT max, count(1) AS count FROM max_by_data GROUP BY max", + viewTwo); + + // viewOne points to CTE, creating a recursive cycle + String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne); + String cycle = String.format("%s -> %s -> %s", view1, viewTwo, view1); + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS %s", viewOne, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle)); + } + + @Test + public void createViewWithRecursiveCycleInSubqueryExpression() { + String viewOne = viewName("viewOne"); + String viewTwo = viewName("viewTwo"); + + sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName); + // viewTwo points to viewOne + sql("CREATE VIEW %s AS SELECT * FROM %s", viewTwo, viewOne); + + // subquery expression points to viewTwo + String sql = + String.format("SELECT * FROM %s WHERE id = (SELECT id FROM %s)", tableName, viewTwo); + + // viewOne points to subquery expression, creating a recursive cycle + String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne); + String cycle = String.format("%s -> %s -> %s", view1, viewTwo, view1); + assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS %s", viewOne, sql)) + .isInstanceOf(AnalysisException.class) + .hasMessageStartingWith( + String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle)); + } + private void insertRows(int numRows) throws NoSuchTableException { List records = Lists.newArrayListWithCapacity(numRows); for (int i = 1; i <= numRows; i++) {