diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index a40198ea36f..cf0bc2bcfc0 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -489,11 +489,25 @@ jobs: name: "gemma3-1b" use-custom: [false, true] qconfig: ["4w", "nvfp4"] + runner: ["macos-14-xlarge"] + include: + - model: + id: "google/gemma-4-E2B-it" + name: "gemma4-e2b" + use-custom: true + qconfig: "4w" + runner: "macos-15-xlarge" + - model: + id: "google/gemma-4-E2B-it" + name: "gemma4-e2b" + use-custom: false + qconfig: "4w" + runner: "macos-15-xlarge" uses: pytorch/test-infra/.github/workflows/macos_job.yml@main secrets: inherit with: job-name: test-mlx-llm-${{ matrix.model.name }}${{ matrix.use-custom && '-custom' || '' }}-${{ matrix.qconfig }} - runner: macos-14-xlarge + runner: ${{ matrix.runner }} python-version: "3.12" submodules: recursive ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} @@ -506,12 +520,16 @@ jobs: MODEL_NAME="${{ matrix.model.name }}" USE_CUSTOM="${{ matrix.use-custom }}" QCONFIG="${{ matrix.qconfig }}" - CUSTOM_ARGS="" if [ "${USE_CUSTOM}" = "true" ]; then CUSTOM_ARGS="--use-custom-sdpa --use-custom-kv-cache" fi + QEMBEDDING_ARGS="--qembedding ${QCONFIG}" + if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then + QEMBEDDING_ARGS="" + fi + echo "::group::Install ExecuTorch and configure MLX build" ${CONDA_RUN} python install_executorch.py > /dev/null ${CONDA_RUN} cmake --preset mlx-release @@ -522,6 +540,13 @@ jobs: ${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) ${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}" + if [ "${MODEL_ID}" = "google/gemma-4-E2B-it" ]; then + # Gemma 4 requires a newer Transformers build than the CI-wide + # optimum-executorch pin currently brings in. Keep this pinned to the + # locally validated commit instead of floating on Transformers HEAD. + GEMMA4_TRANSFORMERS_COMMIT=61461a7bcb458db7cf6eeea49678b9ab776a7821 + ${CONDA_RUN} pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@${GEMMA4_TRANSFORMERS_COMMIT}" + fi echo "::endgroup::" ${CONDA_RUN} pip list @@ -531,7 +556,7 @@ jobs: --model-id "${MODEL_ID}" \ --output /tmp/${MODEL_NAME}.pte \ --qlinear ${QCONFIG} \ - --qembedding ${QCONFIG} \ + ${QEMBEDDING_ARGS} \ ${CUSTOM_ARGS} echo "::endgroup::" diff --git a/backends/mlx/examples/llm/README.md b/backends/mlx/examples/llm/README.md index f860c4f1ce0..738bbfb8c14 100644 --- a/backends/mlx/examples/llm/README.md +++ b/backends/mlx/examples/llm/README.md @@ -9,6 +9,7 @@ This example demonstrates how to export and run LLMs using the MLX delegate for - **KV Cache**: Efficient KV cache implementation for autoregressive generation - **Custom Ops**: Uses `mlx::custom_sdpa` and `mlx::kv_cache_update` for optimal execution on MLX - **Pybindings**: Run inference using ExecuTorch Python bindings +- **Gemma 4**: Text-only export and run flow supports processor-backed checkpoints such as `google/gemma-4-E2B-it` ## Requirements @@ -52,6 +53,24 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \ --use-custom-kv-cache \ --qlinear 4w \ --qembedding 4w + +# Gemma 4 text-only export +python -m executorch.backends.mlx.examples.llm.export_llm_hf \ + --model-id "google/gemma-4-E2B-it" \ + --output gemma4_hf_int4.pte \ + --use-custom-sdpa \ + --use-custom-kv-cache \ + --qlinear 4w +``` + +Gemma 4 support is currently validated for the text-only path using +`--use-custom-sdpa --use-custom-kv-cache --qlinear 4w`. + +Validated with `transformers` commit +`61461a7bcb458db7cf6eeea49678b9ab776a7821`: + +```bash +pip install -U "transformers @ git+https://github.com/huggingface/transformers.git@61461a7bcb458db7cf6eeea49678b9ab776a7821" ``` ### Options @@ -81,12 +100,24 @@ python -m executorch.backends.mlx.examples.llm.run_llm_hf \ --prompt "Explain quantum computing in simple terms" ``` +Gemma 4 checkpoints may use `AutoProcessor` instead of `AutoTokenizer`; `run_llm_hf` now supports both paths automatically for text-only prompts. + +Validated Gemma 4 run command: + +```bash +python -m executorch.backends.mlx.examples.llm.run_llm_hf \ + --pte gemma4_hf_int4.pte \ + --model-id google/gemma-4-E2B-it \ + --prompt "What is the capital of France?" \ + --max-new-tokens 50 +``` + ### Options | Option | Default | Description | |--------|---------|-------------| | `--pte` | `llama_hf.pte` | Path to .pte file | -| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer) | +| `--model-id` | `unsloth/Llama-3.2-1B-Instruct` | HuggingFace model ID (for tokenizer or processor) | | `--prompt` | `The quick brown fox` | Input prompt | | `--max-new-tokens` | `50` | Maximum tokens to generate | diff --git a/backends/mlx/examples/llm/export_llm_hf.py b/backends/mlx/examples/llm/export_llm_hf.py index 39f13e434be..fe6b8094f6b 100644 --- a/backends/mlx/examples/llm/export_llm_hf.py +++ b/backends/mlx/examples/llm/export_llm_hf.py @@ -50,6 +50,7 @@ def _export_with_optimum( model_id: str, + revision: Optional[str], output_path: str, max_seq_len: int, dtype: str, @@ -73,6 +74,7 @@ def _export_with_optimum( logger.info(f"Loading model using optimum-executorch: {model_id}") exportable = load_causal_lm_model( model_id, + revision=revision, dtype=dtype_str, max_seq_len=max_seq_len, ) @@ -124,6 +126,7 @@ def _export_with_optimum( def _export_with_custom_components( model_id: str, + revision: Optional[str], output_path: str, max_seq_len: int, dtype: str, @@ -166,20 +169,21 @@ def _export_with_custom_components( attn_implementation = "mlx" if use_custom_sdpa else None - # Detect sliding window models (e.g., gemma) - sliding_window = None - logger.info(f"Loading HuggingFace model: {model_id}") load_kwargs = { "torch_dtype": torch_dtype, "low_cpu_mem_usage": True, } + if revision is not None: + load_kwargs["revision"] = revision if attn_implementation: load_kwargs["attn_implementation"] = attn_implementation model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs) - # Check if model uses sliding window attention - sliding_window = getattr(model.config, "sliding_window", None) + # Check if model uses sliding window attention. Multimodal configs like + # Gemma 4 keep transformer attributes under text_config. + text_config = model.config.get_text_config() + sliding_window = getattr(text_config, "sliding_window", None) if sliding_window is not None: logger.info(f"Model has sliding_window={sliding_window}") # Cap max_seq_len to sliding window size for cache allocation @@ -188,11 +192,16 @@ def _export_with_custom_components( else: effective_cache_len = max_seq_len + # The HF ExecuTorch cache wrappers validate both generation_config.use_cache + # and the text config's use_cache flag before constructing static caches. + model.generation_config.use_cache = True model.generation_config.cache_implementation = "static" model.generation_config.cache_config = { "batch_size": 1, "max_cache_len": effective_cache_len, } + text_config = model.config.get_text_config() + text_config.use_cache = True model.eval() # Use HybridCache wrapper for sliding window models (stores cache as .cache), @@ -219,52 +228,26 @@ def _export_with_custom_components( ) if use_custom_kv_cache: - if sliding_window is not None: - # Use ring buffer cache for sliding window models - from executorch.backends.mlx.llm.source_transformation import ( - replace_hf_cache_with_mlx_ring_buffer, - ) + from executorch.backends.mlx.llm.source_transformation import ( + replace_hf_cache_with_mlx, + ) + if sliding_window is not None: logger.info( - f"Replacing StaticCache with RingBuffer KV cache " - f"(window_size={effective_cache_len})..." + "Replacing HuggingFace StaticCache with HFStaticCache " + f"(capped to sliding window: {effective_cache_len})..." ) - replace_hf_cache_with_mlx_ring_buffer( - exportable, - model.config, - max_batch_size=1, - window_size=effective_cache_len, - dtype=torch_dtype, - ) - - if use_custom_sdpa: - # Re-register attention with sliding window closure - from executorch.backends.mlx.llm.hf_attention import ( - register_mlx_sliding_window_attention, - ) - - register_mlx_sliding_window_attention(exportable) - model.config._attn_implementation = "mlx_sliding_window" - logger.info( - " Registered sliding window attention (mlx_sliding_window)" - ) - - logger.info(" RingBuffer KV cache installed successfully") else: - # Use standard linear cache for non-sliding-window models - from executorch.backends.mlx.llm.source_transformation import ( - replace_hf_cache_with_mlx, - ) - logger.info("Replacing HuggingFace StaticCache with HFStaticCache...") - replace_hf_cache_with_mlx( - exportable, - model.config, - max_batch_size=1, - max_cache_len=effective_cache_len, - dtype=torch_dtype, - ) - logger.info(" HFStaticCache installed successfully") + + replace_hf_cache_with_mlx( + exportable, + model.config, + max_batch_size=1, + max_cache_len=effective_cache_len, + dtype=torch_dtype, + ) + logger.info(" HFStaticCache installed successfully") from executorch.backends.mlx.llm.quantization import quantize_model_ @@ -341,6 +324,7 @@ def _save_program(executorch_program, output_path: str) -> None: def export_llama_hf( model_id: str, + revision: Optional[str], output_path: str, max_seq_len: int = 1024, dtype: str = "bf16", @@ -372,6 +356,7 @@ def export_llama_hf( ) _export_with_custom_components( model_id=model_id, + revision=revision, output_path=output_path, max_seq_len=max_seq_len, dtype=dtype, @@ -387,6 +372,7 @@ def export_llama_hf( logger.info("Using optimum-executorch pipeline (no custom components)") _export_with_optimum( model_id=model_id, + revision=revision, output_path=output_path, max_seq_len=max_seq_len, dtype=dtype, @@ -408,6 +394,12 @@ def main(): default="unsloth/Llama-3.2-1B-Instruct", help="HuggingFace model ID", ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HuggingFace model revision/commit to pin", + ) parser.add_argument( "--output", type=str, @@ -447,6 +439,7 @@ def main(): export_llama_hf( model_id=args.model_id, + revision=args.revision, output_path=args.output, max_seq_len=args.max_seq_len, dtype=args.dtype, diff --git a/backends/mlx/examples/llm/run_llm_hf.py b/backends/mlx/examples/llm/run_llm_hf.py index ca3d0468114..c15bcd89c46 100644 --- a/backends/mlx/examples/llm/run_llm_hf.py +++ b/backends/mlx/examples/llm/run_llm_hf.py @@ -7,10 +7,11 @@ # LICENSE file in the root directory of this source tree. """ -Run exported Llama model (from HuggingFace) using ExecuTorch pybindings. +Run exported HuggingFace LLM using ExecuTorch pybindings. This script runs models exported using export_llm_hf.py. It loads the tokenizer -directly from HuggingFace using the same model ID used during export. +or processor directly from HuggingFace using the same model ID used during +export. Usage: python -m executorch.backends.mlx.examples.llm.run_llm_hf \ @@ -25,7 +26,7 @@ import torch from executorch.runtime import Runtime, Verification -from transformers import AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -46,15 +47,66 @@ def _get_max_input_seq_len(program) -> int: return sizes[1] if len(sizes) >= 2 else 1 +def _load_text_processor(model_id: str, revision: str | None): + """ + Load a text processor for the model. + + Prefer AutoTokenizer for text-only prompting, even for checkpoints that + also ship an AutoProcessor. Some hybrid checkpoints (for example Gemma 4) + expose both, but the tokenizer path is the more stable interface for the + plain text generation flow exercised by this runner. + """ + logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") + try: + tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision) + return tokenizer, False + except Exception as exc: + logger.info(f"AutoTokenizer unavailable for {model_id}: {exc}") + + try: + processor = AutoProcessor.from_pretrained(model_id, revision=revision) + if hasattr(processor, "apply_chat_template") and hasattr(processor, "decode"): + logger.info(f"Loaded processor from HuggingFace: {model_id}") + return processor, True + except Exception as exc: + logger.info(f"AutoProcessor unavailable for {model_id}: {exc}") + + raise RuntimeError(f"Could not load tokenizer or processor for {model_id}") + + +def _apply_chat_template(text_processor, messages) -> str: + try: + return text_processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + except TypeError: + return text_processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + +def _get_eos_token_id(text_processor): + eos_token_id = getattr(text_processor, "eos_token_id", None) + if eos_token_id is not None: + return eos_token_id + tokenizer = getattr(text_processor, "tokenizer", None) + return getattr(tokenizer, "eos_token_id", None) + + def run_inference( pte_path: str, model_id: str, + revision: str | None, prompt: str, max_new_tokens: int = 50, ) -> str: """Run inference on the exported HuggingFace model.""" - logger.info(f"Loading tokenizer from HuggingFace: {model_id}...") - tokenizer = AutoTokenizer.from_pretrained(model_id) + text_processor, uses_processor = _load_text_processor(model_id, revision) logger.info(f"Loading model from {pte_path}...") et_runtime = Runtime.get() @@ -67,14 +119,18 @@ def run_inference( logger.info(f"Encoding prompt: {prompt!r}") messages = [{"role": "user", "content": prompt}] - formatted_prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - input_ids = tokenizer.encode(formatted_prompt, return_tensors="pt") + formatted_prompt = _apply_chat_template(text_processor, messages) + if uses_processor: + input_ids = text_processor(text=formatted_prompt, return_tensors="pt")[ + "input_ids" + ] + else: + input_ids = text_processor.encode(formatted_prompt, return_tensors="pt") logger.info(f"Input shape: {input_ids.shape}") generated_tokens = input_ids[0].tolist() seq_len = input_ids.shape[1] + eos_token_id = _get_eos_token_id(text_processor) start_time = time.time() @@ -120,7 +176,7 @@ def run_inference( next_token = torch.argmax(next_token_logits).item() generated_tokens.append(next_token) - if next_token == tokenizer.eos_token_id: + if eos_token_id is not None and next_token == eos_token_id: logger.info(f"EOS token reached at position {i + 1}") break @@ -135,12 +191,12 @@ def run_inference( # Decode only the newly generated tokens (not the input prompt) new_tokens = generated_tokens[seq_len:] - generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) + generated_text = text_processor.decode(new_tokens, skip_special_tokens=True) return generated_text def main(): - parser = argparse.ArgumentParser(description="Run exported HuggingFace Llama model") + parser = argparse.ArgumentParser(description="Run exported HuggingFace LLM") parser.add_argument( "--pte", type=str, @@ -151,7 +207,13 @@ def main(): "--model-id", type=str, default="unsloth/Llama-3.2-1B-Instruct", - help="HuggingFace model ID (used to load tokenizer)", + help="HuggingFace model ID (used to load tokenizer or processor)", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Optional HuggingFace model revision/commit to pin", ) parser.add_argument( "--prompt", @@ -171,6 +233,7 @@ def main(): generated_text = run_inference( pte_path=args.pte, model_id=args.model_id, + revision=args.revision, prompt=args.prompt, max_new_tokens=args.max_new_tokens, ) diff --git a/backends/mlx/llm/cache.py b/backends/mlx/llm/cache.py index 9709980689b..2890e823499 100644 --- a/backends/mlx/llm/cache.py +++ b/backends/mlx/llm/cache.py @@ -12,6 +12,7 @@ Provides reusable KV cache implementations optimized for the MLX backend: """ +import inspect from typing import Tuple import torch @@ -21,6 +22,62 @@ from executorch.backends.mlx import custom_ops as _mlx_custom_ops # noqa: F401 +def resolve_hf_text_config(config): + """Return the text config for multimodal HF models, or the config itself.""" + if hasattr(config, "get_text_config"): + return config.get_text_config() + return getattr(config, "text_config", config) + + +def resolve_hf_cache_layout(config): + """ + Return per-cache-layer metadata for HuggingFace hybrid/static caches. + + Some models such as Gemma 4 use different KV geometries depending on the + attention layer type. Match the upstream `transformers` hybrid cache layout + so our replacement cache allocates the same number of layers with the same + `(num_heads, head_dim)` for each backing cache entry. + """ + text_config = resolve_hf_text_config(config) + layer_types = getattr(text_config, "layer_types", None) + + if layer_types is None: + if getattr(text_config, "sliding_window", None) is not None: + layer_types = ["sliding_attention" for _ in range(text_config.num_hidden_layers)] + else: + layer_types = ["full_attention" for _ in range(text_config.num_hidden_layers)] + else: + layer_types = list(layer_types) + + if hasattr(text_config, "num_kv_shared_layers"): + layer_types = layer_types[: -text_config.num_kv_shared_layers] + + if hasattr(text_config, "global_head_dim"): + head_dims = [ + text_config.global_head_dim if layer_type == "full_attention" else text_config.head_dim + for layer_type in layer_types + ] + num_heads = [ + text_config.num_global_key_value_heads + if layer_type == "full_attention" and getattr(text_config, "attention_k_eq_v", False) + else text_config.num_key_value_heads + for layer_type in layer_types + ] + else: + head_dim = getattr( + text_config, + "head_dim", + text_config.hidden_size // text_config.num_attention_heads, + ) + num_head = getattr( + text_config, "num_key_value_heads", text_config.num_attention_heads + ) + head_dims = [head_dim for _ in layer_types] + num_heads = [num_head for _ in layer_types] + + return layer_types, num_heads, head_dims + + class KVCache(nn.Module): """ MLX-optimized KV cache with ExecutorTorch llama KVCache interface. @@ -326,14 +383,13 @@ def __init__( device: Device for cache tensors (default: None = CPU) dtype: Data type for cache tensors (default: torch.float32) """ - # Resolve dimensions from config BEFORE calling parent - num_layers = config.num_hidden_layers - num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) + # Resolve dimensions from the text config before calling parent. Multimodal + # configs like Gemma 4 expose transformer dims under text_config. + text_config = resolve_hf_text_config(config) + layer_types, num_heads, head_dims = resolve_hf_cache_layout(config) + num_model_layers = text_config.num_hidden_layers actual_max_cache_len = max_cache_len or getattr( - config, "max_position_embeddings", 2048 + text_config, "max_position_embeddings", 2048 ) # Initialize parent StaticCache with required arguments @@ -344,19 +400,50 @@ def __init__( device=device, dtype=dtype, ) - # Call early_initialization to ensure parent's layers are fully initialized - self.early_initialization( - batch_size=max_batch_size, - num_heads=num_heads, - head_dim=head_dim, - dtype=dtype, - device=device, - ) + # Newer HF cache implementations already support per-layer layouts in + # early_initialization(). Keep that path for Gemma 4, and only fall + # back to manual layer initialization for the older CI-pinned API. + try: + self.early_initialization( + batch_size=max_batch_size, + num_heads=num_heads, + head_dim=head_dims, + dtype=dtype, + device=device, + ) + except TypeError: + for layer, layer_num_heads, layer_head_dim in zip( + self.layers, num_heads, head_dims + ): + fake_keys_tensor = torch.zeros( + (max_batch_size, layer_num_heads, 0, layer_head_dim), + dtype=dtype, + device=device, + ) + lazy_init_sig = inspect.signature(layer.lazy_initialization) + # Older pinned HF caches take a single fake tensor, while newer + # versions expect both key_states and value_states separately. + if len(lazy_init_sig.parameters) == 1: + layer.lazy_initialization(fake_keys_tensor) + else: + fake_values_tensor = torch.zeros( + (max_batch_size, layer_num_heads, 0, layer_head_dim), + dtype=dtype, + device=device, + ) + layer.lazy_initialization(fake_keys_tensor, fake_values_tensor) + + # Some models (for example Gemma 4) only allocate cache entries for the + # non-shared KV layers. Mirror the parent StaticCache layout exactly so + # layer_idx values passed to update() line up with our backing cache. + num_cache_layers = len(self.layers) # Store dimensions as instance attributes - self.num_layers = num_layers + self.num_model_layers = num_model_layers + self.num_layers = num_cache_layers + self.layer_types = layer_types self.num_heads = num_heads - self.head_dim = head_dim + self.head_dim = head_dims # Create KVCache wrappers for each layer - these use mlx::kv_cache_update # Named 'kv_cache' to match optimum-executorch's ETCustomStaticCache pattern @@ -365,12 +452,12 @@ def __init__( KVCache( max_batch_size=max_batch_size, max_context_length=actual_max_cache_len, - n_heads=num_heads, - head_dim=head_dim, + n_heads=layer_num_heads, + head_dim=layer_head_dim, enable_dynamic_shape=True, dtype=dtype, ) - for _ in range(num_layers) + for layer_num_heads, layer_head_dim in zip(num_heads, head_dims) ] ) @@ -394,18 +481,31 @@ def update( key_states: New key states [batch_size, num_heads, seq_len, head_dim] value_states: New value states [batch_size, num_heads, seq_len, head_dim] layer_idx: Index of the layer to update - cache_kwargs: Dictionary containing 'cache_position' tensor with start position + cache_kwargs: Optional dictionary containing 'cache_position' tensor + with start position. Newer HF StaticCache callers seed + `self.layers[layer_idx].cumulative_length` directly and do not + pass cache_kwargs. Returns: Tuple of (key_cache, value_cache) for the full cache after update """ - assert ( - cache_kwargs is not None - ), "cache_kwargs must be provided with 'cache_position'" - cache_position = cache_kwargs.get("cache_position") - assert ( - cache_position is not None - ), "cache_position must be provided in cache_kwargs" + if cache_kwargs is not None: + cache_position = cache_kwargs.get("cache_position") + else: + cache_position = None + + if cache_position is None: + # Current HF ExecuTorch wrappers copy the requested cache position + # into each StaticCache layer's cumulative_length before forward(). + if hasattr(self.layers[layer_idx], "cumulative_length"): + cache_position = self.layers[layer_idx].cumulative_length + else: + raise RuntimeError( + "cache_position was not provided and the pinned " + "transformers StaticCache layer does not expose " + "cumulative_length" + ) + assert isinstance( cache_position, torch.Tensor ), "cache_position must be a tensor" diff --git a/backends/mlx/llm/hf_attention.py b/backends/mlx/llm/hf_attention.py index 9e3c864dce6..f2a01c9e653 100644 --- a/backends/mlx/llm/hf_attention.py +++ b/backends/mlx/llm/hf_attention.py @@ -89,8 +89,10 @@ def mlx_sdpa_with_start_pos_forward( def sdpa_mask_passthrough( batch_size: int, - cache_position: torch.Tensor, - kv_length: int, + cache_position: Optional[torch.Tensor] = None, + q_length: Optional[int] = None, + kv_length: Optional[int] = None, + q_offset: Optional[Union[int, torch.Tensor]] = None, kv_offset: int = 0, mask_function: Optional[Callable] = None, attention_mask: Optional[torch.Tensor] = None, @@ -139,6 +141,27 @@ def get_mlx_sliding_window_sdpa(exportable_module) -> Callable: Attention function compatible with HuggingFace's attention interface. """ + def _resolve_cache_layer_idx(module: torch.nn.Module, cache) -> Optional[int]: + """ + Map a transformer layer index to the backing cache slot index. + + Hybrid/shared-KV models like Gemma 4 only allocate cache entries for the + non-shared KV layers. Shared layers expose `kv_shared_layer_index`, which + points at the earlier cache-producing layer they reuse. + """ + layer_idx = getattr(module, "layer_idx", None) + if layer_idx is None: + return None + + if layer_idx < len(cache.kv_cache): + return layer_idx + + shared_layer_idx = getattr(module, "kv_shared_layer_index", None) + if shared_layer_idx is not None and shared_layer_idx < len(cache.kv_cache): + return shared_layer_idx + + return None + def _sliding_window_sdpa_forward( module: torch.nn.Module, query: torch.Tensor, # [B, num_heads, seq_len, head_dim] - BHSD @@ -165,6 +188,7 @@ def _sliding_window_sdpa_forward( attn_mask = None start_pos = 0 + layer_cache = None if layer_idx is not None and position_ids is not None: start_pos = position_ids[0][0].item() @@ -173,7 +197,9 @@ def _sliding_window_sdpa_forward( cache = getattr(exportable_module, "cache", None) if cache is not None: - layer_cache = cache.kv_cache[layer_idx] + cache_layer_idx = _resolve_cache_layer_idx(module, cache) + if cache_layer_idx is not None: + layer_cache = cache.kv_cache[cache_layer_idx] if isinstance(layer_cache, RingBufferKVCache): attn_mask = layer_cache.create_sliding_window_mask( start_pos, seq_len @@ -182,11 +208,19 @@ def _sliding_window_sdpa_forward( # stop_pos = start_pos + seq_len = buffer_size start_pos = layer_cache.buffer_size - seq_len + # Hybrid models use one global HF attention implementation. Sliding + # layers need the ring-buffer mask path, while full-attention layers + # should keep the regular causal SDPA path even under the same hook. if attn_mask is None: - raise RuntimeError( - f"Sliding window attention at layer {layer_idx} requires a " - f"RingBufferKVCache, but none was found. Ensure the model's " - f"cache is set up with RingBufferKVCache for sliding window layers." + return mlx_sdpa_with_start_pos_forward( + module, + query, + key, + value, + attention_mask, + position_ids=position_ids, + scaling=scaling, + **kwargs, ) output = torch.ops.mlx.custom_sdpa( diff --git a/backends/mlx/llm/source_transformation.py b/backends/mlx/llm/source_transformation.py index d90073c633e..06a45b9e22b 100644 --- a/backends/mlx/llm/source_transformation.py +++ b/backends/mlx/llm/source_transformation.py @@ -19,7 +19,13 @@ import torch import torch.nn as nn -from executorch.backends.mlx.llm.cache import HFStaticCache, KVCache, RingBufferKVCache +from executorch.backends.mlx.llm.cache import ( + HFStaticCache, + KVCache, + RingBufferKVCache, + resolve_hf_cache_layout, + resolve_hf_text_config, +) logger = logging.getLogger(__name__) @@ -123,9 +129,17 @@ def replace_hf_cache_with_mlx( def _install_cache(attr_name): setattr(module, attr_name, mlx_cache) - for i, layer_cache in enumerate(mlx_cache.kv_cache): + for i, (cache_layer, layer_cache) in enumerate( + zip(mlx_cache.layers, mlx_cache.kv_cache) + ): setattr(module, f"key_cache_{i}", layer_cache.k_cache) setattr(module, f"value_cache_{i}", layer_cache.v_cache) + if hasattr(cache_layer, "cumulative_length"): + setattr( + module, + f"cumulative_length_{i}", + cache_layer.cumulative_length, + ) if hasattr(module, "static_cache"): assert isinstance( @@ -171,12 +185,6 @@ def replace_hf_cache_with_mlx_ring_buffer( """ from transformers.cache_utils import StaticCache - num_layers = config.num_hidden_layers - num_kv_heads = getattr(config, "num_key_value_heads", config.num_attention_heads) - head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) - # Create HFStaticCache with ring buffer layers mlx_cache = HFStaticCache( config=config, @@ -185,22 +193,39 @@ def replace_hf_cache_with_mlx_ring_buffer( dtype=dtype, ) - # Replace each layer's KVCache with RingBufferKVCache - for i in range(num_layers): - ring_cache = RingBufferKVCache( + # Replace only the sliding-window cache entries with ring buffers, while + # preserving full-attention entries as linear caches. Hybrid models like + # Gemma 4 mix both layouts and can also vary head_dim per cache layer. + layer_types, num_heads, head_dims = resolve_hf_cache_layout(config) + num_cache_layers = len(mlx_cache.layers) + num_ring_layers = 0 + for i, (layer_type, layer_num_heads, layer_head_dim) in enumerate( + zip(layer_types, num_heads, head_dims) + ): + if layer_type != "sliding_attention": + continue + mlx_cache.kv_cache[i] = RingBufferKVCache( max_batch_size=max_batch_size, max_context_length=window_size, - n_heads=num_kv_heads, - head_dim=head_dim, + n_heads=layer_num_heads, + head_dim=layer_head_dim, dtype=dtype, ) - mlx_cache.kv_cache[i] = ring_cache + num_ring_layers += 1 def _install_cache(attr_name): setattr(module, attr_name, mlx_cache) - for i, layer_cache in enumerate(mlx_cache.kv_cache): + for i, (cache_layer, layer_cache) in enumerate( + zip(mlx_cache.layers, mlx_cache.kv_cache) + ): setattr(module, f"key_cache_{i}", layer_cache.k_cache) setattr(module, f"value_cache_{i}", layer_cache.v_cache) + if hasattr(cache_layer, "cumulative_length"): + setattr( + module, + f"cumulative_length_{i}", + cache_layer.cumulative_length, + ) if hasattr(module, "static_cache"): assert isinstance( @@ -218,8 +243,8 @@ def _install_cache(attr_name): raise ValueError("Module must have 'static_cache' or 'cache' attribute") logger.info( - f"Installed RingBufferKVCache: {num_layers} layers, " - f"window_size={window_size}, heads={num_kv_heads}, head_dim={head_dim}" + f"Installed hybrid MLX cache: {num_ring_layers} ring-buffer layers / " + f"{num_cache_layers} total cache layers, window_size={window_size}" ) return module diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 99e20114ea7..5bd3bf263d1 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -19,7 +19,6 @@ #include #include -#include #include #include #include @@ -285,6 +284,13 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { processed->Free(); } return Error::InvalidProgram; + } catch (...) { + ET_LOG(Error, "Failed to load MLX program: unknown non-std exception"); + handle->~MLXHandle(); + if (processed != nullptr) { + processed->Free(); + } + return Error::InvalidProgram; } return handle; @@ -416,6 +422,9 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { } catch (const std::exception& e) { ET_LOG(Error, "MLX execute failed: %s", e.what()); return Error::Internal; + } catch (...) { + ET_LOG(Error, "MLX execute failed: unknown non-std exception"); + return Error::Internal; } }