Add sarvam-m contrib model (Mistral head_dim fix)#139
Open
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
Open
Add sarvam-m contrib model (Mistral head_dim fix)#139jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
jimburtoft wants to merge 1 commit intoaws-neuron:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Note: The below template includes items meant for model contributions only. For other contributions such as bug fixes, features, etc., only fill out the relevant portions of the form.
Description
Adds contrib support for sarvam-m, a 24B Mistral-architecture decoder-only LLM optimized for Indian languages and English.
This model exposes a general issue in
NeuronMixtralAttention: it hardcodeshead_dim = config.hidden_size // config.num_attention_heads, ignoring the explicithead_dimin the config. For sarvam-m, this computes head_dim=160 when the actual value is 128, causing XLA shape mismatches. The fix applies the samegetattr(config, "head_dim", ...)pattern already used by NeuronLlamaAttention, NeuronQwen3Attention, NeuronGemma3Attention, and others.The contrib includes:
src/setup_patches.py: Applies head_dim fix to modeling_mixtral.py + NKI eps guards for QKV CTE kernelstest/integration/test_model.py: Comprehensive integration tests (English/Hindi generation, greedy determinism, throughput)README.md: Full documentation with benchmarks, patch explanations, and usage instructionsModel Information
Model Name: sarvam-m (sarvamai/sarvam-m)
Model Architecture: Decoder-only transformer (MistralForCausalLM) — GQA 32Q/8KV, head_dim=128, hidden_size=5120, 40 layers, vocab 131K, 32K context
Purpose: Text generation in English and Indian languages (Hindi, Tamil, Telugu, etc.)
Checklist
Required Components
Accuracy Test (
test/integration/test_model.py)README.md with the following sections:
Source Code (
src/)setup_patches.py: Applies 3 patches (Mistral head_dim, nkilib eps, neuronxcc eps)Optional Components
Folder Structure
Testing
How did you test this change?
Tested on trn2.3xlarge with SDK 2.29 (DLAMI 20260410). Applied patches, launched vLLM with TP=8 (LNC=1) and TP=4 (LNC=2), validated English and Hindi generation, greedy determinism, and throughput benchmarks across 4 workloads (128/128, 128/512, 2048/128, 2048/512 input/output lengths) at concurrency 1 and 4.
Test Results:
Compatibility
Tested with:
Additional Information
The head_dim fix is a general improvement that benefits any Mistral-family model where
head_dim != hidden_size // num_attention_heads. Consider upstreaming this fix toNeuronMixtralAttentiondirectly (1-line change to usegetattr), which would eliminate the need for this patch.The NKI eps guards are also model-agnostic and could be upstreamed to nkilib/neuronxcc.
Related Issues
None
vLLM Integration
The model works with standard vLLM serving (no custom registration needed) — it uses the existing NeuronMixtralForCausalLM code path after patches are applied.
By submitting this PR, I confirm that: