Add Gemma 4 MLX install-path support#19065
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19065
Note: Links to docs will display an error until the docs builds have been completed. ❌ 11 Awaiting Approval, 2 New FailuresAs of commit 719d2e8 with merge base d0b7934 ( AWAITING APPROVAL - The following workflows need approval before CI can run:
NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @zeel2104! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
This PR needs a
|
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
|
|
||
| # Check if model uses sliding window attention | ||
| sliding_window = getattr(model.config, "sliding_window", None) | ||
| # Check if model uses sliding window attention. Multimodal configs like |
There was a problem hiding this comment.
Does this regress gemma3?
There was a problem hiding this comment.
I don’t expect this to regress Gemma 3. The change is just switching the sliding-window lookup to model.config.get_text_config(), which also covers the plain text config case and is needed for Gemma 4 where those attrs live under text_config. I scoped the logic to the same attribute lookup, not a Gemma-4-specific branch. I can also rerun a Gemma 3 smoke test and report back.
There was a problem hiding this comment.
Yeah, it would be great to try on gemma3 as a smoke test, that would be great.
If you are unable to access the version from Google, try the unsloth version unsloth/gemma-3-1b-it (https://github.com/pytorch/executorch/blob/main/.github/workflows/mlx.yml#L469C18-L469C39)
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _iter_mlx_backend_candidates(): |
There was a problem hiding this comment.
This code should not be needed. Did you do:
python install_executorch.py --editable
on a mac machine with xcode installed? If so, in the install logs, did you see a comment about MLX installation being skipped for some reason?
| } | ||
|
|
||
| try { | ||
| std::cerr << "MLX init: constructing handle" << std::endl; |
There was a problem hiding this comment.
Yes, this was debug-only while I was chasing the install/runtime registration issue. I’ll remove the std::cerr logging before merge.
|
Looks fantastic! A couple questions:
|
|
|
||
| QEMBEDDING_ARGS="--qembedding ${QCONFIG}" | ||
| if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then | ||
| QEMBEDDING_ARGS="" |
|
|
||
| logger.info(f"Loading model from {pte_path}...") | ||
| et_runtime = Runtime.get() | ||
| et_runtime = _ensure_mlx_backend_registered() |
There was a problem hiding this comment.
This shouldn't be needed, see comment on the install process.
There was a problem hiding this comment.
That’s fair. I added this while debugging the installed-package path locally because MLXBackend was not being registered from the installed package, and I wanted a way to keep validating the runtime path. Since the install-path issue is now fixed, I’ll remove it and rely on the normal install flow.
| # Decode only the newly generated tokens (not the input prompt) | ||
| new_tokens = generated_tokens[seq_len:] | ||
| generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) | ||
| generated_text = text_processor.decode(new_tokens, skip_special_tokens=True) |
There was a problem hiding this comment.
Does this break the path where uses_processor=False?
Can we unify these two paths somehow?
There was a problem hiding this comment.
I ended up unifying this path. text_processor is now either an AutoProcessor or an AutoTokenizer, and both decode through text_processor.decode(...), so the uses_processor=False case should still work. The remaining split is only at encode time, where AutoProcessor needs processor(text=..., return_tensors="pt") and AutoTokenizer still uses encode(...).
I don’t expect a Gemma 3 regression from these changes. I did not get a non-custom Gemma 4 path to a validated state here; the issues I hit were around Gemma 4’s hybrid/shared-KV cache layout and mixed sliding/full-attention behavior, so I focused on the custom SDPA + custom KV cache path. I also did not land That’s why docs and CI are limited to that exact configuration in this PR. |
|
@metascroy I ran |
|
I tried to rerun a Gemma 3 smoke test locally, but I’m currently blocked by Hugging Face access on The request fails at model download with:
So I wasn’t able to complete a Gemma 3 end-to-end rerun in this environment. I still don’t expect this change to regress Gemma 3, since the relevant change here is switching the sliding-window lookup to Let me know if you’d like me to dig further into the non-custom path or embedding quant in a follow-up. |
For gemma3 verification, you can use the unsloth version model_id="unsloth/gemma-3-1b-it", which isn't gated. This is what we use in CI: https://github.com/pytorch/executorch/blob/main/.github/workflows/mlx.yml#L469 |
Can you say a bit more on what you mean by reliably? Did it fail to lower or run? Or did you run into model quality issues with quantized embeddings? On non-custom path: I think it is fine to leave as follow-up, I was just curious about the specific errors you saw. |
| ET_LOG(Error, "MLX execute failed: %s", e.what()); | ||
| return Error::Internal; | ||
| } catch (...) { | ||
| ET_LOG(Error, "MLX execute failed: unknown non-std exception"); |
There was a problem hiding this comment.
No, I did not specifically hit those C++ catch-all paths.
The failures I was debugging were earlier in the flow:
- Python/export-time Gemma 4 compatibility issues in the HF export path
- installed/editable install issues around getting the MLX path working cleanly
- the
DEBUG=releaseeditable install failure insetup.py
So those catch blocks were not the source of the Gemma 4 bring-up work here.
| } | ||
| return Error::InvalidProgram; | ||
| } catch (...) { | ||
| ET_LOG(Error, "Failed to load MLX program: unknown non-std exception"); |
| # is Release. | ||
| def get_build_type(is_debug=None) -> str: | ||
| debug = int(os.environ.get("DEBUG", 0) or 0) if is_debug is None else is_debug | ||
| if is_debug is None: |
There was a problem hiding this comment.
Were these changes for debugging only?
There was a problem hiding this comment.
No, these were not debug-only.
This came from the editable install path failing in my environment because DEBUG=release, while the existing code assumed DEBUG was always integer-like. The get_build_type() change makes that handling robust for string values like release / debug / true / false, which unblocked python install_executorch.py --editable for me.
I re-ran the editable install after this change and verified that MLXBackend registers correctly there.
There was a problem hiding this comment.
I'd rather not touch setup.py for this task, unless it is actually needed.
If things work with "python install_executorch.py --editable", then let's leave these setup improvements for another PR.
For |
I reran the Gemma 3 smoke test locally using the ungated CI model Tested with: The failure happens during export, before lowering/runtime:
So at least in my local environment, this does look like a Gemma 3 regression in the custom KV-cache path rather than just a Gemma 4-only issue. The later |
|
Happy to investigate that Gemma 3 custom-cache regression further if you want that covered before merge, or I can keep this PR scoped strictly to the Gemma 4 path that was validated. |
Let's see what CI says. You can keep the change scoped to gemma 4, but we cannot have gemma 3 regressing because of your change. |
|
@zeel2104 CI for gemma3 (custom path only) is failing with: whereas it was previously passing. I suspect there is a breaking change in HF interfaces, and your changes for the custom path are implicitly depending on this breaking change. Can you make sure your changes work against the pin in https://github.com/pytorch/executorch/blob/main/.ci/docker/ci_commit_pins/optimum-executorch.txt (which is what we run in CI). See the tests in https://github.com/pytorch/executorch/blob/main/.github/workflows/mlx.yml for setup (and transformer version we pin against). Let me know if this isn't possible to do for Gemma4. |
|
Makes sense. The |
Thanks, I tracked this down to an HF cache API compatibility issue. My custom cache replacement had started assuming newer HF cache-layer behavior than the version pinned in CI. I updated it to handle both the older pinned interface and the newer Gemma 4-capable interface. After the fix:
So this should avoid the Gemma 3 regression while keeping the Gemma 4 support intact. |
|
Re-running CI |
|
Gemma3 is working again, but it looks like gemma4 failed in CI :( |
| if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then | ||
| # Gemma 4 requires a newer Transformers build than the CI-wide | ||
| # optimum-executorch pin currently brings in. | ||
| ${CONDA_RUN} pip install -U "transformers @ git+https://github.com/huggingface/transformers.git" |
There was a problem hiding this comment.
Can we pin on something specific? Whatever version you pin on, add to README under gemma4 section.
|
The Gemma 4 failure changed after the last CI fix, export and runtime now work, but output quality regressed on a floating I pinned Gemma 4 to the I added the same pin to the README as well. I haven’t touched the Qwen35-MoE threshold yet since that still looks separate. |
|
@zeel2104 it looks like the gemma4 test is failing in CI with: |
|
I pushed one more Gemma 4 follow-up. CI is getting through export and runtime now, so I updated I left the Qwen35-MoE threshold unchanged since that still looks separate. |
Yeah, you can ignore the Qwen3.5-MOE. It is unrelated. Re-running CI with your latest changes. |
|
Again the gemma4 CI failed, it was due to moving dependencies in the validation path. First, CI was using a To make the Gemma 4 path reproducible, I pinned both:
I also wired the model revision through the export/run scripts and added the same pins to the README. |
| else: | ||
| raise NotImplementedError(f"Support for input {arg} is not implemented") | ||
|
|
||
| placeholder_nodes = { |
There was a problem hiding this comment.
I don't follow this change.
Why is gemma4 sensistive to this?
There was a problem hiding this comment.
I got here by diffing a previously working Gemma 4 .pte against a fresh export.
What changed there was the slot assignment for the two rotary constants used by sliding-window vs full attention. This change was just to make that assignment deterministic instead of depending on raw placeholder traversal order.
Gemma 4 is where I noticed it because that model exercises both constants in the same path.
If you’d prefer, I can drop this
|
@metascroy At this point it looks more like a Gemma 4 |
It would be good to get 4w working. Let me try checking out your PR today to see if I notice anything. |
Summary
Enable Gemma 4 on the MLX backend through the HuggingFace export/run path.
This PR:
backends/mlx/examples/llm/export_llm_hf.pybackends/mlx/examples/llm/run_llm_hf.pyPYTHONPATHThis PR does not add Gemma 4 support to the internal
export_llm/examples/models/gemma4/path.Test plan
Manual validation on Apple Silicon macOS using the installed package from
.venv/site-packages:python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --model-id google/gemma-4-E2B-it \ --output /tmp/gemma4_custom_qlinear_only_installed.pte \ --qlinear 4w \ --use-custom-sdpa \ --use-custom-kv-cache python -m executorch.backends.mlx.examples.llm.run_llm_hf \ --pte /tmp/gemma4_custom_qlinear_only_installed.pte \ --model-id google/gemma-4-E2B-it \ --prompt "What is the capital of France?" \ --max-new-tokens 50###Validation
-installed import path resolves from .venv/lib/python3.12/site-packages/executorch/...
-MLXBackend is registered from the installed package
-export succeeds for google/gemma-4-E2B-it with --qlinear 4w --use-custom-sdpa --use-custom-kv-cache
-runtime succeeds without PYTHONPATH
-generated output contains Paris
Additional note: