Skip to content

Add Kimi-K2-Instruct-0905 contrib model (1T MoE on trn2.48xlarge)#131

Open
jimburtoft wants to merge 8 commits intoaws-neuron:mainfrom
jimburtoft:contrib/kimi-k2-instruct-0905
Open

Add Kimi-K2-Instruct-0905 contrib model (1T MoE on trn2.48xlarge)#131
jimburtoft wants to merge 8 commits intoaws-neuron:mainfrom
jimburtoft:contrib/kimi-k2-instruct-0905

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

@jimburtoft jimburtoft commented Apr 17, 2026

Summary

  • Adds a full NxDI contrib implementation of Kimi-K2-Instruct-0905 (moonshotai/Kimi-K2-Instruct-0905), a 1 trillion parameter MoE model based on the DeepSeek-V3 architecture
  • Validated on trn2.48xlarge with blockwise FP8 quantization
  • Recommended config: LNC=2, TP=32, EP=2 (64 cores) — 76% faster than LNC=1/TP=64
  • All 4 integration tests pass: smoke, generation, coherence, and TPOT performance

Model Details

Parameter Value
Total parameters ~1,000B
Active parameters/token ~32B
Architecture DeepSeek-V3 variant with MLA attention
Experts 384 routed (8 active) + 1 shared per MoE layer
Quantization Blockwise FP8 (e4m3, 128x128 blocks)
Attention Multi-Latent Attention (MLA) with KV LoRA rank=512

Benchmark Results (BS=1, seq_len=1024)

LNC=2 vs LNC=1 Comparison

Config TP EP Cores TPOT tok/s Speedup
LNC=2 (recommended) 32 2 64 165.5 ms 6.0 +76%
LNC=1 64 2 128 297 ms 3.4 baseline

LNC=2 gives 2x HBM bandwidth per logical core. Since MoE decode is purely bandwidth-bound (192 expert weight loads per step), this translates directly to throughput improvement.

Full Benchmark Sweep (LNC=2)

Output Tokens TTFT P50 (ms) TPOT P50 (ms) tok/s E2E P50 (ms)
16 1,420.4 166.38 6.0 3,916.1
32 1,419.8 165.58 6.0 6,553.0
64 1,419.7 165.56 6.0 11,849.8
128 1,419.8 165.48 6.0 22,435.8
256 1,419.9 165.42 6.0 43,604.1
512 1,420.0 165.47 6.0 85,974.4

Files Added

contrib/models/Kimi-K2-Instruct-0905/
├── README.md                          # Full model info, architecture, benchmarks, usage
├── src/
│   ├── __init__.py                    # Exports NeuronKimiK2ForCausalLM, KimiK2InferenceConfig
│   └── modeling_kimi_k2.py            # Complete model implementation (1548 lines)
└── test/
    ├── __init__.py
    ├── integration/
    │   ├── __init__.py
    │   └── test_model.py              # 4 integration tests
    └── unit/
        └── __init__.py

Key Implementation Notes

  • Streaming checkpoint loader: Processes 62 safetensor shards one at a time to avoid OOM on 2TB host RAM
  • Custom FP8 handling: Blockwise symmetric quantization with 128x128 blocks, non-expert modules dequanted to BF16
  • MLA weight absorption: Avoids KV decompression during decode for efficient inference
  • CPU greedy sampling (no_on_device_sampling): Required due to vocabulary size (163840) exceeding on-device sampling limits
  • Router with bias: Uses RouterTopK(bias=True) to load e_score_correction_bias as linear_router.bias
  • Selective loading threshold: Requires patching to 0.0 in neuronx_distributed/modules/moe/model_utils.py

Known Limitations

  • EOS logit elevation: The <|im_end|> token (163586) has an elevated logit in context encoding outputs. Mitigated with min_tokens_before_eos parameter. Documented in README.
  • Batch size: Higher batch sizes provide no throughput improvement due to MoE bandwidth bottleneck (TPOT scales linearly with BS)
  • Instance requirement: Requires trn2.48xlarge — model is too large for smaller instances

Testing

# On trn2.48xlarge with NEURON_LOGICAL_NC_CONFIG=2:
NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 pytest test_model.py -v --capture=tee-sys

