Skip to content

Commit 41aed0e

Browse files
LordwormsWeijun-H
authored andcommitted
Add contains function, and support in datafusion substrait consumer (apache#10879)
* adding new function contains * adding substrait test * adding doc * adding doc * Update docs/source/user-guide/sql/scalar_functions.md Co-authored-by: Alex Huang <huangweijun1001@gmail.com> * adding entry --------- Co-authored-by: Alex Huang <huangweijun1001@gmail.com>
1 parent cf05f72 commit 41aed0e

File tree

7 files changed

+373
-2
lines changed

7 files changed

+373
-2
lines changed
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::utils::make_scalar_function;
19+
use arrow::array::{ArrayRef, OffsetSizeTrait};
20+
use arrow::datatypes::DataType;
21+
use arrow::datatypes::DataType::Boolean;
22+
use datafusion_common::cast::as_generic_string_array;
23+
use datafusion_common::DataFusionError;
24+
use datafusion_common::Result;
25+
use datafusion_common::{arrow_datafusion_err, exec_err};
26+
use datafusion_expr::ScalarUDFImpl;
27+
use datafusion_expr::TypeSignature::Exact;
28+
use datafusion_expr::{ColumnarValue, Signature, Volatility};
29+
use std::any::Any;
30+
use std::sync::Arc;
31+
#[derive(Debug)]
32+
pub struct ContainsFunc {
33+
signature: Signature,
34+
}
35+
36+
impl Default for ContainsFunc {
37+
fn default() -> Self {
38+
ContainsFunc::new()
39+
}
40+
}
41+
42+
impl ContainsFunc {
43+
pub fn new() -> Self {
44+
use DataType::*;
45+
Self {
46+
signature: Signature::one_of(
47+
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
48+
Volatility::Immutable,
49+
),
50+
}
51+
}
52+
}
53+
54+
impl ScalarUDFImpl for ContainsFunc {
55+
fn as_any(&self) -> &dyn Any {
56+
self
57+
}
58+
59+
fn name(&self) -> &str {
60+
"contains"
61+
}
62+
63+
fn signature(&self) -> &Signature {
64+
&self.signature
65+
}
66+
67+
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
68+
Ok(Boolean)
69+
}
70+
71+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
72+
match args[0].data_type() {
73+
DataType::Utf8 => make_scalar_function(contains::<i32>, vec![])(args),
74+
DataType::LargeUtf8 => make_scalar_function(contains::<i64>, vec![])(args),
75+
other => {
76+
exec_err!("unsupported data type {other:?} for function contains")
77+
}
78+
}
79+
}
80+
}
81+
82+
/// use regexp_is_match_utf8_scalar to do the calculation for contains
83+
pub fn contains<T: OffsetSizeTrait>(
84+
args: &[ArrayRef],
85+
) -> Result<ArrayRef, DataFusionError> {
86+
let mod_str = as_generic_string_array::<T>(&args[0])?;
87+
let match_str = as_generic_string_array::<T>(&args[1])?;
88+
let res = arrow::compute::kernels::comparison::regexp_is_match_utf8(
89+
mod_str, match_str, None,
90+
)
91+
.map_err(|e| arrow_datafusion_err!(e))?;
92+
93+
Ok(Arc::new(res) as ArrayRef)
94+
}
95+
96+
#[cfg(test)]
97+
mod tests {
98+
use crate::string::contains::ContainsFunc;
99+
use crate::utils::test::test_function;
100+
use arrow::array::Array;
101+
use arrow::{array::BooleanArray, datatypes::DataType::Boolean};
102+
use datafusion_common::Result;
103+
use datafusion_common::ScalarValue;
104+
use datafusion_expr::ColumnarValue;
105+
use datafusion_expr::ScalarUDFImpl;
106+
#[test]
107+
fn test_functions() -> Result<()> {
108+
test_function!(
109+
ContainsFunc::new(),
110+
&[
111+
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
112+
ColumnarValue::Scalar(ScalarValue::from("alph")),
113+
],
114+
Ok(Some(true)),
115+
bool,
116+
Boolean,
117+
BooleanArray
118+
);
119+
test_function!(
120+
ContainsFunc::new(),
121+
&[
122+
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
123+
ColumnarValue::Scalar(ScalarValue::from("dddddd")),
124+
],
125+
Ok(Some(false)),
126+
bool,
127+
Boolean,
128+
BooleanArray
129+
);
130+
test_function!(
131+
ContainsFunc::new(),
132+
&[
133+
ColumnarValue::Scalar(ScalarValue::from("alphabet")),
134+
ColumnarValue::Scalar(ScalarValue::from("pha")),
135+
],
136+
Ok(Some(true)),
137+
bool,
138+
Boolean,
139+
BooleanArray
140+
);
141+
Ok(())
142+
}
143+
}

