Skip to content

Commit f1ddf76

Browse files
committed
Refactor align_array_dimensions
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
1 parent 4881b5d commit f1ddf76

File tree

1 file changed

+65
-20
lines changed

1 file changed

+65
-20
lines changed

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -743,35 +743,39 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
743743
}
744744

745745
fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
746-
// Find the maximum number of dimensions
747-
let max_ndim: u64 = (*args
748-
.iter()
749-
.map(|arr| compute_array_ndims(Some(arr.clone())))
750-
.collect::<Result<Vec<Option<u64>>>>()?
751-
.iter()
752-
.max()
753-
.unwrap())
754-
.unwrap();
746+
let mut args_ndim = vec![];
747+
for arg in args.iter() {
748+
let ndim = compute_array_ndims(Some(arg.to_owned()))?;
749+
if let Some(ndim) = ndim {
750+
args_ndim.push(ndim);
751+
} else {
752+
return internal_err!("args should not be empty");
753+
}
754+
}
755+
756+
let max_ndim = args_ndim.iter().max();
757+
let max_ndim = if let Some(max_ndim) = max_ndim {
758+
max_ndim
759+
} else {
760+
return internal_err!("args_ndim should not be empty");
761+
};
755762

756763
// Align the dimensions of the arrays
757764
let aligned_args: Result<Vec<ArrayRef>> = args
758765
.into_iter()
759-
.map(|array| {
760-
let ndim = compute_array_ndims(Some(array.clone()))?.unwrap();
766+
.zip(args_ndim.iter())
767+
.map(|(array, ndim)| {
761768
if ndim < max_ndim {
762769
let mut aligned_array = array.clone();
763770
for _ in 0..(max_ndim - ndim) {
764-
let data_type = aligned_array.as_ref().data_type().clone();
765-
let offsets: Vec<i32> =
766-
(0..downcast_arg!(aligned_array, ListArray).offsets().len())
767-
.map(|i| i as i32)
768-
.collect();
769-
let field = Arc::new(Field::new("item", data_type, true));
771+
let data_type = aligned_array.data_type().to_owned();
772+
let array_lengths = vec![1; aligned_array.len()];
773+
let offsets = OffsetBuffer::<i32>::from_lengths(array_lengths);
770774

771775
aligned_array = Arc::new(ListArray::try_new(
772-
field,
773-
OffsetBuffer::new(offsets.into()),
774-
Arc::new(aligned_array.clone()),
776+
Arc::new(Field::new("item", data_type, true)),
777+
offsets,
778+
aligned_array,
775779
None,
776780
)?)
777781
}
@@ -1946,6 +1950,47 @@ mod tests {
19461950
};
19471951
use datafusion_common::scalar::ScalarValue;
19481952

1953+
#[test]
1954+
fn test_align_array_dimensions() {
1955+
let array1d_1 =
1956+
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
1957+
Some(vec![Some(1), Some(2), Some(3)]),
1958+
Some(vec![Some(4), Some(5)]),
1959+
]));
1960+
let array1d_2 =
1961+
Arc::new(ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
1962+
Some(vec![Some(6), Some(7), Some(8)]),
1963+
]));
1964+
1965+
let array2d_1 = Arc::new(wrap_into_list_array(array1d_1.clone())) as ArrayRef;
1966+
let array2d_2 = Arc::new(wrap_into_list_array(array1d_2.clone())) as ArrayRef;
1967+
1968+
let res =
1969+
align_array_dimensions(vec![array1d_1.to_owned(), array2d_2.to_owned()])
1970+
.expect("should not error");
1971+
1972+
let expected = as_list_array(&array2d_1).unwrap();
1973+
let expected_dim = compute_array_ndims(Some(array2d_1.to_owned())).unwrap();
1974+
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
1975+
assert_eq!(
1976+
compute_array_ndims(Some(res[0].clone())).unwrap(),
1977+
expected_dim
1978+
);
1979+
1980+
let array3d_1 = Arc::new(wrap_into_list_array(array2d_1)) as ArrayRef;
1981+
let array3d_2 = wrap_into_list_array(array2d_2.to_owned());
1982+
let res = align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())])
1983+
.expect("should not error");
1984+
1985+
let expected = as_list_array(&array3d_1).unwrap();
1986+
let expected_dim = compute_array_ndims(Some(array3d_1.to_owned())).unwrap();
1987+
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
1988+
assert_eq!(
1989+
compute_array_ndims(Some(res[0].clone())).unwrap(),
1990+
expected_dim
1991+
);
1992+
}
1993+
19491994
#[test]
19501995
fn test_array() {
19511996
// make_array(1, 2, 3) = [1, 2, 3]

0 commit comments

Comments
 (0)