Skip to content

[contrib] Add MiMo-V2-Flash (Xiaomi, TP=64, EP=64 MoE)#137

Open
whn09 wants to merge 23 commits intoaws-neuron:mainfrom
whn09:contrib/MiMo-V2-Flash
Open

[contrib] Add MiMo-V2-Flash (Xiaomi, TP=64, EP=64 MoE)#137
whn09 wants to merge 23 commits intoaws-neuron:mainfrom
whn09:contrib/MiMo-V2-Flash

Conversation

@whn09
Copy link
Copy Markdown

@whn09 whn09 commented Apr 22, 2026

Description

Adds MiMo-V2-Flash (Xiaomi) to contrib/. All code lives under contrib/models/MiMo-V2-Flash/git diff upstream/main..HEAD -- src/ is empty.

Two inference paths are supported, picked via neuron_config.quantized:

  • FP8 (recommended) — the HF FP8 checkpoint is preprocessed directly into Neuron-FP8 (OCP ±448 → Neuron ±240) with per-row attention scales and blockwise (128×128) MoE scales. Requires moe_tp_degree=1, moe_ep_degree=64, batch_size>=32 on Trn2 (see "FP8 Configuration Notes" below for why other configs collapse).
  • BF16 (fallback) — the HF FP8 checkpoint is preprocessed to BF16 once, then loaded normally. Uses the MoE use_torch_block_wise=true fallback because Neuron SDK 2.29 does not yet ship the NKI kernel that the default blockwise path pulls in. ~2× slower than FP8 but serves as a known-good reference.

Model Information

Model Name: MiMo-V2-Flash

Model Architecture: 48 decoder layers, 256 MoE experts (top-8) with sigmoid routing and topk_method=noaux_tc (e_score_correction_bias folded into top-k selection), no shared experts, hybrid attention (full + sliding window per-layer pattern), asymmetric Q/K/V head dims (Q/K=192, V=128), partial RoPE (34%), QK RMSNorm before reshape, attention sink bias on SWA layers, attention_value_scale=0.707 applied to V right after v_proj.

Purpose: Text generation.

HuggingFace Checkpoint: XiaomiMiMo/MiMo-V2-Flash

The original checkpoint uses blockwise FP8 (E4M3 OCP ±448) which is incompatible with Neuron FP8 (E4M3 IEEE-754 ±240). Two preprocess scripts are provided:

  • src/conversion_script/preprocess_mimo_v2_flash_fp8.py — FP8 → Neuron-FP8 for the native FP8 inference path. Streaming (per-layer) to cap peak RAM at ~24 GB and finishes in ~20 minutes on trn2.48xlarge; produces a ~310 GB model_layer{0..47}.safetensors set plus model_extras.safetensors. o_proj is kept BF16 per HF's quantization_config.ignored_layers.
  • src/conversion_script/preprocess_mimo_v2_fp8.py — FP8 → BF16 for the BF16 inference path.

Quick Start (FP8)

End-to-end recipe — README has detailed steps, timings, and sanity-check curl:

# 1. Download HF FP8
huggingface-cli download XiaomiMiMo/MiMo-V2-Flash --local-dir /opt/dlami/nvme/models/MiMo-V2-Flash

# 2. Preprocess (HF FP8 -> Neuron FP8, ~20 min)
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
python contrib/models/MiMo-V2-Flash/src/conversion_script/preprocess_mimo_v2_flash_fp8.py \
    --hf_model_path /opt/dlami/nvme/models/MiMo-V2-Flash \
    --save_path     /opt/dlami/nvme/models/MiMo-V2-Flash-Neuron-FP8 \
    --tp_degree 64

# 3. (Optional) smoke-verify without vLLM
source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate
python contrib/models/MiMo-V2-Flash/perf_test/smoke_compile_mimo_v2_flash.py
python contrib/models/MiMo-V2-Flash/perf_test/smoke_generate_mimo_v2_flash.py

# 4. Install vllm-neuron + contrib patch
bash contrib/models/MiMo-V2-Flash/perf_test/0_setup.sh

# 5. Start vLLM + bench (BS=32/moe_ep=64, BS=128/moe_ep=64)
bash contrib/models/MiMo-V2-Flash/perf_test/bench_mimo_v2_flash.sh

FP8 Configuration Notes

Three non-obvious constraints on Trn2; each took a repro to find:

  1. moe_tp_degree=1, moe_ep_degree=64 is the only working FP8 ratio. At moe_tp=64 each rank's intermediate slice is 32 rows (<128 blockwise block), and NxDI's _setup_for_scale collapses the per-rank scale to a singleton — losing per-channel FP8 scale granularity. The resulting drift compounds across Flash's 47 MoE layers into output collapse ("helpful helpful helpful...") after ~30 decode tokens. moe_tp=32/ep=2 and moe_tp=16/ep=4 have both been empirically tested and still collapse. Only moe_tp=1/ep=64 keeps each expert's weight + blockwise scale intact on a single rank and produces correct output.

  2. batch_size >= 32 on the FP8 path. NxDI's TKG path refuses Expert Parallelism when batch_size < num_experts / top_k = 256 / 8 = 32; BS=1 latency demos on FP8 are not possible. For single-stream latency use the BF16 checkpoint with moe_tp=64, moe_ep=1, batch_size=1.

  3. Keep outer ep_degree=1. MoENeuronConfig.ep_degree is the full-model expert-parallel factor and multiplies world_size to tp_degree * ep_degree. At world_size > 64 on a 64-NC Trn2, sharded-checkpoint size grows linearly, ranks beyond 63 have no backing hardware, and load fails. MoE EP is controlled exclusively via moe_ep_degree.

Checklist

Required Components

  • Accuracy Testcontrib/models/MiMo-V2-Flash/test/integration/test_model.py: import, required-attribute, and state-dict-converter tests.
  • README.md
    • Usage Example — flat-import bootstrap (sys.path.insert + from modeling_mimo_v2 import ...), mirroring upstream contrib/Qwen2-Audio-7B
    • Compatibility Matrix — Trn2 (trn2.48xlarge), Neuron SDK 2.29, PyTorch 2.9
    • Example Checkpoints — HuggingFace link + preprocessing instructions (BF16 and FP8)
    • Testing Instructionspytest contrib/models/MiMo-V2-Flash/test/integration/test_model.py
    • Quick Start — end-to-end reproduction steps with approximate timings
  • Source Code under contrib/models/MiMo-V2-Flash/src/:
    • modeling_mimo_v2.py — full modeling, with FP8-only runtime patches gated on quantized=True
    • conversion_script/preprocess_mimo_v2_fp8.py — FP8 → BF16 preprocessor
    • conversion_script/preprocess_mimo_v2_flash_fp8.py — FP8 → Neuron-FP8 preprocessor (streaming)

Optional Components

  • Performance benchmarksperf_test/0_setup.sh (install vllm-neuron, fetch weights), perf_test/bench_mimo_v2_flash.sh (BS=32 and BS=128 with moe_tp=1 / moe_ep=64; save_sharded_checkpoint=true to skip re-sharding on restart), perf_test/vllm-neuron-patch.patch (maps MiMo architecture to Qwen2 loader in vllm-neuron, plumbs hf_config through).
  • FP8 smoke testsperf_test/smoke_compile_mimo_v2_flash.py and perf_test/smoke_generate_mimo_v2_flash.py. These bypass vLLM so FP8 bring-up can iterate without paying vllm-neuron startup cost. STAGE={instantiate,compile,load,all}, DRY_RUN=1, SKIP_WARMUP=1, MOE_TP/MOE_EP knobs for quick debug loops.

Folder Structure

contrib/models/MiMo-V2-Flash/
  README.md
  src/
    __init__.py
    modeling_mimo_v2.py
    conversion_script/
      preprocess_mimo_v2_fp8.py          (FP8 -> BF16)
      preprocess_mimo_v2_flash_fp8.py    (FP8 -> Neuron FP8, streaming)
  test/
    integration/
      test_model.py
    unit/
  perf_test/
    0_setup.sh
    bench_mimo_v2_flash.sh
    sanity_check.sh
    run_bench_single.sh
    smoke_compile_mimo_v2_flash.py
    smoke_generate_mimo_v2_flash.py
    vllm-neuron-patch.patch

Testing

How did you test this change?

Validated on trn2.48xlarge with Neuron SDK 2.29 / PyTorch 2.9 / Python 3.12:

  1. Imports and config — integration tests pass on the NxDI venv (import, required attributes, state-dict converter, MoENeuronConfig class resolution).
  2. BF16 end-to-end serving — vllm-neuron with the included patch boots the model at TP=64/EP=64 and responds to chat-completion requests.
  3. BF16 throughput benchmarking — historical numbers preserved in README.
  4. FP8 smoke (direct NxDI) — TP=64 / moe_tp=1 / moe_ep=64 / BS=32 / SEQ=1024 produces coherent multi-sentence output on Chinese and English prompts, including 40-token chat-template prompts with <|im_start|>system ... <|im_end|> frames. No repetition collapse through 100 tokens of generation.
  5. FP8 end-to-end vLLM serving — same recipe as smoke, served via /v1/chat/completions. Example 40-token Chinese prompt "你好,介绍一下你自己。比如年龄,性别,职业,爱好是什么。" produces a fluent multi-paragraph Markdown response (verified directly against a H200 sglang reference serving the same HF checkpoint).

FP8 debugging history (kept in the branch history for future reference):

  • V-pad (symmetric K/V head_dim via preprocess-side padding): attempted, did not help, reverted.
  • attention_value_scale=0.707: real HF behaviour; was previously overridden to 1.0, now matches reference.
  • Router bias init (arange + bf16): guards against XLA constant-folding the + bias op when the init is uniform; matches Jim Burtoft's MiniMax-M2 fix.
  • outer ep_degree = 1: catches a latent world-size overflow the moment anyone sets MOE_EP > 1.
  • MoE scale expansion uses moe_tp_degree: catches a shape mismatch the moment moe_tp_degree != tp_degree.

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.29
  • Instance Type(s): Trn2 (trn2.48xlarge, logical_nc_config=2 → 64 logical cores)
  • PyTorch Version: 2.9
  • Python Version: 3.12

Additional Information

Import convention follows upstream contrib/Qwen2-Audio-7B: tests and examples add this model's src/ directory to sys.path and import modeling files by their flat module name. This keeps the contrib package self-contained (no registration in utils/constants.py) and allows the upstream package to remain untouched.

FP8 runtime pieces (all installed as monkey-patches, gated on neuron_config.quantized):

  • _apply_ep_scale_fix — don't EP-shard singleton [1,1,W] scales.
  • _apply_blockwise_scale_stride_fix — force partition_stride=1 for BLOCKWISE_SYMMETRIC to avoid strided-split failures when per-rank weight is smaller than a 128-wide scale block.
  • _apply_2d_per_channel_fix — 2D attention/dense-MLP weights use per-row (out, 1) scales; flip their from_float q_config from BLOCKWISE_SYMMETRIC to PER_CHANNEL_SYMMETRIC at construction.
  • _apply_router_noaux_tc_fix — register e_score_correction_bias on RouterTopK (init as torch.arange(num_experts, dtype=torch.bfloat16) so XLA cannot constant-fold the subsequent + bias op and the checkpoint values actually bind) and fold the bias into top-k selection. (Flash's topk_method=noaux_tc.)
  • save_quantized_state_dict override — skip the HF-side re-quantize path (requires CUDA; materializes a ~600 GB BF16 copy) when the preprocess-produced Neuron-FP8 index is already on disk.
  • convert_mimo_v2_hf_to_neuron_state_dict — replicate per-row K/V .scale in lockstep with existing CONVERT_TO_MHA weight replication; expand MoE blockwise gate_up_proj / down_proj .scale along the TP-partitioned dim so per_partition_size == 1 after sharding (expansion uses moe_tp_degree, not tp_degree).

Related Issues

Part of a cleanup of my earlier whn09:contrib/llm-models branch, which originally combined MiMo-V2-Flash and MiniMax-M2 and also touched src/neuronx_distributed_inference/utils/constants.py. The branch has been split into two zero-invasion PRs (this one, and a companion MiniMax-M2 PR).

vLLM Integration

  • This model is intended for use with vLLM
  • Documentation includes vLLM registration instructions (see README "vLLM Integration" section and perf_test/vllm-neuron-patch.patch)

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

whn09 and others added 23 commits April 22, 2026 11:56
Xiaomi MiMo-V2-Flash in NxDI contrib format. All code lives under
contrib/models/MiMo-V2-Flash/, with zero changes to the upstream
src/ tree.

Architecture:
  48 decoder layers, 256 MoE experts (top-8), hybrid attention
  (full + sliding window), asymmetric Q/K/V dims (Q/K=192, V=128),
  partial RoPE (34%), sigmoid router, no shared experts.

Structure:
  src/modeling_mimo_v2.py              - full modeling code (1333 lines)
  src/conversion_script/               - FP8 -> BF16 preprocessor
  test/integration/test_model.py       - config/state-dict/import tests
  perf_test/0_setup.sh                 - vllm-neuron install + weight fetch
  perf_test/bench_mimo_v2_flash.sh     - vLLM serving benchmark (BS=1/32/128)
  perf_test/vllm-neuron-patch.patch    - maps MiMo architecture to Qwen2
                                          loader + hf_config plumbing

Import pattern: tests/examples add src/ to sys.path and import the
flat module name (e.g. `from modeling_mimo_v2 import ...`), matching
the convention in upstream contrib/Qwen2-Audio-7B.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous vllm-neuron-patch.patch was a 129-line copy-of-everything
that targeted a fork branch (whn09/vllm-neuron feature/mimo-support) with
stale history; it could not be applied cleanly to upstream vllm-project/
vllm-neuron. Replaced with a 40-line patch against release-0.5.0 that
adds a `_register_contrib_models()` hook to `vllm_neuron.register()`:

  - If NXDI_CONTRIB_MIMO_V2_FLASH_SRC is set, import NeuronMiMoV2ForCausalLM
    from that directory.
  - Register it into NxDI's MODEL_TYPES under key "mimo_v2_flash"
    (matches the `mimov2flash -> mimo_v2_flash` rewrite that already
    exists in release-0.5.0's _get_neuron_model_cls).
  - Register "MiMoV2FlashForCausalLM" into vLLM's ModelRegistry so vLLM's
    architecture allowlist passes.

This avoids modifying upstream NxDI's `utils/constants.py` (preserves the
contrib zero-invasion property) and avoids modifying upstream vllm-neuron's
model loader (the patch only adds a hook function).

Updated accordingly:
  - perf_test/0_setup.sh now clones release-0.5.0 and `git apply`s the patch.
  - perf_test/bench_mimo_v2_flash.sh exports NXDI_CONTRIB_MIMO_V2_FLASH_SRC
    defaulting to this package's own src/.
  - README serving instructions document the new env var.

Verified on trn2.48xlarge (NxDI 2.29, vLLM 0.16, vllm-neuron 0.5.0):
  NxDI 'mimo_v2_flash': True
  vLLM 'MiMoV2FlashForCausalLM': True
  ModelConfig(model=MiMo-V2-Flash-BF16) creation: OK

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous version hooked into vllm_neuron.register(), which only runs
in the parent APIServer process. vLLM V1 spawns EngineCore workers via
multiprocessing.spawn and those child processes start a fresh Python
interpreter; vLLM's plugin discovery does run there, but the
module-level state (in particular the NxDI MODEL_TYPES dict) is a fresh
copy so the parent's registration does not carry over.

Move _register_contrib_models() into the loader itself and call it at
the top of _get_neuron_model_cls(). Every process that tries to look up
an architecture now gets a fresh idempotent registration attempt driven
by NXDI_CONTRIB_MIMO_V2_FLASH_SRC / NXDI_CONTRIB_MINIMAX_M2_SRC.

Also correct the MODEL_TYPES key: release-0.5.0's loader does not have
the mimov2flash->mimo_v2_flash rewrite, so we must register under
"mimov2flash" (matches architecture.lower()) and "minimaxm2".

Verified on trn2.48xlarge:
  _get_neuron_model_cls("MiMoV2FlashForCausalLM") -> NeuronMiMoV2ForCausalLM
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
NxDI's hf_adapter.load_config calls AutoConfig.from_pretrained(path)
without trust_remote_code=True. Contrib models like MiMo-V2-Flash ship
a configuration_*.py in the checkpoint that requires custom code
execution, so this trips:

    ValueError: The repository ... contains custom code which must be
    executed to correctly load the model.

vLLM's top-level --trust-remote-code only affects vLLM's own config load,
not NxDI's re-load via hf_adapter.

Add a _patch_autoconfig_trust_remote_code() helper that wraps
AutoConfig.from_pretrained to default trust_remote_code=True. Called
from _register_contrib_models() alongside the MODEL_TYPES registration
so every process that reaches _get_neuron_model_cls installs the patch
(idempotent via a _nxdi_contrib_patched sentinel on the class).

Verified on trn2.48xlarge:
  AutoConfig.from_pretrained('/opt/dlami/nvme/models/MiMo-V2-Flash-BF16')
  now succeeds instead of asking for user input.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Neuron SDK 2.29 ships with neuronx_distributed 0.17 whose moe/blockwise.py
expects blockwise_mm_baseline_shard_hidden at either
neuronxcc.nki._private.blockwise_mm or _pre_prod_kernels.blockwise_mm.
Both import paths resolve in the installed SDK but neither exports the
baseline_shard_hidden variant; the MoE forward reaches
_call_shard_hidden_kernel and raises NotImplementedError.

Setting blockwise_matmul_config.use_torch_block_wise=true makes the
blockwise matmul go through the PyTorch reference implementation,
bypassing the missing NKI kernel. It is slower than the NKI path but
unblocks end-to-end vLLM benchmarking on the current stack. Remove
when the NKI kernel is promoted back to a public path.

Applied to the COMMON_MIMO_CONFIG block and merged into the Config 2/3
blockwise_matmul_config overrides (JSON does not recursively merge nested
dicts — the per-config override wins, so use_torch_block_wise must be
listed there too).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
First-time compilation of a 256-expert MoE model on trn2.48xlarge takes
30-90 minutes (~3 configs x 3 buckets x 64 TP ranks of neuron-cc work).
The previous 600s timeout aborts the benchmark driver while the background
compile is still running. Bump to 7200s (2h) and emit a progress blip
every minute so the user knows it's alive rather than hung.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Two small companion scripts to bench_mimo_v2_flash.sh, intended for use
after the monolithic bench has already brought a vLLM server up (or when
the bench driver timed out during first-compile and you want to salvage
the still-running server).

sanity_check.sh
  - POSTs a one-shot chat completion against localhost:$PORT.
  - Prints the JSON response and a one-line summary of the model's reply.
  - Health-checks /health first and fails fast if the server isn't up.

run_bench_single.sh
  - Runs one 'vllm bench serve' pass with configurable
    CONCURRENCY / NUM_PROMPTS / INPUT_LEN / OUTPUT_LEN.
  - Does NOT launch or kill the server — you bring your own.
  - Writes the transcript to $RESULTS_DIR/${CONFIG_NAME}_c${CONCURRENCY}.txt,
    matching bench_mimo_v2_flash.sh's output layout.

Typical usage after a long first-compile:
  # terminal 1: start the server via the main bench (it'll fail wait_for_server
  # but the server process stays up and keeps compiling in the background)
  bash bench_mimo_v2_flash.sh

  # terminal 2: once the server prints "Application startup complete.":
  bash sanity_check.sh
  bash run_bench_single.sh
  CONCURRENCY=16 NUM_PROMPTS=128 bash run_bench_single.sh

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The original preprocess_mimo_v2_fp8.py loaded the entire ~290 GB HF FP8
checkpoint into RAM via load_state_dict(), peaking well over 600 GB after
dequant/requant copies. This per-layer streaming rewrite (one
safe_open handle at a time) reduces peak memory to ~24 GB and runs in
~20 minutes, producing a 311 GB Neuron-FP8 checkpoint as
model_layer{0..47}.safetensors plus model_extras.safetensors.

Key points:
- Attention q/k/v: rescale HF OCP FP8 (+/-448) to Neuron FP8 (+/-240)
  with per-row scales.
- Attention o_proj: listed in HF quantization_config.ignored_layers;
  keep as BF16 and DO NOT emit .scale. The Neuron side binds o_proj to
  plain RowParallelLinear (not QuantizedRowParallel), so writing FP8 +
  .scale would be silently reinterpreted as BF16 bytes at load and
  produce garbage outputs.
- MoE experts: keep blockwise scales, fuse gate|up into the packed
  [num_experts, H, 2*IM] layout expected by ExpertFusedRowParallelLinear.
- Layer 0 dense MLP and attention_sink_bias handling matches the
  Flash config (add_swa_attention_sink_bias=True,
  add_full_attention_sink_bias=False).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Wire the runtime pieces needed to run Flash's preprocessed Neuron-FP8
checkpoint. All modifications are gated by neuron_config.quantized, so
the existing BF16 path is untouched.

New pieces:
- Four monkey-patch installers on NeuronMiMoV2ForCausalLM that
  reconcile NxDI's global blockwise_symmetric q_config with the mixed
  3D-blockwise-MoE + 2D-per-row-attn checkpoint layout:
    * _apply_ep_scale_fix: don't EP-shard singleton [1,1,W] scales.
    * _apply_blockwise_scale_stride_fix: force partition_stride=1 for
      BLOCKWISE_SYMMETRIC to avoid strided-split failures when per-rank
      weight is smaller than a 128-wide scale block.
    * _apply_2d_per_channel_fix: 2D attention/dense-MLP weights use
      per-row (out, 1) scales; flip their from_float q_config from
      BLOCKWISE_SYMMETRIC to PER_CHANNEL_SYMMETRIC at construction.
    * _apply_router_noaux_tc_fix: Flash's topk_method=noaux_tc needs
      e_score_correction_bias in the top-k selection; stock RouterTopK
      silently drops this bias.
- compile()/load() overrides call _install_fp8_patches() before super().
- save_quantized_state_dict override: skip the HF-side re-quantize path
  (requires CUDA, materializes a ~600 GB BF16 copy) when the
  preprocess-produced Neuron-FP8 index is already on disk.
- convert_mimo_v2_hf_to_neuron_state_dict additions (FP8-only):
    * Replicate per-row K/V .scale tensors in lockstep with the existing
      CONVERT_TO_MHA weight replication (TP=64 > 4/8 KV heads).
    * Expand MoE blockwise gate_up_proj/down_proj .scale tensors along
      the TP-partitioned dim so per_partition_size == 1 after sharding
      (preserves gate|up boundary by expanding each half independently).

Cleanup: drop the verbose [DEBUG] prints from the BF16-era CONVERT_TO_MHA
block - useful during bring-up, noisy in steady state (48 layers x 4
prints per run).

Verified end-to-end on Trn2 (TP=64, EP=1, SEQ=1024, BS=1):
  preprocess -> smoke_compile (~18.5 min) -> smoke_generate
  Prompt : "Hello! Please introduce yourself in one sentence."
  Output : "**Hi, I'm Alex AI, a virtual AI assistant created by Meta AI
            to help answer questions" (20 tokens in 1.19s, 16.7 tok/s)
Coherent fluent output, no token collapse.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Two minimal scripts that bypass vLLM so FP8 bring-up can iterate on the
preprocessed Neuron-FP8 checkpoint without paying vllm-neuron startup cost:

- smoke_compile_mimo_v2_flash.py: STAGE={instantiate,compile,load,all},
  DRY_RUN=1 for HLO-only, SKIP_WARMUP=1 when HBM is tight. Builds the
  Flash BS=1 recipe (TP=64, EP=1, blockwise_symmetric, use_shard_on_block
  _dynamic_while=True) and calls compile()+load() directly.
- smoke_generate_mimo_v2_flash.py: 20-token generation via
  HuggingFaceGenerationAdapter using the same config (hash matches so the
  NEFF is reused).

bench: add "save_sharded_checkpoint": true to COMMON_MIMO_CONFIG. During
compile this writes per-rank tp{N}_sharded_checkpoint.safetensors under
<compiled-path>/weights/; subsequent load()s read those directly (~55s)
instead of re-sharding the full checkpoint (~10+ min).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
AMIs on Trn2 dev instances periodically wipe /tmp on reboot, which
breaks the editable pip install (finder maps vllm_neuron -> directory
that no longer exists and all subsequent imports fail). Using \$HOME
makes the install survive reboots; re-running 0_setup.sh after a wipe
still works thanks to the existing idempotency guards.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The HF MiMoV2Flash modeling code
(modeling_mimo_v2_flash.py:358-360 in the checkpoint) multiplies
value_states by config.attention_value_scale right after the V
projection, before attention softmax*V:

    if self.v_scale is not None:
        value_states = value_states * self.v_scale

Flash config has attention_value_scale = 0.707, so value_states
is consistently scaled down by that factor in every attention
layer. Previously this file explicitly overrode self.value_scale
to 1.0 based on a mistaken reading of the HF source, which made
every attention layer's output ~0.707x too large. Short prompts
stayed coherent by luck; prompts >=20 tokens accumulated enough
error for the logits distribution to collapse, producing repeated
single-word gibberish ("sentence sentence sentence" or "the
default value is the default value").

Fix: read attention_value_scale from config (defaulting to 1.0)
and apply it to value_states at the same point HF does. The old
post-attention application point (attn_output *= value_scale) is
mathematically equivalent when value_scale != 1.0, but keeping
the application point aligned with HF makes future parity
checks simpler.

Verified on Trn2 TP=64 EP=1 FP8:
  prompt                                    previously      now
  --------                                  ----------      --------
  "Hello! Please introduce yourself..."     ok by luck      ok
  "The quick brown fox...where it lives"    ok              ok
  "The quick brown fox...forest, where"     "the moon dog   "The fox is
                                             purs the deep   a symbol of
                                             dog are..."     cleverness..."
  35-token chat-template prompt             "I I sentence   coherent
                                             sentence..."    think+answer

Note: same bug almost certainly present in MiMo-V2-Pro — Pro also
force-sets self.value_scale = 1.0, but Pro's config has
attention_value_scale = 0.612.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
compile() / load() already call _install_fp8_patches() via our
overrides, but harnesses (e.g. vllm-neuron) may trigger RouterTopK
and Quantized{Column,Row}Parallel construction during model
instantiation — before compile()/load() get a chance to run. By
installing the patches up-front in __init__ (gated on quantized=True)
we guarantee the patched classes are in effect by the time any of
NxDI's layer factories see them, regardless of which harness drives
the model.

The patches themselves are idempotent (guarded by _mimo_v2_*_patched
sentinels), so installing them twice is harmless.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Pre-pad V projection weights from [num_kv_heads*v_head_dim(128), hidden]
to [num_kv_heads*head_dim(192), hidden] by appending 64 zero rows per
head. This lets the Neuron KV cache manager hold K and V in a single
symmetric shape instead of us runtime-padding V on every forward step.
The modeling-side forward() now slices the attention output back to
real_v_head_dim(128) right before the reshape+o_proj path, matching
HF's o_proj weight shape.

Changes:
- preprocess_mimo_v2_flash_fp8.py: pre-pad v_proj weight and scale to
  head_dim=192 per head (zero-fill the tail). No-op if the checkpoint
  already has v_head_dim == head_dim.
- modeling_mimo_v2.py: introduce attn_real_v_head_dim alongside
  attn_v_head_dim (now always == attn_head_dim). Delete the two runtime
  WORKAROUND blocks that used to pad V and repeat/slice KV heads on
  every forward step. Slice attention output to real_v_head_dim before
  reshape+o_proj. o_proj input dim switches to num_heads*real_v_head_dim
  so the HF o_proj weight shape still matches.
- convert_mimo_v2_hf_to_neuron_state_dict: v_proj replication for
  CONVERT_TO_MHA now uses head_dim (V is pre-padded), not v_head_dim.
  v_proj scale replication likewise uses head_dim.

STATUS: does NOT fix the long-decode output collapse that triggered
this refactor. Chinese chat-template prompts at 40 tokens still
degrade into repetition even after this change. Kept on the branch
because the symmetric-KV-cache layout is architecturally cleaner and
matches what the Kimi-K2 contrib model does; future debugging can
build on this instead of having to reason about runtime V pad/slice.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
NxDI's model builder stages HLOs into a global temp workdir controlled
by the BASE_COMPILE_WORK_DIR env var (default "/tmp/nxd_model/"). Two
concurrent compiles with different neuron_config hashes still share that
directory and silently overwrite each other's model.hlo_module.pb files,
which makes neuronx-cc exit 70 when the compiler tries to read what it
just staged and finds a different graph. Setting BASE_COMPILE_WORK_DIR
to a unique per-compile subdir (derived from COMPILED_PATH) lets
FP8/BF16 smoke compiles run in parallel safely.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Two subtle issues around e_score_correction_bias tracing:

1. dtype=torch.float32 was wrong. NxDI's checkpoint loader casts router
   bias from FP32 to BF16 at load time (see the "Found torch.float32
   weights ... Will convert to torch.bfloat16" warnings in the smoke log).
   If the traced NEFF still expects FP32 after the cast, the
   LayoutTransformation silently drops the weight and leaves the
   trace-time init values live, so the bias at runtime is whatever we
   init here — not the checkpoint values.

2. torch.zeros(num_experts) was wrong. If every entry is identical, the
   `+ bias` op does not change topk's relative ordering, so XLA's
   constant-folding passes prove the add is a no-op and eliminate it,
   dropping the bias parameter from the HLO entirely. Checkpoint loading
   has nothing to bind to, and the real bias values never reach the
   device.

Use torch.arange(num_experts, dtype=torch.bfloat16) instead: distinct
per-expert values force the compiler to keep the add as a runtime op
with a live parameter, and BF16 matches the loader's cast target.
Also move the un-bias-affinity logic into scores_for_choice to match
the MiMo HF reference and MiniMax-M2's working implementation.

Source: Jim Burtoft's MiniMax-M2 fix
(jimburtoft/neuronx-distributed-inference@49f8e164).

This change alone does NOT fix the long-decode output collapse for
Flash (BF16 produces coherent Chinese for the same 40-token chat
prompt where FP8 collapses to "helpful helpful helpful"), but it is
required for correctness once the underlying FP8 issue is found and
fixed.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Previously the smoke scripts set ep_degree=MOE_EP at the top-level
MoENeuronConfig *and* moe_ep_degree=MOE_EP inside the MoE config.
The outer ep_degree is the full-model expert-parallel factor and
multiplies world_size to tp_degree*ep_degree. On a 64-NC Trn2 with
tp_degree=64 and ep_degree>1 that blew world_size past 64
(e.g. moe_ep=4 -> world_size=256), which:
  - produced a sharded checkpoint with 4x as many tp{N} files
    (tp0..tp255 instead of tp0..tp63) and 4x the on-disk size;
  - at runtime would try to address ranks beyond the 64 physical
    cores, failing load or OOM'ing.

Pro's working vLLM configs only set moe_ep_degree (no ep_degree in
the override), so NxDI's default ep_degree=1 keeps world_size=64.
Pin ep_degree=1 in the smoke scripts so varying MOE_EP only affects
the MoE-internal split, matching Pro's layout and keeping the
sharded checkpoint sized correctly.

Also generalize the MoE scale-expansion math in the state-dict
converter to use moe_tp_degree (the shard dim for expert weights)
rather than tp_degree. The two are the same at our BS=1 baseline
(moe_tp=64) so the bug was latent, but manifests the moment you try
any other moe_tp (e.g. 16 or 32): expansion produces scale tensors
sized for the full TP instead of just the MoE TP, triggering a
shape mismatch during shard_checkpoint.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
MoE expert weights shard along moe_tp_degree (the MoE-internal TP
factor), which can differ from the top-level tp_degree. The MoE
scale-expansion code in convert_mimo_v2_hf_to_neuron_state_dict was
hard-coded to tp_degree, so it happened to work at the BS=1 baseline
(moe_tp=tp=64) but produced wrongly-sized scale tensors the moment
MOE_TP != TP_DEGREE. Symptom:
  RuntimeError: expected shape torch.Size([4, 32, 32])
  for layers.1.mlp.expert_mlps.mlp_op.gate_up_proj.scale
  but found torch.Size([4, 32, 64])

Read moe_tp_degree from neuron_config (falling back to tp_degree for
non-MoE configs) and use it as the expansion denominator.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
FP8 with moe_tp=64 (all TP ranks splitting MoE outputs) reduces each
rank's expert intermediate slice to 32 rows — below the 128-row
blockwise scale block — which collapses the per-rank scale to a
singleton in NxDI's `_setup_for_scale`. With no per-channel scale
granularity, MoE forward accumulates a per-layer drift that compounds
across 47 MoE layers and lands as output collapse (long decodes
degrade into "helpful helpful helpful" or similar repetition).

Switching the compile/generate smoke defaults to moe_tp=1 / moe_ep=64
keeps every expert intact on a single rank (n_local_experts=4, no
intra-expert TP shard), so the full per-channel FP8 scale survives.
Verified on Trn2 TP=64 FP8: 40-token Chinese chat prompt produces
coherent multi-sentence output instead of collapsing.

Other FP8 ratios still mis-behave: moe_tp=32/ep=2 leaves down-proj
per-rank intermediate at 64 (<128, still collapses), and
moe_tp=16/ep=4 (per-rank gate_up=256/down=128) also gives gibberish
on the same prompt. Only moe_tp=1/ep=64 — the only config that keeps
both dims well above the 128 block boundary — gives correct output.

COMPILED_PATH default also updated to the new directory so reruns
don't accidentally reuse the old (broken) NEFF cache.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous bench defaulted to the BF16 checkpoint and defined three
configs (BS=1/32/128). The Config 1 FP8 recipe (moe_tp=64/moe_ep=1,
BS=1) is now known to produce output collapse after ~30 decode tokens —
NxDI's `_setup_for_scale` drops per-channel FP8 blockwise scale when
per-rank weight is below the 128-row block, which for Flash happens at
moe_tp=64. And BS=1 combined with Expert Parallelism (moe_ep>1) hits
NxDI's `BS >= num_experts / top_k = 32` assertion during TKG HLO
generation.

New bench:
- MODEL_PATH defaults to the Neuron-FP8 preprocessed checkpoint.
- COMMON_MIMO_CONFIG carries all FP8 quantization fields inline so every
  config inherits them (quantized=true, blockwise_symmetric, 128x128
  blocks, o_proj + embed/lm_head/norm/router held out).
- Config 1: BS=32, moe_tp=1, moe_ep=64 (smallest BS the FP8 path
  supports).
- Config 2: BS=128, moe_tp=1, moe_ep=64 (throughput-leaning).
- Drops sequence_parallel_enabled=true from COMMON (it interacts badly
  with our attention forward at generation time) and drops
  use_torch_block_wise=true (the FP8 path uses the native
  shard-on-block NKI kernel). Numeric CC token / scratchpad page size
  tweaks that were specific to the Pro benchmark are removed.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The FP8 inference path is now the recommended path on Trn2, but it has
several non-obvious configuration requirements. Capture them here so
users don't have to rediscover them:

- **moe_tp_degree=1, moe_ep_degree=64 is the only working FP8 ratio.**
  moe_tp=64 collapses the per-rank blockwise scale to a singleton
  (per-rank intermediate = 32 rows < 128-row block), which compounds
  into output collapse. moe_tp=32/16 have been empirically verified
  to still produce gibberish.
- **batch_size must be >= num_experts / top_k = 32** on the FP8 path
  (NxDI refuses EP>1 at TKG under that threshold).
- **Outer ep_degree must stay 1** — it multiplies world_size, and
  world_size > tp_degree overflows the physical NC count.

Other updates:
- Recommend the new streaming preprocess script
  (`preprocess_mimo_v2_flash_fp8.py`) as the default, demote the
  FP8->BF16 dequant path to "fallback".
- Update the Python usage example to the current NeuronConfig surface
  (quantized=True + blockwise_symmetric, AutoConfig hf_config plumbing,
  modules_to_not_convert list, etc.).
- Replace the BS=1 vLLM serving example with the working BS=32 FP8
  recipe; mention NEURON_COMPILED_ARTIFACTS for isolating compile dirs.
- Correct the sliding_window value (128, not 32,768) and the expert
  intermediate size (2048, not 1536).
- Leave the existing BF16 performance table in place as reference
  while FP8 benchmark numbers are still being collected.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The existing README had all the pieces (preprocess, smoke, vLLM) but
scattered across different sections, so a new user couldn't tell what
order to run them in. Add a concrete 6-step Quick Start that walks
from a fresh trn2.48xlarge to a working vLLM server — download,
preprocess, smoke-verify, install vllm-neuron, bench — with
approximate timings so the 60-minute first-compile isn't a surprise.

Also spell out the prerequisites (SDK version, which DLAMI venv to
use for which stage, disk requirement) and add a curl snippet for
post-deployment sanity check with the specific symptom ("helpful
helpful helpful ...") that indicates the FP8 recipe is misconfigured.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant