1717
1818use crate :: ScalarFunctionExpr ;
1919use arrow:: array:: { make_array, MutableArrayData , RecordBatch } ;
20- use arrow:: datatypes:: { DataType , Field , Schema } ;
20+ use arrow:: datatypes:: { DataType , Field , FieldRef , Schema } ;
2121use datafusion_common:: config:: ConfigOptions ;
2222use datafusion_common:: Result ;
2323use datafusion_common:: { internal_err, not_impl_err} ;
24- use datafusion_expr:: async_udf:: { AsyncScalarFunctionArgs , AsyncScalarUDF } ;
24+ use datafusion_expr:: async_udf:: AsyncScalarUDF ;
25+ use datafusion_expr:: ScalarFunctionArgs ;
2526use datafusion_expr_common:: columnar_value:: ColumnarValue ;
2627use datafusion_physical_expr_common:: physical_expr:: PhysicalExpr ;
2728use std:: any:: Any ;
@@ -36,6 +37,8 @@ pub struct AsyncFuncExpr {
3637 pub name : String ,
3738 /// The actual function (always `ScalarFunctionExpr`)
3839 pub func : Arc < dyn PhysicalExpr > ,
40+ /// The field that this function will return
41+ return_field : FieldRef ,
3942}
4043
4144impl Display for AsyncFuncExpr {
@@ -59,17 +62,23 @@ impl Hash for AsyncFuncExpr {
5962
6063impl AsyncFuncExpr {
6164 /// create a new AsyncFuncExpr
62- pub fn try_new ( name : impl Into < String > , func : Arc < dyn PhysicalExpr > ) -> Result < Self > {
65+ pub fn try_new (
66+ name : impl Into < String > ,
67+ func : Arc < dyn PhysicalExpr > ,
68+ schema : & Schema ,
69+ ) -> Result < Self > {
6370 let Some ( _) = func. as_any ( ) . downcast_ref :: < ScalarFunctionExpr > ( ) else {
6471 return internal_err ! (
6572 "unexpected function type, expected ScalarFunctionExpr, got: {:?}" ,
6673 func
6774 ) ;
6875 } ;
6976
77+ let return_field = func. return_field ( schema) ?;
7078 Ok ( Self {
7179 name : name. into ( ) ,
7280 func,
81+ return_field,
7382 } )
7483 }
7584
@@ -128,6 +137,12 @@ impl AsyncFuncExpr {
128137 ) ;
129138 } ;
130139
140+ let arg_fields = scalar_function_expr
141+ . args ( )
142+ . iter ( )
143+ . map ( |e| e. return_field ( batch. schema_ref ( ) ) )
144+ . collect :: < Result < Vec < _ > > > ( ) ?;
145+
131146 let mut result_batches = vec ! [ ] ;
132147 if let Some ( ideal_batch_size) = self . ideal_batch_size ( ) ? {
133148 let mut remainder = batch. clone ( ) ;
@@ -148,10 +163,11 @@ impl AsyncFuncExpr {
148163 result_batches. push (
149164 async_udf
150165 . invoke_async_with_args (
151- AsyncScalarFunctionArgs {
152- args : args. to_vec ( ) ,
166+ ScalarFunctionArgs {
167+ args,
168+ arg_fields : arg_fields. clone ( ) ,
153169 number_rows : current_batch. num_rows ( ) ,
154- schema : current_batch . schema ( ) ,
170+ return_field : Arc :: clone ( & self . return_field ) ,
155171 } ,
156172 option,
157173 )
@@ -168,10 +184,11 @@ impl AsyncFuncExpr {
168184 result_batches. push (
169185 async_udf
170186 . invoke_async_with_args (
171- AsyncScalarFunctionArgs {
187+ ScalarFunctionArgs {
172188 args : args. to_vec ( ) ,
189+ arg_fields,
173190 number_rows : batch. num_rows ( ) ,
174- schema : batch . schema ( ) ,
191+ return_field : Arc :: clone ( & self . return_field ) ,
175192 } ,
176193 option,
177194 )
@@ -223,6 +240,7 @@ impl PhysicalExpr for AsyncFuncExpr {
223240 Ok ( Arc :: new ( AsyncFuncExpr {
224241 name : self . name . clone ( ) ,
225242 func : new_func,
243+ return_field : Arc :: clone ( & self . return_field ) ,
226244 } ) )
227245 }
228246
0 commit comments