Skip to content

Commit aefee03

Browse files
authored
Replace macro with function for array_repeat (#8071)
* General array repeat Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * done Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * remove test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add comment Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * fm Signed-off-by: jayzhan211 <jayzhan211@gmail.com> --------- Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
1 parent 3446382 commit aefee03

File tree

2 files changed

+169
-241
lines changed

2 files changed

+169
-241
lines changed

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 126 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -841,125 +841,6 @@ pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
841841
concat_internal(new_args.as_slice())
842842
}
843843

844-
macro_rules! general_repeat {
845-
($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{
846-
let mut offsets: Vec<i32> = vec![0];
847-
let mut values =
848-
downcast_arg!(new_empty_array($ELEMENT.data_type()), $ARRAY_TYPE).clone();
849-
850-
let element_array = downcast_arg!($ELEMENT, $ARRAY_TYPE);
851-
for (el, c) in element_array.iter().zip($COUNT.iter()) {
852-
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
853-
DataFusionError::Internal(format!("offsets should not be empty"))
854-
})?;
855-
match el {
856-
Some(el) => {
857-
let c = if c < Some(0) { 0 } else { c.unwrap() } as usize;
858-
let repeated_array =
859-
[Some(el.clone())].repeat(c).iter().collect::<$ARRAY_TYPE>();
860-
861-
values = downcast_arg!(
862-
compute::concat(&[&values, &repeated_array])?.clone(),
863-
$ARRAY_TYPE
864-
)
865-
.clone();
866-
offsets.push(last_offset + repeated_array.len() as i32);
867-
}
868-
None => {
869-
offsets.push(last_offset);
870-
}
871-
}
872-
}
873-
874-
let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true));
875-
876-
Arc::new(ListArray::try_new(
877-
field,
878-
OffsetBuffer::new(offsets.into()),
879-
Arc::new(values),
880-
None,
881-
)?)
882-
}};
883-
}
884-
885-
macro_rules! general_repeat_list {
886-
($ELEMENT:expr, $COUNT:expr, $ARRAY_TYPE:ident) => {{
887-
let mut offsets: Vec<i32> = vec![0];
888-
let mut values =
889-
downcast_arg!(new_empty_array($ELEMENT.data_type()), ListArray).clone();
890-
891-
let element_array = downcast_arg!($ELEMENT, ListArray);
892-
for (el, c) in element_array.iter().zip($COUNT.iter()) {
893-
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
894-
DataFusionError::Internal(format!("offsets should not be empty"))
895-
})?;
896-
match el {
897-
Some(el) => {
898-
let c = if c < Some(0) { 0 } else { c.unwrap() } as usize;
899-
let repeated_vec = vec![el; c];
900-
901-
let mut i: i32 = 0;
902-
let mut repeated_offsets = vec![i];
903-
repeated_offsets.extend(
904-
repeated_vec
905-
.clone()
906-
.into_iter()
907-
.map(|a| {
908-
i += a.len() as i32;
909-
i
910-
})
911-
.collect::<Vec<_>>(),
912-
);
913-
914-
let mut repeated_values = downcast_arg!(
915-
new_empty_array(&element_array.value_type()),
916-
$ARRAY_TYPE
917-
)
918-
.clone();
919-
for repeated_list in repeated_vec {
920-
repeated_values = downcast_arg!(
921-
compute::concat(&[&repeated_values, &repeated_list])?,
922-
$ARRAY_TYPE
923-
)
924-
.clone();
925-
}
926-
927-
let field = Arc::new(Field::new(
928-
"item",
929-
element_array.value_type().clone(),
930-
true,
931-
));
932-
let repeated_array = ListArray::try_new(
933-
field,
934-
OffsetBuffer::new(repeated_offsets.clone().into()),
935-
Arc::new(repeated_values),
936-
None,
937-
)?;
938-
939-
values = downcast_arg!(
940-
compute::concat(&[&values, &repeated_array,])?.clone(),
941-
ListArray
942-
)
943-
.clone();
944-
offsets.push(last_offset + repeated_array.len() as i32);
945-
}
946-
None => {
947-
offsets.push(last_offset);
948-
}
949-
}
950-
}
951-
952-
let field = Arc::new(Field::new("item", $ELEMENT.data_type().clone(), true));
953-
954-
Arc::new(ListArray::try_new(
955-
field,
956-
OffsetBuffer::new(offsets.into()),
957-
Arc::new(values),
958-
None,
959-
)?)
960-
}};
961-
}
962-
963844
/// Array_empty SQL function
964845
pub fn array_empty(args: &[ArrayRef]) -> Result<ArrayRef> {
965846
if args[0].as_any().downcast_ref::<NullArray>().is_some() {
@@ -978,28 +859,136 @@ pub fn array_empty(args: &[ArrayRef]) -> Result<ArrayRef> {
978859
/// Array_repeat SQL function
979860
pub fn array_repeat(args: &[ArrayRef]) -> Result<ArrayRef> {
980861
let element = &args[0];
981-
let count = as_int64_array(&args[1])?;
862+
let count_array = as_int64_array(&args[1])?;
982863

983-
let res = match element.data_type() {
984-
DataType::List(field) => {
985-
macro_rules! array_function {
986-
($ARRAY_TYPE:ident) => {
987-
general_repeat_list!(element, count, $ARRAY_TYPE)
988-
};
989-
}
990-
call_array_function!(field.data_type(), true)
864+
match element.data_type() {
865+
DataType::List(_) => {
866+
let list_array = as_list_array(element)?;
867+
general_list_repeat(list_array, count_array)
991868
}
992-
data_type => {
993-
macro_rules! array_function {
994-
($ARRAY_TYPE:ident) => {
995-
general_repeat!(element, count, $ARRAY_TYPE)
996-
};
869+
_ => general_repeat(element, count_array),
870+
}
871+
}
872+
873+
/// For each element of `array[i]` repeat `count_array[i]` times.
874+
///
875+
/// Assumption for the input:
876+
/// 1. `count[i] >= 0`
877+
/// 2. `array.len() == count_array.len()`
878+
///
879+
/// For example,
880+
/// ```text
881+
/// array_repeat(
882+
/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]]
883+
/// )
884+
/// ```
885+
fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result<ArrayRef> {
886+
let data_type = array.data_type();
887+
let mut new_values = vec![];
888+
889+
let count_vec = count_array
890+
.values()
891+
.to_vec()
892+
.iter()
893+
.map(|x| *x as usize)
894+
.collect::<Vec<_>>();
895+
896+
for (row_index, &count) in count_vec.iter().enumerate() {
897+
let repeated_array = if array.is_null(row_index) {
898+
new_null_array(data_type, count)
899+
} else {
900+
let original_data = array.to_data();
901+
let capacity = Capacities::Array(count);
902+
let mut mutable =
903+
MutableArrayData::with_capacities(vec![&original_data], false, capacity);
904+
905+
for _ in 0..count {
906+
mutable.extend(0, row_index, row_index + 1);
997907
}
998-
call_array_function!(data_type, false)
999-
}
1000-
};
1001908

1002-
Ok(res)
909+
let data = mutable.freeze();
910+
arrow_array::make_array(data)
911+
};
912+
new_values.push(repeated_array);
913+
}
914+
915+
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
916+
let values = arrow::compute::concat(&new_values)?;
917+
918+
Ok(Arc::new(ListArray::try_new(
919+
Arc::new(Field::new("item", data_type.to_owned(), true)),
920+
OffsetBuffer::from_lengths(count_vec),
921+
values,
922+
None,
923+
)?))
924+
}
925+
926+
/// Handle List version of `general_repeat`
927+
///
928+
/// For each element of `list_array[i]` repeat `count_array[i]` times.
929+
///
930+
/// For example,
931+
/// ```text
932+
/// array_repeat(
933+
/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]]
934+
/// )
935+
/// ```
936+
fn general_list_repeat(
937+
list_array: &ListArray,
938+
count_array: &Int64Array,
939+
) -> Result<ArrayRef> {
940+
let data_type = list_array.data_type();
941+
let value_type = list_array.value_type();
942+
let mut new_values = vec![];
943+
944+
let count_vec = count_array
945+
.values()
946+
.to_vec()
947+
.iter()
948+
.map(|x| *x as usize)
949+
.collect::<Vec<_>>();
950+
951+
for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) {
952+
let list_arr = match list_array_row {
953+
Some(list_array_row) => {
954+
let original_data = list_array_row.to_data();
955+
let capacity = Capacities::Array(original_data.len() * count);
956+
let mut mutable = MutableArrayData::with_capacities(
957+
vec![&original_data],
958+
false,
959+
capacity,
960+
);
961+
962+
for _ in 0..count {
963+
mutable.extend(0, 0, original_data.len());
964+
}
965+
966+
let data = mutable.freeze();
967+
let repeated_array = arrow_array::make_array(data);
968+
969+
let list_arr = ListArray::try_new(
970+
Arc::new(Field::new("item", value_type.clone(), true)),
971+
OffsetBuffer::from_lengths(vec![original_data.len(); count]),
972+
repeated_array,
973+
None,
974+
)?;
975+
Arc::new(list_arr) as ArrayRef
976+
}
977+
None => new_null_array(data_type, count),
978+
};
979+
new_values.push(list_arr);
980+
}
981+
982+
let lengths = new_values.iter().map(|a| a.len()).collect::<Vec<_>>();
983+
let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect();
984+
let values = arrow::compute::concat(&new_values)?;
985+
986+
Ok(Arc::new(ListArray::try_new(
987+
Arc::new(Field::new("item", data_type.to_owned(), true)),
988+
OffsetBuffer::from_lengths(lengths),
989+
values,
990+
None,
991+
)?))
1003992
}
1004993

