From 78c6c13cc88a9f7fdf347e0178834f3293713526 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 20 Sep 2019 12:54:00 +0500 Subject: [PATCH 1/3] Handle null field --- .../expressions/datetimeExpressions.scala | 57 +++++++++++-------- .../apache/spark/sql/DateFunctionsSuite.scala | 10 ++++ 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 89a6d23b1d73d..d2b7d357b1434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1983,29 +1983,33 @@ object DatePart { def parseExtractField( extractField: String, source: Expression, - errorHandleFunc: => Nothing): Expression = extractField.toUpperCase(Locale.ROOT) match { - case "MILLENNIUM" | "MILLENNIA" | "MIL" | "MILS" => Millennium(source) - case "CENTURY" | "CENTURIES" | "C" | "CENT" => Century(source) - case "DECADE" | "DECADES" | "DEC" | "DECS" => Decade(source) - case "YEAR" | "Y" | "YEARS" | "YR" | "YRS" => Year(source) - case "ISOYEAR" => IsoYear(source) - case "QUARTER" | "QTR" => Quarter(source) - case "MONTH" | "MON" | "MONS" | "MONTHS" => Month(source) - case "WEEK" | "W" | "WEEKS" => WeekOfYear(source) - case "DAY" | "D" | "DAYS" => DayOfMonth(source) - case "DAYOFWEEK" => DayOfWeek(source) - case "DOW" => Subtract(DayOfWeek(source), Literal(1)) - case "ISODOW" => Add(WeekDay(source), Literal(1)) - case "DOY" => DayOfYear(source) - case "HOUR" | "H" | "HOURS" | "HR" | "HRS" => Hour(source) - case "MINUTE" | "M" | "MIN" | "MINS" | "MINUTES" => Minute(source) - case "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => Second(source) - case "MILLISECONDS" | "MSEC" | "MSECS" | "MILLISECON" | "MSECONDS" | "MS" => - Milliseconds(source) - case "MICROSECONDS" | "USEC" | "USECS" | "USECONDS" | "MICROSECON" | "US" => - Microseconds(source) - case "EPOCH" => Epoch(source) - case _ => errorHandleFunc + errorHandleFunc: => Nothing): Expression = { + if (extractField == null) { + Literal(null, DoubleType) + } else extractField.toUpperCase(Locale.ROOT) match { + case "MILLENNIUM" | "MILLENNIA" | "MIL" | "MILS" => Millennium(source) + case "CENTURY" | "CENTURIES" | "C" | "CENT" => Century(source) + case "DECADE" | "DECADES" | "DEC" | "DECS" => Decade(source) + case "YEAR" | "Y" | "YEARS" | "YR" | "YRS" => Year(source) + case "ISOYEAR" => IsoYear(source) + case "QUARTER" | "QTR" => Quarter(source) + case "MONTH" | "MON" | "MONS" | "MONTHS" => Month(source) + case "WEEK" | "W" | "WEEKS" => WeekOfYear(source) + case "DAY" | "D" | "DAYS" => DayOfMonth(source) + case "DAYOFWEEK" => DayOfWeek(source) + case "DOW" => Subtract(DayOfWeek(source), Literal(1)) + case "ISODOW" => Add(WeekDay(source), Literal(1)) + case "DOY" => DayOfYear(source) + case "HOUR" | "H" | "HOURS" | "HR" | "HRS" => Hour(source) + case "MINUTE" | "M" | "MIN" | "MINS" | "MINUTES" => Minute(source) + case "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => Second(source) + case "MILLISECONDS" | "MSEC" | "MSECS" | "MILLISECON" | "MSECONDS" | "MS" => + Milliseconds(source) + case "MICROSECONDS" | "USEC" | "USECS" | "USECONDS" | "MICROSECON" | "US" => + Microseconds(source) + case "EPOCH" => Epoch(source) + case _ => errorHandleFunc + } } } @@ -2053,7 +2057,12 @@ case class DatePart(field: Expression, source: Expression, child: Expression) if (!field.foldable) { throw new AnalysisException("The field parameter needs to be a foldable string value.") } - val fieldStr = field.eval().asInstanceOf[UTF8String].toString + val fieldEval = field.eval() + val fieldStr = if (fieldEval != null) { + fieldEval.asInstanceOf[UTF8String].toString + } else { + fieldEval.asInstanceOf[String] + } DatePart.parseExtractField(fieldStr, source, { throw new AnalysisException(s"Literals of type '$fieldStr' are currently not supported.") }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index 2fef05f97e57c..99189a96b2995 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.unsafe.types.CalendarInterval class DateFunctionsSuite extends QueryTest with SharedSparkSession { @@ -796,4 +797,13 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(Instant.parse(timestamp)))) } } + + test("handling null field by date_part") { + val input = Seq(Date.valueOf("2019-09-20")).toDF("d") + Seq("date_part(null, d)", "date_part(null, date'2019-09-20')").foreach { expr => + val df = input.selectExpr(expr) + assert(df.schema.headOption.get.dataType == DoubleType) + checkAnswer(df, Row(null)) + } + } } From b111cf21b0c8428c6b8e26fed9f095cdf6b17f24 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 20 Sep 2019 13:00:00 +0500 Subject: [PATCH 2/3] Handle null field --- .../expressions/datetimeExpressions.scala | 62 +++++++++---------- 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index d2b7d357b1434..592b9de83d9a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1983,33 +1983,29 @@ object DatePart { def parseExtractField( extractField: String, source: Expression, - errorHandleFunc: => Nothing): Expression = { - if (extractField == null) { - Literal(null, DoubleType) - } else extractField.toUpperCase(Locale.ROOT) match { - case "MILLENNIUM" | "MILLENNIA" | "MIL" | "MILS" => Millennium(source) - case "CENTURY" | "CENTURIES" | "C" | "CENT" => Century(source) - case "DECADE" | "DECADES" | "DEC" | "DECS" => Decade(source) - case "YEAR" | "Y" | "YEARS" | "YR" | "YRS" => Year(source) - case "ISOYEAR" => IsoYear(source) - case "QUARTER" | "QTR" => Quarter(source) - case "MONTH" | "MON" | "MONS" | "MONTHS" => Month(source) - case "WEEK" | "W" | "WEEKS" => WeekOfYear(source) - case "DAY" | "D" | "DAYS" => DayOfMonth(source) - case "DAYOFWEEK" => DayOfWeek(source) - case "DOW" => Subtract(DayOfWeek(source), Literal(1)) - case "ISODOW" => Add(WeekDay(source), Literal(1)) - case "DOY" => DayOfYear(source) - case "HOUR" | "H" | "HOURS" | "HR" | "HRS" => Hour(source) - case "MINUTE" | "M" | "MIN" | "MINS" | "MINUTES" => Minute(source) - case "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => Second(source) - case "MILLISECONDS" | "MSEC" | "MSECS" | "MILLISECON" | "MSECONDS" | "MS" => - Milliseconds(source) - case "MICROSECONDS" | "USEC" | "USECS" | "USECONDS" | "MICROSECON" | "US" => - Microseconds(source) - case "EPOCH" => Epoch(source) - case _ => errorHandleFunc - } + errorHandleFunc: => Nothing): Expression = extractField.toUpperCase(Locale.ROOT) match { + case "MILLENNIUM" | "MILLENNIA" | "MIL" | "MILS" => Millennium(source) + case "CENTURY" | "CENTURIES" | "C" | "CENT" => Century(source) + case "DECADE" | "DECADES" | "DEC" | "DECS" => Decade(source) + case "YEAR" | "Y" | "YEARS" | "YR" | "YRS" => Year(source) + case "ISOYEAR" => IsoYear(source) + case "QUARTER" | "QTR" => Quarter(source) + case "MONTH" | "MON" | "MONS" | "MONTHS" => Month(source) + case "WEEK" | "W" | "WEEKS" => WeekOfYear(source) + case "DAY" | "D" | "DAYS" => DayOfMonth(source) + case "DAYOFWEEK" => DayOfWeek(source) + case "DOW" => Subtract(DayOfWeek(source), Literal(1)) + case "ISODOW" => Add(WeekDay(source), Literal(1)) + case "DOY" => DayOfYear(source) + case "HOUR" | "H" | "HOURS" | "HR" | "HRS" => Hour(source) + case "MINUTE" | "M" | "MIN" | "MINS" | "MINUTES" => Minute(source) + case "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => Second(source) + case "MILLISECONDS" | "MSEC" | "MSECS" | "MILLISECON" | "MSECONDS" | "MS" => + Milliseconds(source) + case "MICROSECONDS" | "USEC" | "USECS" | "USECONDS" | "MICROSECON" | "US" => + Microseconds(source) + case "EPOCH" => Epoch(source) + case _ => errorHandleFunc } } @@ -2058,14 +2054,14 @@ case class DatePart(field: Expression, source: Expression, child: Expression) throw new AnalysisException("The field parameter needs to be a foldable string value.") } val fieldEval = field.eval() - val fieldStr = if (fieldEval != null) { - fieldEval.asInstanceOf[UTF8String].toString + if (fieldEval == null) { + Literal(null, DoubleType) } else { - fieldEval.asInstanceOf[String] + val fieldStr = fieldEval.asInstanceOf[UTF8String].toString + DatePart.parseExtractField(fieldStr, source, { + throw new AnalysisException(s"Literals of type '$fieldStr' are currently not supported.") + }) } - DatePart.parseExtractField(fieldStr, source, { - throw new AnalysisException(s"Literals of type '$fieldStr' are currently not supported.") - }) }) } From b39adaaea65a8772f2307eb5858449e0eb8c8358 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Fri, 20 Sep 2019 13:48:57 +0500 Subject: [PATCH 3/3] Add a test to date_part.sql --- .../src/test/resources/sql-tests/inputs/date_part.sql | 2 ++ .../test/resources/sql-tests/results/date_part.sql.out | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/date_part.sql b/sql/core/src/test/resources/sql-tests/inputs/date_part.sql index cb3d966281009..fd0fb50f71460 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/date_part.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/date_part.sql @@ -66,3 +66,5 @@ select date_part('secs', c) from t; select date_part('not_supported', c) from t; select date_part(c, c) from t; + +select date_part(null, c) from t; diff --git a/sql/core/src/test/resources/sql-tests/results/date_part.sql.out b/sql/core/src/test/resources/sql-tests/results/date_part.sql.out index c59dfdbd3da34..776786850e9da 100644 --- a/sql/core/src/test/resources/sql-tests/results/date_part.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/date_part.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 51 +-- Number of queries: 52 -- !query 0 @@ -410,3 +410,11 @@ struct<> -- !query 50 output org.apache.spark.sql.AnalysisException The field parameter needs to be a foldable string value.;; line 1 pos 7 + + +-- !query 51 +select date_part(null, c) from t +-- !query 51 schema +struct +-- !query 51 output +NULL