Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArrayContainsWithPatternMatch]("array_contains_with_pattern_match"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[MapKeys]("map_keys"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
package org.apache.spark.sql.catalyst.expressions

import java.util.Comparator
import java.util.regex.Pattern

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Given an array or map, returns its size. Returns -1 if null.
Expand Down Expand Up @@ -261,3 +263,137 @@ case class ArrayContains(left: Expression, right: Expression)

override def prettyName: String = "array_contains"
}

@ExpressionDescription(
usage = """_FUNC_(array, value) - Returns TRUE if the array contains the value or

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that this changes the behavior of this method and even makes it a little surprising. Before, ..array("Mr.X"), "Mrox".. didn't match but now it does. Accepting strings and regexes in the same place is inherently ambiguous. I don't know if we'd change the meaning of an existing function like this.

@priyankagargnitk priyankagargnitk Sep 29, 2016

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, in that case wecan add one more expression , something like ArrayContainsWithPatternMatch? whats your thought about this?

for string arrays, if string matches with the any pattern in the array.
This is complete word match""",
extended = """ > SELECT _FUNC_(array("\\d\\s\\d", "2", "3"), "1 5");\n true""")
case class ArrayContainsWithPatternMatch(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def dataType: DataType = BooleanType

override def inputTypes: Seq[AbstractDataType] = right.dataType match {
case NullType => Seq()
case _ => left.dataType match {
case n @ ArrayType(element, _) => Seq(n, element)
case _ => Seq()
}
}

override def checkInputDataTypes(): TypeCheckResult = {
if (right.dataType == NullType) {
TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments")
} else if (!left.dataType.isInstanceOf[ArrayType]
|| left.dataType.asInstanceOf[ArrayType].elementType != right.dataType) {
TypeCheckResult.TypeCheckFailure(
"Arguments must be an array followed by a value of same type as the array members")
} else {
TypeCheckResult.TypeCheckSuccess
}
}

override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull
}

// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegexArray: ArrayData = _
// last regex pattern, we cache it for performance concern
@transient private var patternArray: Array[Pattern] = _


override def nullSafeEval(arrAny: Any, value: Any): Any = {
val arr = arrAny.asInstanceOf[ArrayData]
var hasNull = false
if (right.dataType == StringType) {
if (!arr.equals(lastRegexArray)) {
lastRegexArray = arr.copy()
patternArray = new Array[Pattern](arr.numElements())
lastRegexArray.foreach(StringType, (i : Int, str : Any) => if (str == null) {
patternArray(i) = null
} else {
patternArray(i) = Pattern.compile("^".concat(str.toString).concat("$"))
})
}
patternArray.foreach(v => if (v == null) {
hasNull = true
false
} else if (v.matcher(value.asInstanceOf[UTF8String].toString).find()) {
return true
})
} else {
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == null) {
hasNull = true
} else if (v == value) {
return true
}
)
}

if (hasNull) {
null
} else {
false
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {

val termLastRegexArray = ctx.freshName("lastRegexArray")
val termPatternArray = ctx.freshName("patternArray")
val patternClassNamePattern = classOf[Pattern].getCanonicalName.stripSuffix("[]")
val arrayDataClassNamePattern = classOf[ArrayData].getCanonicalName.stripSuffix("[]")

ctx.addMutableState(s"$arrayDataClassNamePattern", termLastRegexArray,
s"${termLastRegexArray} = null;")
ctx.addMutableState(s"$patternClassNamePattern[]", termPatternArray,
s"${termPatternArray} = null;")

nullSafeCodeGen(ctx, ev, (arr, value) => {
val i = ctx.freshName("i")
var getValue = ctx.getValue(arr, right.dataType, i)
val code = if (right.dataType == StringType) {
s"""
if (!$arr.equals(${termLastRegexArray})) {
// regex Array value changed
${termPatternArray} = new ${patternClassNamePattern}[$arr.numElements()];
${termLastRegexArray} = $arr.copy();
for (int $i = 0; $i < $arr.numElements(); $i ++) {
if ($arr.isNullAt($i)) {
${termPatternArray}[$i] = null;
} else {
${termPatternArray}[$i] = ${patternClassNamePattern}.compile(
"^".concat(${getValue}.toString()).concat("$$"));
}
}
}""".stripMargin
} else ""
val k = {
if (right.dataType == StringType) {
getValue = s"${termPatternArray}[$i]"
}
s"""
for (int $i = 0; $i < $arr.numElements(); $i ++) {
if ($arr.isNullAt($i)) {
${ev.isNull} = true;
} else if (${genEqual(ctx, ev, right.dataType, value, getValue)}) {
${ev.isNull} = false;
${ev.value} = true;
break;
}
}""".stripMargin }
code + k
}
)
}

def genEqual(ctx: CodegenContext, ev: ExprCode, dataType: DataType,
c1: String, c2: String): String = dataType match {
case StringType => s"${c2}.matcher($c1.toString()).find()".stripMargin
case _ => ctx.genEqual(dataType, c1, c2)
}

override def prettyName: String = "array_contains_with_pattern_match"
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,37 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(ArrayContains(a3, Literal("")), null)
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}

