From 3b266ba981d9031ba56a3c5699cde1fd6cf0ebea Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 14 Jun 2015 01:20:09 -0700 Subject: [PATCH 1/2] [SPARK-8362] Add unit tests for +, -, *, /. --- .../sql/catalyst/expressions/arithmetic.scala | 16 +-- .../ArithmeticExpressionSuite.scala | 132 ++++++++++-------- 2 files changed, 85 insertions(+), 63 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 18ddac1b598e6..73d4ad0826c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.TypeUtils @@ -52,8 +51,8 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { private lazy val numeric = TypeUtils.getNumeric(dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { - case dt: DecimalType => defineCodeGen(ctx, ev, c => s"c.unary_$$minus()") - case dt: NumericType => defineCodeGen(ctx, ev, c => s"-($c)") + case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") + case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))") } protected override def evalInternal(evalE: Any) = numeric.negate(evalE) @@ -144,8 +143,8 @@ abstract class BinaryArithmetic extends BinaryExpression { defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide case ByteType | ShortType => - defineCodeGen(ctx, ev, (eval1, eval2) => - s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") + defineCodeGen(ctx, ev, + (eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)") case _ => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2") } @@ -205,7 +204,7 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "/" - override def decimalMethod: String = "$divide" + override def decimalMethod: String = "$div" override def nullable: Boolean = true @@ -257,7 +256,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull} || ${eval2.isNull} || $test) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + ${ev.primitive} = + (${ctx.javaType(left.dataType)})(${eval1.primitive}$method(${eval2.primitive})); } """ } @@ -265,7 +265,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def symbol: String = "%" - override def decimalMethod: String = "reminder" + override def decimalMethod: String = "remainder" override def nullable: Boolean = true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 5ff1bca260b24..a56a5ae70a95f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.scalatest.Matchers._ - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} @@ -26,68 +24,92 @@ import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType} class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { - test("arithmetic") { - val row = create_row(1, 2, 3, null) - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.int.at(3) - - checkEvaluation(UnaryMinus(c1), -1, row) - checkEvaluation(UnaryMinus(Literal.create(100, IntegerType)), -100) - - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3, row) - checkEvaluation(Add(c1, Literal.create(null, IntegerType)), null, row) - checkEvaluation(Add(Literal.create(null, IntegerType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), null, row) - - checkEvaluation(-c1, -1, row) - checkEvaluation(c1 + c2, 3, row) - checkEvaluation(c1 - c2, -1, row) - checkEvaluation(c1 * c2, 2, row) - checkEvaluation(c1 / c2, 0, row) - checkEvaluation(c1 % c2, 1, row) + /** + * Runs through the testFunc for all numeric data types. + * + * @param testFunc a test function that accepts a conversion function to convert an integer + * into another data type. + */ + private def testNumericDataTypes(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toByte) + testFunc(_.toShort) + testFunc(identity) + testFunc(_.toLong) + testFunc(_.toFloat) + testFunc(_.toDouble) + testFunc(Decimal(_)) + } + + test("+ (Add)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Add(left, right), convert(3)) + checkEvaluation(Add(Literal.create(null, left.dataType), right), null) + checkEvaluation(Add(left, Literal.create(null, right.dataType)), null) + } + } + + test("- (UnaryMinus)") { + testNumericDataTypes { convert => + val input = Literal(convert(1)) + val dataType = input.dataType + checkEvaluation(UnaryMinus(input), convert(-1)) + checkEvaluation(UnaryMinus(Literal.create(null, dataType)), null) + } + } + + test("- (Minus)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Subtract(left, right), convert(-1)) + checkEvaluation(Subtract(Literal.create(null, left.dataType), right), null) + checkEvaluation(Subtract(left, Literal.create(null, right.dataType)), null) + } + } + + test("* (Multiply)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Multiply(left, right), convert(2)) + checkEvaluation(Multiply(Literal.create(null, left.dataType), right), null) + checkEvaluation(Multiply(left, Literal.create(null, right.dataType)), null) + } + } + + test("/ (Divide) basic") { + testNumericDataTypes { convert => + val left = Literal(convert(2)) + val right = Literal(convert(1)) + val dataType = left.dataType + checkEvaluation(Divide(left, right), convert(2)) + checkEvaluation(Divide(Literal.create(null, dataType), right), null) + checkEvaluation(Divide(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero + } } - test("fractional arithmetic") { - val row = create_row(1.1, 2.0, 3.1, null) - val c1 = 'a.double.at(0) - val c2 = 'a.double.at(1) - val c3 = 'a.double.at(2) - val c4 = 'a.double.at(3) - - checkEvaluation(UnaryMinus(c1), -1.1, row) - checkEvaluation(UnaryMinus(Literal.create(100.0, DoubleType)), -100.0) - checkEvaluation(Add(c1, c4), null, row) - checkEvaluation(Add(c1, c2), 3.1, row) - checkEvaluation(Add(c1, Literal.create(null, DoubleType)), null, row) - checkEvaluation(Add(Literal.create(null, DoubleType), c2), null, row) - checkEvaluation( - Add(Literal.create(null, DoubleType), Literal.create(null, DoubleType)), null, row) - - checkEvaluation(-c1, -1.1, row) - checkEvaluation(c1 + c2, 3.1, row) - checkDoubleEvaluation(c1 - c2, (-0.9 +- 0.001), row) - checkDoubleEvaluation(c1 * c2, (2.2 +- 0.001), row) - checkDoubleEvaluation(c1 / c2, (0.55 +- 0.001), row) - checkDoubleEvaluation(c3 % c2, (1.1 +- 0.001), row) + test("/ (Divide) for integral type") { + checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) + checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) + checkEvaluation(Divide(Literal(1), Literal(2)), 0) + checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) + } + + test("/ (Divide) for floating point") { + checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) + checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) + checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) } test("Abs") { - def testAbs(convert: (Int) => Any): Unit = { + testNumericDataTypes { convert => checkEvaluation(Abs(Literal(convert(0))), convert(0)) checkEvaluation(Abs(Literal(convert(1))), convert(1)) checkEvaluation(Abs(Literal(convert(-1))), convert(1)) } - testAbs(_.toByte) - testAbs(_.toShort) - testAbs(identity) - testAbs(_.toLong) - testAbs(_.toFloat) - testAbs(_.toDouble) - testAbs(Decimal(_)) } test("Divide") { From fb3fe629492ae35804e3a41db9817ae28a81e439 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 14 Jun 2015 01:25:23 -0700 Subject: [PATCH 2/2] Added Remainder. --- .../sql/catalyst/expressions/arithmetic.scala | 19 +++----- .../ArithmeticExpressionSuite.scala | 43 +++++-------------- 2 files changed, 17 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 73d4ad0826c56..9d1e96572a26d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -244,11 +244,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) { - s".$decimalMethod" - } else { - s"$symbol" - } + val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " + val javaType = ctx.javaType(left.dataType) eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; @@ -256,8 +253,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic if (${eval1.isNull} || ${eval2.isNull} || $test) { ${ev.isNull} = true; } else { - ${ev.primitive} = - (${ctx.javaType(left.dataType)})(${eval1.primitive}$method(${eval2.primitive})); + ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); } """ } @@ -305,11 +301,8 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) { - s".$decimalMethod" - } else { - s"$symbol" - } + val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " + val javaType = ctx.javaType(left.dataType) eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; @@ -317,7 +310,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet if (${eval1.isNull} || ${eval2.isNull} || $test) { ${ev.isNull} = true; } else { - ${ev.primitive} = ${eval1.primitive}$method(${eval2.primitive}); + ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); } """ } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index a56a5ae70a95f..3f4843259e80b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -104,6 +104,17 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) } + test("% (Remainder)") { + testNumericDataTypes { convert => + val left = Literal(convert(1)) + val right = Literal(convert(2)) + checkEvaluation(Remainder(left, right), convert(1)) + checkEvaluation(Remainder(Literal.create(null, left.dataType), right), null) + checkEvaluation(Remainder(left, Literal.create(null, right.dataType)), null) + checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0 + } + } + test("Abs") { testNumericDataTypes { convert => checkEvaluation(Abs(Literal(convert(0))), convert(0)) @@ -112,38 +123,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("Divide") { - checkEvaluation(Divide(Literal(2), Literal(1)), 2) - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(1), Literal(2)), 0) - checkEvaluation(Divide(Literal(1), Literal(0)), null) - checkEvaluation(Divide(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Divide(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Divide(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Divide(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - - test("Remainder") { - checkEvaluation(Remainder(Literal(2), Literal(1)), 0) - checkEvaluation(Remainder(Literal(1.0), Literal(2.0)), 1.0) - checkEvaluation(Remainder(Literal(1), Literal(2)), 1) - checkEvaluation(Remainder(Literal(1), Literal(0)), null) - checkEvaluation(Remainder(Literal(1.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0.0), Literal(0.0)), null) - checkEvaluation(Remainder(Literal(0), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal(1), Literal.create(null, IntegerType)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(0)), null) - checkEvaluation(Remainder(Literal.create(null, DoubleType), Literal(0.0)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal(1)), null) - checkEvaluation(Remainder(Literal.create(null, IntegerType), Literal.create(null, IntegerType)), - null) - } - test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2)