@@ -38,11 +38,52 @@ use crate::utils::{parse_volatility, validate_pycapsule};
3838struct RustAccumulator {
3939 accum : Py < PyAny > ,
4040 return_type : DataType ,
41+ pyarrow_array_type : Option < Py < PyType > > ,
42+ pyarrow_chunked_array_type : Option < Py < PyType > > ,
4143}
4244
4345impl RustAccumulator {
4446 fn new ( accum : Py < PyAny > , return_type : DataType ) -> Self {
45- Self { accum, return_type }
47+ Self {
48+ accum,
49+ return_type,
50+ pyarrow_array_type : None ,
51+ pyarrow_chunked_array_type : None ,
52+ }
53+ }
54+
55+ fn ensure_pyarrow_types (
56+ & mut self ,
57+ py : Python < ' _ > ,
58+ ) -> PyResult < ( Bound < ' _ , PyType > , Bound < ' _ , PyType > ) > {
59+ if self . pyarrow_array_type . is_none ( ) || self . pyarrow_chunked_array_type . is_none ( ) {
60+ let pyarrow = PyModule :: import ( py, "pyarrow" ) ?;
61+ let array_attr = pyarrow. getattr ( "Array" ) ?;
62+ let array_type = array_attr. downcast :: < PyType > ( ) ?;
63+ let chunked_array_attr = pyarrow. getattr ( "ChunkedArray" ) ?;
64+ let chunked_array_type = chunked_array_attr. downcast :: < PyType > ( ) ?;
65+ self . pyarrow_array_type = Some ( array_type. unbind ( ) ) ;
66+ self . pyarrow_chunked_array_type = Some ( chunked_array_type. unbind ( ) ) ;
67+ }
68+ Ok ( (
69+ self . pyarrow_array_type
70+ . as_ref ( )
71+ . expect ( "array type set" )
72+ . bind ( py) ,
73+ self . pyarrow_chunked_array_type
74+ . as_ref ( )
75+ . expect ( "chunked array type set" )
76+ . bind ( py) ,
77+ ) )
78+ }
79+
80+ fn is_pyarrow_array_like (
81+ & mut self ,
82+ py : Python < ' _ > ,
83+ value : & Bound < ' _ , PyAny > ,
84+ ) -> PyResult < bool > {
85+ let ( array_type, chunked_array_type) = self . ensure_pyarrow_types ( py) ?;
86+ Ok ( value. is_instance ( & array_type) ? || value. is_instance ( & chunked_array_type) ?)
4687 }
4788}
4889
@@ -65,7 +106,7 @@ impl Accumulator for RustAccumulator {
65106 self . return_type,
66107 DataType :: List ( _) | DataType :: LargeList ( _) | DataType :: FixedSizeList ( _, _)
67108 ) ;
68- if is_list_type && is_pyarrow_array_like ( py, & value) ? {
109+ if is_list_type && self . is_pyarrow_array_like ( py, & value) ? {
69110 let pyarrow = PyModule :: import ( py, "pyarrow" ) ?;
70111 let list_value = value. call_method0 ( "to_pylist" ) ?;
71112 let py_type = self . return_type . to_pyarrow ( py) ?;
@@ -171,15 +212,6 @@ pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
171212 } )
172213}
173214
174- fn is_pyarrow_array_like ( py : Python < ' _ > , value : & Bound < ' _ , PyAny > ) -> PyResult < bool > {
175- let pyarrow = PyModule :: import ( py, "pyarrow" ) ?;
176- let array_attr = pyarrow. getattr ( "Array" ) ?;
177- let array_type = array_attr. downcast :: < PyType > ( ) ?;
178- let chunked_array_attr = pyarrow. getattr ( "ChunkedArray" ) ?;
179- let chunked_array_type = chunked_array_attr. downcast :: < PyType > ( ) ?;
180- Ok ( value. is_instance ( array_type) ? || value. is_instance ( chunked_array_type) ?)
181- }
182-
183215fn aggregate_udf_from_capsule ( capsule : & Bound < ' _ , PyCapsule > ) -> PyDataFusionResult < AggregateUDF > {
184216 validate_pycapsule ( capsule, "datafusion_aggregate_udf" ) ?;
185217
0 commit comments