Onboard moonshotai/Kimi-K2-Instruct-0905 (1T MoE, 384 experts, MLA
attention, DeepSeek-V3 architecture) to NxDI on trn2.48xlarge.

Configuration: TP=64, EP=2, LNC=1, blockwise FP8 (e4m3, 128x128 blocks)
Performance: 3.4 tok/s at BS=1, TPOT=297.5ms, TTFT=1,788ms

Key implementation details:
- Multi-Latent Attention with compressed KV cache (576 bytes/token/layer)
- Blockwise FP8 quantization for routed expert weights (non-experts in BF16)
- Streaming checkpoint loader for 62 safetensor shards (avoids OOM)
- Sigmoid routing with e_score_correction_bias loaded as router bias
- Monkey patches for EP scale sharding and blockwise scale stride

Tested on: trn2.48xlarge, Neuron SDK 2.28, us-east-2
@jimburtoft jimburtoft force-pushed the contrib/kimi-k2-instruct-0905 branch from c806415 to d599eff Compare April 17, 2026 01:15
LNC=2 gives 2x HBM bandwidth per logical core on trn2.48xlarge.
With TP=32 and EP=2 (64 ranks), NEFF I/O fits at 17.55 GB / 24 GB.
TPOT improves from 297 ms to ~191 ms (purely bandwidth-bound MoE).
…r than LNC=1)

Full sweep across 16-512 output tokens shows rock-stable TPOT.
TTFT P50 = 1,420 ms. Throughput is 1.8x the LNC=1 baseline.
Comment on lines +306 to +311
- **Batching does not improve throughput:** NxDI compiles HLO with per-sequence shapes
(`[1, seq_len]` for CTE, `[1, 1]` for TKG) regardless of `max_batch_size`. Multiple
sequences in a batch are processed sequentially through the same NEFF. Combined with
the bandwidth-bound nature of MoE (192 expert weight loads per decode step), BS>1
provides no aggregate throughput benefit. Verified: BS=2 compile produces identical
NEFF shapes to BS=1.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try setting tkg_batch_size to the same value as max_batch_size to compile TKG with the correct shapes.

whn09 added a commit to whn09/neuronx-distributed-inference that referenced this pull request Apr 28, 2026
Record what works and what doesn't on 2026-04-28:
- Compile + load succeed on Trn2 (moe_tp=1/ep=64/BS=48 recipe).
- Prefill produces coherent English but off-topic output ("100% of the
  time..." loop for a "explain transformer" prompt). Same signature as
  V2-Pro's earlier FP8 failures — per-expert weight distribution too
  narrow for FP8 e4m3 precision.
- Note observed token IDs 15/16/4/315/279/882 look suspiciously small
  but are just " of/ the/ time" etc. — top-BPE English subwords. Greedy
  decode is correct, the logit distribution itself is wrong.
- List recipes still to try (moe_tp=16/ep=4, moe_tp=32/ep=2 etc.) and
  NxDI constraints that rule out BS=1 when moe_ep>1.

Points future debuggers at Jim Burtoft's Flash FP8 observation and his
Kimi PR aws-neuron#131 SDK 2.28 recommendation.

No code changes.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
whn09 added a commit to whn09/neuronx-distributed-inference that referenced this pull request Apr 28, 2026
Earlier wording said "Pro's expert weight std is too small for FP8
precision" in absolute terms. That's misleading — sglang on H100/H200
runs the exact same OCP FP8 checkpoint and produces correct output,
because GPU cutlass/sglang paths dequantize FP8 to BF16 before the
matmul.

The actual issue appears to be Neuron's NKI blockwise FP8 compute
kernel (_bwmm_shard_on_block_nki_call) running FP8 compute directly
on subnormal-leaning tensors. Jim Burtoft's Kimi PR aws-neuron#131 names the
Neuron SDK 2.29 blockwise kernel as producing "depressed logits with
EP=2" and recommends SDK 2.28.

Also noted: V2.5-Pro MoE expert weights are byte-identical to V2-Pro
(measured layer 1 expert 0 gate_proj stats match to 6 decimals), so
all V2-Pro FP8 workarounds remain required — not a new bug.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants