Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -120,19 +121,18 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
extends DeferredObject with HiveInspectors {

private val wrapper = wrapperFor(oi, dataType)
private var func: () => Any = _

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaooqinn @dongjoon-hyun This change removes deferred evaluation and means it is no longer possible to implement short-circuiting in Hive generic UDFs. I filed https://issues.apache.org/jira/browse/SPARK-44616 for this.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During the upgrade from Spark 3.3.1 to 3.5.1, we encountered syntax issues with this pr. The problem arose from DeferredObject currently passing a value instead of a function, which prevented users from catching exceptions in GenericUDF, resulting in semantic differences.

Here is an example case we encountered. Originally, the semantics were that str_to_map_udf would throw an exception due to issues with the input string, while merge_map_udf could catch the exception and return a null value. However, currently, any exception encountered by str_to_map_udf will cause the program to fail.

select merge_map_udf(str_to_map_udf(col1), parse_map_udf(col2), map("key", "value")) from table

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaooqinn is it easy to fix? If not we should probably revert it as this is not a critical perf improvement.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being late,my network glitches a lot recently. and thanks for reporting this issue. It’s easy to make a fix

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yaooqinn, is this already underway? I tried this on local #47193

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@panbingkun thank you

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to fix it in #47268 in another way, @yaooqinn would you please take a look?

def set(func: () => Any): Unit = {
private var func: Any = _
def set(func: Any): Unit = {
this.func = func
}
override def prepare(i: Int): Unit = {}
override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef]
override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
}

private[hive] case class HiveGenericUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression
with HiveInspectors
with CodegenFallback
with Logging
with UserDefinedExpression {

Expand All @@ -154,18 +154,20 @@ private[hive] case class HiveGenericUDF(
function.initializeAndFoldConstants(argumentInspectors.toArray)
}

// Visible for codegen
@transient
private lazy val unwrapper = unwrapperFor(returnInspector)
lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)

@transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic() && !udfType.stateful()
}

// Visible for codegen
@transient
private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
new DeferredObjectAdapter(inspect, child.dataType)
lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map {
case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
}.toArray[DeferredObject]

override lazy val dataType: DataType = inspectorToDataType(returnInspector)
Expand All @@ -178,7 +180,7 @@ private[hive] case class HiveGenericUDF(
while (i < length) {
val idx = i
deferredObjects(i).asInstanceOf[DeferredObjectAdapter]
.set(() => children(idx).eval(input))
.set(children(idx).eval(input))
i += 1
}
unwrapper(function.evaluate(deferredObjects))
Expand All @@ -192,6 +194,48 @@ private[hive] case class HiveGenericUDF(

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val refTerm = ctx.addReferenceObj("this", this)
val childrenEvals = children.map(_.genCode(ctx))

val setDeferredObjects = childrenEvals.zipWithIndex.map {
case (eval, i) =>
val deferredObjectAdapterClz = classOf[DeferredObjectAdapter].getCanonicalName
s"""
|if (${eval.isNull}) {
| (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(null);
Comment thread
yaooqinn marked this conversation as resolved.
|} else {
| (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(${eval.value});
|}
|""".stripMargin
}

val resultType = CodeGenerator.boxedType(dataType)
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
|${childrenEvals.map(_.code).mkString("\n")}
|${setDeferredObjects.mkString("\n")}
|$resultType $resultTerm = null;
|boolean ${ev.isNull} = false;
|try {
| $resultTerm = ($resultType) $refTerm.unwrapper().apply(
| $refTerm.function().evaluate($refTerm.deferredObjects()));
| ${ev.isNull} = $resultTerm == null;
|} catch (Throwable e) {
| throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
Comment thread
yaooqinn marked this conversation as resolved.
| "${funcWrapper.functionClassName}",
| "${children.map(_.dataType.catalogString).mkString(", ")}",
| "${dataType.catalogString}",
| e);
|}
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
|""".stripMargin
)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.io.{LongWritable, Writable}

import org.apache.spark.{SparkFiles, TestUtils}
import org.apache.spark.{SparkException, SparkFiles, TestUtils}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -711,6 +712,37 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
}
}

test("SPARK-42051: HiveGenericUDF Codegen Support") {
withUserDefinedFunction("CodeGenHiveGenericUDF" -> false) {
sql(s"CREATE FUNCTION CodeGenHiveGenericUDF AS '${classOf[GenericUDFMaskHash].getName}'")
withTable("HiveGenericUDFTable") {
sql(s"create table HiveGenericUDFTable as select 'Spark SQL' as v")
val df = sql("SELECT CodeGenHiveGenericUDF(v) from HiveGenericUDFTable")
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[WholeStageCodegenExec])
checkAnswer(df, Seq(Row("14ab8df5135825bc9f5ff7c30609f02f")))
}
}
}

test("SPARK-42051: HiveGenericUDF Codegen Support w/ execution failure") {
withUserDefinedFunction("CodeGenHiveGenericUDF" -> false) {
sql(s"CREATE FUNCTION CodeGenHiveGenericUDF AS '${classOf[GenericUDFAssertTrue].getName}'")
withTable("HiveGenericUDFTable") {
sql(s"create table HiveGenericUDFTable as select false as v")
val df = sql("SELECT CodeGenHiveGenericUDF(v) from HiveGenericUDFTable")
val e = intercept[SparkException](df.collect()).getCause.asInstanceOf[SparkException]
checkError(
e,
"FAILED_EXECUTE_UDF",
parameters = Map(
"functionName" -> s"${classOf[GenericUDFAssertTrue].getName}",
"signature" -> "boolean",
"result" -> "void"))
}
}
}
}

class TestPair(x: Int, y: Int) extends Writable with Serializable {
Expand Down