From 343177838e4382fee2b9b102b1ead335a3fb24cb Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 24 Oct 2017 13:21:50 -0700 Subject: [PATCH 1/2] SPARK-22345: Fix sort-merge joins with conditions and codegen. Code for the condition was generated to depend on the right row instead of the joined row. --- .../execution/joins/SortMergeJoinExec.scala | 6 ++ .../sql/execution/joins/InnerJoinSuite.scala | 56 ++++++++++++++++++- 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 4e02803552e82..8d272ac313e2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -585,21 +585,26 @@ case class SortMergeJoinExec( val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") + val joinedRow = ctx.freshName("joined") val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. val loaded = ctx.freshName("loaded") val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) // Generate code for condition + // set INPUT_ROW to the joined row because it is the data for the condition + ctx.INPUT_ROW = joinedRow ctx.currentVars = leftVars ++ rightVars val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop val before = s""" |boolean $loaded = false; + |$joinedRow.withLeft($leftRow); |$leftBefore """.stripMargin val checking = s""" + |$joinedRow.withRight($rightRow); |$rightBefore |${cond.code} |if (${cond.isNull} || !${cond.value}) continue; @@ -615,6 +620,7 @@ case class SortMergeJoinExec( } s""" + |JoinedRow $joinedRow = new JoinedRow(); |while (findNextInnerJoinRows($leftInput, $rightInput)) { | ${beforeLoop.trim} | scala.collection.Iterator $iterator = $matches.generateIterator(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4408ece112258..4a4dcd2d9c11d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{And, BinaryExpression, Expression, Predicate} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join @@ -124,7 +125,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan) - EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) + EnsureRequirements(spark.sessionState.conf) + .apply(ProjectExec(sortMergeJoin.output, sortMergeJoin)) } test(s"$testName using BroadcastHashJoin (build=left)") { @@ -228,6 +230,49 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ) ) + testInnerJoin( + "inner join without codegen", + myUpperCaseData, + myLowerCaseData, + () => { + // test equality with a UDF that will not generate code to exercise CodegenFallback + // val udfEquals = org.apache.spark.sql.functions.udf((a: String, b: String) => + // a != null && a.toLowerCase(Locale.ENGLISH) == b.toLowerCase(Locale.ENGLISH)) + And( + (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + EqNoCodegen( + org.apache.spark.sql.functions.lower(myUpperCaseData.col("L")).expr, + myLowerCaseData.col("l").expr)) + }, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + + testInnerJoin( + "inner join with CodegenFallback filter", + myUpperCaseData, + myLowerCaseData, + () => { + // add a second equality check that is implemented with a CodegenFallback + // this expression is in the test so that no one implements codegen for it + And( + (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, + EqNoCodegen( + org.apache.spark.sql.functions.lower(myUpperCaseData.col("L")).expr, + myLowerCaseData.col("l").expr)) + }, + Seq( + (1, "A", 1, "a"), + (2, "B", 2, "b"), + (3, "C", 3, "c"), + (4, "D", 4, "d") + ) + ) + { lazy val left = myTestData1.where("a = 1") lazy val right = myTestData2.where("a = 1") @@ -287,3 +332,10 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { (Row(2, 2), "L2", Row(2, 2), "R2"))) } } + +case class EqNoCodegen(left: Expression, right: Expression) extends BinaryExpression + with CodegenFallback with Serializable with Predicate { + override protected def nullSafeEval(left: Any, right: Any): Boolean = { + left == right + } +} From 146d7918b38415980370617a9a97a5aaf657d2e8 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 26 Oct 2017 11:08:25 -0700 Subject: [PATCH 2/2] SPARK-22345: Remove duplicate test. --- .../sql/execution/joins/InnerJoinSuite.scala | 22 ------------------- 1 file changed, 22 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4a4dcd2d9c11d..e69271fa07aff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -230,28 +230,6 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ) ) - testInnerJoin( - "inner join without codegen", - myUpperCaseData, - myLowerCaseData, - () => { - // test equality with a UDF that will not generate code to exercise CodegenFallback - // val udfEquals = org.apache.spark.sql.functions.udf((a: String, b: String) => - // a != null && a.toLowerCase(Locale.ENGLISH) == b.toLowerCase(Locale.ENGLISH)) - And( - (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr, - EqNoCodegen( - org.apache.spark.sql.functions.lower(myUpperCaseData.col("L")).expr, - myLowerCaseData.col("l").expr)) - }, - Seq( - (1, "A", 1, "a"), - (2, "B", 2, "b"), - (3, "C", 3, "c"), - (4, "D", 4, "d") - ) - ) - testInnerJoin( "inner join with CodegenFallback filter", myUpperCaseData,