From 02e11079f80b84c64f2ed03f1a47de1989f13777 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Thu, 11 Jun 2026 17:34:18 -0400 Subject: [PATCH] remove all turboquant stuff Signed-off-by: Connor Tsui --- Cargo.lock | 48 - Cargo.toml | 3 - benchmarks/vector-search-bench/Cargo.toml | 41 - .../scripts/plot-turboquant-distortion.py | 596 ---------- .../vector-search-bench/src/compression.rs | 110 -- benchmarks/vector-search-bench/src/display.rs | 133 --- .../vector-search-bench/src/distortion.rs | 370 ------ .../vector-search-bench/src/expression.rs | 97 -- benchmarks/vector-search-bench/src/ingest.rs | 221 ---- benchmarks/vector-search-bench/src/lib.rs | 67 -- benchmarks/vector-search-bench/src/main.rs | 284 ----- benchmarks/vector-search-bench/src/prepare.rs | 227 ---- benchmarks/vector-search-bench/src/query.rs | 120 -- benchmarks/vector-search-bench/src/scan.rs | 177 --- vortex-btrblocks/Cargo.toml | 7 +- vortex-btrblocks/src/builder.rs | 16 - vortex-tensor/src/encodings/mod.rs | 1 - .../src/encodings/turboquant/centroids.rs | 361 ------ .../src/encodings/turboquant/compress.rs | 270 ----- vortex-tensor/src/encodings/turboquant/mod.rs | 182 --- .../src/encodings/turboquant/scheme.rs | 221 ---- .../src/encodings/turboquant/tests/compute.rs | 216 ---- .../src/encodings/turboquant/tests/mod.rs | 163 --- .../encodings/turboquant/tests/nullable.rs | 178 --- .../encodings/turboquant/tests/roundtrip.rs | 313 ----- .../encodings/turboquant/tests/structural.rs | 334 ------ vortex-tensor/src/lib.rs | 7 +- .../src/scalar_fns/cosine_similarity.rs | 6 +- vortex-tensor/src/scalar_fns/inner_product.rs | 1054 ----------------- vortex-tensor/src/scalar_fns/l2_denorm.rs | 2 +- vortex-tensor/src/scalar_fns/l2_norm.rs | 4 +- vortex-tensor/src/scalar_fns/mod.rs | 1 - .../src/scalar_fns/sorf_transform/mod.rs | 146 --- .../src/scalar_fns/sorf_transform/rotation.rs | 559 --------- .../scalar_fns/sorf_transform/splitmix64.rs | 73 -- .../src/scalar_fns/sorf_transform/tests.rs | 493 -------- .../src/scalar_fns/sorf_transform/vtable.rs | 336 ------ vortex-tensor/src/utils.rs | 36 - vortex-tensor/src/vector_search.rs | 50 +- vortex-turboquant/Cargo.toml | 41 - vortex-turboquant/benches/encode_decode.rs | 147 --- vortex-turboquant/src/centroids.rs | 361 ------ vortex-turboquant/src/config.rs | 84 -- vortex-turboquant/src/lib.rs | 81 -- vortex-turboquant/src/scalar_fns/decode.rs | 319 ----- vortex-turboquant/src/scalar_fns/encode.rs | 226 ---- vortex-turboquant/src/scalar_fns/metadata.rs | 47 - vortex-turboquant/src/scalar_fns/mod.rs | 11 - vortex-turboquant/src/sorf/mod.rs | 7 - vortex-turboquant/src/sorf/splitmix64.rs | 78 -- vortex-turboquant/src/sorf/transform.rs | 419 ------- vortex-turboquant/src/tests/encode_decode.rs | 254 ---- vortex-turboquant/src/tests/file.rs | 73 -- vortex-turboquant/src/tests/malformed.rs | 189 --- vortex-turboquant/src/tests/metadata.rs | 173 --- vortex-turboquant/src/tests/mod.rs | 141 --- vortex-turboquant/src/tests/parity.rs | 38 - vortex-turboquant/src/tests/scalar_fns.rs | 61 - vortex-turboquant/src/vector/mod.rs | 27 - vortex-turboquant/src/vector/normalize.rs | 236 ---- vortex-turboquant/src/vector/quantize.rs | 181 --- vortex-turboquant/src/vector/storage.rs | 164 --- vortex-turboquant/src/vtable.rs | 241 ---- vortex/Cargo.toml | 4 - vortex/benches/single_encoding_throughput.rs | 133 --- vortex/examples/turboquant_vector_search.rs | 399 ------- 66 files changed, 10 insertions(+), 11648 deletions(-) delete mode 100644 benchmarks/vector-search-bench/Cargo.toml delete mode 100644 benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py delete mode 100644 benchmarks/vector-search-bench/src/compression.rs delete mode 100644 benchmarks/vector-search-bench/src/display.rs delete mode 100644 benchmarks/vector-search-bench/src/distortion.rs delete mode 100644 benchmarks/vector-search-bench/src/expression.rs delete mode 100644 benchmarks/vector-search-bench/src/ingest.rs delete mode 100644 benchmarks/vector-search-bench/src/lib.rs delete mode 100644 benchmarks/vector-search-bench/src/main.rs delete mode 100644 benchmarks/vector-search-bench/src/prepare.rs delete mode 100644 benchmarks/vector-search-bench/src/query.rs delete mode 100644 benchmarks/vector-search-bench/src/scan.rs delete mode 100644 vortex-tensor/src/encodings/turboquant/centroids.rs delete mode 100644 vortex-tensor/src/encodings/turboquant/compress.rs delete mode 100644 vortex-tensor/src/encodings/turboquant/mod.rs delete mode 100644 vortex-tensor/src/encodings/turboquant/scheme.rs delete mode 100644 vortex-tensor/src/encodings/turboquant/tests/compute.rs delete mode 100644 vortex-tensor/src/encodings/turboquant/tests/mod.rs delete mode 100644 vortex-tensor/src/encodings/turboquant/tests/nullable.rs delete mode 100644 vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs delete mode 100644 vortex-tensor/src/encodings/turboquant/tests/structural.rs delete mode 100644 vortex-tensor/src/scalar_fns/sorf_transform/mod.rs delete mode 100644 vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs delete mode 100644 vortex-tensor/src/scalar_fns/sorf_transform/splitmix64.rs delete mode 100644 vortex-tensor/src/scalar_fns/sorf_transform/tests.rs delete mode 100644 vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs delete mode 100644 vortex-turboquant/Cargo.toml delete mode 100644 vortex-turboquant/benches/encode_decode.rs delete mode 100644 vortex-turboquant/src/centroids.rs delete mode 100644 vortex-turboquant/src/config.rs delete mode 100644 vortex-turboquant/src/lib.rs delete mode 100644 vortex-turboquant/src/scalar_fns/decode.rs delete mode 100644 vortex-turboquant/src/scalar_fns/encode.rs delete mode 100644 vortex-turboquant/src/scalar_fns/metadata.rs delete mode 100644 vortex-turboquant/src/scalar_fns/mod.rs delete mode 100644 vortex-turboquant/src/sorf/mod.rs delete mode 100644 vortex-turboquant/src/sorf/splitmix64.rs delete mode 100644 vortex-turboquant/src/sorf/transform.rs delete mode 100644 vortex-turboquant/src/tests/encode_decode.rs delete mode 100644 vortex-turboquant/src/tests/file.rs delete mode 100644 vortex-turboquant/src/tests/malformed.rs delete mode 100644 vortex-turboquant/src/tests/metadata.rs delete mode 100644 vortex-turboquant/src/tests/mod.rs delete mode 100644 vortex-turboquant/src/tests/parity.rs delete mode 100644 vortex-turboquant/src/tests/scalar_fns.rs delete mode 100644 vortex-turboquant/src/vector/mod.rs delete mode 100644 vortex-turboquant/src/vector/normalize.rs delete mode 100644 vortex-turboquant/src/vector/quantize.rs delete mode 100644 vortex-turboquant/src/vector/storage.rs delete mode 100644 vortex-turboquant/src/vtable.rs delete mode 100644 vortex/examples/turboquant_vector_search.rs diff --git a/Cargo.lock b/Cargo.lock index b6f67682c4a..c45406a8c6c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9536,31 +9536,6 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" -[[package]] -name = "vector-search-bench" -version = "0.1.0" -dependencies = [ - "anyhow", - "arrow-array", - "arrow-buffer", - "arrow-schema", - "clap", - "futures", - "indicatif", - "parquet", - "rand 0.10.1", - "rand_distr 0.6.0", - "serde", - "tabled", - "tempfile", - "tokio", - "tracing", - "vortex", - "vortex-bench", - "vortex-btrblocks", - "vortex-tensor", -] - [[package]] name = "version_check" version = "0.9.5" @@ -9841,7 +9816,6 @@ dependencies = [ "vortex-sequence", "vortex-session", "vortex-sparse", - "vortex-tensor", "vortex-utils", "vortex-zigzag", "vortex-zstd", @@ -10707,28 +10681,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "vortex-turboquant" -version = "0.1.0" -dependencies = [ - "codspeed-divan-compat", - "half", - "num-traits", - "prost 0.14.4", - "rand 0.10.1", - "rstest", - "vortex-array", - "vortex-buffer", - "vortex-error", - "vortex-file", - "vortex-io", - "vortex-layout", - "vortex-mask", - "vortex-session", - "vortex-tensor", - "vortex-utils", -] - [[package]] name = "vortex-utils" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index b8ef2179541..0fea9f6b125 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ members = [ "vortex-row", "vortex-tensor", "vortex-json", - "vortex-turboquant", "vortex-compressor", "vortex-btrblocks", "vortex-layout", @@ -65,7 +64,6 @@ members = [ "benchmarks/datafusion-bench", "benchmarks/duckdb-bench", "benchmarks/random-access-bench", - "benchmarks/vector-search-bench", # Benchmarks website v3 (alpha) - leaf binary, not part of vortex-* API "benchmarks-website/server", "benchmarks-website/migrate", @@ -313,7 +311,6 @@ vortex-sequence = { version = "0.1.0", path = "encodings/sequence", default-feat vortex-session = { version = "0.1.0", path = "./vortex-session", default-features = false } vortex-sparse = { version = "0.1.0", path = "./encodings/sparse", default-features = false } vortex-tensor = { version = "0.1.0", path = "./vortex-tensor", default-features = false } -vortex-turboquant = { version = "0.1.0", path = "./vortex-turboquant", default-features = false } vortex-utils = { version = "0.1.0", path = "./vortex-utils", default-features = false } vortex-zigzag = { version = "0.1.0", path = "./encodings/zigzag", default-features = false } vortex-zstd = { version = "0.1.0", path = "./encodings/zstd", default-features = false } diff --git a/benchmarks/vector-search-bench/Cargo.toml b/benchmarks/vector-search-bench/Cargo.toml deleted file mode 100644 index 02528526080..00000000000 --- a/benchmarks/vector-search-bench/Cargo.toml +++ /dev/null @@ -1,41 +0,0 @@ -[package] -name = "vector-search-bench" -description = "Vector similarity search benchmarks for Vortex on public embedding datasets" -authors.workspace = true -categories.workspace = true -edition.workspace = true -homepage.workspace = true -include.workspace = true -keywords.workspace = true -license.workspace = true -readme.workspace = true -repository.workspace = true -rust-version.workspace = true -version.workspace = true -publish = false - -[dependencies] -anyhow = { workspace = true } -arrow-array = { workspace = true } -arrow-buffer = { workspace = true } -arrow-schema = { workspace = true } -clap = { workspace = true, features = ["derive"] } -futures = { workspace = true } -indicatif = { workspace = true } -parquet = { workspace = true, features = ["async"] } -rand = { workspace = true } -rand_distr = { workspace = true } -serde = { workspace = true, features = ["derive"] } -tabled = { workspace = true, features = ["std"] } -tokio = { workspace = true, features = ["full"] } -tracing = { workspace = true } -vortex = { workspace = true, features = ["files", "tokio", "unstable_encodings"] } -vortex-bench = { workspace = true, features = ["unstable_encodings"] } -vortex-btrblocks = { workspace = true, features = ["unstable_encodings"] } -vortex-tensor = { workspace = true } - -[dev-dependencies] -tempfile = { workspace = true } - -[lints] -workspace = true diff --git a/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py b/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py deleted file mode 100644 index 6d6b1f22b4f..00000000000 --- a/benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py +++ /dev/null @@ -1,596 +0,0 @@ -# /// script -# requires-python = ">=3.11" -# dependencies = [ -# "matplotlib", -# ] -# /// - -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright the Vortex contributors - -"""Sweep bits-vs-distortion for TurboQuant and plot the curves. - -Calls `vector-search-bench distortion` for each (dataset, bits) combination, parses the -table from stdout, and plots reconstruction NMSE and squared cosine-error curves with -mean/median/max shown on a log-scaled y-axis. - -Each `--dataset` value may optionally pin a train layout with a colon, e.g. -`--dataset cohere-small-100k:single`, for datasets that host more than one layout. - -Usage: - uv run benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py \\ - --dataset sift-small-500k - uv run benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py \\ - --dataset sift-small-500k --dataset glove-small-100k --samples 8192 - uv run benchmarks/vector-search-bench/scripts/plot-turboquant-distortion.py \\ - --dataset cohere-small-100k:single --bits 1 2 3 4 5 6 7 8 \\ - --output /tmp/distortion.png -""" - -import argparse -import math -import re -import subprocess -import sys -from dataclasses import dataclass -from pathlib import Path - -import matplotlib.pyplot as plt -from matplotlib.lines import Line2D -from matplotlib.ticker import MaxNLocator, NullLocator - -REPO_ROOT = Path(__file__).resolve().parents[3] -DEFAULT_BINARY = REPO_ROOT / "target" / "release" / "vector-search-bench" - -METRIC_NAMES = [ - "reconstruction NMSE mean", - "reconstruction NMSE median", - "reconstruction NMSE max", - "decoded cosine sqerr mean", - "decoded cosine sqerr median", - "decoded cosine sqerr max", -] - - -@dataclass(frozen=True) -class DatasetTarget: - """One dataset to sweep, with the layout the bench should use for it.""" - - name: str - layout: str | None # `None` means let the bench auto-pick. - - -@dataclass -class Run: - target: DatasetTarget - dim: int - bits: int - values: dict[str, float] - - @property - def dataset(self) -> str: - return self.target.name - - -DIM_RE = re.compile(r"dim=(\d+)") - - -def parse_dataset_arg(spec: str, default_layout: str | None) -> DatasetTarget: - """Split a `name[:layout]` CLI value. `default_layout` fills in when no `:` is given.""" - if ":" in spec: - name, layout = spec.split(":", 1) - return DatasetTarget(name=name, layout=layout or None) - return DatasetTarget(name=spec, layout=default_layout) - - -def parse_dim(stdout: str) -> int: - """Pull `dim=N` out of the `## ...` header line.""" - match = DIM_RE.search(stdout) - if not match: - raise RuntimeError(f"could not find dim=N in header:\n{stdout}") - return int(match.group(1)) - - -def parse_table(stdout: str) -> dict[str, float]: - """Pull `metric -> value` rows out of the tabled stdout.""" - row_re = re.compile(r"│\s*(.+?)\s*│\s*([^│]+?)\s*│") - values: dict[str, float] = {} - for line in stdout.splitlines(): - match = row_re.match(line) - if not match: - continue - metric, value = match.group(1).strip(), match.group(2).strip() - if metric in METRIC_NAMES: - values[metric] = float(value) - missing = [m for m in METRIC_NAMES if m not in values] - if missing: - raise RuntimeError(f"could not parse metrics {missing} from:\n{stdout}") - return values - - -def run_one( - binary: Path, - target: DatasetTarget, - bits: int, - samples: int, - seed: int, - rounds: int, -) -> Run: - cmd = [ - str(binary), - "distortion", - "--dataset", - target.name, - "--bits", - str(bits), - "--samples", - str(samples), - "--seed", - str(seed), - "--rounds", - str(rounds), - ] - if target.layout: - cmd.extend(["--layout", target.layout]) - layout_tag = f" layout={target.layout}" if target.layout else "" - print(f" running {target.name}{layout_tag} @ bits={bits} ...", file=sys.stderr) - result = subprocess.run(cmd, capture_output=True, text=True, check=True) - return Run( - target=target, - dim=parse_dim(result.stdout), - bits=bits, - values=parse_table(result.stdout), - ) - - -# Refined small-b values from `main.tex` line 273-274 ("for b = 1, 2, 3, 4 we have -# D_mse approx 0.36, 0.117, 0.03, 0.009"). Tighter than the general sqrt(3)*pi/2 * 4^(-b) -# upper bound, which is what we fall back to for b >= 5. -_NMSE_UPPER_REFINED = {1: 0.36, 2: 0.117, 3: 0.03, 4: 0.009} - - -def nmse_bound_stage1(bits: int) -> float: - """Paper's Stage-1 unit-norm reconstruction upper bound for TurboQuant_mse. - - From the Stage 1 theorem (`main.tex`, line 272): for a unit-norm vector `x` quantized - to `b` bits per coordinate, `E[||x - x'||^2] <= (sqrt(3)*pi/2) * 4^(-b)`. TurboQuant - internally normalizes each input before quantizing, so the bound applies to per-row - NMSE = `||x - x'||^2 / ||x||^2 = ||unit(x) - unit(x')||^2` directly. For small `b` - (1..=4) the paper gives tighter refined values; we splice those in. - """ - if bits in _NMSE_UPPER_REFINED: - return _NMSE_UPPER_REFINED[bits] - return (math.sqrt(3.0) * math.pi / 2.0) / (4.0**bits) - - -def nmse_lower_bound(bits: int) -> float: - """Paper's Shannon lower bound on Stage-1 unit-norm reconstruction. - - From `main.tex` line 297: `D_mse(Q) >= 1/4^b` for any randomized `b`-bit quantizer. - Independent of dimension; applies to NMSE for the same reason as the upper bound. - """ - return 1.0 / (4.0**bits) - - -def compression_ratio(bits: int, dim: int) -> float: - """Theoretical TurboQuant compression ratio vs f32 storage. - - Per the `vortex_tensor::encodings::turboquant` module docs, each vector is stored - as `padded_dim * bits / 8` bytes of quantized codes plus one f32 stored norm - (4 bytes), where `padded_dim` is the next power of two at least `dim`. The ratio is - nonlinear in `bits` because of POT padding and the per-vector norm overhead. - """ - padded_dim = 1 << (dim - 1).bit_length() if dim > 1 else 1 - per_vector_bytes = padded_dim * bits / 8.0 + 4.0 - original_bytes = dim * 4.0 - return original_bytes / per_vector_bytes - - -def cosine_sqerr_lower_bound(bits: int, dim: int) -> float: - """Paper's Shannon lower bound on Stage-2 squared inner-product distortion. - - From `main.tex` line 298: `D_prod(Q) >= ||y||^2 / d * 1/4^b` for any randomized - `b`-bit quantizer. With unit probes (`||y||^2 = 1`) this is `1 / (d * 4^b)`. - """ - return 1.0 / (dim * (4.0**bits)) - - -DATASET_PALETTE = [ - "#1f77b4", # blue - "#d62728", # red - "#2ca02c", # green - "#9467bd", # purple - "#ff7f0e", # orange - "#17becf", # teal - "#e377c2", # pink - "#8c564b", # brown - "#7f7f7f", # grey - "#bcbd22", # olive -] - -STAT_STYLES = [ - # (metric_suffix, label, linestyle, linewidth, marker) - ("mean", "mean", "-", 2.4, "o"), - ("max", "max", ":", 1.4, None), -] - - -def plot(runs: list[Run], args: argparse.Namespace) -> None: - by_dataset: dict[str, list[Run]] = {} - for r in runs: - by_dataset.setdefault(r.dataset, []).append(r) - for ds_runs in by_dataset.values(): - ds_runs.sort(key=lambda r: r.bits) - - plt.rcParams.update( - { - "font.size": 11, - "axes.titlesize": 13, - "axes.titleweight": "semibold", - "axes.labelsize": 11, - "axes.spines.top": False, - "axes.spines.right": False, - "axes.grid": True, - "grid.alpha": 0.25, - "grid.linewidth": 0.6, - "legend.frameon": False, - } - ) - - # GridSpec with a dedicated bottom strip for the caption so the long text gets a real - # subplot rect: no clipping by `bbox_inches`, no overlap with axis labels, no reliance - # on matplotlib's `wrap=True` heuristic. Plot row gets the lion's share so the bottom - # caption strip doesn't dominate visually; legends are anchored above the axes via - # `bbox_to_anchor` (see `add_legends`), and constrained_layout reserves space for them - # inside the plot row. - fig = plt.figure(figsize=(22, 9.5), constrained_layout=True) - gs = fig.add_gridspec(2, 3, height_ratios=[12, 1]) - axes = [fig.add_subplot(gs[0, i]) for i in range(3)] - caption_ax = fig.add_subplot(gs[1, :]) - caption_ax.axis("off") - fig.suptitle( - f"TurboQuant distortion vs bits per coordinate" - f" (samples={args.samples:,}, seed={args.seed}, rounds={args.rounds})", - fontsize=14, - fontweight="semibold", - ) - - dataset_colors = {ds: DATASET_PALETTE[i % len(DATASET_PALETTE)] for i, ds in enumerate(by_dataset)} - dataset_dims = {ds: ds_runs[0].dim for ds, ds_runs in by_dataset.items()} - - plot_panel( - axes[0], - by_dataset, - dataset_colors, - metric_prefix="reconstruction NMSE", - title=r"Reconstruction NMSE (per vector, $\|x - x^\prime\|^2 / \|x\|^2$)", - ylabel=r"$\|x - x^\prime\|^2 / \|x\|^2$", - ) - bits_axis = sorted({r.bits for r in runs}) - axes[0].plot( - bits_axis, - [nmse_bound_stage1(b) for b in bits_axis], - color="#222222", - linestyle=(0, (4, 2, 1, 2)), - linewidth=1.6, - zorder=0, - ) - axes[0].plot( - bits_axis, - [nmse_lower_bound(b) for b in bits_axis], - color="#222222", - linestyle=(0, (1, 2)), - linewidth=1.4, - zorder=0, - ) - - plot_panel( - axes[1], - by_dataset, - dataset_colors, - metric_prefix="decoded cosine sqerr", - title=r"Squared cosine error $(\cos(y_i, x_i) - \cos(y_i, x_i^\prime))^2$", - ylabel="squared error", - ) - for dataset, ds_runs in by_dataset.items(): - color = dataset_colors[dataset] - d = ds_runs[0].dim - bits = sorted({r.bits for r in ds_runs}) - axes[1].plot( - bits, - [cosine_sqerr_lower_bound(b, d) for b in bits], - color=color, - linestyle=(0, (1, 2)), - linewidth=1.0, - alpha=0.5, - zorder=0, - ) - - plot_compression_panel(axes[2], by_dataset, dataset_colors) - - add_legends(fig, axes, dataset_colors, dataset_dims) - caption_ax.text( - 0.5, - 1.0, - "NMSE upper bound uses the paper's refined small-b values for b<=4 and the " - "smooth sqrt(3)*pi/2 * 4^(-b) general formula for b>=5. Lower bounds are the " - "Shannon information-theoretic floor for any randomized b-bit quantizer. " - "Vortex ships TurboQuant Stage 1 only, so no Stage-2 inner-product upper " - "bound is drawn on the cosine panel. Probe vectors y_i are sampled iid " - "uniform on the unit sphere. Compression ratio is theoretical " - "(padded_dim * bits / 8 + 4 bytes per vector), excludes per-shard centroid " - "tables and file metadata.", - ha="center", - va="top", - fontsize=9, - color="#555555", - wrap=True, - transform=caption_ax.transAxes, - ) - - if args.output: - fig.savefig(args.output, dpi=140, bbox_inches="tight") - print(f"saved {args.output}", file=sys.stderr) - else: - plt.show() - - -def plot_panel( - ax, - by_dataset: dict[str, list[Run]], - dataset_colors: dict[str, str], - metric_prefix: str, - title: str, - ylabel: str, -) -> None: - for dataset, ds_runs in by_dataset.items(): - color = dataset_colors[dataset] - bits = [r.bits for r in ds_runs] - for stat_key, _label, linestyle, linewidth, marker in STAT_STYLES: - metric = f"{metric_prefix} {stat_key}" - ys = [r.values[metric] for r in ds_runs] - ax.plot( - bits, - ys, - color=color, - linestyle=linestyle, - linewidth=linewidth, - marker=marker, - markersize=6, - markerfacecolor=color, - markeredgecolor="white", - markeredgewidth=0.8, - alpha=0.95 if marker else 0.75, - ) - ax.set_yscale("log") - ax.set_xlabel("bits per coordinate") - ax.set_ylabel(ylabel) - ax.set_title(title) - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) - ax.grid(True, which="major", linewidth=0.7, alpha=0.45) - ax.grid(True, which="minor", linewidth=0.4, alpha=0.22) - ax.minorticks_on() - # Only the integer bit-widths should get an x-axis line; suppress the in-between - # minor ticks that `minorticks_on()` adds (the y-axis minors stay - they're useful - # on the log scale). - ax.xaxis.set_minor_locator(NullLocator()) - - -def plot_compression_panel( - ax, - by_dataset: dict[str, list[Run]], - dataset_colors: dict[str, str], -) -> None: - bits_axis = sorted({r.bits for runs in by_dataset.values() for r in runs}) - for dataset, ds_runs in by_dataset.items(): - color = dataset_colors[dataset] - d = ds_runs[0].dim - padded = 1 << (d - 1).bit_length() if d > 1 else 1 - suffix = f" (padded {padded})" if padded != d else " (no padding)" - ax.plot( - bits_axis, - [compression_ratio(b, d) for b in bits_axis], - color=color, - linestyle="-", - linewidth=2.4, - marker="o", - markersize=6, - markerfacecolor=color, - markeredgecolor="white", - markeredgewidth=0.8, - label=f"{dataset}{suffix}", - ) - ax.set_xlabel("bits per coordinate") - ax.set_ylabel(r"ratio vs f32 (= $4d \,/\, (\mathrm{padded}\!\cdot\! b/8 + 4)$)") - ax.set_title("Compression ratio (theoretical)") - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) - ax.grid(True, which="major", linewidth=0.7, alpha=0.45) - ax.grid(True, which="minor", linewidth=0.4, alpha=0.22) - ax.minorticks_on() - ax.xaxis.set_minor_locator(NullLocator()) - ax.legend( - title="dataset", - loc="lower center", - bbox_to_anchor=(0.5, 1.02), - ncol=2, - fontsize=9, - title_fontsize=10, - ) - - -def add_legends( - fig, - axes, - dataset_colors: dict[str, str], - dataset_dims: dict[str, int], -) -> None: - dataset_handles = [ - Line2D( - [], - [], - color=color, - linewidth=2.4, - marker="o", - markersize=6, - markerfacecolor=color, - markeredgecolor="white", - markeredgewidth=0.8, - label=f"{dataset} (d = {dataset_dims[dataset]})", - ) - for dataset, color in dataset_colors.items() - ] - stat_handles = [ - Line2D( - [], - [], - color="#333333", - linestyle=linestyle, - linewidth=linewidth, - marker=marker, - markersize=6 if marker else 0, - markerfacecolor="#333333", - markeredgecolor="white", - markeredgewidth=0.8, - label=label, - ) - for _, label, linestyle, linewidth, marker in STAT_STYLES - ] - nmse_upper_handle = Line2D( - [], - [], - color="#222222", - linestyle=(0, (4, 2, 1, 2)), - linewidth=1.6, - label=( - r"upper bound: " - r"$D_{\mathrm{mse}} \leq \frac{\sqrt{3}\,\pi}{2}\, 4^{-b}$ (refined for $b\!\leq\!4$)" - ), - ) - nmse_lower_handle = Line2D( - [], - [], - color="#222222", - linestyle=(0, (1, 2)), - linewidth=1.4, - label=r"lower bound: $D_{\mathrm{mse}} \geq 4^{-b}$", - ) - cosine_lower_handle = Line2D( - [], - [], - color="#444444", - linestyle=(0, (1, 2)), - linewidth=1.0, - alpha=0.5, - label=r"lower bound: $D_{\mathrm{prod}} \geq \frac{1}{d}\, 4^{-b}$", - ) - - axes[0].legend( - handles=dataset_handles + [nmse_upper_handle, nmse_lower_handle], - title="dataset / bound", - loc="lower center", - bbox_to_anchor=(0.5, 1.02), - ncol=2, - fontsize=10, - title_fontsize=10, - ) - axes[1].legend( - handles=stat_handles + [cosine_lower_handle], - title="statistic / bound", - loc="lower center", - bbox_to_anchor=(0.5, 1.02), - ncol=3, - fontsize=10, - title_fontsize=10, - ) - - -def main() -> None: - parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument( - "--dataset", - action="append", - required=True, - help=( - "Dataset to sweep (repeat to compare multiple). Optionally suffix " - "`:layout` to pin a specific train layout for that dataset, e.g. " - "`--dataset cohere-small-100k:single`. If omitted, the bench picks " - "the dataset's only layout, or errors if there are several." - ), - ) - parser.add_argument( - "--layout", - default=None, - help=("Default train layout applied to any `--dataset` entry that doesn't pin its own with `:layout`."), - ) - parser.add_argument("--samples", type=int, default=65536) - parser.add_argument( - "--bits", - type=int, - nargs="+", - default=[1, 2, 3, 4, 5, 6, 7, 8], - help="Bit widths to sweep (default: 1..=8).", - ) - parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--rounds", type=int, default=3) - parser.add_argument( - "--binary", - type=Path, - default=DEFAULT_BINARY, - help=f"Path to vector-search-bench (default: {DEFAULT_BINARY}).", - ) - parser.add_argument( - "--output", - type=Path, - default=None, - help="If set, save the chart to this path instead of opening a window.", - ) - args = parser.parse_args() - - print("building vector-search-bench (release) ...", file=sys.stderr) - subprocess.run( - ["cargo", "build", "-p", "vector-search-bench", "--release"], - cwd=REPO_ROOT, - check=True, - ) - - if not args.binary.exists(): - sys.exit(f"binary not found at {args.binary} after build") - - targets = [parse_dataset_arg(spec, args.layout) for spec in args.dataset] - - runs: list[Run] = [] - for target in targets: - layout_tag = f" (layout={target.layout})" if target.layout else "" - print( - f"sweeping {target.name}{layout_tag} over bits {args.bits} ...", - file=sys.stderr, - ) - for bits in args.bits: - runs.append( - run_one( - args.binary, - target, - bits, - args.samples, - args.seed, - args.rounds, - ) - ) - - print_summary(runs) - plot(runs, args) - - -def print_summary(runs: list[Run]) -> None: - print() - print("Summary (one row per (dataset, bits)):") - header = ["dataset", "dim", "bits"] + METRIC_NAMES - widths = [max(len(h), 14) for h in header] - print(" " + " ".join(h.ljust(w) for h, w in zip(header, widths))) - for r in runs: - cells = [r.dataset, str(r.dim), str(r.bits)] + [f"{r.values[m]:.3e}" for m in METRIC_NAMES] - print(" " + " ".join(c.ljust(w) for c, w in zip(cells, widths))) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/vector-search-bench/src/compression.rs b/benchmarks/vector-search-bench/src/compression.rs deleted file mode 100644 index 9b40cb15544..00000000000 --- a/benchmarks/vector-search-bench/src/compression.rs +++ /dev/null @@ -1,110 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Vector compression flavors exercised by the benchmark. -//! -//! Each [`VectorFlavor`] variant maps to a [`vortex::file::WriteStrategyBuilder`] configuration -//! applied to the same input data. -//! -//! The benchmark writes one `.vortex` file per flavor per data file, then scans them all with the -//! same query so the comparison is apples-to-apples with the Parquet files. - -use clap::ValueEnum; -use vortex::array::ArrayId; -use vortex::array::scalar_fn::ScalarFnVTable; -use vortex::file::ALLOWED_ENCODINGS; -use vortex::file::VortexWriteOptions; -use vortex::file::WriteOptionsSessionExt; -use vortex::file::WriteStrategyBuilder; -use vortex::session::VortexSession; -use vortex::utils::aliases::hash_set::HashSet; -use vortex_bench::Format; -use vortex_btrblocks::BtrBlocksCompressorBuilder; -use vortex_tensor::encodings::l2_denorm::L2DenormScheme; -use vortex_tensor::scalar_fns::l2_denorm::L2Denorm; -use vortex_tensor::scalar_fns::sorf_transform::SorfTransform; - -/// Every [`VectorFlavor`] variant in CLI-help order. -pub const ALL_VECTOR_FLAVORS: &[VectorFlavor] = - &[VectorFlavor::Uncompressed, VectorFlavor::TurboQuant]; - -/// One write-side compression configuration we measure. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, ValueEnum)] -pub enum VectorFlavor { - /// `BtrBlocksCompressorBuilder::empty()` - #[clap(name = "vortex-uncompressed")] - Uncompressed, - /// `BtrBlocksCompressorBuilder::default().with_turboquant()`. - #[clap(name = "vortex-turboquant")] - TurboQuant, - // TODO(connor): We will want to add `Default` here which is just the default compressor. -} - -impl VectorFlavor { - /// Stable kebab-cased label used in CLI args and metric names. - pub fn label(&self) -> &'static str { - match self { - VectorFlavor::Uncompressed => "vortex-uncompressed", - VectorFlavor::TurboQuant => "vortex-turboquant", - } - } - - /// The `target.format` value emitted on measurements for this flavor. Both flavors produce - /// `.vortex` files, so the compression label carries the flavor split. - pub fn as_format(&self) -> Format { - match self { - VectorFlavor::Uncompressed => Format::OnDiskVortex, - VectorFlavor::TurboQuant => Format::OnDiskVortex, - } - } - - /// Subdirectory name under the per-dataset cache root used to store this flavor's `.vortex` - /// files. - pub fn dir_name(&self) -> &'static str { - match self { - VectorFlavor::Uncompressed => "vortex-uncompressed", - VectorFlavor::TurboQuant => "vortex-turboquant", - } - } - - /// Build the [`vortex::file::WriteStrategyBuilder`]-backed write options for this flavor. - /// - /// TurboQuant produces `L2Denorm(SorfTransform(...))` which the default file - /// `ALLOWED_ENCODINGS` set rejects on normalization — we extend the allow-list with the two - /// scalar-fn array IDs the scheme actually emits. - pub fn create_write_options(&self, session: &VortexSession) -> VortexWriteOptions { - let strategy = match self { - VectorFlavor::Uncompressed => { - // Even though this is uncompressed, we still want to denormalize the data first so - // that the results are fair. - let compressor = BtrBlocksCompressorBuilder::empty() - .with_new_scheme(&L2DenormScheme) - .build(); - - let mut allowed: HashSet = ALLOWED_ENCODINGS.clone(); - allowed.insert(L2Denorm.id()); - - WriteStrategyBuilder::default() - .with_compressor(compressor) - .with_allow_encodings(allowed) - .build() - } - VectorFlavor::TurboQuant => { - let compressor = BtrBlocksCompressorBuilder::default() - .with_turboquant() - .build(); - - let mut allowed: HashSet = ALLOWED_ENCODINGS.clone(); - allowed.insert(L2Denorm.id()); - allowed.insert(SorfTransform.id()); - - WriteStrategyBuilder::default() - .with_compressor(compressor) - .with_allow_encodings(allowed) - .build() - } - }; - - session.write_options().with_strategy(strategy) - } -} diff --git a/benchmarks/vector-search-bench/src/display.rs b/benchmarks/vector-search-bench/src/display.rs deleted file mode 100644 index a74eec46b01..00000000000 --- a/benchmarks/vector-search-bench/src/display.rs +++ /dev/null @@ -1,133 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Local table renderer for the vector-search benchmark. -//! -//! Groups columns by **flavor** (`vortex-uncompressed`, `vortex-turboquant`) rather than by -//! [`vortex_bench::Format`], because the two Vortex flavors share a single -//! `Format::OnDiskVortex`/`Format::VortexLossy` pair and the generic -//! [`vortex_bench::display::render_table`] groups by Format. Local renderer keeps the -//! column-per-flavor invariant intact without introducing a new global Format value. -//! -//! Output rows: -//! -//! ```text -//! Metric | vortex-uncompressed | vortex-turboquant -//! ------------------ + ------------------- + ----------------- -//! scan wall (mean) | 485 ms | 212 ms -//! scan wall (median) | 490 ms | 215 ms -//! matches | 42 | 39 -//! rows scanned | 10,000,000 | 10,000,000 -//! bytes scanned | 30.5 GB | 7.62 GB -//! rows / sec | 5.2e6 | 1.2e7 -//! ``` - -use std::io::Write; - -use anyhow::Result; -use tabled::settings::Style; - -use crate::compression::VectorFlavor; -use crate::prepare::CompressedVortexDataset; -use crate::scan::ScanTiming; - -/// Final column-per-flavor row set for one dataset. -pub struct DatasetReport<'a> { - pub dataset_name: &'a str, - pub vortex_results: &'a [(VectorFlavor, &'a CompressedVortexDataset, &'a ScanTiming)], -} - -/// Render the full report into the given writer as a tabled table. -pub fn render(report: &DatasetReport<'_>, writer: &mut dyn Write) -> Result<()> { - let mut headers: Vec = vec!["metric".to_owned()]; - for &(flavor, ..) in report.vortex_results { - headers.push(flavor.label().to_owned()); - } - - let rows: Vec> = vec![ - make_row("scan wall (mean)", report, |_, _, scan| { - format_duration(scan.mean) - }), - make_row("scan wall (median)", report, |_, _, scan| { - format_duration(scan.median) - }), - make_row("matches", report, |_, _, scan| scan.matches.to_string()), - make_row("rows scanned", report, |_, _, scan| { - scan.rows_scanned.to_string() - }), - make_row("bytes scanned", report, |_, _, scan| { - format_bytes(scan.bytes_scanned) - }), - make_row("rows / sec", report, |_, _, scan| { - format_throughput_rows(scan.rows_scanned, scan.mean) - }), - ]; - - writeln!(writer, "## {}", report.dataset_name)?; - let mut builder = tabled::builder::Builder::new(); - builder.push_record(headers); - for row in rows { - builder.push_record(row); - } - let mut table = builder.build(); - table.with(Style::modern()); - writeln!(writer, "{table}")?; - Ok(()) -} - -fn make_row(metric: &str, report: &DatasetReport<'_>, vortex_cell: F) -> Vec -where - F: Fn(VectorFlavor, &CompressedVortexDataset, &ScanTiming) -> String, -{ - let mut row = vec![metric.to_owned()]; - for &(flavor, prep, scan) in report.vortex_results { - row.push(vortex_cell(flavor, prep, scan)); - } - row -} - -fn format_duration(d: std::time::Duration) -> String { - let secs = d.as_secs_f64(); - if secs >= 1.0 { - format!("{secs:.2} s") - } else if secs >= 1e-3 { - format!("{:.1} ms", secs * 1e3) - } else { - format!("{:.1} µs", secs * 1e6) - } -} - -fn format_bytes(bytes: u64) -> String { - const UNITS: &[&str] = &["B", "KiB", "MiB", "GiB", "TiB"]; - let mut value = bytes as f64; - let mut unit = UNITS[0]; - for next in &UNITS[1..] { - if value < 1024.0 { - break; - } - value /= 1024.0; - unit = next; - } - if unit == "B" { - format!("{bytes} B") - } else { - format!("{value:.2} {unit}") - } -} - -fn format_throughput_rows(rows: u64, wall: std::time::Duration) -> String { - let secs = wall.as_secs_f64(); - if secs <= 0.0 { - return "—".to_owned(); - } - let rps = rows as f64 / secs; - if rps >= 1e9 { - format!("{:.2}G", rps / 1e9) - } else if rps >= 1e6 { - format!("{:.2}M", rps / 1e6) - } else if rps >= 1e3 { - format!("{:.2}K", rps / 1e3) - } else { - format!("{rps:.0}") - } -} diff --git a/benchmarks/vector-search-bench/src/distortion.rs b/benchmarks/vector-search-bench/src/distortion.rs deleted file mode 100644 index 7a05410522e..00000000000 --- a/benchmarks/vector-search-bench/src/distortion.rs +++ /dev/null @@ -1,370 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant distortion measurement on real vector datasets. -//! -//! Reports the normalized mean square error (`||x - x'||^2 / ||x||^2`) and the squared -//! cosine-similarity error (`(cos(y_i, x_i) - cos(y_i, x'_i))^2`) against a set of independently -//! sampled unit-norm probe vectors `y_i`, after a full encode and decode roundtrip through the -//! [`vortex_tensor::encodings::turboquant`] scheme. -//! -//! NMSE rather than raw SSE because TurboQuant internally normalizes each input to unit -//! norm before quantizing (storing `||x||` separately), so the paper's Stage-1 bound -//! `E[||unit(x) - unit(x')||^2] <= (sqrt(3) * pi / 2) * 4^(-b)` applies to NMSE directly; -//! raw `||x - x'||^2` sits at `||x||^2` times that bound and isn't comparable across rows. - -use std::io::Write; - -use anyhow::Context; -use anyhow::Result; -use anyhow::bail; -use rand::SeedableRng; -use rand::rngs::StdRng; -use rand_distr::Distribution; -use rand_distr::Normal; -use tabled::settings::Style; -use vortex::array::ArrayRef; -use vortex::array::ExecutionCtx; -use vortex::array::IntoArray; -use vortex::array::VortexSessionExecute; -use vortex::array::arrays::ExtensionArray; -use vortex::array::arrays::FixedSizeListArray; -use vortex::array::arrays::PrimitiveArray; -use vortex::array::arrays::Struct; -use vortex::array::arrays::StructArray; -use vortex::array::arrays::extension::ExtensionArrayExt; -use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex::array::arrays::struct_::StructArrayExt; -use vortex::error::VortexExpect; -use vortex_bench::conversions::parquet_to_vortex_chunks; -use vortex_bench::vector_dataset; -use vortex_bench::vector_dataset::TrainLayout; -use vortex_bench::vector_dataset::VectorDataset; -use vortex_tensor::encodings::turboquant::TurboQuantConfig; -use vortex_tensor::encodings::turboquant::turboquant_encode; - -use crate::SESSION; -use crate::ingest::transform_chunk; - -/// Inputs to a distortion run. -#[derive(Debug, Clone)] -pub struct DistortionConfig { - /// Dataset to load vectors from. - pub dataset: VectorDataset, - /// Train-split layout (used to locate the local parquet shards). - pub layout: TrainLayout, - /// Bits per quantized coordinate. - pub bits: u8, - /// Seed for the SORF rotation. - pub seed: u64, - /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. - pub rounds: u8, - /// Number of base vectors to sample from the first train shard. - pub samples: usize, -} - -/// Mean, median, and max of a sample of distortion measurements. -#[derive(Debug, Clone)] -pub struct DistortionStats { - /// Arithmetic mean. - pub mean: f32, - /// Median (mid element after a partial sort). - pub median: f32, - /// Maximum. - pub max: f32, -} - -/// Per-dataset distortion report ready to render as markdown. -#[derive(Debug, Clone)] -pub struct DistortionReport { - /// Dataset the vectors came from. - pub dataset: VectorDataset, - /// Train-split layout used to locate the shard. - pub layout: TrainLayout, - /// Vector dimensionality. - pub dim: u32, - /// Bits per quantized coordinate. - pub bits: u8, - /// Seed for the SORF rotation. - pub seed: u64, - /// Number of SORF rounds. - pub rounds: u8, - /// Number of base vectors sampled. - pub samples: usize, - /// Normalized squared reconstruction error per row, `||x - x'||^2 / ||x||^2`. - pub reconstruction: DistortionStats, - /// Squared cosine-similarity error per row against a random unit-norm probe `y_i`, - /// `(cos(y_i, x_i) - cos(y_i, x'_i))^2`. - pub decoded_cosine: DistortionStats, -} - -/// Compute reconstruction error and cosine-similarity error for a TurboQuant roundtrip. -pub async fn run_distortion(config: &DistortionConfig) -> Result { - let dataset = config.dataset; - let layout = config.layout; - - let paths = vector_dataset::download(dataset, layout) - .await - .with_context(|| format!("download {}", dataset.name()))?; - let train_path = paths - .train_files - .first() - .with_context(|| format!("dataset {} has no train shards", dataset.name()))? - .clone(); - - let mut ctx = SESSION.create_execution_ctx(); - - let chunked = parquet_to_vortex_chunks(train_path).await?; - let struct_array: StructArray = chunked.into_array().execute(&mut ctx)?; - let transformed = transform_chunk(struct_array.into_array(), &mut ctx)?; - let emb_full = transformed - .as_opt::() - .with_context(|| { - format!( - "transform_chunk did not return a Struct, got {}", - transformed.dtype() - ) - })? - .unmasked_field_by_name("emb") - .context("transformed chunk missing `emb` field")? - .clone(); - - let n = config.samples.min(emb_full.len()); - if n == 0 { - bail!( - "distortion: need at least one sampled vector, got 0 (dataset {})", - dataset.name(), - ); - } - let emb = emb_full.slice(0..n)?; - - let original = extract_flat_f32(&emb, &mut ctx)?; - let dim = pairs_per_row(&original, n)?; - - let tq_config = TurboQuantConfig { - bit_width: config.bits, - seed: config.seed, - num_rounds: config.rounds, - }; - let encoded = turboquant_encode(emb, &tq_config, &mut ctx)?; - let decoded_ext: ExtensionArray = encoded.execute(&mut ctx)?; - let decoded = decoded_ext.into_array(); - let decoded_flat = extract_flat_f32(&decoded, &mut ctx)?; - - let reconstruction = stats(&reconstruction_nmse(&original, &decoded_flat, dim, n)); - - // Sample independent unit-norm probe vectors `y_i` (one per row). The TurboQuant Stage-2 - // bound `E[( - )^2] <= sqrt(3) * pi^2 / d * 4^(-b)` holds for any fixed `y`, - // so drawing `y` from the unit sphere is a reasonable empirical sweep. - let probes = random_unit_vectors(n, dim, config.seed)?; - let decoded_cosine = stats(&squared_cosine_errors( - &original, - &decoded_flat, - &probes, - dim, - n, - )); - - Ok(DistortionReport { - dataset, - layout, - dim: u32::try_from(dim).context("dim must fit in u32")?, - bits: config.bits, - seed: config.seed, - rounds: config.rounds, - samples: n, - reconstruction, - decoded_cosine, - }) -} - -/// Extract a flat `f32` slice from a `Vector` extension array. -fn extract_flat_f32(array: &ArrayRef, ctx: &mut ExecutionCtx) -> Result> { - let ext: ExtensionArray = array.clone().execute(ctx)?; - let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; - let elements: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - Ok(elements.as_slice::().to_vec()) -} - -fn pairs_per_row(flat: &[f32], num_rows: usize) -> Result { - if num_rows == 0 { - bail!("distortion: cannot derive dim from zero rows"); - } - if !flat.len().is_multiple_of(num_rows) { - bail!( - "distortion: flat element count {} not divisible by row count {num_rows}", - flat.len(), - ); - } - Ok(flat.len() / num_rows) -} - -/// Normalized squared reconstruction error per row, `||x - x'||^2 / ||x||^2`. Zero-norm -/// rows are dropped from the sample because NMSE is undefined when `||x|| = 0`, and our -/// vector datasets are not expected to contain zero vectors. -fn reconstruction_nmse( - original: &[f32], - reconstructed: &[f32], - dim: usize, - num_rows: usize, -) -> Vec { - (0..num_rows) - .filter_map(|row| { - let start = row * dim; - let end = start + dim; - let orig = &original[start..end]; - let recon = &reconstructed[start..end]; - let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); - if norm_sq == 0.0 { - return None; - } - let err_sq: f32 = orig - .iter() - .zip(recon.iter()) - .map(|(&a, &b)| (a - b) * (a - b)) - .sum(); - Some(err_sq / norm_sq) - }) - .collect() -} - -/// Sample `num_rows` independent `dim`-D vectors with standard-normal entries and normalize each -/// row to unit L2 norm. Used as probe vectors `y_i` for the squared cosine-similarity error. -fn random_unit_vectors(num_rows: usize, dim: usize, seed: u64) -> Result> { - let mut rng = StdRng::seed_from_u64(seed); - let normal = Normal::new(0.0_f32, 1.0).context("constructing Normal(0, 1)")?; - let mut buf = vec![0.0_f32; num_rows * dim]; - for row in 0..num_rows { - let start = row * dim; - let end = start + dim; - for v in &mut buf[start..end] { - *v = normal.sample(&mut rng); - } - let norm = buf[start..end].iter().map(|&v| v * v).sum::().sqrt(); - if norm > 0.0 { - for v in &mut buf[start..end] { - *v /= norm; - } - } - } - Ok(buf) -} - -/// Cosine similarity of two equal-length vectors, returning `0.0` if either has zero norm. -/// A zero-norm decoded vector represents genuine quantizer failure, so the caller still -/// gets a defined per-row error that reflects the lost direction. -fn cosine(a: &[f32], b: &[f32]) -> f32 { - let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum(); - let norm_a: f32 = a.iter().map(|&v| v * v).sum::().sqrt(); - let norm_b: f32 = b.iter().map(|&v| v * v).sum::().sqrt(); - let denom = norm_a * norm_b; - if denom == 0.0 { 0.0 } else { dot / denom } -} - -/// Per-row squared cosine-similarity error against probe `y_i`, -/// `(cos(y_i, x_i) - cos(y_i, x'_i))^2`. Rows whose original `x_i` has zero norm are -/// dropped, matching [`reconstruction_nmse`]. -fn squared_cosine_errors( - original: &[f32], - reconstructed: &[f32], - probes: &[f32], - dim: usize, - num_rows: usize, -) -> Vec { - (0..num_rows) - .filter_map(|row| { - let start = row * dim; - let end = start + dim; - let xi = &original[start..end]; - let xi_dec = &reconstructed[start..end]; - let yi = &probes[start..end]; - if xi.iter().map(|&v| v * v).sum::() == 0.0 { - return None; - } - let diff = cosine(yi, xi) - cosine(yi, xi_dec); - Some(diff * diff) - }) - .collect() -} - -fn stats(samples: &[f32]) -> DistortionStats { - if samples.is_empty() { - return DistortionStats { - mean: f32::NAN, - median: f32::NAN, - max: f32::NAN, - }; - } - - let sum: f64 = samples.iter().map(|&v| f64::from(v)).sum(); - #[expect( - clippy::cast_possible_truncation, - reason = "casting an f64 mean back to f32 is intentional and matches the input precision" - )] - let mean = (sum / samples.len() as f64) as f32; - - let mut sorted = samples.to_vec(); - sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); - let mid = sorted.len() / 2; - let median = if sorted.len() % 2 == 1 { - sorted[mid] - } else { - 0.5 * (sorted[mid - 1] + sorted[mid]) - }; - - let max = samples - .iter() - .copied() - .reduce(f32::max) - .vortex_expect("samples is non-empty per the early return above"); - - DistortionStats { mean, median, max } -} - -impl DistortionReport { - /// Render the report as a markdown header line followed by a tabled table. - pub fn render(&self, writer: &mut dyn Write) -> Result<()> { - writeln!( - writer, - "## {} | dim={} | layout={} | bits={} | samples={} | seed={} | rounds={}", - self.dataset.name(), - self.dim, - self.layout.label(), - self.bits, - self.samples, - self.seed, - self.rounds, - )?; - - let rows: &[(&str, f32)] = &[ - ("reconstruction NMSE mean", self.reconstruction.mean), - ("reconstruction NMSE median", self.reconstruction.median), - ("reconstruction NMSE max", self.reconstruction.max), - ("decoded cosine sqerr mean", self.decoded_cosine.mean), - ("decoded cosine sqerr median", self.decoded_cosine.median), - ("decoded cosine sqerr max", self.decoded_cosine.max), - ]; - - let mut builder = tabled::builder::Builder::new(); - builder.push_record(["metric", "value"]); - for &(metric, value) in rows { - builder.push_record([metric.to_owned(), format_metric(value)]); - } - let mut table = builder.build(); - table.with(Style::modern()); - writeln!(writer, "{table}")?; - Ok(()) - } -} - -fn format_metric(value: f32) -> String { - if value.is_nan() { - "nan".to_owned() - } else if value == 0.0 { - "0".to_owned() - } else if value.abs() < 1e-3 || value.abs() >= 1e4 { - format!("{value:.3e}") - } else { - format!("{value:.6}") - } -} diff --git a/benchmarks/vector-search-bench/src/expression.rs b/benchmarks/vector-search-bench/src/expression.rs deleted file mode 100644 index 1f548072be1..00000000000 --- a/benchmarks/vector-search-bench/src/expression.rs +++ /dev/null @@ -1,97 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Cosine-similarity filter [`Expression`]s used by the file-scan path. -//! -//! We can easily build a cosine similarity filter by hand: -//! -//! ```text -//! gt( -//! cosine_similarity(col("emb"), lit(query_scalar)), -//! lit(threshold), -//! ) -//! ``` -//! -//! The query is wrapped as `Scalar::extension::(Scalar::fixed_size_list(F32, ...))` so -//! [`CosineSimilarity`] can treat it as a single-row `Vector` value during evaluation. -//! -//! At scan time the literal expands into a `ConstantArray` whose row count matches the chunk batch -//! size. - -use anyhow::Result; -use vortex::array::EmptyMetadata; -use vortex::array::expr::Expression; -use vortex::array::expr::col; -use vortex::array::expr::gt; -use vortex::array::expr::lit; -use vortex::array::scalar::Scalar; -use vortex::array::scalar_fn::EmptyOptions; -use vortex::array::scalar_fn::ScalarFnVTableExt; -use vortex::dtype::DType; -use vortex::dtype::Nullability; -use vortex::dtype::PType; -use vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity; -use vortex_tensor::vector::Vector; - -/// Build the filter `cosine_similarity(emb, query) > threshold`. -pub fn similarity_filter(query: &[f32], threshold: f32) -> Result { - // Empty queries short-circuit to a literal `false`, so scans return no rows instead of trying - // to evaluate cosine similarity on a zero-dimensional vector. - if query.is_empty() { - return Ok(lit(false)); - } - - let query_lit = lit(query_scalar(query)?); - let cosine = CosineSimilarity.new_expr(EmptyOptions, [col("emb"), query_lit]); - Ok(gt(cosine, lit(threshold))) -} - -/// Wrap a query vector as `Scalar::extension::(Scalar::fixed_size_list(F32, ...))`. -pub fn query_scalar(query: &[f32]) -> Result { - let children: Vec = query - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let fsl = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - Ok(Scalar::extension::(EmptyMetadata, fsl)) -} - -/// Project just the `emb` column. Used by the throughput-only scan path. -pub fn emb_projection() -> Expression { - col("emb") -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn query_scalar_accepts_empty_query() { - let scalar = query_scalar(&[]).unwrap(); - match scalar.dtype() { - DType::Extension(_) => {} - other => panic!("expected Extension, got {other}"), - } - } - - #[test] - fn query_scalar_builds_extension_dtype() { - let scalar = query_scalar(&[1.0, 0.0, 0.0]).unwrap(); - match scalar.dtype() { - DType::Extension(_) => {} - other => panic!("expected Extension, got {other}"), - } - } - - #[test] - fn similarity_filter_uses_gt_operator() { - let expr = similarity_filter(&[1.0, 0.0, 0.0], 0.5).unwrap(); - // Quick sanity check: the printed form contains the operator and the threshold so - // future refactors that change the structure get caught here. - let printed = format!("{expr:?}"); - assert!(printed.contains("Gt") || printed.contains(">"), "{printed}"); - } -} diff --git a/benchmarks/vector-search-bench/src/ingest.rs b/benchmarks/vector-search-bench/src/ingest.rs deleted file mode 100644 index 0e3071483a1..00000000000 --- a/benchmarks/vector-search-bench/src/ingest.rs +++ /dev/null @@ -1,221 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Per-chunk ingest transform. -//! -//! Bridges the parquet record-batch stream and the Vortex file writer: -//! -//! 1. Project the `emb` column out of each struct chunk. -//! 2. Rewrap the `emb` column as `Extension>` via -//! [`vortex_bench::vector_dataset::list_to_vector_ext`]. -//! 3. Detect the FSL element ptype at runtime and cast `f64` -> `f32` when needed. Detection is -//! from the arrow schema rather than a catalog declaration so upstream parquets whose actual -//! precision disagrees with the catalog still ingest correctly. After this point all -//! downstream code (compression, scan, recall) is f32-only. -//! 4. Optionally project the `scalar_labels` column through unchanged so future filtered-search -//! benchmarks have it without re-ingest. -//! 5. Repackage as `Struct { id: i64, emb: Vector, scalar_labels: ??? }`. - -use anyhow::Context; -use anyhow::Result; -use anyhow::bail; -use anyhow::ensure; -use vortex::array::ArrayRef; -use vortex::array::EmptyMetadata; -use vortex::array::ExecutionCtx; -use vortex::array::IntoArray; -use vortex::array::arrays::ExtensionArray; -use vortex::array::arrays::FixedSizeListArray; -use vortex::array::arrays::PrimitiveArray; -use vortex::array::arrays::Struct; -use vortex::array::arrays::StructArray; -use vortex::array::arrays::extension::ExtensionArrayExt; -use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex::array::arrays::struct_::StructArrayExt; -use vortex::array::validity::Validity; -use vortex::buffer::Buffer; -use vortex::dtype::DType; -use vortex::dtype::PType; -use vortex::dtype::extension::ExtDType; -use vortex_bench::vector_dataset::list_to_vector_ext; -use vortex_tensor::vector::AnyVector; -use vortex_tensor::vector::Vector; - -/// Apply the transform to a single struct chunk and return the rebuilt chunk. -/// -/// `chunk` must be a non-chunked `Struct { id: i64, emb: List }`, where all of the list -/// elements are -/// -/// The returned array is always a `Struct { id: i64, emb: Vector }`. -pub fn transform_chunk(chunk: ArrayRef, ctx: &mut ExecutionCtx) -> Result { - let struct_view = chunk - .as_opt::() - .with_context(|| format!("ingest: expected struct chunk, got dtype {}", chunk.dtype()))?; - - let id = struct_view - .unmasked_field_by_name("id") - .context("ingest: chunk missing `id` column")? - .clone(); - let emb = struct_view - .unmasked_field_by_name("emb") - .context("ingest: chunk missing `emb` column")? - .clone(); - - let emb_ext: ExtensionArray = list_to_vector_ext(emb)?.execute(ctx)?; - - // Detect the actual FSL element ptype from the extension storage dtype. The dataset catalog - // cannot be trusted here: at least one upstream parquet (`sift-medium-5m`) ships f64 - // embeddings despite the catalog advertising f32. - let element_ptype = { - let storage_dtype = emb_ext.storage_array().dtype(); - match storage_dtype { - DType::FixedSizeList(elem, ..) => match elem.as_ref() { - DType::Primitive(ptype, _) => *ptype, - other => bail!("ingest: expected primitive FSL element dtype, got {other}"), - }, - other => bail!("ingest: expected FSL storage dtype, got {other}"), - } - }; - - let f32_vector_array = match element_ptype { - PType::F32 => emb_ext.into_array(), - PType::F64 => convert_f64_to_f32_vectors(&emb_ext, ctx)?, - other => bail!("ingest: unsupported emb element ptype {other}, expected f32 or f64"), - }; - - let fields = [("id", id), ("emb", f32_vector_array)]; - Ok(StructArray::from_fields(&fields)?.into_array()) -} - -/// Convert a `Vector` extension array down to `Vector`. -/// -/// This conversion is lossy, but we are generally ok with this because most vector search -/// operations do not demand a high amount of precision. -fn convert_f64_to_f32_vectors(ext: &ExtensionArray, ctx: &mut ExecutionCtx) -> Result { - ensure!(ext.ext_dtype().is::()); - - let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; - let validity = fsl.validity()?; - let elements: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - ensure!(elements.ptype() == PType::F64); - - let dim = match fsl.dtype() { - DType::FixedSizeList(_, dim, _) => *dim, - other => bail!("cast_vector_ext_to_f32: expected FSL dtype, got {other}"), - }; - - let f64_slice = elements.as_slice::(); - - #[expect( - clippy::cast_possible_truncation, - reason = "this is intentionally lossy" - )] - let f32_buf: Buffer = f64_slice - .iter() - .copied() - .map(|double| double as f32) - .collect(); - - let f32_elements = PrimitiveArray::new::(f32_buf, Validity::NonNullable).into_array(); - let new_fsl = FixedSizeListArray::try_new(f32_elements, dim, validity, fsl.len())?; - let ext_dtype = ExtDType::::try_new(EmptyMetadata, new_fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, new_fsl.into_array()).into_array()) -} - -#[cfg(test)] -mod tests { - use vortex::VortexSessionDefault; - use vortex::array::VortexSessionExecute; - use vortex::array::arrays::List; - use vortex::buffer::BufferMut; - use vortex::dtype::Nullability; - use vortex::session::VortexSession; - - use super::*; - - fn list_chunk_f64(rows: &[&[f64]]) -> ArrayRef { - let mut elements = BufferMut::::with_capacity(rows.iter().map(|r| r.len()).sum()); - let mut offsets = BufferMut::::with_capacity(rows.len() + 1); - offsets.push(0); - for row in rows { - for &v in row.iter() { - elements.push(v); - } - offsets.push(i32::try_from(elements.len()).unwrap()); - } - let elements_array = - PrimitiveArray::new::(elements.freeze(), Validity::NonNullable).into_array(); - let offsets_array = - PrimitiveArray::new::(offsets.freeze(), Validity::NonNullable).into_array(); - vortex::array::Array::::new(elements_array, offsets_array, Validity::NonNullable) - .into_array() - } - - fn id_array(ids: &[i64]) -> ArrayRef { - PrimitiveArray::new::( - BufferMut::from_iter(ids.iter().copied()).freeze(), - Validity::NonNullable, - ) - .into_array() - } - - #[test] - fn f64_chunk_is_cast_to_f32() -> Result<()> { - let session = VortexSession::default(); - let mut ctx = session.create_execution_ctx(); - - let emb = list_chunk_f64(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]); - let chunk = - StructArray::from_fields(&[("id", id_array(&[0, 1])), ("emb", emb)])?.into_array(); - let out = transform_chunk(chunk, &mut ctx)?; - let out_struct = out - .as_opt::() - .context("transform_chunk should return a Struct array")?; - let out_emb = out_struct - .unmasked_field_by_name("emb") - .context("transform_chunk output should contain an emb field")? - .clone(); - let DType::Extension(ext) = out_emb.dtype() else { - panic!("expected extension dtype, got {}", out_emb.dtype()); - }; - match ext.storage_dtype() { - DType::FixedSizeList(elem, 3, Nullability::NonNullable) => { - assert_eq!( - **elem, - DType::Primitive(PType::F32, Nullability::NonNullable) - ); - } - other => panic!("unexpected storage dtype {other}"), - } - Ok(()) - } - - #[test] - fn f32_chunk_passes_through() -> Result<()> { - let session = VortexSession::default(); - let mut ctx = session.create_execution_ctx(); - - let mut elements = BufferMut::::with_capacity(6); - for v in [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] { - elements.push(v); - } - let mut offsets = BufferMut::::with_capacity(3); - offsets.push(0); - offsets.push(3); - offsets.push(6); - let emb = vortex::array::Array::::new( - PrimitiveArray::new::(elements.freeze(), Validity::NonNullable).into_array(), - PrimitiveArray::new::(offsets.freeze(), Validity::NonNullable).into_array(), - Validity::NonNullable, - ) - .into_array(); - let chunk = - StructArray::from_fields(&[("id", id_array(&[0, 1])), ("emb", emb)])?.into_array(); - - let out = transform_chunk(chunk, &mut ctx)?; - let out_struct = out.as_opt::().expect("returns Struct"); - assert_eq!(out_struct.len(), 2); - Ok(()) - } -} diff --git a/benchmarks/vector-search-bench/src/lib.rs b/benchmarks/vector-search-bench/src/lib.rs deleted file mode 100644 index 76b24390d09..00000000000 --- a/benchmarks/vector-search-bench/src/lib.rs +++ /dev/null @@ -1,67 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! `vector-search-bench` vector similarity-search benchmark over several datasets. - -pub mod compression; -pub mod display; -pub mod distortion; -pub mod expression; -pub mod ingest; -pub mod prepare; -pub mod query; -pub mod scan; - -use std::sync::LazyLock; - -use anyhow::Result; -use vortex::VortexSessionDefault; -use vortex::io::session::RuntimeSessionExt; -use vortex::session::VortexSession; -use vortex_bench::vector_dataset::TrainLayout; -use vortex_bench::vector_dataset::VectorDataset; - -pub static SESSION: LazyLock = LazyLock::new(|| { - // SAFETY: called from inside the LazyLock initializer, before any other access to - // `SESSION`. The first thread to dereference SESSION runs this once. - unsafe { std::env::set_var(vortex_tensor::SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV, "1") }; - - let session = VortexSession::default().with_tokio(); - vortex_tensor::initialize(&session); - session -}); - -/// Resolve a dataset's [`TrainLayout`]. -/// -/// Every benchmark has different sets of possible dataset layouts available. The user **must** -/// provide one if there are multiple layouts. But if a dataset only has 1 layout, we can choose -/// that for them as the default. -pub fn resolve_layout( - dataset: VectorDataset, - requested: Option, -) -> Result { - let layouts = dataset.layouts(); - - match requested { - Some(layout) => { - dataset.validate_layout(layout)?; - Ok(layout) - } - None => { - if layouts.len() == 1 { - Ok(layouts[0].layout()) - } else { - let allowed = layouts - .iter() - .map(|s| s.layout().label()) - .collect::>() - .join(", "); - anyhow::bail!( - "dataset {} hosts multiple layouts ([{}]): pass --layout to pick one", - dataset.name(), - allowed, - ); - } - } - } -} diff --git a/benchmarks/vector-search-bench/src/main.rs b/benchmarks/vector-search-bench/src/main.rs deleted file mode 100644 index 307c050d221..00000000000 --- a/benchmarks/vector-search-bench/src/main.rs +++ /dev/null @@ -1,284 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! `vector-search-bench` benchmarks for cosine-similarity scan and TurboQuant distortion. -//! -//! ```sh -//! cargo run -p vector-search-bench --release -- search \ -//! --dataset cohere-large-10m \ -//! --layout partitioned \ -//! --flavors vortex-uncompressed,vortex-turboquant \ -//! --iterations 3 \ -//! --threshold 0.8 -//! -//! cargo run -p vector-search-bench --release -- distortion \ -//! --dataset sift-small-500k \ -//! --bits 4 \ -//! --samples 4096 -//! ``` - -use std::path::PathBuf; - -use anyhow::Context; -use anyhow::Result; -use clap::Parser; -use clap::Subcommand; -use vector_search_bench::compression::ALL_VECTOR_FLAVORS; -use vector_search_bench::compression::VectorFlavor; -use vector_search_bench::display::DatasetReport; -use vector_search_bench::display::render; -use vector_search_bench::distortion::DistortionConfig; -use vector_search_bench::distortion::run_distortion; -use vector_search_bench::prepare::CompressedVortexDataset; -use vector_search_bench::prepare::prepare_all; -use vector_search_bench::query::get_random_query_vector; -use vector_search_bench::resolve_layout; -use vector_search_bench::scan::ScanConfig; -use vector_search_bench::scan::ScanTiming; -use vector_search_bench::scan::run_search_scan; -use vortex_bench::setup_logging_and_tracing; -use vortex_bench::v3; -use vortex_bench::vector_dataset; -use vortex_bench::vector_dataset::TrainLayout; -use vortex_bench::vector_dataset::VectorDataset; - -#[derive(Parser, Debug)] -#[command(version, about, long_about = None)] -struct Cli { - #[command(subcommand)] - command: Command, -} - -#[derive(Subcommand, Debug)] -enum Command { - /// On-disk cosine-similarity scan latency benchmark. - Search(SearchArgs), - /// TurboQuant distortion measurement: reconstruction error and cosine error. - Distortion(DistortionArgs), -} - -#[derive(Parser, Debug)] -struct SearchArgs { - /// Dataset to benchmark. Single dataset per CLI invocation by design — large datasets - /// are intentionally babysat one at a time. - #[arg(long, value_enum)] - dataset: VectorDataset, - - /// Train-split layout. Required when the dataset publishes more than one layout. - /// Defaults to the catalog's first hosted layout when omitted. - #[arg(long, value_enum)] - layout: Option, - - /// Comma-separated list of flavors to run. Each Vortex flavor produces one `.vortex` file per - /// train shard. - #[arg( - long, - value_delimiter = ',', - value_enum, - default_values_t = ALL_VECTOR_FLAVORS.to_vec(), - )] - flavors: Vec, - - /// Number of timed scan iterations per flavor. Mean and median are reported. - #[arg(long, default_value_t = 5)] - iterations: usize, - - /// Cosine threshold passed to the filter expression. - #[arg(long, default_value_t = 0.85)] - threshold: f32, - - /// Seed for the test-parquet query sampler. - #[arg(long, default_value_t = 42)] - query_seed: u64, - - /// Optional path to write the rendered table to instead of stdout. - #[arg(long)] - output_path: Option, - - /// Additionally write v3 JSONL records to this path. See - /// `benchmarks-website/planning/02-contracts.md`. - #[arg(long)] - gh_json_v3: Option, - - /// Emit verbose tracing. - #[arg(short, long)] - verbose: bool, - - /// Enable perfetto tracing output. - #[arg(long)] - tracing: bool, -} - -#[derive(Parser, Debug)] -struct DistortionArgs { - /// Dataset to load vectors from. One dataset per invocation. - #[arg(long, value_enum)] - dataset: VectorDataset, - - /// Train-split layout. Required when the dataset publishes more than one layout. - #[arg(long, value_enum)] - layout: Option, - - /// Bits per quantized coordinate. - #[arg(long, default_value_t = 4)] - bits: u8, - - /// Seed for the SORF rotation. - #[arg(long, default_value_t = 42)] - seed: u64, - - /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. - #[arg(long, default_value_t = 3)] - rounds: u8, - - /// Number of base vectors to sample from the first train shard (first N rows). - #[arg(long, default_value_t = 65536)] - samples: usize, - - /// Optional path to write the rendered table to instead of stdout. - #[arg(long)] - output_path: Option, - - /// Emit verbose tracing. - #[arg(short, long)] - verbose: bool, - - /// Enable perfetto tracing output. - #[arg(long)] - tracing: bool, -} - -#[tokio::main] -async fn main() -> Result<()> { - let cli = Cli::parse(); - match cli.command { - Command::Search(args) => run_search(args).await, - Command::Distortion(args) => run_distortion_cmd(args).await, - } -} - -async fn run_search(args: SearchArgs) -> Result<()> { - setup_logging_and_tracing(args.verbose, args.tracing)?; - - let dataset = args.dataset; - let layout = resolve_layout(dataset, args.layout)?; - tracing::info!( - "running {} on layout {} ({} dims, {} rows)", - dataset.name(), - layout, - dataset.dim(), - dataset.num_rows() - ); - - if args.flavors.is_empty() { - anyhow::bail!("no flavors selected, please pass at least one to --flavors"); - } - - let datasets_paths = vector_dataset::download(dataset, layout) - .await - .with_context(|| format!("download {}", dataset.name()))?; - - let prepared = prepare_all(dataset, layout, &datasets_paths, &args.flavors).await?; - - let query_vector = get_random_query_vector( - &datasets_paths.test, - dataset.dim(), - dataset.element_ptype(), - args.query_seed, - ) - .await?; - tracing::info!( - "sampled query id {} (dim={})", - query_vector.id, - query_vector.query.len() - ); - - let scan_config = ScanConfig { - iterations: args.iterations, - threshold: args.threshold, - }; - - let mut scan_timings: Vec = Vec::with_capacity(prepared.len()); - for prep in &prepared { - let timing = run_search_scan(prep, &query_vector.query, &scan_config).await?; - scan_timings.push(timing); - } - - let pairs: Vec<(VectorFlavor, &CompressedVortexDataset, &ScanTiming)> = prepared - .iter() - .zip(scan_timings.iter()) - .map(|(prep, scan)| (prep.flavor, prep, scan)) - .collect(); - let report = DatasetReport { - dataset_name: dataset.name(), - vortex_results: &pairs, - }; - - if let Some(path) = args.gh_json_v3.as_ref() { - let records: Vec = scan_timings - .iter() - .map(|scan| { - let all_runs_ns: Vec = scan - .all_runs - .iter() - .map(|d| u64::try_from(d.as_nanos()).unwrap_or(u64::MAX)) - .collect(); - let median_ns = u64::try_from(scan.median.as_nanos()).unwrap_or(u64::MAX); - v3::vector_search_record( - v3::VectorSearchDims { - dataset: dataset.name(), - layout: layout.label(), - flavor: scan.flavor.label(), - threshold: f64::from(args.threshold), - }, - median_ns, - all_runs_ns, - scan.matches, - scan.rows_scanned, - scan.bytes_scanned, - ) - }) - .collect(); - v3::write_jsonl_to_path(path, &records)?; - } - - if let Some(path) = args.output_path { - let mut file = - std::fs::File::create(&path).with_context(|| format!("create {}", path.display()))?; - render(&report, &mut file)?; - } else { - let stdout = std::io::stdout(); - let mut handle = stdout.lock(); - render(&report, &mut handle)?; - } - - Ok(()) -} - -async fn run_distortion_cmd(args: DistortionArgs) -> Result<()> { - setup_logging_and_tracing(args.verbose, args.tracing)?; - - let layout = resolve_layout(args.dataset, args.layout)?; - let config = DistortionConfig { - dataset: args.dataset, - layout, - bits: args.bits, - seed: args.seed, - rounds: args.rounds, - samples: args.samples, - }; - - let report = run_distortion(&config).await?; - - if let Some(path) = args.output_path { - let mut file = - std::fs::File::create(&path).with_context(|| format!("create {}", path.display()))?; - report.render(&mut file)?; - } else { - let stdout = std::io::stdout(); - let mut handle = stdout.lock(); - report.render(&mut handle)?; - } - - Ok(()) -} diff --git a/benchmarks/vector-search-bench/src/prepare.rs b/benchmarks/vector-search-bench/src/prepare.rs deleted file mode 100644 index 8cf9d9860ed..00000000000 --- a/benchmarks/vector-search-bench/src/prepare.rs +++ /dev/null @@ -1,227 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Per-flavor on-disk ingest. -//! -//! For each `(dataset, layout, flavor)` triple, [`prepare_flavor`] streams every parquet shard -//! and writes one `.vortex` file per shard. The pipeline is idempotent (existing `.vortex` files -//! are skipped) and reports end-to-end wall-clock time, summed input parquet bytes, and total -//! output bytes. - -use std::path::Path; -use std::path::PathBuf; - -use anyhow::Context; -use anyhow::Result; -use futures::StreamExt; -use parquet::arrow::ParquetRecordBatchStreamBuilder; -use tokio::fs::File; -use tokio::io::AsyncWriteExt; -use tracing::info; -use tracing::warn; -use vortex::array::ArrayRef; -use vortex::array::ExecutionCtx; -use vortex::array::VortexSessionExecute; -use vortex::array::stream::ArrayStreamAdapter; -use vortex::array::stream::ArrayStreamExt; -use vortex::error::VortexResult; -use vortex::error::vortex_err; -use vortex_bench::conversions::parquet_to_vortex_stream; -use vortex_bench::data_dir; -use vortex_bench::utils::file::idempotent_async; -use vortex_bench::vector_dataset::DatasetPaths; -use vortex_bench::vector_dataset::TrainLayout; -use vortex_bench::vector_dataset::VectorDataset; - -use crate::SESSION; -use crate::compression::VectorFlavor; -use crate::ingest::transform_chunk; - -/// The paths of the vortex files that result from preparing one `(dataset, layout, flavor)` triple. -#[derive(Debug, Clone)] -pub struct CompressedVortexDataset { - pub dataset: VectorDataset, - pub layout: TrainLayout, - pub flavor: VectorFlavor, - pub vortex_files: Vec, -} - -/// Drive [`prepare_flavor`] across a list of flavors, returning a [`CompressedVortexDataset`] per -/// flavor in input order. -pub async fn prepare_all( - dataset: VectorDataset, - layout: TrainLayout, - paths_for_dataset: &DatasetPaths, - flavors: &[VectorFlavor], -) -> Result> { - let mut results = Vec::with_capacity(flavors.len()); - - for &flavor in flavors { - let r = prepare_flavor(dataset, layout, paths_for_dataset, flavor).await?; - results.push(r); - } - - Ok(results) -} - -/// Prepare one flavor of one dataset by writing one `.vortex` file per train shard. -/// -/// This function is sequential (for now). -pub async fn prepare_flavor( - dataset: VectorDataset, - layout: TrainLayout, - paths_for_dataset: &DatasetPaths, - flavor: VectorFlavor, -) -> Result { - let mut vortex_files = Vec::with_capacity(paths_for_dataset.train_files.len()); - - for parquet_path in &paths_for_dataset.train_files { - let parquet_path = parquet_path.clone(); - let vortex_path = parquet_to_vortex_path(&parquet_path, dataset, layout, flavor)?; - - let already_cached = vortex_path.exists(); - if already_cached { - warn!( - "skipping cached vortex shard {} ({} flavor)", - vortex_path.display(), - flavor.label() - ); - } else { - info!( - "ingesting {} -> {} ({} flavor)", - parquet_path.display(), - vortex_path.display(), - flavor.label(), - ); - } - - let written_path = idempotent_async(vortex_path.as_path(), |tmp| async move { - write_shard_streaming(&parquet_path, &tmp, flavor).await - }) - .await?; - - vortex_files.push(written_path); - } - - Ok(CompressedVortexDataset { - dataset, - layout, - flavor, - vortex_files, - }) -} - -/// Stream one parquet shard through the chunk transform into a Vortex file. -/// -/// The output dtype is derived once from the first transformed chunk so the [`ArrayStreamAdapter`] -/// can declare it ahead of time. -async fn write_shard_streaming( - parquet_path: &Path, - vortex_path: &Path, - flavor: VectorFlavor, -) -> Result<()> { - let file = File::open(parquet_path).await?; - let builder = ParquetRecordBatchStreamBuilder::new(file).await?; - let mut array_stream = parquet_to_vortex_stream(builder.build()?); - - let mut ctx = SESSION.create_execution_ctx(); - - // We need to get the first chunk so that we know what the dtype of the file is. - let first = match array_stream.next().await { - Some(chunk) => transform_chunk_with_error(chunk, &mut ctx, parquet_path, 1)?, - None => { - return Err(vortex_err!( - "ingest: parquet shard {} produced no chunks", - parquet_path.display(), - ) - .into()); - } - }; - let dtype = first.dtype().clone(); - let shard_path = parquet_path.to_path_buf(); - - let transformed = - futures::stream::iter(std::iter::once(Ok(first))).chain(array_stream.enumerate().map( - move |(chunk_offset, chunk_or_err)| { - let mut local_ctx = SESSION.create_execution_ctx(); - transform_chunk_with_error( - chunk_or_err, - &mut local_ctx, - &shard_path, - chunk_offset + 2, - ) - }, - )); - - let stream = ArrayStreamExt::boxed(ArrayStreamAdapter::new(dtype, transformed)); - - let mut output = tokio::fs::OpenOptions::new() - .write(true) - .truncate(true) - .create(true) - .open(vortex_path) - .await?; - - flavor - .create_write_options(&SESSION) - .write(&mut output, stream) - .await?; - output.flush().await?; - - Ok(()) -} - -fn transform_chunk_with_error( - chunk_or_err: VortexResult, - ctx: &mut ExecutionCtx, - parquet_path: &Path, - chunk_idx: usize, -) -> VortexResult { - let chunk = chunk_or_err.map_err(|err| { - vortex_err!( - "ingest: failed to read chunk {} from {}: {err:#}", - chunk_idx, - parquet_path.display(), - ) - })?; - - transform_chunk(chunk, ctx).map_err(|err| { - vortex_err!( - "ingest: failed to transform chunk {} from {}: {err:#}", - chunk_idx, - parquet_path.display(), - ) - }) -} - -/// Translate a parquet shard path to its `.vortex` companion under the flavor directory. -/// -/// Just swaps the file extension and rebases the file name into the per-[`VectorFlavor`] flavor -/// directory. The shard stem is preserved so a directory listing pairs `00-of-10.parquet` with -/// `00-of-10.vortex`. -pub fn parquet_to_vortex_path( - parquet: &Path, - dataset: VectorDataset, - layout: TrainLayout, - flavor: VectorFlavor, -) -> Result { - let stem = parquet - .file_stem() - .with_context(|| format!("parquet path {} has no file stem", parquet.display()))? - .to_owned(); - - // TODO(connor): Is there a better way to do this? - let mut name = stem; - name.push(".vortex"); - - Ok(flavor_dir(dataset, layout, flavor).join(name)) -} - -/// `vortex-bench/data/vector-search////`. -fn flavor_dir(ds: VectorDataset, layout: TrainLayout, flavor: VectorFlavor) -> PathBuf { - data_dir() - .join("vector-search") - .join(ds.name()) - .join(layout.label()) - .join(flavor.dir_name()) -} diff --git a/benchmarks/vector-search-bench/src/query.rs b/benchmarks/vector-search-bench/src/query.rs deleted file mode 100644 index cbd9d3bf38e..00000000000 --- a/benchmarks/vector-search-bench/src/query.rs +++ /dev/null @@ -1,120 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Sample one query vector from `test.parquet`. -//! -//! The vector datasets ship a `test.parquet` alongside the train split: these are the query vectors -//! meant to be issued against the index. -//! -//! The benchmark picks a single random row (seeded for reproducibility) and uses it as the query -//! for the scan. - -use std::path::Path; - -use anyhow::Context; -use anyhow::Result; -use anyhow::bail; -use anyhow::ensure; -use rand::RngExt; -use rand::SeedableRng; -use rand::rngs::StdRng; -use vortex::array::IntoArray; -use vortex::array::VortexSessionExecute; -use vortex::array::arrays::StructArray; -use vortex::array::arrays::struct_::StructArrayExt; -use vortex::dtype::PType; -use vortex::error::VortexExpect; -use vortex::error::vortex_err; -use vortex_bench::conversions::parquet_to_vortex_chunks; - -use crate::SESSION; - -/// One query vector sampled from `test.parquet`. -#[derive(Debug, Clone)] -pub struct QuerySample { - /// The ID of the vector. - pub id: i64, - /// f32 query values, length `dim`. - pub query: Vec, -} - -/// Sample one query row from `test.parquet`. -/// -/// The cast to f32 happens here when the source is f64 (matching the prepare-side cast), so that -/// all downstream code is uniformly f32. -pub async fn get_random_query_vector( - test_parquet: &Path, - expected_dim: u32, - src_ptype: PType, - seed: u64, -) -> Result { - let mut ctx = SESSION.create_execution_ctx(); - - let chunked = parquet_to_vortex_chunks(test_parquet.to_path_buf()) - .await - .with_context(|| format!("read test parquet {}", test_parquet.display()))?; - // The `test.parquet` files are generally small enough that this is not a big deal. - let struct_array: StructArray = chunked.into_array().execute(&mut ctx)?; - - let id = struct_array - .unmasked_field_by_name("id") - .context("test parquet missing `id` column")? - .clone(); - let emb = struct_array - .unmasked_field_by_name("emb") - .context("test parquet missing `emb` column")? - .clone(); - - let mut rng = StdRng::seed_from_u64(seed); - let query_row_idx = rng.random_range(0..id.len()); - - let id_scalar = id.execute_scalar(query_row_idx, &mut ctx)?; - let emb_scalar = emb.execute_scalar(query_row_idx, &mut ctx)?; - - ensure!(emb_scalar.as_list().len() == expected_dim as usize); - - let id = id_scalar - .as_primitive() - .as_::() - .ok_or_else(|| vortex_err!("vector ID was not a i64"))?; - - let query_vector = match src_ptype { - PType::F32 => emb_scalar - .as_list() - .elements() - .vortex_expect("somehow had a null test vector") - .iter() - .map(|element| { - element - .as_primitive() - .as_::() - .vortex_expect("value was not a f32") - }) - .collect(), - PType::F64 => - { - #[expect( - clippy::cast_possible_truncation, - reason = "this is intentionally lossy" - )] - emb_scalar - .as_list() - .elements() - .vortex_expect("somehow had a null test vector") - .iter() - .map(|element| { - element - .as_primitive() - .as_::() - .vortex_expect("value was not a f64") as f32 - }) - .collect() - } - ptype => bail!("source ptype {ptype} was somehow not f32 or f64"), - }; - - Ok(QuerySample { - query: query_vector, - id, - }) -} diff --git a/benchmarks/vector-search-bench/src/scan.rs b/benchmarks/vector-search-bench/src/scan.rs deleted file mode 100644 index 81a054f74b3..00000000000 --- a/benchmarks/vector-search-bench/src/scan.rs +++ /dev/null @@ -1,177 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Per-iteration scan driver. -//! -//! Each iteration re-opens every `.vortex` shard fresh (so the segment cache is re-primed -//! per run), pushes the cosine-similarity filter through the scan, and drains the resulting -//! [`vortex::array::stream::ArrayStream`]. The wall-clock around the entire per-iteration -//! pass is the headline number; we track the mean and median across iterations. - -use std::path::Path; -use std::path::PathBuf; -use std::time::Duration; -use std::time::Instant; - -use anyhow::Context; -use anyhow::Result; -use futures::TryStreamExt; -use vortex::array::ArrayRef; -use vortex::file::OpenOptionsSessionExt; - -use crate::SESSION; -use crate::compression::VectorFlavor; -use crate::expression::similarity_filter; -use crate::prepare::CompressedVortexDataset; - -/// Inputs to a scan run. -#[derive(Debug, Clone)] -pub struct ScanConfig { - /// Number of timed iterations (best-of-N). - pub iterations: usize, - /// Cosine threshold passed to the filter expression. - pub threshold: f32, -} - -/// Aggregate timing + counters for one `(flavor)` scan. -#[derive(Debug, Clone)] -pub struct ScanTiming { - /// Which compression flavor's `.vortex` files were scanned. - pub flavor: VectorFlavor, - /// Arithmetic mean of the per-iteration wall times. - pub mean: Duration, - /// Median of the per-iteration wall times. - pub median: Duration, - /// Per-iteration wall times in run order. - pub all_runs: Vec, - /// Number of rows that survived the filter (constant across iterations because the - /// filter is deterministic). - pub matches: u64, - /// Total rows scanned (sum of file row counts) as a sanity check that the iteration - /// actually walked the files. - pub rows_scanned: u64, - /// Total on-disk size of the scanned `.vortex` files, in bytes. - pub bytes_scanned: u64, -} - -/// Scan every shard in a [`CompressedVortexDataset`] under the given config. -pub async fn run_search_scan( - dataset: &CompressedVortexDataset, - query: &[f32], - config: &ScanConfig, -) -> Result { - anyhow::ensure!( - config.iterations > 0, - "scan iterations must be >= 1, got {}", - config.iterations - ); - - let bytes_scanned = total_file_size(&dataset.vortex_files)?; - - let mut all_runs = Vec::with_capacity(config.iterations); - let mut matches = 0u64; - let mut rows_scanned = 0u64; - - for iter_idx in 0..config.iterations { - let (wall, iter_matches, iter_rows) = - run_one_iteration(&dataset.vortex_files, query, config.threshold).await?; - tracing::debug!( - "{} iter {} -> {:?} ({} matches, {} rows)", - dataset.flavor.label(), - iter_idx, - wall, - iter_matches, - iter_rows, - ); - // Matches and row counts are deterministic across iterations; reset rather than - // accumulate so the reported value matches a single pass. - matches = iter_matches; - rows_scanned = iter_rows; - all_runs.push(wall); - } - - Ok(ScanTiming { - flavor: dataset.flavor, - mean: mean(&all_runs), - median: median(&all_runs), - all_runs, - matches, - rows_scanned, - bytes_scanned, - }) -} - -/// Sum the on-disk sizes of the given files. -fn total_file_size(paths: &[PathBuf]) -> Result { - let mut total = 0u64; - for path in paths { - let meta = - std::fs::metadata(path).with_context(|| format!("stat {} for size", path.display()))?; - total = total.saturating_add(meta.len()); - } - Ok(total) -} - -async fn run_one_iteration( - vortex_files: &[PathBuf], - query: &[f32], - threshold: f32, -) -> Result<(Duration, u64, u64)> { - let mut matches = 0u64; - let mut rows_scanned = 0u64; - - let started = Instant::now(); - for path in vortex_files { - let (m, r) = scan_one_file(path, query, threshold).await?; - matches = matches.saturating_add(m); - rows_scanned = rows_scanned.saturating_add(r); - } - - Ok((started.elapsed(), matches, rows_scanned)) -} - -async fn scan_one_file(path: &Path, query: &[f32], threshold: f32) -> Result<(u64, u64)> { - let file = SESSION - .open_options() - .open_path(path) - .await - .with_context(|| format!("open {}", path.display()))?; - - let total_rows = file.row_count(); - let filter = similarity_filter(query, threshold)?; - let chunks: Vec = file - .scan()? - .with_filter(filter) - .into_array_stream()? - .try_collect() - .await?; - - let matches: u64 = chunks.iter().map(|c| c.len() as u64).sum(); - Ok((matches, total_rows)) -} - -/// Arithmetic mean of a list of [`Duration`]s. Empty lists return [`Duration::ZERO`]. -pub fn mean(runs: &[Duration]) -> Duration { - if runs.is_empty() { - return Duration::ZERO; - } - let total_nanos: u128 = runs.iter().map(|d| d.as_nanos()).sum(); - let avg_nanos = total_nanos / runs.len() as u128; - Duration::from_nanos(u64::try_from(avg_nanos).unwrap_or(u64::MAX)) -} - -/// Median of a list of [`Duration`]s. Empty lists return [`Duration::ZERO`]. -pub fn median(runs: &[Duration]) -> Duration { - if runs.is_empty() { - return Duration::ZERO; - } - let mut sorted = runs.to_vec(); - sorted.sort(); - let mid = sorted.len() / 2; - if sorted.len() % 2 == 1 { - sorted[mid] - } else { - let total_nanos = sorted[mid - 1].as_nanos() + sorted[mid].as_nanos(); - Duration::from_nanos(u64::try_from(total_nanos / 2).unwrap_or(u64::MAX)) - } -} diff --git a/vortex-btrblocks/Cargo.toml b/vortex-btrblocks/Cargo.toml index 1adb6508828..3991eccb8c7 100644 --- a/vortex-btrblocks/Cargo.toml +++ b/vortex-btrblocks/Cargo.toml @@ -35,7 +35,6 @@ vortex-pco = { workspace = true, optional = true } vortex-runend = { workspace = true } vortex-sequence = { workspace = true } vortex-sparse = { workspace = true } -vortex-tensor = { workspace = true, optional = true } vortex-utils = { workspace = true } vortex-zigzag = { workspace = true } vortex-zstd = { workspace = true, optional = true } @@ -49,11 +48,7 @@ vortex-session = { workspace = true } [features] # This feature enabled unstable encodings for which we don't guarantee stability. -unstable_encodings = [ - "dep:vortex-tensor", - "dep:vortex-onpair", - "vortex-zstd?/unstable_encodings", -] +unstable_encodings = ["dep:vortex-onpair", "vortex-zstd?/unstable_encodings"] pco = ["dep:pco", "dep:vortex-pco"] zstd = ["dep:vortex-zstd"] diff --git a/vortex-btrblocks/src/builder.rs b/vortex-btrblocks/src/builder.rs index c0a0eaeb5eb..43df788473a 100644 --- a/vortex-btrblocks/src/builder.rs +++ b/vortex-btrblocks/src/builder.rs @@ -158,22 +158,6 @@ impl BtrBlocksCompressorBuilder { builder } - /// Adds the TurboQuant lossy vector quantization scheme. - /// - /// When enabled, [`Vector`] extension arrays are compressed using the TurboQuant algorithm - /// with MSE-optimal scalar quantization. - /// - /// # Panics - /// - /// Panics if the TurboQuant scheme is already present. - /// - /// [`Vector`]: vortex_tensor::vector::Vector - #[cfg(feature = "unstable_encodings")] - pub fn with_turboquant(self) -> Self { - use vortex_tensor::encodings::turboquant::TurboQuantScheme; - self.with_new_scheme(&TurboQuantScheme) - } - /// Excludes schemes without CUDA kernel support and adds Zstd for string and binary compression. /// /// With the `unstable_encodings` feature, buffer-level Zstd compression is used which diff --git a/vortex-tensor/src/encodings/mod.rs b/vortex-tensor/src/encodings/mod.rs index 22e57763171..e42a8605096 100644 --- a/vortex-tensor/src/encodings/mod.rs +++ b/vortex-tensor/src/encodings/mod.rs @@ -7,4 +7,3 @@ // pub mod spherical; // Spherical transform on unit-normalized vectors. pub mod l2_denorm; -pub mod turboquant; diff --git a/vortex-tensor/src/encodings/turboquant/centroids.rs b/vortex-tensor/src/encodings/turboquant/centroids.rs deleted file mode 100644 index ab653de3c35..00000000000 --- a/vortex-tensor/src/encodings/turboquant/centroids.rs +++ /dev/null @@ -1,361 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Max-Lloyd centroid computation for TurboQuant scalar quantizers. -//! -//! Pre-computes and caches optimal scalar quantizer centroids for the marginal distribution of -//! coordinates after a random orthogonal transform of a unit-norm vector. -//! -//! In high dimensions, each coordinate of a randomly transformed unit vector follows a -//! distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, which converges to -//! `N(0, 1/d)`. -//! -//! The Max-Lloyd algorithm finds optimal quantization centroids that minimize MSE for this -//! distribution. -//! -//! Centroids are not stored in TurboQuant arrays. They are deterministically derived from -//! `(padded_dim, bit_width)` and cached process-locally. -//! -//! The centroid model follows the random orthogonal transform marginal used by the TurboQuant -//! paper. This encoder applies a SORF-style structured transform instead of a dense random Gaussian -//! or orthogonal matrix, so paper-level error bounds should not be treated as verified for this -//! implementation without separate empirical validation. - -use std::sync::LazyLock; - -use vortex_buffer::Buffer; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_utils::aliases::dash_map::DashMap; - -use crate::encodings::turboquant::MAX_BIT_WIDTH; -use crate::encodings::turboquant::MIN_DIMENSION; - -// NB: All of these constants were chosen arbitrarily. - -/// The maximum iterations for Max-Lloyd algorithm when computing centroids. -const MAX_ITERATIONS: usize = 200; - -/// The Max-Lloyd convergence threshold for stopping early when computing centroids. -const CONVERGENCE_EPSILON: f64 = 1e-12; - -/// Number of trapezoids used for numerical integration when computing conditional expectations. -/// -/// The trapezoidal rule evaluates the integrand at `INTEGRATION_TRAPEZOIDS + 1` points. -const INTEGRATION_TRAPEZOIDS: usize = 1000; - -/// Global centroid cache keyed by (dimension, bit_width). -static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); - -/// Get or compute cached centroids for the given dimension and bit width. -/// -/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar -/// quantization levels for the coordinate distribution after a random orthogonal transform in -/// `dimension`-dimensional space. -pub(crate) fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { - vortex_ensure!( - (1..=MAX_BIT_WIDTH).contains(&bit_width), - "TurboQuant bit_width must be 1-{}, got {bit_width}", - MAX_BIT_WIDTH - ); - vortex_ensure!( - dimension >= MIN_DIMENSION, - "TurboQuant dimension must be >= {}, got {dimension}", - MIN_DIMENSION - ); - - if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { - return Ok(centroids.clone()); - } - - let centroids = max_lloyd_centroids(dimension, bit_width); - CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); - - Ok(centroids) -} - -// TODO(connor): It would potentially be more performant if this was modelled as const generic -// parameters to functions. -/// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`. -/// -/// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) or a -/// half-integer (when `d` is even). -/// -/// This type makes that invariant explicit and avoids floating-point comparison in the hot path. -#[derive(Clone, Copy, Debug)] -struct HalfIntExponent { - int_part: i32, - has_half: bool, -} - -impl HalfIntExponent { - /// Compute `(numerator) / 2` as a half-integer exponent. - /// - /// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative. - fn from_numerator(numerator: i32) -> Self { - // Use Euclidean division to get floor division toward negative infinity. - let int_part = numerator.div_euclid(2); - let has_half = numerator.rem_euclid(2) != 0; - Self { int_part, has_half } - } -} - -/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. -/// -/// Operates on the marginal distribution of a single coordinate of a randomly transformed unit -/// vector in d dimensions. -/// -/// The probability distribution function is: -/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` -/// where `C_d` is the normalizing constant. -/// -/// Centroids are seeded uniformly on `[±sqrt(bit_width) * sigma]` (where `sigma` is the standard -/// deviation of the normal distribution that hypershere dimension values take, and specifically -/// `sigma = 1/sqrt(dimension)`) rather than across the full `[-1, 1]`, which strands most of the -/// centroids in the near-zero-mass tails. -/// -/// Note that the `sqrt(bit_width)` is mostly empirically derived, we do not have a theoretical -/// basis for choosing this other than the fact that it seems to produce good results. -fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { - debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); - let num_centroids = 1usize << bit_width; - - // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. - let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); - - // The coordinate marginal concentrates around 0 with this standard deviation. - let sigma = 1.0 / f64::from(dimension).sqrt(); - let init_half = (f64::from(bit_width).sqrt() * sigma).min(1.0); - - // Initialize centroids uniformly on [-init_half, init_half], where the mass lives, so no cell - // starts in a zero-mass region and freezes. - let mut centroids: Vec = (0..num_centroids) - .map(|idx| -init_half + (2.0 * (idx as f64) + 1.0) * init_half / (num_centroids as f64)) - .collect(); - - let mut boundaries: Vec = vec![0.0; num_centroids + 1]; - for _ in 0..MAX_ITERATIONS { - // Compute decision boundaries (midpoints between adjacent centroids). - boundaries[0] = -1.0; - for idx in 0..num_centroids - 1 { - boundaries[idx + 1] = (centroids[idx] + centroids[idx + 1]) / 2.0; - } - boundaries[num_centroids] = 1.0; - - // Update each centroid to the conditional mean within its Voronoi cell. - let mut max_change = 0.0f64; - for idx in 0..num_centroids { - let lo = boundaries[idx]; - let hi = boundaries[idx + 1]; - let new_centroid = mean_between_centroids(lo, hi, exponent); - max_change = max_change.max((new_centroid - centroids[idx]).abs()); - centroids[idx] = new_centroid; - } - - if max_change < CONVERGENCE_EPSILON { - break; - } - } - - #[expect( - clippy::cast_possible_truncation, - reason = "all values are in [-1, 1] so this just loses precision" - )] - centroids.into_iter().map(|val| val as f32).collect() -} - -/// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. -/// -/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` on [-1, 1]. -/// -/// Since there is no closed form for the integrals, we compute this numerically. -fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { - if (hi - lo).abs() < 1e-15 { - return (lo + hi) / 2.0; - } - - let dx = (hi - lo) / INTEGRATION_TRAPEZOIDS as f64; - - let mut numerator = 0.0; - let mut denominator = 0.0; - - for step in 0..=INTEGRATION_TRAPEZOIDS { - let x_val = lo + (step as f64) * dx; - let weight = pdf_unnormalized(x_val, exponent); - - let trap_weight = if step == 0 || step == INTEGRATION_TRAPEZOIDS { - 0.5 - } else { - 1.0 - }; - - numerator += trap_weight * x_val * weight; - denominator += trap_weight * weight; - } - - if denominator.abs() < 1e-30 { - (lo + hi) / 2.0 - } else { - numerator / denominator - } -} - -/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. -/// -/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents that arise from `(d-3)/2`. -/// This is significantly faster than the general `powf` which goes through -/// `exp(exponent * ln(base))`. -fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { - let base = (1.0 - x_val * x_val).max(0.0); - - if exponent.has_half { - // Half-integer exponent: base^(int_part) * sqrt(base). - base.powi(exponent.int_part) * base.sqrt() - } else { - // Integer exponent: use powi directly. - base.powi(exponent.int_part) - } -} - -/// Precompute decision boundaries (midpoints between adjacent centroids). -/// -/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps to centroid 0, a -/// value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, and a -/// value `>= boundaries[k-2]` maps to centroid `k-1`. -pub(crate) fn compute_centroid_boundaries(centroids: &[f32]) -> Vec { - centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() -} - -/// Find the index of the nearest centroid using precomputed decision boundaries. -/// -/// `boundaries` must be the output of [`compute_centroid_boundaries`] for the corresponding -/// centroids. Uses binary search on the midpoints, avoiding distance comparisons -/// in the inner loop. -#[inline] -pub(crate) fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { - debug_assert!( - boundaries.windows(2).all(|w| w[0] <= w[1]), - "boundaries must be sorted" - ); - debug_assert!( - boundaries.len() <= 256, // 1 << 8 - "too many boundaries" - ); - - #[expect( - clippy::cast_possible_truncation, - reason = "num_centroids <= 256 and partition_point will return at most 255" - )] - (boundaries.partition_point(|&b| b < value) as u8) -} - -#[cfg(test)] -mod tests { - use rstest::rstest; - use vortex_error::VortexResult; - - use super::*; - - #[rstest] - #[case(128, 1, 2)] - #[case(128, 2, 4)] - #[case(128, 3, 8)] - #[case(128, 4, 16)] - #[case(768, 2, 4)] - #[case(1536, 3, 8)] - fn centroids_have_correct_count( - #[case] dim: u32, - #[case] bits: u8, - #[case] expected: usize, - ) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - assert_eq!(centroids.len(), expected); - Ok(()) - } - - #[rstest] - #[case(128, 1)] - #[case(128, 2)] - #[case(128, 3)] - #[case(128, 4)] - #[case(768, 2)] - fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - for window in centroids.windows(2) { - assert!( - window[0] < window[1], - "centroids not sorted: {:?}", - centroids - ); - } - Ok(()) - } - - #[rstest] - #[case(128, 1)] - #[case(128, 2)] - #[case(256, 2)] - #[case(768, 2)] - fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - let count = centroids.len(); - for idx in 0..count / 2 { - let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); - assert!( - diff < 1e-5, - "centroids not symmetric: c[{idx}]={}, c[{}]={}", - centroids[idx], - count - 1 - idx, - centroids[count - 1 - idx] - ); - } - Ok(()) - } - - #[rstest] - #[case(128, 1)] - #[case(128, 4)] - fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - for &val in centroids.iter() { - assert!( - (-1.0..=1.0).contains(&val), - "centroid out of [-1, 1]: {val}", - ); - } - Ok(()) - } - - #[test] - fn centroids_cached() -> VortexResult<()> { - let c1 = compute_or_get_centroids(128, 2)?; - let c2 = compute_or_get_centroids(128, 2)?; - assert_eq!(c1, c2); - Ok(()) - } - - #[test] - fn find_nearest_basic() -> VortexResult<()> { - let centroids = compute_or_get_centroids(128, 2)?; - let boundaries = compute_centroid_boundaries(¢roids); - assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); - - #[expect(clippy::cast_possible_truncation)] - let last_idx = (centroids.len() - 1) as u8; - assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); - for (idx, &cv) in centroids.iter().enumerate() { - #[expect(clippy::cast_possible_truncation)] - let expected = idx as u8; - assert_eq!(find_nearest_centroid(cv, &boundaries), expected); - } - Ok(()) - } - - #[test] - fn rejects_invalid_params() { - assert!(compute_or_get_centroids(128, 0).is_err()); - assert!(compute_or_get_centroids(128, 9).is_err()); - assert!(compute_or_get_centroids(1, 2).is_err()); - assert!(compute_or_get_centroids(127, 2).is_err()); - } -} diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs deleted file mode 100644 index e656ba18822..00000000000 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ /dev/null @@ -1,270 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant encoding (quantization) logic. -//! -//! The input to [`turboquant_encode`] must be a non-nullable [`Vector`](crate::vector::Vector) -//! extension array whose rows are already L2-normalized (unit norm). Normalization is handled -//! externally by [`normalize_as_l2_denorm`](crate::scalar_fns::l2_denorm::normalize_as_l2_denorm), -//! which the [`TurboQuantScheme`] calls before invoking this function. -//! -//! [`TurboQuantScheme`]: crate::encodings::turboquant::TurboQuantScheme - -use vortex_array::ArrayRef; -use vortex_array::ArrayView; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::Extension; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::dict::DictArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_array::dtype::Nullability; -use vortex_array::validity::Validity; -use vortex_buffer::Buffer; -use vortex_buffer::BufferMut; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; - -use crate::encodings::turboquant::MAX_BIT_WIDTH; -use crate::encodings::turboquant::MIN_DIMENSION; -use crate::encodings::turboquant::centroids::compute_centroid_boundaries; -use crate::encodings::turboquant::centroids::compute_or_get_centroids; -use crate::encodings::turboquant::centroids::find_nearest_centroid; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; -use crate::scalar_fns::sorf_transform::SorfMatrix; -use crate::scalar_fns::sorf_transform::SorfOptions; -use crate::scalar_fns::sorf_transform::SorfTransform; -use crate::types::vector::AnyVector; -use crate::types::vector::Vector; -use crate::utils::cast_to_f32; - -/// Configuration for TurboQuant encoding. -#[derive(Clone, Debug)] -pub struct TurboQuantConfig { - /// Bits per coordinate (1-8). - pub bit_width: u8, - /// Seed for the rotation matrix. - pub seed: u64, - /// Number of sign-diagonal + WHT rounds in the structured rotation (default 3). - pub num_rounds: u8, -} - -impl Default for TurboQuantConfig { - fn default() -> Self { - Self { - bit_width: MAX_BIT_WIDTH, - seed: 42, - num_rounds: 3, - } - } -} - -/// Apply the full TurboQuant compression pipeline to a [`Vector`](crate::vector::Vector) -/// extension array: normalize the rows via [`normalize_as_l2_denorm`], quantize the normalized -/// child via [`turboquant_encode_unchecked`], and reattach the stored norms as the outer -/// [`L2Denorm`] wrapper. -/// -/// The returned array has the canonical TurboQuant shape: -/// -/// ```text -/// ScalarFnArray(L2Denorm, [ -/// ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), -/// norms, -/// ]) -/// ``` -/// -/// # Errors -/// -/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or -/// if [`turboquant_encode_unchecked`] rejects the input shape. -pub fn turboquant_encode( - input: ArrayRef, - config: &TurboQuantConfig, - ctx: &mut ExecutionCtx, -) -> VortexResult { - // We must normalize the array before we can encode it with TurboQuant. - let l2_denorm = normalize_as_l2_denorm(input, ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalize_as_l2_denorm always produces an Extension array child"); - - // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero for null rows). - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?; - - // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally - // bypass the strict normalized-row validation when reattaching the stored norms. - Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) -} - -/// Encode a non-nullable, L2-normalized [`Vector`](crate::vector::Vector) extension array into a -/// `ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))])`, without validating the unit-norm -/// precondition. -/// -/// # Safety -/// -/// The caller must ensure: -/// -/// - The input dtype is non-nullable. -/// - Every row is L2-normalized (unit norm) or is a zero vector. -/// -/// Passing non-unit-norm vectors will not cause memory unsafety, but will produce silently -/// incorrect quantization results. -pub unsafe fn turboquant_encode_unchecked( - ext: ArrayView, - config: &TurboQuantConfig, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let ext_dtype = ext.dtype().clone(); - let storage = ext.storage_array(); - let fsl = storage.clone().execute::(ctx)?; - - vortex_ensure!( - config.bit_width >= 1 && config.bit_width <= MAX_BIT_WIDTH, - "bit_width must be 1-{MAX_BIT_WIDTH}, got {}", - config.bit_width - ); - let dimension = fsl.list_size(); - vortex_ensure!( - dimension >= MIN_DIMENSION, - "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimension}", - ); - - let vector_metadata = ext_dtype.as_extension().metadata::(); - let element_ptype = vector_metadata.element_ptype(); - - let seed = config.seed; - let num_rows = fsl.len(); - - if fsl.is_empty() { - let padded_dim = dimension.next_power_of_two(); - let empty_codes = PrimitiveArray::empty::(Nullability::NonNullable); - let empty_centroids = PrimitiveArray::empty::(Nullability::NonNullable); - let empty_dict = - DictArray::try_new(empty_codes.into_array(), empty_centroids.into_array())?; - let empty_fsl = FixedSizeListArray::try_new( - empty_dict.into_array(), - padded_dim, - Validity::NonNullable, - 0, - )?; - let empty_padded_vector = Vector::try_new_vector_array(empty_fsl.into_array())?; - - let sorf_options = SorfOptions { - seed, - num_rounds: config.num_rounds, - dimensions: dimension, - element_ptype, - }; - return Ok( - SorfTransform::try_new_array(&sorf_options, empty_padded_vector, 0)?.into_array(), - ); - } - - let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?; - let quantized_fsl = - build_quantized_fsl(num_rows, core.all_indices, core.centroids, core.padded_dim)?; - let padded_vector = Vector::try_new_vector_array(quantized_fsl)?; - - let sorf_options = SorfOptions { - seed, - num_rounds: config.num_rounds, - dimensions: dimension, - element_ptype, - }; - Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) -} - -/// Shared intermediate results from the quantization loop. -struct QuantizationResult { - centroids: Buffer, - all_indices: Buffer, - padded_dim: usize, -} - -/// Core quantization: rotate and quantize already-normalized rows. -/// -/// The input `fsl` must contain non-nullable, unit-norm vectors (already L2-normalized). Null -/// vectors are not supported and must be zeroed out before reaching this function. The rotation -/// and centroid lookup happen in f32. -fn turboquant_quantize_core( - fsl: &FixedSizeListArray, - seed: u64, - bit_width: u8, - num_rounds: u8, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let dimension = fsl.list_size() as usize; - let num_rows = fsl.len(); - - let padded_dim = dimension.next_power_of_two(); - let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?; - let padded_dim_u32 = - u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); - - let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - let f32_elements = cast_to_f32(elements_prim)?; - - let centroids = compute_or_get_centroids(padded_dim_u32, bit_width)?; - let boundaries = compute_centroid_boundaries(¢roids); - - let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); - let mut padded = vec![0.0f32; padded_dim]; - let mut rotated = vec![0.0f32; padded_dim]; - - let f32_slice = f32_elements.as_slice(); - for row in 0..num_rows { - let x = &f32_slice[row * dimension..(row + 1) * dimension]; - - // Zero-pad to the next power of 2. - padded[..dimension].copy_from_slice(x); - padded[dimension..].fill(0.0); - - rotation.rotate(&padded, &mut rotated); - - for j in 0..padded_dim { - all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); - } - } - - Ok(QuantizationResult { - centroids, - all_indices: all_indices.freeze(), - padded_dim, - }) -} - -/// Build a quantized representation: `FSL(DictArray(codes, centroids), padded_dim)`. -/// -/// This is a Dict-encoded FixedSizeList where each row of `padded_dim` u8 codes indexes into the -/// centroid codebook. The Dict can be independently sliced, taken, or executed (dequantized) -/// without knowledge of the rotation. -fn build_quantized_fsl( - num_rows: usize, - all_indices: Buffer, - centroids: Buffer, - padded_dim: usize, -) -> VortexResult { - let codes = PrimitiveArray::new::(all_indices, Validity::NonNullable); - let centroids_array = PrimitiveArray::new::(centroids, Validity::NonNullable); - - let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?; - - let padded_dim_u32 = - u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); - Ok(FixedSizeListArray::try_new( - dict.into_array(), - padded_dim_u32, - Validity::NonNullable, - num_rows, - )? - .into_array()) -} diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs deleted file mode 100644 index ca774ca65f6..00000000000 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ /dev/null @@ -1,182 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant vector quantization encoding for Vortex. -//! -//! Implements a Stage 1 TurboQuant encoding ([arXiv:2504.19874], [RFC 0033]) for lossy -//! compression of high-dimensional vector data. The encoding operates on [`Vector`] extension -//! arrays, compressing their `FixedSizeList` storage into quantized codes after a structured -//! orthogonal surrogate rotation. -//! -//! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 -//! [RFC 0033]: https://vortex-data.github.io/rfcs/rfc/0033.html -//! [`Vector`]: crate::vector::Vector -//! -//! # Overview -//! -//! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) -//! using MSE-optimal scalar quantization on coordinates of a rotated unit vector. -//! -//! The encoding is decomposed into independently swappable layers: -//! -//! - **Normalization**: [`L2Denorm`] stores per-vector norms and wraps the compressed child. -//! - **Orthogonal transform**: [`SorfTransform`] records the SORF structured orthogonal -//! transform and applies the inverse at decode time. -//! - **Quantization**: `DictArray(codes, centroids)` wrapped in `FixedSizeListArray` stores -//! the per-coordinate codebook indices. -//! -//! The full encoded tree is: -//! -//! ```text -//! ScalarFnArray(L2Denorm, [ -//! ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), -//! norms -//! ]) -//! ``` -//! -//! When executed, the tree automatically decompresses: Dict dequantizes codes → SorfTransform -//! inverse-rotates → L2Denorm re-applies norms → original vectors (approximately). -//! -//! [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm -//! [`SorfTransform`]: crate::scalar_fns::sorf_transform::SorfTransform -//! -//! The TurboQuant paper analyzes a full random orthogonal rotation. The current Vortex -//! implementation instead uses a fixed 3-round Walsh-Hadamard-based structured transform with -//! random sign diagonals generated by Vortex's frozen local SplitMix64 stream. This is a practical -//! approximation chosen for encode/decode efficiency, and should be understood as an -//! implementation choice rather than the exact construction used in the paper's proofs. -//! -//! The current encoding is also intentionally MSE-only. It does not yet implement the paper's QJL -//! residual correction for unbiased inner-product estimation, and it still uses internal -//! power-of-2 padding rather than the block decomposition proposed in RFC 0033. -//! -//! # Theoretical error bounds -//! -//! For unit-norm vectors quantized at `b` bits per coordinate, the paper's Theorem 1 -//! guarantees normalized MSE distortion: -//! -//! > `E[||x - x_hat||^2 / ||x||^2] <= (sqrt(3) * pi / 2) / 4^b` -//! -//! | Bits | MSE bound | Quality | -//! |------|------------|-------------------| -//! | 1 | 6.80e-01 | Poor | -//! | 2 | 1.70e-01 | Usable for ANN | -//! | 3 | 4.25e-02 | Good | -//! | 4 | 1.06e-02 | Very good | -//! | 5 | 2.66e-03 | Excellent | -//! | 6 | 6.64e-04 | Near-lossless | -//! | 7 | 1.66e-04 | Near-lossless | -//! | 8 | 4.15e-05 | Near-lossless | -//! -//! # Compression ratios -//! -//! Each vector is stored as `padded_dim * bit_width / 8` bytes of quantized codes plus one stored -//! norm (in the [`L2Denorm`] wrapper). In the current implementation, that norm uses the vector's -//! element float type, not a separate fixed storage precision. Non-power-of-2 dimensions are -//! padded to the next power of 2 for the structured rotation, which reduces the effective ratio -//! for those dimensions. -//! -//! The table below assumes f32 input, so the stored norm is 4 bytes. -//! -//! | dim | padded | bits | f32 bytes | TQ bytes | ratio | -//! |------|--------|------|-----------|----------|--------| -//! | 768 | 1024 | 2 | 3072 | 260 | 11.8x | -//! | 1024 | 1024 | 2 | 4096 | 260 | 15.8x | -//! | 768 | 1024 | 4 | 3072 | 516 | 6.0x | -//! | 1024 | 1024 | 4 | 4096 | 516 | 7.9x | -//! | 768 | 1024 | 8 | 3072 | 1028 | 3.0x | -//! | 1024 | 1024 | 8 | 4096 | 1028 | 4.0x | -//! -//! # Example -//! -//! ``` -//! use vortex_array::IntoArray; -//! use vortex_array::VortexSessionExecute; -//! use vortex_array::arrays::ExtensionArray; -//! use vortex_array::arrays::FixedSizeListArray; -//! use vortex_array::arrays::PrimitiveArray; -//! use vortex_array::EmptyMetadata; -//! use vortex_array::session::ArraySession; -//! use vortex_array::validity::Validity; -//! use vortex_buffer::BufferMut; -//! use vortex_session::VortexSession; -//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode}; -//! use vortex_tensor::vector::Vector; -//! -//! // Create a Vector extension array of 100 random 128-d vectors. -//! let num_rows = 100; -//! let dim = 128u32; -//! let mut buf = BufferMut::::with_capacity(num_rows * dim as usize); -//! for i in 0..(num_rows * dim as usize) { -//! buf.push((i as f32 * 0.001).sin()); -//! } -//! let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); -//! let fsl = FixedSizeListArray::try_new( -//! elements.into_array(), dim, Validity::NonNullable, num_rows, -//! ).unwrap(); -//! let vector = ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl.into_array()) -//! .map(|ext| ext.into_array()) -//! .unwrap(); -//! -//! // Normalize and quantize at 2 bits per coordinate in one pass. -//! let session = VortexSession::empty().with::(); -//! let mut ctx = session.create_execution_ctx(); -//! let config = TurboQuantConfig { bit_width: 2, seed: 42, num_rounds: 3 }; -//! let tq = turboquant_encode(vector, &config, &mut ctx).unwrap(); -//! -//! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input. -//! assert!(tq.nbytes() < 51200); -//! ``` - -pub(crate) mod centroids; -pub(crate) mod compress; - -mod scheme; -pub use compress::TurboQuantConfig; -pub use compress::turboquant_encode; -pub use compress::turboquant_encode_unchecked; -pub use scheme::TurboQuantScheme; - -/// Minimum vector dimension for TurboQuant encoding. -/// -/// Note that this is not a theoretical minimum, it is mostly a practical one to limit the total -/// amount of distortion. -pub const MIN_DIMENSION: u32 = 128; - -/// Maximum supported number of bits per quantized coordinate. -pub const MAX_BIT_WIDTH: u8 = 8; - -/// Maximum supported number of centroids in the scalar quantizer codebook. -pub const MAX_CENTROIDS: usize = 1usize << (MAX_BIT_WIDTH as usize); - -use vortex_array::dtype::DType; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; - -use crate::types::vector::AnyVector; -use crate::types::vector::VectorMatcherMetadata; - -/// Validates that `dtype` is a [`Vector`](crate::vector::Vector) extension type with -/// dimension >= [`MIN_DIMENSION`]. -/// -/// Returns the validated vector metadata on success. -pub fn tq_validate_vector_dtype(dtype: &DType) -> VortexResult { - let vector_metadata = dtype - .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) - .ok_or_else(|| { - vortex_err!("TurboQuant dtype must be a Vector extension type, got {dtype}") - })?; - - let dimensions = vector_metadata.dimensions(); - vortex_ensure!( - dimensions >= MIN_DIMENSION, - "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", - ); - - Ok(vector_metadata) -} - -#[cfg(test)] -mod tests; diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs deleted file mode 100644 index d4362096bd2..00000000000 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ /dev/null @@ -1,221 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant compression scheme. -//! -//! The scheme is a thin [`Scheme`] adapter over [`turboquant_encode`], which produces: -//! -//! ```text -//! ScalarFnArray(L2Denorm, [ -//! ScalarFnArray( -//! SorfTransform, -//! FSL(Dict(codes, centroids)) -//! ), -//! norms -//! ]) -//! ``` -//! -//! Decompression is automatic: executing the outer array walks the ScalarFn tree. -//! -//! [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode - -use vortex_array::ArrayRef; -use vortex_array::Canonical; -use vortex_array::ExecutionCtx; -use vortex_compressor::CascadingCompressor; -use vortex_compressor::ctx::CompressorContext; -use vortex_compressor::estimate::CompressionEstimate; -use vortex_compressor::estimate::EstimateVerdict; -use vortex_compressor::scheme::Scheme; -use vortex_compressor::stats::ArrayAndStats; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; - -use crate::encodings::turboquant::MAX_CENTROIDS; -use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::tq_validate_vector_dtype; -use crate::encodings::turboquant::turboquant_encode; - -/// TurboQuant compression scheme for [`Vector`] extension types. -/// -/// Applies lossy vector quantization to [`Vector`] extension arrays using the TurboQuant algorithm -/// with MSE-optimal encoding. -/// -/// Register this scheme with the compressor builder via `with_scheme`: -/// -/// ```ignore -/// use vortex_btrblocks::BtrBlocksCompressorBuilder; -/// use vortex_tensor::encodings::turboquant::TurboQuantScheme; -/// -/// let compressor = BtrBlocksCompressorBuilder::default() -/// .with_new_scheme(&TurboQuantScheme) -/// .build(); -/// ``` -/// -/// [`Vector`]: crate::vector::Vector -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct TurboQuantScheme; - -impl Scheme for TurboQuantScheme { - fn scheme_name(&self) -> &'static str { - "vortex.tensor.turboquant" - } - - fn matches(&self, canonical: &Canonical) -> bool { - let Canonical::Extension(ext) = canonical else { - return false; - }; - - tq_validate_vector_dtype(ext.dtype()).is_ok() - } - - fn expected_compression_ratio( - &self, - data: &ArrayAndStats, - _compress_ctx: CompressorContext, - _exec_ctx: &mut ExecutionCtx, - ) -> CompressionEstimate { - let len = data.array().len(); - let dtype = data.array().dtype(); - - let vector_metadata = - tq_validate_vector_dtype(dtype).vortex_expect("invalid dtype for TurboQuant"); - let element_ptype = vector_metadata.element_ptype(); - let element_bit_width: u8 = element_ptype - .bit_width() - .try_into() - .vortex_expect("invalid bit width for TurboQuant"); - let dimension = vector_metadata.dimensions(); - - CompressionEstimate::Verdict(EstimateVerdict::Ratio(estimate_compression_ratio( - element_bit_width, - dimension, - len, - ))) - } - - fn compress( - &self, - _compressor: &CascadingCompressor, - data: &ArrayAndStats, - _compress_ctx: CompressorContext, - exec_ctx: &mut ExecutionCtx, - ) -> VortexResult { - turboquant_encode(data.array().clone(), &TurboQuantConfig::default(), exec_ctx) - } -} - -// TODO(connor): If we ever add scheme vtables with metadata, we would need to pass in the config as -// a parameter here. -/// Estimate the compression ratio for TurboQuant MSE encoding with the default config. -fn estimate_compression_ratio(element_bit_width: u8, dimensions: u32, num_vectors: usize) -> f64 { - let config = TurboQuantConfig::default(); - let padded_dim = dimensions.next_power_of_two() as usize; - let element_bits = usize::from(element_bit_width); - - // Get the size of the fully uncompressed vector data. - let uncompressed_size_bits = element_bits * dimensions as usize * num_vectors; - - // Per-vector: MSE codes per padded coordinate, plus one stored norm in the input element float - // width. - let norm_bits = element_bits; - let compressed_bits_per_vector = usize::from(config.bit_width) * padded_dim; - let total_bits_per_vector = norm_bits + compressed_bits_per_vector; - - // Shared overhead: codebook centroids (2^bit_width f32 values). - let num_centroids = 1usize << config.bit_width; - debug_assert!(num_centroids <= MAX_CENTROIDS); - let overhead_bits = num_centroids * 32; // centroids are always f32 - - // This includes the quantized vectors, norms, and centroid codebook. - let compressed_size_bits = total_bits_per_vector * num_vectors + overhead_bits; - - uncompressed_size_bits as f64 / compressed_size_bits as f64 -} - -#[cfg(test)] -mod tests { - use rstest::rstest; - - use super::*; - - /// Verify compression ratio for typical embedding dimensions. - /// - /// f32 input at 768-d (padded to 1024) with 1000 vectors should give ~3x. - /// f32 input at 1024-d (no padding) should give ~4x since no padding waste. - #[rstest] - #[case::f32_768d(32, 768, 1000, 2.5, 4.5)] - #[case::f32_1024d(32, 1024, 1000, 3.5, 5.0)] - #[case::f32_1536d(32, 1536, 1000, 2.5, 4.5)] - #[case::f32_128d(32, 128, 1000, 3.0, 5.0)] - #[case::f64_768d(64, 768, 1000, 5.0, 9.0)] - #[case::f16_768d(16, 768, 1000, 1.2, 2.5)] - fn compression_ratio_in_expected_range( - #[case] element_bit_width: u8, - #[case] dim: u32, - #[case] num_vectors: usize, - #[case] min_ratio: f64, - #[case] max_ratio: f64, - ) { - let ratio = estimate_compression_ratio(element_bit_width, dim, num_vectors); - assert!( - ratio > min_ratio && ratio < max_ratio, - "ratio {ratio:.2} not in [{min_ratio}, {max_ratio}] for \ - {element_bit_width}-bit elements, dim={dim}, n={num_vectors}" - ); - } - - /// Compression ratio must always be > 1 for reasonable inputs, - /// otherwise TurboQuant makes things bigger and should not be selected. - #[rstest] - #[case(32, 128, 100)] - #[case(32, 768, 10)] - #[case(64, 256, 50)] - fn ratio_always_greater_than_one( - #[case] element_bit_width: u8, - #[case] dim: u32, - #[case] num_vectors: usize, - ) { - let ratio = estimate_compression_ratio(element_bit_width, dim, num_vectors); - assert!( - ratio > 1.0, - "ratio {ratio:.4} <= 1.0 for {element_bit_width}-bit, dim={dim}, n={num_vectors}" - ); - } - - #[rstest] - #[case(16)] - #[case(32)] - #[case(64)] - fn ratio_accounts_for_norm_storage_width(#[case] element_bit_width: u8) { - let dim = 128u32; - let num_vectors = 1usize; - let padded_dim = dim.next_power_of_two() as usize; - let config = TurboQuantConfig::default(); - let num_centroids = 1usize << config.bit_width; - - let expected_compressed_bits = usize::from(element_bit_width) - + usize::from(config.bit_width) * padded_dim - + num_centroids * 32; - let expected_uncompressed_bits = - usize::from(element_bit_width) * dim as usize * num_vectors; - let expected = expected_uncompressed_bits as f64 / expected_compressed_bits as f64; - - assert_eq!( - estimate_compression_ratio(element_bit_width, dim, num_vectors), - expected - ); - } - - /// Power-of-2 dimensions should have better ratios than their non-power-of-2 - /// predecessors due to no padding waste. - #[test] - fn power_of_two_has_better_ratio() { - let ratio_768 = estimate_compression_ratio(32, 768, 1000); - let ratio_1024 = estimate_compression_ratio(32, 1024, 1000); - assert!( - ratio_1024 > ratio_768, - "1024-d ratio ({ratio_1024:.2}) should exceed 768-d ({ratio_768:.2})" - ); - } -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/compute.rs b/vortex-tensor/src/encodings/turboquant/tests/compute.rs deleted file mode 100644 index 4d670695eaf..00000000000 --- a/vortex-tensor/src/encodings/turboquant/tests/compute.rs +++ /dev/null @@ -1,216 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::ArrayRef; -use vortex_array::IntoArray; -use vortex_array::LEGACY_SESSION; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_error::VortexResult; - -use super::*; -use crate::scalar_fns::cosine_similarity::CosineSimilarity; -use crate::scalar_fns::l2_norm::L2Norm; - -fn execute_l2_norm( - input: ArrayRef, - len: usize, - ctx: &mut vortex_array::ExecutionCtx, -) -> VortexResult { - L2Norm::try_new_array(input, len)?.into_array().execute(ctx) -} - -fn execute_cosine_similarity( - lhs: ArrayRef, - rhs: ArrayRef, - len: usize, - ctx: &mut vortex_array::ExecutionCtx, -) -> VortexResult { - CosineSimilarity::try_new_array(lhs, rhs, len)? - .into_array() - .execute(ctx) -} - -#[test] -fn slice_preserves_data() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: 123, - num_rounds: 4, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - // Full decompress then slice. - let mut ctx = SESSION.create_execution_ctx(); - let full_decoded = encoded.clone().execute::(&mut ctx)?; - let full_fsl = full_decoded - .storage_array() - .clone() - .execute::(&mut ctx)?; - let expected = full_fsl.slice(5..10)?; - let expected_fsl = expected.execute::(&mut ctx)?; - let expected_elements = expected_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - - // Slice then decompress. - let sliced = encoded.slice(5..10)?; - let sliced_decoded = sliced.execute::(&mut ctx)?; - let sliced_fsl = sliced_decoded - .storage_array() - .clone() - .execute::(&mut ctx)?; - let actual_elements = sliced_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - - assert_eq!( - expected_elements.as_slice::(), - actual_elements.as_slice::() - ); - Ok(()) -} - -#[test] -fn scalar_at_matches_decompress() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: 123, - num_rounds: 2, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - let full_decoded = encoded.clone().execute::(&mut ctx)?; - - for i in [0, 1, 5, 9] { - let expected = - full_decoded.execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())?; - let actual = encoded.execute_scalar(i, &mut LEGACY_SESSION.create_execution_ctx())?; - assert_eq!(expected, actual, "scalar_at mismatch at index {i}"); - } - Ok(()) -} - -#[test] -fn l2_norm_readthrough() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: 123, - num_rounds: 5, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); - - // Stored norms should match the actual L2 norms of the input. - let norms_prim = norms_child.execute::(&mut ctx)?; - let stored_norms = norms_prim.as_slice::(); - - let input_prim = fsl.elements().clone().execute::(&mut ctx)?; - let input_f32 = input_prim.as_slice::(); - for row in 0..10 { - let vec = &input_f32[row * 128..(row + 1) * 128]; - let actual_norm: f32 = vec.iter().map(|&v| v * v).sum::().sqrt(); - assert!( - (stored_norms[row] - actual_norm).abs() < 1e-5, - "norm mismatch at row {row}: stored={}, actual={}", - stored_norms[row], - actual_norm - ); - } - - // Also verify L2Norm readthrough shortcut works. - let norms = execute_l2_norm(encoded, 10, &mut ctx)?; - assert_eq!(norms.as_slice::(), stored_norms); - assert_eq!(norms.len(), 10); - Ok(()) -} - -#[test] -fn l2_norm_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()> { - let num_rows = 12; - let fsl = make_fsl(num_rows, 128, 7); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 1, - seed: 123, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); - - let stored_norms: PrimitiveArray = norms_child.execute(&mut ctx)?; - let encoded_norms = execute_l2_norm(encoded.clone(), num_rows, &mut ctx)?; - assert_eq!( - encoded_norms.as_slice::(), - stored_norms.as_slice::() - ); - - let decoded = encoded.execute::(&mut ctx)?.into_array(); - let decoded_norms = execute_l2_norm(decoded, num_rows, &mut ctx)?; - let max_gap = stored_norms - .as_slice::() - .iter() - .zip(decoded_norms.as_slice::().iter()) - .map(|(&stored, &decoded)| (stored - decoded).abs()) - .fold(0.0f32, f32::max); - - assert!( - max_gap > 1e-3, - "expected at least one decoded norm to drift from the authoritative stored norms, got max gap {max_gap:.6}", - ); - Ok(()) -} - -#[test] -fn cosine_similarity_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()> { - let num_rows = 12; - let fsl = make_fsl(num_rows, 128, 11); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 1, - seed: 123, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - let encoded_cos = - execute_cosine_similarity(encoded.clone(), encoded.clone(), num_rows, &mut ctx)?; - let decoded = encoded.execute::(&mut ctx)?.into_array(); - let decoded_cos = execute_cosine_similarity(decoded.clone(), decoded, num_rows, &mut ctx)?; - - let decoded_values = decoded_cos.as_slice::(); - assert!( - decoded_values - .iter() - .all(|&value| (value - 1.0).abs() < 1e-5), - "decoded cosine(x, x) should stay at 1.0", - ); - - let max_gap = encoded_cos - .as_slice::() - .iter() - .zip(decoded_values.iter()) - .map(|(&encoded, &decoded)| (encoded - decoded).abs()) - .fold(0.0f32, f32::max); - assert!( - max_gap > 1e-3, - "expected encoded cosine readthrough to differ from decoded recomputation, got max gap {max_gap:.6}", - ); - Ok(()) -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs deleted file mode 100644 index a9667f6d665..00000000000 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ /dev/null @@ -1,163 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Tests for TurboQuant encoding with decomposed SorfTransform + DictArray tree. - -mod compute; -mod nullable; -mod roundtrip; -mod structural; - -use std::f32; - -use rand::SeedableRng; -use rand::rngs::StdRng; -use rand_distr::Distribution; -use rand_distr::Normal; -use vortex_array::ArrayRef; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::Dict; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::ScalarFn; -use vortex_array::arrays::dict::DictArraySlotsExt; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; - -use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::turboquant_encode; -use crate::tests::SESSION; -use crate::types::vector::Vector; - -/// Create a FixedSizeListArray of random f32 vectors with the given validity. -fn make_fsl_with_validity( - num_rows: usize, - dim: usize, - seed: u64, - validity: Validity, -) -> FixedSizeListArray { - let mut rng = StdRng::seed_from_u64(seed); - let normal = Normal::new(0.0f32, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(normal.sample(&mut rng)); - } - - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - validity, - num_rows, - ) - .unwrap() -} - -/// Create a non-nullable FixedSizeListArray of random f32 vectors. -fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { - make_fsl_with_validity(num_rows, dim, seed, Validity::NonNullable) -} - -/// Wrap a `FixedSizeListArray` in a `Vector` extension array. -fn make_vector_ext(fsl: &FixedSizeListArray) -> ArrayRef { - Vector::try_new_vector_array(fsl.clone().into_array()) - .vortex_expect("test FSL satisfies Vector storage constraints") -} - -/// Unwrap an L2Denorm ScalarFnArray into (sorf_child, norms_child). -fn unwrap_l2denorm(encoded: &ArrayRef) -> (ArrayRef, ArrayRef) { - let sfn = encoded - .as_opt::() - .expect("expected ScalarFnArray (L2Denorm)"); - (sfn.child_at(0).clone(), sfn.child_at(1).clone()) -} - -/// Navigate the full tree to get (codes, centroids, norms) as flat arrays. -fn unwrap_codes_centroids_norms( - encoded: &ArrayRef, - ctx: &mut vortex_array::ExecutionCtx, -) -> VortexResult<(PrimitiveArray, PrimitiveArray, PrimitiveArray)> { - let (sorf_child, norms_child) = unwrap_l2denorm(encoded); - let padded_vector_child = sorf_child - .as_opt::() - .expect("expected SorfTransform ScalarFnArray") - .child_at(0) - .clone(); - - // Vector wrapping FSL(Dict(codes, centroids)) - let padded_vector: ExtensionArray = padded_vector_child.execute(ctx)?; - let fsl: FixedSizeListArray = padded_vector.storage_array().clone().execute(ctx)?; - let dict = fsl - .elements() - .as_opt::() - .vortex_expect("FSL elements should be a DictArray"); - let codes: PrimitiveArray = dict.codes().clone().execute(ctx)?; - let centroids: PrimitiveArray = dict.values().clone().execute(ctx)?; - let norms: PrimitiveArray = norms_child.execute(ctx)?; - - Ok((codes, centroids, norms)) -} - -fn theoretical_mse_bound(bit_width: u8) -> f32 { - let sqrt3_pi_over_2 = (3.0f32).sqrt() * f32::consts::PI / 2.0; - sqrt3_pi_over_2 / (4.0f32).powi(bit_width as i32) -} - -fn per_vector_normalized_mse( - original: &[f32], - reconstructed: &[f32], - dim: usize, - num_rows: usize, -) -> f32 { - let mut total = 0.0f32; - for row in 0..num_rows { - let orig = &original[row * dim..(row + 1) * dim]; - let recon = &reconstructed[row * dim..(row + 1) * dim]; - let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); - if norm_sq < 1e-10 { - continue; - } - let err_sq: f32 = orig - .iter() - .zip(recon.iter()) - .map(|(&a, &b)| (a - b) * (a - b)) - .sum(); - total += err_sq / norm_sq; - } - total / num_rows as f32 -} - -/// Normalize, encode, and decode, returning (original, decoded) flat f32 slices. -fn encode_decode( - fsl: &FixedSizeListArray, - config: &TurboQuantConfig, -) -> VortexResult<(Vec, Vec)> { - let mut ctx = SESSION.create_execution_ctx(); - let original: Vec = { - let prim = fsl.elements().clone().execute::(&mut ctx)?; - prim.as_slice::().to_vec() - }; - let encoded = turboquant_encode(make_vector_ext(fsl), config, &mut ctx)?; - let decoded_ext = encoded.execute::(&mut ctx)?; - let decoded_fsl = decoded_ext - .storage_array() - .clone() - .execute::(&mut ctx)?; - let decoded_elements: Vec = { - let prim = decoded_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - prim.as_slice::().to_vec() - }; - Ok((original, decoded_elements)) -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs deleted file mode 100644 index 92eb0e7b152..00000000000 --- a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs +++ /dev/null @@ -1,178 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::validity::Validity; -use vortex_error::VortexResult; - -use super::*; - -/// Encode a nullable Vector array and verify roundtrip preserves validity and non-null values. -#[test] -fn nullable_vectors_roundtrip() -> VortexResult<()> { - let validity = Validity::from_iter([ - true, true, false, true, true, false, true, false, true, true, - ]); - let fsl = make_fsl_with_validity(10, 128, 42, validity); - let ext = make_vector_ext(&fsl); - - let config = TurboQuantConfig { - bit_width: 3, - seed: 123, - num_rounds: 4, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - assert_eq!(encoded.len(), 10); - assert!(encoded.dtype().is_nullable()); - - let encoded_validity = encoded.validity()?; - for i in 0..10 { - let expected = ![2, 5, 7].contains(&i); - assert_eq!( - encoded_validity.execute_is_valid(i, &mut ctx)?, - expected, - "validity mismatch at row {i}" - ); - } - - let decoded_ext = encoded.execute::(&mut ctx)?; - assert_eq!(decoded_ext.len(), 10); - - let decoded_fsl = decoded_ext - .storage_array() - .clone() - .execute::(&mut ctx)?; - let decoded_prim = decoded_fsl - .elements() - .clone() - .execute::(&mut ctx)?; - let decoded_f32 = decoded_prim.as_slice::(); - - let orig_prim = fsl.elements().clone().execute::(&mut ctx)?; - let orig_f32 = orig_prim.as_slice::(); - - for row in [0, 1, 3, 4, 6, 8, 9] { - let orig_vec = &orig_f32[row * 128..(row + 1) * 128]; - let dec_vec = &decoded_f32[row * 128..(row + 1) * 128]; - let norm_sq: f32 = orig_vec.iter().map(|&v| v * v).sum(); - let err_sq: f32 = orig_vec - .iter() - .zip(dec_vec.iter()) - .map(|(&a, &b)| (a - b) * (a - b)) - .sum(); - assert!( - err_sq / norm_sq < 0.1, - "non-null row {row} has excessive reconstruction error" - ); - } - Ok(()) -} - -/// Verify that norms carry the validity: null vectors have null norms. -#[test] -fn nullable_norms_match_validity() -> VortexResult<()> { - let validity = Validity::from_iter([true, false, true, false, true]); - let fsl = make_fsl_with_validity(5, 128, 42, validity); - let ext = make_vector_ext(&fsl); - - let config = TurboQuantConfig { - bit_width: 2, - seed: 123, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); - - let norms_validity = norms_child.validity()?; - for i in 0..5 { - let expected = i % 2 == 0; - assert_eq!( - norms_validity.execute_is_valid(i, &mut ctx)?, - expected, - "norms validity mismatch at row {i}" - ); - } - Ok(()) -} - -/// Verify that L2Norm readthrough works correctly on nullable TurboQuant arrays. -#[test] -fn nullable_l2_norm_readthrough() -> VortexResult<()> { - use crate::scalar_fns::l2_norm::L2Norm; - - let validity = Validity::from_iter([true, false, true, false, true]); - let fsl = make_fsl_with_validity(5, 128, 42, validity); - let ext = make_vector_ext(&fsl); - - let config = TurboQuantConfig { - bit_width: 3, - seed: 123, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - let norm_sfn = L2Norm::try_new_array(encoded, 5)?; - let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; - - let orig_prim = fsl.elements().clone().execute::(&mut ctx)?; - let orig_f32 = orig_prim.as_slice::(); - for row in 0..5 { - if row % 2 == 0 { - assert!(norms.is_valid(row, &mut ctx)?, "row {row} should be valid"); - let expected: f32 = orig_f32[row * 128..(row + 1) * 128] - .iter() - .map(|&v| v * v) - .sum::() - .sqrt(); - let actual = norms.as_slice::()[row]; - assert!( - (actual - expected).abs() < 1e-5, - "norm mismatch at valid row {row}: actual={actual}, expected={expected}" - ); - } else { - assert!(!norms.is_valid(row, &mut ctx)?, "row {row} should be null"); - } - } - Ok(()) -} - -/// Verify that slicing a nullable TurboQuant array preserves validity. -#[test] -fn nullable_slice_preserves_validity() -> VortexResult<()> { - let validity = Validity::from_iter([ - true, true, false, true, true, false, true, false, true, true, - ]); - let fsl = make_fsl_with_validity(10, 128, 42, validity); - let ext = make_vector_ext(&fsl); - - let config = TurboQuantConfig { - bit_width: 3, - seed: 123, - num_rounds: 2, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - let sliced = encoded.slice(1..6)?; - assert_eq!(sliced.len(), 5); - - let sliced_validity = sliced.validity()?; - let expected = [true, false, true, true, false]; - for (i, &exp) in expected.iter().enumerate() { - assert_eq!( - sliced_validity.execute_is_valid(i, &mut ctx)?, - exp, - "sliced validity mismatch at index {i}" - ); - } - Ok(()) -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs deleted file mode 100644 index d82be3cf714..00000000000 --- a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs +++ /dev/null @@ -1,313 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use rstest::rstest; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::Extension; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexResult; -use vortex_error::vortex_err; - -use super::*; -use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; - -#[rstest] -#[case(128, 1)] -#[case(128, 2)] -#[case(128, 3)] -#[case(128, 4)] -#[case(128, 6)] -#[case(128, 8)] -#[case(256, 2)] -fn roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let fsl = make_fsl(10, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: 123, - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - assert_eq!(decoded.len(), original.len()); - Ok(()) -} - -#[rstest] -#[case(128, 1)] -#[case(128, 2)] -#[case(128, 3)] -#[case(128, 4)] -#[case(256, 2)] -#[case(256, 4)] -fn mse_within_theoretical_bound(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 200; - let fsl = make_fsl(num_rows, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: 123, - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - - let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - let bound = theoretical_mse_bound(bit_width); - - assert!( - normalized_mse < bound, - "Normalized MSE {normalized_mse:.6} exceeds bound {bound:.6} \ - for dim={dim}, bits={bit_width}", - ); - Ok(()) -} - -#[rstest] -#[case(128, 6)] -#[case(128, 8)] -#[case(256, 6)] -#[case(256, 8)] -fn high_bitwidth_mse_is_small(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 200; - let fsl = make_fsl(num_rows, dim, 42); - - let config_4bit = TurboQuantConfig { - bit_width: 4, - seed: 123, - num_rounds: 3, - }; - let (original_4, decoded_4) = encode_decode(&fsl, &config_4bit)?; - let mse_4bit = per_vector_normalized_mse(&original_4, &decoded_4, dim, num_rows); - - let config = TurboQuantConfig { - bit_width, - seed: 123, - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - - assert!( - mse < mse_4bit, - "{bit_width}-bit MSE ({mse:.6}) should be < 4-bit MSE ({mse_4bit:.6})" - ); - assert!(mse < 0.01, "{bit_width}-bit MSE ({mse:.6}) should be < 1%"); - Ok(()) -} - -#[test] -fn mse_decreases_with_bits() -> VortexResult<()> { - let dim = 128; - let num_rows = 50; - let fsl = make_fsl(num_rows, dim, 99); - - let mut prev_mse = f32::MAX; - for bit_width in 1..=8u8 { - let config = TurboQuantConfig { - bit_width, - seed: 123, - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - let mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - assert!( - mse <= prev_mse * 1.01, - "MSE should decrease: {bit_width}-bit={mse:.6} > prev={prev_mse:.6}" - ); - prev_mse = mse; - } - Ok(()) -} - -#[rstest] -#[case(0)] -#[case(1)] -fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { - let fsl = make_fsl(num_rows, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 2, - seed: 123, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - let decoded = encoded.execute::(&mut ctx)?; - assert_eq!(decoded.len(), num_rows); - Ok(()) -} - -#[rstest] -#[case(1)] -#[case(64)] -#[case(127)] -fn rejects_dimension_below_128(#[case] dim: usize) { - let elements = PrimitiveArray::new::( - BufferMut::from_iter((0..dim).map(|i| i as f32 + 1.0)).freeze(), - Validity::NonNullable, - ); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into().expect("dim fits u32"), - Validity::NonNullable, - 1, - ) - .unwrap(); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 2, - seed: 0, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - assert!(turboquant_encode(ext, &config, &mut ctx).is_err()); -} - -#[rstest] -#[case(0)] -#[case(9)] -fn rejects_invalid_bit_width(#[case] bit_width: u8) { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width, - seed: 0, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext, &mut ctx) - .unwrap() - .child_at(0) - .clone(); - let normalized_ext = normalized - .as_opt::() - .expect("normalized child should be Extension"); - assert!(unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx) }.is_err()); -} - -#[test] -fn all_zero_vectors_roundtrip() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let buf = BufferMut::::full(0.0f32, num_rows * dim); - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into().map_err(|e| vortex_err!("{e}"))?, - Validity::NonNullable, - num_rows, - )?; - - let config = TurboQuantConfig { - bit_width: 3, - seed: 42, - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - for (i, (&o, &d)) in original.iter().zip(decoded.iter()).enumerate() { - assert_eq!(o, 0.0, "original[{i}] not zero"); - assert_eq!(d, 0.0, "decoded[{i}] not zero for all-zero input"); - } - Ok(()) -} - -/// Roundtrip at large embedding dimensions. -#[rstest] -#[case(768, 4)] -#[case(1024, 5)] -fn large_dimension_roundtrip(#[case] dim: usize, #[case] bit_width: u8) -> VortexResult<()> { - let num_rows = 10; - let fsl = make_fsl(num_rows, dim, 42); - let config = TurboQuantConfig { - bit_width, - seed: 123, - num_rounds: 3, - }; - let (original, decoded) = encode_decode(&fsl, &config)?; - assert_eq!(decoded.len(), original.len()); - - let normalized_mse = per_vector_normalized_mse(&original, &decoded, dim, num_rows); - // 2x slack for the SRHT-vs-Haar gap. - let bound = 2.0 * theoretical_mse_bound(bit_width); - assert!( - normalized_mse < bound, - "Normalized MSE {normalized_mse:.6} exceeds 2x bound {bound:.6} for dim={dim}, bits={bit_width}", - ); - Ok(()) -} - -/// Verify that f64 input is accepted and encoded. -#[test] -fn f64_input_encodes_successfully() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f64, 1.0).map_err(|e| vortex_err!("{e}"))?; - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(normal.sample(&mut rng)); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into().map_err(|e| vortex_err!("{e}"))?, - Validity::NonNullable, - num_rows, - )?; - - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: 42, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); - assert_eq!(norms_child.len(), num_rows); - Ok(()) -} - -/// Verify that f16 input is accepted and encoded. -#[test] -fn f16_input_encodes_successfully() -> VortexResult<()> { - let num_rows = 10; - let dim = 128; - let mut rng = StdRng::seed_from_u64(99); - let normal = Normal::new(0.0f32, 1.0).map_err(|e| vortex_err!("{e}"))?; - - let mut buf = BufferMut::::with_capacity(num_rows * dim); - for _ in 0..(num_rows * dim) { - buf.push(half::f16::from_f32(normal.sample(&mut rng))); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into().map_err(|e| vortex_err!("{e}"))?, - Validity::NonNullable, - num_rows, - )?; - - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: 42, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); - assert_eq!(norms_child.len(), num_rows); - - let decoded_ext = encoded.execute::(&mut ctx)?; - let decoded_fsl = decoded_ext - .storage_array() - .clone() - .execute::(&mut ctx)?; - assert_eq!(decoded_fsl.len(), num_rows); - Ok(()) -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs deleted file mode 100644 index bc9a5e207f1..00000000000 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ /dev/null @@ -1,334 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Tests that verify the internal structure of the encoded tree. - -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_error::VortexResult; - -use super::*; -use crate::encodings::turboquant::centroids::compute_or_get_centroids; - -/// Verify that the centroids stored in the DictArray match what `compute_or_get_centroids()` computes. -#[test] -fn stored_centroids_match_computed() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: 123, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - let (_codes, centroids, _norms) = unwrap_codes_centroids_norms(&encoded, &mut ctx)?; - let stored = centroids.as_slice::(); - - // padded_dim for dim=128 is 128. - let computed = compute_or_get_centroids(128, 3)?; - - assert_eq!(stored.len(), computed.len()); - for i in 0..stored.len() { - assert_eq!(stored[i], computed[i], "Centroid mismatch at {i}"); - } - Ok(()) -} - -/// Verify that the rotation is deterministic from seed by checking decode output. -#[test] -fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: 123, - num_rounds: 4, - }; - - // Encode twice with the same seed → should produce identical results. - let mut ctx = SESSION.create_execution_ctx(); - let encoded1 = turboquant_encode(ext.clone(), &config, &mut ctx)?; - let decoded1 = encoded1.execute::(&mut ctx)?; - let fsl1 = decoded1 - .storage_array() - .clone() - .execute::(&mut ctx)?; - let elems1 = fsl1 - .elements() - .clone() - .execute::(&mut ctx)?; - - let mut ctx = SESSION.create_execution_ctx(); - let encoded2 = turboquant_encode(ext, &config, &mut ctx)?; - let decoded2 = encoded2.execute::(&mut ctx)?; - let fsl2 = decoded2 - .storage_array() - .clone() - .execute::(&mut ctx)?; - let elems2 = fsl2 - .elements() - .clone() - .execute::(&mut ctx)?; - - assert_eq!( - elems1.as_slice::(), - elems2.as_slice::(), - "Two encodes with same seed should produce identical decode output" - ); - Ok(()) -} - -/// Verify that the encoded array's dtype is a Vector extension type. -#[test] -fn encoded_dtype_is_vector_extension() -> VortexResult<()> { - let fsl = make_fsl(10, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 3, - seed: 123, - num_rounds: 2, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - assert!( - encoded.dtype().is_extension(), - "TurboQuant dtype should be an extension type, got {}", - encoded.dtype() - ); - assert!( - encoded.dtype().as_extension().is::(), - "TurboQuant dtype should be a Vector extension type" - ); - Ok(()) -} - -/// Verify approximate cosine similarity in the quantized domain. -#[test] -fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 4, - seed: 123, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - let input_prim = fsl.elements().clone().execute::(&mut ctx)?; - let input_f32 = input_prim.as_slice::(); - - // Navigate tree to get codes, centroids, norms. - let (codes_prim, centroids_prim, norms_prim) = - unwrap_codes_centroids_norms(&encoded, &mut ctx)?; - let all_codes = codes_prim.as_slice::(); - let centroid_vals = centroids_prim.as_slice::(); - let norms = norms_prim.as_slice::(); - - // padded_dim for dim=128. - let pd = 128usize; - - for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { - let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; - let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; - - let dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); - let norm_a: f32 = vec_a.iter().map(|&v| v * v).sum::().sqrt(); - let norm_b: f32 = vec_b.iter().map(|&v| v * v).sum::().sqrt(); - let exact_cos = dot / (norm_a * norm_b); - - let approx_cos = if norms[row_a] == 0.0 || norms[row_b] == 0.0 { - 0.0 - } else { - let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; - let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; - codes_a - .iter() - .zip(codes_b.iter()) - .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) - .sum::() - }; - - let error = (exact_cos - approx_cos).abs(); - assert!( - error < 0.15, - "cosine similarity error too large for ({row_a}, {row_b}): \ - exact={exact_cos:.4}, approx={approx_cos:.4}, error={error:.4}" - ); - } - Ok(()) -} - -/// Verify approximate dot product in the quantized domain. -#[test] -fn dot_product_quantized_accuracy() -> VortexResult<()> { - let fsl = make_fsl(20, 128, 42); - let ext = make_vector_ext(&fsl); - let config = TurboQuantConfig { - bit_width: 8, - seed: 123, - num_rounds: 3, - }; - let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_encode(ext, &config, &mut ctx)?; - - let input_prim = fsl.elements().clone().execute::(&mut ctx)?; - let input_f32 = input_prim.as_slice::(); - - let (codes_prim, centroids_prim, norms_prim) = - unwrap_codes_centroids_norms(&encoded, &mut ctx)?; - let all_codes = codes_prim.as_slice::(); - let centroid_vals = centroids_prim.as_slice::(); - let norms = norms_prim.as_slice::(); - - let pd = 128usize; - - for (row_a, row_b) in [(0, 1), (5, 10), (0, 19)] { - let vec_a = &input_f32[row_a * 128..(row_a + 1) * 128]; - let vec_b = &input_f32[row_b * 128..(row_b + 1) * 128]; - - let exact_dot: f32 = vec_a.iter().zip(vec_b.iter()).map(|(&x, &y)| x * y).sum(); - - let codes_a = &all_codes[row_a * pd..(row_a + 1) * pd]; - let codes_b = &all_codes[row_b * pd..(row_b + 1) * pd]; - let unit_dot: f32 = codes_a - .iter() - .zip(codes_b.iter()) - .map(|(&ca, &cb)| centroid_vals[ca as usize] * centroid_vals[cb as usize]) - .sum(); - let approx_dot = norms[row_a] * norms[row_b] * unit_dot; - - let scale = exact_dot.abs().max(1.0); - let rel_error = (exact_dot - approx_dot).abs() / scale; - assert!( - rel_error < 0.15, - "dot product error too large for ({row_a}, {row_b}): \ - exact={exact_dot:.4}, approx={approx_dot:.4}, rel_error={rel_error:.4}" - ); - } - Ok(()) -} - -/// Verify SorfTransform in isolation: manually forward-rotate known data, wrap in -/// FSL(Dict), execute SorfTransform, and check inverse rotation recovers the original. -#[test] -#[expect( - clippy::cast_possible_truncation, - reason = "test uses known small dimensions" -)] -fn sorf_transform_roundtrip_isolation() -> VortexResult<()> { - use vortex_array::EmptyMetadata; - use vortex_array::IntoArray; - use vortex_array::arrays::dict::DictArray; - use vortex_array::dtype::extension::ExtDType; - use vortex_array::validity::Validity; - use vortex_buffer::BufferMut; - - use crate::encodings::turboquant::centroids::compute_centroid_boundaries; - use crate::encodings::turboquant::centroids::compute_or_get_centroids; - use crate::encodings::turboquant::centroids::find_nearest_centroid; - use crate::scalar_fns::sorf_transform::SorfMatrix; - use crate::scalar_fns::sorf_transform::SorfOptions; - use crate::scalar_fns::sorf_transform::SorfTransform; - use crate::types::vector::Vector; - - let dim = 128usize; - let seed = 99u64; - let num_rounds = 3u8; - let num_rows = 5; - - // Build a known input: simple increasing values, then normalize each row to unit norm. - let mut input_f32 = vec![0.0f32; num_rows * dim]; - for row in 0..num_rows { - let mut norm_sq = 0.0f32; - for i in 0..dim { - let val = ((row * dim + i) as f32 + 1.0) * 0.01; - input_f32[row * dim + i] = val; - norm_sq += val * val; - } - let norm = norm_sq.sqrt(); - for i in 0..dim { - input_f32[row * dim + i] /= norm; - } - } - - // Forward transform + quantize (mimicking what turboquant_quantize_core does). - let padded_dim = dim.next_power_of_two(); - let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?; - let centroids = compute_or_get_centroids(padded_dim as u32, 8)?; - let boundaries = compute_centroid_boundaries(¢roids); - - let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); - let mut padded = vec![0.0f32; padded_dim]; - let mut rotated = vec![0.0f32; padded_dim]; - - for row in 0..num_rows { - padded[..dim].copy_from_slice(&input_f32[row * dim..(row + 1) * dim]); - padded[dim..].fill(0.0); - rotation.rotate(&padded, &mut rotated); - for j in 0..padded_dim { - all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); - } - } - - // Build FSL(Dict(codes, centroids)). - let codes = PrimitiveArray::new::(all_indices.freeze(), Validity::NonNullable); - let mut centroids_buf = BufferMut::::with_capacity(centroids.len()); - centroids_buf.extend_from_slice(¢roids); - let centroids_arr = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); - let dict = DictArray::try_new(codes.into_array(), centroids_arr.into_array())?; - let fsl = FixedSizeListArray::try_new( - dict.into_array(), - padded_dim as u32, - Validity::NonNullable, - num_rows, - )?; - - // Wrap the padded FSL in a Vector extension so it can be the SorfTransform child. - let padded_vector_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - let padded_vector = ExtensionArray::new(padded_vector_dtype, fsl.into_array()); - - // Wrap in SorfTransform and execute. - let sorf_options = SorfOptions { - seed, - num_rounds, - dimensions: dim as u32, - element_ptype: vortex_array::dtype::PType::F32, - }; - let sorf_array = - SorfTransform::try_new_array(&sorf_options, padded_vector.into_array(), num_rows)?; - - let mut ctx = SESSION.create_execution_ctx(); - let result: ExtensionArray = sorf_array.into_array().execute(&mut ctx)?; - let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; - let result_prim: PrimitiveArray = result_fsl.elements().clone().execute(&mut ctx)?; - let result_f32 = result_prim.as_slice::(); - - assert_eq!(result_f32.len(), num_rows * dim); - - // At 8-bit quantization, reconstruction should be very close to input. - for row in 0..num_rows { - let orig = &input_f32[row * dim..(row + 1) * dim]; - let recon = &result_f32[row * dim..(row + 1) * dim]; - let err_sq: f32 = orig - .iter() - .zip(recon) - .map(|(&a, &b)| (a - b) * (a - b)) - .sum(); - let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); - assert!( - err_sq / norm_sq < 1e-3, - "SorfTransform isolation: row {row} MSE too high: {:.6}", - err_sq / norm_sq - ); - } - Ok(()) -} diff --git a/vortex-tensor/src/lib.rs b/vortex-tensor/src/lib.rs index 92d5422c200..0a0b3a8a2ee 100644 --- a/vortex-tensor/src/lib.rs +++ b/vortex-tensor/src/lib.rs @@ -23,7 +23,6 @@ use crate::scalar_fns::cosine_similarity::CosineSimilarity; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_norm::L2Norm; -use crate::scalar_fns::sorf_transform::SorfTransform; use crate::types::fixed_shape_tensor::FixedShapeTensor; use crate::types::vector::Vector; @@ -42,8 +41,8 @@ pub mod vector_search; mod utils; /// Environment variable that gates registration of the tensor scalar-fn array plugins (the array -/// encodings that let [`CosineSimilarity`], [`InnerProduct`], [`L2Denorm`], [`L2Norm`], and -/// [`SorfTransform`] persist in a Vortex file). When unset, only the scalar functions themselves +/// encodings that let [`CosineSimilarity`], [`InnerProduct`], [`L2Denorm`], and [`L2Norm`] +/// persist in a Vortex file). When unset, only the scalar functions themselves /// are registered; readers of files containing serialized tensor scalar-fn arrays will fail to /// deserialize. Opt-in by setting the variable to any non-empty value. pub const SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str = "VX_SCALAR_FN_ARRAY_TENSOR_PLUGIN"; @@ -63,7 +62,6 @@ pub fn initialize(session: &VortexSession) { session_fns.register(InnerProduct); session_fns.register(L2Denorm); session_fns.register(L2Norm); - session_fns.register(SorfTransform); // Registering the scalar-fn array plugins lets the tensor scalar fns be serialized as array // encodings inside Vortex files. Gate this on an env var so applications that do not intend @@ -76,7 +74,6 @@ pub fn initialize(session: &VortexSession) { session_arrays.register(ScalarFnArrayPlugin::new(InnerProduct)); session_arrays.register(ScalarFnArrayPlugin::new(L2Denorm)); session_arrays.register(ScalarFnArrayPlugin::new(L2Norm)); - session_arrays.register(ScalarFnArrayPlugin::new(SorfTransform)); } } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 0b3176915fd..2b267671c8c 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -48,7 +48,7 @@ use crate::utils::validate_binary_tensor_float_inputs; /// same dtype and a float element type. The output is a float column of the same float type. /// /// When either input is wrapped in [`L2Denorm`], this operator treats the stored norms and -/// normalized children as authoritative. For lossy encodings such as TurboQuant, that means the +/// normalized children as authoritative. For lossy encodings, that means the /// optimized readthrough path may intentionally differ slightly from decoding both sides to dense /// coordinates and recomputing cosine from scratch. /// @@ -603,7 +603,7 @@ mod tests { #[test] fn both_denorm_lossy_zero_stored_norm_returns_zero() -> VortexResult<()> { - // Mimics a lossy encoding (e.g. TurboQuant) where the stored norm is authoritative but + // Mimics a lossy encoding where the stored norm is authoritative but // the decoded normalized child is physically nonzero. With a stored norm of `0.0`, cosine // similarity for that row must be `0.0` even though the dot product of the normalized // children is nonzero. @@ -626,7 +626,7 @@ mod tests { #[test] fn one_side_denorm_lossy_zero_stored_norm_returns_zero() -> VortexResult<()> { - // Mimics a lossy encoding (e.g. TurboQuant) where the stored norm is authoritative but + // Mimics a lossy encoding where the stored norm is authoritative but // the decoded normalized child is physically nonzero. The plain side is a normal nonzero // tensor with positive norm. cosine similarity must still be `0.0` because the // authoritative stored norm on the denorm side is `0.0`. diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index b3ba7a7b557..8f6b67ce11b 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -7,25 +7,16 @@ use num_traits::Float; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; -use vortex_array::arrays::Constant; -use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::Dict; -use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeList; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::dict::DictArraySlotsExt; use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; use vortex_array::expr::Expression; use vortex_array::expr::and; use vortex_array::match_each_float_ptype; @@ -38,7 +29,6 @@ use vortex_array::scalar_fn::ScalarFnVTable; use vortex_array::scalar_fn::TypedScalarFnInstance; use vortex_array::serde::ArrayChildren; use vortex_buffer::Buffer; -use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_session::VortexSession; @@ -46,11 +36,7 @@ use vortex_session::registry::CachedId; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::DenormOrientation; -use crate::scalar_fns::sorf_transform::SorfMatrix; -use crate::scalar_fns::sorf_transform::SorfTransform; -use crate::types::vector::Vector; use crate::utils::BinaryTensorOpMetadata; -use crate::utils::extract_constant_flat_row; use crate::utils::extract_flat_elements; use crate::utils::extract_l2_denorm_children; use crate::utils::validate_binary_tensor_float_inputs; @@ -139,20 +125,6 @@ impl ScalarFnVTable for InnerProduct { DenormOrientation::Neither => {} } - // Reduction case 1: `InnerProduct(SorfTransform(x), const)` rewrites to - // `InnerProduct(x, forward_rotate(zero_pad(const)))`. Re-executes recursively so - // case 2 can fire on the rewritten tree. - if let Some(rewritten) = self.try_execute_sorf_constant(&lhs_ref, &rhs_ref, len, ctx)? { - return Ok(rewritten); - } - - // Reduction case 2: `InnerProduct(Vector[FSL(Dict(u8, f32))], const)` is computed by - // gather-summing `q[j] * values[codes[j] as usize]` per row, reading the codebook - // directly instead of decoding the column into dense vectors. - if let Some(result) = self.try_execute_dict_constant(&lhs_ref, &rhs_ref, len, ctx)? { - return Ok(result); - } - // Compute combined validity. let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; @@ -293,206 +265,6 @@ impl InnerProduct { Ok(unsafe { PrimitiveArray::new_unchecked(buffer, validity) }.into_array()) }) } - - /// Fast path when one side is `ExactScalarFn` and the other side is a - /// constant-backed tensor-like extension. Rewrites to - /// `InnerProduct(sorf_child, forward_rotate(zero_pad(const_query)))` because SORF is - /// orthogonal, so ` = ` where `T` is the truncation from - /// `padded_dim` to `dim` applied by `SorfTransform` and `R` is the SORF forward matrix. See the - /// proof in the crate-level docs and in the plan file. - /// - /// Returns `Ok(None)` if neither side matches, when the operand element type is not `F32`, or - /// when the constant side is not a constant-backed tensor extension. The caller is expected to - /// fall through to the standard path in that case. - /// - /// # F32-only - /// - /// TODO(connor): this rewrite is only sound for `PType::F32` because `SorfTransform` applies an - /// `f32 -> element_ptype` cast at the end of its `execute`. For `F16`/`F64` the cast changes - /// the inner product's rounding and the rewrite would not be semantically equivalent. Until we - /// push the cast through `InnerProduct`, both the SorfTransform output ptype and the - /// constant-side element ptype must be `F32` here. - fn try_execute_sorf_constant( - &self, - lhs_ref: &ArrayRef, - rhs_ref: &ArrayRef, - len: usize, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - // Identify which side is the SorfTransform, if any. - let (sorf_view, const_ref) = - if let Some(view) = lhs_ref.as_opt::>() { - (view, rhs_ref) - } else if let Some(view) = rhs_ref.as_opt::>() { - (view, lhs_ref) - } else { - return Ok(None); - }; - - if sorf_view.options.element_ptype != PType::F32 { - return Ok(None); - } - - // The other side must be a constant tensor. - let Some(const_storage) = constant_tensor_storage(const_ref) else { - return Ok(None); - }; - - let dim = sorf_view.options.dimensions as usize; - let num_rounds = sorf_view.options.num_rounds as usize; - let seed = sorf_view.options.seed; - let padded_dim = dim.next_power_of_two(); - - // Extract the single stored row of the constant. - let flat = extract_constant_flat_row(&const_storage, ctx)?; - if flat.ptype() != PType::F32 { - return Ok(None); - } - - // Zero-pad the query from `dim` to `padded_dim` and forward-rotate. - let mut padded_query = vec![0.0f32; padded_dim]; - padded_query[..dim].copy_from_slice(flat.as_slice::()); - - let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds, seed)?; - let mut rotated_query = vec![0.0f32; padded_dim]; - rotation.rotate(&padded_query, &mut rotated_query); - - // Wrap the rotated query as a `Vector` constant broadcast to `len` - // rows. The new extension dtype has `padded_dim` instead of `dim`, matching the - // SorfTransform child we are about to dot it with. - let new_constant = Vector::constant_array(&rotated_query, len)?; - - // Extract the SorfTransform child (the already-padded Vector). - let sorf_child = sorf_view - .nth_child(0) - .vortex_expect("SorfTransform must have exactly one child"); - - // Recursively execute the rewritten inner product. This allows case 2 to fire on - // the rewritten tree if the sorf child is `Vector[FSL(Dict)]`. Termination is - // guaranteed because the rewrite strictly removes a `SorfTransform` scalar-fn node - // from the tree and SORFs cannot be nested. - let rewritten = InnerProduct::try_new_array(sorf_child, new_constant, len)? - .into_array() - .execute(ctx)?; - Ok(Some(rewritten)) - } - - /// Fast path when one side is an extension whose storage is `FSL(Dict(u8, f32))` and - /// the other side is a constant-backed tensor extension with an F32 element ptype. - /// - /// Computes each row's inner product as - /// `out[i] = sum_{j in 0..padded_dim} q[j] * values[codes[i * padded_dim + j] as usize]` - /// using a direct codebook lookup in the hot loop. An explicit product table - /// `P[j, k] = q[j] * values[k]` (size `padded_dim * num_centroids * 4B`, ~1 MiB for the - /// common 1024/256 case) was tried and measured ~10% *slower* on the - /// `similarity_search` bench because the 1 KiB `values` table stays in L1 across all - /// rows, while the 1 MiB product table does not. - /// - /// Returns `Ok(None)` when the pattern doesn't match; the caller should fall through to - /// the standard path. - fn try_execute_dict_constant( - &self, - lhs_ref: &ArrayRef, - rhs_ref: &ArrayRef, - len: usize, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - // Try each orientation. The oriented helper navigates each side exactly once, so - // the only redundant work here is the failed navigation of the first side when the - // dict happens to be on the right. - if let Some(result) = self.try_execute_dict_constant_oriented(lhs_ref, rhs_ref, len, ctx)? { - return Ok(Some(result)); - } - self.try_execute_dict_constant_oriented(rhs_ref, lhs_ref, len, ctx) - } - - /// Orientation-specific helper for [`Self::try_execute_dict_constant`]. `dict_candidate` - /// is tried as `Extension[FSL[Dict]]`; `const_candidate` is tried as a constant-backed - /// tensor extension. Returns `Ok(None)` if either navigation fails or any gate rejects. - fn try_execute_dict_constant_oriented( - &self, - dict_candidate: &ArrayRef, - const_candidate: &ArrayRef, - len: usize, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - // Navigate the dict side. - let Some(dict_ext) = dict_candidate.as_opt::() else { - return Ok(None); - }; - let Some(fsl) = dict_ext.storage_array().as_opt::() else { - return Ok(None); - }; - let Some(dict) = fsl.elements().as_opt::() else { - return Ok(None); - }; - - // Navigate the constant side and require its scalar be non-null. - let Some(const_storage) = constant_tensor_storage(const_candidate) else { - return Ok(None); - }; - - // Canonicalize codes and values. Codes may be e.g. BitPacked; executing is cheaper - // than falling through to the standard path (which would also canonicalize). - let codes_prim: PrimitiveArray = dict.codes().clone().execute(ctx)?; - let values_prim: PrimitiveArray = dict.values().clone().execute(ctx)?; - - // Gate: u8 codes and f32 centroids. - if codes_prim.ptype() != PType::U8 { - // TODO(connor): Should we support wider codes? - return Ok(None); - } - if values_prim.ptype() != PType::F32 { - // TODO(connor): direct-lookup path only supports f32 centroids. SorfTransform - // forces f32 anyway, so this is the only shape we need for now. - return Ok(None); - } - - let padded_dim = usize::try_from(fsl.list_size()).vortex_expect("fsl list_size fits usize"); - - let flat = extract_constant_flat_row(&const_storage, ctx)?; - if flat.ptype() != PType::F32 { - // TODO(connor): case 2 is f32-only. For f16/f64 we fall through to the standard - // path, which computes the inner product with the correct element type. - return Ok(None); - } - - // Combine the input validities up front; the per-row arithmetic may write garbage - // into null rows but the validity mask hides it (matching the standard path). - let validity = dict_candidate - .validity()? - .and(const_candidate.validity()?)?; - - // Fast path for the empty case: skip allocating and touching the codes buffer. - if len == 0 { - let empty = PrimitiveArray::empty::(validity.nullability()); - return Ok(Some(empty.into_array())); - } - - let q: &[f32] = flat.as_slice::(); - debug_assert_eq!(q.len(), padded_dim); - let codes: &[u8] = codes_prim.as_slice::(); - let values: &[f32] = values_prim.as_slice::(); - debug_assert_eq!(codes.len(), len * padded_dim); - - // The hot loop is extracted into [`execute_dict_constant_inner_product`] so the compiler - // can prove the chunked indices stay in bounds and vectorize the inner gather-accumulate. - let out = execute_dict_constant_inner_product(q, values, codes, len, padded_dim); - - // SAFETY: the buffer length equals `len`, which matches the validity length. - let result = unsafe { PrimitiveArray::new_unchecked(out.freeze(), validity) }.into_array(); - Ok(Some(result)) - } -} - -/// Return the storage constant for a canonical tensor-like constant query. -fn constant_tensor_storage(array: &ArrayRef) -> Option { - let constant = array.as_opt::()?; - if constant.scalar().is_null() { - return None; - } - let ext_scalar = constant.scalar().as_extension_opt()?; - Some(ConstantArray::new(ext_scalar.to_storage_scalar(), array.len()).into_array()) } /// Computes the inner product (dot product) of two equal-length float slices. @@ -505,49 +277,6 @@ fn inner_product_row(a: &[T], b: &[T]) -> T { .fold(T::zero(), |acc, v| acc + v) } -/// Compute inner products between a constant query vector and dictionary-encoded rows. -/// -/// For each row, computes `sum(q[j] * values[codes[row * dim + j]])` using the codebook `values` -/// directly instead of decoding the dictionary into dense vectors. -/// -/// The inner loop uses `PARTIAL_SUMS` independent accumulators so the CPU can pipeline FP additions -/// instead of waiting for each `fadd` to retire before starting the next. -fn execute_dict_constant_inner_product( - q: &[f32], - values: &[f32], - codes: &[u8], - num_rows: usize, - dim: usize, -) -> BufferMut { - let mut out = BufferMut::::with_capacity(num_rows); - - const PARTIAL_SUMS: usize = 8; - - for row_codes in codes.chunks_exact(dim) { - let mut acc = [0.0f32; PARTIAL_SUMS]; - - let code_chunks = row_codes.chunks_exact(PARTIAL_SUMS); - let q_chunks = q.chunks_exact(PARTIAL_SUMS); - let code_rem = code_chunks.remainder(); - let q_rem = q_chunks.remainder(); - - for (cc, qd) in code_chunks.zip(q_chunks) { - for i in 0..PARTIAL_SUMS { - acc[i] = qd[i].mul_add(values[cc[i] as usize], acc[i]); - } - } - - for (&code, &q_val) in code_rem.iter().zip(q_rem.iter()) { - acc[0] = q_val.mul_add(values[code as usize], acc[0]); - } - - // SAFETY: we reserved `num_rows` slots and push exactly once per row. - unsafe { out.push_unchecked(acc.iter().sum::()) }; - } - - out -} - #[cfg(test)] mod tests { @@ -801,787 +530,4 @@ mod tests { fn inner_product_tensor_rhs() -> ArrayRef { tensor_array(&[2], &[5.0, 6.0, 7.0, 8.0]).expect("valid tensor array") } - - // ---- Tests for the `SorfTransform + constant` and `Dict + constant` fast paths ---- - - #[allow( - clippy::cast_possible_truncation, - reason = "tests build small fixtures with deterministic in-range indices" - )] - mod constant_query_optimizations { - use rstest::rstest; - use vortex_array::ArrayRef; - use vortex_array::IntoArray; - use vortex_array::VortexSessionExecute; - use vortex_array::arrays::Constant; - use vortex_array::arrays::FixedSizeListArray; - use vortex_array::arrays::PrimitiveArray; - use vortex_array::arrays::ScalarFnArray; - use vortex_array::arrays::dict::DictArray; - use vortex_array::dtype::DType; - use vortex_array::dtype::Nullability; - use vortex_array::dtype::PType; - use vortex_array::validity::Validity; - use vortex_buffer::Buffer; - use vortex_error::VortexResult; - - use crate::scalar_fns::inner_product::InnerProduct; - use crate::scalar_fns::inner_product::constant_tensor_storage; - use crate::scalar_fns::sorf_transform::SorfMatrix; - use crate::scalar_fns::sorf_transform::SorfOptions; - use crate::scalar_fns::sorf_transform::SorfTransform; - use crate::tests::SESSION; - use crate::types::vector::Vector; - use crate::utils::extract_flat_elements; - use crate::utils::test_helpers::literal_vector_array; - use crate::utils::test_helpers::vector_array; - - /// Build a `Vector` whose storage is `FSL(DictArray(codes: u8, values: - /// f32))`. This mirrors the shape that TurboQuant produces as the SorfTransform child. - fn dict_vector_f32(list_size: u32, codes: &[u8], values: &[f32]) -> VortexResult { - let num_rows = codes.len() / list_size as usize; - let codes_arr = - PrimitiveArray::new::(Buffer::copy_from(codes), Validity::NonNullable) - .into_array(); - let values_arr = - PrimitiveArray::new::(Buffer::copy_from(values), Validity::NonNullable) - .into_array(); - let dict = DictArray::try_new(codes_arr, values_arr)?; - let fsl = FixedSizeListArray::try_new( - dict.into_array(), - list_size, - Validity::NonNullable, - num_rows, - )?; - Vector::try_new_vector_array(fsl.into_array()) - } - - /// Execute an inner product and return the flat `f32` results. - fn eval_ip_f32(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { - let scalar_fn = InnerProduct::new().erased(); - let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], len)?; - let mut ctx = SESSION.create_execution_ctx(); - let prim: PrimitiveArray = result.into_array().execute(&mut ctx)?; - Ok(prim.as_slice::().to_vec()) - } - - fn assert_close_f32(actual: &[f32], expected: &[f32], tol: f32) { - assert_eq!(actual.len(), expected.len(), "length mismatch"); - for (i, (a, e)) in actual.iter().zip(expected).enumerate() { - assert!( - (a - e).abs() < tol, - "row {i}: got {a}, expected {e} (diff = {})", - (a - e).abs() - ); - } - } - - /// Build a SorfTransform ScalarFnArray whose child is a `Vector` - /// wrapping `FSL(Dict(codes, values))`. Returns `(sorf_array, codes, values, - /// padded_dim)`. - fn build_sorf_with_dict_child( - dim: u32, - num_rows: usize, - seed: u64, - num_rounds: u8, - ) -> VortexResult<(ArrayRef, Vec, Vec, usize)> { - let padded_dim = (dim as usize).next_power_of_two(); - // Small hand-picked codebook of 8 f32 centroids. - let values: Vec = vec![-1.5, -1.0, -0.5, -0.1, 0.1, 0.5, 1.0, 1.5]; - // Deterministic codes in 0..values.len() covering every position. - let codes: Vec = (0..num_rows * padded_dim) - .map(|i| (i as u8) % (values.len() as u8)) - .collect(); - - let padded_vector = dict_vector_f32(padded_dim as u32, &codes, &values)?; - let sorf_options = SorfOptions { - seed, - num_rounds, - dimensions: dim, - element_ptype: PType::F32, - }; - let sorf = - SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); - Ok((sorf, codes, values, padded_dim)) - } - - /// Decode a SorfTransform-wrapped dict-vector to a flat `Vec` of `num_rows * - /// dim` post-rotation, post-truncation values. This is the ground truth against - /// which we compare the fast-path result. - fn decode_sorf_dict( - codes: &[u8], - values: &[f32], - padded_dim: usize, - dim: usize, - num_rows: usize, - seed: u64, - num_rounds: u8, - ) -> VortexResult> { - let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?; - let mut padded = vec![0.0f32; padded_dim]; - let mut rotated = vec![0.0f32; padded_dim]; - let mut out = Vec::with_capacity(num_rows * dim); - for row in 0..num_rows { - for j in 0..padded_dim { - padded[j] = values[codes[row * padded_dim + j] as usize]; - } - rotation.inverse_rotate(&padded, &mut rotated); - out.extend_from_slice(&rotated[..dim]); - } - Ok(out) - } - - fn naive_dot(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum() - } - - // ---- Case 1: SorfTransform + Constant pull-through ---- - - #[test] - fn constant_tensor_storage_accepts_extension_scalar_literal() -> VortexResult<()> { - let literal = literal_vector_array(&[1.0f32, 2.0, 3.0], 5); - let storage = - constant_tensor_storage(&literal).expect("literal vector should be recognized"); - - assert_eq!(storage.len(), 5); - let const_storage = storage - .as_opt::() - .expect("storage should remain constant-backed"); - assert!(matches!( - const_storage.scalar().dtype(), - DType::FixedSizeList(_, 3, Nullability::NonNullable) - )); - - let mut ctx = SESSION.create_execution_ctx(); - let flat = extract_flat_elements(&storage, 3, &mut ctx)?; - assert_eq!(flat.row::(0), &[1.0, 2.0, 3.0]); - Ok(()) - } - - /// Case 1: SorfTransform on LHS, constant query on RHS, with `dim < padded_dim` - /// so the zero-padding branch is exercised. - #[test] - fn case1_sorf_lhs_constant_rhs_padded_gt_dim() -> VortexResult<()> { - let dim: u32 = 100; - let num_rows = 7usize; - let seed = 42u64; - let num_rounds = 3u8; - let padded_dim = (dim as usize).next_power_of_two(); - assert!(padded_dim > dim as usize, "test must exercise padding"); - - let (sorf_lhs, codes, values, padded_dim_computed) = - build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; - assert_eq!(padded_dim_computed, padded_dim); - - // Query has `dim` elements. - let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect(); - let const_rhs = Vector::constant_array(&query_elems, num_rows)?; - - // Ground truth: decode LHS to plain f32 vectors, dot each with the query. - let decoded = decode_sorf_dict( - &codes, - &values, - padded_dim, - dim as usize, - num_rows, - seed, - num_rounds, - )?; - let expected: Vec = (0..num_rows) - .map(|i| { - naive_dot( - &decoded[i * dim as usize..(i + 1) * dim as usize], - &query_elems, - ) - }) - .collect(); - - let actual = eval_ip_f32(sorf_lhs, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-3); - Ok(()) - } - - /// Case 1: SorfTransform on RHS, constant query on LHS (mirrored). - #[test] - fn case1_constant_lhs_sorf_rhs_mirrored() -> VortexResult<()> { - let dim: u32 = 100; - let num_rows = 5usize; - let seed = 7u64; - let num_rounds = 3u8; - - let (sorf, codes, values, padded_dim) = - build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; - - let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.2).cos()).collect(); - let const_lhs = Vector::constant_array(&query_elems, num_rows)?; - - let decoded = decode_sorf_dict( - &codes, - &values, - padded_dim, - dim as usize, - num_rows, - seed, - num_rounds, - )?; - let expected: Vec = (0..num_rows) - .map(|i| { - naive_dot( - &decoded[i * dim as usize..(i + 1) * dim as usize], - &query_elems, - ) - }) - .collect(); - - let actual = eval_ip_f32(const_lhs, sorf, num_rows)?; - assert_close_f32(&actual, &expected, 1e-3); - Ok(()) - } - - /// Case 1: `dim == padded_dim` (power-of-two, no zero padding). - #[test] - fn case1_padded_equals_dim() -> VortexResult<()> { - let dim: u32 = 128; - let num_rows = 4usize; - let seed = 11u64; - let num_rounds = 3u8; - - let (sorf, codes, values, padded_dim) = - build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; - assert_eq!(padded_dim, dim as usize); - - let query_elems: Vec = (0..dim).map(|i| i as f32 * 0.01 - 0.5).collect(); - let const_rhs = Vector::constant_array(&query_elems, num_rows)?; - - let decoded = decode_sorf_dict( - &codes, - &values, - padded_dim, - dim as usize, - num_rows, - seed, - num_rounds, - )?; - let expected: Vec = (0..num_rows) - .map(|i| { - naive_dot( - &decoded[i * dim as usize..(i + 1) * dim as usize], - &query_elems, - ) - }) - .collect(); - - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-3); - Ok(()) - } - - /// Case 1: empty `len == 0`. The fast path should handle this without exploding. - #[test] - fn case1_empty_len_zero() -> VortexResult<()> { - let dim: u32 = 100; - let num_rows = 0usize; - let seed = 42u64; - let num_rounds = 3u8; - - let (sorf, _codes, _values, _padded_dim) = - build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; - - let query_elems: Vec = vec![0.0; dim as usize]; - let const_rhs = Vector::constant_array(&query_elems, num_rows)?; - - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; - assert_eq!(actual.len(), 0); - Ok(()) - } - - // ---- Case 2: Dict + Constant direct-lookup path ---- - - /// Case 2: Vector[FSL[Dict(u8, f32)]] on LHS, constant query on RHS. - #[test] - fn case2_dict_lhs_constant_rhs_matches_naive() -> VortexResult<()> { - let list_size: u32 = 8; - let num_rows = 10usize; - // 8 centroids, tiny table. - let values: Vec = vec![-1.0, -0.5, -0.25, -0.1, 0.1, 0.25, 0.5, 1.0]; - // Deterministic codes. - let codes: Vec = (0..num_rows * list_size as usize) - .map(|i| (i as u8) % (values.len() as u8)) - .collect(); - let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; - - let query: Vec = (0..list_size).map(|i| (i as f32 + 1.0) * 0.3).collect(); - let const_rhs = Vector::constant_array(&query, num_rows)?; - - let expected: Vec = (0..num_rows) - .map(|row| { - let mut acc = 0.0f32; - for j in 0..list_size as usize { - let k = codes[row * list_size as usize + j] as usize; - acc += query[j] * values[k]; - } - acc - }) - .collect(); - - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-5); - Ok(()) - } - - /// Case 2: constant query on LHS, dict column on RHS (mirrored). - #[test] - fn case2_constant_lhs_dict_rhs_mirrored() -> VortexResult<()> { - let list_size: u32 = 4; - let num_rows = 6usize; - let values: Vec = vec![0.1, 0.4, 0.7, 1.0]; - let codes: Vec = (0..num_rows * list_size as usize) - .map(|i| ((i * 3) as u8) % (values.len() as u8)) - .collect(); - let dict_rhs = dict_vector_f32(list_size, &codes, &values)?; - - let query: Vec = vec![0.5, -1.0, 2.5, -0.25]; - let const_lhs = Vector::constant_array(&query, num_rows)?; - - let expected: Vec = (0..num_rows) - .map(|row| { - let mut acc = 0.0f32; - for j in 0..list_size as usize { - let k = codes[row * list_size as usize + j] as usize; - acc += query[j] * values[k]; - } - acc - }) - .collect(); - - let actual = eval_ip_f32(const_lhs, dict_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-5); - Ok(()) - } - - /// Case 2: dict with `u16` codes (and hence more than 256 values) falls through to - /// the standard path but still produces the correct result. The direct-lookup path - /// only handles `u8` codes today. - #[test] - fn case2_u16_codes_falls_through() -> VortexResult<()> { - let list_size: u32 = 4; - let num_rows = 3usize; - let num_values = 300usize; - let values: Vec = (0..num_values).map(|i| i as f32 * 0.01).collect(); - // Codes must be u16 because 300 > 255. dict_vector_f32 only supports u8 so we - // build the dict by hand here. - let codes_u16: Vec = (0..(num_rows * 4)) - .map(|i| (i % num_values) as u16) - .collect(); - let codes_arr = - PrimitiveArray::new::(Buffer::copy_from(codes_u16), Validity::NonNullable) - .into_array(); - let values_arr = - PrimitiveArray::new::(Buffer::copy_from(&values), Validity::NonNullable) - .into_array(); - let dict = DictArray::try_new(codes_arr, values_arr)?; - let fsl = FixedSizeListArray::try_new( - dict.into_array(), - list_size, - Validity::NonNullable, - num_rows, - )?; - let dict_lhs = Vector::try_new_vector_array(fsl.into_array())?; - - let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = Vector::constant_array(&query, num_rows)?; - - // Build expected by decoding by hand. - let expected: Vec = (0..num_rows) - .map(|row| { - let mut acc = 0.0f32; - for j in 0..4 { - let code = (row * 4 + j) % num_values; - acc += query[j] * values[code]; - } - acc - }) - .collect(); - - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-5); - Ok(()) - } - - /// Case 2: plain (non-dict) FSL with a constant RHS falls through to the standard - /// path and produces the correct result. - #[test] - fn case2_plain_fsl_falls_through() -> VortexResult<()> { - let dim: u32 = 4; - let num_rows = 3usize; - let lhs_elems: Vec = (0..num_rows * dim as usize) - .map(|i| i as f32 * 0.25) - .collect(); - let plain_lhs = vector_array(dim, &lhs_elems)?; - - let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = Vector::constant_array(&query, num_rows)?; - - let expected: Vec = (0..num_rows) - .map(|row| { - naive_dot( - &lhs_elems[row * dim as usize..(row + 1) * dim as usize], - &query, - ) - }) - .collect(); - - let actual = eval_ip_f32(plain_lhs, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-5); - Ok(()) - } - - /// Case 2: empty `len == 0` fast path returns an empty primitive array without - /// touching the codes buffer. - #[test] - fn case2_empty_len_zero() -> VortexResult<()> { - let list_size: u32 = 4; - let num_rows = 0usize; - let values: Vec = vec![0.0, 1.0, 2.0, 3.0]; - let codes: Vec = Vec::new(); - let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; - - let query: Vec = vec![0.0; 4]; - let const_rhs = Vector::constant_array(&query, num_rows)?; - - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; - assert_eq!(actual.len(), 0); - Ok(()) - } - - /// Case 1 + Case 2 end-to-end: the SorfTransform-wrapped dict column hits Case 1 - /// then Case 2 via recursive execution. - #[test] - fn end_to_end_sorf_plus_dict_cosine_path() -> VortexResult<()> { - let dim: u32 = 100; - let num_rows = 9usize; - let seed = 99u64; - let num_rounds = 3u8; - - let (sorf, codes, values, padded_dim) = - build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; - - let query_elems: Vec = (0..dim).map(|i| ((i as f32) * 0.15).sin() * 0.4).collect(); - let const_rhs = Vector::constant_array(&query_elems, num_rows)?; - - // Ground truth via full decode + naive dot. - let decoded = decode_sorf_dict( - &codes, - &values, - padded_dim, - dim as usize, - num_rows, - seed, - num_rounds, - )?; - let expected: Vec = (0..num_rows) - .map(|i| { - naive_dot( - &decoded[i * dim as usize..(i + 1) * dim as usize], - &query_elems, - ) - }) - .collect(); - - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-3); - Ok(()) - } - - // ---- Additional correctness / stress tests (all with loose tolerances) ---- - - /// A tiny in-place xorshift64 PRNG so these tests don't depend on `rand`. Producing - /// deterministic pseudo-random f32 values lets the correctness checks exercise - /// realistic data instead of smooth sin/cos patterns. - struct XorShift64(u64); - - impl XorShift64 { - fn new(seed: u64) -> Self { - // Any nonzero seed is fine; xorshift fixed-points at 0. - Self(seed.wrapping_add(0x9E37_79B9_7F4A_7C15)) - } - - fn next_u64(&mut self) -> u64 { - let mut x = self.0; - x ^= x << 13; - x ^= x >> 7; - x ^= x << 17; - self.0 = x; - x - } - - /// Uniform f32 in `[-1.0, 1.0)`. - fn next_f32(&mut self) -> f32 { - // Top 24 bits -> mantissa in [0, 1), then shift to [-1, 1). - let bits = (self.next_u64() >> 40) as u32; // 24 bits - (bits as f32) / (1u32 << 24) as f32 * 2.0 - 1.0 - } - } - - /// Case 2 stress: u8-coded dict with 200 centroids (formerly blocked by the - /// `values.len() <= 256` gate). The direct-lookup path must now handle it. - #[test] - fn case2_large_u8_codebook_direct_lookup() -> VortexResult<()> { - let list_size: u32 = 16; - let num_rows = 20usize; - let num_centroids = 200usize; - assert!(num_centroids > 8 && num_centroids <= 256); - - let mut rng = XorShift64::new(0xDEAD_BEEF); - let values: Vec = (0..num_centroids).map(|_| rng.next_f32()).collect(); - let codes: Vec = (0..num_rows * list_size as usize) - .map(|_| (rng.next_u64() % num_centroids as u64) as u8) - .collect(); - - let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; - let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = Vector::constant_array(&query, num_rows)?; - - let expected: Vec = (0..num_rows) - .map(|row| { - let mut acc = 0.0f32; - for j in 0..list_size as usize { - let k = codes[row * list_size as usize + j] as usize; - acc += query[j] * values[k]; - } - acc - }) - .collect(); - - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-4); - Ok(()) - } - - /// Parameterized sweep over the full `InnerProduct(SorfTransform(Vector[FSL(Dict)]), - /// ConstantArray)` tree, exercising the case 1 + case 2 chain for a realistic mix - /// of dimensions, row counts, seeds, and number of SORF rounds. Tolerance is - /// deliberately loose because the rewrite introduces an f32-domain rotation that - /// accumulates a small numerical drift versus a naive decode. - #[rstest] - #[case::small_no_pad(128, 11, 1, 1)] - #[case::small_no_pad_rounds3(128, 23, 1_234, 3)] - #[case::small_padded(100, 17, 42, 3)] - #[case::mid_padded(200, 13, 2024, 3)] - #[case::mid_power_of_two(256, 31, 7, 3)] - #[case::larger_padded(300, 9, 99, 3)] - #[case::max_rounds(128, 5, 31_415, 5)] - fn case1_sorf_random_sweep( - #[case] dim: u32, - #[case] num_rows: usize, - #[case] seed: u64, - #[case] num_rounds: u8, - ) -> VortexResult<()> { - let (sorf, codes, values, padded_dim) = - build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; - - // Use a pseudo-random query with both positive and negative entries so the sum - // has cancellation. - let mut rng = XorShift64::new(seed ^ 0xABCD_1234); - let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = Vector::constant_array(&query, num_rows)?; - - let decoded = decode_sorf_dict( - &codes, - &values, - padded_dim, - dim as usize, - num_rows, - seed, - num_rounds, - )?; - let expected: Vec = (0..num_rows) - .map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query)) - .collect(); - - // Loose tolerance: the sorf transform works in f32 with a k-round butterfly, so - // the rewrite path and the decoded path accumulate slightly different rounding - // even though the math is equivalent in exact arithmetic. - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-2); - Ok(()) - } - - /// Parameterized sweep over plain `Vector[FSL(Dict(u8, f32))]` + constant query, - /// without SorfTransform in the mix. This directly exercises case 2 across a - /// variety of list sizes, num_rows, and codebook sizes including large ones that - /// the old `<= 256` gate would have rejected. - #[rstest] - #[case::small(4, 7, 8)] - #[case::medium(16, 50, 64)] - #[case::larger(32, 100, 150)] - #[case::very_large_codebook(8, 25, 250)] - fn case2_random_sweep( - #[case] list_size: u32, - #[case] num_rows: usize, - #[case] num_centroids: usize, - ) -> VortexResult<()> { - let mut rng = XorShift64::new((list_size as u64) * 31 + num_rows as u64); - let values: Vec = (0..num_centroids).map(|_| rng.next_f32()).collect(); - assert!(num_centroids <= 256, "u8 codes cap at 256 centroids"); - let codes: Vec = (0..num_rows * list_size as usize) - .map(|_| (rng.next_u64() % num_centroids as u64) as u8) - .collect(); - - let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; - let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = Vector::constant_array(&query, num_rows)?; - - let expected: Vec = (0..num_rows) - .map(|row| { - let mut acc = 0.0f32; - for j in 0..list_size as usize { - let k = codes[row * list_size as usize + j] as usize; - acc += query[j] * values[k]; - } - acc - }) - .collect(); - - // Tight tolerance here because no SorfTransform rotation is involved — the - // arithmetic should agree bit-for-bit up to float reassociation. - let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-4); - Ok(()) - } - - /// End-to-end regression: for a plausible vector-search configuration (SORF rounds - /// = 3, dim = 128, num_rows = 64, u8 codes, 64 centroids), the fast-path result - /// must track a fully naive computation within 1e-2. - #[test] - fn end_to_end_dim128_rows64_bit6_regression() -> VortexResult<()> { - let dim: u32 = 128; - let num_rows = 64usize; - let seed = 0xFACE_F00D; - let num_rounds = 3u8; - - // Use 64 centroids (6 bits), a typical TurboQuant configuration. - let num_centroids = 64usize; - let padded_dim = (dim as usize).next_power_of_two(); - let mut rng = XorShift64::new(seed); - let values: Vec = (0..num_centroids).map(|_| rng.next_f32()).collect(); - let codes: Vec = (0..num_rows * padded_dim) - .map(|_| (rng.next_u64() % num_centroids as u64) as u8) - .collect(); - - let padded_vector = dict_vector_f32(padded_dim as u32, &codes, &values)?; - let sorf_options = SorfOptions { - seed, - num_rounds, - dimensions: dim, - element_ptype: PType::F32, - }; - let sorf = - SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); - - let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = Vector::constant_array(&query, num_rows)?; - - let decoded = decode_sorf_dict( - &codes, - &values, - padded_dim, - dim as usize, - num_rows, - seed, - num_rounds, - )?; - let expected: Vec = (0..num_rows) - .map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query)) - .collect(); - - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-2); - - // Also verify the max relative error is small. The SORF rotation does not - // amplify error, so both measures should be bounded. - for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() { - let denom = e.abs().max(1.0); - let rel = (a - e).abs() / denom; - assert!( - rel < 1e-3, - "row {i}: rel err {rel} too large (a={a}, e={e})" - ); - } - Ok(()) - } - - /// Case 1 + Case 2 end-to-end with varying `num_rounds`. The rotation becomes - /// progressively more chaotic as rounds increase, so this catches any off-by-one - /// bug in the round-indexing that would not show up in the 3-round default. - #[rstest] - #[case(1)] - #[case(2)] - #[case(3)] - #[case(4)] - #[case(5)] - fn case1_various_num_rounds(#[case] num_rounds: u8) -> VortexResult<()> { - let dim: u32 = 128; - let num_rows = 8usize; - let seed = 0x1234_5678; - - let (sorf, codes, values, padded_dim) = - build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; - - let mut rng = XorShift64::new(seed ^ (num_rounds as u64)); - let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = Vector::constant_array(&query, num_rows)?; - - let decoded = decode_sorf_dict( - &codes, - &values, - padded_dim, - dim as usize, - num_rows, - seed, - num_rounds, - )?; - let expected: Vec = (0..num_rows) - .map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query)) - .collect(); - - let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; - assert_close_f32(&actual, &expected, 1e-2); - Ok(()) - } - - /// Swap LHS and RHS on the full tree to prove the side-detection and the scalar - /// argument-order handling are symmetric for both cases simultaneously. - #[test] - fn end_to_end_constant_lhs_sorf_rhs_mirrored() -> VortexResult<()> { - let dim: u32 = 256; - let num_rows = 12usize; - let seed = 0xBEEF_CAFE; - let num_rounds = 3u8; - - let (sorf, codes, values, padded_dim) = - build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; - - let mut rng = XorShift64::new(seed); - let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_lhs = Vector::constant_array(&query, num_rows)?; - - let decoded = decode_sorf_dict( - &codes, - &values, - padded_dim, - dim as usize, - num_rows, - seed, - num_rounds, - )?; - let expected: Vec = (0..num_rows) - .map(|i| naive_dot(&decoded[i * dim as usize..(i + 1) * dim as usize], &query)) - .collect(); - - let actual = eval_ip_f32(const_lhs, sorf, num_rows)?; - assert_close_f32(&actual, &expected, 1e-2); - Ok(()) - } - } } diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index a814af33f33..d8ca5fc41ed 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -406,7 +406,7 @@ fn execute_l2_denorm_constant_norms( /// /// Rows that are null in the original input are **zeroed out** in the normalized output. This is /// necessary because null rows may have undefined (garbage) physical storage values, and we do not -/// want to let those propagate into downstream encodings (like TurboQuant). +/// want to let those propagate into downstream lossy encodings. /// /// # Nullability /// diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 59d49fc8e1a..7e8a85cc500 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -57,7 +57,7 @@ use crate::utils::validate_tensor_float_input; /// column of the same float type. /// /// When the input is wrapped in [`L2Denorm`], this operator treats the stored norms as -/// authoritative. For lossy encodings such as TurboQuant, that means `L2Norm` may intentionally +/// authoritative. For lossy encodings, that means `L2Norm` may intentionally /// read the stored norms instead of re-deriving them from fully decoded coordinates. That behavior /// is part of the lossy storage contract, not a separate lossy-compute mode. #[derive(Clone)] @@ -127,7 +127,7 @@ impl ScalarFnVTable for L2Norm { let norm_dtype = DType::Primitive(element_ptype, ext.nullability()); // L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored - // norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics + // norms. Exact callers of lossy encodings opt into that storage semantics // instead of forcing a decode-and-recompute path here. if input_ref.is::>() { let (_, norms) = extract_l2_denorm_children(&input_ref); diff --git a/vortex-tensor/src/scalar_fns/mod.rs b/vortex-tensor/src/scalar_fns/mod.rs index 1d1a362c8af..9acf59bda01 100644 --- a/vortex-tensor/src/scalar_fns/mod.rs +++ b/vortex-tensor/src/scalar_fns/mod.rs @@ -7,4 +7,3 @@ pub mod cosine_similarity; pub mod inner_product; pub mod l2_denorm; pub mod l2_norm; -pub mod sorf_transform; diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs b/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs deleted file mode 100644 index 26d38e87a1e..00000000000 --- a/vortex-tensor/src/scalar_fns/sorf_transform/mod.rs +++ /dev/null @@ -1,146 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! SORF inverse transform scalar function. -//! -//! SORF (Structured Orthogonal Random Features, [Yu et al. 2016][sorf-paper]) is a fast structured -//! approximation to a random orthogonal matrix. It composes random sign diagonals with the -//! Walsh-Hadamard transform to achieve O(d log d) matrix-vector products instead of the O(d^2) cost -//! of a dense orthogonal matrix. -//! -//! This module wraps a [`Vector`] extension array whose dimension is the padded SORF dimension -//! (e.g. a `Vector` wrapping `FSL(Dict(codes, centroids))`) and applies the inverse SORF transform -//! at execution time, producing a [`Vector`] extension array with the original (pre-padding) -//! dimensionality. -//! -//! The transform parameters are stored as a deterministic seed in [`SorfOptions`], so the -//! [`SorfMatrix`] is reconstructed cheaply at decode time. Sign diagonals are defined by Vortex's -//! frozen local SplitMix64 stream contract rather than by an external RNG crate. -//! -//! # Input element type: `f32` only (TODO(connor): for now...) -//! -//! The child [`Vector`] **must** have `f32` storage elements. This is a hard constraint that is -//! enforced by `SorfTransform`'s `return_dtype` check. Callers with `f16` or `f64` source data need -//! to cast to `f32` before wrapping in a [`Vector`] and handing it to SorfTransform. -//! -//! The reason for this constraint is that TurboQuant (the only production caller today) stores its -//! dictionary centroids as `f32`, and the SORF transform itself operates internally in `f32`. -//! -//! Supporting other float storage types would require an implicit up-/down-cast that we do not yet -//! want to bake into SorfTransform. This restriction is intentional and may be relaxed in the -//! future, but today it is load-bearing. -//! -//! # Output element type -//! -//! The output [`Vector`]'s element type is whatever [`SorfOptions::element_ptype`] is set to. It -//! does **not** have to match the child's `f32` storage: we apply an explicit `f32 -> T` cast -//! while materializing the output. This lets SorfTransform hand its result directly to a -//! downstream consumer (e.g. [`L2Denorm`](crate::scalar_fns::l2_denorm::L2Denorm)) whose -//! element-type expectation may differ from the `f32` the transform operated on internally. -//! -//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf -//! [`Vector`]: crate::vector::Vector - -use std::fmt; -use std::fmt::Formatter; - -use vortex_array::ArrayRef; -use vortex_array::arrays::ScalarFnArray; -use vortex_array::dtype::PType; -use vortex_array::scalar_fn::TypedScalarFnInstance; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; - -mod rotation; -mod splitmix64; -pub use rotation::SorfMatrix; - -mod vtable; - -/// Inverse SORF orthogonal transform scalar function. -/// -/// Takes a [`Vector`](crate::vector::Vector) extension child at the padded dimension with `f32` -/// storage, applies the inverse structured Walsh-Hadamard orthogonal transform, truncates to the -/// original (pre-padding) dimension, casts element-wise to [`SorfOptions::element_ptype`], and -/// wraps the result in a new [`Vector`](crate::vector::Vector) extension array. -/// -/// See the [module-level docs](crate::scalar_fns::sorf_transform) for the rationale behind the -/// `f32`-only input constraint. -#[derive(Clone)] -pub struct SorfTransform; - -/// Options for the SORF inverse transform scalar function. -/// -/// Stored in the [`ScalarFnArray`] and used to deterministically reconstruct the -/// [`SorfMatrix`] at decode time. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct SorfOptions { - /// Seed used to generate the structured sign diagonals via Vortex's frozen SplitMix64 stream. - pub seed: u64, - /// Number of sign-diagonal + WHT rounds in the structured orthogonal transform. - pub num_rounds: u8, - /// Original vector dimension (before power-of-2 padding). The output - /// [`Vector`](crate::vector::Vector) has this dimension. - pub dimensions: u32, - /// Element type of the output [`Vector`](crate::vector::Vector). The child input must always - /// be `f32`, but the output can be any float type (`F16`, `F32`, `F64`); the final - /// `f32 -> element_ptype` cast happens while building the output. - pub element_ptype: PType, -} - -impl SorfTransform { - /// Creates a new [`TypedScalarFnInstance`] wrapping the SORF inverse transform with the given options. - pub fn new(options: &SorfOptions) -> TypedScalarFnInstance { - TypedScalarFnInstance::new(SorfTransform, options.clone()) - } - - /// Constructs a validated [`ScalarFnArray`] that lazily applies the inverse SORF transform. - /// - /// The `child` must be a [`Vector`] extension array (or an array that executes to one) with: - /// - /// - dimension equal to `padded_dim` (i.e. `options.dimension.next_power_of_two()`), and - /// - `f32` storage elements. This is a hard requirement today; see the - /// [module-level docs](crate::scalar_fns::sorf_transform) for the rationale. - /// - /// The output [`Vector`] has dimension `options.dimension` and element type - /// `options.element_ptype`. - /// - /// [`Vector`]: crate::vector::Vector - pub fn try_new_array( - options: &SorfOptions, - child: ArrayRef, - len: usize, - ) -> VortexResult { - validate_sorf_options(options)?; - - ScalarFnArray::try_new(SorfTransform::new(options).erased(), vec![child], len) - } -} - -/// Checks that the SORF configuration is valid. -pub(crate) fn validate_sorf_options(options: &SorfOptions) -> VortexResult<()> { - vortex_ensure!( - options.num_rounds >= 1, - "SorfTransform num_rounds must be >= 1, got {}", - options.num_rounds - ); - vortex_ensure!( - options.element_ptype.is_float(), - "SorfTransform element_ptype must be a float type, got {}", - options.element_ptype - ); - Ok(()) -} - -impl fmt::Display for SorfOptions { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!( - f, - "SorfOptions(seed={}, rounds={}, dim={}, ptype={})", - self.seed, self.num_rounds, self.dimensions, self.element_ptype - ) - } -} - -#[cfg(test)] -mod tests; diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs deleted file mode 100644 index b416eb4a7e6..00000000000 --- a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs +++ /dev/null @@ -1,559 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! SORF (Structured Orthogonal Random Features) orthogonal transform. -//! -//! Implements the SORF construction from [Yu et al. 2016][sorf-paper]: a fast structured -//! approximation to a random orthogonal matrix using random sign diagonals interleaved with the -//! Fast Walsh-Hadamard Transform (FWHT). -//! -//! For `k` rounds, the transform is `norm * H * D_k * ... * H * D_1 * x`, where `D_1` is the -//! first sign diagonal applied. The number of rounds is configurable (typically 3). Each round -//! applies a random sign diagonal `D_i` and then the Hadamard matrix `H`, giving O(d log d) cost -//! per matrix-vector product instead of the O(d^2) cost of a dense orthogonal matrix. -//! -//! Vortex defines those sign diagonals using a frozen local SplitMix64 stream rather than an -//! external RNG crate. The contract is: -//! -//! - state is a single `u64` seed, -//! - each `next_u64()` call uses the SplitMix64 reference algorithm with wrapping `u64` -//! arithmetic, -//! - signs are generated in round-major, block-major order, -//! - each generated `u64` contributes 64 signs in least-significant-bit-first order, -//! - bit `1` means `+1` and bit `0` means `-1`. -//! -//! This makes SORF sign generation stable as a Vortex format contract even if external RNG -//! implementations change. -//! -//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf -//! -//! The FWHT exploits the Kronecker product structure of the Hadamard matrix (`H_n = H_2 (x) H_2 -//! (x) ... (x) H_2`, with `log2(n)` factors) to compute the matrix-vector product in O(n log n) -//! time using only in-place 2-element butterfly operations. No row of the full n x n Hadamard -//! matrix is ever materialized. -//! -//! For dimensions that are not powers of 2, the input is zero-padded to the next power of 2 before -//! the transform and truncated afterward. -//! -//! # Sign representation -//! -//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) and `0x80000000` for -//! -1 (flip IEEE 754 sign bit). The sign application function uses integer XOR instead of -//! floating-point multiply, which avoids FP dependency chains and auto-vectorizes into -//! `vpxor`/`veor`. - -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; - -use super::splitmix64::SplitMix64; - -/// IEEE 754 sign bit mask for f32. -const F32_SIGN_BIT: u32 = 0x8000_0000; - -/// A Walsh-Hadamard-based structured orthogonal transform matrix. -/// -/// All computation is done in f32. The sign diagonals are stored as IEEE 754 XOR masks on -/// f32 bit patterns, and the Walsh-Hadamard butterfly operates on `&mut [f32]` slices. -pub struct SorfMatrix { - /// Flat XOR masks for all `num_rounds` diagonal matrices, total length - /// `num_rounds * padded_dim`. - /// - /// Indexed as `round * padded_dim + i`. `0x00000000` = multiply by +1 (no-op), `0x80000000` = - /// multiply by -1 (flip sign bit). - sign_masks: Vec, - /// The number of sign-diagonal + WHT rounds. - num_rounds: usize, - /// The padded dimension (next power of 2 >= dimension). - padded_dim: usize, - /// Normalization factor: `padded_dim^(-num_rounds/2)`, applied once at the end. - norm_factor: f32, -} - -impl SorfMatrix { - /// Create a new structured Walsh-Hadamard-based orthogonal transform from a deterministic - /// seed. - /// - /// The seed is expanded using Vortex's frozen local SplitMix64 stream. Signs are generated in - /// round-major, block-major order, with each `u64` contributing 64 sign bits in - /// least-significant-bit-first order. - pub fn try_new(seed: u64, dimensions: usize, num_rounds: usize) -> VortexResult { - Self::try_new_padded(dimensions.next_power_of_two(), num_rounds, seed) - } - - /// Create a new structured Walsh-Hadamard-based orthogonal transform for a padded dimension. - /// - /// `padded_dimensions` must already be a power of two. Callers that start from an unpadded - /// logical dimension should call [`Self::try_new`] instead. - pub(crate) fn try_new_padded( - padded_dimensions: usize, - num_rounds: usize, - seed: u64, - ) -> VortexResult { - vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}"); - vortex_ensure!( - padded_dimensions.is_power_of_two(), - "padded_dimensions must be a power of two, got {padded_dimensions}" - ); - - let padded_dim = padded_dimensions; - let sign_masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); - - // Compute in f64 for precision, then store as f32 since the WHT operates on f32 buffers. - // The result is always in (0, 1] for any valid padded_dim >= 2 and num_rounds >= 1, so - // the f64 -> f32 cast is a precision loss only (it cannot overflow to infinity). - #[expect( - clippy::cast_possible_truncation, - reason = "the norm factor is in (0, 1] so the f64 -> f32 cast cannot overflow" - )] - let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32; - - Ok(Self { - sign_masks, - num_rounds, - padded_dim, - norm_factor, - }) - } - - /// Returns the padded dimension (next power of 2 >= dim). - /// - /// All `rotate`/`inverse_rotate` buffers must be this length. - pub fn padded_dim(&self) -> usize { - self.padded_dim - } - - /// Apply the forward orthogonal transform: `output = R(input)`. - /// - /// Both `input` and `output` must have length [`padded_dim()`](Self::padded_dim). The caller is - /// responsible for zero-padding input beyond `dim` positions. - pub fn rotate(&self, input: &[f32], output: &mut [f32]) { - debug_assert_eq!(input.len(), self.padded_dim); - debug_assert_eq!(output.len(), self.padded_dim); - - output.copy_from_slice(input); - self.apply_srht(output); - } - - /// Apply the inverse orthogonal transform: `output = R⁻¹(input)`. - /// - /// Both `input` and `output` must have length `padded_dim()`. - pub fn inverse_rotate(&self, input: &[f32], output: &mut [f32]) { - debug_assert_eq!(input.len(), self.padded_dim); - debug_assert_eq!(output.len(), self.padded_dim); - - output.copy_from_slice(input); - self.apply_inverse_srht(output); - } - - /// Apply the forward structured transform: `norm · H · D_k · ... · H · D₁ · x`. - fn apply_srht(&self, buf: &mut [f32]) { - for round in 0..self.num_rounds { - self.apply_signs_xor(buf, round); - walsh_hadamard_transform(buf); - } - - let norm = self.norm_factor; - buf.iter_mut().for_each(|val| *val *= norm); - } - - /// Apply the inverse structured transform. - /// - /// Forward is: `norm · H · D_k · ... · H · D₁`. - /// Inverse is: `norm · D₁ · H · ... · D_k · H`. - fn apply_inverse_srht(&self, buf: &mut [f32]) { - for round in (0..self.num_rounds).rev() { - walsh_hadamard_transform(buf); - self.apply_signs_xor(buf, round); - } - - let norm = self.norm_factor; - buf.iter_mut().for_each(|val| *val *= norm); - } - - /// Apply one round's sign masks via XOR on the IEEE 754 sign bit. - /// - /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to - /// multiplying each element by +/-1.0, but avoids FP dependency chains. - fn apply_signs_xor(&self, buf: &mut [f32], round: usize) { - let masks = &self.sign_masks[round * self.padded_dim..][..self.padded_dim]; - for (val, &mask) in buf.iter_mut().zip(masks.iter()) { - *val = f32::from_bits(val.to_bits() ^ mask); - } - } - - /// Export the sign vectors as a flat `Vec` of 0/1 values in inverse application order - /// `[D_k | ... | D₁]`. - /// - /// Convention: `1` = positive (+1), `0` = negative (-1). The output has length - /// `num_rounds * padded_dim` and is suitable for bitpacking via FastLanes - /// `bitpack_encode(..., 1, None)`. - #[cfg(test)] - pub fn export_inverse_signs_u8(&self) -> Vec { - let total = self.num_rounds * self.padded_dim; - let mut out = Vec::with_capacity(total); - - // Store in inverse order: round k-1 first, then k-2, ..., then 0. - for round in (0..self.num_rounds).rev() { - let offset = round * self.padded_dim; - for &mask in &self.sign_masks[offset..offset + self.padded_dim] { - out.push(if mask == 0 { 1u8 } else { 0u8 }); - } - } - out - } - - /// Reconstruct a [`SorfMatrix`] from unpacked `u8` 0/1 values. - /// - /// The input must have length `num_rounds * padded_dim` with signs in inverse application - /// order `[D_k | ... | D₁]` (as produced by [`export_inverse_signs_u8`]). Convention: - /// `1` = positive, `0` = negative. - /// - /// This is the decode-time reconstruction path: FastLanes SIMD-unpacks the stored - /// [`BitPackedArray`] into `&[u8]`, which is passed here. - #[cfg(test)] - pub fn from_u8_slice( - signs_u8: &[u8], - dimension: usize, - num_rounds: usize, - ) -> VortexResult { - vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}"); - let padded_dim = dimension.next_power_of_two(); - vortex_ensure!( - signs_u8.len() == num_rounds * padded_dim, - "Expected {} sign bytes, got {}", - num_rounds * padded_dim, - signs_u8.len() - ); - - // The storage is in inverse application order: round k-1 first, then k-2, ..., 0. - // We reconstruct into forward order (round 0 at the start of the flat vec). - let mut sign_masks = vec![0u32; num_rounds * padded_dim]; - for storage_idx in 0..num_rounds { - let round = num_rounds - 1 - storage_idx; - let src_offset = storage_idx * padded_dim; - let dst_offset = round * padded_dim; - for i in 0..padded_dim { - sign_masks[dst_offset + i] = if signs_u8[src_offset + i] != 0 { - 0u32 - } else { - F32_SIGN_BIT - }; - } - } - - // Same norm factor computation as `try_new`. See the comment there for why this cast - // cannot overflow. - #[expect( - clippy::cast_possible_truncation, - reason = "the norm factor is in (0, 1] so the f64 -> f32 cast cannot overflow" - )] - let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32; - - Ok(Self { - sign_masks, - num_rounds, - padded_dim, - norm_factor, - }) - } -} - -/// Generate XOR sign masks from the frozen local SplitMix64 stream. -/// -/// Signs are produced in round-major, block-major order. For each block we call -/// [`SplitMix64::next_u64`] exactly once and unpack its bits from least significant to most -/// significant. Bit `1` means positive sign / `0x00000000`; bit `0` means negative sign / -/// [`F32_SIGN_BIT`]. -fn gen_sign_masks_from_seed(seed: u64, padded_dim: usize, num_rounds: usize) -> Vec { - let mut rng = SplitMix64::new(seed); - let mut sign_masks = Vec::with_capacity(num_rounds * padded_dim); - - for _round in 0..num_rounds { - for base_idx in (0..padded_dim).step_by(64) { - let word = rng.next_u64(); - let bits_in_block = (padded_dim - base_idx).min(64); - sign_masks.extend((0..bits_in_block).map(|bit_idx| sign_mask_from_word(word, bit_idx))); - } - } - - sign_masks -} - -/// Convert one bit from a SplitMix64 output word into an XOR sign mask. -fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 { - if ((word >> bit_idx) & 1) != 0 { - 0u32 - } else { - F32_SIGN_BIT - } -} - -/// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative. -/// -/// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2` -/// [`butterfly`] operations each. See the [module-level docs](self) for why this avoids -/// materializing the full Hadamard matrix. -/// -/// The chunk-based iteration gives LLVM enough structure to auto-vectorize each butterfly call -/// into NEON/AVX SIMD instructions. -fn walsh_hadamard_transform(buf: &mut [f32]) { - let len = buf.len(); - debug_assert!(len.is_power_of_two()); - - let mut half = 1; - while half < len { - let stride = half * 2; - // Process in chunks of `stride` elements. Within each chunk, - // split into non-overlapping (lo, hi) halves for the butterfly. - for chunk in buf.chunks_exact_mut(stride) { - let (lo, hi) = chunk.split_at_mut(half); - butterfly(lo, hi); - } - half *= 2; - } -} - -/// Butterfly: `(lo[i], hi[i]) -> (lo[i] + hi[i], lo[i] - hi[i])`. -/// -/// This is multiplication by the 2x2 Hadamard kernel `H_2 = [[1, 1], [1, -1]]` on each element -/// pair. Factored into a separate function so LLVM can see the slice lengths match and -/// auto-vectorize. -fn butterfly(lo: &mut [f32], hi: &mut [f32]) { - debug_assert_eq!(lo.len(), hi.len()); - for (a, b) in lo.iter_mut().zip(hi.iter_mut()) { - let sum = *a + *b; - let diff = *a - *b; - *a = sum; - *b = diff; - } -} - -#[cfg(test)] -mod tests { - use rstest::rstest; - use vortex_error::VortexResult; - - use super::*; - use crate::scalar_fns::sorf_transform::splitmix64::SplitMix64; - - fn unpack_sign_bits(word: u64, count: usize) -> Vec { - (0..count) - .map(|bit_idx| u8::from(((word >> bit_idx) & 1) != 0)) - .collect() - } - - fn dim_to_usize(dim: u32) -> usize { - usize::try_from(dim).unwrap() - } - - fn rounds_to_usize(num_rounds: u8) -> usize { - usize::from(num_rounds) - } - - #[test] - fn deterministic_from_seed() -> VortexResult<()> { - let dim = dim_to_usize(64u32); - let num_rounds = rounds_to_usize(3u8); - let r1 = SorfMatrix::try_new(42u64, dim, num_rounds)?; - let r2 = SorfMatrix::try_new(42u64, dim, num_rounds)?; - let pd = r1.padded_dim(); - - let mut input = vec![0.0f32; pd]; - for i in 0..dim { - input[i] = i as f32; - } - let mut out1 = vec![0.0f32; pd]; - let mut out2 = vec![0.0f32; pd]; - - r1.rotate(&input, &mut out1); - r2.rotate(&input, &mut out2); - - assert_eq!(out1, out2); - Ok(()) - } - - #[test] - fn export_inverse_signs_matches_golden_words() -> VortexResult<()> { - let dim = dim_to_usize(64u32); - let num_rounds = rounds_to_usize(2u8); - let seed = 42u64; - let rot = SorfMatrix::try_new(seed, dim, num_rounds)?; - let padded_dim = rot.padded_dim(); - let actual = rot.export_inverse_signs_u8(); - let mut rng = SplitMix64::new(seed); - let round0_word = rng.next_u64(); - let round1_word = rng.next_u64(); - - let mut expected = Vec::with_capacity(num_rounds * padded_dim); - expected.extend(unpack_sign_bits(round1_word, padded_dim)); - expected.extend(unpack_sign_bits(round0_word, padded_dim)); - - assert_eq!(actual, expected); - Ok(()) - } - - #[test] - fn one_word_generates_64_signs_lsb_first() { - let seed = 42u64; - let padded_dim = dim_to_usize(64u32); - let num_rounds = rounds_to_usize(1u8); - let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); - assert_eq!(masks.len(), padded_dim); - - let mut rng = SplitMix64::new(seed); - let word = rng.next_u64(); - let expected: Vec<_> = (0..padded_dim) - .map(|bit_idx| sign_mask_from_word(word, bit_idx)) - .collect(); - assert_eq!(masks, expected); - } - - #[test] - fn accepts_non_power_of_two_dimensions() -> VortexResult<()> { - let rot = SorfMatrix::try_new(42u64, dim_to_usize(100u32), rounds_to_usize(3u8))?; - assert_eq!(rot.padded_dim(), 128); - Ok(()) - } - - #[test] - fn tail_block_uses_only_required_bits() { - let seed = 42u64; - let padded_dim = dim_to_usize(32u32); - let num_rounds = rounds_to_usize(1u8); - let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); - assert_eq!(masks.len(), padded_dim); - - let mut rng = SplitMix64::new(seed); - let word = rng.next_u64(); - let expected: Vec<_> = (0..padded_dim) - .map(|bit_idx| sign_mask_from_word(word, bit_idx)) - .collect(); - assert_eq!(masks, expected); - } - - /// Verify roundtrip is exact to f32 precision across many dimensions and round counts, - /// including non-power-of-two dimensions that require padding. - #[rstest] - #[case(32u32, 3u8)] - #[case(64u32, 3u8)] - #[case(100u32, 3u8)] - #[case(128u32, 1u8)] - #[case(128u32, 2u8)] - #[case(128u32, 3u8)] - #[case(128u32, 5u8)] - #[case(256u32, 3u8)] - #[case(512u32, 3u8)] - #[case(768u32, 3u8)] - #[case(1024u32, 3u8)] - fn roundtrip_exact(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { - let dim = dim_to_usize(dim); - let num_rounds = rounds_to_usize(num_rounds); - let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?; - let padded_dim = rot.padded_dim(); - - let mut input = vec![0.0f32; padded_dim]; - for i in 0..dim { - input[i] = (i as f32 + 1.0) * 0.01; - } - let mut rotated = vec![0.0f32; padded_dim]; - let mut recovered = vec![0.0f32; padded_dim]; - - rot.rotate(&input, &mut rotated); - rot.inverse_rotate(&rotated, &mut recovered); - - let max_err: f32 = input - .iter() - .zip(recovered.iter()) - .map(|(a, b)| (a - b).abs()) - .fold(0.0f32, f32::max); - let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max); - let rel_err = max_err / max_val; - - // SRHT roundtrip should be exact up to f32 precision (~1e-6). - assert!( - rel_err < 1e-5, - "roundtrip relative error too large for dim={dim}, rounds={num_rounds}: {rel_err:.2e}" - ); - Ok(()) - } - - /// Verify norm preservation across dimensions and round counts. - #[rstest] - #[case(128u32, 1u8)] - #[case(128u32, 3u8)] - #[case(128u32, 5u8)] - #[case(768u32, 3u8)] - fn preserves_norm(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { - let dim = dim_to_usize(dim); - let num_rounds = rounds_to_usize(num_rounds); - let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?; - let padded_dim = rot.padded_dim(); - - let mut input = vec![0.0f32; padded_dim]; - for i in 0..dim { - input[i] = (i as f32) * 0.01; - } - let input_norm: f32 = input.iter().map(|x| x * x).sum::().sqrt(); - - let mut rotated = vec![0.0f32; padded_dim]; - rot.rotate(&input, &mut rotated); - let rotated_norm: f32 = rotated.iter().map(|x| x * x).sum::().sqrt(); - - assert!( - (input_norm - rotated_norm).abs() / input_norm < 1e-5, - "norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})", - input_norm, - rotated_norm, - (input_norm - rotated_norm).abs() / input_norm - ); - Ok(()) - } - - /// Verify that export -> [`from_u8_slice`] produces identical transform output. - #[rstest] - #[case(64u32, 3u8)] - #[case(128u32, 1u8)] - #[case(128u32, 3u8)] - #[case(128u32, 5u8)] - #[case(768u32, 3u8)] - fn sign_export_import_roundtrip(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { - let dim = dim_to_usize(dim); - let num_rounds = rounds_to_usize(num_rounds); - let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?; - let padded_dim = rot.padded_dim(); - - let signs_u8 = rot.export_inverse_signs_u8(); - let rot2 = SorfMatrix::from_u8_slice(&signs_u8, dim, num_rounds)?; - - let mut input = vec![0.0f32; padded_dim]; - for i in 0..dim { - input[i] = (i as f32 + 1.0) * 0.01; - } - - let mut out1 = vec![0.0f32; padded_dim]; - let mut out2 = vec![0.0f32; padded_dim]; - rot.rotate(&input, &mut out1); - rot2.rotate(&input, &mut out2); - assert_eq!(out1, out2, "Forward transform mismatch after export/import"); - - rot.inverse_rotate(&out1, &mut out2); - let mut out3 = vec![0.0f32; padded_dim]; - rot2.inverse_rotate(&out1, &mut out3); - assert_eq!(out2, out3, "Inverse transform mismatch after export/import"); - - Ok(()) - } - - #[test] - fn wht_basic() { - // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] - let mut buf = vec![1.0f32, 0.0, 0.0, 0.0]; - walsh_hadamard_transform(&mut buf); - assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]); - - // WHT is self-inverse (up to scaling by n) - walsh_hadamard_transform(&mut buf); - // After two WHTs: each element multiplied by n=4 - assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]); - } -} diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/splitmix64.rs b/vortex-tensor/src/scalar_fns/sorf_transform/splitmix64.rs deleted file mode 100644 index 23345cbcb7d..00000000000 --- a/vortex-tensor/src/scalar_fns/sorf_transform/splitmix64.rs +++ /dev/null @@ -1,73 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Frozen local SplitMix64 stream used to define SORF sign diagonals. -//! -//! This is a direct translation of the `splitmix64.c` reference implementation. The state is a -//! single `u64`, and `next_u64()` first adds [`SPLITMIX64_INCREMENT`] with wrapping arithmetic, -//! then applies the two reference mixing steps and final xor-shift. - -/// SplitMix64 additive constant from the reference implementation. -const SPLITMIX64_INCREMENT: u64 = 0x9E37_79B9_7F4A_7C15; - -/// First SplitMix64 mixing multiplier from the reference implementation. -const SPLITMIX64_MUL1: u64 = 0xBF58_476D_1CE4_E5B9; - -/// Second SplitMix64 mixing multiplier from the reference implementation. -const SPLITMIX64_MUL2: u64 = 0x94D0_49BB_1331_11EB; - -/// Frozen local SplitMix64 stream used to define SORF sign diagonals. -pub(crate) struct SplitMix64 { - state: u64, -} - -impl SplitMix64 { - pub(crate) fn new(seed: u64) -> Self { - Self { state: seed } - } - - pub(crate) fn next_u64(&mut self) -> u64 { - self.state = self.state.wrapping_add(SPLITMIX64_INCREMENT); - let mut z = self.state; - z = (z ^ (z >> 30)).wrapping_mul(SPLITMIX64_MUL1); - z = (z ^ (z >> 27)).wrapping_mul(SPLITMIX64_MUL2); - z ^ (z >> 31) - } -} - -#[cfg(test)] -mod tests { - use super::SplitMix64; - - const SPLITMIX64_SEED0_GOLDEN: [u64; 4] = [ - 0xE220_A839_7B1D_CDAF, - 0x6E78_9E6A_A1B9_65F4, - 0x06C4_5D18_8009_454F, - 0xF88B_B8A8_724C_81EC, - ]; - - const SPLITMIX64_SEED42_GOLDEN: [u64; 4] = [ - 0xBDD7_3226_2FEB_6E95, - 0x28EF_E333_B266_F103, - 0x4752_6757_130F_9F52, - 0x581C_E1FF_0E4A_E394, - ]; - - #[test] - fn splitmix64_seed0_matches_golden_outputs() { - let mut rng = SplitMix64::new(0); - let actual: Vec<_> = (0..SPLITMIX64_SEED0_GOLDEN.len()) - .map(|_| rng.next_u64()) - .collect(); - assert_eq!(actual, SPLITMIX64_SEED0_GOLDEN); - } - - #[test] - fn splitmix64_seed42_matches_golden_outputs() { - let mut rng = SplitMix64::new(42); - let actual: Vec<_> = (0..SPLITMIX64_SEED42_GOLDEN.len()) - .map(|_| rng.next_u64()) - .collect(); - assert_eq!(actual, SPLITMIX64_SEED42_GOLDEN); - } -} diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs deleted file mode 100644 index 5efe1436ed6..00000000000 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ /dev/null @@ -1,493 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Unit tests for the [`SorfTransform`] scalar function. - -#![allow(clippy::cast_possible_truncation)] - -use std::sync::Arc; - -use vortex_array::ArrayPlugin; -use vortex_array::ArrayRef; -use vortex_array::EmptyMetadata; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::dict::DictArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; -use vortex_array::dtype::DType; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::validity::Validity; -use vortex_buffer::Buffer; -use vortex_buffer::BufferMut; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; - -use super::SorfOptions; -use super::SorfTransform; -use super::rotation::SorfMatrix; -use crate::encodings::turboquant::centroids::compute_centroid_boundaries; -use crate::encodings::turboquant::centroids::compute_or_get_centroids; -use crate::encodings::turboquant::centroids::find_nearest_centroid; -use crate::tests::SESSION; -use crate::types::vector::AnyVector; -use crate::types::vector::Vector; - -/// Build a unit-normalized input vector array and forward-transform + quantize it, returning -/// `(input_f32, Vector(FSL(Dict(codes, centroids))), padded_dim)`. -/// -/// This mimics what the TurboQuant compression pipeline does, but directly, so we can test -/// `SorfTransform` in isolation. -fn forward_rotate_and_quantize( - dim: usize, - num_rows: usize, - seed: u64, - num_rounds: usize, - bit_width: u8, -) -> VortexResult<(Vec, ExtensionArray, usize)> { - // Build simple unit-normalized input vectors. - let mut input_f32 = vec![0.0f32; num_rows * dim]; - for row in 0..num_rows { - let mut norm_sq = 0.0f32; - for i in 0..dim { - let val = ((row * dim + i) as f32 + 1.0) * 0.01; - input_f32[row * dim + i] = val; - norm_sq += val * val; - } - let norm = norm_sq.sqrt(); - for i in 0..dim { - input_f32[row * dim + i] /= norm; - } - } - - let padded_dim = dim.next_power_of_two(); - let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds, seed)?; - let centroids = compute_or_get_centroids(padded_dim as u32, bit_width)?; - let boundaries = compute_centroid_boundaries(¢roids); - - let mut all_indices = BufferMut::::with_capacity(num_rows * padded_dim); - let mut padded = vec![0.0f32; padded_dim]; - let mut rotated = vec![0.0f32; padded_dim]; - - for row in 0..num_rows { - padded[..dim].copy_from_slice(&input_f32[row * dim..(row + 1) * dim]); - padded[dim..].fill(0.0); - rotation.rotate(&padded, &mut rotated); - for j in 0..padded_dim { - all_indices.push(find_nearest_centroid(rotated[j], &boundaries)); - } - } - - let codes = PrimitiveArray::new::(all_indices.freeze(), Validity::NonNullable); - let mut centroids_buf = BufferMut::::with_capacity(centroids.len()); - centroids_buf.extend_from_slice(¢roids); - let centroids_arr = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); - let dict = DictArray::try_new(codes.into_array(), centroids_arr.into_array())?; - let fsl = FixedSizeListArray::try_new( - dict.into_array(), - padded_dim as u32, - Validity::NonNullable, - num_rows, - )?; - let padded_vector = wrap_as_vector(fsl, Validity::NonNullable)?; - - Ok((input_f32, padded_vector, padded_dim)) -} - -/// Wrap an FSL in a Vector extension, optionally re-tagging its validity. This is used by tests -/// that need to adjust top-level nullability of a padded vector child. -fn wrap_as_vector(fsl: FixedSizeListArray, validity: Validity) -> VortexResult { - let list_size = fsl.list_size(); - let num_rows = fsl.len(); - let elements = fsl.elements().clone(); - let fsl = FixedSizeListArray::try_new(elements, list_size, validity, num_rows)?; - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array())) -} - -/// Helper to build `SorfOptions` with common defaults. -fn default_options(dim: u32, seed: u64) -> SorfOptions { - SorfOptions { - seed, - num_rounds: 3, - dimensions: dim, - element_ptype: PType::F32, - } -} - -/// Execute a `SorfTransform` array and return the decoded flat f32 elements. -fn execute_sorf( - options: &SorfOptions, - child: ExtensionArray, - num_rows: usize, -) -> VortexResult> { - let sorf = SorfTransform::try_new_array(options, child.into_array(), num_rows)?; - let mut ctx = SESSION.create_execution_ctx(); - let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; - let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; - let result_prim: PrimitiveArray = result_fsl.elements().clone().execute(&mut ctx)?; - Ok(result_prim.as_slice::().to_vec()) -} - -/// Build an empty `Vector` extension array wrapping an empty FSL. -fn empty_padded_vector(padded_dim: u32, validity: Validity) -> VortexResult { - let elements = PrimitiveArray::empty::(Nullability::NonNullable); - let fsl = FixedSizeListArray::try_new(elements.into_array(), padded_dim, validity, 0)?; - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array())) -} - -#[test] -fn roundtrip_recovery() -> VortexResult<()> { - let dim = 128; - let num_rows = 10; - let seed = 42u64; - let (input_f32, padded_vector, _) = forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; - let options = default_options(dim as u32, seed); - let result = execute_sorf(&options, padded_vector, num_rows)?; - - assert_eq!(result.len(), num_rows * dim); - - // At 8-bit quantization, the reconstruction should be very close to the input. - for row in 0..num_rows { - let orig = &input_f32[row * dim..(row + 1) * dim]; - let recon = &result[row * dim..(row + 1) * dim]; - let err_sq: f32 = orig - .iter() - .zip(recon) - .map(|(&a, &b)| (a - b) * (a - b)) - .sum(); - let norm_sq: f32 = orig.iter().map(|&v| v * v).sum(); - assert!( - err_sq / norm_sq < 1e-3, - "row {row} MSE too high: {:.6}", - err_sq / norm_sq - ); - } - Ok(()) -} - -#[test] -fn empty_array_non_nullable() -> VortexResult<()> { - let dim = 128u32; - let padded_dim = dim.next_power_of_two(); - let options = default_options(dim, 42); - - // Build an empty Vector child. - let child = empty_padded_vector(padded_dim, Validity::NonNullable)?; - - let sorf = SorfTransform::try_new_array(&options, child.into_array(), 0)?; - let mut ctx = SESSION.create_execution_ctx(); - let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; - - assert_eq!(result.len(), 0); - - // Output should be non-nullable. - let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; - assert!(!result_fsl.dtype().is_nullable()); - - Ok(()) -} - -#[test] -fn empty_array_nullable() -> VortexResult<()> { - let dim = 128u32; - let padded_dim = dim.next_power_of_two(); - let options = default_options(dim, 42); - - // Build an empty but nullable Vector child. - let child = empty_padded_vector(padded_dim, Validity::from(Nullability::Nullable))?; - - let sorf = SorfTransform::try_new_array(&options, child.into_array(), 0)?; - let mut ctx = SESSION.create_execution_ctx(); - let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; - - assert_eq!(result.len(), 0); - - // Output should be nullable (matching the child). - let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; - assert!(result_fsl.dtype().is_nullable()); - - Ok(()) -} - -#[test] -fn nullable_validity_propagation() -> VortexResult<()> { - let dim = 128; - let num_rows = 4; - let seed = 42u64; - let (_, non_nullable_vector, padded_dim) = - forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; - - // Re-wrap the underlying FSL with a validity mask: rows 0 and 2 are valid, rows 1 and 3 - // are null. - let validity = Validity::from_iter([true, false, true, false]); - let fsl_non_nullable: FixedSizeListArray = non_nullable_vector - .storage_array() - .clone() - .execute(&mut SESSION.create_execution_ctx())?; - let fsl_nullable = FixedSizeListArray::try_new( - fsl_non_nullable.elements().clone(), - padded_dim as u32, - validity.clone(), - num_rows, - )?; - let nullable_vector = wrap_as_vector(fsl_nullable, validity.clone())?; - - let options = default_options(dim as u32, seed); - let sorf = SorfTransform::try_new_array(&options, nullable_vector.into_array(), num_rows)?; - let mut ctx = SESSION.create_execution_ctx(); - let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; - let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; - - // The output FSL validity should match the input. - let output_validity = result_fsl.validity()?; - for row in 0..num_rows { - assert_eq!( - output_validity.execute_is_valid(row, &mut ctx)?, - validity.execute_is_valid(row, &mut ctx)?, - "validity mismatch at row {row}" - ); - } - - Ok(()) -} - -#[test] -fn dimension_truncation() -> VortexResult<()> { - // Use a non-power-of-2 dimension (padded 200 -> 256). - let dim = 200; - let num_rows = 3; - let seed = 42u64; - let (_, padded_vector, padded_dim) = forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; - - assert_eq!(padded_dim, 256, "200 should pad to 256"); - - let options = default_options(dim as u32, seed); - let result = execute_sorf(&options, padded_vector, num_rows)?; - - // Output should have original dimension, not padded. - assert_eq!(result.len(), num_rows * dim); - - Ok(()) -} - -#[test] -fn return_dtype_is_vector_extension() -> VortexResult<()> { - let dim = 128u32; - let padded_dim = dim.next_power_of_two(); - let options = default_options(dim, 42); - - // Input must be a Vector extension dtype. - let child_elem_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let child_storage_dtype = DType::FixedSizeList( - Arc::new(child_elem_dtype), - padded_dim, - Nullability::NonNullable, - ); - let child_ext_dtype = ExtDType::::try_new(EmptyMetadata, child_storage_dtype)?.erased(); - let child_dtype = DType::Extension(child_ext_dtype); - - use vortex_array::scalar_fn::ScalarFnVTable; - let return_dtype = SorfTransform.return_dtype(&options, &[child_dtype])?; - - // Should be a Vector extension type. - let ext = return_dtype - .as_extension_opt() - .expect("return dtype should be an extension type"); - assert!(ext.metadata_opt::().is_some()); - - // Inner FSL should have the original (unpadded) dimension. - let DType::FixedSizeList(_, inner_dim, _) = ext.storage_dtype() else { - panic!("expected storage dtype to be FSL"); - }; - assert_eq!(*inner_dim, dim); - - Ok(()) -} - -#[test] -fn rejects_zero_rounds_at_construction() { - let options = SorfOptions { - seed: 42, - num_rounds: 0, - dimensions: 128, - element_ptype: PType::F32, - }; - let elements = PrimitiveArray::from_iter([0.0f32; 128]).into_array(); - let child = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) - .expect("test child should be valid"); - - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) - .expect_err("zero rounds should be rejected at construction time"); - assert!(err.to_string().contains("num_rounds")); -} - -#[test] -fn rejects_non_float_output_ptype_at_construction() { - let options = SorfOptions { - seed: 42, - num_rounds: 3, - dimensions: 128, - element_ptype: PType::U8, - }; - let elements = PrimitiveArray::from_iter([0.0f32; 128]).into_array(); - let child = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) - .expect("test child should be valid"); - - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) - .expect_err("non-float output ptypes should be rejected at construction time"); - assert!(err.to_string().contains("element_ptype")); -} - -#[test] -fn rejects_non_vector_extension_child_at_construction() { - let options = default_options(128, 42); - // A bare FSL child (not wrapped in a Vector extension) should be rejected. - let elements = PrimitiveArray::from_iter([0.0f32; 128]).into_array(); - let child = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) - .expect("test child should be valid"); - - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) - .expect_err("non-Vector-extension children should be rejected at construction time"); - assert!(err.to_string().contains("Vector extension")); -} - -#[test] -fn rejects_wrong_padded_dimension_at_construction() { - // Options say dimension=128 so padded_dim should be 128. Pass a Vector<256> instead. - let options = default_options(128, 42); - let elements = PrimitiveArray::from_iter([0.0f32; 256]).into_array(); - let fsl = FixedSizeListArray::try_new(elements, 256, Validity::NonNullable, 1) - .expect("test child should be valid"); - let child = wrap_as_vector(fsl, Validity::NonNullable).expect("wrap should succeed"); - - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) - .expect_err("mismatched padded dimension should be rejected at construction time"); - assert!(err.to_string().contains("dimension")); -} - -#[test] -fn rejects_non_f32_child_storage_at_construction() { - // Options are valid and target f32 output. Pass a Vector<128> whose storage is f16 instead - // of f32 -- SorfTransform's f32-only input constraint should reject this. - let options = default_options(128, 42); - let elements = PrimitiveArray::from_iter([half::f16::from_f32(0.0); 128]).into_array(); - let fsl = FixedSizeListArray::try_new(elements, 128, Validity::NonNullable, 1) - .expect("test child should be valid"); - let child = wrap_as_vector(fsl, Validity::NonNullable).expect("wrap should succeed"); - - let err = SorfTransform::try_new_array(&options, child.into_array(), 1) - .expect_err("non-f32 Vector storage should be rejected at construction time"); - assert!(err.to_string().contains("f32")); -} - -#[test] -fn f16_output_type() -> VortexResult<()> { - let dim = 128; - let num_rows = 3; - let seed = 42u64; - let (_, padded_vector, _) = forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; - - let options = SorfOptions { - seed, - num_rounds: 3, - dimensions: dim as u32, - element_ptype: PType::F16, - }; - let sorf = SorfTransform::try_new_array(&options, padded_vector.into_array(), num_rows)?; - let mut ctx = SESSION.create_execution_ctx(); - let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; - let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; - let result_prim: PrimitiveArray = result_fsl.elements().clone().execute(&mut ctx)?; - - assert_eq!(result_prim.ptype(), PType::F16); - assert_eq!(result_prim.as_slice::().len(), num_rows * dim); - - Ok(()) -} - -#[test] -fn f64_output_type() -> VortexResult<()> { - let dim = 128; - let num_rows = 3; - let seed = 42u64; - let (_, padded_vector, _) = forward_rotate_and_quantize(dim, num_rows, seed, 3, 8)?; - - let options = SorfOptions { - seed, - num_rounds: 3, - dimensions: dim as u32, - element_ptype: PType::F64, - }; - let sorf = SorfTransform::try_new_array(&options, padded_vector.into_array(), num_rows)?; - let mut ctx = SESSION.create_execution_ctx(); - let result: ExtensionArray = sorf.into_array().execute(&mut ctx)?; - let result_fsl: FixedSizeListArray = result.storage_array().clone().execute(&mut ctx)?; - let result_prim: PrimitiveArray = result_fsl.elements().clone().execute(&mut ctx)?; - - assert_eq!(result_prim.ptype(), PType::F64); - assert_eq!(result_prim.as_slice::().len(), num_rows * dim); - - Ok(()) -} - -/// Build a trivial `Vector>` child populated with zeroes. The values -/// are irrelevant for the serde round-trip test; only the dtype shape matters. -fn trivial_padded_vector(padded_dim: u32, num_rows: usize, validity: Validity) -> ArrayRef { - let elements = PrimitiveArray::new( - Buffer::::zeroed(num_rows * padded_dim as usize), - Validity::NonNullable, - ); - let fsl = FixedSizeListArray::try_new(elements.into_array(), padded_dim, validity, num_rows) - .vortex_expect("fsl must build"); - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) - .vortex_expect("ext dtype must build") - .erased(); - ExtensionArray::new(ext_dtype, fsl.into_array()).into_array() -} - -#[rstest::rstest] -// Non-power-of-two dimension to exercise `padded_dim = dim.next_power_of_two()`. -#[case::power_of_two_dim(128, Validity::NonNullable)] -#[case::non_power_of_two_dim(100, Validity::NonNullable)] -// Nullable top-level Vector to verify child nullability is reconstructed from the parent output. -#[case::nullable_child(100, Validity::AllValid)] -fn serde_round_trip(#[case] dimensions: u32, #[case] validity: Validity) -> VortexResult<()> { - let padded_dim = dimensions.next_power_of_two(); - let num_rows = 4; - let options = SorfOptions { - seed: 42, - num_rounds: 3, - dimensions, - element_ptype: PType::F32, - }; - let child = trivial_padded_vector(padded_dim, num_rows, validity); - let original = SorfTransform::try_new_array(&options, child.clone(), num_rows)?.into_array(); - - let plugin = ScalarFnArrayPlugin::new(SorfTransform); - let metadata = plugin - .serialize(&original, &SESSION)? - .expect("SorfTransform serialize must produce metadata"); - - let children = vec![child]; - let recovered = plugin.deserialize( - original.dtype(), - original.len(), - &metadata, - &[], - &children, - &SESSION, - )?; - - assert_eq!(recovered.dtype(), original.dtype()); - assert_eq!(recovered.len(), original.len()); - assert_eq!(recovered.encoding_id(), original.encoding_id()); - Ok(()) -} diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs deleted file mode 100644 index 4e93d722e61..00000000000 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ /dev/null @@ -1,336 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! [`ScalarFnVTable`] implementation for [`SorfTransform`]. - -use std::sync::Arc; - -use num_traits::Float; -use num_traits::FromPrimitive; -use prost::Message; -use vortex_array::ArrayRef; -use vortex_array::EmptyMetadata; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::ScalarFn as ScalarFnArrayEncoding; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayView; -use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; -use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; -use vortex_array::dtype::DType; -use vortex_array::dtype::NativePType; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::dtype::proto::dtype as pb; -use vortex_array::expr::Expression; -use vortex_array::match_each_float_ptype; -use vortex_array::scalar_fn::Arity; -use vortex_array::scalar_fn::ChildName; -use vortex_array::scalar_fn::ExecutionArgs; -use vortex_array::scalar_fn::ScalarFnId; -use vortex_array::scalar_fn::ScalarFnVTable; -use vortex_array::serde::ArrayChildren; -use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure_eq; -use vortex_error::vortex_err; -use vortex_session::VortexSession; -use vortex_session::registry::CachedId; - -use super::SorfOptions; -use super::SorfTransform; -use super::rotation::SorfMatrix; -use super::validate_sorf_options; -use crate::types::vector::AnyVector; -use crate::types::vector::Vector; - -impl ScalarFnVTable for SorfTransform { - type Options = SorfOptions; - - fn id(&self) -> ScalarFnId { - static ID: CachedId = CachedId::new("vortex.tensor.sorf_transform"); - *ID - } - - fn arity(&self, _options: &Self::Options) -> Arity { - Arity::Exact(1) - } - - fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { - match child_idx { - 0 => ChildName::from("rotated"), - _ => unreachable!("SorfTransform must have exactly one child"), - } - } - - fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { - validate_sorf_options(options)?; - - let child_dtype = &arg_dtypes[0]; - let vector_metadata = child_dtype - .as_extension_opt() - .and_then(|ext| ext.metadata_opt::()) - .ok_or_else(|| { - vortex_err!("SorfTransform child must be a Vector extension, got {child_dtype}") - })?; - - let expected_padded = options.dimensions.next_power_of_two(); - vortex_ensure_eq!( - vector_metadata.dimensions(), - expected_padded, - "SorfTransform child Vector must have dimension {expected_padded} (next power of two \ - for dimension {})", - options.dimensions, - ); - - // For now, the child Vector storage must be f32. TurboQuant stores its centroids as f32, - // and the SORF transform itself operates in f32, so any other input type would require an - // implicit cast that we do not yet support. The output element type is independently - // specified via `options.element_ptype` and is built below. - vortex_ensure_eq!( - vector_metadata.element_ptype(), - PType::F32, - "SorfTransform child Vector storage must be f32 (for now), got {}", - vector_metadata.element_ptype(), - ); - - let output_elem_dtype = DType::Primitive(options.element_ptype, Nullability::NonNullable); - let storage_dtype = DType::FixedSizeList( - Arc::new(output_elem_dtype), - options.dimensions, - child_dtype.nullability(), - ); - - let _ = vector_metadata; - let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage_dtype)?.erased(); - - Ok(DType::Extension(ext_dtype)) - } - - fn execute( - &self, - options: &Self::Options, - args: &dyn ExecutionArgs, - ctx: &mut ExecutionCtx, - ) -> VortexResult { - let dim = options.dimensions as usize; - let num_rows = args.row_count(); - - if num_rows == 0 { - let child_dtype = args.get(0)?.dtype().clone(); - let validity = Validity::from(child_dtype.nullability()); - - return match_each_float_ptype!(options.element_ptype, |T| { - let elements = PrimitiveArray::empty::(Nullability::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - options.dimensions, - validity, - 0, - )?; - Vector::try_new_vector_array(fsl.into_array()) - }); - } - - // Execute the child to get the Vector extension wrapping an FSL of f32 coordinates. The - // `return_dtype` check guarantees the child is a `Vector`, so the - // materialized FSL elements are always f32. - let child_ext: ExtensionArray = args.get(0)?.execute(ctx)?; - let child_validity = child_ext.as_ref().validity()?; - let child_fsl: FixedSizeListArray = child_ext.storage_array().clone().execute(ctx)?; - let padded_dim = - usize::try_from(child_fsl.list_size()).vortex_expect("list_size fits usize"); - - let elements_prim: PrimitiveArray = child_fsl.elements().clone().execute(ctx)?; - let f32_elements = elements_prim.into_buffer::(); - - // Reconstruct the orthogonal transform matrix from the seed. - let rotation = - SorfMatrix::try_new_padded(padded_dim, options.num_rounds as usize, options.seed)?; - - // Inverse transform each row, truncate to original dimension, cast to target type. - match_each_float_ptype!(options.element_ptype, |T| { - inverse_rotate_typed::( - &f32_elements, - &rotation, - dim, - padded_dim, - num_rows, - child_validity, - ) - }) - } - - fn validity( - &self, - _options: &Self::Options, - expression: &Expression, - ) -> VortexResult> { - Ok(Some(expression.child(0).validity()?)) - } - - fn is_null_sensitive(&self, _options: &Self::Options) -> bool { - false - } - - fn is_fallible(&self, _options: &Self::Options) -> bool { - false - } -} - -/// Metadata for a serialized [`SorfTransform`] array. -/// -/// Stores the full [`SorfOptions`] inline along with the child [`DType`]. Older metadata omitted -/// this field; deserialization derives the legacy plain-`Vector` child dtype from the parent dtype -/// in that case. -#[derive(Clone, prost::Message)] -pub(super) struct SorfTransformMetadata { - #[prost(uint64, tag = "1")] - seed: u64, - /// Rust `u8` widened to `u32` for protobuf (no `u8` on the wire). - #[prost(uint32, tag = "2")] - num_rounds: u32, - #[prost(uint32, tag = "3")] - dimension: u32, - #[prost(enumeration = "PType", tag = "4")] - element_ptype: i32, - #[prost(message, optional, tag = "5")] - child_dtype: Option, -} - -impl ScalarFnArrayVTable for SorfTransform { - fn serialize( - &self, - view: &ScalarFnArrayView, - _session: &VortexSession, - ) -> VortexResult>> { - let scalar_fn_array = view.as_::(); - let child_dtype = Some(scalar_fn_array.child_at(0).dtype().try_into()?); - let metadata = SorfTransformMetadata { - child_dtype, - ..SorfTransformMetadata::from(view.options) - }; - Ok(Some(metadata.encode_to_vec())) - } - - fn deserialize( - &self, - dtype: &DType, - len: usize, - metadata: &[u8], - children: &dyn ArrayChildren, - session: &VortexSession, - ) -> VortexResult> { - let metadata = SorfTransformMetadata::decode(metadata) - .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))?; - let options = metadata.to_options()?; - - // `return_dtype` sets the output FSL's nullability to the child's nullability (see - // `return_dtype` above), so we read the child nullability back from the parent dtype. - let child_nullability = dtype - .as_extension_opt() - .ok_or_else(|| { - vortex_err!("SorfTransform parent dtype must be a Vector extension, got {dtype}") - })? - .storage_dtype() - .nullability(); - let padded_dim = options.dimensions.next_power_of_two(); - let child_storage = DType::FixedSizeList( - Arc::new(DType::Primitive(PType::F32, Nullability::NonNullable)), - padded_dim, - child_nullability, - ); - let child_dtype = match metadata.child_dtype.as_ref() { - Some(dtype) => DType::from_proto(dtype, session)?, - None => { - let child_ext = ExtDType::::try_new(EmptyMetadata, child_storage)?.erased(); - DType::Extension(child_ext) - } - }; - let child = children.get(0, &child_dtype, len)?; - - Ok(ScalarFnArrayParts { - options, - children: vec![child], - }) - } -} - -/// Convert an f32 value to a float type `T`. -/// -/// `FromPrimitive::from_f32` is infallible for all Vortex float types: f16 saturates via the -/// inherent `f16::from_f32()`, f32 is identity, f64 is lossless widening. -fn float_from_f32(v: f32) -> T { - FromPrimitive::from_f32(v).vortex_expect("f32-to-float conversion is infallible") -} - -/// Apply the inverse SORF transform on f32 data, truncate to the original dimension, cast each -/// element to `T`, and build a plain [`Vector`](crate::vector::Vector) extension array. -fn inverse_rotate_typed( - f32_elements: &[f32], - rotation: &SorfMatrix, - dim: usize, - padded_dim: usize, - num_rows: usize, - validity: Validity, -) -> VortexResult { - let dim_u32 = u32::try_from(dim).vortex_expect("dimension fits u32"); - let mut output = BufferMut::::with_capacity(num_rows * dim); - let mut unrotated = vec![0.0f32; padded_dim]; - - for row in 0..num_rows { - let row_data = &f32_elements[row * padded_dim..(row + 1) * padded_dim]; - - rotation.inverse_rotate(row_data, &mut unrotated); - - for idx in 0..dim { - // SAFETY: We allocated enough memory above. - unsafe { output.push_unchecked(float_from_f32::(unrotated[idx])) }; - } - } - - let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new(elements.into_array(), dim_u32, validity, num_rows)?; - Vector::try_new_vector_array(fsl.into_array()) -} - -impl From<&SorfOptions> for SorfTransformMetadata { - fn from(options: &SorfOptions) -> Self { - Self { - seed: options.seed, - num_rounds: u32::from(options.num_rounds), - dimension: options.dimensions, - element_ptype: options.element_ptype as i32, - child_dtype: None, - } - } -} - -impl SorfTransformMetadata { - /// Rebuild the [`SorfOptions`] this metadata was serialized from, validating that the wire - /// values are in range. - fn to_options(&self) -> VortexResult { - let num_rounds = u8::try_from(self.num_rounds).map_err(|_| { - vortex_err!( - "SorfTransform num_rounds {} does not fit in u8", - self.num_rounds - ) - })?; - let options = SorfOptions { - seed: self.seed, - num_rounds, - dimensions: self.dimension, - element_ptype: self.element_ptype(), - }; - validate_sorf_options(&options)?; - Ok(options) - } -} diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 88845f4b767..f9558a07373 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -21,10 +21,8 @@ use vortex_array::dtype::NativePType; use vortex_array::dtype::PType; use vortex_array::dtype::proto::dtype as pb; use vortex_array::scalar_fn::ScalarFnVTable; -use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_session::VortexSession; @@ -106,40 +104,6 @@ pub fn validate_binary_tensor_float_inputs<'a>( validate_tensor_float_input(lhs) } -/// Cast a float [`PrimitiveArray`] to a `Buffer`. -/// -/// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively -/// in f32. This function handles the cast from any float ptype: -/// -/// - f16: losslessly widened to f32. -/// - f32: zero-copy buffer extraction. -/// - f64: truncated to f32 precision. Values outside f32 range become +/- infinity. This is -/// acceptable because callers of this function operate in f32 and document this constraint. -pub fn cast_to_f32(prim: PrimitiveArray) -> VortexResult> { - match prim.ptype() { - PType::F16 => Ok(prim - .as_slice::() - .iter() - .map(|&v| f32::from(v)) - .collect()), - PType::F32 => Ok(prim.into_buffer()), - PType::F64 => Ok(prim - .as_slice::() - .iter() - .map(|&v| { - #[expect( - clippy::cast_possible_truncation, - reason = "f64 values outside f32 range become infinity, which is acceptable \ - because callers operate in f32 and document this constraint" - )] - let v = v as f32; - v - }) - .collect()), - other => vortex_bail!("expected float elements, got {other:?}"), - } -} - /// The flat primitive elements of a tensor storage array, with typed row access. /// /// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a diff --git a/vortex-tensor/src/vector_search.rs b/vortex-tensor/src/vector_search.rs index 2a036accdfe..2c053ccf0a2 100644 --- a/vortex-tensor/src/vector_search.rs +++ b/vortex-tensor/src/vector_search.rs @@ -7,8 +7,7 @@ //! [`build_similarity_search_tree`] broadcasts the query into the shape expected by //! [`CosineSimilarity`] via `Vector::constant_array` and returns a lazy //! `Binary(Gt, [CosineSimilarity(data, query), threshold])` expression. The caller is responsible -//! for preparing `data` (e.g. by running it through [`turboquant_encode`]); this builder does not -//! compress. +//! for preparing `data` (e.g. by compressing it beforehand); this builder does not compress. //! //! Executing the tree into a [`BoolArray`] yields one boolean per row indicating whether that row's //! cosine similarity to the query exceeds `threshold`. @@ -19,12 +18,10 @@ //! use vortex_array::{ArrayRef, VortexSessionExecute}; //! use vortex_array::arrays::BoolArray; //! use vortex_session::VortexSession; -//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode}; //! use vortex_tensor::vector_search::build_similarity_search_tree; //! //! fn run(session: &VortexSession, data: ArrayRef, query: &[f32]) -> anyhow::Result<()> { //! let mut ctx = session.create_execution_ctx(); -//! let data = turboquant_encode(data, &TurboQuantConfig::default(), &mut ctx)?; //! let tree = build_similarity_search_tree(data, query, 0.8)?; //! let _matches: BoolArray = tree.execute(&mut ctx)?; //! Ok(()) @@ -33,7 +30,6 @@ //! //! [`Vector`]: crate::vector::Vector //! [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity -//! [`turboquant_encode`]: crate::encodings::turboquant::turboquant_encode //! [`BoolArray`]: vortex_array::arrays::BoolArray use vortex_array::ArrayRef; @@ -99,8 +95,6 @@ mod tests { use vortex_error::VortexResult; use super::build_similarity_search_tree; - use crate::encodings::turboquant::TurboQuantConfig; - use crate::encodings::turboquant::turboquant_encode; use crate::tests::SESSION; use crate::utils::test_helpers::vector_array; @@ -130,46 +124,4 @@ mod tests { assert!(bits.value(3)); Ok(()) } - - #[test] - fn turboquant_roundtrip_preserves_ranking() -> VortexResult<()> { - // Build 6 rows of 128-dim vectors where row 0 is highly correlated with the query. - // TurboQuant should preserve the "row 0 is the best match" ordering. - const DIM: u32 = 128; - const NUM_ROWS: usize = 6; - - let mut values = Vec::::with_capacity(NUM_ROWS * DIM as usize); - let query: Vec = (0..DIM as usize) - .map(|i| ((i as f32) * 0.017).sin()) - .collect(); - - // Row 0: identical to query (cosine=1.0) - values.extend_from_slice(&query); - // Row 1: query + noise - for (i, q) in query.iter().enumerate() { - values.push(q + 0.05 * ((i as f32) * 0.03).cos()); - } - // Rows 2..6: unrelated patterns - for row in 2..NUM_ROWS { - for i in 0..DIM as usize { - values.push(((row as f32 * 1.3 + i as f32) * 0.07).sin()); - } - } - - let data = vector_array(DIM, &values)?; - let mut ctx = SESSION.create_execution_ctx(); - let compressed = turboquant_encode(data, &TurboQuantConfig::default(), &mut ctx)?; - assert_eq!(compressed.len(), NUM_ROWS); - - // Build a tree with a low threshold so row 0 (cosine=1.0 exact) matches. - let tree = build_similarity_search_tree(compressed, &query, 0.95)?; - let result: BoolArray = tree.execute(&mut ctx)?; - let bits = result.to_bit_buffer(); - assert_eq!(bits.len(), NUM_ROWS); - assert!( - bits.value(0), - "row 0 (identical to query) must match at threshold 0.95 even after TurboQuant" - ); - Ok(()) - } } diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml deleted file mode 100644 index ab3f63583d3..00000000000 --- a/vortex-turboquant/Cargo.toml +++ /dev/null @@ -1,41 +0,0 @@ -[package] -name = "vortex-turboquant" -authors = { workspace = true } -categories = { workspace = true } -description = "TurboQuant vector extension type" -edition = { workspace = true } -homepage = { workspace = true } -include = { workspace = true } -keywords = { workspace = true } -license = { workspace = true } -readme = { workspace = true } -repository = { workspace = true } -rust-version = { workspace = true } -version = { workspace = true } - -[lints] -workspace = true - -[dependencies] -half = { workspace = true } -num-traits = { workspace = true } -prost = { workspace = true } -vortex-array = { workspace = true } -vortex-buffer = { workspace = true } -vortex-error = { workspace = true } -vortex-mask = { workspace = true } -vortex-session = { workspace = true } -vortex-tensor = { workspace = true } -vortex-utils = { workspace = true, features = ["dashmap"] } - -[dev-dependencies] -divan = { workspace = true } -rand = { workspace = true } -rstest = { workspace = true } -vortex-file = { workspace = true } -vortex-io = { workspace = true } -vortex-layout = { workspace = true } - -[[bench]] -name = "encode_decode" -harness = false diff --git a/vortex-turboquant/benches/encode_decode.rs b/vortex-turboquant/benches/encode_decode.rs deleted file mode 100644 index 6adcd5523b1..00000000000 --- a/vortex-turboquant/benches/encode_decode.rs +++ /dev/null @@ -1,147 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Benchmarks for `turboquant_encode` and `turboquant_decode` across different validity-mask -//! shapes. -//! -//! The four mask shapes (`AllTrue`, `AllFalse`, dense `Values`, sparse `Values`) exercise the -//! variant-specialized paths added in the mask refactor in `vector/normalize.rs`, -//! `vector/quantize.rs`, and `scalar_fns/decode.rs`. - -#![expect(clippy::unwrap_used)] - -use std::sync::LazyLock; - -use divan::Bencher; -use rand::RngExt; -use rand::SeedableRng as _; -use rand::rngs::StdRng; -use vortex_array::ArrayRef; -use vortex_array::EmptyMetadata; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::session::ArraySession; -use vortex_array::validity::Validity; -use vortex_buffer::Buffer; -use vortex_session::VortexSession; -use vortex_tensor::vector::Vector; -use vortex_turboquant::TQDecode; -use vortex_turboquant::TQEncode; -use vortex_turboquant::TurboQuantConfig; - -fn main() { - divan::main(); -} - -static SESSION: LazyLock = LazyLock::new(|| { - let session = VortexSession::empty().with::(); - vortex_turboquant::initialize(&session); - session -}); - -/// Shape of the validity mask used to drive the variant-specialized paths. -#[derive(Copy, Clone)] -enum MaskShape { - AllValid, - AllInvalid, - DenseValues, - SparseValues, -} - -impl std::fmt::Debug for MaskShape { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(match self { - MaskShape::AllValid => "all_valid", - MaskShape::AllInvalid => "all_invalid", - MaskShape::DenseValues => "dense_95pct", - MaskShape::SparseValues => "sparse_5pct", - }) - } -} - -impl MaskShape { - fn build(self, rows: usize, rng: &mut StdRng) -> Validity { - match self { - MaskShape::AllValid => Validity::NonNullable, - MaskShape::AllInvalid => Validity::AllInvalid, - MaskShape::DenseValues => Validity::from_iter((0..rows).map(|_| rng.random_bool(0.95))), - MaskShape::SparseValues => { - Validity::from_iter((0..rows).map(|_| rng.random_bool(0.05))) - } - } - } -} - -const MASK_SHAPES: &[MaskShape] = &[ - MaskShape::AllValid, - MaskShape::AllInvalid, - MaskShape::DenseValues, - MaskShape::SparseValues, -]; - -const ROWS: usize = 4096; -const DIMENSIONS: u32 = 128; - -fn build_vector_array(shape: MaskShape) -> ArrayRef { - let mut rng = StdRng::seed_from_u64(0xC0FFEE); - let dim = DIMENSIONS as usize; - let values: Buffer = (0..ROWS * dim).map(|_| rng.random::()).collect(); - let elements = PrimitiveArray::new::(values, Validity::NonNullable); - let validity = shape.build(ROWS, &mut rng); - let fsl = - FixedSizeListArray::try_new(elements.into_array(), DIMENSIONS, validity, ROWS).unwrap(); - - ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl.into_array()) - .unwrap() - .into_array() -} - -fn encode(vec: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { - TQEncode::try_new_array(vec, config) - .unwrap() - .into_array() - .execute(ctx) - .unwrap() -} - -fn decode(encoded: ArrayRef, ctx: &mut ExecutionCtx) -> ArrayRef { - TQDecode::try_new_array(encoded) - .unwrap() - .into_array() - .execute(ctx) - .unwrap() -} - -fn config() -> TurboQuantConfig { - // 4 bits, 4 SORF rounds, fixed seed: representative defaults from the test fixtures. - TurboQuantConfig::try_new(4, 0xDEADBEEF, 4).unwrap() -} - -#[divan::bench(args = MASK_SHAPES)] -fn turboquant_encode(bencher: Bencher, shape: &MaskShape) { - let shape = *shape; - let cfg = config(); - bencher - .with_inputs(|| (build_vector_array(shape), SESSION.create_execution_ctx())) - .input_counter(|_| divan::counter::ItemsCount::new(ROWS)) - .bench_values(|(arr, mut ctx)| encode(arr, &cfg, &mut ctx)) -} - -#[divan::bench(args = MASK_SHAPES)] -fn turboquant_decode(bencher: Bencher, shape: &MaskShape) { - let shape = *shape; - let cfg = config(); - bencher - .with_inputs(|| { - let arr = build_vector_array(shape); - let mut ctx = SESSION.create_execution_ctx(); - let encoded = encode(arr, &cfg, &mut ctx); - (encoded, SESSION.create_execution_ctx()) - }) - .input_counter(|_| divan::counter::ItemsCount::new(ROWS)) - .bench_values(|(encoded, mut ctx)| decode(encoded, &mut ctx)) -} diff --git a/vortex-turboquant/src/centroids.rs b/vortex-turboquant/src/centroids.rs deleted file mode 100644 index 2be60c0ed4e..00000000000 --- a/vortex-turboquant/src/centroids.rs +++ /dev/null @@ -1,361 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Max-Lloyd centroid computation for TurboQuant scalar quantizers. -//! -//! Pre-computes and caches optimal scalar quantizer centroids for the marginal distribution of -//! coordinates after a random orthogonal transform of a unit-norm vector. -//! -//! In high dimensions, each coordinate of a randomly transformed unit vector follows a -//! distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, which converges to -//! `N(0, 1/d)`. -//! -//! The Max-Lloyd algorithm finds optimal quantization centroids that minimize MSE for this -//! distribution. -//! -//! Centroids are not stored in TurboQuant arrays. They are deterministically derived from -//! `(padded_dim, bit_width)` and cached process-locally. -//! -//! The centroid model follows the random orthogonal transform marginal used by the TurboQuant -//! paper. This encoder applies a SORF-style structured transform instead of a dense random Gaussian -//! or orthogonal matrix, so paper-level error bounds should not be treated as verified for this -//! implementation without separate empirical validation. - -use std::sync::LazyLock; - -use vortex_buffer::Buffer; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_utils::aliases::dash_map::DashMap; - -use crate::config::MAX_BIT_WIDTH; -use crate::config::MIN_DIMENSION; - -// NB: All of these constants were chosen arbitrarily. - -/// The maximum iterations for Max-Lloyd algorithm when computing centroids. -const MAX_ITERATIONS: usize = 200; - -/// The Max-Lloyd convergence threshold for stopping early when computing centroids. -const CONVERGENCE_EPSILON: f64 = 1e-12; - -/// Number of trapezoids used for numerical integration when computing conditional expectations. -/// -/// The trapezoidal rule evaluates the integrand at `INTEGRATION_TRAPEZOIDS + 1` points. -const INTEGRATION_TRAPEZOIDS: usize = 1000; - -/// Global centroid cache keyed by (dimension, bit_width). -static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); - -/// Get or compute cached centroids for the given dimension and bit width. -/// -/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar -/// quantization levels for the coordinate distribution after a random orthogonal transform in -/// `dimension`-dimensional space. -pub(crate) fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { - vortex_ensure!( - (1..=MAX_BIT_WIDTH).contains(&bit_width), - "TurboQuant bit_width must be 1-{}, got {bit_width}", - MAX_BIT_WIDTH - ); - vortex_ensure!( - dimension >= MIN_DIMENSION, - "TurboQuant dimension must be >= {}, got {dimension}", - MIN_DIMENSION - ); - - if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { - return Ok(centroids.clone()); - } - - let centroids = max_lloyd_centroids(dimension, bit_width); - CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); - - Ok(centroids) -} - -// TODO(connor): It would potentially be more performant if this was modelled as const generic -// parameters to functions. -/// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`. -/// -/// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) or a -/// half-integer (when `d` is even). -/// -/// This type makes that invariant explicit and avoids floating-point comparison in the hot path. -#[derive(Clone, Copy, Debug)] -struct HalfIntExponent { - int_part: i32, - has_half: bool, -} - -impl HalfIntExponent { - /// Compute `(numerator) / 2` as a half-integer exponent. - /// - /// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative. - fn from_numerator(numerator: i32) -> Self { - // Use Euclidean division to get floor division toward negative infinity. - let int_part = numerator.div_euclid(2); - let has_half = numerator.rem_euclid(2) != 0; - Self { int_part, has_half } - } -} - -/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. -/// -/// Operates on the marginal distribution of a single coordinate of a randomly transformed unit -/// vector in d dimensions. -/// -/// The probability distribution function is: -/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` -/// where `C_d` is the normalizing constant. -/// -/// Centroids are seeded uniformly on `[±sqrt(bit_width) * sigma]` (where `sigma` is the standard -/// deviation of the normal distribution that hypershere dimension values take, and specifically -/// `sigma = 1/sqrt(dimension)`) rather than across the full `[-1, 1]`, which strands most of the -/// centroids in the near-zero-mass tails. -/// -/// Note that the `sqrt(bit_width)` is mostly empirically derived, we do not have a theoretical -/// basis for choosing this other than the fact that it seems to produce good results. -fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { - debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); - let num_centroids = 1usize << bit_width; - - // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. - let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); - - // The coordinate marginal concentrates around 0 with this standard deviation. - let sigma = 1.0 / f64::from(dimension).sqrt(); - let init_half = (f64::from(bit_width).sqrt() * sigma).min(1.0); - - // Initialize centroids uniformly on [-init_half, init_half], where the mass lives, so no cell - // starts in a zero-mass region and freezes. - let mut centroids: Vec = (0..num_centroids) - .map(|idx| -init_half + (2.0 * (idx as f64) + 1.0) * init_half / (num_centroids as f64)) - .collect(); - - let mut boundaries: Vec = vec![0.0; num_centroids + 1]; - for _ in 0..MAX_ITERATIONS { - // Compute decision boundaries (midpoints between adjacent centroids). - boundaries[0] = -1.0; - for idx in 0..num_centroids - 1 { - boundaries[idx + 1] = (centroids[idx] + centroids[idx + 1]) / 2.0; - } - boundaries[num_centroids] = 1.0; - - // Update each centroid to the conditional mean within its Voronoi cell. - let mut max_change = 0.0f64; - for idx in 0..num_centroids { - let lo = boundaries[idx]; - let hi = boundaries[idx + 1]; - let new_centroid = mean_between_centroids(lo, hi, exponent); - max_change = max_change.max((new_centroid - centroids[idx]).abs()); - centroids[idx] = new_centroid; - } - - if max_change < CONVERGENCE_EPSILON { - break; - } - } - - #[expect( - clippy::cast_possible_truncation, - reason = "all values are in [-1, 1] so this just loses precision" - )] - centroids.into_iter().map(|val| val as f32).collect() -} - -/// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. -/// -/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` on [-1, 1]. -/// -/// Since there is no closed form for the integrals, we compute this numerically. -fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { - if (hi - lo).abs() < 1e-15 { - return (lo + hi) / 2.0; - } - - let dx = (hi - lo) / INTEGRATION_TRAPEZOIDS as f64; - - let mut numerator = 0.0; - let mut denominator = 0.0; - - for step in 0..=INTEGRATION_TRAPEZOIDS { - let x_val = lo + (step as f64) * dx; - let weight = pdf_unnormalized(x_val, exponent); - - let trap_weight = if step == 0 || step == INTEGRATION_TRAPEZOIDS { - 0.5 - } else { - 1.0 - }; - - numerator += trap_weight * x_val * weight; - denominator += trap_weight * weight; - } - - if denominator.abs() < 1e-30 { - (lo + hi) / 2.0 - } else { - numerator / denominator - } -} - -/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. -/// -/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents that arise from `(d-3)/2`. -/// This is significantly faster than the general `powf` which goes through -/// `exp(exponent * ln(base))`. -fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { - let base = (1.0 - x_val * x_val).max(0.0); - - if exponent.has_half { - // Half-integer exponent: base^(int_part) * sqrt(base). - base.powi(exponent.int_part) * base.sqrt() - } else { - // Integer exponent: use powi directly. - base.powi(exponent.int_part) - } -} - -/// Precompute decision boundaries (midpoints between adjacent centroids). -/// -/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps to centroid 0, a -/// value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, and a -/// value `>= boundaries[k-2]` maps to centroid `k-1`. -pub(crate) fn compute_centroid_boundaries(centroids: &[f32]) -> Vec { - centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() -} - -/// Find the index of the nearest centroid using precomputed decision boundaries. -/// -/// `boundaries` must be the output of [`compute_centroid_boundaries`] for the corresponding -/// centroids. Uses binary search on the midpoints, avoiding distance comparisons -/// in the inner loop. -#[inline] -pub(crate) fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { - debug_assert!( - boundaries.windows(2).all(|w| w[0] <= w[1]), - "boundaries must be sorted" - ); - debug_assert!( - boundaries.len() <= 256, // 1 << 8 - "too many boundaries" - ); - - #[expect( - clippy::cast_possible_truncation, - reason = "num_centroids <= 256 and partition_point will return at most 255" - )] - (boundaries.partition_point(|&b| b < value) as u8) -} - -#[cfg(test)] -mod tests { - use rstest::rstest; - use vortex_error::VortexResult; - - use super::*; - - #[rstest] - #[case(128, 1, 2)] - #[case(128, 2, 4)] - #[case(128, 3, 8)] - #[case(128, 4, 16)] - #[case(768, 2, 4)] - #[case(1536, 3, 8)] - fn centroids_have_correct_count( - #[case] dim: u32, - #[case] bits: u8, - #[case] expected: usize, - ) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - assert_eq!(centroids.len(), expected); - Ok(()) - } - - #[rstest] - #[case(128, 1)] - #[case(128, 2)] - #[case(128, 3)] - #[case(128, 4)] - #[case(768, 2)] - fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - for window in centroids.windows(2) { - assert!( - window[0] < window[1], - "centroids not sorted: {:?}", - centroids - ); - } - Ok(()) - } - - #[rstest] - #[case(128, 1)] - #[case(128, 2)] - #[case(256, 2)] - #[case(768, 2)] - fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - let count = centroids.len(); - for idx in 0..count / 2 { - let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); - assert!( - diff < 1e-5, - "centroids not symmetric: c[{idx}]={}, c[{}]={}", - centroids[idx], - count - 1 - idx, - centroids[count - 1 - idx] - ); - } - Ok(()) - } - - #[rstest] - #[case(128, 1)] - #[case(128, 4)] - fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { - let centroids = compute_or_get_centroids(dim, bits)?; - for &val in centroids.iter() { - assert!( - (-1.0..=1.0).contains(&val), - "centroid out of [-1, 1]: {val}", - ); - } - Ok(()) - } - - #[test] - fn centroids_cached() -> VortexResult<()> { - let c1 = compute_or_get_centroids(128, 2)?; - let c2 = compute_or_get_centroids(128, 2)?; - assert_eq!(c1, c2); - Ok(()) - } - - #[test] - fn find_nearest_basic() -> VortexResult<()> { - let centroids = compute_or_get_centroids(128, 2)?; - let boundaries = compute_centroid_boundaries(¢roids); - assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); - - #[expect(clippy::cast_possible_truncation)] - let last_idx = (centroids.len() - 1) as u8; - assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); - for (idx, &cv) in centroids.iter().enumerate() { - #[expect(clippy::cast_possible_truncation)] - let expected = idx as u8; - assert_eq!(find_nearest_centroid(cv, &boundaries), expected); - } - Ok(()) - } - - #[test] - fn rejects_invalid_params() { - assert!(compute_or_get_centroids(128, 0).is_err()); - assert!(compute_or_get_centroids(128, 9).is_err()); - assert!(compute_or_get_centroids(1, 2).is_err()); - assert!(compute_or_get_centroids(127, 2).is_err()); - } -} diff --git a/vortex-turboquant/src/config.rs b/vortex-turboquant/src/config.rs deleted file mode 100644 index 57cd8b1e94b..00000000000 --- a/vortex-turboquant/src/config.rs +++ /dev/null @@ -1,84 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::fmt; - -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; - -/// Minimum vector dimension for TurboQuant encoding. -/// -/// Note that this is not a theoretical minimum, it is mostly a practical one to limit the total -/// amount of distortion. -pub(crate) const MIN_DIMENSION: u32 = 128; - -/// Maximum supported number of bits per quantized coordinate. -pub(crate) const MAX_BIT_WIDTH: u8 = 8; - -/// Configuration for lossy TurboQuant encoding. -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct TurboQuantConfig { - bit_width: u8, - seed: u64, - num_rounds: u8, -} - -impl TurboQuantConfig { - /// Build a TurboQuant configuration. - /// - /// # Errors - /// - /// Returns an error if `bit_width` is outside `1..=8` or `num_rounds` is zero. - pub fn try_new(bit_width: u8, seed: u64, num_rounds: u8) -> VortexResult { - vortex_ensure!( - (1..=MAX_BIT_WIDTH).contains(&bit_width), - "TurboQuant bit_width must be 1-{MAX_BIT_WIDTH}, got {bit_width}", - ); - vortex_ensure!( - num_rounds > 0, - "TurboQuant num_rounds must be > 0, got {num_rounds}" - ); - - Ok(Self { - bit_width, - seed, - num_rounds, - }) - } - - /// Bits per coordinate in the scalar quantizer codebook. - pub fn bit_width(&self) -> u8 { - self.bit_width - } - - /// Seed used to derive the deterministic SORF transform. - pub fn seed(&self) -> u64 { - self.seed - } - - /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. - pub fn num_rounds(&self) -> u8 { - self.num_rounds - } -} - -impl Default for TurboQuantConfig { - /// Defaults to 8 bits per coordinate, seed 42, and 3 SORF rounds. - fn default() -> Self { - Self { - bit_width: MAX_BIT_WIDTH, - seed: 42, - num_rounds: 3, - } - } -} - -impl fmt::Display for TurboQuantConfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "bit_width: {}, seed: {}, num_rounds: {}", - self.bit_width, self.seed, self.num_rounds - ) - } -} diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs deleted file mode 100644 index 7aeb60368dd..00000000000 --- a/vortex-turboquant/src/lib.rs +++ /dev/null @@ -1,81 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant vector quantization extension type for Vortex. -//! -//! Implements a Stage 1 TurboQuant encoding ([arXiv:2504.19874], [RFC 0033]) for lossy compression -//! of high-dimensional vector data. The extension operates on -//! [`Vector`](vortex_tensor::vector::Vector) extension arrays, encoding their `FixedSizeList` -//! storage into quantized codes after a structured orthogonal surrogate transform. -//! -//! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 -//! [RFC 0033]: https://vortex-data.github.io/rfcs/rfc/0033.html -//! -//! # Overview -//! -//! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) -//! using MSE-optimal scalar quantization on coordinates of a transformed unit vector. -//! -//! The [`TQEncode`] scalar function first computes and stores the original L2 norm for each vector -//! row, then normalizes each valid nonzero row internally before SORF transform and scalar -//! quantization. The [`TQDecode`] scalar function dequantizes through deterministic centroids, -//! applies the inverse SORF transform, truncates back to the original dimension, and re-applies the -//! stored norm. -//! -//! The encoded storage is a row-aligned extension tree: -//! -//! ```text -//! Extension( -//! Struct { -//! norms: Primitive, -//! codes: FixedSizeList, padded_dim, vector_validity>, -//! } -//! ) -//! ``` -//! -//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Decoded quantized -//! directions are not guaranteed to have unit norm after scalar quantization and inverse transform. -//! -//! # Source map -//! -//! Implementation details are documented next to the code that owns them: -//! -//! - `vector/storage.rs`: physical storage shape, full-length child arrays, and field-level -//! validity for null vectors. -//! - `vector/normalize.rs`: TurboQuant-local normalization and how it differs from the tensor -//! crate's null-row zeroing helper. -//! - `vector/quantize.rs`: SORF transform, centroid lookup, and why invalid rows are skipped rather -//! than quantized. -//! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching. -//! - `sorf/`: the Walsh-Hadamard-based structured transform and the stable SplitMix64 sign stream. -//! -//! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL -//! residual correction for unbiased inner-product estimation, and it still uses internal -//! power-of-2 padding rather than the block decomposition proposed in RFC 0033. - -mod centroids; -mod config; -mod scalar_fns; -mod sorf; -mod vector; -mod vtable; - -pub use config::TurboQuantConfig; -pub use scalar_fns::TQDecode; -pub use scalar_fns::TQEncode; -pub use vtable::TurboQuant; -pub use vtable::TurboQuantMetadata; - -// TODO(connor): We need to somehow make sure that callers call `vortex_tensor::initialize` first. -/// Register the TurboQuant extension type with a Vortex session. -pub fn initialize(session: &vortex_session::VortexSession) { - use vortex_array::dtype::session::DTypeSessionExt; - use vortex_array::scalar_fn::session::ScalarFnSessionExt; - session.dtypes().register(TurboQuant); - - session.scalar_fns().register(TQEncode); - session.scalar_fns().register(TQDecode); -} - -#[cfg(test)] -mod tests; diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs deleted file mode 100644 index 9d4d465dbe8..00000000000 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ /dev/null @@ -1,319 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant decode scalar function. - -use std::fmt; -use std::fmt::Formatter; -use std::sync::Arc; - -use num_traits::Float; -use num_traits::FromPrimitive; -use vortex_array::ArrayRef; -use vortex_array::EmptyMetadata; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::ScalarFnArray; -use vortex_array::dtype::DType; -use vortex_array::dtype::NativePType; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::expr::Expression; -use vortex_array::match_each_float_ptype; -use vortex_array::scalar_fn::Arity; -use vortex_array::scalar_fn::ChildName; -use vortex_array::scalar_fn::ExecutionArgs; -use vortex_array::scalar_fn::ScalarFnId; -use vortex_array::scalar_fn::ScalarFnVTable; -use vortex_array::scalar_fn::TypedScalarFnInstance; -use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; -use vortex_mask::Mask; -use vortex_session::VortexSession; -use vortex_session::registry::CachedId; -use vortex_tensor::vector::Vector; - -use crate::centroids::compute_or_get_centroids; -use crate::sorf::SorfMatrix; -use crate::vector::storage::parse_storage; -use crate::vector::tq_padded_dim; -use crate::vtable::TurboQuantMetadata; -use crate::vtable::tq_metadata; - -/// Lazy TurboQuant vector decode scalar function. -#[derive(Clone)] -pub struct TQDecode; - -impl TQDecode { - /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant decoding. - pub fn new() -> TypedScalarFnInstance { - TypedScalarFnInstance::new(TQDecode, EmptyMetadata) - } - - /// Constructs a [`ScalarFnArray`] that lazily decodes a `TurboQuant` child into a `Vector`. - pub fn try_new_array(child: ArrayRef) -> VortexResult { - let len = child.len(); - ScalarFnArray::try_new(TQDecode::new().erased(), vec![child], len) - } -} - -impl ScalarFnVTable for TQDecode { - type Options = EmptyMetadata; - - fn id(&self) -> ScalarFnId { - static ID: CachedId = CachedId::new("vortex.turboquant.decode"); - *ID - } - - fn serialize(&self, _options: &Self::Options) -> VortexResult>> { - Ok(Some(vec![])) - } - - fn deserialize( - &self, - metadata: &[u8], - _session: &VortexSession, - ) -> VortexResult { - vortex_ensure!( - metadata.is_empty(), - "TQDecode options metadata must be empty" - ); - - Ok(EmptyMetadata) - } - - fn arity(&self, _options: &Self::Options) -> Arity { - Arity::Exact(1) - } - - fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { - match child_idx { - 0 => ChildName::from("turboquant"), - _ => unreachable!("TQDecode must have exactly one child"), - } - } - - fn fmt_sql( - &self, - _options: &Self::Options, - expr: &Expression, - f: &mut Formatter<'_>, - ) -> fmt::Result { - write!(f, "tq_decode(")?; - expr.child(0).fmt_sql(f)?; - write!(f, ")") - } - - fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { - let child_dtype = &arg_dtypes[0]; - let metadata = tq_metadata(child_dtype)?; - - let storage_dtype = DType::FixedSizeList( - Arc::new(DType::Primitive( - metadata.element_ptype, - Nullability::NonNullable, - )), - metadata.dimensions, - child_dtype.nullability(), - ); - let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage_dtype)?.erased(); - - Ok(DType::Extension(ext_dtype)) - } - - fn execute( - &self, - _options: &Self::Options, - args: &dyn ExecutionArgs, - ctx: &mut ExecutionCtx, - ) -> VortexResult { - decode_vector(args.get(0)?, ctx) - } - - fn validity( - &self, - _options: &Self::Options, - expression: &Expression, - ) -> VortexResult> { - Ok(Some(expression.child(0).validity()?)) - } - - fn is_null_sensitive(&self, _options: &Self::Options) -> bool { - false - } - - fn is_fallible(&self, _options: &Self::Options) -> bool { - false - } -} - -/// Decode a `TurboQuant` extension array back into a `Vector` extension array. -/// -/// The decoded directions are inverse-transformed, truncated to the original dimension, and -/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with -/// [`TQEncode`](crate::TQEncode). -pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let parsed = parse_storage(input, ctx)?; - let metadata = parsed.metadata; - if parsed.len == 0 { - return build_empty_vector(metadata, parsed.vector_validity); - } - - let padded_dim = tq_padded_dim(metadata.dimensions)?; - let transform = SorfMatrix::try_new(padded_dim, metadata.num_rounds as usize, metadata.seed)?; - let padded_dim = u32::try_from(padded_dim) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - - let centroids = compute_or_get_centroids(padded_dim, metadata.bit_width)?; - - match_each_float_ptype!(metadata.element_ptype, |T| { - decode_typed::( - DecodeInputs { - metadata: &metadata, - sorf_matrix: &transform, - centroids: ¢roids, - norms: &parsed.norms, - codes: &parsed.codes, - }, - parsed.vector_validity, - parsed.len, - ctx, - ) - }) -} - -fn build_empty_vector( - metadata: TurboQuantMetadata, - vector_validity: Validity, -) -> VortexResult { - match_each_float_ptype!(metadata.element_ptype, |T| { - let elements = PrimitiveArray::empty::(Nullability::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - metadata.dimensions, - vector_validity, - 0, - )?; - - Vector::try_new_vector_array(fsl.into_array()) - }) -} - -/// Borrowed bundle of the per-array decode inputs passed to the typed inner loop. -/// -/// Packaged as a struct rather than positional arguments because `decode_typed` runs through -/// [`vortex_array::match_each_float_ptype!`] which expands once per supported element ptype. -/// Each expansion takes the same set of inputs, and the struct keeps the call site short. -struct DecodeInputs<'a> { - /// TurboQuant metadata recovered from the input extension dtype. - metadata: &'a TurboQuantMetadata, - /// SORF transform reconstructed from `metadata.seed` and `metadata.num_rounds`. - sorf_matrix: &'a SorfMatrix, - /// Centroid codebook for `(padded_dim, bit_width)`, in f32. - centroids: &'a [f32], - /// Per-row stored L2 norm of the original input vector, in the element ptype. - norms: &'a PrimitiveArray, - /// Flat per-row centroid indices, `num_vectors * padded_dim` bytes. - codes: &'a PrimitiveArray, -} - -fn decode_typed( - decode: DecodeInputs<'_>, - vector_validity: Validity, - num_vectors: usize, - ctx: &mut ExecutionCtx, -) -> VortexResult -where - T: NativePType + Float + FromPrimitive, -{ - let metadata = decode.metadata; - let dimensions = usize::try_from(metadata.dimensions) - .vortex_expect("dimensions stays representable as usize"); - let padded_dim = decode.sorf_matrix.padded_dim(); - let centroids = decode.centroids; - let norms = decode.norms.as_slice::(); - let codes = decode.codes.as_slice::(); - let mask = vector_validity.execute_mask(num_vectors, ctx)?; - - let output_len = num_vectors - .checked_mul(dimensions) - .ok_or_else(|| vortex_err!("TurboQuant decoded vector length overflow"))?; - let mut output = BufferMut::::with_capacity(output_len); - - let mut decoded = vec![0.0f32; padded_dim]; - let mut inverse = vec![0.0f32; padded_dim]; - - let mut decode_row = |output: &mut BufferMut, i: usize| { - let code_row = &codes[i * padded_dim..][..padded_dim]; - - for (dst, &code) in decoded.iter_mut().zip(code_row.iter()) { - *dst = *centroids - .get(usize::from(code)) - .vortex_expect("TurboQuant code exceeds centroid count"); - } - - decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); - - let norm = norms[i]; - for &value in inverse.iter().take(dimensions) { - // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, - // `f64`): values outside `f16` range saturate to `±inf` rather than returning - // `None`. - let value = T::from_f32(value) - .vortex_expect("from_f32 is infallible for supported float types"); - - // SAFETY: total pushes across all match arms equal `output_len`. - unsafe { output.push_unchecked(value * norm) }; - } - }; - - match &mask { - Mask::AllFalse(_) => { - // SAFETY: `output` was allocated with capacity `output_len`, and this push writes - // exactly `output_len` zero placeholders. - unsafe { output.push_n_unchecked(T::zero(), output_len) }; - } - Mask::AllTrue(_) => { - for i in 0..num_vectors { - decode_row(&mut output, i); - } - } - Mask::Values(values_mask) => { - let mut cursor = 0; - - for &(start, end) in values_mask.slices() { - if start > cursor { - // SAFETY: total pushes across all arms equal `output_len`. - unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; - } - - for i in start..end { - decode_row(&mut output, i); - } - - cursor = end; - } - - if cursor < num_vectors { - // SAFETY: total pushes across all arms equal `output_len`. - unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; - } - } - } - - let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - metadata.dimensions, - vector_validity, - num_vectors, - )?; - - Vector::try_new_vector_array(fsl.into_array()) -} diff --git a/vortex-turboquant/src/scalar_fns/encode.rs b/vortex-turboquant/src/scalar_fns/encode.rs deleted file mode 100644 index 6dd16e4bb66..00000000000 --- a/vortex-turboquant/src/scalar_fns/encode.rs +++ /dev/null @@ -1,226 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant encode scalar function. - -use std::fmt; -use std::fmt::Formatter; - -use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::Extension; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_array::dtype::DType; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::expr::Expression; -use vortex_array::scalar_fn::Arity; -use vortex_array::scalar_fn::ChildName; -use vortex_array::scalar_fn::ExecutionArgs; -use vortex_array::scalar_fn::ScalarFnId; -use vortex_array::scalar_fn::ScalarFnVTable; -use vortex_array::scalar_fn::TypedScalarFnInstance; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; -use vortex_session::VortexSession; -use vortex_session::registry::CachedId; -use vortex_tensor::vector::AnyVector; - -use super::metadata::deserialize_config; -use super::metadata::serialize_config; -use crate::TurboQuantConfig; -use crate::config::MIN_DIMENSION; -use crate::vector::normalize::tq_normalize_as_l2_denorm; -use crate::vector::quantize::empty_quantization; -use crate::vector::quantize::turboquant_quantize_core; -use crate::vector::storage::build_codes_child; -use crate::vector::storage::build_storage; -use crate::vector::tq_padded_dim; -use crate::vtable::TurboQuant; -use crate::vtable::TurboQuantMetadata; -use crate::vtable::tq_storage_dtype; - -/// TurboQuant vector encode scalar function. -/// -/// `TQEncode` itself is a `ScalarFnVTable` and so its options round-trip through expression -/// serialization. -/// -/// Unlike `TQDecode`, it deliberately does **not** implement `ScalarFnArrayVTable` since the -/// persisted artifact would be the original vector array, not the TurboQuant-quantized array. -#[derive(Clone)] -pub struct TQEncode; - -impl TQEncode { - /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant encoding. - pub fn new(config: &TurboQuantConfig) -> TypedScalarFnInstance { - TypedScalarFnInstance::new(TQEncode, config.clone()) - } - - /// Constructs a [`ScalarFnArray`] that lazily encodes a `Vector` child into `TurboQuant`. - pub fn try_new_array( - child: ArrayRef, - config: &TurboQuantConfig, - ) -> VortexResult { - let len = child.len(); - ScalarFnArray::try_new(TQEncode::new(config).erased(), vec![child], len) - } -} - -impl ScalarFnVTable for TQEncode { - type Options = TurboQuantConfig; - - fn id(&self) -> ScalarFnId { - static ID: CachedId = CachedId::new("vortex.turboquant.encode"); - *ID - } - - fn serialize(&self, options: &Self::Options) -> VortexResult>> { - Ok(Some(serialize_config(options))) - } - - fn deserialize( - &self, - metadata: &[u8], - _session: &VortexSession, - ) -> VortexResult { - deserialize_config(metadata) - } - - fn arity(&self, _options: &Self::Options) -> Arity { - Arity::Exact(1) - } - - fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { - match child_idx { - 0 => ChildName::from("vector"), - _ => unreachable!("TQEncode must have exactly one child"), - } - } - - fn fmt_sql( - &self, - options: &Self::Options, - expr: &Expression, - f: &mut Formatter<'_>, - ) -> fmt::Result { - write!(f, "tq_encode(")?; - expr.child(0).fmt_sql(f)?; - write!(f, ", {options})") - } - - fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { - let input_dtype = &arg_dtypes[0]; - let vector_metadata = input_dtype - .as_extension_opt() - .and_then(|ext_dtype| ext_dtype.metadata_opt::()) - .ok_or_else(|| { - vortex_err!("TQEncode expects a Vector extension array, got {input_dtype}") - })?; - - let dimensions = vector_metadata.dimensions(); - vortex_ensure!( - dimensions >= MIN_DIMENSION, - "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", - ); - tq_padded_dim(dimensions)?; - - let metadata = TurboQuantMetadata { - element_ptype: vector_metadata.element_ptype(), - dimensions, - bit_width: options.bit_width(), - seed: options.seed(), - num_rounds: options.num_rounds(), - }; - let storage_dtype = tq_storage_dtype(&metadata, input_dtype.nullability())?; - let ext_dtype = ExtDType::::try_new(metadata, storage_dtype)?.erased(); - - Ok(DType::Extension(ext_dtype)) - } - - fn execute( - &self, - options: &Self::Options, - args: &dyn ExecutionArgs, - ctx: &mut ExecutionCtx, - ) -> VortexResult { - encode_vector(args.get(0)?, options, ctx) - } - - fn validity( - &self, - _options: &Self::Options, - expression: &Expression, - ) -> VortexResult> { - Ok(Some(expression.child(0).validity()?)) - } - - fn is_null_sensitive(&self, _options: &Self::Options) -> bool { - false - } - - fn is_fallible(&self, _options: &Self::Options) -> bool { - false - } -} - -/// Lossily encode a `Vector` extension array into a `TurboQuant` extension array. -/// -/// Valid rows are normalized internally before SORF transform and scalar quantization. The original -/// row norms are stored explicitly, and original vector nulls are preserved on the storage struct -/// and both row-aligned child arrays. -pub(crate) fn encode_vector( - input: ArrayRef, - config: &TurboQuantConfig, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let num_vectors = input.len(); - let vector_metadata = input - .dtype() - .as_extension_opt() - .and_then(|ext_dtype| ext_dtype.metadata_opt::()) - .ok_or_else(|| vortex_err!("TurboQuant encode expects a Vector extension array"))?; - - let element_ptype = vector_metadata.element_ptype(); - - let dimensions = vector_metadata.dimensions(); - vortex_ensure!( - dimensions >= MIN_DIMENSION, - "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", - ); - let padded_dim = tq_padded_dim(dimensions)?; - - let vector_validity = input.validity()?; - - let l2_denorm = tq_normalize_as_l2_denorm(input, ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - - let normalized_ext = normalized - .as_opt::() - .ok_or_else(|| vortex_err!("normalized TurboQuant input must be a Vector extension"))?; - let normalized_fsl: FixedSizeListArray = normalized_ext.storage_array().clone().execute(ctx)?; - - let core = if normalized_fsl.is_empty() { - empty_quantization(padded_dim) - } else { - // SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child. - unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? } - }; - let codes = build_codes_child(num_vectors, core, vector_validity.clone())?; - - let metadata = TurboQuantMetadata { - element_ptype, - dimensions, - bit_width: config.bit_width(), - seed: config.seed(), - num_rounds: config.num_rounds(), - }; - let storage = build_storage(norms, codes, num_vectors, vector_validity)?; - - Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array()) -} diff --git a/vortex-turboquant/src/scalar_fns/metadata.rs b/vortex-turboquant/src/scalar_fns/metadata.rs deleted file mode 100644 index f5eddfe51d9..00000000000 --- a/vortex-turboquant/src/scalar_fns/metadata.rs +++ /dev/null @@ -1,47 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use prost::Message; -use vortex_error::VortexResult; -use vortex_error::vortex_err; - -use crate::TurboQuantConfig; - -#[derive(Clone, PartialEq, Message)] -pub(super) struct TQScalarFnMetadata { - #[prost(uint32, tag = "1")] - bit_width: u32, - #[prost(uint64, tag = "2")] - seed: u64, - #[prost(uint32, tag = "3")] - num_rounds: u32, -} - -impl TQScalarFnMetadata { - pub(super) fn from_config(config: &TurboQuantConfig) -> Self { - Self { - bit_width: u32::from(config.bit_width()), - seed: config.seed(), - num_rounds: u32::from(config.num_rounds()), - } - } - - pub(super) fn to_config(&self) -> VortexResult { - let bit_width = u8::try_from(self.bit_width) - .map_err(|_| vortex_err!("TurboQuant bit_width does not fit u8"))?; - let num_rounds = u8::try_from(self.num_rounds) - .map_err(|_| vortex_err!("TurboQuant num_rounds does not fit u8"))?; - - TurboQuantConfig::try_new(bit_width, self.seed, num_rounds) - } -} - -pub(super) fn serialize_config(config: &TurboQuantConfig) -> Vec { - TQScalarFnMetadata::from_config(config).encode_to_vec() -} - -pub(super) fn deserialize_config(metadata: &[u8]) -> VortexResult { - TQScalarFnMetadata::decode(metadata) - .map_err(|e| vortex_err!("Failed to decode TurboQuant scalar function metadata: {e}"))? - .to_config() -} diff --git a/vortex-turboquant/src/scalar_fns/mod.rs b/vortex-turboquant/src/scalar_fns/mod.rs deleted file mode 100644 index 1acea9f70f3..00000000000 --- a/vortex-turboquant/src/scalar_fns/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Scalar functions for lazy TurboQuant vector encode and decode operations. - -mod decode; -mod encode; -mod metadata; - -pub use decode::TQDecode; -pub use encode::TQEncode; diff --git a/vortex-turboquant/src/sorf/mod.rs b/vortex-turboquant/src/sorf/mod.rs deleted file mode 100644 index cce477aa906..00000000000 --- a/vortex-turboquant/src/sorf/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -mod splitmix64; -mod transform; - -pub(crate) use transform::SorfMatrix; diff --git a/vortex-turboquant/src/sorf/splitmix64.rs b/vortex-turboquant/src/sorf/splitmix64.rs deleted file mode 100644 index fc3f9073ced..00000000000 --- a/vortex-turboquant/src/sorf/splitmix64.rs +++ /dev/null @@ -1,78 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Frozen local SplitMix64 stream used to define SORF sign diagonals. -//! -//! This is a direct translation of the `splitmix64.c` [reference implementation][impl]. -//! -//! The state is a single `u64`, and `next_u64()` first adds [`SPLITMIX64_INCREMENT`] with wrapping -//! arithmetic, then applies the two reference mixing steps and final xor-shift. -//! -//! [impl]: https://prng.di.unimi.it/splitmix64.c - -/// SplitMix64 additive constant from the reference implementation. -const SPLITMIX64_INCREMENT: u64 = 0x9E37_79B9_7F4A_7C15; - -/// First SplitMix64 mixing multiplier from the reference implementation. -const SPLITMIX64_MUL1: u64 = 0xBF58_476D_1CE4_E5B9; - -/// Second SplitMix64 mixing multiplier from the reference implementation. -const SPLITMIX64_MUL2: u64 = 0x94D0_49BB_1331_11EB; - -/// Frozen local SplitMix64 stream used to define SORF sign diagonals. Bit-identical to the -/// reference implementation linked at the module top, which makes the sign stream part of the -/// encoding's wire contract. -pub(crate) struct SplitMix64 { - state: u64, -} - -impl SplitMix64 { - pub(crate) fn new(seed: u64) -> Self { - Self { state: seed } - } - - pub(crate) fn next_u64(&mut self) -> u64 { - self.state = self.state.wrapping_add(SPLITMIX64_INCREMENT); - let mut z = self.state; - z = (z ^ (z >> 30)).wrapping_mul(SPLITMIX64_MUL1); - z = (z ^ (z >> 27)).wrapping_mul(SPLITMIX64_MUL2); - z ^ (z >> 31) - } -} - -#[cfg(test)] -mod tests { - use super::SplitMix64; - - const SPLITMIX64_SEED0_GOLDEN: [u64; 4] = [ - 0xE220_A839_7B1D_CDAF, - 0x6E78_9E6A_A1B9_65F4, - 0x06C4_5D18_8009_454F, - 0xF88B_B8A8_724C_81EC, - ]; - - const SPLITMIX64_SEED42_GOLDEN: [u64; 4] = [ - 0xBDD7_3226_2FEB_6E95, - 0x28EF_E333_B266_F103, - 0x4752_6757_130F_9F52, - 0x581C_E1FF_0E4A_E394, - ]; - - #[test] - fn splitmix64_seed0_matches_golden_outputs() { - let mut rng = SplitMix64::new(0); - let actual: Vec<_> = (0..SPLITMIX64_SEED0_GOLDEN.len()) - .map(|_| rng.next_u64()) - .collect(); - assert_eq!(actual, SPLITMIX64_SEED0_GOLDEN); - } - - #[test] - fn splitmix64_seed42_matches_golden_outputs() { - let mut rng = SplitMix64::new(42); - let actual: Vec<_> = (0..SPLITMIX64_SEED42_GOLDEN.len()) - .map(|_| rng.next_u64()) - .collect(); - assert_eq!(actual, SPLITMIX64_SEED42_GOLDEN); - } -} diff --git a/vortex-turboquant/src/sorf/transform.rs b/vortex-turboquant/src/sorf/transform.rs deleted file mode 100644 index 3fa221fe03a..00000000000 --- a/vortex-turboquant/src/sorf/transform.rs +++ /dev/null @@ -1,419 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! SORF (Structured Orthogonal Random Features) orthogonal transform. -//! -//! Implements the SORF construction from [Yu et al. 2016][sorf-paper]: a fast structured -//! approximation to a random orthogonal matrix using random sign diagonals interleaved with the -//! Fast Walsh-Hadamard Transform (FWHT). -//! -//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf -//! -//! For `k` rounds, the transform is `norm * H * D_k * ... * H * D_1 * x`, where `D_1` is the -//! first sign diagonal applied. The number of rounds is configurable (typically 3). Each round -//! applies a random sign diagonal `D_i` and then the Hadamard matrix `H`, giving O(d log d) cost -//! per matrix-vector product instead of the O(d^2) cost of a dense orthogonal matrix. -//! -//! This implementation defines those sign diagonals using a frozen local SplitMix64 stream rather -//! than an -//! external RNG crate. The contract is: -//! -//! - state is a single `u64` seed, -//! - each `next_u64()` call uses the SplitMix64 reference algorithm with wrapping `u64` -//! arithmetic, -//! - signs are generated in round-major, block-major order, -//! - each generated `u64` contributes 64 signs in least-significant-bit-first order, -//! - bit `1` means `+1` and bit `0` means `-1`. -//! -//! This makes SORF sign generation stable as an extension format contract even if external RNG -//! implementations change. -//! -//! This transform is the crate's practical structured transform choice for TurboQuant. It is not -//! the dense random Gaussian or orthogonal matrix used by some theoretical analyses, so theoretical -//! bounds from those models need separate validation before being presented as implementation -//! guarantees. -//! -//! The FWHT exploits the Kronecker product structure of the Hadamard matrix (`H_n = H_2 (x) H_2 -//! (x) ... (x) H_2`, with `log2(n)` factors) to compute the matrix-vector product in O(n log n) -//! time using only in-place 2-element butterfly operations. No row of the full n x n Hadamard -//! matrix is ever materialized. -//! -//! For dimensions that are not powers of 2, the input is zero-padded to the next power of 2 before -//! the transform and truncated afterward. -//! -//! # Sign representation -//! -//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) and `0x80000000` for -//! -1 (flip IEEE 754 sign bit). The sign application function uses integer XOR instead of -//! floating-point multiply, which avoids FP dependency chains and auto-vectorizes into -//! `vpxor`/`veor`. - -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; - -use super::splitmix64::SplitMix64; - -/// IEEE 754 sign bit mask for f32. -const F32_SIGN_BIT: u32 = 0x8000_0000; - -/// A Walsh-Hadamard-based structured orthogonal transform matrix. -/// -/// All computation is done in f32. The sign diagonals are stored as IEEE 754 XOR masks on -/// f32 bit patterns, and the Walsh-Hadamard butterfly operates on `&mut [f32]` slices. -pub(crate) struct SorfMatrix { - /// Flat XOR masks for all `num_rounds` diagonal matrices, total length - /// `num_rounds * padded_dim`. - /// - /// Indexed as `round * padded_dim + i`. `0x00000000` = multiply by +1 (no-op), `0x80000000` = - /// multiply by -1 (flip sign bit). - sign_masks: Vec, - - /// The number of sign-diagonal + WHT rounds. - num_rounds: usize, - - /// The padded dimension (next power of 2 >= dimension). - padded_dim: usize, - - /// Normalization factor: `padded_dim^(-num_rounds/2)`, applied once at the end. - /// - /// This is stored for convenience. - norm_factor: f32, -} - -impl SorfMatrix { - /// Create a new structured Walsh-Hadamard-based orthogonal transform for a padded dimension. - /// - /// `padded_dimensions` must already be a power of two. Callers that start from an unpadded - /// logical dimension are responsible for padding it before constructing the matrix. - pub(crate) fn try_new( - padded_dimensions: usize, - num_rounds: usize, - seed: u64, - ) -> VortexResult { - vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}"); - vortex_ensure!( - padded_dimensions.is_power_of_two(), - "padded_dimensions must be a power of two, got {padded_dimensions}" - ); - - let padded_dim = padded_dimensions; - let sign_masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); - - // Compute in f64 for precision, then store as f32 since the WHT operates on f32 buffers. - // The result is always in (0, 1] for any valid padded_dim >= 2 and num_rounds >= 1, so - // the f64 -> f32 cast is a precision loss only (it cannot overflow to infinity). - #[expect( - clippy::cast_possible_truncation, - reason = "the norm factor is in (0, 1] so the f64 -> f32 cast cannot overflow" - )] - let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32; - - Ok(Self { - sign_masks, - num_rounds, - padded_dim, - norm_factor, - }) - } - - /// Returns the padded dimension (next power of 2 >= dim). - /// - /// All `transform`/`inverse_transform` buffers must be this length. - pub(crate) fn padded_dim(&self) -> usize { - self.padded_dim - } - - /// Apply the forward orthogonal transform: `output = R(input)`. - /// - /// Both `input` and `output` must have length [`padded_dim()`](Self::padded_dim). The caller is - /// responsible for zero-padding input beyond `dim` positions. - pub(crate) fn transform(&self, input: &[f32], output: &mut [f32]) { - debug_assert_eq!(input.len(), self.padded_dim); - debug_assert_eq!(output.len(), self.padded_dim); - - output.copy_from_slice(input); - self.apply_srht(output); - } - - /// Apply the inverse orthogonal transform: `output = R⁻¹(input)`. - /// - /// Both `input` and `output` must have length `padded_dim()`. - pub(crate) fn inverse_transform(&self, input: &[f32], output: &mut [f32]) { - debug_assert_eq!(input.len(), self.padded_dim); - debug_assert_eq!(output.len(), self.padded_dim); - - output.copy_from_slice(input); - self.apply_inverse_srht(output); - } - - /// Apply the forward structured transform: `norm · H · D_k · ... · H · D₁ · x`. - fn apply_srht(&self, buf: &mut [f32]) { - for round in 0..self.num_rounds { - self.apply_signs_xor(buf, round); - walsh_hadamard_transform(buf); - } - - buf.iter_mut().for_each(|val| *val *= self.norm_factor); - } - - /// Apply the inverse structured transform. - /// - /// Forward is: `norm · H · D_k · ... · H · D₁`. - /// Inverse is: `norm · D₁ · H · ... · D_k · H`. - fn apply_inverse_srht(&self, buf: &mut [f32]) { - for round in (0..self.num_rounds).rev() { - walsh_hadamard_transform(buf); - self.apply_signs_xor(buf, round); - } - - buf.iter_mut().for_each(|val| *val *= self.norm_factor); - } - - /// Apply one round's sign masks via XOR on the IEEE 754 sign bit. - /// - /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to - /// multiplying each element by +/-1.0, but avoids FP dependency chains. - fn apply_signs_xor(&self, buf: &mut [f32], round: usize) { - let masks = &self.sign_masks[round * self.padded_dim..][..self.padded_dim]; - for (val, &mask) in buf.iter_mut().zip(masks.iter()) { - *val = f32::from_bits(val.to_bits() ^ mask); - } - } -} - -/// Generate XOR sign masks from the frozen local SplitMix64 stream. -/// -/// Signs are produced in round-major, block-major order. For each block we call -/// [`SplitMix64::next_u64`] exactly once and unpack its bits from least significant to most -/// significant. Bit `1` means positive sign / `0x00000000`; bit `0` means negative sign / -/// [`F32_SIGN_BIT`]. -fn gen_sign_masks_from_seed(seed: u64, padded_dim: usize, num_rounds: usize) -> Vec { - let mut rng = SplitMix64::new(seed); - let mut sign_masks = Vec::with_capacity(num_rounds * padded_dim); - - for _round in 0..num_rounds { - for base_idx in (0..padded_dim).step_by(64) { - let word = rng.next_u64(); - let bits_in_block = (padded_dim - base_idx).min(64); - sign_masks.extend((0..bits_in_block).map(|bit_idx| sign_mask_from_word(word, bit_idx))); - } - } - - sign_masks -} - -/// Convert one bit from a SplitMix64 output word into an XOR sign mask. -fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 { - if ((word >> bit_idx) & 1) != 0 { - 0u32 - } else { - F32_SIGN_BIT - } -} - -/// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative. -/// -/// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2` -/// [`butterfly`] operations each. See the [module-level docs](self) for why this avoids -/// materializing the full Hadamard matrix. -/// -/// The chunk-based iteration gives LLVM enough structure to auto-vectorize each butterfly call -/// into NEON/AVX SIMD instructions. -fn walsh_hadamard_transform(buf: &mut [f32]) { - let len = buf.len(); - debug_assert!(len.is_power_of_two()); - - let mut half = 1; - while half < len { - let stride = half * 2; - // Process in chunks of `stride` elements. Within each chunk, - // split into non-overlapping (lo, hi) halves for the butterfly. - for chunk in buf.chunks_exact_mut(stride) { - let (lo, hi) = chunk.split_at_mut(half); - butterfly(lo, hi); - } - half *= 2; - } -} - -/// Butterfly: `(lo[i], hi[i]) -> (lo[i] + hi[i], lo[i] - hi[i])`. -/// -/// This is multiplication by the 2x2 Hadamard kernel `H_2 = [[1, 1], [1, -1]]` on each element -/// pair. Factored into a separate function so LLVM can see the slice lengths match and -/// auto-vectorize. -fn butterfly(lo: &mut [f32], hi: &mut [f32]) { - debug_assert_eq!(lo.len(), hi.len()); - for (a, b) in lo.iter_mut().zip(hi.iter_mut()) { - let sum = *a + *b; - let diff = *a - *b; - *a = sum; - *b = diff; - } -} - -#[cfg(test)] -mod tests { - use rstest::rstest; - use vortex_error::VortexResult; - - use super::*; - - fn dim_to_usize(dim: u32) -> usize { - usize::try_from(dim).unwrap() - } - - fn rounds_to_usize(num_rounds: u8) -> usize { - usize::from(num_rounds) - } - - #[test] - fn deterministic_from_seed() -> VortexResult<()> { - let padded_dim = dim_to_usize(64u32); - let num_rounds = rounds_to_usize(3u8); - let seed = 42u64; - let transform1 = SorfMatrix::try_new(padded_dim, num_rounds, seed)?; - let transform2 = SorfMatrix::try_new(padded_dim, num_rounds, seed)?; - let pd = transform1.padded_dim(); - - let mut input = vec![0.0f32; pd]; - for i in 0..padded_dim { - input[i] = i as f32; - } - let mut out1 = vec![0.0f32; pd]; - let mut out2 = vec![0.0f32; pd]; - - transform1.transform(&input, &mut out1); - transform2.transform(&input, &mut out2); - - assert_eq!(out1, out2); - Ok(()) - } - - #[test] - fn one_word_generates_64_signs_lsb_first() { - let seed = 42u64; - let padded_dim = dim_to_usize(64u32); - let num_rounds = rounds_to_usize(1u8); - let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); - assert_eq!(masks.len(), padded_dim); - - let mut rng = SplitMix64::new(seed); - let word = rng.next_u64(); - let expected: Vec<_> = (0..padded_dim) - .map(|bit_idx| sign_mask_from_word(word, bit_idx)) - .collect(); - assert_eq!(masks, expected); - } - - #[test] - fn rejects_non_power_of_two_padded_dimensions() { - assert!(SorfMatrix::try_new(dim_to_usize(100u32), rounds_to_usize(3u8), 42u64).is_err()); - } - - #[test] - fn tail_block_uses_only_required_bits() { - let seed = 42u64; - let padded_dim = dim_to_usize(32u32); - let num_rounds = rounds_to_usize(1u8); - let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); - assert_eq!(masks.len(), padded_dim); - - let mut rng = SplitMix64::new(seed); - let word = rng.next_u64(); - let expected: Vec<_> = (0..padded_dim) - .map(|bit_idx| sign_mask_from_word(word, bit_idx)) - .collect(); - assert_eq!(masks, expected); - } - - /// Verify roundtrip is exact to f32 precision across many dimensions and round counts, - /// including non-power-of-two dimensions that require padding. - #[rstest] - #[case(32u32, 3u8)] - #[case(64u32, 3u8)] - #[case(100u32, 3u8)] - #[case(128u32, 1u8)] - #[case(128u32, 2u8)] - #[case(128u32, 3u8)] - #[case(128u32, 5u8)] - #[case(256u32, 3u8)] - #[case(512u32, 3u8)] - #[case(768u32, 3u8)] - #[case(1024u32, 3u8)] - fn roundtrip_exact(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { - let dim = dim_to_usize(dim); - let num_rounds = rounds_to_usize(num_rounds); - let transform = SorfMatrix::try_new(dim.next_power_of_two(), num_rounds, 42u64)?; - let padded_dim = transform.padded_dim(); - - let mut input = vec![0.0f32; padded_dim]; - for i in 0..dim { - input[i] = (i as f32 + 1.0) * 0.01; - } - let mut transformed = vec![0.0f32; padded_dim]; - let mut recovered = vec![0.0f32; padded_dim]; - - transform.transform(&input, &mut transformed); - transform.inverse_transform(&transformed, &mut recovered); - - let max_err: f32 = input - .iter() - .zip(recovered.iter()) - .map(|(a, b)| (a - b).abs()) - .fold(0.0f32, f32::max); - let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max); - let rel_err = max_err / max_val; - - // SRHT roundtrip should be exact up to f32 precision (~1e-6). - assert!( - rel_err < 1e-5, - "roundtrip relative error too large for dim={dim}, rounds={num_rounds}: {rel_err:.2e}" - ); - Ok(()) - } - - /// Verify norm preservation across dimensions and round counts. - #[rstest] - #[case(128u32, 1u8)] - #[case(128u32, 3u8)] - #[case(128u32, 5u8)] - #[case(768u32, 3u8)] - fn preserves_norm(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { - let dim = dim_to_usize(dim); - let num_rounds = rounds_to_usize(num_rounds); - let transform = SorfMatrix::try_new(dim.next_power_of_two(), num_rounds, 7u64)?; - let padded_dim = transform.padded_dim(); - - let mut input = vec![0.0f32; padded_dim]; - for i in 0..dim { - input[i] = (i as f32) * 0.01; - } - let input_norm: f32 = input.iter().map(|x| x * x).sum::().sqrt(); - - let mut transformed = vec![0.0f32; padded_dim]; - transform.transform(&input, &mut transformed); - let transformed_norm: f32 = transformed.iter().map(|x| x * x).sum::().sqrt(); - - assert!( - (input_norm - transformed_norm).abs() / input_norm < 1e-5, - "norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})", - input_norm, - transformed_norm, - (input_norm - transformed_norm).abs() / input_norm - ); - Ok(()) - } - - #[test] - fn wht_basic() { - // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] - let mut buf = vec![1.0f32, 0.0, 0.0, 0.0]; - walsh_hadamard_transform(&mut buf); - assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]); - - // WHT is self-inverse (up to scaling by n) - walsh_hadamard_transform(&mut buf); - // After two WHTs: each element multiplied by n=4 - assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]); - } -} diff --git a/vortex-turboquant/src/tests/encode_decode.rs b/vortex-turboquant/src/tests/encode_decode.rs deleted file mode 100644 index ed5aab190aa..00000000000 --- a/vortex-turboquant/src/tests/encode_decode.rs +++ /dev/null @@ -1,254 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use rstest::rstest; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_array::arrays::struct_::StructArrayExt; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; -use vortex_array::validity::Validity; -use vortex_buffer::Buffer; -use vortex_error::VortexResult; - -use super::execute_tq_decode; -use super::execute_tq_encode; -use super::f32_vector_array; -use super::test_session; -use super::turboquant_storage; -use super::vector_array; -use super::vector_element_ptype; -use super::vector_validity; -use super::vector_values_f32; -use crate::TurboQuantConfig; -use crate::centroids::compute_or_get_centroids; -use crate::vector::normalize::tq_normalize_as_l2_denorm; - -#[rstest] -#[case::zero_bits(0, 42, 3)] -#[case::too_many_bits(9, 42, 3)] -#[case::zero_rounds(2, 42, 0)] -fn config_rejects_invalid_values(#[case] bit_width: u8, #[case] seed: u64, #[case] num_rounds: u8) { - assert!(TurboQuantConfig::try_new(bit_width, seed, num_rounds).is_err()); -} - -#[test] -fn encode_rejects_non_vector_input() { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = PrimitiveArray::new::(Buffer::copy_from([1.0, 2.0]), Validity::NonNullable) - .into_array(); - - assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); -} - -#[test] -fn encode_rejects_small_dimensions() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(127, 1, 1.0, Validity::NonNullable)?; - - assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); - Ok(()) -} - -#[test] -fn encode_rejects_padded_dimension_overflow() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = vector_array::(2_147_483_649, &[], Validity::NonNullable)?; - - assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); - Ok(()) -} - -#[test] -fn centroid_cache_is_deterministic() -> VortexResult<()> { - let first = compute_or_get_centroids(128, 3)?; - let second = compute_or_get_centroids(128, 3)?; - - assert_eq!(first.as_slice(), second.as_slice()); - Ok(()) -} - -#[test] -fn encode_decode_empty_vectors() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = vector_array::(128, &[], Validity::NonNullable)?; - - let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; - let decoded = execute_tq_decode(encoded, &mut ctx)?; - - assert!(decoded.is_empty()); - Ok(()) -} - -#[test] -fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let validity = Validity::from_iter([true, false, true]); - let input = f32_vector_array(128, 3, 0.25, validity)?; - - let config = TurboQuantConfig::try_new(3, 1, 2)?; - let encoded = execute_tq_encode(input, &config, &mut ctx)?; - let storage = turboquant_storage(encoded, &mut ctx)?; - let mask = storage.struct_validity().execute_mask(3, &mut ctx)?; - let norms: PrimitiveArray = storage - .unmasked_field_by_name("norms")? - .clone() - .execute(&mut ctx)?; - let codes: FixedSizeListArray = storage - .unmasked_field_by_name("codes")? - .clone() - .execute(&mut ctx)?; - - assert!(mask.value(0)); - assert!(!mask.value(1)); - assert!(mask.value(2)); - assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); - assert_eq!(codes.validity()?.nullability(), Nullability::Nullable); - - let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?; - let codes_validity = codes.validity()?.execute_mask(3, &mut ctx)?; - assert!(norms_validity.value(0)); - assert!(!norms_validity.value(1)); - assert!(norms_validity.value(2)); - assert!(codes_validity.value(0)); - assert!(!codes_validity.value(1)); - assert!(codes_validity.value(2)); - - let codes_values: PrimitiveArray = codes.elements().clone().execute(&mut ctx)?; - assert!( - codes_values.as_slice::()[128..256] - .iter() - .all(|&code| code == 0) - ); - Ok(()) -} - -#[test] -fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let mut values = vec![0.0f32; 3 * 128]; - values[0] = 3.0; - values[1] = 4.0; - values[128..256].fill(13.0); - values[256] = 1.0; - let input = vector_array(128, &values, Validity::from_iter([true, false, true]))?; - - let l2_denorm = tq_normalize_as_l2_denorm(input, &mut ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - - let normalized_ext: ExtensionArray = normalized.execute(&mut ctx)?; - let normalized_fsl: FixedSizeListArray = - normalized_ext.storage_array().clone().execute(&mut ctx)?; - let normalized_values: PrimitiveArray = normalized_fsl.elements().clone().execute(&mut ctx)?; - let norms: PrimitiveArray = norms.execute(&mut ctx)?; - let normalized_validity = normalized_fsl.validity()?.execute_mask(3, &mut ctx)?; - let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?; - - assert!(normalized_validity.value(0)); - assert!(!normalized_validity.value(1)); - assert!(normalized_validity.value(2)); - assert!(norms_validity.value(0)); - assert!(!norms_validity.value(1)); - assert!(norms_validity.value(2)); - assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); - assert_eq!(norms.as_slice::()[0], 5.0); - assert!( - normalized_values.as_slice::()[128..256] - .iter() - .all(|&value| value == 0.0) - ); - Ok(()) -} - -#[test] -fn encode_decode_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let mut values = vec![0.0f32; 3 * 128]; - values[0] = 3.0; - values[1] = 4.0; - values[256] = 1.0; - values[257] = -1.0; - let input = vector_array(128, &values, Validity::from_iter([true, true, false]))?; - - let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; - let decoded = execute_tq_decode(encoded, &mut ctx)?; - let output = vector_values_f32(decoded.clone(), &mut ctx)?; - let validity = vector_validity(decoded, &mut ctx)?.execute_mask(3, &mut ctx)?; - - assert!(validity.value(0)); - assert!(validity.value(1)); - assert!(!validity.value(2)); - assert!(output[128..256].iter().all(|&v| v == 0.0)); - Ok(()) -} - -#[rstest] -#[case::f16(PType::F16)] -#[case::f64(PType::F64)] -fn encode_decode_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let config = TurboQuantConfig::try_new(3, 42, 3)?; - - match ptype { - PType::F16 => { - let values = (0..2 * 128) - .map(|i| half::f16::from_f32(((i % 17) as f32 - 8.0) * 0.25)) - .collect::>(); - let input = vector_array(128, &values, Validity::NonNullable)?; - let encoded = execute_tq_encode(input, &config, &mut ctx)?; - let decoded = execute_tq_decode(encoded, &mut ctx)?; - let ext: ExtensionArray = decoded.execute(&mut ctx)?; - assert_eq!(vector_element_ptype(&ext)?, PType::F16); - } - PType::F64 => { - let values = (0..2 * 128) - .map(|i| ((i % 17) as f64 - 8.0) * 0.25) - .collect::>(); - let input = vector_array(128, &values, Validity::NonNullable)?; - let encoded = execute_tq_encode(input, &config, &mut ctx)?; - let decoded = execute_tq_decode(encoded, &mut ctx)?; - let ext: ExtensionArray = decoded.execute(&mut ctx)?; - assert_eq!(vector_element_ptype(&ext)?, PType::F64); - } - _ => unreachable!("test only passes f16/f64"), - } - Ok(()) -} - -#[test] -fn decode_scales_by_stored_norms() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let base = f32_vector_array(128, 1, 0.5, Validity::NonNullable)?; - let scaled = f32_vector_array(128, 1, 1.0, Validity::NonNullable)?; - let config = TurboQuantConfig::try_new(2, 99, 3)?; - - let base_values = vector_values_f32( - execute_tq_decode(execute_tq_encode(base, &config, &mut ctx)?, &mut ctx)?, - &mut ctx, - )?; - let scaled_values = vector_values_f32( - execute_tq_decode(execute_tq_encode(scaled, &config, &mut ctx)?, &mut ctx)?, - &mut ctx, - )?; - - for (base, scaled) in base_values.iter().zip(scaled_values.iter()) { - assert!((*scaled - 2.0 * *base).abs() < 1e-5); - } - Ok(()) -} diff --git a/vortex-turboquant/src/tests/file.rs b/vortex-turboquant/src/tests/file.rs deleted file mode 100644 index e59b7a95c75..00000000000 --- a/vortex-turboquant/src/tests/file.rs +++ /dev/null @@ -1,73 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::stream::ArrayStreamExt; -use vortex_array::validity::Validity; -use vortex_error::VortexResult; -use vortex_file::OpenOptionsSessionExt; -use vortex_file::VortexWriteOptions; -use vortex_io::runtime::BlockingRuntime; -use vortex_io::runtime::single::SingleThreadRuntime; -use vortex_tensor::vector::Vector; - -use super::execute_tq_decode_from_metadata; -use super::execute_tq_encode; -use super::f32_vector_array; -use super::file_session; -use super::vector_validity; -use crate::TQDecode; -use crate::TurboQuantConfig; -use crate::vtable::tq_metadata; - -#[test] -fn file_roundtrip_with_initialize_session() -> VortexResult<()> { - let runtime = SingleThreadRuntime::default(); - let session = file_session(&runtime); - let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; - let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; - - let mut file_bytes = Vec::new(); - VortexWriteOptions::new(session.clone()) - .blocking(&runtime) - .write(&mut file_bytes, encoded.to_array_iterator())?; - - let file = session.open_options().open_buffer(file_bytes)?; - let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; - - let metadata = tq_metadata(read.dtype())?; - assert_eq!(metadata.dimensions, 128); - let decoded = execute_tq_decode_from_metadata(read, &mut ctx)?; - let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; - assert!(validity.value(0)); - assert!(!validity.value(1)); - Ok(()) -} - -#[test] -fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResult<()> { - let runtime = SingleThreadRuntime::default(); - let session = file_session(&runtime); - let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; - let config = TurboQuantConfig::try_new(3, 42, 3)?; - let encoded = execute_tq_encode(input, &config, &mut ctx)?; - let decoded = TQDecode::try_new_array(encoded)?.into_array(); - - let mut file_bytes = Vec::new(); - VortexWriteOptions::new(session.clone()) - .blocking(&runtime) - .write(&mut file_bytes, decoded.to_array_iterator())?; - - let file = session.open_options().open_buffer(file_bytes)?; - let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; - - assert!(read.dtype().as_extension().is::()); - - let validity = vector_validity(read, &mut ctx)?.execute_mask(2, &mut ctx)?; - assert!(validity.value(0)); - assert!(!validity.value(1)); - Ok(()) -} diff --git a/vortex-turboquant/src/tests/malformed.rs b/vortex-turboquant/src/tests/malformed.rs deleted file mode 100644 index f99f0ee5105..00000000000 --- a/vortex-turboquant/src/tests/malformed.rs +++ /dev/null @@ -1,189 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use rstest::rstest; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::StructArray; -use vortex_array::dtype::FieldNames; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; -use vortex_array::validity::Validity; -use vortex_buffer::Buffer; -use vortex_error::VortexResult; - -use super::execute_tq_decode_from_metadata; -use super::test_session; -use super::vector_validity; -use crate::TurboQuant; -use crate::TurboQuantMetadata; - -#[rstest] -#[case::nullable_norms_under_nonnullable_struct( - Nullability::NonNullable, - Nullability::Nullable, - Nullability::NonNullable -)] -#[case::nullable_codes_under_nonnullable_struct( - Nullability::NonNullable, - Nullability::NonNullable, - Nullability::Nullable -)] -#[case::nonnullable_norms_under_nullable_struct( - Nullability::Nullable, - Nullability::NonNullable, - Nullability::Nullable -)] -#[case::nonnullable_codes_under_nullable_struct( - Nullability::Nullable, - Nullability::Nullable, - Nullability::NonNullable -)] -fn decode_accepts_child_nullability_that_covers_struct_validity( - #[case] struct_nullability: Nullability, - #[case] norms_nullability: Nullability, - #[case] codes_nullability: Nullability, -) -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 1, - seed: 42, - num_rounds: 3, - }; - let norms = - PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::from(norms_nullability)) - .into_array(); - let codes = PrimitiveArray::new::(vec![0u8; 128], Validity::NonNullable); - let codes = FixedSizeListArray::try_new( - codes.into_array(), - 128, - Validity::from(codes_nullability), - 1, - ) - .unwrap() - .into_array(); - let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], - 1, - Validity::from(struct_nullability), - ) - .unwrap(); - let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array()) - .unwrap() - .into_array(); - - execute_tq_decode_from_metadata(tq, &mut ctx)?; - Ok(()) -} - -#[test] -fn decode_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 1, - seed: 42, - num_rounds: 3, - }; - let norms = - PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) - .into_array(); - let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); - let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 3)? - .into_array(); - let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], - 3, - Validity::from_iter([true, false, true]), - )?; - let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? - .into_array(); - - let decoded = execute_tq_decode_from_metadata(tq, &mut ctx)?; - let validity = vector_validity(decoded, &mut ctx)?.execute_mask(3, &mut ctx)?; - assert!(validity.value(0)); - assert!(!validity.value(1)); - assert!(validity.value(2)); - Ok(()) -} - -#[test] -fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 1, - seed: 42, - num_rounds: 3, - }; - let norms = PrimitiveArray::new::( - Buffer::copy_from([1.0, 1.0, 1.0]), - Validity::from_iter([true, true, false]), - ) - .into_array(); - let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); - let codes = FixedSizeListArray::try_new( - codes.into_array(), - 128, - Validity::from_iter([true, false, true]), - 3, - )? - .into_array(); - let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], - 3, - Validity::from_iter([true, false, true]), - )?; - let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? - .into_array(); - - assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); - Ok(()) -} - -#[test] -#[should_panic(expected = "TurboQuant code exceeds centroid count")] -fn decode_panics_on_codes_outside_centroid_table() { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 1, - seed: 42, - num_rounds: 3, - }; - let norms = - PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::NonNullable).into_array(); - let mut codes = vec![0u8; 128]; - codes[0] = 2; - let codes = PrimitiveArray::new::(codes, Validity::NonNullable); - let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 1) - .unwrap() - .into_array(); - let storage = StructArray::try_new( - FieldNames::from(["norms", "codes"]), - vec![norms, codes], - 1, - Validity::NonNullable, - ) - .unwrap(); - let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array()) - .unwrap() - .into_array(); - - drop(execute_tq_decode_from_metadata(tq, &mut ctx)); -} diff --git a/vortex-turboquant/src/tests/metadata.rs b/vortex-turboquant/src/tests/metadata.rs deleted file mode 100644 index e0d1042f02f..00000000000 --- a/vortex-turboquant/src/tests/metadata.rs +++ /dev/null @@ -1,173 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::sync::Arc; - -use prost::Message; -use rstest::rstest; -use vortex_array::dtype::DType; -use vortex_array::dtype::FieldNames; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; -use vortex_array::dtype::StructFields; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::dtype::extension::ExtVTable; -use vortex_error::VortexResult; -use vortex_error::vortex_err; - -use crate::TurboQuant; -use crate::TurboQuantMetadata; -use crate::vector::storage::CODES_FIELD; -use crate::vector::storage::NORMS_FIELD; -use crate::vector::tq_padded_dim; - -#[derive(Clone, PartialEq, Message)] -struct MetadataWire { - #[prost(enumeration = "PType", tag = "1")] - element_ptype: i32, - #[prost(uint32, tag = "2")] - dimensions: u32, - #[prost(uint32, tag = "3")] - bit_width: u32, - #[prost(uint64, tag = "4")] - seed: u64, - #[prost(uint32, tag = "5")] - num_rounds: u32, -} - -fn tq_storage_dtype( - metadata: &TurboQuantMetadata, - row_nullability: Nullability, -) -> VortexResult { - let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - Ok(DType::Struct( - StructFields::new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), - vec![ - DType::Primitive(metadata.element_ptype, row_nullability), - DType::FixedSizeList( - Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), - padded_dim, - row_nullability, - ), - ], - ), - row_nullability, - )) -} - -#[rstest] -#[case::f16(PType::F16)] -#[case::f32(PType::F32)] -#[case::f64(PType::F64)] -fn metadata_serialization_roundtrips(#[case] element_ptype: PType) -> VortexResult<()> { - let metadata = TurboQuantMetadata { - element_ptype, - dimensions: 128, - bit_width: 4, - seed: 7, - num_rounds: 3, - }; - - let encoded = TurboQuant.serialize_metadata(&metadata)?; - let decoded = TurboQuant.deserialize_metadata(&encoded)?; - - assert_eq!(decoded, metadata); - Ok(()) -} - -#[test] -fn metadata_serialization_uses_ptype_discriminants() -> VortexResult<()> { - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 4, - seed: 7, - num_rounds: 3, - }; - - let encoded = TurboQuant.serialize_metadata(&metadata)?; - let wire = MetadataWire::decode(encoded.as_slice())?; - - assert_eq!(wire.element_ptype, PType::F32 as i32); - assert_eq!(wire.dimensions, 128); - Ok(()) -} - -#[test] -fn metadata_display_matches_field_order() { - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 4, - seed: 7, - num_rounds: 3, - }; - - assert_eq!( - metadata.to_string(), - "element_ptype: f32, dimensions: 128, bit_width: 4, seed: 7, num_rounds: 3" - ); -} - -#[test] -fn dtype_validation_accepts_expected_storage() -> VortexResult<()> { - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 129, - bit_width: 2, - seed: 42, - num_rounds: 3, - }; - - ExtDType::::try_new( - metadata, - tq_storage_dtype(&metadata, Nullability::Nullable)?, - )?; - Ok(()) -} - -#[test] -fn dtype_validation_accepts_nonnullable_storage() -> VortexResult<()> { - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 129, - bit_width: 2, - seed: 42, - num_rounds: 3, - }; - - ExtDType::::try_new( - metadata, - tq_storage_dtype(&metadata, Nullability::NonNullable)?, - )?; - Ok(()) -} - -#[test] -fn dtype_validation_rejects_malformed_storage() { - let metadata = TurboQuantMetadata { - element_ptype: PType::F32, - dimensions: 128, - bit_width: 2, - seed: 42, - num_rounds: 3, - }; - let storage = DType::Struct( - StructFields::new( - FieldNames::from(["norms", "codes"]), - vec![ - DType::Primitive(PType::F32, Nullability::Nullable), - DType::FixedSizeList( - DType::Primitive(PType::U8, Nullability::Nullable).into(), - 128, - Nullability::NonNullable, - ), - ], - ), - Nullability::Nullable, - ); - - assert!(ExtDType::::try_new(metadata, storage).is_err()); -} diff --git a/vortex-turboquant/src/tests/mod.rs b/vortex-turboquant/src/tests/mod.rs deleted file mode 100644 index 3163706a002..00000000000 --- a/vortex-turboquant/src/tests/mod.rs +++ /dev/null @@ -1,141 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -#![cfg_attr( - test, - allow(clippy::unwrap_used, clippy::expect_used, clippy::unwrap_in_result) -)] - -use vortex_array::ArrayRef; -use vortex_array::EmptyMetadata; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::StructArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::dtype::NativePType; -use vortex_array::dtype::PType; -use vortex_array::memory::MemorySession; -use vortex_array::session::ArraySession; -use vortex_array::validity::Validity; -use vortex_buffer::Buffer; -use vortex_error::VortexResult; -use vortex_error::vortex_err; -use vortex_io::runtime::BlockingRuntime; -use vortex_io::runtime::single::SingleThreadRuntime; -use vortex_io::session::RuntimeSession; -use vortex_io::session::RuntimeSessionExt; -use vortex_layout::session::LayoutSession; -use vortex_session::VortexSession; -use vortex_tensor::vector::Vector; - -use crate::TQDecode; -use crate::TQEncode; -use crate::TurboQuantConfig; -use crate::initialize; - -mod encode_decode; -mod file; -mod malformed; -mod metadata; -mod parity; -mod scalar_fns; - -fn test_session() -> VortexSession { - let session = VortexSession::empty().with::(); - initialize(&session); - session -} - -fn file_session(runtime: &SingleThreadRuntime) -> VortexSession { - let session = VortexSession::empty() - .with::() - .with::() - .with::() - .with::() - .with_handle(runtime.handle()); - vortex_file::register_default_encodings(&session); - vortex_tensor::initialize(&session); - initialize(&session); - session -} - -fn vector_array( - dimensions: u32, - values: &[T], - validity: Validity, -) -> VortexResult { - assert!(dimensions > 0, "dimensions must be > 0"); - let row_count = values.len() / dimensions as usize; - - let elements = PrimitiveArray::new::( - values.iter().copied().collect::>(), - Validity::NonNullable, - ); - let fsl = FixedSizeListArray::try_new(elements.into_array(), dimensions, validity, row_count)?; - - Ok(ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl.into_array())?.into_array()) -} - -fn f32_vector_array( - dimensions: u32, - rows: usize, - scale: f32, - validity: Validity, -) -> VortexResult { - let values = (0..rows * dimensions as usize) - .map(|i| ((i % 17) as f32 - 8.0) * scale) - .collect::>(); - vector_array(dimensions, &values, validity) -} - -fn vector_values_f32(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult> { - let ext: ExtensionArray = array.execute(ctx)?; - let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; - let elements: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - Ok(elements.as_slice::().to_vec()) -} - -fn vector_validity(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let ext: ExtensionArray = array.execute(ctx)?; - let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; - fsl.validity() -} - -fn vector_element_ptype(array: &ExtensionArray) -> VortexResult { - Ok(array - .storage_array() - .dtype() - .as_fixed_size_list_element_opt() - .ok_or_else(|| vortex_err!("expected FixedSizeList vector storage"))? - .as_ptype()) -} - -fn turboquant_storage(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let ext: ExtensionArray = array.execute(ctx)?; - ext.storage_array().clone().execute(ctx) -} - -fn execute_tq_encode( - input: ArrayRef, - config: &TurboQuantConfig, - ctx: &mut ExecutionCtx, -) -> VortexResult { - TQEncode::try_new_array(input, config)? - .into_array() - .execute(ctx) -} - -fn execute_tq_decode(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - TQDecode::try_new_array(input)?.into_array().execute(ctx) -} - -fn execute_tq_decode_from_metadata( - input: ArrayRef, - ctx: &mut ExecutionCtx, -) -> VortexResult { - execute_tq_decode(input, ctx) -} diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs deleted file mode 100644 index 4360d90849d..00000000000 --- a/vortex-turboquant/src/tests/parity.rs +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::VortexSessionExecute; -use vortex_array::validity::Validity; -use vortex_error::VortexResult; -use vortex_tensor::encodings::turboquant::TurboQuantConfig as OldTurboQuantConfig; -use vortex_tensor::encodings::turboquant::turboquant_encode; - -use super::execute_tq_decode; -use super::execute_tq_encode; -use super::f32_vector_array; -use super::test_session; -use super::vector_values_f32; -use crate::TurboQuantConfig; - -#[test] -fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 2, 0.125, Validity::NonNullable)?; - let config = TurboQuantConfig::try_new(3, 42, 3)?; - - let new_encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; - let new_decoded = execute_tq_decode(new_encoded, &mut ctx)?; - let old_config = OldTurboQuantConfig { - bit_width: config.bit_width(), - seed: config.seed(), - num_rounds: config.num_rounds(), - }; - let old_decoded = turboquant_encode(input, &old_config, &mut ctx)?.execute(&mut ctx)?; - - let new_values = vector_values_f32(new_decoded, &mut ctx)?; - let old_values = vector_values_f32(old_decoded, &mut ctx)?; - - assert_eq!(new_values, old_values); - Ok(()) -} diff --git a/vortex-turboquant/src/tests/scalar_fns.rs b/vortex-turboquant/src/tests/scalar_fns.rs deleted file mode 100644 index 31406686ee0..00000000000 --- a/vortex-turboquant/src/tests/scalar_fns.rs +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use vortex_array::EmptyMetadata; -use vortex_array::IntoArray; -use vortex_array::VortexSessionExecute; -use vortex_array::scalar_fn::ScalarFnVTable; -use vortex_array::validity::Validity; -use vortex_error::VortexResult; - -use super::f32_vector_array; -use super::test_session; -use super::vector_validity; -use crate::TQDecode; -use crate::TQEncode; -use crate::TurboQuant; -use crate::TurboQuantConfig; -use crate::vtable::tq_metadata; - -#[test] -fn scalar_fn_ids_and_options_roundtrip() -> VortexResult<()> { - let session = test_session(); - let config = TurboQuantConfig::try_new(4, 7, 2)?; - - assert_eq!(TQEncode.id().as_ref(), "vortex.turboquant.encode"); - assert_eq!(TQDecode.id().as_ref(), "vortex.turboquant.decode"); - - let encode_metadata = TQEncode.serialize(&config)?.unwrap(); - let decode_metadata = TQDecode.serialize(&EmptyMetadata)?.unwrap(); - - assert_eq!(TQEncode.deserialize(&encode_metadata, &session)?, config); - assert!(decode_metadata.is_empty()); - assert_eq!( - TQDecode.deserialize(&decode_metadata, &session)?, - EmptyMetadata - ); - Ok(()) -} - -#[test] -fn scalar_fn_arrays_encode_and_decode_vectors() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; - let config = TurboQuantConfig::try_new(3, 42, 3)?; - - let encoded_lazy = TQEncode::try_new_array(input, &config)?; - let encoded_metadata = tq_metadata(encoded_lazy.dtype())?; - assert_eq!(encoded_metadata.dimensions, 128); - assert_eq!(encoded_metadata.bit_width, config.bit_width()); - assert!(encoded_lazy.dtype().as_extension().is::()); - - let encoded = encoded_lazy.into_array().execute(&mut ctx)?; - let decoded_lazy = TQDecode::try_new_array(encoded)?; - let decoded = decoded_lazy.into_array().execute(&mut ctx)?; - let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; - - assert!(validity.value(0)); - assert!(!validity.value(1)); - Ok(()) -} diff --git a/vortex-turboquant/src/vector/mod.rs b/vortex-turboquant/src/vector/mod.rs deleted file mode 100644 index f4fe8726103..00000000000 --- a/vortex-turboquant/src/vector/mod.rs +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Vector-side helpers: normalization, quantization, and physical storage layout. - -pub(crate) mod normalize; -pub(crate) mod quantize; -pub(crate) mod storage; - -use vortex_error::VortexResult; -use vortex_error::vortex_err; - -/// Compute the padded SORF dimension for an original vector dimension. -/// -/// The SORF transform requires a power-of-two width, so non-power-of-two input dimensions are -/// padded with zeros up to the next power of two. The padded dimension is stored implicitly via -/// [`TurboQuantMetadata::dimensions`](crate::TurboQuantMetadata) plus the codes child's -/// `FixedSizeList` width and recovered at decode time via this function. Returns an error when -/// the next power of two overflows the input integer type. -pub(crate) fn tq_padded_dim(dimensions: u32) -> VortexResult { - let padded_dim = dimensions - .checked_next_power_of_two() - .ok_or_else(|| vortex_err!("TurboQuant padded dimension overflow for {dimensions}"))?; - - usize::try_from(padded_dim) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit usize")) -} diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs deleted file mode 100644 index c0a5c9f6f06..00000000000 --- a/vortex-turboquant/src/vector/normalize.rs +++ /dev/null @@ -1,236 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant-local vector normalization. - -// TODO(connor): Remove this comment once we delete the other version in `vortex-tensor`. -// The tensor crate also has a `normalize_as_l2_denorm` helper, but TurboQuant needs different -// validity semantics: a null vector is not a zero vector, so invalid rows keep their row validity -// on both `L2Denorm` children and downstream quantization skips them. - -use num_traits::Float; -use vortex_array::ArrayRef; -use vortex_array::EmptyMetadata; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::dtype::NativePType; -use vortex_array::match_each_float_ptype; -use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure_eq; -use vortex_error::vortex_err; -use vortex_mask::Mask; -use vortex_mask::MaskValues; -use vortex_tensor::scalar_fns::l2_denorm::L2Denorm; -use vortex_tensor::scalar_fns::l2_norm::L2Norm; -use vortex_tensor::vector::AnyVector; -use vortex_tensor::vector::Vector; - -/// Normalize a `Vector` array and wrap it with its original row norms with [`L2Denorm`]. -/// -/// This preserves input row validity on both [`L2Denorm`] children. Or in other words, validity is -/// propagated down to the children so that TurboQuant can skip quantizing those vectors (as it does -/// not have a good way to represent 0 vectors in its quantized domain). -pub(crate) fn tq_normalize_as_l2_denorm( - input: ArrayRef, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let row_count = input.len(); - let vector_metadata = input - .dtype() - .as_extension_opt() - .and_then(|ext_dtype| ext_dtype.metadata_opt::()) - .ok_or_else(|| vortex_err!("TurboQuant normalization expects a Vector extension array"))?; - let dimensions = vector_metadata.dimensions() as usize; - let vector_validity = input.validity()?; - - // Use `L2Norm` to calculate the normals for each vector. - let norms: ArrayRef = L2Norm::try_new_array(input.clone(), row_count)? - .into_array() - .execute(ctx)?; - let primitive_norms: PrimitiveArray = norms.clone().execute(ctx)?; - - let input: ExtensionArray = input.execute(ctx)?; - let storage: FixedSizeListArray = input.storage_array().clone().execute(ctx)?; - vortex_ensure_eq!( - storage.list_size() as usize, - dimensions, - "Vector storage dimension must be {dimensions}, got {}", - storage.list_size() - ); - let elements: PrimitiveArray = storage.elements().clone().execute(ctx)?; - - let mask = vector_validity.execute_mask(row_count, ctx)?; - - let normalized = match_each_float_ptype!(elements.ptype(), |T| { - normalize_vectors::( - &elements, - &primitive_norms, - &mask, - dimensions, - vector_validity.clone(), - ) - })?; - - // SAFETY: matches the lossy-encoding relaxation documented on - // [`L2Denorm::new_array_unchecked`]. Norms come from `L2Norm` over the same input, so they - // match the vector element type and row count. Valid nonzero rows are divided by their stored - // norm and are unit-norm. Valid zero-norm rows and invalid rows use physical zero placeholders; - // invalid rows remain guarded by row-level invalid validity. - unsafe { L2Denorm::new_array_unchecked(normalized, norms, row_count) } -} - -fn normalize_vectors( - elements: &PrimitiveArray, - norms: &PrimitiveArray, - mask: &Mask, - dimensions: usize, - vector_validity: Validity, -) -> VortexResult -where - T: Float + NativePType, -{ - let num_vectors = norms.len(); - - let values = elements.as_slice::(); - let norm_values = norms.as_slice::(); - - let output_len = num_vectors - .checked_mul(dimensions) - .ok_or_else(|| vortex_err!("TurboQuant normalized vector length overflow"))?; - let mut output = BufferMut::::with_capacity(output_len); - - // The total number of pushes is always exactly `num_vectors * dimensions == output_len` - // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. - match mask { - Mask::AllFalse(_) => { - // Every row is invalid: bulk-fill the output with zero placeholders. - // - // SAFETY: `output` was allocated with capacity `output_len`, and this push writes - // exactly `output_len` zero placeholders. - unsafe { output.push_n_unchecked(T::zero(), output_len) }; - } - Mask::AllTrue(_) => { - for i in 0..num_vectors { - // SAFETY: `output` was allocated with capacity `output_len = num_vectors * - // dimensions`. This loop runs `num_vectors` times and each call pushes exactly - // `dimensions` elements, so capacity for `dimensions` more elements always - // remains. - unsafe { normalize_one_row::(&mut output, values, norm_values, dimensions, i) }; - } - } - Mask::Values(values_mask) => { - // SAFETY: `output` was allocated with capacity `output_len = num_vectors * - // dimensions`, which is the bound the helper requires. - unsafe { - normalize_vectors_with_mask::( - &mut output, - values, - norm_values, - dimensions, - num_vectors, - values_mask, - ) - }; - } - } - - // Vector elements are always non-nullable. - let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); - - #[expect( - clippy::cast_possible_truncation, - reason = "this initially came from a u32" - )] - let storage = FixedSizeListArray::try_new( - elements.into_array(), - dimensions as u32, - vector_validity, - num_vectors, - )?; - - Ok( - ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, storage.into_array())? - .into_array(), - ) -} - -/// Normalize a single valid row, or push `dimensions` zero placeholders if the row's L2 norm -/// is zero. -/// -/// A valid vector with L2 norm zero is all zeros, so dividing through it would be undefined. -/// Treating it the same as an invalid row preserves the original semantics. -/// -/// # Safety -/// -/// `output` must have capacity for at least `dimensions` more elements before this call. -unsafe fn normalize_one_row( - output: &mut BufferMut, - values: &[T], - norm_values: &[T], - dimensions: usize, - i: usize, -) where - T: Float + NativePType, -{ - let norm = norm_values[i]; - - if norm == T::zero() { - // SAFETY: caller guarantees capacity for `dimensions` more elements. - unsafe { output.push_n_unchecked(T::zero(), dimensions) }; - } else { - let row_values = &values[i * dimensions..][..dimensions]; - - for &value in row_values { - // SAFETY: caller guarantees capacity for `dimensions` more elements. - unsafe { output.push_unchecked(value / norm) }; - } - } -} - -/// Walk the pre-cached run boundaries of a `Values` mask, bulk-pushing zero placeholders for -/// invalid runs and normalizing valid runs row by row. -/// -/// # Safety -/// -/// `output` must have capacity for at least `num_vectors * dimensions` more elements before -/// this call. -unsafe fn normalize_vectors_with_mask( - output: &mut BufferMut, - values: &[T], - norm_values: &[T], - dimensions: usize, - num_vectors: usize, - values_mask: &MaskValues, -) where - T: Float + NativePType, -{ - let mut cursor = 0; - - for &(start, end) in values_mask.slices() { - if start > cursor { - // SAFETY: capacity invariant from caller. - unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; - } - - for i in start..end { - // SAFETY: capacity invariant from caller — each call pushes `dimensions` and the - // total number of valid rows in the mask is bounded by `num_vectors`. - unsafe { normalize_one_row::(output, values, norm_values, dimensions, i) }; - } - - cursor = end; - } - - if cursor < num_vectors { - // SAFETY: capacity invariant from caller. - unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; - } -} diff --git a/vortex-turboquant/src/vector/quantize.rs b/vortex-turboquant/src/vector/quantize.rs deleted file mode 100644 index 0861b9f6805..00000000000 --- a/vortex-turboquant/src/vector/quantize.rs +++ /dev/null @@ -1,181 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! Core TurboQuant quantization helpers. -//! -//! Quantization consumes the TurboQuant-local normalized `Vector` child. Valid rows are transformed -//! and mapped to scalar centroid indices. Invalid rows remain in the full-length output but are -//! skipped: their physical code bytes are placeholders guarded by the `codes` row validity. -//! -//! This matters because TurboQuant's scalar codebook is optimized for coordinates of transformed -//! unit-norm vectors. The codebook does not generally contain an exact zero centroid, and a -//! physical code byte of `0` means "centroid 0", not "zero coordinate". Null vectors therefore -//! should not be converted to zero vectors and fed through the quantizer. - -use half::f16; -use vortex_array::ExecutionCtx; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::dtype::PType; -use vortex_buffer::Buffer; -use vortex_buffer::BufferMut; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_err; -use vortex_mask::Mask; - -use super::tq_padded_dim; -use crate::TurboQuantConfig; -use crate::centroids::compute_centroid_boundaries; -use crate::centroids::compute_or_get_centroids; -use crate::centroids::find_nearest_centroid; -use crate::sorf::SorfMatrix; - -/// Shared intermediate results from the quantization loop. -pub(crate) struct QuantizationResult { - pub(crate) all_indices: Buffer, - pub(crate) padded_dim: usize, -} - -pub(crate) fn empty_quantization(padded_dim: usize) -> QuantizationResult { - QuantizationResult { - all_indices: Buffer::empty(), - padded_dim, - } -} - -/// Core quantization: transform and quantize already-normalized rows. -/// -/// # Safety -/// -/// The input `fsl` must contain unit-norm vectors (already L2-normalized) for every valid row. -/// Invalid rows are left row-aligned in the output but are not transformed or quantized. The -/// transform and centroid lookup happen in f32. -pub(crate) unsafe fn turboquant_quantize_core( - fsl: &FixedSizeListArray, - config: &TurboQuantConfig, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let dimension = fsl.list_size(); - let num_vectors = fsl.len(); - let padded_dim = tq_padded_dim(dimension)?; - - let sorf_transform = - SorfMatrix::try_new(padded_dim, config.num_rounds() as usize, config.seed())?; - debug_assert_eq!(sorf_transform.padded_dim(), padded_dim); - let padded_dim_u32 = u32::try_from(padded_dim) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - - let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - let f32_elements = cast_to_f32(elements_prim)?; - let validity = fsl.validity()?; - let mask = validity.execute_mask(num_vectors, ctx)?; - - let centroids = compute_or_get_centroids(padded_dim_u32, config.bit_width())?; - let boundaries = compute_centroid_boundaries(¢roids); - - let codes_len = num_vectors - .checked_mul(padded_dim) - .ok_or_else(|| vortex_err!("TurboQuant codes length overflow"))?; - let mut all_indices = BufferMut::::with_capacity(codes_len); - - let mut padded = vec![0.0f32; padded_dim]; - let mut transformed = vec![0.0f32; padded_dim]; - - // Pad, SORF-transform, and quantize a single row, pushing `padded_dim` codes into - // `all_indices`. Captures the read-only inputs and the scratch buffers so each call site - // only needs to pass `all_indices` and the row index. - // - // NB: `all_indices` cannot be captured here: the `Values` arm interleaves the closure call - // with direct `all_indices.push_n_unchecked` calls. - let f32_slice = f32_elements.as_slice(); - let dimension = dimension as usize; - let mut quantize_row = |all_indices: &mut BufferMut, row: usize| { - // Reuse `padded` and `transformed` from the outer scope. - padded[..dimension].copy_from_slice(&f32_slice[row * dimension..][..dimension]); - padded[dimension..].fill(0.0); - sorf_transform.transform(&padded, &mut transformed); - - for &value in &transformed { - // SAFETY: total pushes across all match arms equal `codes_len`. - unsafe { all_indices.push_unchecked(find_nearest_centroid(value, &boundaries)) }; - } - }; - - // The total number of pushes is always exactly `num_vectors * padded_dim == codes_len` - // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. - match &mask { - Mask::AllFalse(_) => { - // Every row is invalid: bulk-fill placeholder zero codes. - // - // SAFETY: `all_indices` was allocated with capacity `codes_len`, and this push - // writes exactly `codes_len` zero codes. - unsafe { all_indices.push_n_unchecked(0, codes_len) }; - } - Mask::AllTrue(_) => { - for row in 0..num_vectors { - quantize_row(&mut all_indices, row); - } - } - Mask::Values(values_mask) => { - let mut cursor = 0; - - for &(start, end) in values_mask.slices() { - if start > cursor { - // SAFETY: total pushes across all arms equal `codes_len`. - unsafe { all_indices.push_n_unchecked(0, (start - cursor) * padded_dim) }; - } - - for row in start..end { - quantize_row(&mut all_indices, row); - } - - cursor = end; - } - - if cursor < num_vectors { - // SAFETY: total pushes across all arms equal `codes_len`. - unsafe { all_indices.push_n_unchecked(0, (num_vectors - cursor) * padded_dim) }; - } - } - } - - Ok(QuantizationResult { - all_indices: all_indices.freeze(), - padded_dim, - }) -} - -/// Cast a float [`PrimitiveArray`] to a `Buffer`. -/// -/// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively -/// in f32. This function handles the cast from any float ptype: -/// -/// - f16: losslessly widened to f32. -/// - f32: zero-copy buffer extraction. -/// - f64: truncated to f32 precision. Values outside f32 range become +/- infinity. This is -/// acceptable because callers of this function operate in f32 and document this constraint. -fn cast_to_f32(prim: PrimitiveArray) -> VortexResult> { - match prim.ptype() { - PType::F16 => Ok(prim - .as_slice::() - .iter() - .map(|&v| f32::from(v)) - .collect()), - PType::F32 => Ok(prim.into_buffer()), - PType::F64 => Ok(prim - .as_slice::() - .iter() - .map(|&v| { - #[expect( - clippy::cast_possible_truncation, - reason = "f64 values outside f32 range become infinity, matching tensor TQ" - )] - let v = v as f32; - v - }) - .collect()), - other => vortex_bail!("expected float elements, got {other:?}"), - } -} diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs deleted file mode 100644 index d1b4f06cc05..00000000000 --- a/vortex-turboquant/src/vector/storage.rs +++ /dev/null @@ -1,164 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant physical storage helpers. -//! -//! TurboQuant storage is row-aligned and full length: -//! -//! ```text -//! Struct { -//! norms: Primitive, -//! codes: FixedSizeList, padded_dim, vector_validity>, -//! } -//! ``` -//! -//! Row nullability is carried on the outer struct and on the `norms` and `codes` field arrays. -//! This is deliberate duplication: null vectors remain null throughout encode/decode instead of being -//! converted into zero vectors. The code bytes for invalid rows are physical placeholders only; the -//! field-level validity records that those rows were not quantized. -//! -//! Parsing treats the outer struct validity as authoritative. Child validity may be wider than -//! the struct validity (for example after a generic mask only updates the struct validity), but -//! each child must be valid wherever the struct row is valid. - -use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::arrays::StructArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex_array::arrays::struct_::StructArrayExt; -use vortex_array::dtype::FieldNames; -use vortex_array::validity::Validity; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; -use vortex_mask::Mask; - -use super::quantize::QuantizationResult; -use crate::vtable::TurboQuantMetadata; -use crate::vtable::tq_metadata; - -/// Name of the stored row-norm child. -pub(crate) const NORMS_FIELD: &str = "norms"; - -/// Name of the stored quantized-code child. -pub(crate) const CODES_FIELD: &str = "codes"; - -/// Executed storage children of a TurboQuant extension array plus the authoritative outer -/// struct validity. Every child is row-aligned to `len` and every child's validity covers -/// `vector_validity`. -pub(crate) struct TurboQuantParsedStorage { - /// Metadata recovered from the input extension dtype. - pub(crate) metadata: TurboQuantMetadata, - /// Authoritative row validity for the quantized vectors, taken from the outer struct. - pub(crate) vector_validity: Validity, - /// Per-row stored L2 norm of the original input vector, in `metadata.element_ptype`. - pub(crate) norms: PrimitiveArray, - /// Flat `u8` per-row centroid indices, `num_vectors * padded_dim` entries long. - pub(crate) codes: PrimitiveArray, - /// Row count. - pub(crate) len: usize, -} - -/// Build the `codes: FixedSizeList, padded_dim>` storage child. -/// -/// Each row of `padded_dim` u8 codes indexes into the deterministic centroid codebook derived -/// from `(padded_dim, bit_width)`. The centroid values are intentionally not stored in the array. -pub(crate) fn build_codes_child( - num_vectors: usize, - quantization: QuantizationResult, - vector_validity: Validity, -) -> VortexResult { - let codes = PrimitiveArray::new::(quantization.all_indices, Validity::NonNullable); - let padded_dim_u32 = u32::try_from(quantization.padded_dim) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - - Ok(FixedSizeListArray::try_new( - codes.into_array(), - padded_dim_u32, - vector_validity, - num_vectors, - )? - .into_array()) -} - -/// Build the TurboQuant `Struct { norms, codes }` storage array. -pub(crate) fn build_storage( - norms: ArrayRef, - codes: ArrayRef, - len: usize, - vector_validity: Validity, -) -> VortexResult { - Ok(StructArray::try_new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), - vec![norms, codes], - len, - vector_validity, - )? - .into_array()) -} - -/// Parse a TurboQuant extension array into executed storage children. -pub(crate) fn parse_storage( - input: ArrayRef, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let metadata = tq_metadata(input.dtype())?; - let ext: ExtensionArray = input.execute(ctx)?; - let storage: StructArray = ext.storage_array().clone().execute(ctx)?; - - let norms: PrimitiveArray = storage - .unmasked_field_by_name(NORMS_FIELD)? - .clone() - .execute(ctx)?; - - let codes_fsl: FixedSizeListArray = storage - .unmasked_field_by_name(CODES_FIELD)? - .clone() - .execute(ctx)?; - let codes: PrimitiveArray = codes_fsl.elements().clone().execute(ctx)?; - - let len = storage.len(); - let struct_validity = storage.struct_validity(); - let norms_validity = norms.validity()?; - let codes_validity = codes_fsl.validity()?; - - let struct_mask = struct_validity.execute_mask(len, ctx)?; - let norms_mask = norms_validity.execute_mask(len, ctx)?; - let codes_mask = codes_validity.execute_mask(len, ctx)?; - validate_child_validity_covers_struct(&struct_mask, &norms_mask, &codes_mask)?; - - Ok(TurboQuantParsedStorage { - metadata, - vector_validity: struct_validity, - norms, - codes, - len, - }) -} - -/// Validate that both child masks cover the struct mask: every row that the struct considers -/// valid must also be valid in the `norms` and `codes` children. -/// -/// `struct_mask & !child_mask` selects rows where the struct is valid but the child is not. If -/// no such row exists, the child covers the struct. [`Mask::bitand_not`] is variant-specialized, -/// so this short-circuits in `O(1)` when either mask is `AllTrue` or `AllFalse`. -fn validate_child_validity_covers_struct( - struct_mask: &Mask, - norms_mask: &Mask, - codes_mask: &Mask, -) -> VortexResult<()> { - vortex_ensure!( - struct_mask.clone().bitand_not(norms_mask).all_false(), - "TurboQuant {NORMS_FIELD} row validity must cover storage validity" - ); - vortex_ensure!( - struct_mask.clone().bitand_not(codes_mask).all_false(), - "TurboQuant {CODES_FIELD} row validity must cover storage validity" - ); - Ok(()) -} diff --git a/vortex-turboquant/src/vtable.rs b/vortex-turboquant/src/vtable.rs deleted file mode 100644 index 854bcee6c70..00000000000 --- a/vortex-turboquant/src/vtable.rs +++ /dev/null @@ -1,241 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -use std::fmt; -use std::sync::Arc; - -use prost::Message; -use vortex_array::dtype::DType; -use vortex_array::dtype::FieldNames; -use vortex_array::dtype::Nullability; -use vortex_array::dtype::PType; -use vortex_array::dtype::StructFields; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::dtype::extension::ExtId; -use vortex_array::dtype::extension::ExtVTable; -use vortex_array::scalar::ScalarValue; -use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use vortex_error::vortex_ensure; -use vortex_error::vortex_ensure_eq; -use vortex_error::vortex_err; - -use crate::TurboQuantConfig; -use crate::config::MIN_DIMENSION; -use crate::vector::storage::CODES_FIELD; -use crate::vector::storage::NORMS_FIELD; -use crate::vector::tq_padded_dim; - -/// TurboQuant logical extension type. Per-array configuration lives in [`TurboQuantMetadata`]. -#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] -pub struct TurboQuant; - -/// Serialized metadata for a TurboQuant extension array. The fields together suffice to -/// reconstruct the SORF transform and centroid codebook at decode time. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub struct TurboQuantMetadata { - /// Original vector element ptype and stored row-norm ptype. Restricted to `f16` / `f32` / - /// `f64`. - pub element_ptype: PType, - /// Original vector dimension before SORF padding to the next power of two. - pub dimensions: u32, - /// Bits per coordinate in the scalar quantizer codebook (`1..=8`). - pub bit_width: u8, - /// Seed used to derive the deterministic SORF transform. - pub seed: u64, - /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. - pub num_rounds: u8, -} - -impl ExtVTable for TurboQuant { - type Metadata = TurboQuantMetadata; - type NativeValue<'a> = &'a ScalarValue; - - fn id(&self) -> ExtId { - ExtId::new("vortex.turboquant") - } - - fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { - validate_tq_metadata(metadata)?; - - let proto = TurboQuantMetadataProto { - element_ptype: metadata.element_ptype as i32, - dimensions: metadata.dimensions, - bit_width: u32::from(metadata.bit_width), - seed: metadata.seed, - num_rounds: u32::from(metadata.num_rounds), - }; - - Ok(proto.encode_to_vec()) - } - - fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { - let proto = TurboQuantMetadataProto::decode(metadata) - .map_err(|e| vortex_err!("Failed to decode TurboQuantMetadata: {e}"))?; - let bit_width = u8::try_from(proto.bit_width) - .map_err(|_| vortex_err!("TurboQuant bit_width does not fit u8"))?; - let num_rounds = u8::try_from(proto.num_rounds) - .map_err(|_| vortex_err!("TurboQuant num_rounds does not fit u8"))?; - let element_ptype = PType::try_from(proto.element_ptype).map_err(|e| { - vortex_err!( - "invalid TurboQuant metadata element_ptype code {}: {e}", - proto.element_ptype - ) - })?; - - let metadata = TurboQuantMetadata { - element_ptype, - dimensions: proto.dimensions, - bit_width, - seed: proto.seed, - num_rounds, - }; - validate_tq_metadata(&metadata)?; - - Ok(metadata) - } - - fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { - validate_tq_metadata(ext_dtype.metadata())?; - validate_tq_storage_dtype(ext_dtype.metadata(), ext_dtype.storage_dtype()) - } - - fn unpack_native<'a>( - _ext_dtype: &'a ExtDType, - storage_value: &'a ScalarValue, - ) -> VortexResult> { - Ok(storage_value) - } -} - -/// Wire-format representation of [`TurboQuantMetadata`]. Field tags MUST NOT change once -/// shipped; new fields must use unused tags and remain optional. -#[derive(Clone, PartialEq, Message)] -struct TurboQuantMetadataProto { - #[prost(enumeration = "PType", tag = "1")] - element_ptype: i32, - #[prost(uint32, tag = "2")] - dimensions: u32, - #[prost(uint32, tag = "3")] - bit_width: u32, - #[prost(uint64, tag = "4")] - seed: u64, - #[prost(uint32, tag = "5")] - num_rounds: u32, -} - -/// Extract TurboQuant metadata from a dtype. -/// -/// Returns an error when the dtype is not the TurboQuant extension type. -pub(crate) fn tq_metadata(dtype: &DType) -> VortexResult { - let ext_dtype = dtype - .as_extension_opt() - .ok_or_else(|| vortex_err!("expected a TurboQuant extension array, got {dtype}"))?; - - let metadata = ext_dtype - .metadata_opt::() - .ok_or_else(|| vortex_err!("expected a TurboQuant extension array, got {dtype}"))?; - - Ok(*metadata) -} - -pub(crate) fn tq_storage_dtype( - metadata: &TurboQuantMetadata, - row_nullability: Nullability, -) -> VortexResult { - let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - - Ok(DType::Struct( - StructFields::new( - FieldNames::from([NORMS_FIELD, CODES_FIELD]), - vec![ - DType::Primitive(metadata.element_ptype, row_nullability), - DType::FixedSizeList( - Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), - padded_dim, - row_nullability, - ), - ], - ), - row_nullability, - )) -} - -/// Validate [`TurboQuantMetadata`] invariants. Called on both serialize and deserialize so a -/// corrupted on-disk metadata block errors out rather than decoding into nonsense. -fn validate_tq_metadata(metadata: &TurboQuantMetadata) -> VortexResult<()> { - vortex_ensure!( - metadata.dimensions >= MIN_DIMENSION, - "TurboQuant dimensions must be >= {MIN_DIMENSION}, got {}", - metadata.dimensions - ); - vortex_ensure!( - metadata.element_ptype.is_float(), - "TurboQuant element_ptype must be a float, got {:?}", - metadata.element_ptype - ); - - tq_padded_dim(metadata.dimensions)?; - - TurboQuantConfig::try_new(metadata.bit_width, metadata.seed, metadata.num_rounds).map(|_| ()) -} - -/// Validate that `dtype` matches the storage shape produced by [`tq_storage_dtype`] for -/// `metadata`. Called from [`TurboQuant::validate_dtype`]. -fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> VortexResult<()> { - let DType::Struct(fields, _) = dtype else { - vortex_bail!("TurboQuant storage dtype must be a Struct, got {dtype}"); - }; - let expected_names = FieldNames::from([NORMS_FIELD, CODES_FIELD]); - vortex_ensure_eq!( - fields.names(), - &expected_names, - "TurboQuant storage fields must be {expected_names}, got {}", - fields.names() - ); - - let Some(norms_dtype) = fields.field(NORMS_FIELD) else { - vortex_bail!("TurboQuant storage missing {NORMS_FIELD} field"); - }; - let DType::Primitive(norms_ptype, _) = norms_dtype else { - vortex_bail!("TurboQuant {NORMS_FIELD} field must be primitive, got {norms_dtype}"); - }; - vortex_ensure_eq!( - norms_ptype, - metadata.element_ptype, - "TurboQuant {NORMS_FIELD} ptype must be {}, got {norms_ptype}", - metadata.element_ptype - ); - - let Some(codes_dtype) = fields.field(CODES_FIELD) else { - vortex_bail!("TurboQuant storage missing {CODES_FIELD} field"); - }; - let DType::FixedSizeList(element_dtype, list_size, _) = codes_dtype else { - vortex_bail!("TurboQuant {CODES_FIELD} field must be fixed-size-list, got {codes_dtype}"); - }; - let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - vortex_ensure_eq!( - list_size, - padded_dim, - "TurboQuant {CODES_FIELD} list size must be {padded_dim}, got {list_size}" - ); - vortex_ensure_eq!( - element_dtype.as_ref(), - &DType::Primitive(PType::U8, Nullability::NonNullable), - "TurboQuant {CODES_FIELD} elements must be non-nullable u8, got {element_dtype}" - ); - - Ok(()) -} - -impl fmt::Display for TurboQuantMetadata { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "element_ptype: {}, dimensions: {}, bit_width: {}, seed: {}, num_rounds: {}", - self.element_ptype, self.dimensions, self.bit_width, self.seed, self.num_rounds - ) - } -} diff --git a/vortex/Cargo.toml b/vortex/Cargo.toml index 982127a4035..6923f307917 100644 --- a/vortex/Cargo.toml +++ b/vortex/Cargo.toml @@ -85,10 +85,6 @@ unstable_encodings = [ "vortex-zstd?/unstable_encodings", ] -[[example]] -name = "turboquant_vector_search" -required-features = ["files", "tokio", "unstable_encodings"] - [[bench]] name = "single_encoding_throughput" harness = false diff --git a/vortex/benches/single_encoding_throughput.rs b/vortex/benches/single_encoding_throughput.rs index ad94adf8341..2db4ebaf0ba 100644 --- a/vortex/benches/single_encoding_throughput.rs +++ b/vortex/benches/single_encoding_throughput.rs @@ -476,136 +476,3 @@ fn bench_zstd_decompress_string(bencher: Bencher) { .with_inputs(|| (&compressed, SESSION.create_execution_ctx())) .bench_refs(|(a, ctx)| canonicalize((**a).clone(), ctx)); } - -// TODO(connor): Remove this. -// TurboQuant vector quantization benchmarks. -#[cfg(feature = "unstable_encodings")] -mod turboquant_benches { - use divan::Bencher; - use paste::paste; - use rand::SeedableRng; - use rand::rngs::StdRng; - use vortex::array::EmptyMetadata; - use vortex::array::IntoArray; - use vortex::array::arrays::Extension; - use vortex::array::arrays::ExtensionArray; - use vortex::array::arrays::FixedSizeListArray; - use vortex::array::arrays::PrimitiveArray; - use vortex::array::arrays::scalar_fn::ScalarFnArrayExt; - use vortex::array::dtype::extension::ExtDType; - use vortex::array::validity::Validity; - use vortex_array::VortexSessionExecute; - use vortex_buffer::BufferMut; - use vortex_tensor::encodings::turboquant::TurboQuantConfig; - use vortex_tensor::encodings::turboquant::turboquant_encode_unchecked; - use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm; - use vortex_tensor::vector::Vector; - - use super::SESSION; - use super::with_byte_counter; - - const NUM_VECTORS: usize = 1_000; - - /// Generate `num_vectors` random f32 Vector extension arrays of the given dimension - /// using i.i.d. standard normal components. This is a conservative test distribution: - /// real neural network embeddings typically have structure (clustered, anisotropic) - /// that the SRHT exploits for better quantization, so Gaussian i.i.d. is a - /// worst-case baseline for TurboQuant. - fn setup_vector_ext(dim: usize) -> ExtensionArray { - let mut rng = StdRng::seed_from_u64(42); - let normal = rand_distr::Normal::new(0.0f32, 1.0).unwrap(); - - let mut buf = BufferMut::::with_capacity(NUM_VECTORS * dim); - for _ in 0..(NUM_VECTORS * dim) { - buf.push(rand_distr::Distribution::sample(&normal, &mut rng)); - } - - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim as u32, - Validity::NonNullable, - NUM_VECTORS, - ) - .unwrap(); - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) - .unwrap() - .erased(); - ExtensionArray::new(ext_dtype, fsl.into_array()) - } - - fn turboquant_config(bit_width: u8) -> TurboQuantConfig { - TurboQuantConfig { - bit_width, - seed: 123, - num_rounds: 3, - } - } - - fn setup_normalized_vector_ext(dim: usize) -> ExtensionArray { - let ext = setup_vector_ext(dim); - let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext.into_array(), &mut ctx) - .unwrap() - .child_at(0) - .clone(); - normalized.execute::(&mut ctx).unwrap() - } - - macro_rules! turboquant_bench { - (compress, $dim:literal, $bits:literal, $name:ident) => { - paste! { - #[divan::bench(name = concat!("turboquant_encode_dim", stringify!($dim), "_", stringify!($bits), "bit"))] - fn $name(bencher: Bencher) { - let normalized_ext = setup_normalized_vector_ext($dim); - let config = turboquant_config($bits); - with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) - .with_inputs(|| (normalized_ext.clone(), SESSION.create_execution_ctx())) - .bench_refs(|(a, ctx)| { - let normalized = a - .as_ref() - .as_opt::() - .expect("normalized benchmark input should be an Extension array"); - // SAFETY: Benchmark inputs are normalized once up front so the timed - // region measures only TurboQuant encoding. - unsafe { turboquant_encode_unchecked(normalized, &config, ctx) } - .unwrap() - }); - } - } - }; - (decompress, $dim:literal, $bits:literal, $name:ident) => { - paste! { - #[divan::bench(name = concat!("turboquant_decompress_dim", stringify!($dim), "_", stringify!($bits), "bit"))] - fn $name(bencher: Bencher) { - let normalized_ext = setup_normalized_vector_ext($dim); - let config = turboquant_config($bits); - let mut ctx = SESSION.create_execution_ctx(); - let compressed = unsafe { - turboquant_encode_unchecked(normalized_ext.as_view(), &config, &mut ctx) - } - .unwrap(); - with_byte_counter(bencher, (NUM_VECTORS * $dim * 4) as u64) - .with_inputs(|| (&compressed, SESSION.create_execution_ctx())) - .bench_refs(|(a, ctx)| { - (**a).clone() - .into_array() - .execute::(ctx) - .unwrap() - }); - } - } - }; - } - - turboquant_bench!(compress, 128, 4, bench_tq_compress_128_4); - turboquant_bench!(decompress, 128, 4, bench_tq_decompress_128_4); - turboquant_bench!(compress, 768, 4, bench_tq_compress_768_4); - turboquant_bench!(decompress, 768, 4, bench_tq_decompress_768_4); - turboquant_bench!(compress, 1024, 2, bench_tq_compress_1024_2); - turboquant_bench!(decompress, 1024, 2, bench_tq_decompress_1024_2); - turboquant_bench!(compress, 1024, 4, bench_tq_compress_1024_4); - turboquant_bench!(decompress, 1024, 4, bench_tq_decompress_1024_4); - turboquant_bench!(compress, 1024, 8, bench_tq_compress_1024_8); - turboquant_bench!(decompress, 1024, 8, bench_tq_decompress_1024_8); -} diff --git a/vortex/examples/turboquant_vector_search.rs b/vortex/examples/turboquant_vector_search.rs deleted file mode 100644 index 4c1c0252528..00000000000 --- a/vortex/examples/turboquant_vector_search.rs +++ /dev/null @@ -1,399 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant vector-search roundtrip on a vector-embedding dataset. -//! -//! Load a parquet dataset (cohere-small), wrap the `emb` column as a `Vector` -//! extension, compress with BtrBlocks + TurboQuant, write to an in-memory Vortex file, then read -//! the file back twice: -//! -//! 1. plain scan — decode to canonical `FixedSizeList` and verify the per-element diff -//! against the original. TurboQuant is lossy, so we only check the reconstructed values are -//! within a tolerance. -//! 2. scan with a pushed-down cosine-similarity filter `cosine_similarity(emb, query) > thresh`. -//! The `CosineSimilarity` scalar fn is expressed directly as a filter `Expression`, so row -//! selection happens inside the scan rather than after materialization. -//! -//! The parquet file is cached under `vortex-bench/data//` after the first download. Run -//! with: -//! -//! ```sh -//! cargo run --example turboquant_vector_search \ -//! -p vortex --features unstable_encodings --release -//! ``` - -use std::path::PathBuf; -use std::time::Instant; - -use anyhow::Result; -use anyhow::bail; -use anyhow::ensure; -use futures::TryStreamExt; -use vortex::VortexSessionDefault; -use vortex::array::ArrayRef; -use vortex::array::EmptyMetadata; -use vortex::array::IntoArray; -use vortex::array::VortexSessionExecute; -use vortex::array::arrays::ChunkedArray; -use vortex::array::arrays::ExtensionArray; -use vortex::array::arrays::FixedSizeListArray; -use vortex::array::arrays::PrimitiveArray; -use vortex::array::arrays::StructArray; -use vortex::array::arrays::extension::ExtensionArrayExt; -use vortex::array::arrays::fixed_size_list::FixedSizeListArrayExt; -use vortex::array::arrays::struct_::StructArrayExt; -use vortex::array::expr::col; -use vortex::array::expr::gt; -use vortex::array::expr::lit; -use vortex::array::scalar::Scalar; -use vortex::array::scalar_fn::EmptyOptions; -use vortex::array::scalar_fn::ScalarFnVTable; -use vortex::array::scalar_fn::ScalarFnVTableExt; -use vortex::buffer::ByteBuffer; -use vortex::buffer::ByteBufferMut; -use vortex::dtype::DType; -use vortex::dtype::Nullability; -use vortex::dtype::PType; -use vortex::file::ALLOWED_ENCODINGS; -use vortex::file::OpenOptionsSessionExt; -use vortex::file::WriteOptionsSessionExt; -use vortex::file::WriteStrategyBuilder; -use vortex::io::session::RuntimeSessionExt; -use vortex::session::VortexSession; -use vortex_array::ExecutionCtx; -use vortex_array::builtins::ArrayBuiltins; -use vortex_bench::conversions::parquet_to_vortex_chunks; -use vortex_bench::vector_dataset; -use vortex_bench::vector_dataset::TrainLayout; -use vortex_bench::vector_dataset::VectorDataset; -use vortex_bench::vector_dataset::list_to_vector_ext; -use vortex_btrblocks::BtrBlocksCompressorBuilder; -use vortex_error::VortexExpect; -use vortex_tensor::scalar_fns::cosine_similarity::CosineSimilarity; -use vortex_tensor::scalar_fns::l2_denorm::L2Denorm; -use vortex_tensor::scalar_fns::sorf_transform::SorfTransform; -use vortex_tensor::vector::AnyVector; -use vortex_tensor::vector::Vector; - -/// Cosine threshold for the demo filter. The query comes from the test split, so it may or may not -/// have nearby rows in the train split. -const COSINE_THRESHOLD: f32 = 0.85; - -/// Slack for checking decoded rows against a predicate that was evaluated on TurboQuant's lossy -/// readthrough representation. -const COSINE_THRESHOLD_TOL: f32 = 0.02; - -/// Regression ceiling on the decoded vs original max-abs-diff for 8-bit TurboQuant on 768-dim f32 -/// embeddings. Observed on cohere-small: ~0.10. Pinned with slack so the check catches large -/// quality regressions without flapping on normal run-to-run variation. -const MAX_ABS_DIFF_TOL: f32 = 0.2; - -#[tokio::main] -async fn main() -> Result<()> { - // Opt in to registering the tensor scalar-fn array plugins before building the session. - // Without this, the TurboQuant-compressed `emb` column cannot be serialized into the Vortex - // file or deserialized on read. - // - // SAFETY: single-threaded setup before any other thread exists. - unsafe { std::env::set_var(vortex_tensor::SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV, "1") }; - - let session = VortexSession::default().with_tokio(); - vortex_tensor::initialize(&session); - let mut ctx = session.create_execution_ctx(); - println!("session initialized with tensor plugins"); - - let dataset = VectorDataset::CohereSmall100k; // This is one of the smaller datasets. - - // Download the source parquet files. - let dataset_paths = vector_dataset::download(dataset, TrainLayout::Single).await?; - let (_id, query_vector) = get_query_vector(dataset_paths.test, &mut ctx).await?; - println!( - "query vector selected (id = {_id}, dim = {})", - query_vector.len() - ); - - // Bring the parquet file into memory so that we can write it as a vortex file (after prep). - let single_train_file = dataset_paths - .train_files - .first() - .vortex_expect("we know that there must be a file here") - .clone(); - - println!("reading parquet into chunked array..."); - let chunked_table = parquet_to_vortex_chunks(single_train_file) - .await? - .into_array(); - let len = chunked_table.len(); - println!("parquet loaded: {len} rows"); - - let id = chunked_table.get_item("id")?; - let emb = chunked_table.get_item("emb")?; - - println!("converting emb column to Vector extension type..."); - let vector_array = list_to_vector_ext(emb)?; - - let fields = [("id", id), ("emb", vector_array)]; - let struct_array = StructArray::from_fields(&fields)?.into_array(); - - println!("compressing with TurboQuant and writing to in-memory Vortex file..."); - let bytes = write_turboquant(&session, struct_array.clone().into_array()).await?; - println!("vortex file written: {} bytes", bytes.len()); - - println!("verifying roundtrip fidelity..."); - verify_roundtrip(&session, &bytes, struct_array.clone()).await?; - - println!("verifying filter pushdown with cosine similarity..."); - verify_filter_pushdown(&session, &bytes, &query_vector, struct_array, &mut ctx).await?; - - println!("all checks passed!"); - Ok(()) -} - -async fn get_query_vector( - query_vectors_path: PathBuf, - ctx: &mut ExecutionCtx, -) -> Result<(usize, Vec)> { - let test_vectors = parquet_to_vortex_chunks(query_vectors_path).await?; - - // Get a random query vector. - let idx = rand::random_range(0..test_vectors.len()); - let struct_scalar = test_vectors.execute_scalar(idx, ctx)?; - let id_scalar = struct_scalar - .as_struct() - .field("id") - .vortex_expect("test parquet file missing `id` field"); - - ensure!( - id_scalar - .as_primitive() - .as_::() - .vortex_expect("id was not a i64") - == idx as i64 - ); - - let emb_scalar = struct_scalar - .as_struct() - .field("emb") - .vortex_expect("test parquet file missing `emb` field"); - - // Pack into a `Vec`. - let query_vector: Vec = emb_scalar - .as_list() - .elements() - .vortex_expect("somehow had a null test vector") - .iter() - .map(|element| { - element - .as_primitive() - .as_::() - .vortex_expect("value was not a f32") - }) - .collect(); - - Ok((idx, query_vector)) -} - -async fn write_turboquant(session: &VortexSession, array: ArrayRef) -> Result { - let compressor = BtrBlocksCompressorBuilder::default() - .with_turboquant() - .build(); - - // TurboQuant produces `L2Denorm(SorfTransform(FSL(Dict(...))), norms)`. The default write - // allow-list only covers canonical/compressed array encodings, so the tensor scalar-fn - // encodings it emits get rejected during normalization. Extend the set with the two encoding - // IDs this scheme actually uses. - let mut allowed = ALLOWED_ENCODINGS.clone(); - allowed.insert(L2Denorm.id()); - allowed.insert(SorfTransform.id()); - - let strategy = WriteStrategyBuilder::default() - .with_compressor(compressor) - .with_allow_encodings(allowed) - .build(); - - let mut buf = ByteBufferMut::empty(); - session - .write_options() - .with_strategy(strategy) - .write(&mut buf, array.to_array_stream()) - .await?; - Ok(buf.freeze()) -} - -async fn verify_roundtrip( - session: &VortexSession, - bytes: &ByteBuffer, - original: ArrayRef, -) -> Result<()> { - let chunks: Vec = session - .open_options() - .open_buffer(bytes.clone())? - .scan()? - .into_array_stream()? - .try_collect() - .await?; - - let mut ctx = session.create_execution_ctx(); - - let read: StructArray = ChunkedArray::try_new(chunks, original.dtype().clone())? - .into_array() - .execute(&mut ctx)?; - let original: StructArray = original.execute(&mut ctx)?; - ensure!(read.len() == original.len()); - - let read_emb = read.unmasked_field_by_name("emb")?.clone(); - let original_emb = original.unmasked_field_by_name("emb")?.clone(); - - let decoded = flatten_vector_column(read_emb, &mut ctx)?; - let original_decoded = flatten_vector_column(original_emb, &mut ctx)?; - - let (max_abs, mean_abs) = diff_stats(&original_decoded, &decoded); - println!( - "roundtrip fidelity: max_abs_diff = {max_abs:.6}, mean_abs_diff = {mean_abs:.6} \ - (tol = {MAX_ABS_DIFF_TOL})" - ); - if max_abs > MAX_ABS_DIFF_TOL { - bail!("TurboQuant max_abs_diff {max_abs} exceeds tolerance {MAX_ABS_DIFF_TOL}"); - } - - Ok(()) -} - -async fn verify_filter_pushdown( - session: &VortexSession, - bytes: &ByteBuffer, - query: &[f32], - original: ArrayRef, - ctx: &mut ExecutionCtx, -) -> Result<()> { - // Build the filter as `cosine_similarity(emb, ) > threshold`. The RHS of - // `CosineSimilarity` is a `lit(...)` wrapping a `Vector` scalar; during scan - // evaluation the Literal expands to a ConstantArray whose row count matches the current batch, - // satisfying `CosineSimilarity`'s same-length requirement. The entire expression is pushed - // through `with_filter`, so row selection happens inside the scan rather than after the whole - // column is materialized. - println!("query: {}", preview_vector(query)); - - let query_scalar = build_query_vector_scalar(query)?; - let cosine_expr = CosineSimilarity.new_expr(EmptyOptions, [col("emb"), lit(query_scalar)]); - let filter = gt(cosine_expr, lit(COSINE_THRESHOLD)); - - let scan_start = Instant::now(); - let chunks: Vec = session - .open_options() - .open_buffer(bytes.clone())? - .scan()? - .with_filter(filter) - .into_array_stream()? - .try_collect() - .await?; - let scan_ms = scan_start.elapsed().as_secs_f64() * 1e3; - - let hits: usize = chunks.iter().map(|c| c.len()).sum(); - println!( - "pushed down `cosine_similarity(emb, query) > {COSINE_THRESHOLD}`: {hits} rows survived \ - in {scan_ms:.2} ms" - ); - if hits == 0 { - println!(" no rows survived the filter for this random query"); - return Ok(()); - } - - // Materialize the matching rows and dump each `emb` vector so the reader can see what the - // pushed-down filter actually selected. Vectors are truncated to the first few elements since - // DIM is typically large. - let filtered: StructArray = ChunkedArray::try_new(chunks, original.dtype().clone())? - .into_array() - .execute(ctx)?; - - let emb = filtered.unmasked_field_by_name("emb")?.clone(); - let flat = flatten_vector_column(emb, ctx)?; - - let dim = query.len(); - for (i, row) in flat.chunks_exact(dim).enumerate() { - let cos = cosine_similarity(query, row); - ensure!( - cos >= COSINE_THRESHOLD - COSINE_THRESHOLD_TOL, - "filtered row {i} had decoded cosine {cos:+.6}, below threshold {COSINE_THRESHOLD} \ - by more than tolerance {COSINE_THRESHOLD_TOL}" - ); - println!(" match {i}: cos = {cos:+.6} {}", preview_vector(row)); - } - - Ok(()) -} - -/// Plain `dot(a, b) / (||a|| * ||b||)` over two equal-length f32 slices. Used purely for reporting -/// — the actual row selection is done inside the scan by the pushed-down `CosineSimilarity` -/// expression. This lets the reader cross-check that the surviving rows really do clear the -/// threshold once decoded. -fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { - assert_eq!(a.len(), b.len()); - let mut dot = 0.0f32; - let mut a_sq = 0.0f32; - let mut b_sq = 0.0f32; - for (&x, &y) in a.iter().zip(b) { - dot += x * y; - a_sq += x * x; - b_sq += y * y; - } - dot / (a_sq.sqrt() * b_sq.sqrt()) -} - -/// Render a vector as `[v0, v1, ..., vN-1, vN]` with the first 4 and last 1 elements at 4-decimal -/// precision. Keeps the output compact for high-dim embeddings while still giving the reader -/// something concrete to eyeball. -fn preview_vector(row: &[f32]) -> String { - let dim = row.len(); - if dim <= 5 { - return format!("[{}] (dim = {dim})", fmt_slice(row)); - } - format!( - "[{}, ..., {}] (dim = {dim})", - fmt_slice(&row[..4]), - fmt_slice(&row[dim - 1..]) - ) -} - -fn fmt_slice(s: &[f32]) -> String { - s.iter() - .map(|v| format!("{v:+.4}")) - .collect::>() - .join(", ") -} - -/// Wrap a query vector in a `Vector` extension scalar suitable for use as the RHS of a -/// `CosineSimilarity` filter expression via `lit(...)`. -fn build_query_vector_scalar(query: &[f32]) -> Result { - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = query - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let fsl_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - Ok(Scalar::extension::(EmptyMetadata, fsl_scalar)) -} - -/// Decode a `Vector` extension array's storage down to its flat f32 buffer. -fn flatten_vector_column(emb: ArrayRef, ctx: &mut ExecutionCtx) -> Result> { - let ext: ExtensionArray = emb.execute(ctx)?; - ensure!(ext.ext_dtype().is::()); - - let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; - let elements: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - Ok(elements.as_slice::().to_vec()) -} - -fn diff_stats(original: &[f32], decoded: &[f32]) -> (f32, f32) { - assert_eq!(original.len(), decoded.len()); - let (sum_abs, max_abs) = - original - .iter() - .zip(decoded) - .fold((0.0f32, 0.0f32), |(sum, peak), (&orig, &dec)| { - let diff = (orig - dec).abs(); - (sum + diff, peak.max(diff)) - }); - let mean_abs = sum_abs / original.len() as f32; - (max_abs, mean_abs) -}