From 8db4b6875c861aa9a2534f219f6592f2669cbd32 Mon Sep 17 00:00:00 2001 From: mprammer Date: Sun, 24 May 2026 18:50:47 -0500 Subject: [PATCH] Fix TurboQuant centroid initialization wasting most codes at high dimensions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Seed Lloyd-Max centroids on ±sqrt(bit_width)·sigma instead of the full support [-1, 1], so they start where the rotated-coordinate marginal has mass and no cell freezes in the zero-mass tails. The same change lands in both centroid implementations (vortex-tensor and vortex-turboquant), kept in sync by the cross-crate parity test, with a regression test and an ignored sweep harness. Signed-off-by: mprammer Co-Authored-By: Claude Opus 4.7 --- .../src/encodings/turboquant/centroids.rs | 168 +++++++++++++++++- vortex-turboquant/src/centroids.rs | 15 +- 2 files changed, 178 insertions(+), 5 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/centroids.rs b/vortex-tensor/src/encodings/turboquant/centroids.rs index 1af86c79d85..896be32bb0c 100644 --- a/vortex-tensor/src/encodings/turboquant/centroids.rs +++ b/vortex-tensor/src/encodings/turboquant/centroids.rs @@ -84,7 +84,41 @@ impl HalfIntExponent { } } -/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. +/// How far to spread the initial centroids, as a multiple of the coordinate standard deviation +/// `sigma = 1 / sqrt(dimension)`. +/// +/// Seeding centroids across the full support `[-1, 1]` strands most of them in the near-zero-mass +/// tails, where the zero-denominator guard in [`mean_between_centroids`] freezes them for every +/// iteration; scaling the seed by `sigma` keeps every cell on live probability mass. +#[derive(Clone, Copy, Debug)] +enum InitSpread { + /// A constant multiple of `sigma`, independent of bit width. Only the sweep test constructs + /// this; production uses [`InitSpread::SqrtRate`]. + #[cfg_attr(not(test), allow(dead_code))] + Fixed(f64), + /// `coeff * sqrt(bit_width)` multiples of `sigma`. A codebook with more levels needs a wider + /// seed to keep its outermost cells on live probability mass, so the spread grows with the bit + /// width — mirroring how a quantizer's optimal loading factor grows with rate. + SqrtRate(f64), +} + +impl InitSpread { + /// The seed half-width, in multiples of `sigma`, for the given bit width. + fn sigmas(self, bit_width: u8) -> f64 { + match self { + InitSpread::Fixed(sigmas) => sigmas, + InitSpread::SqrtRate(coeff) => coeff * f64::from(bit_width).sqrt(), + } + } +} + +/// Default centroid initialization. The seed half-width grows as `sqrt(bit_width)` standard +/// deviations, tracking the bit-width-dependent optimum and beating every fixed multiple in +/// `sweep_centroid_init` (including vLLM's `3.5 sigma`). +const DEFAULT_INIT_SPREAD: InitSpread = InitSpread::SqrtRate(1.0); + +/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm with the +/// [default initialization](DEFAULT_INIT_SPREAD). /// /// Operates on the marginal distribution of a single coordinate of a randomly rotated unit vector /// in d dimensions. @@ -93,15 +127,26 @@ impl HalfIntExponent { /// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` /// where `C_d` is the normalizing constant. fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { + max_lloyd_centroids_with(dimension, bit_width, DEFAULT_INIT_SPREAD) +} + +/// Compute Max-Lloyd centroids for an explicit [`InitSpread`]. Production code calls +/// [`max_lloyd_centroids`]; the sweep test explores alternatives through this entry point. +fn max_lloyd_centroids_with(dimension: u32, bit_width: u8, init: InitSpread) -> Buffer { debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); let num_centroids = 1usize << bit_width; // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); - // Initialize centroids uniformly on [-1, 1]. + // The coordinate marginal concentrates around 0 with this standard deviation. + let sigma = 1.0 / f64::from(dimension).sqrt(); + let init_half = (init.sigmas(bit_width) * sigma).min(1.0); + + // Initialize centroids uniformly on [-init_half, init_half], where the mass lives, so no cell + // starts in a zero-mass region and freezes. let mut centroids: Vec = (0..num_centroids) - .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .map(|idx| -init_half + (2.0 * (idx as f64) + 1.0) * init_half / (num_centroids as f64)) .collect(); let mut boundaries: Vec = vec![0.0; num_centroids + 1]; @@ -222,6 +267,8 @@ pub fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { #[cfg(test)] mod tests { + use std::f64::consts::PI; + use rstest::rstest; use vortex_error::VortexResult; @@ -329,4 +376,119 @@ mod tests { assert!(compute_or_get_centroids(1, 2).is_err()); assert!(compute_or_get_centroids(127, 2).is_err()); } + + /// Fine-grained reference measurement of a codebook's quality on the coordinate marginal, + /// computed independently of the solver's own (coarser) integration grid. + struct QuantizerQuality { + /// Implied normalized reconstruction error `E[||x - x_hat||^2 / ||x||^2]` under an ideal + /// orthogonal rotation: `dimension * E[(X - q(X))^2]`. + normalized_mse: f64, + /// `normalized_mse` divided by the Theorem 1 high-rate bound `sqrt(3) * pi / 2 / 4^b`. + ratio_to_bound: f64, + /// Number of centroids whose decision cell carries less than 1e-6 of the total mass, i.e. + /// codes that are wasted because the solver froze them in a near-zero-mass region. + wasted: usize, + } + + /// Measure how well `centroids` quantize the coordinate marginal for `dimension`. + #[expect( + clippy::cast_possible_truncation, + reason = "integration samples are cast f64 -> f32 only to drive find_nearest_centroid" + )] + fn measure_quantizer(dimension: u32, bit_width: u8, centroids: &[f32]) -> QuantizerQuality { + const POINTS: usize = 100_000; + let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); + let boundaries = compute_centroid_boundaries(centroids); + let count = centroids.len(); + let mut mass = vec![0.0f64; count]; + let mut distortion = vec![0.0f64; count]; + let mut total = 0.0f64; + let dx = 2.0 / POINTS as f64; + for step in 0..=POINTS { + let x = -1.0 + step as f64 * dx; + let trapezoid = if step == 0 || step == POINTS { + 0.5 + } else { + 1.0 + }; + let weight = trapezoid * pdf_unnormalized(x, exponent); + let idx = usize::from(find_nearest_centroid(x as f32, &boundaries)); + let delta = x - f64::from(centroids[idx]); + mass[idx] += weight; + distortion[idx] += weight * delta * delta; + total += weight; + } + let per_coord_mse = distortion.iter().sum::() / total; + let normalized_mse = f64::from(dimension) * per_coord_mse; + let bound = 3.0f64.sqrt() * PI / 2.0 / 4.0f64.powi(i32::from(bit_width)); + let wasted = mass.iter().filter(|&&m| m / total < 1e-6).count(); + QuantizerQuality { + normalized_mse, + ratio_to_bound: normalized_mse / bound, + wasted, + } + } + + /// Every code in the production codebook must land on live probability mass. This is the + /// invariant the legacy `[-1, 1]` initialization violated for `dimension >= 256`, where most + /// cells froze in the zero-mass tails and wasted their codes. + #[rstest] + #[case(128)] + #[case(256)] + #[case(1024)] + #[case(2048)] + fn production_centroids_have_no_wasted_cells(#[case] dimension: u32) -> VortexResult<()> { + for bit_width in 1..=MAX_BIT_WIDTH { + let centroids = compute_or_get_centroids(dimension, bit_width)?; + let quality = measure_quantizer(dimension, bit_width, ¢roids); + assert_eq!( + quality.wasted, 0, + "dim={dimension} bits={bit_width}: {} codes landed on zero-mass cells", + quality.wasted + ); + } + Ok(()) + } + + /// Exploratory sweep over centroid-init and outer-edge configurations. Not a pass/fail gate; + /// run with `cargo test -p vortex-tensor centroids::tests::sweep -- --ignored --nocapture` to + /// compare distortion and wasted-code counts when revisiting the default configuration. + #[test] + #[ignore = "exploratory sweep; run with --ignored --nocapture"] + fn sweep_centroid_init() { + // `1e9` saturates the seed spread past 1.0, reproducing the legacy `[-1, 1]` choice. + let configs: &[(&str, InitSpread)] = &[ + ("legacy [-1,1]", InitSpread::Fixed(1e9)), + ("fixed 2.5s", InitSpread::Fixed(2.5)), + ("fixed 3.0s", InitSpread::Fixed(3.0)), + ("fixed 3.5s (vLLM)", InitSpread::Fixed(3.5)), + ("sqrt 1.00*sqrt(b) [default]", DEFAULT_INIT_SPREAD), + ("sqrt 1.05*sqrt(b)", InitSpread::SqrtRate(1.05)), + ("sqrt 1.10*sqrt(b)", InitSpread::SqrtRate(1.10)), + ( + "sqrt 1.18*sqrt(b) [sqrt(2lnN)]", + InitSpread::SqrtRate(1.1774), + ), + ]; + let dims = [128u32, 1024, 2048]; + let bits_list = [4u8, 5, 6, 7, 8]; + + for &(name, init) in configs { + println!("\n=== {name} ==="); + println!( + "{:>6} {:>5} {:>12} {:>9} {:>7}", + "dim", "bits", "norm_mse", "x_bound", "wasted" + ); + for &dimension in &dims { + for &bit_width in &bits_list { + let centroids = max_lloyd_centroids_with(dimension, bit_width, init); + let q = measure_quantizer(dimension, bit_width, ¢roids); + println!( + "{dimension:>6} {bit_width:>5} {:>12.3e} {:>9.2} {:>7}", + q.normalized_mse, q.ratio_to_bound, q.wasted + ); + } + } + } + } } diff --git a/vortex-turboquant/src/centroids.rs b/vortex-turboquant/src/centroids.rs index 8499dfc397a..919948d1c00 100644 --- a/vortex-turboquant/src/centroids.rs +++ b/vortex-turboquant/src/centroids.rs @@ -106,6 +106,12 @@ impl HalfIntExponent { /// The probability distribution function is: /// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` /// where `C_d` is the normalizing constant. +/// +/// Centroids are seeded uniformly on `±sqrt(bit_width) * sigma` (`sigma = 1/sqrt(dimension)`) +/// rather than across the full support `[-1, 1]`, which would strand most of them in the +/// near-zero-mass tails where the zero-denominator guard in [`mean_between_centroids`] freezes them. +/// This must stay identical to `vortex-tensor`'s canonical centroid code (which carries the +/// supporting sweep); the cross-crate parity test enforces it. fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); let num_centroids = 1usize << bit_width; @@ -113,9 +119,14 @@ fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); - // Initialize centroids uniformly on [-1, 1]. + // The coordinate marginal concentrates around 0 with this standard deviation. + let sigma = 1.0 / f64::from(dimension).sqrt(); + let init_half = (f64::from(bit_width).sqrt() * sigma).min(1.0); + + // Initialize centroids uniformly on [-init_half, init_half], where the mass lives, so no cell + // starts in a zero-mass region and freezes. let mut centroids: Vec = (0..num_centroids) - .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .map(|idx| -init_half + (2.0 * (idx as f64) + 1.0) * init_half / (num_centroids as f64)) .collect(); let mut boundaries: Vec = vec![0.0; num_centroids + 1];