From 69f63092ddcb04ea10186d71d407b4886dea3245 Mon Sep 17 00:00:00 2001 From: Priyanka Garg Date: Thu, 29 Sep 2016 13:43:08 +0530 Subject: [PATCH 1/4] [SPARK-17619][SQL] To add support for pattern matching in ArrayContains Expression. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This change adds support for pattern matching in arrayContains expression for the string arrays. For eg. a. arrayContains ( Seq ( “\\d\\d\\s-\\s\\d\\d”, null, "", "pattern"), "12 - 20" ) returns true b. arrayContains ( Seq ( "\\d\\d\\s-\\s\\d\\d", "", "pattern"), "132 - 20" ) ) returns false c. arrayContains ( Seq ( "\\d\\d\\s-\\s\\d\\d", null, ””, "pattern"), "132 - 20" ) ) returns null This change is completely backward compatible. ## How was this patch tested? Added some more test cases for pattern match use case in the following: a. CollectionFunctionsSuite.scala b. DataFrameFunctionsSuite.scala c. ExpressionToSQLSuite.scala jira entry for detail: https://issues.apache.org/jira/browse/SPARK-17619 --- .../expressions/collectionOperations.scala | 113 ++++++++++++++---- .../CollectionFunctionsSuite.scala | 12 +- .../org/apache/spark/sql/functions.scala | 3 +- .../spark/sql/DataFrameFunctionsSuite.scala | 76 ++++++++++-- .../sql/catalyst/ExpressionToSQLSuite.scala | 2 + 5 files changed, 173 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c0200299376ca..a4f5fc44800b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import java.util.regex.Pattern import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -191,11 +193,15 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } /** - * Checks if the array (left) has the element (right) + * Checks if the array (left) has the element (right) and pattern match in + * case left is Array of type string */ + @ExpressionDescription( - usage = "_FUNC_(array, value) - Returns TRUE if the array contains the value.", - extended = " > SELECT _FUNC_(array(1, 2, 3), 2);\n true") + usage = """_FUNC_(array, value) - Returns TRUE if the array contains the value or + for string arrays, if string matches with the any pattern in the array. + This is complete word match""", + extended = """ > SELECT _FUNC_(array("\\d\\s\\d", "2", "3"), "1 5");\n true""") case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -225,38 +231,101 @@ case class ArrayContains(left: Expression, right: Expression) left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull } - override def nullSafeEval(arr: Any, value: Any): Any = { + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegexArray: ArrayData = _ + // last regex pattern, we cache it for performance concern + @transient private var patternArray: Array[Pattern] = _ + + + override def nullSafeEval(arrAny: Any, value: Any): Any = { + val arr = arrAny.asInstanceOf[ArrayData] var hasNull = false - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == null) { + if (right.dataType == StringType) { + if (!arr.equals(lastRegexArray)) { + lastRegexArray = arr.copy() + patternArray = new Array[Pattern](arr.numElements()) + lastRegexArray.foreach(StringType, (i : Int, str : Any) => if (str == null) { + patternArray(i) = null + } else { + patternArray(i) = Pattern.compile("^".concat(str.toString).concat("$")) + }) + } + patternArray.foreach(v => if (v == null) { hasNull = true - } else if (v == value) { + false + } else if (v.matcher(value.asInstanceOf[UTF8String].toString).find()) { return true - } - ) + }) + } else { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (v == value) { + return true + } + ) + } + if (hasNull) { null } else { false } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + + val termLastRegexArray = ctx.freshName("lastRegexArray") + val termPatternArray = ctx.freshName("patternArray") + val patternClassNamePattern = classOf[Pattern].getCanonicalName.stripSuffix("[]") + val arrayDataClassNamePattern = classOf[ArrayData].getCanonicalName.stripSuffix("[]") + + ctx.addMutableState(s"$arrayDataClassNamePattern", termLastRegexArray, + s"${termLastRegexArray} = null;") + ctx.addMutableState(s"$patternClassNamePattern[]", termPatternArray, + s"${termPatternArray} = null;") + nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") - val getValue = ctx.getValue(arr, right.dataType, i) - s""" - for (int $i = 0; $i < $arr.numElements(); $i ++) { - if ($arr.isNullAt($i)) { - ${ev.isNull} = true; - } else if (${ctx.genEqual(right.dataType, value, getValue)}) { - ${ev.isNull} = false; - ${ev.value} = true; - break; + var getValue = ctx.getValue(arr, right.dataType, i) + val code = if (right.dataType == StringType) { + s""" + if (!$arr.equals(${termLastRegexArray})) { + // regex Array value changed + ${termPatternArray} = new ${patternClassNamePattern}[$arr.numElements()]; + ${termLastRegexArray} = $arr.copy(); + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${termPatternArray}[$i] = null; + } else { + ${termPatternArray}[$i] = ${patternClassNamePattern}.compile( + "^".concat(${getValue}.toString()).concat("$$")); + } + } + }""".stripMargin + } else "" + val k = { + if (right.dataType == StringType) { + getValue = s"${termPatternArray}[$i]" } - } - """ - }) + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${genEqual(ctx, ev, right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.value} = true; + break; + } + }""".stripMargin } + code + k + } + ) + } + + def genEqual(ctx: CodegenContext, ev: ExprCode, dataType: DataType, + c1: String, c2: String): String = dataType match { + case StringType => s"${c2}.matcher($c1.toString()).find()".stripMargin + case _ => ctx.genEqual(dataType, c1, c2) } override def prettyName: String = "array_contains" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index c76dad208ea1e..843ad88126fe7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -91,8 +91,9 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq[String]("\\d\\d\\s-\\s\\d\\d", null, "", "pattern"), + ArrayType(StringType)) - checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) @@ -105,5 +106,14 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContains(a4, Literal.create(null, StringType)), null) + checkEvaluation(ArrayContains(a4, Literal.create("", StringType)), true) + checkEvaluation(ArrayContains(a4, Literal.create("12 - 20", StringType)), true) + checkEvaluation(ArrayContains(a4, Literal.create("pat", StringType)), null) + checkEvaluation(ArrayContains(a4, Literal.create("pattern", StringType)), true) + checkEvaluation(ArrayContains(a4, Literal.create("ab - cd", StringType)), null) + checkEvaluation(ArrayContains(a4, Literal.create(" 12 - 20 ", StringType)), null) + checkEvaluation(ArrayContains(a4, Literal.create("132 - 20", StringType)), null) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 47bf41a2da813..43e61d9b64a77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2771,7 +2771,8 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns true if the array contains `value` + * Returns true if the array contains `value` or `value` match any + * of the pattern available in array if array is of type string. * @group collection_funcs * @since 1.5.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 45db61515e9b6..0d43f39eaef42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -371,39 +371,97 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("array contains function") { - val df = Seq( + val df1 = Seq( (Seq[Int](1, 2), "x"), (Seq[Int](), "x") ).toDF("a", "b") // Simple test cases checkAnswer( - df.select(array_contains(df("a"), 1)), + df1.select(array_contains(df1("a"), 1)), Seq(Row(true), Row(false)) ) checkAnswer( - df.selectExpr("array_contains(a, 1)"), + df1.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) // In hive, this errors because null has no type information intercept[AnalysisException] { - df.select(array_contains(df("a"), null)) + df1.select(array_contains(df1("a"), null)) } intercept[AnalysisException] { - df.selectExpr("array_contains(a, null)") + df1.selectExpr("array_contains(a, null)") } intercept[AnalysisException] { - df.selectExpr("array_contains(null, 1)") + df1.selectExpr("array_contains(null, 1)") } checkAnswer( - df.selectExpr("array_contains(array(array(1), null)[0], 1)"), + df1.selectExpr("array_contains(array(array(1), null)[0], 1)"), Seq(Row(true), Row(true)) ) checkAnswer( - df.selectExpr("array_contains(array(1, null), array(1, null)[0])"), + df1.selectExpr("array_contains(array(1, null), array(1, null)[0])"), Seq(Row(true), Row(true)) ) + val df2 = Seq( + (Seq[String]("1", "2"), "x"), + (Seq[String](), "x"), + (Seq[String]("\\d\\s-\\s\\d", "pattern", ""), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df2.select(array_contains(df2("a"), "1")), + Seq(Row(true), Row(false), Row(false)) + ) + checkAnswer( + df2.selectExpr("""array_contains(a, "1")"""), + Seq(Row(true), Row(false), Row(false)) + ) + checkAnswer( + df2.select(array_contains(df2("a"), "1 - 2")), + Seq(Row(false), Row(false), Row(true)) + ) + checkAnswer( + df2.selectExpr("""array_contains(a, "3 - 4")"""), + Seq(Row(false), Row(false), Row(true)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df2.select(array_contains(df2("a"), null)) + } + intercept[AnalysisException] { + df2.selectExpr("array_contains(a, null)") + } + intercept[AnalysisException] { + df2.selectExpr("array_contains(null, 1)") + } + + checkAnswer( + df2.selectExpr("array_contains(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr("array_contains(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr( + """array_contains(array(array("\\d\\s\\d"), null)[0], + "1 3")""".stripMargin), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr( + """array_contains(array("\\d\\s\\d", null), array("1 3", null)[0])""".stripMargin), + Seq(Row(true), Row(true), Row(true)) + ) + } -} +} \ No newline at end of file diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index fdd02821dfa29..c95e8e475dbc4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -261,6 +261,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { test("collection functions") { checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)") + checkSqlGeneration("""SELECT array_contains(array("\\d\\s\\d", "pattern", + "8"), "9 8")""".stripMargin) checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))") checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))") } From 3e849e0217bcd0b1cf7a7a95f3fc64e79ade4ab7 Mon Sep 17 00:00:00 2001 From: prigarg Date: Fri, 30 Sep 2016 12:12:53 +0530 Subject: [PATCH 2/4] Revert "[SPARK-17619][SQL] To add support for pattern matching in ArrayContains Expression." This reverts commit 69f63092ddcb04ea10186d71d407b4886dea3245. --- .../expressions/collectionOperations.scala | 113 ++++-------------- .../CollectionFunctionsSuite.scala | 12 +- .../org/apache/spark/sql/functions.scala | 3 +- .../spark/sql/DataFrameFunctionsSuite.scala | 76 ++---------- .../sql/catalyst/ExpressionToSQLSuite.scala | 2 - 5 files changed, 33 insertions(+), 173 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a4f5fc44800b3..c0200299376ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -17,14 +17,12 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator -import java.util.regex.Pattern import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -193,15 +191,11 @@ case class SortArray(base: Expression, ascendingOrder: Expression) } /** - * Checks if the array (left) has the element (right) and pattern match in - * case left is Array of type string + * Checks if the array (left) has the element (right) */ - @ExpressionDescription( - usage = """_FUNC_(array, value) - Returns TRUE if the array contains the value or - for string arrays, if string matches with the any pattern in the array. - This is complete word match""", - extended = """ > SELECT _FUNC_(array("\\d\\s\\d", "2", "3"), "1 5");\n true""") + usage = "_FUNC_(array, value) - Returns TRUE if the array contains the value.", + extended = " > SELECT _FUNC_(array(1, 2, 3), 2);\n true") case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -231,101 +225,38 @@ case class ArrayContains(left: Expression, right: Expression) left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull } - // last regex in string, we will update the pattern iff regexp value changed. - @transient private var lastRegexArray: ArrayData = _ - // last regex pattern, we cache it for performance concern - @transient private var patternArray: Array[Pattern] = _ - - - override def nullSafeEval(arrAny: Any, value: Any): Any = { - val arr = arrAny.asInstanceOf[ArrayData] + override def nullSafeEval(arr: Any, value: Any): Any = { var hasNull = false - if (right.dataType == StringType) { - if (!arr.equals(lastRegexArray)) { - lastRegexArray = arr.copy() - patternArray = new Array[Pattern](arr.numElements()) - lastRegexArray.foreach(StringType, (i : Int, str : Any) => if (str == null) { - patternArray(i) = null - } else { - patternArray(i) = Pattern.compile("^".concat(str.toString).concat("$")) - }) - } - patternArray.foreach(v => if (v == null) { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { hasNull = true - false - } else if (v.matcher(value.asInstanceOf[UTF8String].toString).find()) { + } else if (v == value) { return true - }) - } else { - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == null) { - hasNull = true - } else if (v == value) { - return true - } - ) - } - + } + ) if (hasNull) { null } else { false } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - - val termLastRegexArray = ctx.freshName("lastRegexArray") - val termPatternArray = ctx.freshName("patternArray") - val patternClassNamePattern = classOf[Pattern].getCanonicalName.stripSuffix("[]") - val arrayDataClassNamePattern = classOf[ArrayData].getCanonicalName.stripSuffix("[]") - - ctx.addMutableState(s"$arrayDataClassNamePattern", termLastRegexArray, - s"${termLastRegexArray} = null;") - ctx.addMutableState(s"$patternClassNamePattern[]", termPatternArray, - s"${termPatternArray} = null;") + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") - var getValue = ctx.getValue(arr, right.dataType, i) - val code = if (right.dataType == StringType) { - s""" - if (!$arr.equals(${termLastRegexArray})) { - // regex Array value changed - ${termPatternArray} = new ${patternClassNamePattern}[$arr.numElements()]; - ${termLastRegexArray} = $arr.copy(); - for (int $i = 0; $i < $arr.numElements(); $i ++) { - if ($arr.isNullAt($i)) { - ${termPatternArray}[$i] = null; - } else { - ${termPatternArray}[$i] = ${patternClassNamePattern}.compile( - "^".concat(${getValue}.toString()).concat("$$")); - } - } - }""".stripMargin - } else "" - val k = { - if (right.dataType == StringType) { - getValue = s"${termPatternArray}[$i]" + val getValue = ctx.getValue(arr, right.dataType, i) + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.value} = true; + break; } - s""" - for (int $i = 0; $i < $arr.numElements(); $i ++) { - if ($arr.isNullAt($i)) { - ${ev.isNull} = true; - } else if (${genEqual(ctx, ev, right.dataType, value, getValue)}) { - ${ev.isNull} = false; - ${ev.value} = true; - break; - } - }""".stripMargin } - code + k - } - ) - } - - def genEqual(ctx: CodegenContext, ev: ExprCode, dataType: DataType, - c1: String, c2: String): String = dataType match { - case StringType => s"${c2}.matcher($c1.toString()).find()".stripMargin - case _ => ctx.genEqual(dataType, c1, c2) + } + """ + }) } override def prettyName: String = "array_contains" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 843ad88126fe7..c76dad208ea1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -91,9 +91,8 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) val a3 = Literal.create(null, ArrayType(StringType)) - val a4 = Literal.create(Seq[String]("\\d\\d\\s-\\s\\d\\d", null, "", "pattern"), - ArrayType(StringType)) + checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) @@ -106,14 +105,5 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) - - checkEvaluation(ArrayContains(a4, Literal.create(null, StringType)), null) - checkEvaluation(ArrayContains(a4, Literal.create("", StringType)), true) - checkEvaluation(ArrayContains(a4, Literal.create("12 - 20", StringType)), true) - checkEvaluation(ArrayContains(a4, Literal.create("pat", StringType)), null) - checkEvaluation(ArrayContains(a4, Literal.create("pattern", StringType)), true) - checkEvaluation(ArrayContains(a4, Literal.create("ab - cd", StringType)), null) - checkEvaluation(ArrayContains(a4, Literal.create(" 12 - 20 ", StringType)), null) - checkEvaluation(ArrayContains(a4, Literal.create("132 - 20", StringType)), null) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 43e61d9b64a77..47bf41a2da813 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2771,8 +2771,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns true if the array contains `value` or `value` match any - * of the pattern available in array if array is of type string. + * Returns true if the array contains `value` * @group collection_funcs * @since 1.5.0 */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0d43f39eaef42..45db61515e9b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -371,97 +371,39 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { } test("array contains function") { - val df1 = Seq( + val df = Seq( (Seq[Int](1, 2), "x"), (Seq[Int](), "x") ).toDF("a", "b") // Simple test cases checkAnswer( - df1.select(array_contains(df1("a"), 1)), + df.select(array_contains(df("a"), 1)), Seq(Row(true), Row(false)) ) checkAnswer( - df1.selectExpr("array_contains(a, 1)"), + df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) // In hive, this errors because null has no type information intercept[AnalysisException] { - df1.select(array_contains(df1("a"), null)) + df.select(array_contains(df("a"), null)) } intercept[AnalysisException] { - df1.selectExpr("array_contains(a, null)") + df.selectExpr("array_contains(a, null)") } intercept[AnalysisException] { - df1.selectExpr("array_contains(null, 1)") + df.selectExpr("array_contains(null, 1)") } checkAnswer( - df1.selectExpr("array_contains(array(array(1), null)[0], 1)"), + df.selectExpr("array_contains(array(array(1), null)[0], 1)"), Seq(Row(true), Row(true)) ) checkAnswer( - df1.selectExpr("array_contains(array(1, null), array(1, null)[0])"), + df.selectExpr("array_contains(array(1, null), array(1, null)[0])"), Seq(Row(true), Row(true)) ) - val df2 = Seq( - (Seq[String]("1", "2"), "x"), - (Seq[String](), "x"), - (Seq[String]("\\d\\s-\\s\\d", "pattern", ""), "x") - ).toDF("a", "b") - - // Simple test cases - checkAnswer( - df2.select(array_contains(df2("a"), "1")), - Seq(Row(true), Row(false), Row(false)) - ) - checkAnswer( - df2.selectExpr("""array_contains(a, "1")"""), - Seq(Row(true), Row(false), Row(false)) - ) - checkAnswer( - df2.select(array_contains(df2("a"), "1 - 2")), - Seq(Row(false), Row(false), Row(true)) - ) - checkAnswer( - df2.selectExpr("""array_contains(a, "3 - 4")"""), - Seq(Row(false), Row(false), Row(true)) - ) - - // In hive, this errors because null has no type information - intercept[AnalysisException] { - df2.select(array_contains(df2("a"), null)) - } - intercept[AnalysisException] { - df2.selectExpr("array_contains(a, null)") - } - intercept[AnalysisException] { - df2.selectExpr("array_contains(null, 1)") - } - - checkAnswer( - df2.selectExpr("array_contains(array(array(1), null)[0], 1)"), - Seq(Row(true), Row(true), Row(true)) - ) - - checkAnswer( - df2.selectExpr("array_contains(array(1, null), array(1, null)[0])"), - Seq(Row(true), Row(true), Row(true)) - ) - - checkAnswer( - df2.selectExpr( - """array_contains(array(array("\\d\\s\\d"), null)[0], - "1 3")""".stripMargin), - Seq(Row(true), Row(true), Row(true)) - ) - - checkAnswer( - df2.selectExpr( - """array_contains(array("\\d\\s\\d", null), array("1 3", null)[0])""".stripMargin), - Seq(Row(true), Row(true), Row(true)) - ) - } -} \ No newline at end of file +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index c95e8e475dbc4..fdd02821dfa29 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -261,8 +261,6 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { test("collection functions") { checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)") - checkSqlGeneration("""SELECT array_contains(array("\\d\\s\\d", "pattern", - "8"), "9 8")""".stripMargin) checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))") checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))") } From 3f999a51d65d08562c51634718c51d187a7116af Mon Sep 17 00:00:00 2001 From: prigarg Date: Fri, 30 Sep 2016 13:42:34 +0530 Subject: [PATCH 3/4] [SPARK-17619][SQL] To add support for pattern matching in ArrayContains Expression. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This change adds new expression ArrayContainsWithPatternMatch , which does the pattern matching for string types and works in the same way as ArrayContains for all other data types. For eg. a. ArrayContainsWithPatternMatch ( Seq ( “\\d\\d\\s-\\s\\d\\d”, null, "", "pattern"), "12 - 20" ) returns true b. ArrayContainsWithPatternMatch ( Seq ( "\\d\\d\\s-\\s\\d\\d", "", "pattern"), "132 - 20" ) ) returns false c. ArrayContainsWithPatternMatch ( Seq ( "\\d\\d\\s-\\s\\d\\d", null, ””, "pattern"), "132 - 20" ) ) returns null This change is completely backward compatible. ## How was this patch tested? Added some more test cases for pattern match use case in the following: a. CollectionFunctionsSuite.scala b. DataFrameFunctionsSuite.scala c. ExpressionToSQLSuite.scala jira entry for detail: https://issues.apache.org/jira/browse/SPARK-17619 --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 136 ++++++++++++++++++ .../CollectionFunctionsSuite.scala | 33 +++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 95 ++++++++++++ .../sql/catalyst/ExpressionToSQLSuite.scala | 2 + 6 files changed, 277 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b05f4f61f6a3e..ba5f71d1ec2a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -351,6 +351,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayContainsWithPatternMatch]("array_contains_with_pattern_match"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c0200299376ca..dcd16e7eb1e22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import java.util.regex.Pattern import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -261,3 +263,137 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +@ExpressionDescription( + usage = """_FUNC_(array, value) - Returns TRUE if the array contains the value or + for string arrays, if string matches with the any pattern in the array. + This is complete word match""", + extended = """ > SELECT _FUNC_(array("\\d\\s\\d", "2", "3"), "1 5");\n true""") +case class ArrayContainsWithPatternMatch(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq() + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq() + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegexArray: ArrayData = _ + // last regex pattern, we cache it for performance concern + @transient private var patternArray: Array[Pattern] = _ + + + override def nullSafeEval(arrAny: Any, value: Any): Any = { + val arr = arrAny.asInstanceOf[ArrayData] + var hasNull = false + if (right.dataType == StringType) { + if (!arr.equals(lastRegexArray)) { + lastRegexArray = arr.copy() + patternArray = new Array[Pattern](arr.numElements()) + lastRegexArray.foreach(StringType, (i : Int, str : Any) => if (str == null) { + patternArray(i) = null + } else { + patternArray(i) = Pattern.compile("^".concat(str.toString).concat("$")) + }) + } + patternArray.foreach(v => if (v == null) { + hasNull = true + false + } else if (v.matcher(value.asInstanceOf[UTF8String].toString).find()) { + return true + }) + } else { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (v == value) { + return true + } + ) + } + + if (hasNull) { + null + } else { + false + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + + val termLastRegexArray = ctx.freshName("lastRegexArray") + val termPatternArray = ctx.freshName("patternArray") + val patternClassNamePattern = classOf[Pattern].getCanonicalName.stripSuffix("[]") + val arrayDataClassNamePattern = classOf[ArrayData].getCanonicalName.stripSuffix("[]") + + ctx.addMutableState(s"$arrayDataClassNamePattern", termLastRegexArray, + s"${termLastRegexArray} = null;") + ctx.addMutableState(s"$patternClassNamePattern[]", termPatternArray, + s"${termPatternArray} = null;") + + nullSafeCodeGen(ctx, ev, (arr, value) => { + val i = ctx.freshName("i") + var getValue = ctx.getValue(arr, right.dataType, i) + val code = if (right.dataType == StringType) { + s""" + if (!$arr.equals(${termLastRegexArray})) { + // regex Array value changed + ${termPatternArray} = new ${patternClassNamePattern}[$arr.numElements()]; + ${termLastRegexArray} = $arr.copy(); + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${termPatternArray}[$i] = null; + } else { + ${termPatternArray}[$i] = ${patternClassNamePattern}.compile( + "^".concat(${getValue}.toString()).concat("$$")); + } + } + }""".stripMargin + } else "" + val k = { + if (right.dataType == StringType) { + getValue = s"${termPatternArray}[$i]" + } + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${genEqual(ctx, ev, right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.value} = true; + break; + } + }""".stripMargin } + code + k + } + ) + } + + def genEqual(ctx: CodegenContext, ev: ExprCode, dataType: DataType, + c1: String, c2: String): String = dataType match { + case StringType => s"${c2}.matcher($c1.toString()).find()".stripMargin + case _ => ctx.genEqual(dataType, c1, c2) + } + + override def prettyName: String = "array_contains_with_pattern_match" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index c76dad208ea1e..c3090de116a09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -106,4 +106,37 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array contains with pattern match") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq[String]("\\d\\d\\s-\\s\\d\\d", null, "", "pattern"), + ArrayType(StringType)) + + checkEvaluation(ArrayContainsWithPatternMatch(a0, Literal(0)), false) + checkEvaluation(ArrayContainsWithPatternMatch(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal("")), true) + checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal("a")), null) + checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a2, Literal(1L)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a3, Literal("")), null) + checkEvaluation(ArrayContainsWithPatternMatch(a3, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create(null, StringType)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("", StringType)), true) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("12 - 20", StringType)), true) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("pat", StringType)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("pattern", StringType)), true) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("ab - cd", StringType)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create(" 12 - 20 ", StringType)), + null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("132 - 20", StringType)), null) + } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 47bf41a2da813..cae69ceb135ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2779,6 +2779,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns true if the array contains `value` or `value` match any + * of the pattern available in array if array is of type string. + * @group collection_funcs + * @since 2.0 + */ + def array_contains_with_pattern_match(column: Column, value: Any): Column = withExpr { + ArrayContainsWithPatternMatch(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 45db61515e9b6..210aefcc4a7f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -406,4 +406,99 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(true), Row(true)) ) } + + test("array contains with pattern match function") { + val df1 = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df1.select(array_contains_with_pattern_match(df1("a"), 1)), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df1.selectExpr("array_contains_with_pattern_match(a, 1)"), + Seq(Row(true), Row(false)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df1.select(array_contains_with_pattern_match(df1("a"), null)) + } + intercept[AnalysisException] { + df1.selectExpr("array_contains_with_pattern_match(a, null)") + } + intercept[AnalysisException] { + df1.selectExpr("array_contains_with_pattern_match(null, 1)") + } + + checkAnswer( + df1.selectExpr("array_contains_with_pattern_match(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true)) + ) + checkAnswer( + df1.selectExpr("array_contains_with_pattern_match(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true)) + ) + val df2 = Seq( + (Seq[String]("1", "2"), "x"), + (Seq[String](), "x"), + (Seq[String]("\\d\\s-\\s\\d", "pattern", ""), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df2.select(array_contains_with_pattern_match(df2("a"), "1")), + Seq(Row(true), Row(false), Row(false)) + ) + checkAnswer( + df2.selectExpr("""array_contains_with_pattern_match(a, "1")"""), + Seq(Row(true), Row(false), Row(false)) + ) + checkAnswer( + df2.select(array_contains_with_pattern_match(df2("a"), "1 - 2")), + Seq(Row(false), Row(false), Row(true)) + ) + checkAnswer( + df2.selectExpr("""array_contains_with_pattern_match(a, "3 - 4")"""), + Seq(Row(false), Row(false), Row(true)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df2.select(array_contains_with_pattern_match(df2("a"), null)) + } + intercept[AnalysisException] { + df2.selectExpr("array_contains_with_pattern_match(a, null)") + } + intercept[AnalysisException] { + df2.selectExpr("array_contains_with_pattern_match(null, 1)") + } + + checkAnswer( + df2.selectExpr("array_contains_with_pattern_match(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr("array_contains_with_pattern_match(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr( + """array_contains_with_pattern_match(array(array("\\d\\s\\d"), null)[0], + "1 3")""".stripMargin), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr( + """array_contains_with_pattern_match(array("\\d\\s\\d", null), array("1 3", null)[0])""".stripMargin), + Seq(Row(true), Row(true), Row(true)) + ) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index fdd02821dfa29..151e5a5a125dd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -261,6 +261,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { test("collection functions") { checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)") + checkSqlGeneration("""SELECT array_contains_with_pattern_match(array("\\d\\s\\d", "pattern", + "8"), "9 8")""".stripMargin) checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))") checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))") } From 4d9a42d0b95397a91f78833aa78c5acad0c0c8f7 Mon Sep 17 00:00:00 2001 From: prigarg Date: Fri, 30 Sep 2016 13:42:34 +0530 Subject: [PATCH 4/4] [SPARK-17619][SQL] To add a new expression ArrayContainsWithPatternMatch. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This change adds new expression ArrayContainsWithPatternMatch , which does the pattern matching for string types and works in the same way as ArrayContains for all other data types. For eg. a. ArrayContainsWithPatternMatch ( Seq ( “\\d\\d\\s-\\s\\d\\d”, null, "", "pattern"), "12 - 20" ) returns true b. ArrayContainsWithPatternMatch ( Seq ( "\\d\\d\\s-\\s\\d\\d", "", "pattern"), "132 - 20" ) ) returns false c. ArrayContainsWithPatternMatch ( Seq ( "\\d\\d\\s-\\s\\d\\d", null, ””, "pattern"), "132 - 20" ) ) returns null This change is completely backward compatible. ## How was this patch tested? Added some more test cases for pattern match use case in the following: a. CollectionFunctionsSuite.scala b. DataFrameFunctionsSuite.scala c. ExpressionToSQLSuite.scala jira entry for detail: https://issues.apache.org/jira/browse/SPARK-17619 --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/collectionOperations.scala | 136 ++++++++++++++++++ .../CollectionFunctionsSuite.scala | 33 +++++ .../org/apache/spark/sql/functions.scala | 10 ++ .../spark/sql/DataFrameFunctionsSuite.scala | 95 ++++++++++++ .../sql/catalyst/ExpressionToSQLSuite.scala | 2 + 6 files changed, 277 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b05f4f61f6a3e..ba5f71d1ec2a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -351,6 +351,7 @@ object FunctionRegistry { // collection functions expression[CreateArray]("array"), expression[ArrayContains]("array_contains"), + expression[ArrayContainsWithPatternMatch]("array_contains_with_pattern_match"), expression[CreateMap]("map"), expression[CreateNamedStruct]("named_struct"), expression[MapKeys]("map_keys"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c0200299376ca..dcd16e7eb1e22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import java.util.regex.Pattern import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Given an array or map, returns its size. Returns -1 if null. @@ -261,3 +263,137 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + +@ExpressionDescription( + usage = """_FUNC_(array, value) - Returns TRUE if the array contains the value or + for string arrays, if string matches with the any pattern in the array. + This is complete word match""", + extended = """ > SELECT _FUNC_(array("\\d\\s\\d", "2", "3"), "1 5");\n true""") +case class ArrayContainsWithPatternMatch(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = right.dataType match { + case NullType => Seq() + case _ => left.dataType match { + case n @ ArrayType(element, _) => Seq(n, element) + case _ => Seq() + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (right.dataType == NullType) { + TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments") + } else if (!left.dataType.isInstanceOf[ArrayType] + || left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) { + TypeCheckResult.TypeCheckFailure( + "Arguments must be an array followed by a value of same type as the array members") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegexArray: ArrayData = _ + // last regex pattern, we cache it for performance concern + @transient private var patternArray: Array[Pattern] = _ + + + override def nullSafeEval(arrAny: Any, value: Any): Any = { + val arr = arrAny.asInstanceOf[ArrayData] + var hasNull = false + if (right.dataType == StringType) { + if (!arr.equals(lastRegexArray)) { + lastRegexArray = arr.copy() + patternArray = new Array[Pattern](arr.numElements()) + lastRegexArray.foreach(StringType, (i : Int, str : Any) => if (str == null) { + patternArray(i) = null + } else { + patternArray(i) = Pattern.compile("^".concat(str.toString).concat("$")) + }) + } + patternArray.foreach(v => if (v == null) { + hasNull = true + false + } else if (v.matcher(value.asInstanceOf[UTF8String].toString).find()) { + return true + }) + } else { + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (v == value) { + return true + } + ) + } + + if (hasNull) { + null + } else { + false + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + + val termLastRegexArray = ctx.freshName("lastRegexArray") + val termPatternArray = ctx.freshName("patternArray") + val patternClassNamePattern = classOf[Pattern].getCanonicalName.stripSuffix("[]") + val arrayDataClassNamePattern = classOf[ArrayData].getCanonicalName.stripSuffix("[]") + + ctx.addMutableState(s"$arrayDataClassNamePattern", termLastRegexArray, + s"${termLastRegexArray} = null;") + ctx.addMutableState(s"$patternClassNamePattern[]", termPatternArray, + s"${termPatternArray} = null;") + + nullSafeCodeGen(ctx, ev, (arr, value) => { + val i = ctx.freshName("i") + var getValue = ctx.getValue(arr, right.dataType, i) + val code = if (right.dataType == StringType) { + s""" + if (!$arr.equals(${termLastRegexArray})) { + // regex Array value changed + ${termPatternArray} = new ${patternClassNamePattern}[$arr.numElements()]; + ${termLastRegexArray} = $arr.copy(); + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${termPatternArray}[$i] = null; + } else { + ${termPatternArray}[$i] = ${patternClassNamePattern}.compile( + "^".concat(${getValue}.toString()).concat("$$")); + } + } + }""".stripMargin + } else "" + val k = { + if (right.dataType == StringType) { + getValue = s"${termPatternArray}[$i]" + } + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${genEqual(ctx, ev, right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.value} = true; + break; + } + }""".stripMargin } + code + k + } + ) + } + + def genEqual(ctx: CodegenContext, ev: ExprCode, dataType: DataType, + c1: String, c2: String): String = dataType match { + case StringType => s"${c2}.matcher($c1.toString()).find()".stripMargin + case _ => ctx.genEqual(dataType, c1, c2) + } + + override def prettyName: String = "array_contains_with_pattern_match" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index c76dad208ea1e..c3090de116a09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -106,4 +106,37 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Array contains with pattern match") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) + val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) + val a4 = Literal.create(Seq[String]("\\d\\d\\s-\\s\\d\\d", null, "", "pattern"), + ArrayType(StringType)) + + checkEvaluation(ArrayContainsWithPatternMatch(a0, Literal(0)), false) + checkEvaluation(ArrayContainsWithPatternMatch(a0, Literal.create(null, IntegerType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal("")), true) + checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal("a")), null) + checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a2, Literal(1L)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a3, Literal("")), null) + checkEvaluation(ArrayContainsWithPatternMatch(a3, Literal.create(null, StringType)), null) + + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create(null, StringType)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("", StringType)), true) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("12 - 20", StringType)), true) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("pat", StringType)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("pattern", StringType)), true) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("ab - cd", StringType)), null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create(" 12 - 20 ", StringType)), + null) + checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("132 - 20", StringType)), null) + } } + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 47bf41a2da813..cae69ceb135ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2779,6 +2779,16 @@ object functions { ArrayContains(column.expr, Literal(value)) } + /** + * Returns true if the array contains `value` or `value` match any + * of the pattern available in array if array is of type string. + * @group collection_funcs + * @since 2.0 + */ + def array_contains_with_pattern_match(column: Column, value: Any): Column = withExpr { + ArrayContainsWithPatternMatch(column.expr, Literal(value)) + } + /** * Creates a new row for each element in the given array or map column. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 45db61515e9b6..210aefcc4a7f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -406,4 +406,99 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(true), Row(true)) ) } + + test("array contains with pattern match function") { + val df1 = Seq( + (Seq[Int](1, 2), "x"), + (Seq[Int](), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df1.select(array_contains_with_pattern_match(df1("a"), 1)), + Seq(Row(true), Row(false)) + ) + checkAnswer( + df1.selectExpr("array_contains_with_pattern_match(a, 1)"), + Seq(Row(true), Row(false)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df1.select(array_contains_with_pattern_match(df1("a"), null)) + } + intercept[AnalysisException] { + df1.selectExpr("array_contains_with_pattern_match(a, null)") + } + intercept[AnalysisException] { + df1.selectExpr("array_contains_with_pattern_match(null, 1)") + } + + checkAnswer( + df1.selectExpr("array_contains_with_pattern_match(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true)) + ) + checkAnswer( + df1.selectExpr("array_contains_with_pattern_match(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true)) + ) + val df2 = Seq( + (Seq[String]("1", "2"), "x"), + (Seq[String](), "x"), + (Seq[String]("\\d\\s-\\s\\d", "pattern", ""), "x") + ).toDF("a", "b") + + // Simple test cases + checkAnswer( + df2.select(array_contains_with_pattern_match(df2("a"), "1")), + Seq(Row(true), Row(false), Row(false)) + ) + checkAnswer( + df2.selectExpr("""array_contains_with_pattern_match(a, "1")"""), + Seq(Row(true), Row(false), Row(false)) + ) + checkAnswer( + df2.select(array_contains_with_pattern_match(df2("a"), "1 - 2")), + Seq(Row(false), Row(false), Row(true)) + ) + checkAnswer( + df2.selectExpr("""array_contains_with_pattern_match(a, "3 - 4")"""), + Seq(Row(false), Row(false), Row(true)) + ) + + // In hive, this errors because null has no type information + intercept[AnalysisException] { + df2.select(array_contains_with_pattern_match(df2("a"), null)) + } + intercept[AnalysisException] { + df2.selectExpr("array_contains_with_pattern_match(a, null)") + } + intercept[AnalysisException] { + df2.selectExpr("array_contains_with_pattern_match(null, 1)") + } + + checkAnswer( + df2.selectExpr("array_contains_with_pattern_match(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr("array_contains_with_pattern_match(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr( + """array_contains_with_pattern_match(array(array("\\d\\s\\d"), null)[0], + "1 3")""".stripMargin), + Seq(Row(true), Row(true), Row(true)) + ) + + checkAnswer( + df2.selectExpr( + """array_contains_with_pattern_match(array("\\d\\s\\d", null), array("1 3", null)[0])""".stripMargin), + Seq(Row(true), Row(true), Row(true)) + ) + + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index fdd02821dfa29..151e5a5a125dd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -261,6 +261,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { test("collection functions") { checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)") + checkSqlGeneration("""SELECT array_contains_with_pattern_match(array("\\d\\s\\d", "pattern", + "8"), "9 8")""".stripMargin) checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))") checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))") }