diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b05f4f61f6a3e..ba5f71d1ec2a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -351,6 +351,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayContainsWithPatternMatch]("array_contains_with_pattern_match"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c0200299376ca..dcd16e7eb1e22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import java.util.regex.Pattern import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -261,3 +263,137 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +@ExpressionDescription( + usage = """_FUNC_(array, value) - Returns TRUE if the array contains the value or + for string arrays, if string matches with the any pattern in the array. + This is complete word match""", + extended = """ > SELECT _FUNC_(array("\\d\\s\\d", "2", "3"), "1 5");\n true""") +case class ArrayContainsWithPatternMatch(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq() + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq() + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegexArray: ArrayData = _ + // last regex pattern, we cache it for performance concern + @transient private var patternArray: Array[Pattern] = _ + + + override def nullSafeEval(arrAny: Any, value: Any): Any = { + val arr = arrAny.asInstanceOf[ArrayData] + var hasNull = false + if (right.dataType == StringType) { + if (!arr.equals(lastRegexArray)) { + lastRegexArray = arr.copy() + patternArray = new Array[Pattern](arr.numElements()) + lastRegexArray.foreach(StringType, (i : Int, str : Any) => if (str == null) { + patternArray(i) = null + } else { + patternArray(i) = Pattern.compile("^".concat(str.toString).concat("$")) + }) + } + patternArray.foreach(v => if (v == null) { + hasNull = true + false + } else if (v.matcher(value.asInstanceOf[UTF8String].toString).find()) { + return true + }) + } else { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (v == value) { + return true + } + ) + } + + if (hasNull) { + null + } else { + false + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + + val termLastRegexArray = ctx.freshName("lastRegexArray") + val termPatternArray = ctx.freshName("patternArray") + val patternClassNamePattern = classOf[Pattern].getCanonicalName.stripSuffix("[]") + val arrayDataClassNamePattern = classOf[ArrayData].getCanonicalName.stripSuffix("[]") + + ctx.addMutableState(s"$arrayDataClassNamePattern", termLastRegexArray, + s"${termLastRegexArray} = null;") + ctx.addMutableState(s"$patternClassNamePattern[]", termPatternArray, + s"${termPatternArray} = null;") + + nullSafeCodeGen(ctx, ev, (arr, value) => { + val i = ctx.freshName("i") + var getValue = ctx.getValue(arr, right.dataType, i) + val code = if (right.dataType == StringType) { + s""" + if (!$arr.equals(${termLastRegexArray})) { + // regex Array value changed + ${termPatternArray} = new ${patternClassNamePattern}[$arr.numElements()]; + ${termLastRegexArray} = $arr.copy(); + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${termPatternArray}[$i] = null; + } else { + ${termPatternArray}[$i] = ${patternClassNamePattern}.compile( + "^".concat(${getValue}.toString()).concat("$$")); + } + } + }""".stripMargin + } else "" + val k = { + if (right.dataType == StringType) { + getValue = s"${termPatternArray}[$i]" + } + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${genEqual(ctx, ev, right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.value} = true; + break; + } + }""".stripMargin } + code + k + } + ) + } + + def genEqual(ctx: CodegenContext, ev: ExprCode, dataType: DataType, + c1: String, c2: String): String = dataType match { + case StringType => s"${c2}.matcher($c1.toString()).find()".stripMargin + case _ => ctx.genEqual(dataType, c1, c2) + } + + override def prettyName: String = "array_contains_with_pattern_match" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index c76dad208ea1e..c3090de116a09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -106,4 +106,37 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array contains with pattern match") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq[String]("\\d\\d\\s-\\s\\d\\d", null, "", "pattern"), + ArrayType(StringType)) + + checkEvaluation(ArrayContainsWithPatternMatch(a0, Literal(0)), false) + checkEvaluation(ArrayContainsWithPatternMatch(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal("")), true) + checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal("a")), null) + checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a2, Literal(1L)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a3, Literal("")), null) + checkEvaluation(ArrayContainsWithPatternMatch(a3, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create(null, StringType)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("", StringType)), true) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("12 - 20", StringType)), true) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("pat", StringType)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("pattern", StringType)), true) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("ab - cd", StringType)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create(" 12 - 20 ", StringType)), + null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("132 - 20", StringType)), null) + } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 47bf41a2da813..cae69ceb135ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2779,6 +2779,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns true if the array contains `value` or `value` match any + * of the pattern available in array if array is of type string. + * @group collection_funcs + * @since 2.0 + */ + def array_contains_with_pattern_match(column: Column, value: Any): Column = withExpr { + ArrayContainsWithPatternMatch(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 45db61515e9b6..210aefcc4a7f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -406,4 +406,99 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(true), Row(true)) ) } + + test("array contains with pattern match function") { + val df1 = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df1.select(array_contains_with_pattern_match(df1("a"), 1)), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df1.selectExpr("array_contains_with_pattern_match(a, 1)"), + Seq(Row(true), Row(false)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df1.select(array_contains_with_pattern_match(df1("a"), null)) + } + intercept[AnalysisException] { + df1.selectExpr("array_contains_with_pattern_match(a, null)") + } + intercept[AnalysisException] { + df1.selectExpr("array_contains_with_pattern_match(null, 1)") + } + + checkAnswer( + df1.selectExpr("array_contains_with_pattern_match(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true)) + ) + checkAnswer( + df1.selectExpr("array_contains_with_pattern_match(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true)) + ) + val df2 = Seq( + (Seq[String]("1", "2"), "x"), + (Seq[String](), "x"), + (Seq[String]("\\d\\s-\\s\\d", "pattern", ""), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df2.select(array_contains_with_pattern_match(df2("a"), "1")), + Seq(Row(true), Row(false), Row(false)) + ) + checkAnswer( + df2.selectExpr("""array_contains_with_pattern_match(a, "1")"""), + Seq(Row(true), Row(false), Row(false)) + ) + checkAnswer( + df2.select(array_contains_with_pattern_match(df2("a"), "1 - 2")), + Seq(Row(false), Row(false), Row(true)) + ) + checkAnswer( + df2.selectExpr("""array_contains_with_pattern_match(a, "3 - 4")"""), + Seq(Row(false), Row(false), Row(true)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df2.select(array_contains_with_pattern_match(df2("a"), null)) + } + intercept[AnalysisException] { + df2.selectExpr("array_contains_with_pattern_match(a, null)") + } + intercept[AnalysisException] { + df2.selectExpr("array_contains_with_pattern_match(null, 1)") + } + + checkAnswer( + df2.selectExpr("array_contains_with_pattern_match(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr("array_contains_with_pattern_match(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr( + """array_contains_with_pattern_match(array(array("\\d\\s\\d"), null)[0], + "1 3")""".stripMargin), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr( + """array_contains_with_pattern_match(array("\\d\\s\\d", null), array("1 3", null)[0])""".stripMargin), + Seq(Row(true), Row(true), Row(true)) + ) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index fdd02821dfa29..151e5a5a125dd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -261,6 +261,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { test("collection functions") { checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)") + checkSqlGeneration("""SELECT array_contains_with_pattern_match(array("\\d\\s\\d", "pattern", + "8"), "9 8")""".stripMargin) checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))") checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))") }