Skip to content

[contrib] Add Qwen2.5-Omni-7B with full Neuron speech pipeline (Thinker+Talker+Token2Wav, TP=4)#135

Open
whn09 wants to merge 14 commits intoaws-neuron:mainfrom
whn09:contrib/Qwen2.5-Omni-7B
Open

[contrib] Add Qwen2.5-Omni-7B with full Neuron speech pipeline (Thinker+Talker+Token2Wav, TP=4)#135
whn09 wants to merge 14 commits intoaws-neuron:mainfrom
whn09:contrib/Qwen2.5-Omni-7B

Conversation

@whn09
Copy link
Copy Markdown

@whn09 whn09 commented Apr 22, 2026

Description

Adds the Qwen2.5-Omni-7B multimodal model to contrib/ with full Neuron support for the entire speech pipeline: text generation, image understanding, audio understanding, and text-to-speech — all running on Neuron at TP=4.

Supersedes #122, refactored to a 100% zero-invasion contrib layout (no changes under src/neuronx_distributed_inference/).

Full Neuron Speech Pipeline

Steady-state per-run latency on trn2.48xlarge with all three Neuron models resident in the same process (TP=4, NeuronCores 0–3), averaged over 5 runs of the default prompt:

Component Runtime TP Params Per-run latency
Thinker (text) Neuron 4 7B 0.34s
Vision encoder Neuron 4 670M — (on-demand, image input)
Audio encoder CPU + Neuron 4 620M 20–34ms CPU pre/post (on-demand)
HF CPU forward (Phase 2 hidden states) CPU bf16 7B 1.16s
Talker prep (CPU projection) CPU 0.18s
Talker (codec tokens) Neuron 4 690M 2.28s
Token2Wav DiT Neuron 85M (part of 14.14s)
Token2Wav BigVGAN CPU ~15M (part of 14.14s)
Pipeline total 18.10s

Audio: ~12s of speech per request. RTF: 1.51×.

One-time load (inside the same process, amortized across every subsequent run):

Time
Thinker load + warmup 13.1s
HF CPU bf16 load 0.3s
Talker load 2.0s
Token2Wav DiT load 19.4s
Total cold start ~37s

So wall time for a single-prompt invocation is ~55s; for 5 back-to-back prompts it is ~129s. Once the process is up every subsequent prompt is 18s regardless of how many requests came before.

Zero-Invasion Design

All code lives under contrib/models/Qwen2.5-Omni-7B/. git diff upstream/main..HEAD -- src/ is empty. Three things that previously required modifying the core tree have been handled locally:

  1. Talker per-step thinker state injection — The Talker needs encode_vision_to_input to run during both context encoding and token generation. Previously this required adding an apply_vision_during_token_gen gate in model_base.py. Now it's a get_model_output override on the contrib Talker subclass (src/modeling_qwen25_omni_talker.py), which pre-injects into inputs_embeds before calling super().get_model_output().
  2. Model registration — The model registers itself via flat imports (from modeling_qwen25_omni import ...) from sys.path.insert, same pattern as upstream's contrib/Qwen2-Audio-7B. No entries added to utils/constants.py or inference_demo.py.
  3. hf_adapter.py NameError — Upstream HuggingFaceGenerationAdapter.prepare_inputs_for_generation references an undefined tensor_capture_hook local, which breaks adapter.generate() for the Talker. A runtime shim in src/_upstream_compat.py patches this at import time. The underlying bug is addressed by companion PR Fix NameError in HuggingFaceGenerationAdapter.prepare_inputs_for_generation #136 — once Fix NameError in HuggingFaceGenerationAdapter.prepare_inputs_for_generation #136 merges the shim can be removed.

Single-Process Pipeline + NeuronCore Pinning

generate_qwen25_omni_speech.py loads all three Neuron-compiled models into the same Python process (Thinker TP=4, Talker TP=4, Token2Wav DiT single-device) and reuses them across runs. The earlier subprocess-per-component layout paid ~60s of cold start per request because each child forked, loaded the model, inferred, and exited.

Two details matter for keeping the DiT fast once it co-exists with the Thinker and Talker:

  • NeuronCore pinning: the script sets NEURON_RT_VISIBLE_CORES=0-3 before any Neuron module is imported. Without this, the runtime places the single-device DiT NEFF on a different core group than the TP=4 Thinker/Talker and every DiT forward pays a cross-group scheduling penalty. This is documented in the README and users embedding the pipeline in their own entrypoint must set the env var the same way.
  • HF CPU model in bf16: Phase 2 (hidden-state extraction for the Talker) loads the HuggingFace model in bfloat16 rather than float32. Downstream consumers round back to bf16 anyway, so float32 here is pure overhead — switching cut Phase 2 from ~10s to ~1s.

Default model path

generate_qwen25_omni*.py and test_model.py all resolve the model path via huggingface_hub.snapshot_download("Qwen/Qwen2.5-Omni-7B"), so a user who has never touched the checkpoint can run the examples directly and the weights auto-download on first use. Users with a pre-downloaded checkpoint can point QWEN25_OMNI_MODEL_PATH at a directory containing config.json to skip the HF cache entirely.

Model Information

Model Name: Qwen2.5-Omni-7B

Model Architecture: Multimodal encoder–decoder — Thinker (Qwen2 text) + Vision (SwiGLU ViT) + Audio (CPU frontend + Neuron transformer) + Talker (Qwen2-style with fused embedding) + Token2Wav (DiT + BigVGAN)

Purpose: Text generation, image-to-text, audio-to-text, and text-to-speech (full omni-modal)

HuggingFace Checkpoint: Qwen/Qwen2.5-Omni-7B

Checklist

Required Components

  • Accuracy Testcontrib/models/Qwen2.5-Omni-7B/test/integration/test_model.py: 8-test suite (imports, config, state dict, audio CPU components, Talker CPU, text-only Thinker compile+generate, image understanding, audio understanding).
  • README.md with:
    • Usage Example — text-only and full multimodal code snippets using the flat-import bootstrap.
    • Compatibility Matrix — Neuron SDK 2.29, PyTorch 2.9, Python 3.12, Trn2 (trn2.48xlarge).
    • Example Checkpoints — HuggingFace link + snapshot_download auto-fetch behavior.
    • Testing Instructions — how to run test_model.py and generate_qwen25_omni_speech.py against compiled artifacts.
  • Source Code under contrib/models/Qwen2.5-Omni-7B/src/:
    • modeling_qwen25_omni.py — Text-only and multimodal orchestration, config, state dict conversion.
    • modeling_qwen25_omni_vision.py — Vision encoder (SwiGLU, RMSNorm, PatchMerger) on Neuron.
    • modeling_qwen25_omni_audio.py — Audio encoder with CPU frontend + Neuron transformer (chunked attention).
    • modeling_qwen25_omni_talker.py — Talker (Neuron, fused 8448→896 embedding, per-step thinker injection via get_model_output override).
    • modeling_qwen25_omni_token2wav.py — DiT transformer on Neuron + BigVGAN on CPU.
    • _model_path.pyresolve_model_path() helper; honors QWEN25_OMNI_MODEL_PATH, falls back to snapshot_download.
    • _upstream_compat.py — Runtime shim for the upstream hf_adapter.py tensor_capture_hook NameError (see companion PR Fix NameError in HuggingFaceGenerationAdapter.prepare_inputs_for_generation #136).

Optional Components

  • Integration tests on Neuron — under test/integration/: test_model.py (8-test suite), test_talker_neuron.py, test_e2e_qwen25_omni.py.
  • End-to-end speech pipeline — verified producing real speech audio on trn2.48xlarge; steady-state 18s/request.
  • Examplesexamples/generate_qwen25_omni.py (text/image/audio → text) and examples/generate_qwen25_omni_speech.py (full speech synthesis, three Neuron models resident in one process).
  • Performance benchmarksperf_test/3_bench_qwen25_omni_7b.sh (vLLM BS=1/4, TP=4) and perf_test/apply_vllm_neuron_patch_qwen25omni.py.

Folder Structure

contrib/models/Qwen2.5-Omni-7B/
  README.md
  src/
    __init__.py
    _model_path.py                   # snapshot_download-backed default path resolver
    _upstream_compat.py              # Runtime shim for hf_adapter.py NameError
    modeling_qwen25_omni.py          # Text-only + multimodal orchestration
    modeling_qwen25_omni_vision.py   # Vision encoder (Neuron, TP=4)
    modeling_qwen25_omni_audio.py    # Audio encoder (CPU + Neuron hybrid)
    modeling_qwen25_omni_talker.py   # Talker with get_model_output override
    modeling_qwen25_omni_token2wav.py # DiT (Neuron) + BigVGAN (CPU)
  examples/
    generate_qwen25_omni.py          # text / image / audio / full
    generate_qwen25_omni_speech.py   # end-to-end speech (single-process, core 0-3 pinned)
  test/
    integration/
      test_model.py                  # 8-test suite
      test_talker_neuron.py          # Talker-specific Neuron tests
      test_e2e_qwen25_omni.py        # CPU reference end-to-end
  perf_test/
    3_bench_qwen25_omni_7b.sh
    apply_vllm_neuron_patch_qwen25omni.py

Testing

How did you test this change?

All 8 tests pass on trn2.48xlarge with real Qwen2.5-Omni-7B weights (Neuron SDK 2.29, PyTorch 2.9, Python 3.12):

  1. Imports — all module groups import correctly through the contrib flat-import bootstrap.
  2. Config — TP=4 head divisibility verified (Thinker 7Q/1KV, Audio 5, Vision 4 per rank).
  3. State dict — all 2448 keys convert correctly (text=339, audio=489, vision=518, talker=293, token2wav=809).
  4. Audio CPU — frontend + postprocessor latency 20–34ms for 1–30s audio.
  5. Talker Neuron — compile + load + generate codec tokens on Neuron.
  6. Text generation — compile + load + generate on Neuron, outputs match HF reference.
  7. Image understanding — vision encoder + Thinker, correct captioning.
  8. Full speech pipeline — Thinker → Talker → Token2Wav all on Neuron, producing real speech audio.

End-to-end speech pipeline benchmark (generate_qwen25_omni_speech.py --num-runs 5 on trn2.48xlarge):

Phase Avg / run
Thinker 0.34s
HF CPU hidden states 1.16s
Talker input prep 0.18s
Talker 2.28s
Token2Wav 14.14s
Pipeline 18.10s

Wall time for 5 runs: 129s (load + 5×18s). Audio duration: 12s; RTF 1.51×.

Verified behaviors:

  • The _upstream_compat shim successfully patches HuggingFaceGenerationAdapter.prepare_inputs_for_generation at import time.
  • NeuronTalkerModel.get_model_output override replaces the old apply_vision_during_token_gen class attribute (which used to require a model_base.py edit).
  • All three Neuron-compiled models co-reside on NeuronCores 0–3 when NEURON_RT_VISIBLE_CORES=0-3 is set (verified via neuron-top: NC0/NC1/NC3 at ~5.6 GB each for the TP=4 Thinker+Talker shards, NC2 at ~8.2 GB with the extra DiT NEFF).

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.29
  • Instance Type(s): Trn2 (trn2.48xlarge — primary; trn2.3xlarge supports the same 4-core TP=4 configuration but was not benchmarked in this PR).
  • PyTorch Version: 2.9
  • Python Version: 3.12

Additional Information

Known limitations:

  • Token2Wav BigVGAN remains on CPU (6 upsample stages); only the DiT transformer core is compiled to Neuron. This dominates Token2Wav's 14s per-run time and is the largest remaining optimization target.
  • _upstream_compat.py is a stopgap for an hf_adapter.py bug addressed by companion PR Fix NameError in HuggingFaceGenerationAdapter.prepare_inputs_for_generation #136; once that merges the shim can be deleted.
  • Examples and tests resolve paths via Path(__file__).resolve().parents[N], so the contrib directory layout must be preserved.

Key technical contributions:

  1. NeuronTalkerModel.get_model_output override replaces the previous apply_vision_during_token_gen flag-on-model_base approach, keeping the change fully inside contrib.
  2. Fused embedding (8448→3584→896 collapsed to 8448→896) eliminates the thinker-to-talker projection at inference time.
  3. Vision embeddings auto-pad to max_context_length buckets in set_vision_embeddings() for bucket compatibility.
  4. Token2Wav CPU fallback correctly orders the mel_len overflow check before input_embed doubles the batch for classifier-free guidance.
  5. Three-model single-process pipeline with NeuronCore 0-3 pinning (NEURON_RT_VISIBLE_CORES=0-3) + Phase 2 loaded in bf16 — the combination keeps per-request latency at ~18s even after the cold start, vs. ~90s when each component was in its own subprocess.

Related Issues / PRs

vLLM Integration

  • This model is intended for use with vLLM (text-only).
  • vLLM-neuron patch included at perf_test/apply_vllm_neuron_patch_qwen25omni.py.

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines.
  • This is a community contribution and may have limited testing compared to officially-supported models.
  • The code follows best practices and is well-documented.
  • All required components listed above are included.

All Qwen2.5-Omni code now lives under contrib/models/Qwen2.5-Omni-7B/:
  src/       - modeling files (text, vision, audio, talker, token2wav)
  examples/  - end-to-end generation + speech pipeline
  test/      - integration tests
  perf_test/ - vLLM benchmarks

Key refactors vs feature/qwen25-omni-support:
  - Removed src/neuronx_distributed_inference/models/qwen25_omni/ entirely;
    contrib files import each other flat via sys.path.insert + _upstream_compat
    bootstrap, mirroring upstream's contrib/Qwen2-Audio-7B convention.
  - NeuronTalkerModel previously required a one-line src/model_base.py patch
    (apply_vision_during_token_gen gate). Replaced by a get_model_output
    override inside the contrib Talker class — no src/ change needed.
  - Worked around an upstream hf_adapter.py bug (prepare_inputs_for_generation
    references an undefined tensor_capture_hook local) via a local
    _upstream_compat shim applied at contrib import time. Also leaves room
    for a separate upstream PR with the 1-line fix.
  - Removed the qwen2_5_omni entries from utils/constants.py and
    inference_demo.py that were unnecessary for contrib usage.

Result: git diff upstream/main..HEAD -- src/ is empty; all changes are
additive under contrib/models/Qwen2.5-Omni-7B/.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
whn09 and others added 13 commits April 22, 2026 11:53
The previous defaults were brittle: one pointed at /opt/dlami/nvme/models/
which only exists on specific DLAMI instances, and another pointed at the
parent snapshots/ directory (missing the commit hash sub-folder, so
transformers could never find config.json).

Replace them with a _model_path.resolve_model_path() helper that:
  1. Honors QWEN25_OMNI_MODEL_PATH if it points at a dir with config.json.
  2. Otherwise calls huggingface_hub.snapshot_download(HF_REPO_ID), which
     is a no-op if the model is already cached and returns the real
     snapshot directory (including commit hash) in either case.

This lets developers run the examples and tests with zero setup beyond
the NxDI venv -- if weights are not cached, they auto-download.

Also removes the now-redundant _resolve_model_path() helper in
generate_qwen25_omni_speech.py that manually reached into the snapshots/
directory to guess the commit subdir.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Two Python packages outside the NxDI venv are needed by the examples:

  - soundfile: used in generate_qwen25_omni_speech.py Phase 5 to write
    the output WAV. Previously imported lazily so the error only surfaced
    after ~100 seconds of pipeline execution; now imported up front with
    a clear install hint.
  - qwen-omni-utils[decord]: used in generate_qwen25_omni.py for
    multimodal preprocessing. Already documented in the script's
    docstring, now also in the README.

Added a "Python dependencies" subsection under Prerequisites listing
both pip-install commands.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Phase 2 was loading Qwen2.5-Omni in float32 even though the downstream
Talker consumes the projected states as bfloat16 — the upcast was pure
overhead (~28GB RAM, ~6s load + ~4s forward on Trn2).

Load the HF model in bfloat16 (the checkpoint's native dtype) with
low_cpu_mem_usage=True, and make sure ThinkerToTalkerProjection is
cast to match the weight dtype so the Linear runs in bf16 end-to-end.

This shaves roughly half the Phase 2 wall time with no accuracy change
(everything downstream already rounds back to bf16).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…rocess

The previous implementation spawned a separate Python subprocess for each
Neuron-compiled component (Thinker, Talker, Token2Wav DiT) based on the
assumption that each model needed exclusive NeuronCore access. That wasn't
actually true: since all three share TP=4 (NeuronCores 0-3), they can all
live in the same process and swap NEFFs on the fly, exactly the pattern
used by upstream contrib/Qwen-Image-Edit (PR aws-neuron#117).

Concretely:

- Removed `_run_subprocess`, `_SUBPROCESS_BOOTSTRAP`, and all three
  embedded f-string scripts (one per component).
- `run_thinker` / `run_talker` / `run_token2wav` are now regular Python
  functions that take in-memory arguments and return in-memory results.
  No more `torch.save` / `torch.load` round-trips through a temp dir.
- `load_thinker` / `load_talker` / `load_token2wav` / `load_hf_cpu` are
  invoked once in `main()`; the `for i in range(num_runs)` loop now
  calls the per-run phase functions on the already-loaded models.
- `compile_all` was also de-subprocessed; compilation holds the Neuron
  compiler (not the runtime) so there was never a core-conflict risk.

Performance impact (per your ~14s warm pipeline on trn2.48xlarge):

- First run still pays the full ~60s one-shot load.
- All subsequent runs skip the ~60s load — previously they paid it every
  time because each subprocess re-loaded and then exited.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…nker/Talker

Previously the Token2Wav DiT core was compiled with torch_neuronx.trace()
(single-device). When the whole speech pipeline runs in one Python process
(after the subprocess removal), the Neuron runtime places a single-device
NEFF on a different core group than the TP=4 Thinker/Talker. The result is
a cross-core-group scheduling penalty: DiT per-call jumps from ~10.7s (when
DiT had its own subprocess and core group) to ~14-18s.

Switch DiT compilation to neuronx_distributed.trace.parallel_model_trace
with tp_degree=4 in *replicated* mode:

  - Linears inside _NeuronDiTCore are NOT sharded; DiT has only ~85M params,
    so there's no memory win from sharding and the code change would be much
    larger.
  - The win is pure co-location: all three Neuron models now live on the
    same core group (0..TP-1) and the runtime schedules their NEFFs as peers.
  - parallel_model_trace takes a no-arg builder callable; captured the
    state dict and _block_mask_idx list so the builder can rehydrate the
    same module on the XLA device for each rank.

Artifact layout change:

  - New: <compiled_path>/dit_core/dit_core_parallel/  (directory, via
    parallel_model_save; reloaded with parallel_model_load).
  - Legacy: <compiled_path>/dit_core/dit_core_neuron.pt (single-file
    torch.jit). load_dit() still accepts this for backwards compat and
    prints a warning recommending recompile.
  - compile_all() and _check_compiled() both recognize either artifact.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
The previous attempt defined the parallel_model_trace builder as a closure
inside compile_dit, which fails under start_method='spawn' (the child can't
pickle a local function). Switched to a module-level, importable builder
_build_dit_core_for_trace.

But spawn'd children don't inherit globals either, so the builder can't
read the DiT module/state_dict from a module-level variable. Use an env
var to point the child at a torch.save()'d temp file written right before
parallel_model_trace. The builder loads it back on first call and then the
file is cleaned up (and env var unset) in a finally block.

The temp file is ~350MB (weights + dit reference). It lives next to the
compiled artifacts and is removed after compilation.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…e module

The previous attempt torch.save()'d the DiT nn.Module into a stash file for
the spawn'd child to reload. That fails because the Neuron SDK runtime-patches
torch classes like nn.Embedding, so the module tree contains class objects
whose identity doesn't match torch.nn.modules.sparse.Embedding — the pickler
raises:

    _pickle.PicklingError: Can't pickle <class 'torch.nn.modules.sparse.Embedding'>:
    it's not the same object as torch.nn.modules.sparse.Embedding

Rework the builder so each spawn'd child rebuilds the DiT itself from the
HuggingFace checkpoint on disk. We stash only plain tensors / ints (the
transformer-core state_dict + per-block mask indices) and pass the model
path and stash path via env vars (which spawn children DO inherit).

Additional env-var plumbing for spawn:
  - PYTHONPATH is extended with this module's directory so the child can
    import `modeling_qwen25_omni_token2wav` (and therefore find
    `_build_dit_core_for_trace`). The parent state is restored in the
    finally block.
  - _QWEN25_OMNI_DIT_MODEL_PATH points children at the HF checkpoint.
  - _QWEN25_OMNI_DIT_STASH points children at the tensor stash.

compile_dit() now requires `model_path=...`; the caller in
generate_qwen25_omni_speech.py already knows it.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
parallel_model_trace's worker unpacks the builder result as
``model, input_output_alias = func(**func_kwargs)``. We were returning just
the model, producing ``TypeError: cannot unpack non-iterable _NeuronDiTCore``.
DiT has no input/output weight aliasing, so the alias dict is empty.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…CORES

After reverting the DiT TP=4 experiment back to the single-device
torch_neuronx.trace path, Token2Wav was still slower than the subprocess
baseline (~14s vs ~10.7s) because the Neuron runtime was placing the
single-device DiT NEFF on a different core group than the TP=4
Thinker/Talker, paying a cross-group scheduling penalty on every DiT
forward.

Fix: set NEURON_RT_VISIBLE_CORES=0-3 so all three NEFFs share the same
four NeuronCores. Done in two places:
  - generate_qwen25_omni_speech.py: os.environ.setdefault before any
    Neuron module is imported. setdefault so users can still override
    if they need a different core range.
  - README "Prerequisites" section: explicit instruction + rationale
    for anyone embedding the pipeline in their own entrypoint.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant