Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions encodings/zstd/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
//! ```

pub use array::*;
use vortex_array::dtype::proto::dtype as pb;
#[cfg(feature = "unstable_encodings")]
pub use zstd_buffers::*;

Expand Down Expand Up @@ -73,4 +74,12 @@ pub struct ZstdBuffersMetadata {
/// Alignment of each buffer in bytes (must be a power of two).
#[prost(uint32, repeated, tag = "4")]
pub buffer_alignments: Vec<u32>,
/// DType of child arrays. Children belong to inner encodings, and their
/// dtypes don't persist after serialization, so we need to retrieve them
/// from metadata.
#[prost(message, repeated, tag = "5")]
pub child_dtypes: Vec<pb::DType>,
/// Length of each child array, ordered as "child_dtypes"
#[prost(uint64, repeated, tag = "6")]
pub child_lens: Vec<u64>,
}
61 changes: 59 additions & 2 deletions encodings/zstd/src/zstd_buffers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,12 +419,21 @@ impl VTable for ZstdBuffers {
array: ArrayView<'_, Self>,
_session: &VortexSession,
) -> VortexResult<Option<Vec<u8>>> {
let children: Vec<&ArrayRef> = array.slots().iter().flatten().collect();
let child_dtypes = children
.iter()
.map(|child| child.dtype().try_into())
.collect::<VortexResult<Vec<_>>>()?;
let child_lens = children.iter().map(|child| child.len() as u64).collect();

Ok(Some(
ZstdBuffersMetadata {
inner_encoding_id: array.inner_encoding_id.to_string(),
inner_metadata: array.inner_metadata.clone(),
uncompressed_sizes: array.uncompressed_sizes.clone(),
buffer_alignments: array.buffer_alignments.clone(),
child_dtypes,
child_lens,
}
.encode_to_vec(),
))
Expand All @@ -437,13 +446,23 @@ impl VTable for ZstdBuffers {
metadata: &[u8],
buffers: &[BufferHandle],
children: &dyn ArrayChildren,
_session: &VortexSession,
session: &VortexSession,
) -> VortexResult<ArrayParts<Self>> {
let metadata = ZstdBuffersMetadata::decode(metadata)?;
let compressed_buffers: Vec<BufferHandle> = buffers.to_vec();

// Children belong to inner encodings, and serialization doesn't
// preserve their dtypes and values. Check dtypes are recovered from
// metadata.
vortex_ensure_eq!(metadata.child_dtypes.len(), children.len());
vortex_ensure_eq!(metadata.child_lens.len(), children.len());

let slots: ArraySlots = (0..children.len())
.map(|i| children.get(i, dtype, len).map(Some))
.map(|i| {
let child_dtype = DType::from_proto(&metadata.child_dtypes[i], session)?;
let child_len = usize::try_from(metadata.child_lens[i])?;
children.get(i, &child_dtype, child_len).map(Some)
})
.collect::<VortexResult<Vec<_>>>()?
.into();

Expand Down Expand Up @@ -506,6 +525,7 @@ impl ValidityVTable<ZstdBuffers> for ZstdBuffers {
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_array::ArrayContext;
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
Expand All @@ -516,7 +536,12 @@ mod tests {
use vortex_array::expr::stats::Precision;
use vortex_array::expr::stats::Stat;
use vortex_array::expr::stats::StatsProvider;
use vortex_array::serde::SerializeOptions;
use vortex_array::serde::SerializedArray;
use vortex_array::session::ArraySessionExt;
use vortex_buffer::ByteBufferMut;
use vortex_error::VortexResult;
use vortex_session::registry::ReadContext;

use super::*;

Expand Down Expand Up @@ -572,6 +597,38 @@ mod tests {
Ok(())
}

#[rstest]
#[case::primitive(make_primitive_array())]
#[case::varbinview(make_varbinview_array())]
#[case::nullable_primitive(make_nullable_primitive_array())]
#[case::nullable_varbinview(make_nullable_varbinview_array())]
#[case::empty_primitive(make_empty_primitive_array())]
#[case::inlined_varbinview(make_inlined_varbinview_array())]
fn test_serde_roundtrip(#[case] input: ArrayRef) -> VortexResult<()> {
let session = array_session();
session.arrays().register(ZstdBuffers);

let compressed = ZstdBuffers::compress(&input, 3, &session)?.into_array();
let dtype = compressed.dtype().clone();
let len = compressed.len();

let array_ctx = ArrayContext::empty();
let serialized =
compressed.serialize(&array_ctx, &session, &SerializeOptions::default())?;

let mut concat = ByteBufferMut::empty();
for buf in serialized {
concat.extend_from_slice(buf.as_ref());
}
let parts = SerializedArray::try_from(concat.freeze())?;
let decoded = parts.decode(&dtype, len, &ReadContext::new(array_ctx.to_ids()), &session)?;

let mut ctx = session.create_execution_ctx();
let decoded = decoded.execute::<ArrayRef>(&mut ctx)?;
assert_arrays_eq!(input, decoded, &mut ctx);
Ok(())
}

#[test]
fn test_compress_inherits_stats() -> VortexResult<()> {
let input = make_primitive_array();
Expand Down
Loading