diff --git a/Cargo.lock b/Cargo.lock index ec89bf1161f..efc977abf08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3272,9 +3272,9 @@ dependencies = [ [[package]] name = "fsst-rs" -version = "0.5.10" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bf53d7c403a2b76873d4d66ba7d79c54bde2784cdaba6083f223d6e33270708" +checksum = "9b13ac798afc0d9194eb4efefef8b9332efbd80b43f302a968cb8cb23b9d5360" dependencies = [ "rustc-hash", ] diff --git a/Cargo.toml b/Cargo.toml index 9700d8d78ed..e3c3cbae67e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -154,7 +154,7 @@ enum-iterator = "2.0.0" env_logger = "0.11" fastlanes = "0.5" flatbuffers = "25.2.10" -fsst-rs = "0.5.5" +fsst-rs = "0.5.11" futures = { version = "0.3.31", default-features = false } fuzzy-matcher = "0.3" get_dir = "0.5.0" @@ -226,7 +226,7 @@ reqwest = { version = "0.13.0", features = [ roaring = "0.11.0" rstest = "0.26.1" rstest_reuse = "0.7.0" -rustc-hash = "2.1" +rustc-hash = "2.1.1" serde = "1.0.220" serde_json = "1.0.138" serde_test = "1.0.176" diff --git a/encodings/fsst/src/array.rs b/encodings/fsst/src/array.rs index 617908e94dd..2019a6c73d4 100644 --- a/encodings/fsst/src/array.rs +++ b/encodings/fsst/src/array.rs @@ -6,7 +6,7 @@ use std::fmt::Display; use std::fmt::Formatter; use std::hash::Hasher; use std::sync::Arc; -use std::sync::LazyLock; +use std::sync::OnceLock; use fsst::Compressor; use fsst::Decompressor; @@ -81,18 +81,23 @@ impl FSSTMetadata { impl ArrayHash for FSSTData { fn array_hash(&self, state: &mut H, precision: Precision) { - self.symbols.array_hash(state, precision); - self.symbol_lengths.array_hash(state, precision); + self.symbol_table.symbols.array_hash(state, precision); + self.symbol_table + .symbol_lengths + .array_hash(state, precision); self.codes_bytes.as_host().array_hash(state, precision); } } impl ArrayEq for FSSTData { fn array_eq(&self, other: &Self, precision: Precision) -> bool { - self.symbols.array_eq(&other.symbols, precision) + self.symbol_table + .symbols + .array_eq(&other.symbol_table.symbols, precision) && self + .symbol_table .symbol_lengths - .array_eq(&other.symbol_lengths, precision) + .array_eq(&other.symbol_table.symbol_lengths, precision) && self .codes_bytes .as_host() @@ -346,28 +351,29 @@ pub(crate) const SLOT_NAMES: [&str; NUM_SLOTS] = /// [`FSSTArrayExt::codes()`], combining this buffer with the offsets/validity from slots. #[derive(Clone)] pub struct FSSTData { - symbols: Buffer, - symbol_lengths: Buffer, + symbol_table: Arc, /// The raw compressed codes bytes, equivalent to `VarBinData::bytes`. codes_bytes: BufferHandle, /// Cached length (number of elements). len: usize, - - /// Memoized compressor used for push-down of compute by compressing the RHS. - compressor: Arc Compressor + Send>>>, } impl Display for FSSTData { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "len: {}, nsymbols: {}", self.len, self.symbols.len()) + write!( + f, + "len: {}, nsymbols: {}", + self.len, + self.symbol_table.symbols.len() + ) } } impl Debug for FSSTData { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("FSSTArray") - .field("symbols", &self.symbols) - .field("symbol_lengths", &self.symbol_lengths) + .field("symbols", &self.symbol_table.symbols) + .field("symbol_lengths", &self.symbol_table.symbol_lengths) .field("codes_bytes_len", &self.codes_bytes.len()) .field("len", &self.len) .field("uncompressed_lengths", &"") @@ -377,6 +383,29 @@ impl Debug for FSSTData { } } +pub(crate) struct FSSTSymbolTable { + symbols: Buffer, + symbol_lengths: Buffer, + /// Memoized compressor used for push-down of compute by compressing the RHS. + compressor: OnceLock, +} + +impl FSSTSymbolTable { + fn new(symbols: Buffer, symbol_lengths: Buffer) -> Self { + Self { + symbols, + symbol_lengths, + compressor: OnceLock::new(), + } + } + + fn compressor(&self) -> &Compressor { + self.compressor.get_or_init(|| { + Compressor::rebuild_from(self.symbols.as_slice(), self.symbol_lengths.as_slice()) + }) + } +} + #[derive(Clone, Debug)] pub struct FSST; @@ -412,6 +441,32 @@ impl FSST { }) } + pub(crate) fn try_new_with_symbol_table( + dtype: DType, + symbol_table: Arc, + codes: VarBinArray, + uncompressed_lengths: ArrayRef, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let len = codes.len(); + FSSTData::validate_parts_from_codes( + &symbol_table.symbols, + &symbol_table.symbol_lengths, + &codes, + &uncompressed_lengths, + &dtype, + len, + ctx, + )?; + let slots = FSSTData::make_slots(&codes, &uncompressed_lengths); + let codes_bytes = codes.bytes_handle().clone(); + let data = + unsafe { FSSTData::new_unchecked_with_symbol_table(symbol_table, codes_bytes, len) }; + Ok(unsafe { + Array::from_parts_unchecked(ArrayParts::new(FSST, dtype, len, data).with_slots(slots)) + }) + } + /// Legacy deserialization path (2 buffers): the codes were stored as a full /// `VarBinArray` child. We decompose the VarBinArray into its bytes (stored in /// FSSTData) and offsets/validity (stored in slots). @@ -463,17 +518,17 @@ impl FSST { Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots)) } - pub(crate) unsafe fn new_unchecked( + pub(crate) unsafe fn new_unchecked_with_symbol_table( dtype: DType, - symbols: Buffer, - symbol_lengths: Buffer, + symbol_table: Arc, codes: VarBinArray, uncompressed_lengths: ArrayRef, ) -> FSSTArray { let len = codes.len(); let slots = FSSTData::make_slots(&codes, &uncompressed_lengths); let codes_bytes = codes.bytes_handle().clone(); - let data = unsafe { FSSTData::new_unchecked(symbols, symbol_lengths, codes_bytes, len) }; + let data = + unsafe { FSSTData::new_unchecked_with_symbol_table(symbol_table, codes_bytes, len) }; unsafe { Array::from_parts_unchecked(ArrayParts::new(FSST, dtype, len, data).with_slots(slots)) } @@ -534,8 +589,8 @@ impl FSSTData { .as_ref() .vortex_expect("FSSTArray codes_offsets slot"); Self::validate_parts( - &self.symbols, - &self.symbol_lengths, + &self.symbol_table.symbols, + &self.symbol_table.symbol_lengths, &self.codes_bytes, codes_offsets, dtype.nullability(), @@ -567,10 +622,13 @@ impl FSSTData { if symbols.len() > 255 { vortex_bail!(InvalidArgument: "symbols array must have length <= 255"); } + if symbols.len() != symbol_lengths.len() { vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length"); } + Self::validate_symbol_lengths(symbol_lengths.as_slice())?; + // codes_offsets.len() - 1 == number of elements let codes_len = codes_offsets.len().saturating_sub(1); if codes_len != len { @@ -612,6 +670,32 @@ impl FSSTData { Ok(()) } + fn validate_symbol_lengths(symbol_lengths: &[u8]) -> VortexResult<()> { + let mut expected = 2; + for (idx, &len) in symbol_lengths.iter().enumerate() { + if len > 8 || len == 0 { + vortex_bail!(InvalidArgument: "symbol length at index {idx} must be between 1 and 8, found {len}"); + } + + if expected == 1 { + if len != 1 { + vortex_bail!(InvalidArgument: "symbol length at index {idx} must be 1 after one-byte symbols begin, found {len}"); + } + } else { + if len == 1 { + expected = 1; + } + + if len < expected { + vortex_bail!(InvalidArgument: "symbol length at index {idx} violates FSST symbol table ordering"); + } + expected = len; + } + } + + Ok(()) + } + /// Validate using a VarBinArray for the codes (convenience for construction paths). fn validate_parts_from_codes( symbols: &Buffer, @@ -641,18 +725,19 @@ impl FSSTData { codes_bytes: BufferHandle, len: usize, ) -> Self { - let symbols2 = symbols.clone(); - let symbol_lengths2 = symbol_lengths.clone(); - let compressor = Arc::new(LazyLock::new(Box::new(move || { - Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice()) - }) - as Box Compressor + Send>)); + let symbol_table = Arc::new(FSSTSymbolTable::new(symbols, symbol_lengths)); + unsafe { Self::new_unchecked_with_symbol_table(symbol_table, codes_bytes, len) } + } + + pub(crate) unsafe fn new_unchecked_with_symbol_table( + symbol_table: Arc, + codes_bytes: BufferHandle, + len: usize, + ) -> Self { Self { - symbols, - symbol_lengths, + symbol_table, codes_bytes, len, - compressor, } } @@ -668,12 +753,16 @@ impl FSSTData { /// Access the symbol table array. pub fn symbols(&self) -> &Buffer { - &self.symbols + &self.symbol_table.symbols } /// Access the symbol lengths array. pub fn symbol_lengths(&self) -> &Buffer { - &self.symbol_lengths + &self.symbol_table.symbol_lengths + } + + pub(crate) fn symbol_table(&self) -> Arc { + Arc::clone(&self.symbol_table) } /// Access the compressed codes bytes buffer handle (may be on host or device). @@ -694,7 +783,7 @@ impl FSSTData { /// Retrieves the FSST compressor. pub fn compressor(&self) -> &Compressor { - self.compressor.as_ref() + self.symbol_table.compressor() } } @@ -771,12 +860,48 @@ mod test { use vortex_array::test_harness::check_metadata; use vortex_buffer::Buffer; use vortex_error::VortexError; + use vortex_error::VortexResult; + use vortex_error::vortex_err; use crate::FSST; use crate::array::FSSTArrayExt; use crate::array::FSSTMetadata; use crate::fsst_compress_iter; + #[test] + fn slice_reuses_initialized_compressor() -> VortexResult<()> { + let symbols = Buffer::::copy_from([ + Symbol::from_slice(b"abc00000"), + Symbol::from_slice(b"defghijk"), + ]); + let symbol_lengths = Buffer::::copy_from([3, 8]); + + let compressor = Compressor::rebuild_from(symbols.as_slice(), symbol_lengths.as_slice()); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let fsst_array = fsst_compress_iter( + [ + Some(b"abcabcab".as_ref()), + Some(b"defghijk".as_ref()), + Some(b"abcxyz".as_ref()), + ] + .into_iter(), + 3, + DType::Utf8(Nullability::NonNullable), + &compressor, + &mut ctx, + ); + + let compressor_ptr = fsst_array.compressor() as *const Compressor; + let sliced = fsst_array + .slice(1..3)? + .try_downcast::() + .map_err(|_| vortex_err!("slice must return an FSST array"))?; + let sliced_compressor_ptr = sliced.compressor() as *const Compressor; + + assert_eq!(compressor_ptr, sliced_compressor_ptr); + Ok(()) + } + #[cfg_attr(miri, ignore)] #[test] fn test_fsst_metadata() { diff --git a/encodings/fsst/src/compute/cast.rs b/encodings/fsst/src/compute/cast.rs index a1c96363ba0..47c324fc2a4 100644 --- a/encodings/fsst/src/compute/cast.rs +++ b/encodings/fsst/src/compute/cast.rs @@ -30,10 +30,9 @@ fn build_with_codes_validity( )?; Ok(unsafe { - FSST::new_unchecked( + FSST::new_unchecked_with_symbol_table( dtype.clone(), - array.symbols().clone(), - array.symbol_lengths().clone(), + array.symbol_table(), new_codes, array.uncompressed_lengths().clone(), ) diff --git a/encodings/fsst/src/compute/filter.rs b/encodings/fsst/src/compute/filter.rs index 74e32cfbd02..eae72cac8dd 100644 --- a/encodings/fsst/src/compute/filter.rs +++ b/encodings/fsst/src/compute/filter.rs @@ -31,10 +31,9 @@ impl FilterKernel for FSST { .vortex_expect("must be VarBin"); Ok(Some( - FSST::try_new( + FSST::try_new_with_symbol_table( array.dtype().clone(), - array.symbols().clone(), - array.symbol_lengths().clone(), + array.symbol_table(), filtered_codes, array.uncompressed_lengths().filter(mask.clone())?, ctx, diff --git a/encodings/fsst/src/compute/mod.rs b/encodings/fsst/src/compute/mod.rs index 02efdf7febc..04ce6db4eda 100644 --- a/encodings/fsst/src/compute/mod.rs +++ b/encodings/fsst/src/compute/mod.rs @@ -28,13 +28,12 @@ impl TakeExecute for FSST { ctx: &mut ExecutionCtx, ) -> VortexResult> { Ok(Some( - FSST::try_new( + FSST::try_new_with_symbol_table( array .dtype() .clone() .union_nullability(indices.dtype().nullability()), - array.symbols().clone(), - array.symbol_lengths().clone(), + array.symbol_table(), { let codes = array.codes(); let codes = codes.as_view(); diff --git a/encodings/fsst/src/slice.rs b/encodings/fsst/src/slice.rs index 9b7b2833d85..b98f5fa03b3 100644 --- a/encodings/fsst/src/slice.rs +++ b/encodings/fsst/src/slice.rs @@ -19,10 +19,9 @@ impl SliceReduce for FSST { // SAFETY: slicing the `codes` leaves the symbol table intact Ok(Some( unsafe { - FSST::new_unchecked( + FSST::new_unchecked_with_symbol_table( array.dtype().clone(), - array.symbols().clone(), - array.symbol_lengths().clone(), + array.symbol_table(), array .codes() .slice(range.clone())?