test("Array contains with pattern match") {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a2 = Literal.create(Seq(null), ArrayType(LongType))
val a3 = Literal.create(null, ArrayType(StringType))
val a4 = Literal.create(Seq[String]("\\d\\d\\s-\\s\\d\\d", null, "", "pattern"),
ArrayType(StringType))

checkEvaluation(ArrayContainsWithPatternMatch(a0, Literal(0)), false)
checkEvaluation(ArrayContainsWithPatternMatch(a0, Literal.create(null, IntegerType)), null)

checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal("")), true)
checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal("a")), null)
checkEvaluation(ArrayContainsWithPatternMatch(a1, Literal.create(null, StringType)), null)

checkEvaluation(ArrayContainsWithPatternMatch(a2, Literal(1L)), null)
checkEvaluation(ArrayContainsWithPatternMatch(a2, Literal.create(null, LongType)), null)

checkEvaluation(ArrayContainsWithPatternMatch(a3, Literal("")), null)
checkEvaluation(ArrayContainsWithPatternMatch(a3, Literal.create(null, StringType)), null)

checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create(null, StringType)), null)
checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("", StringType)), true)
checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("12 - 20", StringType)), true)
checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("pat", StringType)), null)
checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("pattern", StringType)), true)
checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("ab - cd", StringType)), null)
checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create(" 12 - 20 ", StringType)),
null)
checkEvaluation(ArrayContainsWithPatternMatch(a4, Literal.create("132 - 20", StringType)), null)
}
}

10 changes: 10 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2779,6 +2779,16 @@ object functions {
ArrayContains(column.expr, Literal(value))
}

/**
* Returns true if the array contains `value` or `value` match any
* of the pattern available in array if array is of type string.
* @group collection_funcs
* @since 2.0
*/
def array_contains_with_pattern_match(column: Column, value: Any): Column = withExpr {
ArrayContainsWithPatternMatch(column.expr, Literal(value))
}

/**
* Creates a new row for each element in the given array or map column.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,4 +406,99 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(true), Row(true))
)
}

test("array contains with pattern match function") {
val df1 = Seq(
(Seq[Int](1, 2), "x"),
(Seq[Int](), "x")
).toDF("a", "b")

// Simple test cases
checkAnswer(
df1.select(array_contains_with_pattern_match(df1("a"), 1)),
Seq(Row(true), Row(false))
)
checkAnswer(
df1.selectExpr("array_contains_with_pattern_match(a, 1)"),
Seq(Row(true), Row(false))
)

// In hive, this errors because null has no type information
intercept[AnalysisException] {
df1.select(array_contains_with_pattern_match(df1("a"), null))
}
intercept[AnalysisException] {
df1.selectExpr("array_contains_with_pattern_match(a, null)")
}
intercept[AnalysisException] {
df1.selectExpr("array_contains_with_pattern_match(null, 1)")
}

checkAnswer(
df1.selectExpr("array_contains_with_pattern_match(array(array(1), null)[0], 1)"),
Seq(Row(true), Row(true))
)
checkAnswer(
df1.selectExpr("array_contains_with_pattern_match(array(1, null), array(1, null)[0])"),
Seq(Row(true), Row(true))
)
val df2 = Seq(
(Seq[String]("1", "2"), "x"),
(Seq[String](), "x"),
(Seq[String]("\\d\\s-\\s\\d", "pattern", ""), "x")
).toDF("a", "b")

// Simple test cases
checkAnswer(
df2.select(array_contains_with_pattern_match(df2("a"), "1")),
Seq(Row(true), Row(false), Row(false))
)
checkAnswer(
df2.selectExpr("""array_contains_with_pattern_match(a, "1")"""),
Seq(Row(true), Row(false), Row(false))
)
checkAnswer(
df2.select(array_contains_with_pattern_match(df2("a"), "1 - 2")),
Seq(Row(false), Row(false), Row(true))
)
checkAnswer(
df2.selectExpr("""array_contains_with_pattern_match(a, "3 - 4")"""),
Seq(Row(false), Row(false), Row(true))
)

// In hive, this errors because null has no type information
intercept[AnalysisException] {
df2.select(array_contains_with_pattern_match(df2("a"), null))
}
intercept[AnalysisException] {
df2.selectExpr("array_contains_with_pattern_match(a, null)")
}
intercept[AnalysisException] {
df2.selectExpr("array_contains_with_pattern_match(null, 1)")
}

checkAnswer(
df2.selectExpr("array_contains_with_pattern_match(array(array(1), null)[0], 1)"),
Seq(Row(true), Row(true), Row(true))
)

checkAnswer(
df2.selectExpr("array_contains_with_pattern_match(array(1, null), array(1, null)[0])"),
Seq(Row(true), Row(true), Row(true))
)

checkAnswer(
df2.selectExpr(
"""array_contains_with_pattern_match(array(array("\\d\\s\\d"), null)[0],
"1 3")""".stripMargin),
Seq(Row(true), Row(true), Row(true))
)

checkAnswer(
df2.selectExpr(
"""array_contains_with_pattern_match(array("\\d\\s\\d", null), array("1 3", null)[0])""".stripMargin),
Seq(Row(true), Row(true), Row(true))
)

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {

test("collection functions") {
checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)")
checkSqlGeneration("""SELECT array_contains_with_pattern_match(array("\\d\\s\\d", "pattern",
"8"), "9 8")""".stripMargin)
checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))")
checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))")
}
Expand Down