Skip to content

feat: add additional TurboQuant kernel templates for enhanced flash a…#2

Merged
Vect0rM merged 1 commit into
feature/turboquant-kv-cachefrom
feature/sdf
Apr 2, 2026
Merged

feat: add additional TurboQuant kernel templates for enhanced flash a…#2
Vect0rM merged 1 commit into
feature/turboquant-kv-cachefrom
feature/sdf

Conversation

@Ooooze
Copy link
Copy Markdown

@Ooooze Ooooze commented Apr 2, 2026

…ttention support

Added new kernel templates for 512x512 dimensions across TurboQuant configurations (turbo2, turbo3, turbo4) to improve flash attention capabilities. This enhancement allows for better performance and flexibility in handling larger input sizes.
@Vect0rM Vect0rM merged commit 7c01058 into feature/turboquant-kv-cache Apr 2, 2026
11 of 45 checks passed
Vect0rM pushed a commit that referenced this pull request Apr 21, 2026
Codex post-commit review found:
1. TURBO_D was QK_TURBO3 (now 32) — broke turbo4 C array sizes
2. SET_ROWS kernel turbo3-specific but instantiated for turbo4
3. Tail block drop for non-128 head dims

Fixed #3 (TURBO_D). #1 and #2 don't affect turbo3+dk128 path.

Co-Authored-By: tturney@psyguard.ai
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Vect0rM pushed a commit that referenced this pull request Apr 21, 2026
Complete experiment log:
  #1  4-mag LUT:           15.1 at 8K (BEST, +38%)
  #2  Batched extract:     13.7 (+25%)
  #3  Inline FA block:     13.5 (I-cache pressure)
  #4  Deferred norm:       12.9 (loses ILP)
  #5  2-pair half2:        12.0 (ternary overhead)
  #6  Select chain:        11.9 (branches kill)
  #7  Bit-arithmetic:      11.6 (ALU too heavy)
  #8  FMA branchless:      11.4 (ALU still too heavy)
  #9  Named-reg ternary:   10.3 (branches worst)
  #10 Main (8-LUT):        10.95 (baseline)
  #11 Non-vec FA:          10.2 (wrong kernel)
  Ceiling:                 24.5 (no dequant)

Apple8 hardware truth:
  1 divergent constant read < 7 ALU ops (even with fma)
  Branches cost MORE than divergent constant reads
  Array indexing ALWAYS spills on Metal
  4 constant addresses is the sweet spot

The 4-mag LUT is the dequant-level ceiling on Apple Silicon.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: tturney@psyguard.ai
Vect0rM pushed a commit that referenced this pull request Apr 21, 2026
* vulkan: add TQ4_1S weight compression support

Adds Vulkan shader support for TQ4_1S (4-bit WHT-rotated weight
compression with 16 Lloyd-Max centroids, 32-element blocks).

Shaders:
- dequant_tq4_1s.comp: standalone dequant with WHT inverse via
  subgroupShuffleXor (32-thread workgroup, 5-stage butterfly)
- mul_mat_vec_tq4_1s.comp: specialized MUL_MAT_VEC with inline
  activation pre-rotation (forward RHT on activation, centroid*scale
  dequant without inverse RHT)
- copy_from_quant.comp: TQ4_1S dequant path with full WHT inverse
- copy_to_quant.comp: TQ4_1S SET_ROWS quantization path with forward
  RHT, dual half-block RMS scales, 16-centroid quantization
- types.glsl: block_tq4_1s struct (d0, d1, qs[16])
- dequant_funcs.glsl: TQ4_1S centroid*scale dequant (no RHT)

Pipeline wiring (ggml-vulkan.cpp):
- MUL_MAT, SET_ROWS, CPY supports_op
- pipeline_dequant, pipeline_set_rows, pipeline_cpy_quant_f32
- Specialized MUL_MAT_VEC with forced subgroup workgroup size

Tests:
- test_set_rows_tq4_1s: SET_ROWS round-trip validation

* vulkan: add fused mul_mat_vec kernel for TQ4_1S

Adds a specialised MUL_MAT_VEC shader for TQ4_1S weights so the
per-decode-step matrix-vector product no longer has to dequant the
full weight tensor to f16 and then go through the generic matmul
path.  The kernel pre-rotates the activation via a forward
Walsh-Hadamard Transform in shared memory and dot-products against
the raw centroid*scale stored weights, folding the inverse-WHT on
the weight side into the activation by the symmetry H = H^T.

Math:
  w[k] = sign[k] * INV_SQRT32 * (H @ stored)[k]
  sum_k w[k] * a[k] = INV_SQRT32 * sum_j stored[j] * (H @ (sign * a))[j]

Portability choices:

- Workgroup size is pinned to 32 threads regardless of the
  DMMV_WG_SIZE bucket the rest of the mul_mat_vec family picks for
  the current architecture.  The butterfly operates on 32-element
  blocks with one element per thread; that contract is fixed by the
  quantization format, not by the GPU.  Earlier revisions used
  `gl_WorkGroupSize.x` as the stride unit, which silently skipped
  half the work on Intel drivers that force the subgroup to 16
  (tests passed via NMSE tolerance while real inference output was
  garbage).

- Butterfly implementation is shared memory only.  A subgroup-shuffle
  variant (`subgroupShuffleXor`) was prototyped and measured on Intel
  Arc A380 with Mesa Xe HPG: it ran ~60-85 %% slower than the
  explicit shared-memory butterfly, because Mesa emulates subgroup
  shuffles via LDS and ends up doing the same LDS traffic with extra
  driver overhead.  The shared-memory butterfly is correct on every
  device regardless of subgroup-op support, is the fastest path on
  every device we can actually measure, and leaves the
  `pipeline_dequant_mul_mat_vec_f32_f32[w][TQ4_1S]` slot uniform
  across all DMMV_WG_SIZE buckets.

- Reduction is the shared-memory tree reduction (no subgroupAdd), for
  the same reason: on Intel Arc the subgroupAdd is also LDS-backed
  and the hybrid reduction path was measurably slower.  Future
  vendor-specific heuristics can switch to the hybrid or pure-subgroup
  reduction variants on NVIDIA / AMD RDNA if hardware subgroup ops
  turn out to beat the LDS roundtrip there; the existing reduction
  modes in `mul_mat_vec_base.glsl` already provide the necessary
  variants.

- NUM_ROWS is 8 so the butterfly cost amortises across 8 output rows
  per workgroup.  Each thread holds one position of each of the 8
  weight blocks and pairs them with the shared rotated activation.

- `mul_mm` and `flash_attn_cm2` shader generation is skipped for
  TQ4_1S because it is a weight-only format that never reaches the
  coopmat2 matmul or the KV cache flash-attention paths.

Tests:

- `test-backend-ops` MUL_MAT tolerance tightened from 2.0 to 0.01
  NMSE so real defects can't hide behind a loose check.
- Added Gemma-4 E2B, Qwen, Phi and Llama dimensional coverage
  (k in {1536, 2048, 2304, 3072, 4096}, m in {256, 1152, 1536,
  2048, 5120, 6144}, n in {1..8, 16, 64, 256}).  148 MUL_MAT test
  cases total.

Verification (Intel Arc A380, 6 GB VRAM, Vulkan ANV / Mesa Xe HPG,
`llama-bench -p 512 -n 128 -r 3` and `llama-perplexity -c 512
--chunks 20 wiki.test.raw`):

| Model         | Config  |     Size  | Reduction | PPL Δ  | pp512/Q8 | tg128/Q8 |
|---------------|---------|----------:|----------:|-------:|---------:|---------:|
| Qwen2.5-1.5B  | I       | 1570→1082 |   -31.1%  | +4.66% |    53.9% |   107.5% |
| Phi-3.5-mini  | I       | 3873→2839 |   -26.7%  | +5.36% |    57.6% |    52.8% |
| Llama-3.2-3B  | hybrid  | 3263→2147 |   -34.2%  | +2.03% |    82.4% |    84.2% |
| Llama-3.2-3B  | premium | 3263→2577 |   -21.0%  | +0.98% |    71.3% |    67.3% |

Qwen2.5-1.5B is faster than its own Q8_0 baseline with Config I:
the compressed model fits in less VRAM, and on a small model the
TQ4_1S compute cost is offset by the reduced memory traffic.

All four models produce coherent output end-to-end and the
reductions line up with the TurboQuant paper's validation matrix
(§5.8).  The remaining gap to Q8_0 on the bigger models is
compute-bound on the A380; it closes further on GPUs with more raw
throughput.

* vulkan: restructure TQ4_1S inner loop for cross-row smem reuse

Splits the dequant+accumulate phase into two sub-loops:

  1. Pre-compute w_vals[n] for all NUM_ROWS rows (centroid lookup +
     scale multiply, reads from weight buffer only).
  2. Read the rotated activation from shared memory ONCE per column,
     then FMA across all rows in a tight register loop.

This is the Vulkan analogue of the 'hot loop load dedup' from the
CUDA kernel (PR TheTom#57 optimisation #2).  It makes the shared memory
read explicitly loop-invariant across rows, which helps compilers
that don't auto-hoist LDS loads out of unrolled loops.

Measured effect on Intel Arc A380 (Llama-3.2-3B premium,
llama-bench tg128, r=5): 15.50 -> 15.78 t/s (+1.8%, within noise
but not a regression).  The structure is cleaner regardless and
should benefit architectures with higher LDS latency.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants