From 1567b6c63e251c55ccb503ceae298a4db79bae52 Mon Sep 17 00:00:00 2001 From: HyukjinKwon Date: Thu, 11 Jul 2019 14:42:41 +0900 Subject: [PATCH] 0-args Java UDF should not be called only once --- .../scala/org/apache/spark/sql/UDFRegistration.scala | 10 +++++----- .../main/scala/org/apache/spark/sql/functions.scala | 10 +++++----- .../src/test/scala/org/apache/spark/sql/UDFSuite.scala | 9 +++++++++ 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f0ef6e19b0aa0..bb05c76cfee6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -142,16 +142,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") val version = if (i == 0) "2.3.0" else "1.3.0" - val funcCall = if (i == 0) "() => func" else "func" + val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)" println(s""" |/** | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { - | val func = f$anyCast.call($anyParams) + | val func = $funcCall | def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF($funcCall, returnType, e, e.map(_ => false), udfName = Some(name)) + | ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $i; Found: " + e.length) @@ -717,9 +717,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { - val func = f.asInstanceOf[UDF0[Any]].call() + val func = () => f.asInstanceOf[UDF0[Any]].call() def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(() => func, returnType, e, e.map(_ => false), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5fa3fd0a37a65..72a197bdbcfc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3932,7 +3932,7 @@ object functions { val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") - val funcCall = if (i == 0) "() => func" else "func" + val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)" println(s""" |/** | * Defines a Java UDF$i instance as user-defined function (UDF). @@ -3944,8 +3944,8 @@ object functions { | * @since 2.3.0 | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { - | val func = f$anyCast.call($anyParams) - | SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None)) + | val func = $funcCall + | SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill($i)(None)) |}""".stripMargin) } @@ -4145,8 +4145,8 @@ object functions { * @since 2.3.0 */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { - val func = f.asInstanceOf[UDF0[Any]].call() - SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None)) + val func = () => f.asInstanceOf[UDF0[Any]].call() + SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(0)(None)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index f155b5dc80cf1..058c5ba7e50b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -514,4 +514,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(df.collect().toSeq === Seq(Row(expected))) } } + + test("SPARK-28321 0-args Java UDF should not be called only once") { + val nonDeterministicJavaUDF = udf( + new UDF0[Int] { + override def call(): Int = scala.util.Random.nextInt() + }, IntegerType).asNondeterministic() + + assert(spark.range(2).select(nonDeterministicJavaUDF()).distinct().count() == 2) + } }