diff --git a/vortex-array/src/array/struct_/mod.rs b/vortex-array/src/array/struct_/mod.rs index 9459db42c78..9d83681dd1f 100644 --- a/vortex-array/src/array/struct_/mod.rs +++ b/vortex-array/src/array/struct_/mod.rs @@ -54,6 +54,8 @@ impl StructArray { ) -> VortexResult { let nullability = validity.nullability(); + println!("null {:?}", nullability); + if names.len() != fields.len() { vortex_bail!("Got {} names and {} fields", names.len(), fields.len()); } diff --git a/vortex-layout/src/data.rs b/vortex-layout/src/data.rs index 1d5b07021e5..d4385cc600b 100644 --- a/vortex-layout/src/data.rs +++ b/vortex-layout/src/data.rs @@ -166,7 +166,11 @@ impl LayoutData { .vortex_expect("child bounds already checked")[i] .clone(); if child.dtype() != &dtype { - vortex_bail!("child dtype mismatch"); + vortex_bail!( + "child dtype mismatch expected {:?} found {:?}", + dtype, + child.dtype() + ); } Ok(child) } diff --git a/vortex-layout/src/layouts/struct_/eval_expr.rs b/vortex-layout/src/layouts/struct_/eval_expr.rs index 63dddc206ba..132d1e03c2f 100644 --- a/vortex-layout/src/layouts/struct_/eval_expr.rs +++ b/vortex-layout/src/layouts/struct_/eval_expr.rs @@ -3,14 +3,15 @@ use futures::future::try_join_all; use itertools::Itertools; use vortex_array::array::StructArray; use vortex_array::validity::Validity; -use vortex_array::{ArrayData, IntoArrayData}; +use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant}; +use vortex_dtype::Nullability; use vortex_error::VortexResult; use vortex_expr::transform::partition::partition; -use vortex_expr::ExprRef; +use vortex_expr::{ident, ExprRef}; use vortex_scan::RowMask; use crate::layouts::struct_::reader::StructReader; -use crate::ExprEvaluator; +use crate::{ExprEvaluator, LayoutReaderExt}; #[async_trait] impl ExprEvaluator for StructReader { @@ -34,6 +35,17 @@ impl ExprEvaluator for StructReader { ) .await?; + let validity = if self.dtype().nullability() == Nullability::Nullable { + let validity: ArrayData = self + .validity()? + .evaluate_expr(row_mask.clone(), ident()) + .await?; + let bool = validity.into_bool()?; + Validity::from(bool.boolean_buffer()) + } else { + Validity::NonNullable + }; + let row_count = row_mask.true_count(); debug_assert!(arrays.iter().all(|a| a.len() == row_count)); @@ -46,7 +58,7 @@ impl ExprEvaluator for StructReader { .into(), arrays, row_count, - Validity::NonNullable, + validity, )? .into_array(); @@ -60,10 +72,14 @@ mod tests { use std::sync::Arc; use futures::executor::block_on; - use vortex_array::array::StructArray; + use itertools::Itertools; + use vortex_array::array::{BoolArray, StructArray}; use vortex_array::compute::FilterMask; + use vortex_array::validity::Validity::NonNullable; + use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::{IntoArrayData, IntoArrayVariant}; use vortex_buffer::buffer; + use vortex_dtype::DType::Bool; use vortex_dtype::PType::I32; use vortex_dtype::{DType, Field, Nullability, StructDType}; use vortex_expr::{get_item, gt, ident, select}; @@ -76,7 +92,7 @@ mod tests { use crate::LayoutData; /// Create a chunked layout with three chunks of primitive arrays. - fn struct_layout() -> (Arc, LayoutData) { + fn struct_layout(validity: Validity) -> (Arc, LayoutData) { let mut segments = TestSegments::default(); let layout = StructLayoutWriter::new( @@ -85,23 +101,26 @@ mod tests { vec!["a".into(), "b".into(), "c".into()].into(), vec![I32.into(), I32.into(), I32.into()], ), - Nullability::NonNullable, + validity.nullability(), ), vec![ Box::new(FlatLayoutWriter::new(I32.into())), Box::new(FlatLayoutWriter::new(I32.into())), Box::new(FlatLayoutWriter::new(I32.into())), ], + Box::new(FlatLayoutWriter::new(Bool(Nullability::NonNullable))), ) .push_all( &mut segments, - [StructArray::from_fields( - [ - ("a", buffer![7, 2, 3].into_array()), - ("b", buffer![4, 5, 6].into_array()), - ("c", buffer![4, 5, 6].into_array()), - ] - .as_slice(), + [StructArray::try_new( + ["a".into(), "b".into(), "c".into()].into(), + vec![ + buffer![7, 2, 3].into_array(), + buffer![4, 5, 6].into_array(), + buffer![4, 5, 6].into_array(), + ], + 3, + validity, ) .map(IntoArrayData::into_array)], ) @@ -111,9 +130,26 @@ mod tests { #[test] fn test_struct_layout() { - let (segments, layout) = struct_layout(); + let (segments, layout) = struct_layout(NonNullable); let reader = layout.reader(segments, Default::default()).unwrap(); + + let expr = get_item("a", ident()); + let result = + block_on(reader.evaluate_expr(RowMask::new_valid_between(0, 3), expr)).unwrap(); + println!( + "result {:?}", + result.into_primitive().unwrap().as_slice::() + ); + + let expr = get_item("b", ident()); + let result = + block_on(reader.evaluate_expr(RowMask::new_valid_between(0, 3), expr)).unwrap(); + println!( + "result {:?}", + result.into_primitive().unwrap().as_slice::() + ); + let expr = gt(get_item("a", ident()), get_item("b", ident())); let result = block_on(reader.evaluate_expr(RowMask::new_valid_between(0, 3), expr)).unwrap(); @@ -130,7 +166,7 @@ mod tests { #[test] fn test_struct_layout_row_mask() { - let (segments, layout) = struct_layout(); + let (segments, layout) = struct_layout(NonNullable); let reader = layout.reader(segments, Default::default()).unwrap(); let expr = gt(get_item("a", ident()), get_item("b", ident())); @@ -156,7 +192,7 @@ mod tests { #[test] fn test_struct_layout_select() { - let (segments, layout) = struct_layout(); + let (segments, layout) = struct_layout(NonNullable); let reader = layout.reader(segments, Default::default()).unwrap(); let expr = select(vec!["a".into(), "b".into()], ident()); @@ -193,4 +229,39 @@ mod tests { [4, 5].as_slice() ); } + + #[test] + fn test_struct_nullable() { + let (segments, layout) = struct_layout(Validity::Array( + BoolArray::from_iter([false, true, true]).into_array(), + )); + + let reader = layout.reader(segments, Default::default()).unwrap(); + let expr = get_item("a", ident()); + let result = block_on(reader.evaluate_expr( + // Take rows 0 and 1, skip row 2, and anything after that + RowMask::new(FilterMask::from_iter([true, true, true]), 0), + expr, + )) + .unwrap(); + + assert_eq!(result.len(), 3); + + assert_eq!( + result + .logical_validity() + .into_array() + .into_bool() + .unwrap() + .boolean_buffer() + .iter() + .collect_vec(), + vec![false, true, true] + ); + + assert_eq!( + result.into_primitive().unwrap().as_slice::(), + [7, 2, 3].as_slice() + ); + } } diff --git a/vortex-layout/src/layouts/struct_/reader.rs b/vortex-layout/src/layouts/struct_/reader.rs index dab8390906b..ec48d8056ca 100644 --- a/vortex-layout/src/layouts/struct_/reader.rs +++ b/vortex-layout/src/layouts/struct_/reader.rs @@ -2,6 +2,7 @@ use std::sync::{Arc, OnceLock}; use vortex_array::aliases::hash_map::HashMap; use vortex_array::ContextRef; +use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, FieldName, StructDType}; use vortex_error::{vortex_err, vortex_panic, VortexExpect, VortexResult}; @@ -17,6 +18,9 @@ pub struct StructReader { segments: Arc, field_readers: Arc<[OnceLock>]>, + + validity_reader: Arc>>, + field_lookup: HashMap, } @@ -51,6 +55,7 @@ impl StructReader { ctx, segments, field_readers, + validity_reader: Arc::new(OnceLock::new()), field_lookup, }) } @@ -73,7 +78,14 @@ impl StructReader { self.field_readers[idx].get_or_try_init(|| { let child_layout = self .layout - .child(idx, self.struct_dtype().field_dtype(idx)?)?; + .child(idx + 1, self.struct_dtype().field_dtype(idx)?)?; + child_layout.reader(self.segments.clone(), self.ctx.clone()) + }) + } + + pub(crate) fn validity(&self) -> VortexResult<&Arc> { + self.validity_reader.get_or_try_init(|| { + let child_layout = self.layout.child(0, DType::Bool(NonNullable))?; child_layout.reader(self.segments.clone(), self.ctx.clone()) }) } diff --git a/vortex-layout/src/layouts/struct_/writer.rs b/vortex-layout/src/layouts/struct_/writer.rs index dda5a3cf770..c7b9978ba6c 100644 --- a/vortex-layout/src/layouts/struct_/writer.rs +++ b/vortex-layout/src/layouts/struct_/writer.rs @@ -1,5 +1,5 @@ use itertools::Itertools; -use vortex_array::ArrayData; +use vortex_array::{ArrayData, IntoArrayData}; use vortex_dtype::DType; use vortex_error::{vortex_bail, vortex_err, vortex_panic, VortexExpect, VortexResult}; @@ -11,12 +11,22 @@ use crate::strategies::{LayoutStrategy, LayoutWriter}; /// A [`LayoutWriter`] that splits a StructArray batch into child layout writers pub struct StructLayoutWriter { column_strategies: Vec>, + validity_strategy: Box, dtype: DType, row_count: u64, } +// TODO: add a LayoutWriterBuilder for +// - Add a validity child +// - Validity `() -> LayoutWrite` and +// - Fields `(Field) -> LayoutWriter` + impl StructLayoutWriter { - pub fn new(dtype: DType, column_layout_writers: Vec>) -> Self { + pub fn new( + dtype: DType, + column_layout_writers: Vec>, + validity_layout_writer: Box, + ) -> Self { let struct_dtype = dtype.as_struct().vortex_expect("dtype is not a struct"); if struct_dtype.dtypes().len() != column_layout_writers.len() { vortex_panic!( @@ -25,6 +35,7 @@ impl StructLayoutWriter { } Self { column_strategies: column_layout_writers, + validity_strategy: validity_layout_writer, dtype, row_count: 0, } @@ -41,6 +52,7 @@ impl StructLayoutWriter { .dtypes() .map(|dtype| factory.new_writer(&dtype)) .try_collect()?, + factory.new_writer(&DType::Bool(dtype.nullability()))?, )) } } @@ -72,11 +84,14 @@ impl LayoutWriter for StructLayoutWriter { self.column_strategies[i].push_chunk(segments, column)?; } + self.validity_strategy + .push_chunk(segments, struct_array.logical_validity().into_array())?; + Ok(()) } fn finish(&mut self, segments: &mut dyn SegmentWriter) -> VortexResult { - let mut column_layouts = vec![]; + let mut column_layouts = vec![self.validity_strategy.finish(segments)?]; for writer in self.column_strategies.iter_mut() { column_layouts.push(writer.finish(segments)?); } diff --git a/vortex-layout/src/strategies/struct_of_chunks.rs b/vortex-layout/src/strategies/struct_of_chunks.rs index 74d2c82a226..a70995bdd28 100644 --- a/vortex-layout/src/strategies/struct_of_chunks.rs +++ b/vortex-layout/src/strategies/struct_of_chunks.rs @@ -28,6 +28,7 @@ impl LayoutStrategy for StructOfChunks { .dtypes() .map(|col_dtype| default_column_layout(&col_dtype)) .collect(), + default_column_layout(&DType::Bool(*nullability)), ))) } _ => Ok(default_column_layout(dtype)),