Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.AlterViewAs
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias
import org.apache.spark.sql.catalyst.plans.logical.View
import org.apache.spark.sql.catalyst.plans.logical.views.CreateIcebergView
import org.apache.spark.sql.catalyst.plans.logical.views.ResolvedV2View
import org.apache.spark.sql.connector.catalog.ViewCatalog
Expand All @@ -30,12 +34,18 @@ import org.apache.spark.sql.util.SchemaUtils

object CheckViews extends (LogicalPlan => Unit) {

import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._

override def apply(plan: LogicalPlan): Unit = {
plan foreach {
case CreateIcebergView(resolvedIdent@ResolvedIdentifier(_: ViewCatalog, _), _, query, columnAliases, _,
_, _, _, _, _, _) =>
_, _, _, _, replace, _) =>
verifyColumnCount(resolvedIdent, columnAliases, query)
SchemaUtils.checkColumnNameDuplication(query.schema.fieldNames, SQLConf.get.resolver)
if (replace) {
val viewIdent: Seq[String] = resolvedIdent.catalog.name() +: resolvedIdent.identifier.asMultipartIdentifier
checkCyclicViewReference(viewIdent, query, Seq(viewIdent))
}

case AlterViewAs(ResolvedV2View(_, _), _, _) =>
throw new AnalysisException("ALTER VIEW <viewName> AS is not supported. Use CREATE OR REPLACE VIEW instead")
Expand All @@ -59,4 +69,44 @@ object CheckViews extends (LogicalPlan => Unit) {
}
}
}

private def checkCyclicViewReference(
viewIdent: Seq[String],
plan: LogicalPlan,
cyclePath: Seq[Seq[String]]): Unit = {
plan match {
case sub@SubqueryAlias(_, Project(_, _)) =>
val currentViewIdent: Seq[String] = sub.identifier.qualifier :+ sub.identifier.name
checkIfRecursiveView(viewIdent, currentViewIdent, cyclePath, sub.children)
case v1View: View =>
val currentViewIdent: Seq[String] = v1View.desc.identifier.nameParts
checkIfRecursiveView(viewIdent, currentViewIdent, cyclePath, v1View.children)
case _ =>
plan.children.foreach(child => checkCyclicViewReference(viewIdent, child, cyclePath))
}

plan.expressions.flatMap(_.flatMap {
case e: SubqueryExpression =>
checkCyclicViewReference(viewIdent, e.plan, cyclePath)
None
case _ => None
})
}

private def checkIfRecursiveView(
viewIdent: Seq[String],
currentViewIdent: Seq[String],
cyclePath: Seq[Seq[String]],
children: Seq[LogicalPlan]
): Unit = {
val newCyclePath = cyclePath :+ currentViewIdent
if (currentViewIdent == viewIdent) {
throw new AnalysisException(String.format("Recursive cycle in view detected: %s (cycle: %s)",
viewIdent.asIdentifier, newCyclePath.map(p => p.mkString(".")).mkString(" -> ")))
} else {
children.foreach { c =>
checkCyclicViewReference(viewIdent, c, newCyclePath)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1867,6 +1867,93 @@ public void replacingViewWithDialectDropAllowed() {
.isEqualTo(ImmutableSQLViewRepresentation.builder().dialect("spark").sql(sql).build());
}

@Test
public void createViewWithRecursiveCycle() {
String viewOne = viewName("viewOne");
String viewTwo = viewName("viewTwo");

sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName);
// viewTwo points to viewOne
sql("CREATE VIEW %s AS SELECT * FROM %s", viewTwo, viewOne);

// viewOne points to viewTwo points to viewOne, creating a recursive cycle
String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne);
String view2 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewTwo);
String cycle = String.format("%s -> %s -> %s", view1, view2, view1);
assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS SELECT * FROM %s", viewOne, view2))
.isInstanceOf(AnalysisException.class)
.hasMessageStartingWith(
String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle));
}

@Test
public void createViewWithRecursiveCycleToV1View() {
String viewOne = viewName("view_one");
String viewTwo = viewName("view_two");

sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName);
// viewTwo points to viewOne
sql("USE spark_catalog");
sql("CREATE VIEW %s AS SELECT * FROM %s.%s.%s", viewTwo, catalogName, NAMESPACE, viewOne);

sql("USE %s", catalogName);
// viewOne points to viewTwo points to viewOne, creating a recursive cycle
String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne);
String view2 = String.format("%s.%s.%s", "spark_catalog", NAMESPACE, viewTwo);
String cycle = String.format("%s -> %s -> %s", view1, view2, view1);
assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS SELECT * FROM %s", viewOne, view2))
.isInstanceOf(AnalysisException.class)
.hasMessageStartingWith(
String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle));
}

@Test
public void createViewWithRecursiveCycleInCTE() {
String viewOne = viewName("viewOne");
String viewTwo = viewName("viewTwo");

sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName);
// viewTwo points to viewOne
sql("CREATE VIEW %s AS SELECT * FROM %s", viewTwo, viewOne);

// CTE points to viewTwo
String sql =
String.format(
"WITH max_by_data AS (SELECT max(id) as max FROM %s) "
+ "SELECT max, count(1) AS count FROM max_by_data GROUP BY max",
viewTwo);

// viewOne points to CTE, creating a recursive cycle
String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne);
String cycle = String.format("%s -> %s -> %s", view1, viewTwo, view1);
assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS %s", viewOne, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageStartingWith(
String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle));
}

@Test
public void createViewWithRecursiveCycleInSubqueryExpression() {
String viewOne = viewName("viewOne");
String viewTwo = viewName("viewTwo");

sql("CREATE VIEW %s AS SELECT * FROM %s", viewOne, tableName);
// viewTwo points to viewOne
sql("CREATE VIEW %s AS SELECT * FROM %s", viewTwo, viewOne);

// subquery expression points to viewTwo
String sql =
String.format("SELECT * FROM %s WHERE id = (SELECT id FROM %s)", tableName, viewTwo);

// viewOne points to subquery expression, creating a recursive cycle
String view1 = String.format("%s.%s.%s", catalogName, NAMESPACE, viewOne);
String cycle = String.format("%s -> %s -> %s", view1, viewTwo, view1);
assertThatThrownBy(() -> sql("CREATE OR REPLACE VIEW %s AS %s", viewOne, sql))
.isInstanceOf(AnalysisException.class)
.hasMessageStartingWith(
String.format("Recursive cycle in view detected: %s (cycle: %s)", view1, cycle));
}

private void insertRows(int numRows) throws NoSuchTableException {
List<SimpleRecord> records = Lists.newArrayListWithCapacity(numRows);
for (int i = 1; i <= numRows; i++) {
Expand Down