From 74cf2ddd5524a9572193803bf9d753b8c888af7c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 19 Sep 2020 18:42:29 -0700 Subject: [PATCH 1/5] Optimize WithFields expression chain. --- .../sql/catalyst/optimizer/ComplexTypes.scala | 6 ++- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/WithFields.scala | 19 ++++++-- ...te.scala => OptimizeWithFieldsSuite.scala} | 43 +++++++++++++++++-- 4 files changed, 62 insertions(+), 8 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{CombineWithFieldsSuite.scala => OptimizeWithFieldsSuite.scala} (64%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 2aba4bae397c7..a47523b6fa023 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -48,7 +48,11 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")` // we want to return `lit(2)` (and not `lit(1)`). val expr = matches.last._2 - If(IsNull(struct), Literal(null, expr.dataType), expr) + if (struct.nullable) { + If(IsNull(struct), Literal(null, expr.dataType), expr) + } else { + expr + } } else { GetStructField(struct, ordinal, maybeName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6033c01a60f47..c57082e54e270 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -109,7 +109,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveRedundantAliases, UnwrapCastInBinaryComparison, RemoveNoopOperators, - CombineWithFields, + OptimizeWithFields, SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala index 05c90864e4bb0..435d9e503e169 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -17,16 +17,29 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.WithFields +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, WithFields} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule /** - * Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression. + * Optimizes [[WithFields]] expression chains. */ -object CombineWithFields extends Rule[LogicalPlan] { +object OptimizeWithFields extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case WithFields(structExpr, names, values) if names.distinct.length != names.length => + val newNames = mutable.ArrayBuffer.empty[String] + val newValues = mutable.ArrayBuffer.empty[Expression] + names.zip(values).reverse.foreach { case (name, value) => + if (!newNames.contains(name)) { + newNames += name + newValues += value + } + } + WithFields(structExpr, names = newNames.reverse.toSeq, valExprs = newValues.reverse.toSeq) + case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala similarity index 64% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala index a3e0bbc57e639..70298ce301b9d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala @@ -19,19 +19,21 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields} +import org.apache.spark.sql.catalyst.expressions.{Alias, GetStructField, Literal, WithFields} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class CombineWithFieldsSuite extends PlanTest { +class OptimizeWithFieldsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil + val batches = Batch("OptimizeWithFields", FixedPoint(10), + OptimizeWithFields, SimplifyExtractValueOps) :: Nil } private val testRelation = LocalRelation('a.struct('a1.int)) + private val testRelation2 = LocalRelation('a.struct('a1.int).notNull) test("combines two WithFields") { val originalQuery = testRelation @@ -73,4 +75,39 @@ class CombineWithFieldsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("SPARK-32941: optimize WithFields followed by GetStructField") { + val originalQuery = testRelation2 + .select(Alias( + GetStructField(WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), 1), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation2 + .select(Alias(Literal(4), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("SPARK-32941: optimize WithFields chain") { + val originalQuery = testRelation + .select(Alias( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("b1"), + Seq(Literal(5))), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(Alias(WithFields('a, Seq("b1"), Seq(Literal(5))), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } } From 00acff9c5bacfee8fb86a5e95ac97bf8a1f3cd76 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 21 Sep 2020 18:05:53 -0700 Subject: [PATCH 2/5] Use resolver. --- .../apache/spark/sql/catalyst/optimizer/WithFields.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala index 435d9e503e169..4eafb56f7d55f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -19,21 +19,24 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, WithFields} +import org.apache.spark.sql.catalyst.expressions.{Expression, WithFields} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf /** * Optimizes [[WithFields]] expression chains. */ object OptimizeWithFields extends Rule[LogicalPlan] { + lazy val resolver = SQLConf.get.resolver + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case WithFields(structExpr, names, values) if names.distinct.length != names.length => val newNames = mutable.ArrayBuffer.empty[String] val newValues = mutable.ArrayBuffer.empty[Expression] names.zip(values).reverse.foreach { case (name, value) => - if (!newNames.contains(name)) { + if (newNames.find(resolver(_, name)).isEmpty) { newNames += name newValues += value } From cb8872c3e203cf2ae5025f836b1eb53004b0e7f8 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 22 Sep 2020 15:39:43 -0700 Subject: [PATCH 3/5] Address comment. --- .../sql/catalyst/optimizer/WithFields.scala | 31 +++++++++--- .../optimizer/OptimizeWithFieldsSuite.scala | 49 +++++++++++++++++-- 2 files changed, 69 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala index 4eafb56f7d55f..55886310e8c1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer +import java.util.Locale + import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Expression, WithFields} @@ -29,18 +31,33 @@ import org.apache.spark.sql.internal.SQLConf * Optimizes [[WithFields]] expression chains. */ object OptimizeWithFields extends Rule[LogicalPlan] { - lazy val resolver = SQLConf.get.resolver - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case WithFields(structExpr, names, values) if names.distinct.length != names.length => + case WithFields(structExpr, names, values) + if names.map(_.toLowerCase(Locale.ROOT)).distinct.length != names.length => + val caseSensitive = SQLConf.get.caseSensitiveAnalysis + val newNames = mutable.ArrayBuffer.empty[String] val newValues = mutable.ArrayBuffer.empty[Expression] - names.zip(values).reverse.foreach { case (name, value) => - if (newNames.find(resolver(_, name)).isEmpty) { - newNames += name - newValues += value + + if (caseSensitive) { + names.zip(values).reverse.foreach { case (name, value) => + if (!newNames.contains(name)) { + newNames += name + newValues += value + } + } + } else { + val nameSet = mutable.HashSet.empty[String] + names.zip(values).reverse.foreach { case (name, value) => + val lowercaseName = name.toLowerCase(Locale.ROOT) + if (!nameSet.contains(lowercaseName)) { + newNames += name + newValues += value + nameSet += lowercaseName + } } } + WithFields(structExpr, names = newNames.reverse.toSeq, valExprs = newValues.reverse.toSeq) case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala index 70298ce301b9d..6cbbe4283b5df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, GetStructField, Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ - +import org.apache.spark.sql.internal.SQLConf class OptimizeWithFieldsSuite extends PlanTest { @@ -92,7 +92,7 @@ class OptimizeWithFieldsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("SPARK-32941: optimize WithFields chain") { + test("SPARK-32941: optimize WithFields chain - case insensitive") { val originalQuery = testRelation .select(Alias( WithFields( @@ -101,13 +101,54 @@ class OptimizeWithFieldsSuite extends PlanTest { Seq("b1"), Seq(Literal(4))), Seq("b1"), - Seq(Literal(5))), "out")()) + Seq(Literal(5))), "out1")(), + Alias( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("B1"), + Seq(Literal(5))), "out2")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Alias(WithFields('a, Seq("b1"), Seq(Literal(5))), "out")()) + .select( + Alias(WithFields('a, Seq("b1"), Seq(Literal(5))), "out1")(), + Alias(WithFields('a, Seq("B1"), Seq(Literal(5))), "out2")()) .analyze comparePlans(optimized, correctAnswer) } + + test("SPARK-32941: optimize WithFields chain - case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val originalQuery = testRelation + .select(Alias( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("b1"), + Seq(Literal(5))), "out1")(), + Alias( + WithFields( + WithFields( + 'a, + Seq("b1"), + Seq(Literal(4))), + Seq("B1"), + Seq(Literal(5))), "out2")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select( + Alias(WithFields('a, Seq("b1"), Seq(Literal(5))), "out1")(), + Alias(WithFields('a, Seq("b1", "B1"), Seq(Literal(4), Literal(5))), "out2")()) + .analyze + + comparePlans(optimized, correctAnswer) + } + } } From 38bdefde56e039c0380a290b939a82f206487d80 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 18 Oct 2020 12:50:36 -0700 Subject: [PATCH 4/5] Simplify UpdateFields in analysis too. --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 ++++++ .../spark/sql/catalyst/analysis/ResolveUnion.scala | 11 ++--------- .../spark/sql/catalyst/optimizer/UpdateFields.scala | 10 ++++++---- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0ba150ec1efb4..4264627e0d9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -207,6 +208,11 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, + // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. + // However, when manipulating deeply nested schema, `UpdateFields` expression tree could be + // very complex and make analysis impossible. Thus we need to optimize `UpdateFields` early + // at the beginning of analysis. + OptimizeUpdateFields, CTESubstitution, WindowsSubstitution, EliminateUnions, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index c1a9c9d3d9bab..b08e116642ece 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.CombineUnions +import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, OptimizeUpdateFields} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -88,13 +88,6 @@ object ResolveUnion extends Rule[LogicalPlan] { } } - def simplifyWithFields(expr: Expression): Expression = { - expr.transformUp { - case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => - UpdateFields(struct, fieldOps1 ++ fieldOps2) - } - } - /** * Adds missing fields recursively into given `col` expression, based on the target `StructType`. * This is called by `compareAndAddFields` when we find two struct columns with same name but @@ -119,7 +112,7 @@ object ResolveUnion extends Rule[LogicalPlan] { missingFieldsOpt.map { s => val struct = addFieldsInto(col, s.fields) // Combines `WithFields`s to reduce expression tree. - val reducedStruct = simplifyWithFields(struct) + val reducedStruct = struct.transformUp(OptimizeUpdateFields.optimizeUpdateFields) val sorted = sortStructFieldsInWithFields(reducedStruct) sorted }.get diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala index accce49582b0a..af9e1ee89767d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala @@ -31,11 +31,11 @@ import org.apache.spark.sql.internal.SQLConf * Optimizes [[UpdateFields]] expression chains. */ object OptimizeUpdateFields extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + val optimizeUpdateFields: PartialFunction[Expression, Expression] = { case UpdateFields(structExpr, fieldOps) - if fieldOps.forall(_.isInstanceOf[WithField]) && - fieldOps.map(_.asInstanceOf[WithField].name.toLowerCase(Locale.ROOT)).distinct.length != - fieldOps.length => + if fieldOps.forall(_.isInstanceOf[WithField]) && + fieldOps.map(_.asInstanceOf[WithField].name.toLowerCase(Locale.ROOT)).distinct.length != + fieldOps.length => val caseSensitive = SQLConf.get.caseSensitiveAnalysis val withFields = fieldOps.map(_.asInstanceOf[WithField]) @@ -70,6 +70,8 @@ object OptimizeUpdateFields extends Rule[LogicalPlan] { case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => UpdateFields(struct, fieldOps1 ++ fieldOps2) } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions(optimizeUpdateFields) } /** From f41900cb3855d56dcb0ba7d5f448482baef1c3fa Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 18 Oct 2020 18:29:46 -0700 Subject: [PATCH 5/5] Skip the rule if possible. --- .../spark/sql/catalyst/optimizer/UpdateFields.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala index af9e1ee89767d..465d2efe2775c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala @@ -31,11 +31,18 @@ import org.apache.spark.sql.internal.SQLConf * Optimizes [[UpdateFields]] expression chains. */ object OptimizeUpdateFields extends Rule[LogicalPlan] { + private def canOptimize(names: Seq[String]): Boolean = { + if (SQLConf.get.caseSensitiveAnalysis) { + names.distinct.length != names.length + } else { + names.map(_.toLowerCase(Locale.ROOT)).distinct.length != names.length + } + } + val optimizeUpdateFields: PartialFunction[Expression, Expression] = { case UpdateFields(structExpr, fieldOps) if fieldOps.forall(_.isInstanceOf[WithField]) && - fieldOps.map(_.asInstanceOf[WithField].name.toLowerCase(Locale.ROOT)).distinct.length != - fieldOps.length => + canOptimize(fieldOps.map(_.asInstanceOf[WithField].name)) => val caseSensitive = SQLConf.get.caseSensitiveAnalysis val withFields = fieldOps.map(_.asInstanceOf[WithField])