Skip to content

Commit 7f363a7

Browse files
committed
Refactor RustAccumulator to support pyarrow array types and improve type checking for list types
1 parent 21906bb commit 7f363a7

File tree

2 files changed

+45
-13
lines changed

2 files changed

+45
-13
lines changed

python/tests/test_udaf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def merge(self, states: list[pa.Array]) -> None:
7575
if state is not None:
7676
self._values.extend(state)
7777

78-
def evaluate(self) -> pa.Array:
79-
return pa.array(self._values, type=pa.timestamp("ns"))
78+
def evaluate(self) -> pa.Scalar:
79+
return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))
8080

8181

8282
@pytest.fixture

src/udaf.rs

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,52 @@ use crate::utils::{parse_volatility, validate_pycapsule};
3838
struct RustAccumulator {
3939
accum: Py<PyAny>,
4040
return_type: DataType,
41+
pyarrow_array_type: Option<Py<PyType>>,
42+
pyarrow_chunked_array_type: Option<Py<PyType>>,
4143
}
4244

4345
impl RustAccumulator {
4446
fn new(accum: Py<PyAny>, return_type: DataType) -> Self {
45-
Self { accum, return_type }
47+
Self {
48+
accum,
49+
return_type,
50+
pyarrow_array_type: None,
51+
pyarrow_chunked_array_type: None,
52+
}
53+
}
54+
55+
fn ensure_pyarrow_types(
56+
&mut self,
57+
py: Python<'_>,
58+
) -> PyResult<(Bound<'_, PyType>, Bound<'_, PyType>)> {
59+
if self.pyarrow_array_type.is_none() || self.pyarrow_chunked_array_type.is_none() {
60+
let pyarrow = PyModule::import(py, "pyarrow")?;
61+
let array_attr = pyarrow.getattr("Array")?;
62+
let array_type = array_attr.downcast::<PyType>()?;
63+
let chunked_array_attr = pyarrow.getattr("ChunkedArray")?;
64+
let chunked_array_type = chunked_array_attr.downcast::<PyType>()?;
65+
self.pyarrow_array_type = Some(array_type.unbind());
66+
self.pyarrow_chunked_array_type = Some(chunked_array_type.unbind());
67+
}
68+
Ok((
69+
self.pyarrow_array_type
70+
.as_ref()
71+
.expect("array type set")
72+
.bind(py),
73+
self.pyarrow_chunked_array_type
74+
.as_ref()
75+
.expect("chunked array type set")
76+
.bind(py),
77+
))
78+
}
79+
80+
fn is_pyarrow_array_like(
81+
&mut self,
82+
py: Python<'_>,
83+
value: &Bound<'_, PyAny>,
84+
) -> PyResult<bool> {
85+
let (array_type, chunked_array_type) = self.ensure_pyarrow_types(py)?;
86+
Ok(value.is_instance(&array_type)? || value.is_instance(&chunked_array_type)?)
4687
}
4788
}
4889

@@ -65,7 +106,7 @@ impl Accumulator for RustAccumulator {
65106
self.return_type,
66107
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _)
67108
);
68-
if is_list_type && is_pyarrow_array_like(py, &value)? {
109+
if is_list_type && self.is_pyarrow_array_like(py, &value)? {
69110
let pyarrow = PyModule::import(py, "pyarrow")?;
70111
let list_value = value.call_method0("to_pylist")?;
71112
let py_type = self.return_type.to_pyarrow(py)?;
@@ -171,15 +212,6 @@ pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
171212
})
172213
}
173214

174-
fn is_pyarrow_array_like(py: Python<'_>, value: &Bound<'_, PyAny>) -> PyResult<bool> {
175-
let pyarrow = PyModule::import(py, "pyarrow")?;
176-
let array_attr = pyarrow.getattr("Array")?;
177-
let array_type = array_attr.downcast::<PyType>()?;
178-
let chunked_array_attr = pyarrow.getattr("ChunkedArray")?;
179-
let chunked_array_type = chunked_array_attr.downcast::<PyType>()?;
180-
Ok(value.is_instance(array_type)? || value.is_instance(chunked_array_type)?)
181-
}
182-
183215
fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
184216
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
185217

0 commit comments

Comments
 (0)