Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
val version = if (i == 0) "2.3.0" else "1.3.0"
val funcCall = if (i == 0) "() => func" else "func"
val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)"
println(s"""
|/**
| * Register a deterministic Java UDF$i instance as user-defined function (UDF).
| * @since $version
| */
|def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
| val func = f$anyCast.call($anyParams)
| val func = $funcCall
| def builder(e: Seq[Expression]) = if (e.length == $i) {
| ScalaUDF($funcCall, returnType, e, e.map(_ => false), udfName = Some(name))
| ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
| } else {
| throw new AnalysisException("Invalid number of arguments for function " + name +
| ". Expected: $i; Found: " + e.length)
Expand Down Expand Up @@ -717,9 +717,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @since 2.3.0
*/
def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF0[Any]].call()
val func = () => f.asInstanceOf[UDF0[Any]].call()

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one, I manually tested but didn't add the test just in case we want to add some kind of optimization in the future. We shouldn't do such thing here but in optimizer anyway. Seems like just a mistake.

def builder(e: Seq[Expression]) = if (e.length == 0) {
ScalaUDF(() => func, returnType, e, e.map(_ => false), udfName = Some(name))
ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 0; Found: " + e.length)
Expand Down
10 changes: 5 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3932,7 +3932,7 @@ object functions {
val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ")
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]"
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
val funcCall = if (i == 0) "() => func" else "func"
val funcCall = if (i == 0) s"() => f$anyCast.call($anyParams)" else s"f$anyCast.call($anyParams)"
println(s"""
|/**
| * Defines a Java UDF$i instance as user-defined function (UDF).
Expand All @@ -3944,8 +3944,8 @@ object functions {
| * @since 2.3.0
| */
|def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = {
| val func = f$anyCast.call($anyParams)
| SparkUserDefinedFunction($funcCall, returnType, inputSchemas = Seq.fill($i)(None))
| val func = $funcCall
| SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill($i)(None))
Comment thread
HyukjinKwon marked this conversation as resolved.
Outdated
|}""".stripMargin)
}

Expand Down Expand Up @@ -4145,8 +4145,8 @@ object functions {
* @since 2.3.0
*/
def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
val func = f.asInstanceOf[UDF0[Any]].call()
SparkUserDefinedFunction(() => func, returnType, inputSchemas = Seq.fill(0)(None))
val func = () => f.asInstanceOf[UDF0[Any]].call()
SparkUserDefinedFunction(func, returnType, inputSchemas = Seq.fill(0)(None))
}

/**
Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -514,4 +514,13 @@ class UDFSuite extends QueryTest with SharedSQLContext {
assert(df.collect().toSeq === Seq(Row(expected)))
}
}

test("SPARK-28321 0-args Java UDF should not be called only once") {
val nonDeterministicJavaUDF = udf(
new UDF0[Int] {
override def call(): Int = scala.util.Random.nextInt()
}, IntegerType).asNondeterministic()

assert(spark.range(2).select(nonDeterministicJavaUDF()).distinct().count() == 2)
}
}