Add Kimi-K2-Instruct-0905 contrib model (1T MoE on trn2.48xlarge)#131
Open
jimburtoft wants to merge 8 commits intoaws-neuron:mainfrom
Open
Add Kimi-K2-Instruct-0905 contrib model (1T MoE on trn2.48xlarge)#131jimburtoft wants to merge 8 commits intoaws-neuron:mainfrom
jimburtoft wants to merge 8 commits intoaws-neuron:mainfrom
Conversation
Onboard moonshotai/Kimi-K2-Instruct-0905 (1T MoE, 384 experts, MLA attention, DeepSeek-V3 architecture) to NxDI on trn2.48xlarge. Configuration: TP=64, EP=2, LNC=1, blockwise FP8 (e4m3, 128x128 blocks) Performance: 3.4 tok/s at BS=1, TPOT=297.5ms, TTFT=1,788ms Key implementation details: - Multi-Latent Attention with compressed KV cache (576 bytes/token/layer) - Blockwise FP8 quantization for routed expert weights (non-experts in BF16) - Streaming checkpoint loader for 62 safetensor shards (avoids OOM) - Sigmoid routing with e_score_correction_bias loaded as router bias - Monkey patches for EP scale sharding and blockwise scale stride Tested on: trn2.48xlarge, Neuron SDK 2.28, us-east-2
c806415 to
d599eff
Compare
LNC=2 gives 2x HBM bandwidth per logical core on trn2.48xlarge. With TP=32 and EP=2 (64 ranks), NEFF I/O fits at 17.55 GB / 24 GB. TPOT improves from 297 ms to ~191 ms (purely bandwidth-bound MoE).
…r than LNC=1) Full sweep across 16-512 output tokens shows rock-stable TPOT. TTFT P50 = 1,420 ms. Throughput is 1.8x the LNC=1 baseline.
…S>1 provides no throughput benefit
Comment on lines
+306
to
+311
| - **Batching does not improve throughput:** NxDI compiles HLO with per-sequence shapes | ||
| (`[1, seq_len]` for CTE, `[1, 1]` for TKG) regardless of `max_batch_size`. Multiple | ||
| sequences in a batch are processed sequentially through the same NEFF. Combined with | ||
| the bandwidth-bound nature of MoE (192 expert weight loads per decode step), BS>1 | ||
| provides no aggregate throughput benefit. Verified: BS=2 compile produces identical | ||
| NEFF shapes to BS=1. |
Contributor
There was a problem hiding this comment.
Try setting tkg_batch_size to the same value as max_batch_size to compile TKG with the correct shapes.
…essed logits, EP workaround verified correct
…t CTE kernels produce wrong output with EP=2
…piler BIR verifier bug on SDK 2.29
whn09
added a commit
to whn09/neuronx-distributed-inference
that referenced
this pull request
Apr 28, 2026
Record what works and what doesn't on 2026-04-28:
- Compile + load succeed on Trn2 (moe_tp=1/ep=64/BS=48 recipe).
- Prefill produces coherent English but off-topic output ("100% of the
time..." loop for a "explain transformer" prompt). Same signature as
V2-Pro's earlier FP8 failures — per-expert weight distribution too
narrow for FP8 e4m3 precision.
- Note observed token IDs 15/16/4/315/279/882 look suspiciously small
but are just " of/ the/ time" etc. — top-BPE English subwords. Greedy
decode is correct, the logit distribution itself is wrong.
- List recipes still to try (moe_tp=16/ep=4, moe_tp=32/ep=2 etc.) and
NxDI constraints that rule out BS=1 when moe_ep>1.
Points future debuggers at Jim Burtoft's Flash FP8 observation and his
Kimi PR aws-neuron#131 SDK 2.28 recommendation.
No code changes.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
whn09
added a commit
to whn09/neuronx-distributed-inference
that referenced
this pull request
Apr 28, 2026
Earlier wording said "Pro's expert weight std is too small for FP8 precision" in absolute terms. That's misleading — sglang on H100/H200 runs the exact same OCP FP8 checkpoint and produces correct output, because GPU cutlass/sglang paths dequantize FP8 to BF16 before the matmul. The actual issue appears to be Neuron's NKI blockwise FP8 compute kernel (_bwmm_shard_on_block_nki_call) running FP8 compute directly on subnormal-leaning tensors. Jim Burtoft's Kimi PR aws-neuron#131 names the Neuron SDK 2.29 blockwise kernel as producing "depressed logits with EP=2" and recommends SDK 2.28. Also noted: V2.5-Pro MoE expert weights are byte-identical to V2-Pro (measured layer 1 expert 0 gate_proj stats match to 6 decimals), so all V2-Pro FP8 workarounds remain required — not a new bug. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Model Details
Benchmark Results (BS=1, seq_len=1024)
LNC=2 vs LNC=1 Comparison
LNC=2 gives 2x HBM bandwidth per logical core. Since MoE decode is purely bandwidth-bound (192 expert weight loads per step), this translates directly to throughput improvement.
Full Benchmark Sweep (LNC=2)
Files Added
Key Implementation Notes
no_on_device_sampling): Required due to vocabulary size (163840) exceeding on-device sampling limitsRouterTopK(bias=True)to loade_score_correction_biasaslinear_router.biasneuronx_distributed/modules/moe/model_utils.pyKnown Limitations
<|im_end|>token (163586) has an elevated logit in context encoding outputs. Mitigated withmin_tokens_before_eosparameter. Documented in README.Testing
# On trn2.48xlarge with NEURON_LOGICAL_NC_CONFIG=2: NEURON_LOGICAL_NC_CONFIG=2 LOCAL_WORLD_SIZE=64 pytest test_model.py -v --capture=tee-sys