Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0df9f41
Port v1.3 codec (randomized PCA + --dump-decoded) + HF-based PPL harness
cursoragent Apr 21, 2026
953a8bf
Add vLLM-based v1.3 PPL validation harness
cursoragent Apr 21, 2026
a40755f
e2e_ppl_validation_vllm: match vLLM 0.7.3 Attention.forward signature
cursoragent Apr 21, 2026
58b61cd
run_v1_3_ppl_vllm.sh: auto-select /venv/main/bin/python when present
cursoragent Apr 21, 2026
31abaf1
v1.3 PPL on vLLM: smoke result on H200 — REJECT (+292% ppl)
cursoragent Apr 21, 2026
83220dd
Drop V0 bare-v1.3 datapoint from this PR
cursoragent Apr 21, 2026
5873f11
Port Rust codec guardrails from PR #13 (521e97b state)
cursoragent Apr 21, 2026
7c2d96e
Copy calibration artifacts + Python helpers from PR #13
cursoragent Apr 21, 2026
aec2284
Extend vLLM harness with full v1.3 PPL production recipe
cursoragent Apr 21, 2026
4fc0f36
v1.3 PPL full recipe on vLLM: DS-Distill, 4 passages, H200 result
cursoragent Apr 21, 2026
7d705e0
vLLM PPL ablation harness: isolate Q-precond/FA mismatch vs noise
cursoragent Apr 21, 2026
04190bc
vLLM ablation: H2 (noise) ruled out; H1 (post-RoPE Q-precond) wrong d…
cursoragent Apr 21, 2026
f92ffd9
Refit Σ_q + Lloyd-Max centroids from vLLM prefill snapshots
cursoragent Apr 21, 2026
cdccf72
vllm_calibration_refit: avoid HF-only imports when loading LM helpers
cursoragent Apr 21, 2026
f2065ff
vllm_calibration_refit: set __file__ in the synthetic namespace
cursoragent Apr 21, 2026
5d4809d
vLLM-recalibrated artifacts for DS-Distill (H200, 8 train passages)
cursoragent Apr 21, 2026
dda2e4e
Re-run ablation with vLLM-calibrated artifacts: H3 ruled out
cursoragent Apr 21, 2026
d616cf4
Remove 4-cell ablation harness and its reports
cursoragent Apr 21, 2026
642863d
Consolidate to one standing cell: codec-pre_qp HF-cal vs vLLM-cal
cursoragent Apr 21, 2026
0310a96
run_v1_3_ppl_full_vllm.sh: allow ATTN_BACKEND override
cursoragent Apr 21, 2026
3feaad2
H4 & H5 ablations: XFORMERS backend + prefix-only codec mode
cursoragent Apr 21, 2026
bcfa933
H5 result: prefix-only codec mode gives the same \u0394ppl \u2014 fal…
cursoragent Apr 21, 2026
bb9df76
run_v1_3_ppl_full_vllm.sh: allow K_CENTROIDS='' to skip flag
cursoragent Apr 21, 2026
cdbd7c1
Full H3/H4/H5 ablation + K/V rate strategy sweep \u2014 V-side is the…
cursoragent Apr 21, 2026
867803e
Keep only the codec-pre_qp production cell; add --compress-stream K/V…
cursoragent Apr 21, 2026
b5780e9
Per-channel attribution on the production cell: K carries 2/3 of \u03…
cursoragent Apr 21, 2026
3186b5f
Add V outlier-compensation knob for symmetric V guardrails
cursoragent Apr 21, 2026
546d51b
V-only with outlier T=2.0: \u0394ppl +11.10 \u2192 +7.04 (\u22124 pp)
cursoragent Apr 21, 2026
301bcd6
Snapshot-mode vLLM harness \u2014 scenario A (compress post-prefill KV)
cursoragent Apr 22, 2026
c612063
Snapshot-mode vLLM result: \u0394ppl +29.07% / top1 74.22% (was +35.3…
cursoragent Apr 22, 2026
bb6b823
Compare PR #17 (snapshot-mode) against vLLM upstream PR #38479 (Turbo…
cursoragent Apr 22, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
464 changes: 464 additions & 0 deletions benchmarks/e2e_ppl_validation.py

Large diffs are not rendered by default.

637 changes: 637 additions & 0 deletions benchmarks/e2e_ppl_validation_vllm.py

Large diffs are not rendered by default.

690 changes: 690 additions & 0 deletions benchmarks/e2e_ppl_validation_vllm_full.py

Large diffs are not rendered by default.

555 changes: 555 additions & 0 deletions benchmarks/e2e_ppl_validation_vllm_snapshot.py

Large diffs are not rendered by default.

355 changes: 355 additions & 0 deletions benchmarks/lloyd_max_calibration.py

Large diffs are not rendered by default.

239 changes: 239 additions & 0 deletions benchmarks/q_calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
#!/usr/bin/env python3
"""Q calibration for Q-preconditioned PCA.

For every full-attention layer l and every kv-head h_k we compute

Sigma_q^{(l, h_k)} = sum_{h_q in group(h_k)} E[ q^{(l, h_q)} (q^{(l, h_q)})^T ]

where q is the *pre-RoPE* query (right after q_proj, no rotary). That is
what the codec's k_pre sees when attention is factored as

q_post^T k_post = q_pre^T R^T R k_pre = q_pre^T R(delta) k_pre

and the coupling with k_pre under marginalised relative position is the
pre-RoPE Gram matrix (to leading order — exact at zero position lag,
orthogonally mixed at nonzero lags, which is why we collect q_pre over
many positions and many prompts, not just one).

We then Cholesky-factor each Sigma_q and store L (lower triangular).
The downstream codec pipeline whitens K via `K_tilde = K @ L` before
encode and unwhitens via `K_hat = K_hat_tilde @ L^{-T}` after decode.
This is mathematically identical to using a Sigma_q-weighted distortion
in PCA / K-means / Lloyd-Max, but requires zero Rust codec change.

Math. We want to minimise

sum_i e_i^T Sigma_q e_i = tr(E Sigma_q E^T) = || E L ||_F^2
= || K L - K_hat L ||_F^2

Whiten before codec: K_tilde = K @ L
Un-whiten after decode: K_hat = K_hat_tilde @ L^{-1}

where L is the lower-triangular Cholesky factor of Sigma_q (L L^T = Sigma_q).

Output file format (safetensors):

layer_<l>_chol : [n_kv, D, D] fp32 (lower triangular L)
layer_<l>_inv_chol : [n_kv, D, D] fp32 (lower triangular L^{-1})
layer_<l>_sigma : [n_kv, D, D] fp32 (for diagnostics only)

plus a `config.json` sidecar with shapes and axis meanings.
"""
from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

import numpy as np
import torch

REPO = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO))

