From 7aae81c73873b37494e6f6d393717aad068a4a9f Mon Sep 17 00:00:00 2001 From: Sunitha Kambhampati Date: Tue, 11 Feb 2020 13:29:29 -0800 Subject: [PATCH] Fix incorrect results during aggregate sum for decimal when there is overflow, throw exception and make it consistent to when wholestage codegen is disabled. Also fix the affected test from spark-28224 --- .../catalyst/expressions/aggregate/Sum.scala | 16 ++++++++++-- .../org/apache/spark/sql/DataFrameSuite.scala | 16 +++++------- .../org/apache/spark/sql/SQLQuerySuite.scala | 26 +++++++++++++++++++ 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 87f1a4f02e4fc..82b184dadebd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -57,6 +57,7 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast case _ => DoubleType } + private lazy val wrongResultDueToOverflow = false private lazy val sumDataType = resultType private lazy val sum = AttributeReference("sum", sumDataType)() @@ -73,7 +74,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast if (child.nullable) { Seq( /* sum = */ - coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + resultType match { + case d: DecimalType => coalesce( + CheckOverflow( + coalesce(sum, zero) + child.cast(sumDataType), d, wrongResultDueToOverflow), + sum) + case _ => coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum) + } ) } else { Seq( @@ -86,7 +93,12 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast override lazy val mergeExpressions: Seq[Expression] = { Seq( /* sum = */ - coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + resultType match { + case d: DecimalType => coalesce( + CheckOverflow(coalesce(sum.left, zero) + sum.right, d, wrongResultDueToOverflow), + sum.left) + case _ => coalesce(coalesce(sum.left, zero) + sum.right, sum.left) + } ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d2d58a83ded5d..c81fef0a691a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -169,18 +169,14 @@ class DataFrameSuite extends QueryTest DecimalData(BigDecimal("1"* 20 + ".123"), BigDecimal("1"* 20 + ".123")) :: DecimalData(BigDecimal("9"* 20 + ".123"), BigDecimal("9"* 20 + ".123")) :: Nil).toDF() - Seq(true, false).foreach { ansiEnabled => - withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + Seq("true", "false").foreach { codegenEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, codegenEnabled)) { val structDf = largeDecimals.select("a").agg(sum("a")) - if (!ansiEnabled) { - checkAnswer(structDf, Row(null)) - } else { - val e = intercept[SparkException] { - structDf.collect - } - assert(e.getCause.getClass.equals(classOf[ArithmeticException])) - assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + val e = intercept[SparkException] { + structDf.collect } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 11f9724e587f2..981f411940900 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3319,6 +3319,32 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } + test("SPARK-28067 - Aggregate sum should not return wrong results") { + Seq("true", "false").foreach { wholeStageEnabled => + withSQLConf((SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStageEnabled)) { + val df = Seq( + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 1), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2), + (BigDecimal("10000000000000000000"), 2)).toDF("decNum", "intNum") + val df2 = df.withColumnRenamed("decNum", "decNum2").join(df, "intNum").agg(sum("decNum")) + val e = intercept[SparkException] { + df2.collect() + } + assert(e.getCause.getClass.equals(classOf[ArithmeticException])) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal")) + } + } + } + test("SPARK-29213: FilterExec should not throw NPE") { withTempView("t1", "t2", "t3") { sql("SELECT ''").as[String].map(identity).toDF("x").createOrReplaceTempView("t1")