Skip to content

Add Ministral-3-14B-Instruct-2512 (Leanstral) contrib model#134

Open
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/leanstral-clean
Open

Add Ministral-3-14B-Instruct-2512 (Leanstral) contrib model#134
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft:contrib/leanstral-clean

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

@jimburtoft jimburtoft commented Apr 21, 2026

Summary

NxDI contrib for Ministral-3-14B-Instruct-2512 (Leanstral) — a 14B dense GQA text decoder running on trn2.3xlarge (TP=4, LNC=2) with SDK 2.29.

Key results:

  • 252 tok/s aggregate at BS=4 (1.8x H100 FP8)
  • Multi-KV-head TKG NKI kernel with virtual-batch approach — matches baseline TPOT at BS=4, only 5-9% overhead at BS=8
  • FP8 E4M3 → BF16 text-only checkpoint extraction from VL model
  • Full vLLM 0.16 benchmark suite (4 workloads × 2 concurrency levels × 2 batch sizes)

What's included

  • src/extract_text_model.py — FP8→BF16 text extraction, strips vision keys
  • src/setup_patches.py — 6 runtime patches for SDK 2.29 compatibility
  • src/attention_block_tkg_multi_kv.py — Multi-KV-head TKG kernel (NKI 0.3.0)
  • src/multi_kv_adapter.py — Adapter for attention_base.py
  • src/fix_nki030.py — NKI 0.3.0 compatibility
  • bench.py — Async streaming benchmark (TTFT, TPOT, tok/s)
  • Integration test

Architecture

  • 40 layers, hidden=5120, 32 Q / 8 KV heads, d_head=128
  • At TP=4: q_heads_per_rank=8, kv_heads_per_rank=2
  • Uses LlamaForCausalLM code path via --hf-overrides (vLLM 0.16 auto-promotes Mistral to Pixtral)

Performance (trn2.3xlarge, TP=4, LNC=2, SDK 2.29)

BS=4 — TKG kernel matches baseline

Workload Baseline TPOT TKG TPOT Overhead
short-short (128/128) 15.8ms 15.7ms -0.6%
long-long (2048/512) 17.3ms 17.0ms -1.7%

BS=8 — TKG kernel with modest overhead

Workload Baseline TPOT TKG TPOT Overhead
short-short (128/128) 16.7ms 17.5ms +4.8%
long-long (2048/512) 18.7ms 20.1ms +7.5%

NxDI gaps identified

  1. TKG kernel hardcodes kv_heads=1 — this contrib adds multi-KV-head support
  2. convert_state_dict_to_fused_qkv assumes standard Llama head ratios
  3. RMS norm epsilon not passed from config to model base

@jimburtoft jimburtoft force-pushed the contrib/leanstral-clean branch 3 times, most recently from 553d114 to 5f2f28e Compare April 26, 2026 05:29
14B dense GQA text decoder on trn2.3xlarge (TP=4, LNC=2, SDK 2.29).
Includes FP8-to-BF16 text extraction, multi-KV-head TKG NKI kernel with
virtual-batch approach, and full vLLM 0.16 benchmark suite.

At BS=4: 252 tok/s aggregate (1.8x H100 FP8). TKG kernel matches baseline
TPOT at BS=4, adds 5-9% overhead at BS=8.
@jimburtoft jimburtoft force-pushed the contrib/leanstral-clean branch from 5f2f28e to bc8f36c Compare April 26, 2026 08:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant