contrib: add MiMo-V2.5 (FP8 on Trn2)#148
Open
whn09 wants to merge 15 commits intoaws-neuron:mainfrom
Open
Conversation
XiaomiMiMo/MiMo-V2.5 supersedes MiMo-V2-Flash: same architecture
(48-layer MoE, 256 experts, hybrid full+SWA attention, partial RoPE,
sink bias, sigmoid + noaux_tc routing, attention_value_scale=0.707)
with new tokenizer id, larger vocab, and multimodal heads (vision +
audio) that the NxDI path does not use.
Copies the Flash tree and renames:
- contrib/models/MiMo-V2-Flash -> contrib/models/MiMo-V2.5
- preprocess_mimo_v2_flash_fp8.py -> preprocess_mimo_v2_5_fp8.py
- bench_mimo_v2_flash.sh -> bench_mimo_v2_5.sh
- smoke_{compile,generate}_mimo_v2_flash.py -> ..._mimo_v2_5.py
- MiMoV2FlashForCausalLM -> MiMoV2ForCausalLM (HF arch name in V2.5)
- NXDI_CONTRIB_MIMO_V2_FLASH_SRC -> NXDI_CONTRIB_MIMO_V2_5_SRC
- MODEL_TYPES key "mimov2flash" -> "mimov2"
The unused legacy preprocess_mimo_v2_fp8.py (Jim's first version,
superseded by the streaming variant) is dropped.
Preprocess adjustments for V2.5's published FP8 checkpoint layout:
- LazyWeightMap aliases legacy `model_N-00001-of-00002.safetensors`
filenames referenced by safetensors.index.json to the actual shard
names on disk (`model_pp0_epN_shardM.safetensors`). V2.5 ships both
naming conventions inconsistently: HF Hub stores the latter while
the index still references the former.
Setup script:
- 0_setup.sh downloads from HuggingFace directly (V2.5 is a public
repo), drops the S3 fallback and the stale "BF16" path.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
MiMo-V2.5's published model.safetensors.index.json references legacy shard filenames like `model_N-00001-of-00002.safetensors`, but the LFS objects on HuggingFace Hub (and therefore on disk after download) are named `model_pp0_epN_shardM.safetensors`. The N values between the two namings are not aligned either, so a mechanical legacy->new rewrite doesn't work — for example model.layers.0.input_layernorm is mapped to `model_1-00002-of-00002` in the index but actually lives in `model_pp0_ep0_shard1`. Rather than reverse-engineer the ep-index permutation, scan the on-disk shards once and rebuild weight_map directly from each safetensors file's manifest. This is a one-time O(num_shards) open at startup and avoids any heuristic filename mapping. Preserves the fast-path for pre-V2.5 checkpoints (where the index filenames match the on-disk names): if any overlap is detected the provided weight_map is used as-is. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
MiMo-V2.5's HF checkpoint stores attention as a fused self_attn.qkv_proj.weight tensor (shape [Q_dim+K_dim+V_dim, hidden]) even though its safetensors.index.json advertises separate q_proj/k_proj/v_proj keys. The actual LFS objects on the Hub carry only the fused form; HF's modeling code slices on the fly. NxDI's MiMoV2Attention hard-codes separate q_proj/k_proj/v_proj ColumnParallelLinear modules and actively deletes any qkv_proj attribute inherited from the base class, so the preprocess must produce split tensors. Slice the fused tensor along the output dim into Q / K / V chunks using the config's per-head dims (swa vs full), then run each through the same per-row FP8 rescale used for non-fused checkpoints. Slices the blockwise (128×128) scale along the output dim the same way — all Q/K/V output dims on V2.5 are multiples of 128, so the block boundaries line up. Any trailing block rows beyond Q+K+V (HF pads full-attention layers' scale to 108 blocks but the weight only has 106 blocks of content) are dropped with the unused weight rows. Falls back to the pre-V2.5 split-qkv path when qkv_proj.weight is absent, so Flash checkpoints still preprocess correctly. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
MiMo-V2.5's fused qkv_proj.weight is NOT a simple
[all_Q | all_K | all_V] concatenation — the first naive slicing
approach produced garbled outputs because Q/K/V rows are physically
interleaved in num_groups=4 per-layer groups:
group g (g = 0..3):
rows [g*R : g*R + qg] = Q heads [g*hpg : (g+1)*hpg]
rows [g*R + qg : g*R + qg + kg] = K heads [g*kpg : (g+1)*kpg]
rows [g*R + qg + kg : g*R + R] = V heads [g*kpg : (g+1)*kpg]
The group count (4) is a model-level constant equal to the full-
attention num_key_value_heads. SWA layers with num_kv_heads=8 pack
kpg=2 K/V heads per group, which is why their fused weight row count
is 14848 (= 4 * (8*192 + 2*192 + 2*128)) rather than the ~27136 one
would expect from an 8-group layout. Full-attention layers with
num_kv_heads=4 pack kpg=1 K/V head per group, giving 13568 rows.
Scale rows also follow the per-group layout with phantom padding:
full attention's K has kg=192 rows but consumes 2 scale blocks
(the last half of the second block is unused), giving 4*(24+2+1)=108
total scale rows against 106 real blocks.
Implementation ported from MiMo-V2.5-Pro's split_qkv_fused with the
num_kv_heads/num_groups axis decoupled so it works for V2.5's
asymmetric (num_kv_heads=4 full / 8 swa) config. Verified empirically
by Q/K/V scale-magnitude probes — Q/K/V bands have distinct scale
distributions that match the claimed slice boundaries on both full
and SWA layers.
Falls back to the pre-V2.5 per-proj path when qkv_proj.weight is
absent, so Flash/other split-qkv checkpoints still preprocess.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
README rewritten to match what actually works today: - Fixed stale Flash-era facts (vocab 151,936 -> 152,576, param counts removed, wrong "fused_qkv not supported" claim updated). - New "Fused QKV on disk, split on Neuron" key feature note and a dedicated "V2.5-specific: fused qkv_proj split into 4 interleaved groups" subsection explaining the per-group layout and scale phantom-row handling that the preprocess implements. - "weight_map rebuild" note: V2.5's index.json references legacy shard filenames that don't exist on disk, and the preprocess scans actual files instead. - Dropped the "FP8 -> BF16 fallback" doc paragraph — that script never existed on this branch. - Mount instructions in Prerequisites: the DLAMI creates /dev/md0 as a 6.9 TB RAID0 but does not add it to /etc/fstab, so after a reboot /opt/dlami/nvme is empty until remounted. Document the `sudo mount /dev/md0 /opt/dlami/nvme` fix. - Updated timing numbers (preprocess 16 min / 15 GB peak RAM, first compile 30 min dominated by 27 min shard_checkpoint). - Dropped the stale BF16 benchmark numbers; FP8 numbers pending. Scratch locations off /tmp: - smoke_compile / smoke_generate: default BASE_COMPILE_WORK_DIR from /tmp/nxd_model/ to /opt/dlami/nvme/tmp/nxd_model/, so HLO/NEFF staging survives the nightly Trn2 reboot. - bench_mimo_v2_5.sh, run_bench_single.sh: RESULTS_DIR default from /tmp/bench_results/mimo_v2_5 to /opt/dlami/nvme/logs/bench_results/ mimo_v2_5. Why: the Trn2 instance reboots daily around 00:07 UTC and /tmp is wiped on reboot. A long-running compile that straddles the reboot loses all its intermediate files under /tmp. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- Step-by-step curl probes for both short sanity check and a longer generation, calling out what the outputs should look like when FP8 is working and what collapse symptoms to watch for. - Note that request-level temperature is ignored because on_device_sampling_config is baked into the NEFF at compile time. - Fix Prerequisites: trn2.48xlarge has 128 physical NeuronCores (not 32); with logical_nc_config=2 they appear as 64 logical cores. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The CONTRIB_SRC lookup used $(cd "$(dirname "$0")/.." && pwd), which
only works when $0 is an absolute path or dirname resolves from the
current working directory. But by the time CONTRIB_SRC was computed,
the script had already cd'd into $HOME/vllm-neuron, so a relative
$0 like "contrib/models/MiMo-V2.5/perf_test/0_setup.sh" could not
find the parent directory and the script failed with:
cd: contrib/models/MiMo-V2.5/perf_test/..: No such file or directory
Resolve SCRIPT_DIR, PATCH_FILE, and CONTRIB_SRC at the top of the
script (before any cd), and reuse SCRIPT_DIR.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
bench_mimo_v2_5.sh and run_bench_single.sh sourced /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference, which doesn't have vllm installed. The rest of the Quick Start (preprocess, smoke, 0_setup.sh) already uses /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16, so align these two. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
vLLM 0.16 validates ModelConfig.architectures against its builtin supported-archs list before plugins get a chance to register new classes. That list already contains MiMoV2FlashForCausalLM and MiMoV2ProForCausalLM (Xiaomi upstream PRs), but not the new V2.5 arch name MiMoV2ForCausalLM, so serving V2.5 via vLLM would fail the pydantic check at APIServer startup. Since the V2.5 and Flash NxDI modeling code are the same (modeling_mimo_v2.NeuronMiMoV2ForCausalLM), reuse the Flash arch name to piggyback on the existing vLLM support instead of trying to register a brand new arch from a plugin: - preprocess rewrites `architectures: ["MiMoV2ForCausalLM"]` in the copied config.json to `["MiMoV2FlashForCausalLM"]`. auto_map still points at the V2.5 configuration_mimo_v2 / modeling_mimo_v2 modules, so trust_remote_code loads V2.5 classes as expected. - vllm-neuron-patch.patch is replaced with the Flash-branch patch verbatim (registers mimov2flash in MODEL_TYPES and registers MiMoV2FlashForCausalLM in vllm's ModelRegistry via the worker loader hook). Exactly the same payload as Flash uses. - bench_mimo_v2_5.sh aliases NXDI_CONTRIB_MIMO_V2_FLASH_SRC to the V2.5 src so the Flash-keyed registration hook picks up our V2.5 modeling code. No new __init__.py surgery, no architecture spoofing at runtime; just one config.json rewrite during preprocess and one env var alias at serve time. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Centralize the four env vars used across smoke / bench / manual vLLM launches: - 0_setup.sh: clearer "Next steps" output that prints all four exports (required + optional) with explanations. Replaces the two-line hint the previous version ended with. - bench_mimo_v2_5.sh: adds defaults for NEURON_COMPILED_ARTIFACTS (/opt/dlami/nvme/compiled/mimo_v2_5_bs32_moetp1_ep64_fp8_vllm, same layout as other contrib models on this instance) and BASE_COMPILE_WORK_DIR (/opt/dlami/nvme/tmp/nxd_model/<basename>, so NxDI's HLO/NEFF staging survives the nightly Trn2 reboot and parallel compiles can't clobber each other). - README: new "Environment variables" subsection under Quick Start tabulating required vs optional vars, defaults, and why each matters. Without NEURON_COMPILED_ARTIFACTS set, vllm-neuron falls back to <checkpoint>/neuron-compiled-artifacts/<hash>/, which buries the output inside the checkpoint dir and isn't what we want when iterating. Without BASE_COMPILE_WORK_DIR set, NxDI's /tmp/nxd_model/ default gets wiped by the reboot mid-compile. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
BS=32 is the smallest batch size the FP8 path supports (num_experts/top_k = 32 requirement for EP>1 in the TKG graph), and it's already the target recipe for serving. Running BS=128 in the same bench script doubled compile time for no additional signal and produced a second NEFF + sharded-weights tree that we don't use. Also update the README description of the bench script. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Both scripts defaulted MODEL_PATH to the BF16 directory path, which is leftover from the Flash-era bench (Flash had a BF16 serving recipe at BS=1 alongside the FP8 recipe at BS=32). On V2.5 only FP8 is supported, so default to the -Neuron-FP8 directory instead. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…n wrapper Previously bench_mimo_v2_5.sh inlined server launch + sanity + three bench runs + teardown (and repeated env-var setup + the full additional-config JSON twice, once per Config). sanity_check.sh and run_bench_single.sh existed as standalone tools but there was no matching "just start the server" script, so the only way to get a running server was to invoke the bench driver. Users who wanted to keep a server up to iterate on prompt or concurrency choices had to either copy-paste bench's launch block or kill bench mid-run. Extract start_vllm_server.sh as the single place that: - sources the vllm venv - exports NXDI_CONTRIB_MIMO_V2_5_SRC, NXDI_CONTRIB_MIMO_V2_FLASH_SRC, NEURON_COMPILED_ARTIFACTS, BASE_COMPILE_WORK_DIR (with defaults) - execs `python3 -m vllm.entrypoints.openai.api_server` with the recipe bench_mimo_v2_5.sh is now a thin orchestrator: backgrounds start_vllm_server.sh, waits for readiness, invokes sanity_check.sh and run_bench_single.sh at c=1,16,32, tears down on exit. 205 lines -> 87. 0_setup.sh "Next steps" and the README now document both the one-shot path and the long-running-server + ad-hoc probe path. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Replace the "numbers pending" placeholder with the real vLLM serving numbers from trn2.48xlarge: output throughput, TPOT/TTFT medians and P99, plus a short analysis note explaining the 58 ms ITL floor (cost of one BS=32 TKG NEFF forward), the 576 tok/s peak at c=32, and why TPOT and TTFT degrade with concurrency under `enable_chunked_prefill=false`. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
14 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds
contrib/models/MiMo-V2.5/— a NeuronX Distributed Inference port of XiaomiMiMo/MiMo-V2.5 with an FP8 serving recipe on trn2.48xlarge.MiMo-V2.5 is the successor to the earlier MiMo-V2-Flash release. The decoder-only MoE architecture is identical; the main delta on the NxDI side is that V2.5 ships the attention projection fused on disk (
attention_projection_layout="fused_qkv") in a per-layer interleaved-group layout, so the preprocess script has to reconstruct per-projq/k/vtensors. The multimodal (vision + audio) heads in the HF checkpoint are not used by the NxDI language path.Model Information
Model Name: MiMo-V2.5
Model Architecture: Decoder-only MoE transformer with hybrid attention (9 full + 39 sliding-window layers), 256 routed experts × 8 top-k, asymmetric Q/K (192) vs V (128) head dims, partial RoPE (rotary_dim=64 of head_dim=192), sigmoid + noaux_tc router.
Purpose: Text generation.
Checklist
Required Components
Accuracy Test (
test/integration/test_model.py)MiMoV2InferenceConfigandNeuronMiMoV2ForCausalLMand asserts the required-attributes contract.perf_test/smoke_compile_mimo_v2_5.py+smoke_generate_mimo_v2_5.py) runs on Trn2 and produces a coherent MiMo self-introduction.README.md sections:
MoENeuronConfigand runninggenerate().XiaomiMiMo/MiMo-V2.5on HF Hub.pytest contrib/models/MiMo-V2.5/test/integration/test_model.py -v.Source Code (
src/)src/modeling_mimo_v2.py— NxDI-compatible modeling (NeuronMiMoV2Attention,NeuronMiMoV2ForCausalLM,MiMoV2InferenceConfig).src/conversion_script/preprocess_mimo_v2_5_fp8.py— streaming OCP-FP8 → Neuron-FP8 rescale with V2.5-specific fused-qkv splitting.Optional Components
start_vllm_server.sh,sanity_check.sh,run_bench_single.sh,bench_mimo_v2_5.sh,0_setup.shfor vLLM serving + benchmarking;smoke_compile_mimo_v2_5.py/smoke_generate_mimo_v2_5.pyfor direct NxDI.Folder Structure
Testing
How did you test this change?
XiaomiMiMo/MiMo-V2.5from HF Hub (~295 GB FP8 blockwise).preprocess_mimo_v2_5_fp8.pyto produce Neuron-FP8 checkpoint (~311 GB, ~16 min, ~15 GB peak RAM).smoke_compile_mimo_v2_5.pySTAGE=all on trn2.48xlarge: compile OK (NEFF cached in neuronx-cc cache, shard_checkpoint for 64 ranks took ~29 min and populatedweights/tp{0..63}_sharded_checkpoint.safetensors).Removing redundant keys from checkpoint: []— no state_dict drops.smoke_generate_mimo_v2_5.pywithapply_chat_template: produced fluent MiMo self-introduction ("Hi there! I'm MiMo, a large language model developed by Xiaomi's LLM team...").bench_mimo_v2_5.shend-to-end (server launch + sanity + 3 bench runs at c=1/16/32). Results in README "Performance" section.Test Results:
vLLM serving on trn2.48xlarge, FP8, BS=32, TP=64 / moe_ep=64, continuous batching + bucketing, 900/90 random I/O:
Median inter-token latency stays at ~58 ms across all concurrencies (the cost of one BS=32 TKG NEFF forward), which matches expectations for this fixed-shape graph.
Compatibility
Tested with:
Additional Information
Known constraints on the FP8 serving path (detailed in
README.md#fp8-configuration-notes):moe_tp_degree=1, moe_ep_degree=64is the only supported MoE ratio.moe_tp_degree=64collapses the per-rank blockwise FP8 scale to a singleton becauseintermediate=2048 / 64 = 32 < block_size=128; NxDI's_setup_for_scalethen drops per-channel granularity and the model produces repetition collapse after ~30 decode tokens.batch_size >= 32required by NxDI's TKG path:batch_size >= num_experts / top_k = 256/8 = 32when Expert Parallelism is enabled. Single-stream BS=1 FP8 latency demos are not currently possible on V2.5.ep_degree = 1. The MoE-internal EP factor is controlled only bymoe_ep_degree; setting the outerep_degree>1multipliesworld_sizepast the 64-NC cap.Checkpoint preparation has two V2.5-specific oddities handled by the preprocess script:
self_attn.qkv_proj.weightas 4 num_kv-groups ×[16 Q heads, 1-2 K heads, 1-2 V heads]per group (full layers get 1 KV/group, SWA layers get 2). The preprocess slices this back out into per-projq_proj / k_proj / v_projtensors to match NxDI's hard-coded ColumnParallelLinear layout. Naive[Q|K|V]concat slicing produces garbled output — we verified empirically by probing per-group scale magnitudes.model.safetensors.index.json. V2.5's index references legacymodel_N-00001-of-00002.safetensorsfilenames that don't match themodel_pp0_epN_shardM.safetensorsLFS objects actually on the Hub.LazyWeightMaprebuildsweight_mapdirectly from the on-disk shards at startup rather than trusting the index.Related Issues
None.
vLLM Integration
README.md#vllm-integrationandperf_test/vllm-neuron-patch.patch)The vLLM piggybacks on upstream vLLM's builtin
MiMoV2FlashForCausalLMarch support (Xiaomi's upstream PR) — preprocess rewrites the checkpoint'sarchitecturesto the Flash name so vLLM 0.16's pydantic arch validator accepts it without requiring a vLLM-side PR to add a new arch.auto_mapstill points at the V2.5 configuration / modeling modules andtrust_remote_code=Trueloads V2.5 classes.By submitting this PR, I confirm that: