1818use std:: any:: Any ;
1919use std:: sync:: Arc ;
2020
21- use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait } ;
21+ use arrow:: array:: { ArrayRef , GenericStringArray , OffsetSizeTrait , StringArray } ;
2222use arrow:: datatypes:: DataType ;
2323
24- use datafusion_common:: cast:: { as_generic_string_array, as_int64_array} ;
24+ use datafusion_common:: cast:: {
25+ as_generic_string_array, as_int64_array, as_string_view_array,
26+ } ;
2527use datafusion_common:: { exec_err, Result } ;
2628use datafusion_expr:: TypeSignature :: * ;
2729use datafusion_expr:: { ColumnarValue , Volatility } ;
@@ -45,7 +47,14 @@ impl RepeatFunc {
4547 use DataType :: * ;
4648 Self {
4749 signature : Signature :: one_of (
48- vec ! [ Exact ( vec![ Utf8 , Int64 ] ) , Exact ( vec![ LargeUtf8 , Int64 ] ) ] ,
50+ vec ! [
51+ // Planner attempts coercion to the target type starting with the most preferred candidate.
52+ // For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`.
53+ // If that fails, it proceeds to `(Utf8, Int64)`.
54+ Exact ( vec![ Utf8View , Int64 ] ) ,
55+ Exact ( vec![ Utf8 , Int64 ] ) ,
56+ Exact ( vec![ LargeUtf8 , Int64 ] ) ,
57+ ] ,
4958 Volatility :: Immutable ,
5059 ) ,
5160 }
@@ -71,9 +80,10 @@ impl ScalarUDFImpl for RepeatFunc {
7180
7281 fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
7382 match args[ 0 ] . data_type ( ) {
83+ DataType :: Utf8View => make_scalar_function ( repeat_utf8view, vec ! [ ] ) ( args) ,
7484 DataType :: Utf8 => make_scalar_function ( repeat :: < i32 > , vec ! [ ] ) ( args) ,
7585 DataType :: LargeUtf8 => make_scalar_function ( repeat :: < i64 > , vec ! [ ] ) ( args) ,
76- other => exec_err ! ( "Unsupported data type {other:?} for function repeat" ) ,
86+ other => exec_err ! ( "Unsupported data type {other:?} for function repeat. Expected Utf8, Utf8View or LargeUtf8 " ) ,
7787 }
7888 }
7989}
@@ -87,18 +97,35 @@ fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
8797 let result = string_array
8898 . iter ( )
8999 . zip ( number_array. iter ( ) )
90- . map ( |( string, number) | match ( string, number) {
91- ( Some ( string) , Some ( number) ) if number >= 0 => {
92- Some ( string. repeat ( number as usize ) )
93- }
94- ( Some ( _) , Some ( _) ) => Some ( "" . to_string ( ) ) ,
95- _ => None ,
96- } )
100+ . map ( |( string, number) | repeat_common ( string, number) )
97101 . collect :: < GenericStringArray < T > > ( ) ;
98102
99103 Ok ( Arc :: new ( result) as ArrayRef )
100104}
101105
106+ fn repeat_utf8view ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
107+ let string_view_array = as_string_view_array ( & args[ 0 ] ) ?;
108+ let number_array = as_int64_array ( & args[ 1 ] ) ?;
109+
110+ let result = string_view_array
111+ . iter ( )
112+ . zip ( number_array. iter ( ) )
113+ . map ( |( string, number) | repeat_common ( string, number) )
114+ . collect :: < StringArray > ( ) ;
115+
116+ Ok ( Arc :: new ( result) as ArrayRef )
117+ }
118+
119+ fn repeat_common ( string : Option < & str > , number : Option < i64 > ) -> Option < String > {
120+ match ( string, number) {
121+ ( Some ( string) , Some ( number) ) if number >= 0 => {
122+ Some ( string. repeat ( number as usize ) )
123+ }
124+ ( Some ( _) , Some ( _) ) => Some ( "" . to_string ( ) ) ,
125+ _ => None ,
126+ }
127+ }
128+
102129#[ cfg( test) ]
103130mod tests {
104131 use arrow:: array:: { Array , StringArray } ;
@@ -124,7 +151,6 @@ mod tests {
124151 Utf8 ,
125152 StringArray
126153 ) ;
127-
128154 test_function ! (
129155 RepeatFunc :: new( ) ,
130156 & [
@@ -148,6 +174,40 @@ mod tests {
148174 StringArray
149175 ) ;
150176
177+ test_function ! (
178+ RepeatFunc :: new( ) ,
179+ & [
180+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from( "Pg" ) ) ) ) ,
181+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 4 ) ) ) ,
182+ ] ,
183+ Ok ( Some ( "PgPgPgPg" ) ) ,
184+ & str ,
185+ Utf8 ,
186+ StringArray
187+ ) ;
188+ test_function ! (
189+ RepeatFunc :: new( ) ,
190+ & [
191+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( None ) ) ,
192+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( 4 ) ) ) ,
193+ ] ,
194+ Ok ( None ) ,
195+ & str ,
196+ Utf8 ,
197+ StringArray
198+ ) ;
199+ test_function ! (
200+ RepeatFunc :: new( ) ,
201+ & [
202+ ColumnarValue :: Scalar ( ScalarValue :: Utf8View ( Some ( String :: from( "Pg" ) ) ) ) ,
203+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( None ) ) ,
204+ ] ,
205+ Ok ( None ) ,
206+ & str ,
207+ Utf8 ,
208+ StringArray
209+ ) ;
210+
151211 Ok ( ( ) )
152212 }
153213}
0 commit comments