diff --git a/vortex-array/src/variants.rs b/vortex-array/src/variants.rs index 4ce126acaf2..220f1e618f9 100644 --- a/vortex-array/src/variants.rs +++ b/vortex-array/src/variants.rs @@ -7,9 +7,13 @@ use std::sync::Arc; use vortex_dtype::{DType, ExtDType, Field, FieldInfo, FieldNames, PType}; use vortex_error::{vortex_panic, VortexError, VortexExpect as _, VortexResult}; +use vortex_scalar::Scalar; +use crate::array::ConstantArray; +use crate::compute::{invert, mask, try_cast, FilterMask}; use crate::encoding::Encoding; -use crate::{ArrayDType, ArrayData, ArrayTrait}; +use crate::validity::LogicalValidity; +use crate::{ArrayDType, ArrayData, ArrayTrait, IntoArrayData as _}; /// An Array encoding must declare which DTypes it can be downcast into. pub trait VariantsVTable { @@ -208,6 +212,161 @@ pub trait StructArrayTrait: ArrayTrait { self.names().len() } + /// Return a field's array by index, masking by the struct's validity. + /// + /// If either this array or the field array is invalid at a position, the result is invalid at + /// that position. Consequently, if either array is nullable, the result is nullable. + /// + /// # Examples + /// + /// The field of a non-nullable struct array is the same whether accessed by + /// [StructArrayTrait::field_by_idx] or [StructArrayTrait::maybe_null_field_by_idx]: + /// + /// ``` + /// use vortex_array::array::{BoolArray, PrimitiveArray, StructArray}; + /// use vortex_array::validity::{ArrayValidity, Validity}; + /// use vortex_array::variants::StructArrayTrait; + /// use vortex_array::{ArrayDType, IntoArrayData}; + /// use vortex_dtype::FieldNames; + /// + /// let original_field = PrimitiveArray::from_option_iter([ + /// Some(1), None, Some(3), None, Some(4), + /// ]).into_array(); + /// let array = StructArray::try_new( + /// FieldNames::from(["a".into()]), + /// vec![original_field], + /// 5, + /// Validity::NonNullable, + /// ).unwrap(); + /// let field = array.field_by_idx(0).unwrap().unwrap(); + /// let maybe_null_field = array.maybe_null_field_by_idx(0).unwrap(); + /// + /// assert_eq!(field.dtype(), maybe_null_field.dtype()); + /// assert!((0..field.len()).all(|i| { + /// field.is_valid(i) == maybe_null_field.is_valid(i) + /// })); + /// ``` + /// + /// When both a struct and its field are nullable, [StructArrayTrait::field_by_idx] returns the + /// intersection of the validity, which is to say: a position is valid if and only if both the + /// struct and the field are valid at that position. + /// + /// ``` + /// use vortex_array::array::{BoolArray, PrimitiveArray, StructArray}; + /// use vortex_array::compute::scalar_at; + /// use vortex_array::validity::{ArrayValidity, Validity}; + /// use vortex_array::variants::StructArrayTrait; + /// use vortex_array::{ArrayDType, IntoArrayData}; + /// use vortex_dtype::FieldNames; + /// use vortex_scalar::Scalar; + /// + /// let original_field = PrimitiveArray::from_option_iter([ + /// Some(1), None, Some(3), None, Some(5), + /// ]).into_array(); + /// let struct_validity = Validity::Array(BoolArray::from_iter([ + /// true, true, false, false, true, + /// ]).into_array()); + /// let array = StructArray::try_new( + /// FieldNames::from(["a".into()]), + /// vec![original_field], + /// 5, + /// struct_validity, + /// ).unwrap(); + /// let field = array.field_by_idx(0).unwrap().unwrap(); + /// + /// assert!(field.dtype().is_nullable()); + /// assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1))); + /// assert!(!field.is_valid(1)); + /// assert!(!field.is_valid(2)); + /// assert!(!field.is_valid(3)); + /// assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5))); + /// ``` + /// + /// When a field is non-nullable, but the struct is nullable, the field receives the struct's + /// validity. + /// + /// ``` + /// use vortex_array::array::{BoolArray, StructArray}; + /// use vortex_array::compute::scalar_at; + /// use vortex_array::validity::{ArrayValidity, Validity}; + /// use vortex_array::variants::StructArrayTrait; + /// use vortex_array::{ArrayDType, IntoArrayData}; + /// use vortex_buffer::buffer; + /// use vortex_dtype::FieldNames; + /// use vortex_scalar::Scalar; + /// + /// let original_field = buffer![1, 2, 3, 4, 5].into_array(); + /// let struct_validity = Validity::Array(BoolArray::from_iter([ + /// true, true, false, false, true, + /// ]).into_array()); + /// let array = StructArray::try_new( + /// FieldNames::from(["a".into()]), + /// vec![original_field], + /// 5, + /// struct_validity, + /// ).unwrap(); + /// let field = array.field_by_idx(0).unwrap().unwrap(); + /// + /// assert!(field.dtype().is_nullable()); + /// assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1))); + /// assert_eq!(scalar_at(&field, 1).unwrap(), Scalar::from(Some(2))); + /// assert!(!field.is_valid(2)); + /// assert!(!field.is_valid(3)); + /// assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5))); + /// ``` + fn field_by_idx(&self, idx: usize) -> VortexResult> { + let Some(maybe_null_field) = self.maybe_null_field_by_idx(idx) else { + return Ok(None); + }; + + if !self.dtype().is_nullable() { + return Ok(Some(maybe_null_field)); + } + + match self.logical_validity() { + LogicalValidity::AllValid(_) => { + let nullable_dtype = maybe_null_field.dtype().as_nullable(); + try_cast(maybe_null_field, &nullable_dtype).map(Some) + } + LogicalValidity::AllInvalid(_) => { + let nullable_dtype = maybe_null_field.dtype().as_nullable(); + + Ok(Some( + ConstantArray::new(Scalar::null(nullable_dtype), maybe_null_field.len()) + .into_array(), + )) + } + LogicalValidity::Array(is_valid) => { + mask(&maybe_null_field, FilterMask::try_from(invert(&is_valid)?)?).map(Some) + } + } + } + + /// Return a field's array by name, masking by the struct's validity. + /// + /// See also [StructArrayTrait::field_by_idx]. + fn field_by_name(&self, name: &str) -> VortexResult> { + let field_idx = self + .names() + .iter() + .position(|field_name| field_name.as_ref() == name); + + match field_idx { + None => Ok(None), + Some(field_idx) => self.field_by_idx(field_idx), + } + } + + /// Return a field's array by name or index, masking by the struct's validity. + /// + /// See also [StructArrayTrait::field_by_idx]. + fn field(&self, field: &Field) -> VortexResult> { + match field { + Field::Index(idx) => self.field_by_idx(*idx), + Field::Name(name) => self.field_by_name(name.as_ref()), + } + } + /// Return a field's array by index, ignoring struct nullability fn maybe_null_field_by_idx(&self, idx: usize) -> Option; @@ -221,6 +380,7 @@ pub trait StructArrayTrait: ArrayTrait { field_idx.and_then(|field_idx| self.maybe_null_field_by_idx(field_idx)) } + /// Return a field's array by name or index, ignoring struct nullability fn maybe_null_field(&self, field: &Field) -> Option { match field { Field::Index(idx) => self.maybe_null_field_by_idx(*idx), @@ -245,3 +405,69 @@ pub trait ExtensionArrayTrait: ArrayTrait { /// Returns the underlying [`ArrayData`], without the [`ExtDType`]. fn storage_data(&self) -> ArrayData; } + +#[cfg(test)] +mod tests { + use vortex_buffer::buffer; + use vortex_dtype::FieldNames; + use vortex_scalar::Scalar; + + use crate::array::{BoolArray, PrimitiveArray, StructArray}; + use crate::compute::scalar_at; + use crate::validity::{ArrayValidity, Validity}; + use crate::variants::StructArrayTrait; + use crate::{ArrayDType, IntoArrayData}; + + #[test] + fn test_field() { + let original_field = + PrimitiveArray::from_option_iter([Some(1), None, Some(3), None, Some(5)]).into_array(); + let array = StructArray::try_new( + FieldNames::from(["a".into()]), + vec![original_field.clone()], + 5, + Validity::NonNullable, + ) + .unwrap(); + let field = array.field_by_idx(0).unwrap().unwrap(); + let maybe_null_field = array.maybe_null_field_by_idx(0).unwrap(); + + assert_eq!(field.dtype(), maybe_null_field.dtype()); + assert!((0..field.len()).all(|i| { field.is_valid(i) == maybe_null_field.is_valid(i) })); + + let struct_validity = + Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array()); + let array = StructArray::try_new( + FieldNames::from(["a".into()]), + vec![original_field], + 5, + struct_validity.clone(), + ) + .unwrap(); + let field = array.field_by_idx(0).unwrap().unwrap(); + + assert!(field.dtype().is_nullable()); + assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1))); + assert!(!field.is_valid(1)); + assert!(!field.is_valid(2)); + assert!(!field.is_valid(3)); + assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5))); + + let original_field = buffer![1, 2, 3, 4, 5].into_array(); + let array = StructArray::try_new( + FieldNames::from(["a".into()]), + vec![original_field], + 5, + struct_validity, + ) + .unwrap(); + let field = array.field_by_idx(0).unwrap().unwrap(); + + assert!(field.dtype().is_nullable()); + assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1))); + assert_eq!(scalar_at(&field, 1).unwrap(), Scalar::from(Some(2))); + assert!(!field.is_valid(2)); + assert!(!field.is_valid(3)); + assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5))); + } +}