1005994
macro_rules! position {
@@ -2925,55 +2914,6 @@ mod tests {
29252914
);
29262915
}
29272916

2928-
#[test]
2929-
fn test_array_repeat() {
2930-
// array_repeat(3, 5) = [3, 3, 3, 3, 3]
2931-
let array = array_repeat(&[
2932-
Arc::new(Int64Array::from_value(3, 1)),
2933-
Arc::new(Int64Array::from_value(5, 1)),
2934-
])
2935-
.expect("failed to initialize function array_repeat");
2936-
let result =
2937-
as_list_array(&array).expect("failed to initialize function array_repeat");
2938-
2939-
assert_eq!(result.len(), 1);
2940-
assert_eq!(
2941-
&[3, 3, 3, 3, 3],
2942-
result
2943-
.value(0)
2944-
.as_any()
2945-
.downcast_ref::<Int64Array>()
2946-
.unwrap()
2947-
.values()
2948-
);
2949-
}
2950-
2951-
#[test]
2952-
fn test_nested_array_repeat() {
2953-
// array_repeat([1, 2, 3, 4], 3) = [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
2954-
let element = return_array();
2955-
let array = array_repeat(&[element, Arc::new(Int64Array::from_value(3, 1))])
2956-
.expect("failed to initialize function array_repeat");
2957-
let result =
2958-
as_list_array(&array).expect("failed to initialize function array_repeat");
2959-
2960-
assert_eq!(result.len(), 1);
2961-
let data = vec![
2962-
Some(vec![Some(1), Some(2), Some(3), Some(4)]),
2963-
Some(vec![Some(1), Some(2), Some(3), Some(4)]),
2964-
Some(vec![Some(1), Some(2), Some(3), Some(4)]),
2965-
];
2966-
let expected = ListArray::from_iter_primitive::<Int64Type, _, _>(data);
2967-
assert_eq!(
2968-
expected,
2969-
result
2970-
.value(0)
2971-
.as_any()
2972-
.downcast_ref::<ListArray>()
2973-
.unwrap()
2974-
.clone()
2975-
);
2976-
}
29772917
#[test]
29782918
fn test_array_to_string() {
29792919
// array_to_string([1, 2, 3, 4], ',') = 1,2,3,4

0 commit comments

Comments
 (0)