datafusion/functions/src/string/mod.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pub mod chr;
2828
pub mod common;
2929
pub mod concat;
3030
pub mod concat_ws;
31+
pub mod contains;
3132
pub mod ends_with;
3233
pub mod initcap;
3334
pub mod levenshtein;
@@ -43,7 +44,6 @@ pub mod starts_with;
4344
pub mod to_hex;
4445
pub mod upper;
4546
pub mod uuid;
46-
4747
// create UDFs
4848
make_udf_function!(ascii::AsciiFunc, ASCII, ascii);
4949
make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length);
@@ -66,7 +66,7 @@ make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part);
6666
make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex);
6767
make_udf_function!(upper::UpperFunc, UPPER, upper);
6868
make_udf_function!(uuid::UuidFunc, UUID, uuid);
69-
69+
make_udf_function!(contains::ContainsFunc, CONTAINS, contains);
7070
pub mod expr_fn {
7171
use datafusion_expr::Expr;
7272

@@ -149,6 +149,9 @@ pub mod expr_fn {
149149
),(
150150
uuid,
151151
"returns uuid v4 as a string value",
152+
), (
153+
contains,
154+
"Return true if search_string is found within string. treated it like a reglike",
152155
));
153156

154157
#[doc = "Removes all characters, spaces by default, from both sides of a string"]
@@ -188,5 +191,6 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
188191
to_hex(),
189192
upper(),
190193
uuid(),
194+
contains(),
191195
]
192196
}

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,3 +1158,21 @@ drop table uuid_table
11581158

11591159
statement ok
11601160
drop table t
1161+
1162+
1163+
# test for contains
1164+
1165+
query B
1166+
select contains('alphabet', 'pha');
1167+
----
1168+
true
1169+
1170+
query B
1171+
select contains('alphabet', 'dddd');
1172+
----
1173+
false
1174+
1175+
query B
1176+
select contains('', '');
1177+
----
1178+
true
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Tests for Function Compatibility
19+
20+
#[cfg(test)]
21+
mod tests {
22+
use datafusion::common::Result;
23+
use datafusion::prelude::{CsvReadOptions, SessionContext};
24+
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
25+
use std::fs::File;
26+
use std::io::BufReader;
27+
use substrait::proto::Plan;
28+
29+
#[tokio::test]
30+
async fn contains_function_test() -> Result<()> {
31+
let ctx = create_context().await?;
32+
33+
let path = "tests/testdata/contains_plan.substrait.json";
34+
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
35+
File::open(path).expect("file not found"),
36+
))
37+
.expect("failed to parse json");
38+
39+
let plan = from_substrait_plan(&ctx, &proto).await?;
40+
41+
let plan_str = format!("{:?}", plan);
42+
43+
assert_eq!(
44+
plan_str,
45+
"Projection: nation.b AS n_name\
46+
\n Filter: contains(nation.b, Utf8(\"IA\"))\
47+
\n TableScan: nation projection=[a, b, c, d, e, f]"
48+
);
49+
Ok(())
50+
}
51+
52+
async fn create_context() -> datafusion::common::Result<SessionContext> {
53+
let ctx = SessionContext::new();
54+
ctx.register_csv("nation", "tests/testdata/data.csv", CsvReadOptions::new())
55+
.await?;
56+
Ok(ctx)
57+
}
58+
}

datafusion/substrait/tests/cases/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
mod consumer_integration;
19+
mod function_test;
1920
mod logical_plans;
2021
mod roundtrip_logical_plan;
2122
mod roundtrip_physical_plan;

0 commit comments

Comments
 (0)