Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ enum-iterator = "2.0.0"
env_logger = "0.11"
fastlanes = "0.5"
flatbuffers = "25.2.10"
fsst-rs = "0.5.5"
fsst-rs = "0.5.11"
futures = { version = "0.3.31", default-features = false }
fuzzy-matcher = "0.3"
get_dir = "0.5.0"
Expand Down Expand Up @@ -226,7 +226,7 @@ reqwest = { version = "0.13.0", features = [
roaring = "0.11.0"
rstest = "0.26.1"
rstest_reuse = "0.7.0"
rustc-hash = "2.1"
rustc-hash = "2.1.1"
serde = "1.0.220"
serde_json = "1.0.138"
serde_test = "1.0.176"
Expand Down
187 changes: 156 additions & 31 deletions encodings/fsst/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::fmt::Display;
use std::fmt::Formatter;
use std::hash::Hasher;
use std::sync::Arc;
use std::sync::LazyLock;
use std::sync::OnceLock;

use fsst::Compressor;
use fsst::Decompressor;
Expand Down Expand Up @@ -81,18 +81,23 @@ impl FSSTMetadata {

impl ArrayHash for FSSTData {
fn array_hash<H: Hasher>(&self, state: &mut H, precision: Precision) {
self.symbols.array_hash(state, precision);
self.symbol_lengths.array_hash(state, precision);
self.symbol_table.symbols.array_hash(state, precision);
self.symbol_table
.symbol_lengths
.array_hash(state, precision);
self.codes_bytes.as_host().array_hash(state, precision);
}
}

impl ArrayEq for FSSTData {
fn array_eq(&self, other: &Self, precision: Precision) -> bool {
self.symbols.array_eq(&other.symbols, precision)
self.symbol_table
.symbols
.array_eq(&other.symbol_table.symbols, precision)
&& self
.symbol_table
.symbol_lengths
.array_eq(&other.symbol_lengths, precision)
.array_eq(&other.symbol_table.symbol_lengths, precision)
&& self
.codes_bytes
.as_host()
Expand Down Expand Up @@ -346,28 +351,29 @@ pub(crate) const SLOT_NAMES: [&str; NUM_SLOTS] =
/// [`FSSTArrayExt::codes()`], combining this buffer with the offsets/validity from slots.
#[derive(Clone)]
pub struct FSSTData {
symbols: Buffer<Symbol>,
symbol_lengths: Buffer<u8>,
symbol_table: Arc<FSSTSymbolTable>,
/// The raw compressed codes bytes, equivalent to `VarBinData::bytes`.
codes_bytes: BufferHandle,
/// Cached length (number of elements).
len: usize,

/// Memoized compressor used for push-down of compute by compressing the RHS.
compressor: Arc<LazyLock<Compressor, Box<dyn Fn() -> Compressor + Send>>>,
}

impl Display for FSSTData {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "len: {}, nsymbols: {}", self.len, self.symbols.len())
write!(
f,
"len: {}, nsymbols: {}",
self.len,
self.symbol_table.symbols.len()
)
}
}

impl Debug for FSSTData {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FSSTArray")
.field("symbols", &self.symbols)
.field("symbol_lengths", &self.symbol_lengths)
.field("symbols", &self.symbol_table.symbols)
.field("symbol_lengths", &self.symbol_table.symbol_lengths)
.field("codes_bytes_len", &self.codes_bytes.len())
.field("len", &self.len)
.field("uncompressed_lengths", &"<outer slot>")
Expand All @@ -377,6 +383,29 @@ impl Debug for FSSTData {
}
}

pub(crate) struct FSSTSymbolTable {
symbols: Buffer<Symbol>,
symbol_lengths: Buffer<u8>,
/// Memoized compressor used for push-down of compute by compressing the RHS.
compressor: OnceLock<Compressor>,
}

impl FSSTSymbolTable {
fn new(symbols: Buffer<Symbol>, symbol_lengths: Buffer<u8>) -> Self {
Self {
symbols,
symbol_lengths,
compressor: OnceLock::new(),
}
}

fn compressor(&self) -> &Compressor {
self.compressor.get_or_init(|| {
Compressor::rebuild_from(self.symbols.as_slice(), self.symbol_lengths.as_slice())
})
}
}

#[derive(Clone, Debug)]
pub struct FSST;

Expand Down Expand Up @@ -412,6 +441,32 @@ impl FSST {
})
}

pub(crate) fn try_new_with_symbol_table(
dtype: DType,
symbol_table: Arc<FSSTSymbolTable>,
codes: VarBinArray,
uncompressed_lengths: ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<FSSTArray> {
let len = codes.len();
FSSTData::validate_parts_from_codes(
&symbol_table.symbols,
&symbol_table.symbol_lengths,
&codes,
&uncompressed_lengths,
&dtype,
len,
ctx,
)?;
let slots = FSSTData::make_slots(&codes, &uncompressed_lengths);
let codes_bytes = codes.bytes_handle().clone();
let data =
unsafe { FSSTData::new_unchecked_with_symbol_table(symbol_table, codes_bytes, len) };
Ok(unsafe {
Array::from_parts_unchecked(ArrayParts::new(FSST, dtype, len, data).with_slots(slots))
})
}

/// Legacy deserialization path (2 buffers): the codes were stored as a full
/// `VarBinArray` child. We decompose the VarBinArray into its bytes (stored in
/// FSSTData) and offsets/validity (stored in slots).
Expand Down Expand Up @@ -463,17 +518,17 @@ impl FSST {
Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
}

pub(crate) unsafe fn new_unchecked(
pub(crate) unsafe fn new_unchecked_with_symbol_table(
dtype: DType,
symbols: Buffer<Symbol>,
symbol_lengths: Buffer<u8>,
symbol_table: Arc<FSSTSymbolTable>,
codes: VarBinArray,
uncompressed_lengths: ArrayRef,
) -> FSSTArray {
let len = codes.len();
let slots = FSSTData::make_slots(&codes, &uncompressed_lengths);
let codes_bytes = codes.bytes_handle().clone();
let data = unsafe { FSSTData::new_unchecked(symbols, symbol_lengths, codes_bytes, len) };
let data =
unsafe { FSSTData::new_unchecked_with_symbol_table(symbol_table, codes_bytes, len) };
unsafe {
Array::from_parts_unchecked(ArrayParts::new(FSST, dtype, len, data).with_slots(slots))
}
Expand Down Expand Up @@ -534,8 +589,8 @@ impl FSSTData {
.as_ref()
.vortex_expect("FSSTArray codes_offsets slot");
Self::validate_parts(
&self.symbols,
&self.symbol_lengths,
&self.symbol_table.symbols,
&self.symbol_table.symbol_lengths,
&self.codes_bytes,
codes_offsets,
dtype.nullability(),
Expand Down Expand Up @@ -567,10 +622,13 @@ impl FSSTData {
if symbols.len() > 255 {
vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
}

if symbols.len() != symbol_lengths.len() {
vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
}

Self::validate_symbol_lengths(symbol_lengths.as_slice())?;

// codes_offsets.len() - 1 == number of elements
let codes_len = codes_offsets.len().saturating_sub(1);
if codes_len != len {
Expand Down Expand Up @@ -612,6 +670,32 @@ impl FSSTData {
Ok(())
}

fn validate_symbol_lengths(symbol_lengths: &[u8]) -> VortexResult<()> {
let mut expected = 2;
for (idx, &len) in symbol_lengths.iter().enumerate() {
if len > 8 || len == 0 {
vortex_bail!(InvalidArgument: "symbol length at index {idx} must be between 1 and 8, found {len}");
}

if expected == 1 {
if len != 1 {
vortex_bail!(InvalidArgument: "symbol length at index {idx} must be 1 after one-byte symbols begin, found {len}");
}
} else {
if len == 1 {
expected = 1;
}

if len < expected {
vortex_bail!(InvalidArgument: "symbol length at index {idx} violates FSST symbol table ordering");
}
expected = len;
}
}

Ok(())
}

/// Validate using a VarBinArray for the codes (convenience for construction paths).
fn validate_parts_from_codes(
symbols: &Buffer<Symbol>,
Expand Down Expand Up @@ -641,18 +725,19 @@ impl FSSTData {
codes_bytes: BufferHandle,
len: usize,
) -> Self {
let symbols2 = symbols.clone();
let symbol_lengths2 = symbol_lengths.clone();
let compressor = Arc::new(LazyLock::new(Box::new(move || {
Compressor::rebuild_from(symbols2.as_slice(), symbol_lengths2.as_slice())
})
as Box<dyn Fn() -> Compressor + Send>));
let symbol_table = Arc::new(FSSTSymbolTable::new(symbols, symbol_lengths));
unsafe { Self::new_unchecked_with_symbol_table(symbol_table, codes_bytes, len) }
}

pub(crate) unsafe fn new_unchecked_with_symbol_table(
symbol_table: Arc<FSSTSymbolTable>,
codes_bytes: BufferHandle,
len: usize,
) -> Self {
Self {
symbols,
symbol_lengths,
symbol_table,
codes_bytes,
len,
compressor,
}
}

Expand All @@ -668,12 +753,16 @@ impl FSSTData {

/// Access the symbol table array.
pub fn symbols(&self) -> &Buffer<Symbol> {
&self.symbols
&self.symbol_table.symbols
}

/// Access the symbol lengths array.
pub fn symbol_lengths(&self) -> &Buffer<u8> {
&self.symbol_lengths
&self.symbol_table.symbol_lengths
}

pub(crate) fn symbol_table(&self) -> Arc<FSSTSymbolTable> {
Arc::clone(&self.symbol_table)
}

/// Access the compressed codes bytes buffer handle (may be on host or device).
Expand All @@ -694,7 +783,7 @@ impl FSSTData {

/// Retrieves the FSST compressor.
pub fn compressor(&self) -> &Compressor {
self.compressor.as_ref()
self.symbol_table.compressor()
}
}

Expand Down Expand Up @@ -771,12 +860,48 @@ mod test {
use vortex_array::test_harness::check_metadata;
use vortex_buffer::Buffer;
use vortex_error::VortexError;
use vortex_error::VortexResult;
use vortex_error::vortex_err;

use crate::FSST;
use crate::array::FSSTArrayExt;
use crate::array::FSSTMetadata;
use crate::fsst_compress_iter;

#[test]
fn slice_reuses_initialized_compressor() -> VortexResult<()> {
let symbols = Buffer::<Symbol>::copy_from([
Symbol::from_slice(b"abc00000"),
Symbol::from_slice(b"defghijk"),
]);
let symbol_lengths = Buffer::<u8>::copy_from([3, 8]);

let compressor = Compressor::rebuild_from(symbols.as_slice(), symbol_lengths.as_slice());
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let fsst_array = fsst_compress_iter(
[
Some(b"abcabcab".as_ref()),
Some(b"defghijk".as_ref()),
Some(b"abcxyz".as_ref()),
]
.into_iter(),
3,
DType::Utf8(Nullability::NonNullable),
&compressor,
&mut ctx,
);

let compressor_ptr = fsst_array.compressor() as *const Compressor;
let sliced = fsst_array
.slice(1..3)?
.try_downcast::<FSST>()
.map_err(|_| vortex_err!("slice must return an FSST array"))?;
let sliced_compressor_ptr = sliced.compressor() as *const Compressor;

assert_eq!(compressor_ptr, sliced_compressor_ptr);
Ok(())
}

#[cfg_attr(miri, ignore)]
#[test]
fn test_fsst_metadata() {
Expand Down
5 changes: 2 additions & 3 deletions encodings/fsst/src/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ fn build_with_codes_validity(
)?;

Ok(unsafe {
FSST::new_unchecked(
FSST::new_unchecked_with_symbol_table(
dtype.clone(),
array.symbols().clone(),
array.symbol_lengths().clone(),
array.symbol_table(),
new_codes,
array.uncompressed_lengths().clone(),
)
Expand Down
5 changes: 2 additions & 3 deletions encodings/fsst/src/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ impl FilterKernel for FSST {
.vortex_expect("must be VarBin");

Ok(Some(
FSST::try_new(
FSST::try_new_with_symbol_table(
array.dtype().clone(),
array.symbols().clone(),
array.symbol_lengths().clone(),
array.symbol_table(),
filtered_codes,
array.uncompressed_lengths().filter(mask.clone())?,
ctx,
Expand Down
5 changes: 2 additions & 3 deletions encodings/fsst/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,12 @@ impl TakeExecute for FSST {
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
Ok(Some(
FSST::try_new(
FSST::try_new_with_symbol_table(
array
.dtype()
.clone()
.union_nullability(indices.dtype().nullability()),
array.symbols().clone(),
array.symbol_lengths().clone(),
array.symbol_table(),
{
let codes = array.codes();
let codes = codes.as_view();
Expand Down
Loading
Loading