|
16 | 16 | // under the License. |
17 | 17 |
|
18 | 18 | use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch}; |
19 | | -use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait}; |
| 19 | +use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait, StructArray}; |
20 | 20 | use arrow_schema::{DataType, FieldRef, Schema}; |
21 | 21 | use datafusion::logical_expr::ColumnarValue; |
22 | 22 | use datafusion::physical_expr_common::physical_expr::down_cast_any_ref; |
@@ -275,6 +275,144 @@ impl PartialEq<dyn Any> for ListExtract { |
275 | 275 | } |
276 | 276 | } |
277 | 277 |
|
| 278 | +#[derive(Debug, Hash)] |
| 279 | +pub struct GetArrayStructFields { |
| 280 | + child: Arc<dyn PhysicalExpr>, |
| 281 | + ordinal: usize, |
| 282 | +} |
| 283 | + |
| 284 | +impl GetArrayStructFields { |
| 285 | + pub fn new(child: Arc<dyn PhysicalExpr>, ordinal: usize) -> Self { |
| 286 | + Self { child, ordinal } |
| 287 | + } |
| 288 | + |
| 289 | + fn list_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> { |
| 290 | + match self.child.data_type(input_schema)? { |
| 291 | + DataType::List(field) | DataType::LargeList(field) => Ok(field), |
| 292 | + data_type => Err(DataFusionError::Internal(format!( |
| 293 | + "Unexpected data type in GetArrayStructFields: {:?}", |
| 294 | + data_type |
| 295 | + ))), |
| 296 | + } |
| 297 | + } |
| 298 | + |
| 299 | + fn child_field(&self, input_schema: &Schema) -> DataFusionResult<FieldRef> { |
| 300 | + match self.list_field(input_schema)?.data_type() { |
| 301 | + DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), |
| 302 | + data_type => Err(DataFusionError::Internal(format!( |
| 303 | + "Unexpected data type in GetArrayStructFields: {:?}", |
| 304 | + data_type |
| 305 | + ))), |
| 306 | + } |
| 307 | + } |
| 308 | +} |
| 309 | + |
| 310 | +impl PhysicalExpr for GetArrayStructFields { |
| 311 | + fn as_any(&self) -> &dyn Any { |
| 312 | + self |
| 313 | + } |
| 314 | + |
| 315 | + fn data_type(&self, input_schema: &Schema) -> DataFusionResult<DataType> { |
| 316 | + let struct_field = self.child_field(input_schema)?; |
| 317 | + match self.child.data_type(input_schema)? { |
| 318 | + DataType::List(_) => Ok(DataType::List(struct_field)), |
| 319 | + DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)), |
| 320 | + data_type => Err(DataFusionError::Internal(format!( |
| 321 | + "Unexpected data type in GetArrayStructFields: {:?}", |
| 322 | + data_type |
| 323 | + ))), |
| 324 | + } |
| 325 | + } |
| 326 | + |
| 327 | + fn nullable(&self, input_schema: &Schema) -> DataFusionResult<bool> { |
| 328 | + Ok(self.list_field(input_schema)?.is_nullable() |
| 329 | + || self.child_field(input_schema)?.is_nullable()) |
| 330 | + } |
| 331 | + |
| 332 | + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult<ColumnarValue> { |
| 333 | + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; |
| 334 | + |
| 335 | + match child_value.data_type() { |
| 336 | + DataType::List(_) => { |
| 337 | + let list_array = as_list_array(&child_value)?; |
| 338 | + |
| 339 | + get_array_struct_fields(list_array, self.ordinal) |
| 340 | + } |
| 341 | + DataType::LargeList(_) => { |
| 342 | + let list_array = as_large_list_array(&child_value)?; |
| 343 | + |
| 344 | + get_array_struct_fields(list_array, self.ordinal) |
| 345 | + } |
| 346 | + data_type => Err(DataFusionError::Internal(format!( |
| 347 | + "Unexpected child type for ListExtract: {:?}", |
| 348 | + data_type |
| 349 | + ))), |
| 350 | + } |
| 351 | + } |
| 352 | + |
| 353 | + fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> { |
| 354 | + vec![&self.child] |
| 355 | + } |
| 356 | + |
| 357 | + fn with_new_children( |
| 358 | + self: Arc<Self>, |
| 359 | + children: Vec<Arc<dyn PhysicalExpr>>, |
| 360 | + ) -> datafusion_common::Result<Arc<dyn PhysicalExpr>> { |
| 361 | + match children.len() { |
| 362 | + 1 => Ok(Arc::new(GetArrayStructFields::new( |
| 363 | + Arc::clone(&children[0]), |
| 364 | + self.ordinal, |
| 365 | + ))), |
| 366 | + _ => internal_err!("GetArrayStructFields should have exactly one child"), |
| 367 | + } |
| 368 | + } |
| 369 | + |
| 370 | + fn dyn_hash(&self, state: &mut dyn Hasher) { |
| 371 | + let mut s = state; |
| 372 | + self.child.hash(&mut s); |
| 373 | + self.ordinal.hash(&mut s); |
| 374 | + self.hash(&mut s); |
| 375 | + } |
| 376 | +} |
| 377 | + |
| 378 | +fn get_array_struct_fields<O: OffsetSizeTrait>( |
| 379 | + list_array: &GenericListArray<O>, |
| 380 | + ordinal: usize, |
| 381 | +) -> DataFusionResult<ColumnarValue> { |
| 382 | + let values = list_array |
| 383 | + .values() |
| 384 | + .as_any() |
| 385 | + .downcast_ref::<StructArray>() |
| 386 | + .expect("A struct is expected"); |
| 387 | + |
| 388 | + let column = Arc::clone(values.column(ordinal)); |
| 389 | + let field = Arc::clone(&values.fields()[ordinal]); |
| 390 | + |
| 391 | + let offsets = list_array.offsets(); |
| 392 | + let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned()); |
| 393 | + |
| 394 | + Ok(ColumnarValue::Array(Arc::new(array))) |
| 395 | +} |
| 396 | + |
| 397 | +impl Display for GetArrayStructFields { |
| 398 | + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { |
| 399 | + write!( |
| 400 | + f, |
| 401 | + "GetArrayStructFields [child: {:?}, ordinal: {:?}]", |
| 402 | + self.child, self.ordinal |
| 403 | + ) |
| 404 | + } |
| 405 | +} |
| 406 | + |
| 407 | +impl PartialEq<dyn Any> for GetArrayStructFields { |
| 408 | + fn eq(&self, other: &dyn Any) -> bool { |
| 409 | + down_cast_any_ref(other) |
| 410 | + .downcast_ref::<Self>() |
| 411 | + .map(|x| self.child.eq(&x.child) && self.ordinal.eq(&x.ordinal)) |
| 412 | + .unwrap_or(false) |
| 413 | + } |
| 414 | +} |
| 415 | + |
278 | 416 | #[cfg(test)] |
279 | 417 | mod test { |
280 | 418 | use crate::list::{list_extract, zero_based_index}; |
|
0 commit comments