Skip to content

Add Gemma 4 E2B/E4B support (text-only)#18695

Open
Phineas1500 wants to merge 1 commit intopytorch:mainfrom
Phineas1500:codex/gemma4-e2b-e4b-support
Open

Add Gemma 4 E2B/E4B support (text-only)#18695
Phineas1500 wants to merge 1 commit intopytorch:mainfrom
Phineas1500:codex/gemma4-e2b-e4b-support

Conversation

@Phineas1500
Copy link
Copy Markdown
Contributor

@Phineas1500 Phineas1500 commented Apr 3, 2026

Summary

Add native text-only Gemma 4 support for google/gemma-4-E2B and google/gemma-4-E4B in the ExecuTorch LLM export path.

Why

Gemma 4 E2B/E4B do not fit the existing Llama/Qwen config-only path. Supporting them required new model/runtime behavior plus a checkpoint conversion path, not just new repo IDs and JSON configs.

What Changed

  • Register gemma4_e2b and gemma4_e4b as first-class export targets.
  • Add a new examples/models/gemma4 package with configs, converter, BUCK target, and README.
  • Extend the native text runtime for Gemma 4-specific behavior, including:
    • layer-type-aware sliding/full attention
    • dual RoPE behavior
    • shared-KV reuse
    • per-layer input embeddings / scaling
    • GELU-tanh MLP support
    • post-attention and post-FFN norms
    • layer scaling and final logit softcapping
  • Carry Gemma 4 attention scaling through the custom-SDPA export path.
  • Add focused regression coverage for Gemma 4 support.
  • Add two small supporting fixes discovered during validation:
    • source-tree import cleanup in examples/models/model_factory.py
    • source-tree flatbuffer schema fallback in exir/_serialize/_flatbuffer.py

Validation

Ran:

conda activate et_pt211_clean
export PYTHONNOUSERSITE=1
export PYTHONPATH=..
python -m unittest \
  executorch.examples.models.test.test_model_factory \
  executorch.exir._serialize.test.test_flatbuffer \
  executorch.examples.models.llama.tests.test_gemma4_support \
  executorch.examples.models.qwen3_5.tests.test_convert_weights

Result: Ran 31 tests ... OK

Also validated with real HF checkpoint conversion/export/runtime smoke tests for both google/gemma-4-E2B and google/gemma-4-E4B, including broad greedy-decoding parity checks against HF.

Prompt benchmark summary:

  • E4B: exact match on 11/12 prompts, first-token match on 12/12 prompts
  • E2B: exact match on 8/12 prompts, first-token match on 10/12 prompts

The remaining E2B drift was concentrated in open-ended near-tie generations rather than structural export failures.

Not Included In This PR

  • Gemma 4 multimodal support
  • Qualcomm/QNN or other backend-specific bring-up
  • A dedicated Gemma 4 runner / example app beyond the native text export path
  • CI end-to-end export coverage with real HF weights
  • Performance or memory tuning work beyond correctness bring-up

cc @mergennachin @iseeyuan @lucylq @helunwencser @tarun292 @kimishpatel @jackzhxng

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 3, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18695

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 3 New Failures, 2 Cancelled Jobs

As of commit 8fa1faa with merge base 28f3cf3 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 3, 2026
@Phineas1500
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "release notes: examples"

@pytorch-bot pytorch-bot Bot added the release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava label Apr 3, 2026
@Phineas1500 Phineas1500 changed the title [codex] Add Gemma 4 E2B/E4B support Add Gemma 4 E2B/E4B support (text-only) Apr 4, 2026
@Phineas1500 Phineas1500 marked this pull request as ready for review April 4, 2026 03:27
Copilot AI review requested due to automatic review settings April 4, 2026 03:27
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds native text-only export/runtime support for Gemma 4 E2B/E4B models to ExecuTorch’s LLM (Llama-style) export path, including checkpoint conversion, runtime behavior updates, and regression tests.

Changes:

  • Registers gemma4_e2b / gemma4_e4b as first-class export targets and wires them into the Llama export loader/converter selection.
  • Extends the Llama native text runtime to support Gemma 4 specifics (new attention impl, per-layer embeddings/scaling, dual RoPE tables, additional norms, logit softcapping).
  • Adds a new examples/models/gemma4 package (configs, converter, docs) plus targeted unit/regression tests and a small flatbuffer schema fallback fix for source-tree usage.

Reviewed changes

Copilot reviewed 23 out of 23 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
extension/llm/export/config/llm_config.py Adds Gemma4 model types to the export config enum.
exir/_serialize/_flatbuffer.py Adds a fallback to load flatbuffer schemas from the repo schema/ dir when package resources are missing (editable/source-tree).
exir/_serialize/test/test_flatbuffer.py Adds coverage verifying the schema fallback behavior.
examples/models/model_factory.py Simplifies source-tree vs package-root model imports using __package__ instead of cwd heuristics.
examples/models/test/test_model_factory.py Adds unit tests validating EagerModelFactory.create_model import behavior and toy model load.
examples/models/test/BUCK Adds Buck unittest target for test_model_factory.
examples/models/llama/tests/test_gemma4_support.py Adds focused Gemma4 runtime/export/convert regression tests.
examples/models/llama/tests/BUCK Registers Buck unittest target for Gemma4 support tests.
examples/models/llama/source_transformation/sdpa.py Threads SDPA scale through custom/quantized custom SDPA export path.
examples/models/llama/rope.py Adds proportional RoPE precompute + layer-type-specific RoPE tables and dtype-preserving HF RoPE application.
examples/models/llama/norm.py Extends RMSNorm to support “no scale parameter” mode (with_scale=False) used by Gemma4.
examples/models/llama/model_args.py Adds Gemma4-related args (layer-type rope params, per-layer embedding dims, global head dims, etc.).
examples/models/llama/llama_transformer.py Adds Gemma4 features in transformer blocks (post norms, per-layer inputs, layer scaling, KV donor selection, logit softcapping).
examples/models/llama/feed_forward.py Allows configurable activation function in FFN (needed for GELU-tanh variant).
examples/models/llama/export_llama_lib.py Registers Gemma4 model IDs, routes converter import, and loads Gemma4Model when appropriate.
examples/models/llama/attention.py Adds SDPA scaling support and introduces AttentionGemma4MHA with Gemma4 KV sharing + sliding/full attention behavior.
examples/models/gemma4/README.md Documents Gemma4 export usage and supported models.
examples/models/gemma4/convert_weights.py Adds Gemma4 checkpoint loading (pt/safetensors) + key mapping into ExecuTorch meta format.
examples/models/gemma4/config/e4b_config.json Adds Gemma4 E4B export config (layer types, rope params, scaling constants).
examples/models/gemma4/config/e2b_config.json Adds Gemma4 E2B export config (layer types, rope params, scaling constants).
examples/models/gemma4/BUCK Adds Buck library target for Gemma4 package (incl. resources + safetensors dep).
examples/models/gemma4/init.py Adds Gemma4 model entrypoint (lazy Gemma4Model wrapper) and exports converter.
examples/models/BUCK Adds Gemma4 package to the aggregated models target list.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 158 to 160
model_architecture: Optional[str] = (
None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now.
)
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModelArgs defines model_architecture twice (earlier as a required str with default, and again here as Optional[str]). In a dataclass the later field overwrites the earlier one, which silently changes the default from e.g. "LlamaForCausalLM" to None and can break any code relying on the original default. Remove the duplicate and keep a single model_architecture definition (or rename the new field if it’s meant to be different metadata).

Suggested change
model_architecture: Optional[str] = (
None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now.
)

Copilot uses AI. Check for mistakes.
@nil-is-all nil-is-all added the module: examples Issues related to demos under examples/ label Apr 7, 2026
@nil-is-all
Copy link
Copy Markdown
Contributor

Hi @Phineas1500, thanks for the PR. We are working on Gemma4 support, which is currently a WIP. Will keep you posted
cc @mergennachin

@digantdesai
Copy link
Copy Markdown
Contributor

@claude Review this PR

@claude
Copy link
Copy Markdown

claude Bot commented Apr 22, 2026

Claude finished @digantdesai's task in 2m 16s —— View job


PR Review: Add Gemma 4 E2B/E4B support (text-only)

  • Read all changed files and gather context
  • Review new Gemma4 package (configs, converter, init, BUCK)
  • Review changes to Llama infrastructure (attention, transformer, rope, norms, feed_forward, model_args)
  • Review export path and SDPA changes
  • Review test coverage
  • Review supporting fixes (model_factory, flatbuffer)
  • Post comprehensive review

Overall Assessment

This is a substantial and well-structured PR that adds Gemma 4 E2B/E4B text-only support by extending the existing Llama infrastructure rather than forking it. The approach of registering a new gemma4_mha attention type and threading Gemma 4-specific behavior through ModelArgs fields is clean. The test coverage is focused and the export/runtime parity validation is strong.

There are a few issues worth addressing, ranging from a confirmed bug to design considerations.


Issues

1. Bug: Duplicate model_architecture field in ModelArgs dataclass

File: examples/models/llama/model_args.py:60-62 and examples/models/llama/model_args.py:158-160

The ModelArgs dataclass defines model_architecture twice:

# Line 60
model_architecture: str = "LlamaForCausalLM"

# Line 158 (added by this PR's reordering)
model_architecture: Optional[str] = None

In a Python dataclass, later fields overwrite earlier ones. The second definition silently changes the type from str to Optional[str] and the default from "LlamaForCausalLM" to None. This could break any code that relies on the original default (e.g., QNN backend code that checks model_architecture).

Copilot already flagged this. The duplicate must be resolved — keep only the second (Optional[str]) definition and remove the first, or consolidate into one with the correct default.

Fix this →

2. _get_kv_donor_layer_idx could fail silently with no matching donor type

File: examples/models/llama/llama_transformer.py:54-73

When layer_types are provided, the function searches backwards for a donor with the same layer type. If no match is found, it returns None:

for donor_idx in range(first_shared - 1, -1, -1):
    if layer_types[donor_idx] == target_type:
        return donor_idx
return None

If a shared layer has a type that doesn't appear in any non-shared layer, the function returns None. This would cause shared_kv[donor_idx] lookup to silently fail later in _forward_layers (line 366: if donor_idx in shared_kv would be False for None), meaning that shared layer silently proceeds without any KV input at all, falling through to the self.wk path, where self.wk is None (since it's marked as a shared layer) — causing an assertion error. This sequence is somewhat protected by the assertion at line 852, but only at runtime. Consider raising an explicit error in _get_kv_donor_layer_idx when layer_types are provided but no donor is found, to surface config errors earlier.

3. AttentionGemma4MHA.__init__ unconditionally creates k_norm_fn for shared layers

File: examples/models/llama/attention.py:727-742

Shared layers (where has_kv_weights=False) don't do K/V projection, so the k_norm_fn and v_norm_fn parameters allocated at lines 732-742 will never be used for those layers. This is wasted memory (small amount per layer) and could be confusing. Consider guarding the K/V norm allocation similar to how wk/wv are guarded:

if self.has_kv_weights:
    self.k_norm_fn = RMSNorm(...)
    self.v_norm_fn = RMSNorm(...)

4. store_full_length_kv logic is complex and could benefit from a helper

File: examples/models/llama/attention.py:685-694

The donor-layer detection logic in AttentionGemma4MHA.__init__ is duplicated between AttentionMHA (lines 405-411) and AttentionGemma4MHA (lines 681-694), with the Gemma4 version adding layer-type awareness. This is a maintenance risk. Consider extracting a shared helper.

5. hf_apply_rotary_emb now casts cos/sin to input dtype — verify this is safe with float32 freqs

File: examples/models/llama/rope.py:231-238

The change from cos.unsqueeze(unsqueeze_dim) to cos.unsqueeze(unsqueeze_dim).to(dtype=q.dtype) means that precomputed float32 cos/sin tables will be cast down to bf16/fp16 when the model uses those dtypes. This is correct for Gemma 4 (preserving input precision), but verify there's no existing model that relies on the RoPE computation happening in float32 for numerical stability. The original code didn't cast, which could mean some models assumed float32 RoPE.

