diff --git a/encodings/zstd/src/lib.rs b/encodings/zstd/src/lib.rs index 679e2c371fe..a178fdec142 100644 --- a/encodings/zstd/src/lib.rs +++ b/encodings/zstd/src/lib.rs @@ -22,6 +22,7 @@ //! ``` pub use array::*; +use vortex_array::dtype::proto::dtype as pb; #[cfg(feature = "unstable_encodings")] pub use zstd_buffers::*; @@ -73,4 +74,12 @@ pub struct ZstdBuffersMetadata { /// Alignment of each buffer in bytes (must be a power of two). #[prost(uint32, repeated, tag = "4")] pub buffer_alignments: Vec, + /// DType of child arrays. Children belong to inner encodings, and their + /// dtypes don't persist after serialization, so we need to retrieve them + /// from metadata. + #[prost(message, repeated, tag = "5")] + pub child_dtypes: Vec, + /// Length of each child array, ordered as "child_dtypes" + #[prost(uint64, repeated, tag = "6")] + pub child_lens: Vec, } diff --git a/encodings/zstd/src/zstd_buffers.rs b/encodings/zstd/src/zstd_buffers.rs index ed25b984974..f6cb9af586c 100644 --- a/encodings/zstd/src/zstd_buffers.rs +++ b/encodings/zstd/src/zstd_buffers.rs @@ -419,12 +419,21 @@ impl VTable for ZstdBuffers { array: ArrayView<'_, Self>, _session: &VortexSession, ) -> VortexResult>> { + let children: Vec<&ArrayRef> = array.slots().iter().flatten().collect(); + let child_dtypes = children + .iter() + .map(|child| child.dtype().try_into()) + .collect::>>()?; + let child_lens = children.iter().map(|child| child.len() as u64).collect(); + Ok(Some( ZstdBuffersMetadata { inner_encoding_id: array.inner_encoding_id.to_string(), inner_metadata: array.inner_metadata.clone(), uncompressed_sizes: array.uncompressed_sizes.clone(), buffer_alignments: array.buffer_alignments.clone(), + child_dtypes, + child_lens, } .encode_to_vec(), )) @@ -437,13 +446,23 @@ impl VTable for ZstdBuffers { metadata: &[u8], buffers: &[BufferHandle], children: &dyn ArrayChildren, - _session: &VortexSession, + session: &VortexSession, ) -> VortexResult> { let metadata = ZstdBuffersMetadata::decode(metadata)?; let compressed_buffers: Vec = buffers.to_vec(); + // Children belong to inner encodings, and serialization doesn't + // preserve their dtypes and values. Check dtypes are recovered from + // metadata. + vortex_ensure_eq!(metadata.child_dtypes.len(), children.len()); + vortex_ensure_eq!(metadata.child_lens.len(), children.len()); + let slots: ArraySlots = (0..children.len()) - .map(|i| children.get(i, dtype, len).map(Some)) + .map(|i| { + let child_dtype = DType::from_proto(&metadata.child_dtypes[i], session)?; + let child_len = usize::try_from(metadata.child_lens[i])?; + children.get(i, &child_dtype, child_len).map(Some) + }) .collect::>>()? .into(); @@ -506,6 +525,7 @@ impl ValidityVTable for ZstdBuffers { #[cfg(test)] mod tests { use rstest::rstest; + use vortex_array::ArrayContext; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; @@ -516,7 +536,12 @@ mod tests { use vortex_array::expr::stats::Precision; use vortex_array::expr::stats::Stat; use vortex_array::expr::stats::StatsProvider; + use vortex_array::serde::SerializeOptions; + use vortex_array::serde::SerializedArray; + use vortex_array::session::ArraySessionExt; + use vortex_buffer::ByteBufferMut; use vortex_error::VortexResult; + use vortex_session::registry::ReadContext; use super::*; @@ -572,6 +597,38 @@ mod tests { Ok(()) } + #[rstest] + #[case::primitive(make_primitive_array())] + #[case::varbinview(make_varbinview_array())] + #[case::nullable_primitive(make_nullable_primitive_array())] + #[case::nullable_varbinview(make_nullable_varbinview_array())] + #[case::empty_primitive(make_empty_primitive_array())] + #[case::inlined_varbinview(make_inlined_varbinview_array())] + fn test_serde_roundtrip(#[case] input: ArrayRef) -> VortexResult<()> { + let session = array_session(); + session.arrays().register(ZstdBuffers); + + let compressed = ZstdBuffers::compress(&input, 3, &session)?.into_array(); + let dtype = compressed.dtype().clone(); + let len = compressed.len(); + + let array_ctx = ArrayContext::empty(); + let serialized = + compressed.serialize(&array_ctx, &session, &SerializeOptions::default())?; + + let mut concat = ByteBufferMut::empty(); + for buf in serialized { + concat.extend_from_slice(buf.as_ref()); + } + let parts = SerializedArray::try_from(concat.freeze())?; + let decoded = parts.decode(&dtype, len, &ReadContext::new(array_ctx.to_ids()), &session)?; + + let mut ctx = session.create_execution_ctx(); + let decoded = decoded.execute::(&mut ctx)?; + assert_arrays_eq!(input, decoded, &mut ctx); + Ok(()) + } + #[test] fn test_compress_inherits_stats() -> VortexResult<()> { let input = make_primitive_array();