diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 81564a44011..7901374f6bf 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -723,30 +723,29 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { val trimParaSepStr = "\u2029" // Needs to be trimmed for casting to float/double/decimal val trimSpaceStr = ('\u0000' to '\u0020').toList.mkString + // ISOControl characters, refer java.lang.Character.isISOControl(int) + val isoControlStr = (('\u0000' to '\u001F') ++ ('\u007F' to '\u009F')).toList.mkString // scalastyle:on nonascii - c.dataType match { - case BinaryType | _: ArrayType | _: MapType | _: StructType | _: UserDefinedType[_] => - c - case FloatType | DoubleType | _: DecimalType => - c.child.dataType match { - case StringType if GlutenConfig.getConf.castFromVarcharAddTrimNode => - val trimNode = StringTrim(c.child, Some(Literal(trimSpaceStr))) - c.withNewChildren(Seq(trimNode)).asInstanceOf[Cast] - case _ => - c - } - case _ => - c.child.dataType match { - case StringType if GlutenConfig.getConf.castFromVarcharAddTrimNode => - val trimNode = StringTrim( - c.child, - Some( - Literal(trimWhitespaceStr + - trimSpaceSepStr + trimLineSepStr + trimParaSepStr))) - c.withNewChildren(Seq(trimNode)).asInstanceOf[Cast] - case _ => - c + if (GlutenConfig.getConf.castFromVarcharAddTrimNode && c.child.dataType == StringType) { + val trimStr = c.dataType match { + case BinaryType | _: ArrayType | _: MapType | _: StructType | _: UserDefinedType[_] => + None + case FloatType | DoubleType | _: DecimalType => + Some(trimSpaceStr) + case _ => + Some( + (trimWhitespaceStr + trimSpaceSepStr + trimLineSepStr + + trimParaSepStr + isoControlStr).toSet.mkString + ) + } + trimStr + .map { + trim => + c.withNewChildren(Seq(StringTrim(c.child, Some(Literal(trim))))).asInstanceOf[Cast] } + .getOrElse(c) + } else { + c } } diff --git a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala index 4008f862e17..3b2db7117f4 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/spark/sql/GlutenDataFrameSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql +import org.apache.gluten.GlutenConfig import org.apache.gluten.execution.{ProjectExecTransformer, WholeStageTransformer} import org.apache.spark.SparkException @@ -323,41 +324,52 @@ class GlutenDataFrameSuite extends DataFrameSuite with GlutenSQLTestsTrait { } testGluten("Allow leading/trailing whitespace in string before casting") { - def checkResult(df: DataFrame, expectedResult: Seq[Row]): Unit = { - checkAnswer(df, expectedResult) - assert(find(df.queryExecution.executedPlan)(_.isInstanceOf[ProjectExecTransformer]).isDefined) - } + withSQLConf(GlutenConfig.CAST_FROM_VARCHAR_ADD_TRIM_NODE.key -> "true") { + def checkResult(df: DataFrame, expectedResult: Seq[Row]): Unit = { + checkAnswer(df, expectedResult) + assert( + find(df.queryExecution.executedPlan)(_.isInstanceOf[ProjectExecTransformer]).isDefined) + } - // scalastyle:off nonascii - Seq(" 123", "123 ", " 123 ", "\u2000123\n\n\n", "123\r\r\r", "123\f\f\f", "123\u000C") - .toDF("col1") - .createOrReplaceTempView("t1") - // scalastyle:on nonascii - val expectedIntResult = Row(123) :: Row(123) :: - Row(123) :: Row(123) :: Row(123) :: Row(123) :: Row(123) :: Nil - var df = spark.sql("select cast(col1 as int) from t1") - checkResult(df, expectedIntResult) - df = spark.sql("select cast(col1 as long) from t1") - checkResult(df, expectedIntResult) - - Seq(" 123.5", "123.5 ", " 123.5 ", "123.5\n\n\n", "123.5\r\r\r", "123.5\f\f\f", "123.5\u000C") - .toDF("col1") - .createOrReplaceTempView("t1") - val expectedFloatResult = Row(123.5) :: Row(123.5) :: - Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Nil - df = spark.sql("select cast(col1 as float) from t1") - checkResult(df, expectedFloatResult) - df = spark.sql("select cast(col1 as double) from t1") - checkResult(df, expectedFloatResult) - - // scalastyle:off nonascii - val rawData = - Seq(" abc", "abc ", " abc ", "\u2000abc\n\n\n", "abc\r\r\r", "abc\f\f\f", "abc\u000C") - // scalastyle:on nonascii - rawData.toDF("col1").createOrReplaceTempView("t1") - val expectedBinaryResult = rawData.map(d => Row(d.getBytes(StandardCharsets.UTF_8))).seq - df = spark.sql("select cast(col1 as binary) from t1") - checkResult(df, expectedBinaryResult) + // scalastyle:off nonascii + Seq( + " 123", + "123 ", + " 123 ", + "\u2000123\n\n\n", + "123\r\r\r", + "123\f\f\f", + "123\u000C", + "123\u0000") + .toDF("col1") + .createOrReplaceTempView("t1") + // scalastyle:on nonascii + val expectedIntResult = Row(123) :: Row(123) :: + Row(123) :: Row(123) :: Row(123) :: Row(123) :: Row(123) :: Row(123) :: Nil + var df = spark.sql("select cast(col1 as int) from t1") + checkResult(df, expectedIntResult) + df = spark.sql("select cast(col1 as long) from t1") + checkResult(df, expectedIntResult) + + Seq(" 123.5", "123.5 ", " 123.5 ", "123.5\n\n\n", "123.5\r\r\r", "123.5\f\f\f", "123.5\u000C") + .toDF("col1") + .createOrReplaceTempView("t1") + val expectedFloatResult = Row(123.5) :: Row(123.5) :: + Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Row(123.5) :: Nil + df = spark.sql("select cast(col1 as float) from t1") + checkResult(df, expectedFloatResult) + df = spark.sql("select cast(col1 as double) from t1") + checkResult(df, expectedFloatResult) + + // scalastyle:off nonascii + val rawData = + Seq(" abc", "abc ", " abc ", "\u2000abc\n\n\n", "abc\r\r\r", "abc\f\f\f", "abc\u000C") + // scalastyle:on nonascii + rawData.toDF("col1").createOrReplaceTempView("t1") + val expectedBinaryResult = rawData.map(d => Row(d.getBytes(StandardCharsets.UTF_8))).seq + df = spark.sql("select cast(col1 as binary) from t1") + checkResult(df, expectedBinaryResult) + } } private def withExpr(newExpr: Expression): Column = new Column(newExpr)