The equivalent change to hf_apply_rotary_emb_to_k at line 272-273 has the same consideration.

6. AttentionGemma4MHA.forward deletes the freqs_cos/freqs_sin args then recreates them

File: examples/models/llama/attention.py:832

del freqs_cos, freqs_sin
freqs_cos, freqs_sin = self.rope.get_freqs_for_layer_type(...)

The del statement is unnecessary and slightly misleading since the values are immediately reassigned. The del is a stylistic choice (the PR author may have wanted to signal that the passed-in frequencies are intentionally ignored), but it could be replaced by simply reassigning. Not blocking, but worth noting.

7. Potential issue with self.kv_cache access on shared layers in AttentionMHA.forward

File: examples/models/llama/attention.py:597

if getattr(self.kv_cache, "is_ring_buffer", False):

When shared_kv is not None, the previous block is skipped (lines 590-595), but line 597 still unconditionally accesses self.kv_cache. For shared layers, self.kv_cache is None, so getattr(None, "is_ring_buffer", False) returns False — which happens to be correct behavior. But this relies on getattr working safely with None, which is fragile. Consider guarding this with if self.kv_cache is not None and ....

8. __getattr__ lazy import pattern in gemma4/__init__.py

File: examples/models/gemma4/__init__.py:9-18

The module-level __getattr__ that lazily creates Gemma4Model on first access is a clever way to avoid circular imports, but:

  • globals()["Gemma4Model"] = Gemma4Model caches on the module for subsequent accesses
  • The class itself is defined inside the __getattr__ function, which means each call creates a new class if the global cache were to somehow be invalidated
  • The Gemma4Model class is just an empty wrapper around Llama2Model — the __init__ only calls super().__init__(**kwargs) and adds no behavior

This pattern works but is unusual. If the only motivation is avoiding a circular import, a simpler alternative would be a lazy import at the call site in export_llama_lib.py.


Design Considerations (non-blocking)

  1. Config-driven vs. code-driven: The Gemma4 configs are JSON files with many Gemma4-specific fields (layer_types, rope_parameters, sliding_window, etc.) that are already supported by ModelArgs. This is a clean approach. However, the configs have 35/42 entries in layer_types arrays (lines 28 in each config) that are long and error-prone to maintain. A compressed representation (e.g., a pattern spec like "4 sliding, 1 full") could be less fragile.

  2. KV sharing dispatch: The _forward_layers method (llama_transformer.py:334-400) has grown complex with the YOCO/shared KV dispatch logic. The shared_kv dict now stores either tuples or dicts depending on whether ring buffers are in use (attention.py:783-811), adding a runtime type dispatch. This could be simplified with a shared KVPayload dataclass.

  3. Attention scale passthrough: The scale parameter is now threaded through SDPASDPACustomQuantizedSDPA (sdpa.py). This is a nice improvement that fixes what was presumably a latent bug where custom SDPA didn't respect the model's attention scale. The fix is correct and well-contained.


Test Coverage Assessment

The test suite (test_gemma4_support.py) is thorough for a correctness bring-up. It covers:

  • RoPE dtype preservation
  • Dual RoPE table dimensions
  • Shared layer donor selection
  • Ring vs. full KV cache assignment
  • Attention scale propagation
  • Transformer execution with shared layers + softcapping
  • Per-layer token embeddings
  • Layer scalar behavior
  • Export-to-ExecuTorch runtime parity

The weight converter tests cover both multimodal checkpoint stripping and unknown key detection. The model factory and flatbuffer tests are good additions for the supporting fixes.


Summary

This is a well-executed PR that correctly extends the Llama infrastructure for Gemma 4. The main actionable item is the duplicate model_architecture field (issue #1), which is a real bug. The other items are improvements and considerations of varying priority.


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: examples Issues related to demos under examples/ release notes: examples Changes to any of our example LLMs integrations, such as Llama3 and Llava

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants