From 34c116471c08719cab98426c913f4ec9d8f1c1a0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Oct 2020 22:51:45 -0700 Subject: [PATCH 1/5] Support subexpression elimination in ProjectExec. --- .../sql/execution/basicPhysicalOperators.scala | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1f70fde3f7654..7e19965eef0e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -66,10 +66,22 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = bindReferences[Expression](projectList, child.output) - val resultVars = exprs.map(_.genCode(ctx)) + val (subExprsCode, resultVars) = if (SQLConf.get.subexpressionEliminationEnabled) { + // subexpression elimination + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) + val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { + exprs.map(_.genCode(ctx)) + } + (subExprs.codes.mkString("\n"), genVars) + } else { + ("", exprs.map(_.genCode(ctx))) + } + // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) s""" + |// common sub-expressions + |$subExprsCode |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} |${consume(ctx, resultVars)} """.stripMargin From 59e3ee373b112f71c0a0d77c1f6fa13a83d8c6f2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 7 Oct 2020 23:09:41 -0700 Subject: [PATCH 2/5] Use conf directly. --- .../org/apache/spark/sql/execution/basicPhysicalOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 7e19965eef0e4..ff5ba6703e767 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -66,7 +66,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = bindReferences[Expression](projectList, child.output) - val (subExprsCode, resultVars) = if (SQLConf.get.subexpressionEliminationEnabled) { + val (subExprsCode, resultVars) = if (conf.subexpressionEliminationEnabled) { // subexpression elimination val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { From 0d958e90745c26cd632325fb288e8e7d4631fe44 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 8 Oct 2020 17:46:36 -0700 Subject: [PATCH 3/5] Evaluate input variables. --- .../expressions/codegen/CodeGenerator.scala | 40 +++++++++++++------ .../aggregate/HashAggregateExec.scala | 2 +- .../execution/basicPhysicalOperators.scala | 7 ++-- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 71d36733464f6..3f1f2cdea7738 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -90,8 +90,13 @@ case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) * @param codes Strings representing the codes that evaluate common subexpressions. * @param states Foreach expression that is participating in subexpression elimination, * the state to use. + * @param exprCodesNeedEvaluate Some expression codes that need to be evaluate before + * calling common subexpressions. */ -case class SubExprCodes(codes: Seq[String], states: Map[Expression, SubExprEliminationState]) +case class SubExprCodes( + codes: Seq[String], + states: Map[Expression, SubExprEliminationState], + exprCodesNeedEvaluate: Seq[ExprCode]) /** * The main information about a new added function. @@ -1044,7 +1049,7 @@ class CodegenContext extends Logging { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) - val commonExprVals = commonExprs.map(_.head.genCode(this)) + lazy val commonExprVals = commonExprs.map(_.head.genCode(this)) lazy val nonSplitExprCode = { commonExprs.zip(commonExprVals).map { case (exprs, eval) => @@ -1055,10 +1060,13 @@ class CodegenContext extends Logging { } } - val codes = if (commonExprVals.map(_.code.length).sum > SQLConf.get.methodSplitThreshold) { - val inputVarsForAllFuncs = commonExprs.map { expr => - getLocalInputVariableValues(this, expr.head).toSeq - } + val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr => + val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head) + (inputVars.toSeq, exprCodes.toSeq) + }.unzip + + val splitThreshold = SQLConf.get.methodSplitThreshold + val codes = if (commonExprVals.map(_.code.length).sum > splitThreshold) { if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) { commonExprs.zipWithIndex.map { case (exprs, i) => val expr = exprs.head @@ -1109,7 +1117,7 @@ class CodegenContext extends Logging { } else { nonSplitExprCode } - SubExprCodes(codes, localSubExprEliminationExprs.toMap) + SubExprCodes(codes, localSubExprEliminationExprs.toMap, exprCodesNeedEvaluate.flatten) } /** @@ -1739,8 +1747,11 @@ object CodeGenerator extends Logging { def getLocalInputVariableValues( ctx: CodegenContext, expr: Expression, - subExprs: Map[Expression, SubExprEliminationState] = Map.empty): Set[VariableValue] = { + subExprs: Map[Expression, SubExprEliminationState] = Map.empty): + (Set[VariableValue], Set[ExprCode]) = { val argSet = mutable.Set[VariableValue]() + val exprCodesNeedEvaluate = mutable.Set[ExprCode]() + if (ctx.INPUT_ROW != null) { argSet += JavaCode.variable(ctx.INPUT_ROW, classOf[InternalRow]) } @@ -1761,16 +1772,21 @@ object CodeGenerator extends Logging { case ref: BoundReference if ctx.currentVars != null && ctx.currentVars(ref.ordinal) != null => - val ExprCode(_, isNull, value) = ctx.currentVars(ref.ordinal) - collectLocalVariable(value) - collectLocalVariable(isNull) + val exprCode = ctx.currentVars(ref.ordinal) + // If the referred variable is not evaluated yet. + if (exprCode.code != EmptyBlock) { + exprCodesNeedEvaluate += exprCode.copy() + exprCode.code = EmptyBlock + } + collectLocalVariable(exprCode.value) + collectLocalVariable(exprCode.isNull) case e => stack.pushAll(e.children) } } - argSet.toSet + (argSet.toSet, exprCodesNeedEvaluate.toSet) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index dcb465707a0ed..52d0450afb181 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -263,7 +263,7 @@ case class HashAggregateExec( } else { val inputVars = aggBufferUpdatingExprs.map { aggExprsForOneFunc => val inputVarsForOneFunc = aggExprsForOneFunc.map( - CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)).reduce(_ ++ _).toSeq + CodeGenerator.getLocalInputVariableValues(ctx, _, subExprs)._1).reduce(_ ++ _).toSeq val paramLength = CodeGenerator.calculateParamLengthFromExprValues(inputVarsForOneFunc) // Checks if a parameter length for the `aggExprsForOneFunc` does not go over the JVM limit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index ff5ba6703e767..7334ea1e27284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -66,21 +66,22 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { val exprs = bindReferences[Expression](projectList, child.output) - val (subExprsCode, resultVars) = if (conf.subexpressionEliminationEnabled) { + val (subExprsCode, resultVars, localValInputs) = if (conf.subexpressionEliminationEnabled) { // subexpression elimination val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs) val genVars = ctx.withSubExprEliminationExprs(subExprs.states) { exprs.map(_.genCode(ctx)) } - (subExprs.codes.mkString("\n"), genVars) + (subExprs.codes.mkString("\n"), genVars, subExprs.exprCodesNeedEvaluate) } else { - ("", exprs.map(_.genCode(ctx))) + ("", exprs.map(_.genCode(ctx)), Seq.empty) } // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) s""" |// common sub-expressions + |${evaluateVariables(localValInputs)} |$subExprsCode |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))} |${consume(ctx, resultVars)} From fe92bfe066d0090c302082e16e9757be27aaef3e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 8 Oct 2020 22:10:49 -0700 Subject: [PATCH 4/5] Fix test. --- .../org/apache/spark/sql/connector/DataSourceV2Suite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index a9c521eb46499..ec1ac00d08bf8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -268,7 +268,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } // this input data will fail to read middle way. - val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) + val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j) val e3 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } From 2414bb041f47c00de28428014372ab6c55435e56 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 11 Oct 2020 18:39:42 -0700 Subject: [PATCH 5/5] Address some comments. --- .../expressions/codegen/CodeGenerator.scala | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 3f1f2cdea7738..9a26c388f59af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -90,7 +90,7 @@ case class SubExprEliminationState(isNull: ExprValue, value: ExprValue) * @param codes Strings representing the codes that evaluate common subexpressions. * @param states Foreach expression that is participating in subexpression elimination, * the state to use. - * @param exprCodesNeedEvaluate Some expression codes that need to be evaluate before + * @param exprCodesNeedEvaluate Some expression codes that need to be evaluated before * calling common subexpressions. */ case class SubExprCodes( @@ -1060,6 +1060,10 @@ class CodegenContext extends Logging { } } + // For some operators, they do not require all its child's outputs to be evaluated in advance. + // Instead it only early evaluates part of outputs, for example, `ProjectExec` only early + // evaluate the outputs used more than twice. So we need to extract these variables used by + // subexpressions and evaluate them before subexpressions. val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr => val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head) (inputVars.toSeq, exprCodes.toSeq) @@ -1740,15 +1744,20 @@ object CodeGenerator extends Logging { } /** - * Extracts all the input variables from references and subexpression elimination states - * for a given `expr`. This result will be used to split the generated code of - * expressions into multiple functions. + * This methods returns two values in a Tuple. + * + * First value: Extracts all the input variables from references and subexpression + * elimination states for a given `expr`. This result will be used to split the + * generated code of expressions into multiple functions. + * + * Second value: Returns the set of `ExprCodes`s which are necessary codes before + * evaluating subexpressions. */ def getLocalInputVariableValues( ctx: CodegenContext, expr: Expression, - subExprs: Map[Expression, SubExprEliminationState] = Map.empty): - (Set[VariableValue], Set[ExprCode]) = { + subExprs: Map[Expression, SubExprEliminationState] = Map.empty) + : (Set[VariableValue], Set[ExprCode]) = { val argSet = mutable.Set[VariableValue]() val exprCodesNeedEvaluate = mutable.Set[ExprCode]()