diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala index 2ebe724f399a7..409be67f7af4c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFEvaluators.scala @@ -129,7 +129,11 @@ class HiveGenericUDFEvaluator( override def returnType: DataType = inspectorToDataType(returnInspector) def setArg(index: Int, arg: Any): Unit = - deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg) + deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => arg) + + def setException(index: Int, exp: Throwable): Unit = { + deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => throw exp) + } override def doEvaluate(): Any = unwrapper(function.evaluate(deferredObjects)) } @@ -139,10 +143,10 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp extends DeferredObject with HiveInspectors { private val wrapper = wrapperFor(oi, dataType) - private var func: Any = _ - def set(func: Any): Unit = { + private var func: () => Any = _ + def set(func: () => Any): Unit = { this.func = func } override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef] + override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 6efdb676ccbdc..227c6a618e3d4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -136,7 +136,13 @@ private[hive] case class HiveGenericUDF( override def eval(input: InternalRow): Any = { children.zipWithIndex.foreach { - case (child, idx) => evaluator.setArg(idx, child.eval(input)) + case (child, idx) => + try { + evaluator.setArg(idx, child.eval(input)) + } catch { + case t: Throwable => + evaluator.setException(idx, t) + } } evaluator.evaluate() } @@ -157,10 +163,15 @@ private[hive] case class HiveGenericUDF( val setValues = evals.zipWithIndex.map { case (eval, i) => s""" - |if (${eval.isNull}) { - | $refEvaluator.setArg($i, null); - |} else { - | $refEvaluator.setArg($i, ${eval.value}); + |try { + | ${eval.code} + | if (${eval.isNull}) { + | $refEvaluator.setArg($i, null); + | } else { + | $refEvaluator.setArg($i, ${eval.value}); + | } + |} catch (Throwable t) { + | $refEvaluator.setException($i, t); |} |""".stripMargin } @@ -169,7 +180,6 @@ private[hive] case class HiveGenericUDF( val resultTerm = ctx.freshName("result") ev.copy(code = code""" - |${evals.map(_.code).mkString("\n")} |${setValues.mkString("\n")} |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(); |boolean ${ev.isNull} = $resultTerm == null; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java new file mode 100644 index 0000000000000..242dbeaa63c94 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFCatchException.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; + +public class UDFCatchException extends GenericUDF { + + @Override + public ObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException { + if (args.length != 1) { + throw new UDFArgumentException("Exactly one argument is expected."); + } + return PrimitiveObjectInspectorFactory.javaStringObjectInspector; + } + + @Override + public Object evaluate(GenericUDF.DeferredObject[] args) { + if (args == null) { + return null; + } + try { + return args[0].get(); + } catch (Exception e) { + return null; + } + } + + @Override + public String getDisplayString(String[] children) { + return null; + } +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java new file mode 100644 index 0000000000000..5d6ff6ca40ae5 --- /dev/null +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/execution/UDFThrowException.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution; + +import org.apache.hadoop.hive.ql.exec.UDF; + +public class UDFThrowException extends UDF { + public String evaluate(String data) { + return Integer.valueOf(data).toString(); + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index d73f2be3b3f50..2e88b13f0963d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.io.{LongWritable, Writable} import org.apache.spark.{SparkException, SparkFiles, TestUtils} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.functions.{call_function, max} @@ -801,6 +802,28 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("SPARK-48845: GenericUDF catch exceptions from child UDFs") { + withTable("test_catch_exception") { + withUserDefinedFunction("udf_throw_exception" -> true, "udf_catch_exception" -> true) { + Seq("9", "9-1").toDF("a").write.saveAsTable("test_catch_exception") + sql("CREATE TEMPORARY FUNCTION udf_throw_exception AS " + + s"'${classOf[UDFThrowException].getName}'") + sql("CREATE TEMPORARY FUNCTION udf_catch_exception AS " + + s"'${classOf[UDFCatchException].getName}'") + Seq( + CodegenObjectFactoryMode.FALLBACK.toString, + CodegenObjectFactoryMode.NO_CODEGEN.toString + ).foreach { codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + val df = sql( + "SELECT udf_catch_exception(udf_throw_exception(a)) FROM test_catch_exception") + checkAnswer(df, Seq(Row("9"), Row(null))) + } + } + } + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable {