Add Qwen3.5-2B contrib model#141
Open
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
Open
Conversation
Hybrid DeltaNet + GQA decoder with custom NKI kernels for Neuron. - 24 layers: 18 DeltaNet (linear recurrent) + 6 standard GQA - Custom NKI fused kernel for context encoding (CTE) - Custom NKI per-token kernel for token generation (TKG) - First-token logit validation against CPU BF16 reference - 42 unit tests (CPU) + 9 integration tests (Neuron) - Validated on trn2.3xlarge TP=4 LNC=2 SDK 2.29 - BS=1-8, seq_len=128-4096, all configurations pass - 114.5 tok/s BS=1, up to 409.5 tok/s BS=8
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Note: The below template includes items meant for model contributions only. For other contributions such as bug fixes, features, etc., only fill out the relevant portions of the form.
Description
Adds Qwen3.5-2B, a 2B parameter dense hybrid DeltaNet + GQA decoder from Alibaba Cloud, to the contrib directory. This model features 18 DeltaNet linear recurrent attention layers and 6 standard GQA layers in a [3 DeltaNet + 1 GQA] x 6 pattern, requiring custom NKI kernels for the DeltaNet forward passes on Neuron.
Key implementation details:
tie_word_embeddings=true)Model Information
Model Name: Qwen3.5-2B
Model Architecture: Decoder-only hybrid DeltaNet/GQA transformer (24 layers: 18 DeltaNet + 6 GQA), dense SwiGLU MLP, 2048 hidden size, 248K vocabulary
Purpose: Text generation (chat model with
<|im_start|>/<|im_end|>format)Checklist
Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.
Required Components
Accuracy Test (ex.
test/integration/test_model.py)README.md with the following sections:
Source Code (
src/)modeling_qwen35.py— Main text decoder with NKI DeltaNet kernelsmodeling_qwen35_vision.py— Vision encoder (for future VL support)modeling_qwen35_vl.py— VL orchestrator (for future VL support)nki_kernels/— DeltaNet NKI kernel implementations (fused CTE, per-token TKG, chunked)Optional Components
test/unit/directoryFolder Structure
Confirm your contribution follows this structure:
Testing
How did you test this change?
All tests were run on a trn2.3xlarge instance with TP=4, LNC=2, SDK 2.29 (NKI 0.3.0, PyTorch 2.9, neuronx-distributed-inference). The model was compiled from Qwen/Qwen3.5-2B HuggingFace weights.
Test Results:
Benchmark results (BF16, seq_len=128):
Note on logit validation approach: DeltaNet layers (18 of 24) use NKI linear recurrent kernels that produce higher BF16 numerical divergence than standard GQA. Autoregressive sequences diverge after the first generated token, making multi-token
logit_validation()inapplicable. The first-token logits are validated where CPU and Neuron process identical input prefixes. The model outputs TP-sharded logits (vocab/tp_degree) becauseModelWrapperdoes not call_gather_along_dim, so comparison uses the TP shard 0 slice.Compatibility
Tested with:
Additional Information
_chunk_forwardpath creates 5D tensors that trigger a neuronx-cc codegen crash (NCC_INLA001); the fused NKI kernel is the default and required CTE pathtokenizer.apply_chat_template()is required for quality outputqwen3_5model type requirestransformers>=5.0for CPU reference generation, but the NxDI-pinnedtransformers==4.57.*works for Neuron inference since the model is loaded via manualconfig.jsonparsingRelated Issues
N/A
vLLM Integration
By submitting this PR, I confirm that: