Skip to content

Commit b131cc3

Browse files
authored
feat: Support GetArrayStructFields expression (#993)
* Start working on GetArrayStructFIelds * Almost have working * Working * Add another test * Remove unused * Remove unused sql conf
1 parent 3413397 commit b131cc3

File tree

6 files changed

+198
-4
lines changed

6 files changed

+198
-4
lines changed

native/core/src/execution/datafusion/planner.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ use datafusion_comet_proto::{
9696
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
9797
};
9898
use datafusion_comet_spark_expr::{
99-
Cast, CreateNamedStruct, DateTruncExpr, GetStructField, HourExpr, IfExpr, ListExtract,
100-
MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
99+
Cast, CreateNamedStruct, DateTruncExpr, GetArrayStructFields, GetStructField, HourExpr, IfExpr,
100+
ListExtract, MinuteExpr, RLike, SecondExpr, TimestampTruncExpr, ToJson,
101101
};
102102
use datafusion_common::scalar::ScalarStructBuilder;
103103
use datafusion_common::{
@@ -680,6 +680,15 @@ impl PhysicalPlanner {
680680
expr.fail_on_error,
681681
)))
682682
}
683+
ExprStruct::GetArrayStructFields(expr) => {
684+
let child =
685+
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
686+
687+
Ok(Arc::new(GetArrayStructFields::new(
688+
child,
689+
expr.ordinal as usize,
690+
)))
691+
}
683692
expr => Err(ExecutionError::GeneralError(format!(
684693
"Not implemented: {:?}",
685694
expr

native/proto/src/proto/expr.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ message Expr {
8181
GetStructField get_struct_field = 54;
8282
ToJson to_json = 55;
8383
ListExtract list_extract = 56;
84+
GetArrayStructFields get_array_struct_fields = 57;
8485
}
8586
}
8687

@@ -517,6 +518,11 @@ message ListExtract {
517518
bool fail_on_error = 5;
518519
}
519520

521+
message GetArrayStructFields {
522+
Expr child = 1;
523+
int32 ordinal = 2;
524+
}
525+
520526
enum SortDirection {
521527
Ascending = 0;
522528
Descending = 1;

native/spark-expr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mod xxhash64;
3838
pub use cast::{spark_cast, Cast};
3939
pub use error::{SparkError, SparkResult};
4040
pub use if_expr::IfExpr;
41-
pub use list::ListExtract;
41+
pub use list::{GetArrayStructFields, ListExtract};
4242
pub use regexp::RLike;
4343
pub use structs::{CreateNamedStruct, GetStructField};
4444
pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr};

native/spark-expr/src/list.rs

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
// under the License.
1717

1818
use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch};
19-
use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait};
19+
use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait, StructArray};
2020
use arrow_schema::{DataType, FieldRef, Schema};
2121
use datafusion::logical_expr::ColumnarValue;
2222
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
@@ -275,6 +275,144 @@ impl PartialEq<dyn Any> for ListExtract {
275275
}
276276
}
277277

278+
#[derive(Debug, Hash)]
279+
pub struct GetArrayStructFields {
280+
child: Arc<dyn PhysicalExpr>,
281+
ordinal: usize,
282+
}
283+
284+
impl GetArrayStructFields {
285+
pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self {
286+
Self { child, ordinal }
287+
}
288+
289+
fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
290+
match self.child.data_type(input_schema)? {
291+
DataType::List(field) | DataType::LargeList(field) => Ok(field),
292+
data_type => Err(DataFusionError::Internal(format!(
293+
"Unexpected data type in GetArrayStructFields: {:?}",
294+
data_type
295+
))),
296+
}
297+
}
298+
299+
fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> {
300+
match self.list_field(input_schema)?.data_type() {
301+
DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])),
302+
data_type => Err(DataFusionError::Internal(format!(
303+
"Unexpected data type in GetArrayStructFields: {:?}",
304+
data_type
305+
))),
306+
}
307+
}
308+
}
309+
310+
impl PhysicalExpr for GetArrayStructFields {
311+
fn as_any(&self) -> &dyn Any {
312+
self
313+
}
314+
315+
fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> {
316+
let struct_field = self.child_field(input_schema)?;
317+
match self.child.data_type(input_schema)? {
318+
DataType::List(_) => Ok(DataType::List(struct_field)),
319+
DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)),
320+
data_type => Err(DataFusionError::Internal(format!(
321+
"Unexpected data type in GetArrayStructFields: {:?}",
322+
data_type
323+
))),
324+
}
325+
}
326+
327+
fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> {
328+
Ok(self.list_field(input_schema)?.is_nullable()
329+
|| self.child_field(input_schema)?.is_nullable())
330+
}
331+
332+
fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> {
333+
let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?;
334+
335+
match child_value.data_type() {
336+
DataType::List(_) => {
337+
let list_array = as_list_array(&child_value)?;
338+
339+
get_array_struct_fields(list_array, self.ordinal)
340+
}
341+
DataType::LargeList(_) => {
342+
let list_array = as_large_list_array(&child_value)?;
343+
344+
get_array_struct_fields(list_array, self.ordinal)
345+
}
346+
data_type => Err(DataFusionError::Internal(format!(
347+
"Unexpected child type for ListExtract: {:?}",
348+
data_type
349+
))),
350+
}
351+
}
352+
353+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
354+
vec![&self.child]
355+
}
356+
357+
fn with_new_children(
358+
self: Arc<Self>,
359+
children: Vec<Arc<dyn PhysicalExpr>>,
360+
) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> {
361+
match children.len() {
362+
1 => Ok(Arc::new(GetArrayStructFields::new(
363+
Arc::clone(&children[0]),
364+
self.ordinal,
365+
))),
366+
_ => internal_err!("GetArrayStructFields should have exactly one child"),
367+
}
368+
}
369+
370+
fn dyn_hash(&self, state: &mut dyn Hasher) {
371+
let mut s = state;
372+
self.child.hash(&mut s);
373+
self.ordinal.hash(&mut s);
374+
self.hash(&mut s);
375+
}
376+
}
377+
378+
fn get_array_struct_fields<O: OffsetSizeTrait>(
379+
list_array: &GenericListArray<O>,
380+
ordinal: usize,
381+
) -> DataFusionResult<ColumnarValue> {
382+
let values = list_array
383+
.values()
384+
.as_any()
385+
.downcast_ref::<StructArray>()
386+
.expect("A struct is expected");
387+
388+
let column = Arc::clone(values.column(ordinal));
389+
let field = Arc::clone(&values.fields()[ordinal]);
390+
391+
let offsets = list_array.offsets();
392+
let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned());
393+
394+
Ok(ColumnarValue::Array(Arc::new(array)))
395+
}
396+
397+
impl Display for GetArrayStructFields {
398+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
399+
write!(
400+
f,
401+
"GetArrayStructFields [child: {:?}, ordinal: {:?}]",
402+
self.child, self.ordinal
403+
)
404+
}
405+
}
406+
407+
impl PartialEq<dyn Any> for GetArrayStructFields {
408+
fn eq(&self, other: &dyn Any) -> bool {
409+
down_cast_any_ref(other)
410+
.downcast_ref::<Self>()
411+
.map(|x| self.child.eq(&x.child) && self.ordinal.eq(&x.ordinal))
412+
.unwrap_or(false)
413+
}
414+
}
415+
278416
#[cfg(test)]
279417
mod test {
280418
use crate::list::{list_extract, zero_based_index};

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2542,6 +2542,25 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
25422542
None
25432543
}
25442544

2545+
case GetArrayStructFields(child, _, ordinal, _, _) =>
2546+
val childExpr = exprToProto(child, inputs, binding)
2547+
2548+
if (childExpr.isDefined) {
2549+
val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
2550+
.newBuilder()
2551+
.setChild(childExpr.get)
2552+
.setOrdinal(ordinal)
2553+
2554+
Some(
2555+
ExprOuterClass.Expr
2556+
.newBuilder()
2557+
.setGetArrayStructFields(arrayStructFieldsBuilder)
2558+
.build())
2559+
} else {
2560+
withInfo(expr, "unsupported arguments for GetArrayStructFields", child)
2561+
None
2562+
}
2563+
25452564
case _ =>
25462565
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
25472566
None

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,4 +2271,26 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
22712271
}
22722272
}
22732273
}
2274+
2275+
test("GetArrayStructFields") {
2276+
Seq(true, false).foreach { dictionaryEnabled =>
2277+
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> SimplifyExtractValueOps.ruleName) {
2278+
withTempDir { dir =>
2279+
val path = new Path(dir.toURI.toString, "test.parquet")
2280+
makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000)
2281+
val df = spark.read
2282+
.parquet(path.toString)
2283+
.select(
2284+
array(struct(col("_2"), col("_3"), col("_4"), col("_8")), lit(null)).alias("arr"))
2285+
checkSparkAnswerAndOperator(df.select("arr._2", "arr._3", "arr._4"))
2286+
2287+
val complex = spark.read
2288+
.parquet(path.toString)
2289+
.select(array(struct(struct(col("_4"), col("_8")).alias("nested"))).alias("arr"))
2290+
2291+
checkSparkAnswerAndOperator(complex.select(col("arr.nested._4")))
2292+
}
2293+
}
2294+
}
2295+
}
22742296
}

0 commit comments

Comments
 (0)