Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 227 additions & 1 deletion vortex-array/src/variants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ use std::sync::Arc;

use vortex_dtype::{DType, ExtDType, Field, FieldInfo, FieldNames, PType};
use vortex_error::{vortex_panic, VortexError, VortexExpect as _, VortexResult};
use vortex_scalar::Scalar;

use crate::array::ConstantArray;
use crate::compute::{invert, mask, try_cast, FilterMask};
use crate::encoding::Encoding;
use crate::{ArrayDType, ArrayData, ArrayTrait};
use crate::validity::LogicalValidity;
use crate::{ArrayDType, ArrayData, ArrayTrait, IntoArrayData as _};

/// An Array encoding must declare which DTypes it can be downcast into.
pub trait VariantsVTable<Array> {
Expand Down Expand Up @@ -208,6 +212,161 @@ pub trait StructArrayTrait: ArrayTrait {
self.names().len()
}

/// Return a field's array by index, masking by the struct's validity.
///
/// If either this array or the field array is invalid at a position, the result is invalid at
/// that position. Consequently, if either array is nullable, the result is nullable.
///
/// # Examples
///
/// The field of a non-nullable struct array is the same whether accessed by
/// [StructArrayTrait::field_by_idx] or [StructArrayTrait::maybe_null_field_by_idx]:
///
/// ```
/// use vortex_array::array::{BoolArray, PrimitiveArray, StructArray};
/// use vortex_array::validity::{ArrayValidity, Validity};
/// use vortex_array::variants::StructArrayTrait;
/// use vortex_array::{ArrayDType, IntoArrayData};
/// use vortex_dtype::FieldNames;
///
/// let original_field = PrimitiveArray::from_option_iter([
/// Some(1), None, Some(3), None, Some(4),
/// ]).into_array();
/// let array = StructArray::try_new(
/// FieldNames::from(["a".into()]),
/// vec![original_field],
/// 5,
/// Validity::NonNullable,
/// ).unwrap();
/// let field = array.field_by_idx(0).unwrap().unwrap();
/// let maybe_null_field = array.maybe_null_field_by_idx(0).unwrap();
///
/// assert_eq!(field.dtype(), maybe_null_field.dtype());
/// assert!((0..field.len()).all(|i| {
/// field.is_valid(i) == maybe_null_field.is_valid(i)
/// }));
/// ```
///
/// When both a struct and its field are nullable, [StructArrayTrait::field_by_idx] returns the
/// intersection of the validity, which is to say: a position is valid if and only if both the
/// struct and the field are valid at that position.
///
/// ```
/// use vortex_array::array::{BoolArray, PrimitiveArray, StructArray};
/// use vortex_array::compute::scalar_at;
/// use vortex_array::validity::{ArrayValidity, Validity};
/// use vortex_array::variants::StructArrayTrait;
/// use vortex_array::{ArrayDType, IntoArrayData};
/// use vortex_dtype::FieldNames;
/// use vortex_scalar::Scalar;
///
/// let original_field = PrimitiveArray::from_option_iter([
/// Some(1), None, Some(3), None, Some(5),
/// ]).into_array();
/// let struct_validity = Validity::Array(BoolArray::from_iter([
/// true, true, false, false, true,
/// ]).into_array());
/// let array = StructArray::try_new(
/// FieldNames::from(["a".into()]),
/// vec![original_field],
/// 5,
/// struct_validity,
/// ).unwrap();
/// let field = array.field_by_idx(0).unwrap().unwrap();
///
/// assert!(field.dtype().is_nullable());
/// assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1)));
/// assert!(!field.is_valid(1));
/// assert!(!field.is_valid(2));
/// assert!(!field.is_valid(3));
/// assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5)));
/// ```
///
/// When a field is non-nullable, but the struct is nullable, the field receives the struct's
/// validity.
///
/// ```
/// use vortex_array::array::{BoolArray, StructArray};
/// use vortex_array::compute::scalar_at;
/// use vortex_array::validity::{ArrayValidity, Validity};
/// use vortex_array::variants::StructArrayTrait;
/// use vortex_array::{ArrayDType, IntoArrayData};
/// use vortex_buffer::buffer;
/// use vortex_dtype::FieldNames;
/// use vortex_scalar::Scalar;
///
/// let original_field = buffer![1, 2, 3, 4, 5].into_array();
/// let struct_validity = Validity::Array(BoolArray::from_iter([
/// true, true, false, false, true,
/// ]).into_array());
/// let array = StructArray::try_new(
/// FieldNames::from(["a".into()]),
/// vec![original_field],
/// 5,
/// struct_validity,
/// ).unwrap();
/// let field = array.field_by_idx(0).unwrap().unwrap();
///
/// assert!(field.dtype().is_nullable());
/// assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1)));
/// assert_eq!(scalar_at(&field, 1).unwrap(), Scalar::from(Some(2)));
/// assert!(!field.is_valid(2));
/// assert!(!field.is_valid(3));
/// assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5)));
/// ```
fn field_by_idx(&self, idx: usize) -> VortexResult<Option<ArrayData>> {
let Some(maybe_null_field) = self.maybe_null_field_by_idx(idx) else {
return Ok(None);
};

if !self.dtype().is_nullable() {
return Ok(Some(maybe_null_field));
}

match self.logical_validity() {
LogicalValidity::AllValid(_) => {
let nullable_dtype = maybe_null_field.dtype().as_nullable();
try_cast(maybe_null_field, &nullable_dtype).map(Some)
}
LogicalValidity::AllInvalid(_) => {
let nullable_dtype = maybe_null_field.dtype().as_nullable();

Ok(Some(
ConstantArray::new(Scalar::null(nullable_dtype), maybe_null_field.len())
.into_array(),
))
}
LogicalValidity::Array(is_valid) => {
mask(&maybe_null_field, FilterMask::try_from(invert(&is_valid)?)?).map(Some)
}
}
}

/// Return a field's array by name, masking by the struct's validity.
///
/// See also [StructArrayTrait::field_by_idx].
fn field_by_name(&self, name: &str) -> VortexResult<Option<ArrayData>> {
let field_idx = self
.names()
.iter()
.position(|field_name| field_name.as_ref() == name);

match field_idx {
None => Ok(None),
Some(field_idx) => self.field_by_idx(field_idx),
}
}

/// Return a field's array by name or index, masking by the struct's validity.
///
/// See also [StructArrayTrait::field_by_idx].
fn field(&self, field: &Field) -> VortexResult<Option<ArrayData>> {
match field {
Field::Index(idx) => self.field_by_idx(*idx),
Field::Name(name) => self.field_by_name(name.as_ref()),
}
}

/// Return a field's array by index, ignoring struct nullability
fn maybe_null_field_by_idx(&self, idx: usize) -> Option<ArrayData>;

Expand All @@ -221,6 +380,7 @@ pub trait StructArrayTrait: ArrayTrait {
field_idx.and_then(|field_idx| self.maybe_null_field_by_idx(field_idx))
}

/// Return a field's array by name or index, ignoring struct nullability
fn maybe_null_field(&self, field: &Field) -> Option<ArrayData> {
match field {
Field::Index(idx) => self.maybe_null_field_by_idx(*idx),
Expand All @@ -245,3 +405,69 @@ pub trait ExtensionArrayTrait: ArrayTrait {
/// Returns the underlying [`ArrayData`], without the [`ExtDType`].
fn storage_data(&self) -> ArrayData;
}

#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use vortex_dtype::FieldNames;
use vortex_scalar::Scalar;

use crate::array::{BoolArray, PrimitiveArray, StructArray};
use crate::compute::scalar_at;
use crate::validity::{ArrayValidity, Validity};
use crate::variants::StructArrayTrait;
use crate::{ArrayDType, IntoArrayData};

#[test]
fn test_field() {
let original_field =
PrimitiveArray::from_option_iter([Some(1), None, Some(3), None, Some(5)]).into_array();
let array = StructArray::try_new(
FieldNames::from(["a".into()]),
vec![original_field.clone()],
5,
Validity::NonNullable,
)
.unwrap();
let field = array.field_by_idx(0).unwrap().unwrap();
let maybe_null_field = array.maybe_null_field_by_idx(0).unwrap();

assert_eq!(field.dtype(), maybe_null_field.dtype());
assert!((0..field.len()).all(|i| { field.is_valid(i) == maybe_null_field.is_valid(i) }));

let struct_validity =
Validity::Array(BoolArray::from_iter([true, true, false, false, true]).into_array());
let array = StructArray::try_new(
FieldNames::from(["a".into()]),
vec![original_field],
5,
struct_validity.clone(),
)
.unwrap();
let field = array.field_by_idx(0).unwrap().unwrap();

assert!(field.dtype().is_nullable());
assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1)));
assert!(!field.is_valid(1));
assert!(!field.is_valid(2));
assert!(!field.is_valid(3));
assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5)));

let original_field = buffer![1, 2, 3, 4, 5].into_array();
let array = StructArray::try_new(
FieldNames::from(["a".into()]),
vec![original_field],
5,
struct_validity,
)
.unwrap();
let field = array.field_by_idx(0).unwrap().unwrap();

assert!(field.dtype().is_nullable());
assert_eq!(scalar_at(&field, 0).unwrap(), Scalar::from(Some(1)));
assert_eq!(scalar_at(&field, 1).unwrap(), Scalar::from(Some(2)));
assert!(!field.is_valid(2));
assert!(!field.is_valid(3));
assert_eq!(scalar_at(&field, 4).unwrap(), Scalar::from(Some(5)));
}
}