from transformers import AutoModelForCausalLM, AutoTokenizer

import benchmarks.pre_rope_cache as prc
from benchmarks.e2e_ppl_pre_rope import load_wikitext_passages


def _kv_group_ranges(num_q_heads: int, num_kv_heads: int):
"""For each kv-head, return the list of query-head indices that share it."""
assert num_q_heads % num_kv_heads == 0, "non-evenly-divisible GQA unsupported"
group_size = num_q_heads // num_kv_heads
return [list(range(h * group_size, (h + 1) * group_size))
for h in range(num_kv_heads)]


@torch.inference_mode()
def calibrate(model_path: str, out_path: Path, *,
n_passages: int, ctx_len: int, prefill_chunk: int,
ridge: float):
print(f"loading {model_path}…", flush=True)
tok = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, dtype=torch.bfloat16, attn_implementation="eager"
)
model.eval()
info = prc.install(model)
print(f" patched {info['patched_layers']} attention layers", flush=True)

cfg = model.config.get_text_config(decoder=True)
# Qwen3 and similar families set an explicit head_dim that differs
# from hidden_size // num_attention_heads. Honor the explicit value
# when present; fall back to the classical formula otherwise.
head_dim = getattr(cfg, "head_dim", None) or (cfg.hidden_size // cfg.num_attention_heads)
n_q = cfg.num_attention_heads
n_kv = cfg.num_key_value_heads
groups = _kv_group_ranges(n_q, n_kv)
n_layers = cfg.num_hidden_layers
layer_types = getattr(cfg, "layer_types", None) or (
["full_attention"] * n_layers
)
print(f" D={head_dim}, n_q={n_q}, n_kv={n_kv}, groups per kv = {len(groups[0])}")

passages = load_wikitext_passages(tok, ctx_len, n_passages)
print(f" using {len(passages)} WikiText passages of >= {ctx_len} tokens", flush=True)

# Accumulate gram sums, not per-passage tensors — cheap and bounded memory.
# gram[l][h_kv] shape [D, D] sum_{all tokens in group} q q^T
gram = [[np.zeros((head_dim, head_dim), dtype=np.float64)
for _ in range(n_kv)]
for _ in range(n_layers)]
count = [0 for _ in range(n_layers)]

for i, passage in enumerate(passages):
ids = tok(passage, return_tensors="pt")["input_ids"][:, :ctx_len]
if ids.shape[-1] < ctx_len:
print(f" passage {i+1}: SKIP (too short: {ids.shape[-1]})")
continue

# Reset recorder for this passage so we can consume incrementally
# without letting memory pile up at long context.
cfg._q_recorder = {}

# Chunked prefill; recorder accumulates q_pre per chunk.
if prefill_chunk <= 0 or ids.shape[-1] <= prefill_chunk:
_ = model(input_ids=ids, use_cache=False)
else:
# use_cache=False means no kv cache is kept — we only want Q stats.
# BUT attention still needs the full context so calibrate on full
# input in one shot (bf16 CPU is fine at ctx=1024-2048).
_ = model(input_ids=ids, use_cache=False)

# Consume recorder: accumulate into grams, then drop refs.
for l_idx, q_list in cfg._q_recorder.items():
if layer_types[l_idx] != "full_attention":
continue
# q_list entries: [bsz=1, n_q, seq, D]
for q in q_list:
q = q[0] # [n_q, seq, D]
for h_kv in range(n_kv):
# Sum over all q heads in this kv group
group = groups[h_kv]
q_group = q[group] # [g, seq, D]
q_flat = q_group.reshape(-1, head_dim).numpy() # [g*seq, D]
gram[l_idx][h_kv] += q_flat.T @ q_flat
count[l_idx] += q.shape[1] # seq, same for every layer
cfg._q_recorder = None
print(f" passage {i+1}/{len(passages)}: processed", flush=True)

# Finalize: Sigma = gram / (group_size * total_tokens_seen)
# then L = chol(Sigma + ridge*I), L^{-T}
print(f"\nfactoring Sigma_q for {n_layers} layers × {n_kv} kv-heads…",
flush=True)
out_tensors = {}
diagnostics = []
for l in range(n_layers):
if layer_types[l] != "full_attention":
continue
if count[l] == 0:
continue
chol_stack = np.zeros((n_kv, head_dim, head_dim), dtype=np.float32)
inv_chol_stack = np.zeros_like(chol_stack)
sigma_stack = np.zeros_like(chol_stack)
group_size = len(groups[0])
n_tokens_per_head = group_size * count[l]
for h_kv in range(n_kv):
sigma = gram[l][h_kv] / n_tokens_per_head
# Symmetrize against fp drift
sigma = 0.5 * (sigma + sigma.T)
# Ridge for numerical stability — ridge * mean_diag
mean_diag = float(np.mean(np.diag(sigma)))
sigma_reg = sigma + ridge * mean_diag * np.eye(head_dim)
L = np.linalg.cholesky(sigma_reg) # lower triangular, L L^T = Sigma
# Inverse of L (also lower triangular) for the unwhitening post-decode step.
L_inv = np.linalg.solve(L, np.eye(head_dim))
chol_stack[h_kv] = L.astype(np.float32)
inv_chol_stack[h_kv] = L_inv.astype(np.float32)
sigma_stack[h_kv] = sigma.astype(np.float32)

evals = np.linalg.eigvalsh(sigma_reg)
diagnostics.append({
"layer": l, "kv_head": h_kv,
"sigma_trace": float(np.trace(sigma)),
"eig_min": float(evals.min()),
"eig_max": float(evals.max()),
"condition": float(evals.max() / max(evals.min(), 1e-30)),
"diag_mean": mean_diag,
"off_diag_max_abs": float(np.abs(sigma - np.diag(np.diag(sigma))).max()),
})
out_tensors[f"layer_{l}_chol"] = torch.from_numpy(chol_stack)
out_tensors[f"layer_{l}_inv_chol"] = torch.from_numpy(inv_chol_stack)
out_tensors[f"layer_{l}_sigma"] = torch.from_numpy(sigma_stack)

out_path.parent.mkdir(parents=True, exist_ok=True)
from safetensors.torch import save_file
save_file(out_tensors, str(out_path))

cfg_sidecar = out_path.with_suffix(".json")
cfg_sidecar.write_text(json.dumps({
"model_path": model_path,
"head_dim": head_dim,
"num_q_heads": n_q,
"num_kv_heads": n_kv,
"num_layers": n_layers,
"layer_types": layer_types,
"n_passages_used": sum(1 for c in count if c > 0),
"ctx_len": ctx_len,
"ridge": ridge,
"diagnostics": diagnostics,
}, indent=2))
print(f"wrote {out_path} + {cfg_sidecar}", flush=True)

# Summary stats: is Sigma_q anisotropic, or close to isotropic?
conds = [d["condition"] for d in diagnostics]
off_ratios = [
d["off_diag_max_abs"] / max(d["diag_mean"], 1e-30)
for d in diagnostics
]
print(f"\nSigma_q anisotropy summary (across all (layer, kv_head) pairs):")
print(f" condition number: min={min(conds):.2f} "
f"median={np.median(conds):.2f} max={max(conds):.2f}")
print(f" max(|off-diag|)/mean_diag: min={min(off_ratios):.3f} "
f"median={np.median(off_ratios):.3f} max={max(off_ratios):.3f}")
print(" (condition ≫ 1 or off/diag ≫ 0 ⇒ Sigma_q is anisotropic, "
"so Q-precondition has something to do)")


def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model-path", required=True)
ap.add_argument("--out-path", type=Path, required=True,
help=".safetensors file path for the calibration")
ap.add_argument("--n-passages", type=int, default=8)
ap.add_argument("--ctx-len", type=int, default=2048)
ap.add_argument("--prefill-chunk", type=int, default=0)
ap.add_argument("--ridge", type=float, default=1e-3,
help="Cholesky ridge = ridge * mean_diag(Sigma) for numerical stability")
args = ap.parse_args()
calibrate(
args.model_path, args.out_path,
n_passages=args.n_passages, ctx_len=args.ctx_len,
prefill_chunk=args.prefill_chunk, ridge=args.ridge,
)


if __name__ == "__main__":
main()
141 changes: 141 additions & 0 deletions benchmarks/q_precondition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""Q-preconditioning utility for the K-stream codec.

This module owns the math

K_tilde = K @ L (whiten)
K_hat = K_hat_tilde @ L^{-1} (un-whiten)

where L is the lower-triangular Cholesky factor of the per-(layer, kv_head)
query Gram matrix Sigma_q produced by `benchmarks/q_calibration.py`.

Minimising standard MSE on K_tilde is mathematically equivalent to
minimising the Sigma_q-weighted distortion on K, which is exactly the
"InnerProduct on K" failure metric that v1.3 paper §2 claims but the
current codec (per-coordinate MSE on K) does not actually enforce.

The Rust codec is not touched. Whitening is done in Python, right
before the tensor is serialised to the KKTV format and sent into the
bench binary. Unwhitening is done right after the decoded tensor is
read back.
"""
from __future__ import annotations

import json
from pathlib import Path
from typing import Optional

import numpy as np
import torch


class QPrecond:
"""Per-layer, per-kv-head Cholesky factors L and L^{-1} held in fp32 cpu.

Shapes:
chol[l] : [n_kv, D, D] lower triangular
inv_chol[l] : [n_kv, D, D] lower triangular (L^{-1})

Layers not present in the dict (e.g. skipped via `skip_layers`) use
an identity transform (no-op), so the codec falls back to its plain
Euclidean behaviour on those layers.
"""

def __init__(self, path: str | Path, skip_layers: list[int] | None = None):
path = Path(path)
cfg = json.loads(path.with_suffix(".json").read_text())
self.head_dim = cfg["head_dim"]
self.n_kv = cfg["num_kv_heads"]
self.n_layers = cfg["num_layers"]
self.layer_types = cfg["layer_types"]
self.skip_layers = set(skip_layers or [])
from safetensors.torch import load_file
tensors = load_file(str(path))
self.chol: dict[int, np.ndarray] = {}
self.inv_chol: dict[int, np.ndarray] = {}
for l in range(self.n_layers):
if l in self.skip_layers:
continue
k_chol = f"layer_{l}_chol"
k_inv = f"layer_{l}_inv_chol"
if k_chol in tensors:
self.chol[l] = tensors[k_chol].numpy().astype(np.float32)
self.inv_chol[l] = tensors[k_inv].numpy().astype(np.float32)

@property
def n_calibrated_layers(self) -> int:
return len(self.chol)

def is_active(self, layer: int) -> bool:
"""Layer has a calibrated Cholesky; whiten/unwhiten are non-trivial."""
return layer in self.chol

def whiten(self, k_per_head: np.ndarray, layer: int) -> np.ndarray:
"""Input: K with shape [seq, n_kv, D] (pre-RoPE, one passage, one layer).
Output: same shape, whitened per kv-head. No-op for layers not
in the calibration (e.g. layer 0 when skip_layers=[0])."""
assert k_per_head.ndim == 3
assert k_per_head.shape[1] == self.n_kv
assert k_per_head.shape[2] == self.head_dim
if layer not in self.chol:
return k_per_head.astype(np.float32, copy=False)
L = self.chol[layer] # [n_kv, D, D]
# K[t, h, :] @ L[h, :, :] → K_tilde[t, h, :]
return np.einsum("thj,hjk->thk", k_per_head, L, optimize=True).astype(
np.float32, copy=False
)

def unwhiten(self, k_tilde_per_head: np.ndarray, layer: int) -> np.ndarray:
"""Inverse of `whiten`. Applies L^{-1} on the right per kv-head.
No-op for layers not in the calibration."""
assert k_tilde_per_head.ndim == 3
if layer not in self.inv_chol:
return k_tilde_per_head.astype(np.float32, copy=False)
Linv = self.inv_chol[layer]
return np.einsum("thj,hjk->thk", k_tilde_per_head, Linv, optimize=True).astype(
np.float32, copy=False
)


def sanity_check(qp: QPrecond) -> dict:
"""Verify whiten ∘ unwhiten ≈ identity on random data for every layer.

Reports max absolute error and max relative error (per layer). A
correctly-built QPrecond will have errors ≲ 1e-5 (fp32 round-off).
"""
rng = np.random.default_rng(0)
results = []
for l, L in qp.chol.items():
x = rng.normal(size=(128, qp.n_kv, qp.head_dim)).astype(np.float32)
x_tilde = qp.whiten(x, l)
x_back = qp.unwhiten(x_tilde, l)
abs_err = float(np.abs(x - x_back).max())
rel_err = float(np.abs(x - x_back).max() / max(np.abs(x).max(), 1e-30))
results.append({"layer": l, "max_abs_err": abs_err, "max_rel_err": rel_err})
return {
"max_abs_err": max(r["max_abs_err"] for r in results),
"max_rel_err": max(r["max_rel_err"] for r in results),
"per_layer": results,
}


def load(path: Optional[str | Path],
skip_layers: list[int] | None = None) -> Optional[QPrecond]:
if path is None:
return None
return QPrecond(path, skip_layers=skip_layers)


if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("--calib", required=True)
args = ap.parse_args()
qp = QPrecond(args.calib)
print(f"loaded {qp.n_calibrated_layers} layers, n_kv={qp.n_kv}, D={qp.head_dim}")
san = sanity_check(qp)
print(f"sanity: max_abs_err={san['max_abs_err']:.3e} "
f"max_rel_err={san['max_rel_err']:.3e}")
# Also show anisotropy of L directly (not Sigma)
for l in sorted(qp.chol)[:3]:
L = qp.chol[l][0] # kv-head 0
print(f" layer {l} kv0 L diag[:5]: {np.diag(L)[:5]}")
Loading