Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 13 additions & 93 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Type, TypedDict
Expand Down Expand Up @@ -53,8 +52,6 @@ def forward(


ATTENTION_REGISTRY: Dict[str, Type[Attention]] = {}
_RECURRENT_GATED_DELTA_RULE_OP = None
_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = False


def register_attention(name: str):
Expand All @@ -67,38 +64,6 @@ def decorator(cls: Type[Attention]):
return decorator


def _get_recurrent_gated_delta_rule_op():
global _RECURRENT_GATED_DELTA_RULE_OP
global _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP

if _TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP:
return _RECURRENT_GATED_DELTA_RULE_OP

_TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
try:
_RECURRENT_GATED_DELTA_RULE_OP = (
torch.ops.llama.recurrent_gated_delta_rule.default
)
return _RECURRENT_GATED_DELTA_RULE_OP
except (AttributeError, RuntimeError):
pass

try:
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
except (ImportError, OSError, RuntimeError):
logging.debug("Failed to import custom ops library", exc_info=True)
return None

try:
_RECURRENT_GATED_DELTA_RULE_OP = (
torch.ops.llama.recurrent_gated_delta_rule.default
)
except (AttributeError, RuntimeError):
_RECURRENT_GATED_DELTA_RULE_OP = None

return _RECURRENT_GATED_DELTA_RULE_OP


class KVCache(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -760,43 +725,28 @@ def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor:
out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype)
return out.transpose(1, 2).contiguous()

def _gated_delta_rule_op(
def _recurrent_gated_delta_rule(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
batch_size = query.shape[0]
recurrent_gated_delta_rule_op = _get_recurrent_gated_delta_rule_op()
if recurrent_gated_delta_rule_op is not None:
return recurrent_gated_delta_rule_op(
query,
key,
value,
g,
beta,
self.recurrent_state[:batch_size],
)
return self._naive_gated_delta_rule_op(
query,
key,
value,
g,
beta,
)
# query/key/value: (batch, seq_len, num_heads, head_dim)
# g/beta: (batch, seq_len, num_heads)
initial_dtype = query.dtype
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]

def _naive_gated_delta_rule_op(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
batch_size, num_heads, sequence_length, _ = key.shape
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale

core_attn_out = torch.zeros(
batch_size,
Expand Down Expand Up @@ -830,36 +780,6 @@ def _naive_gated_delta_rule_op(
last_recurrent_state.to(self.recurrent_state.dtype)
)

return core_attn_out

def _recurrent_gated_delta_rule(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
# query/key/value: (batch, seq_len, num_heads, head_dim)
# g/beta: (batch, seq_len, num_heads)
initial_dtype = query.dtype
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32)
for x in (query, key, value, beta, g)
]

scale = 1.0 / (query.shape[-1] ** 0.5)
query = query * scale

core_attn_out = self._gated_delta_rule_op(
query,
key,
value,
g,
beta,
)
return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)

def forward(
Expand Down
72 changes: 0 additions & 72 deletions examples/models/llama/tests/test_export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
import tempfile
import unittest
from pathlib import Path

from executorch.devtools.backend_debug import get_delegation_info

Expand All @@ -28,7 +25,6 @@

from executorch.examples.models.llama.export_llama_lib import (
_export_llama,
_prepare_for_llama_export,
build_args_parser,
get_quantizer_and_quant_params,
)
Expand All @@ -41,39 +37,6 @@


class ExportLlamaLibTest(unittest.TestCase):
def _make_tiny_qwen35_params(self) -> dict:
return {
"dim": 64,
"hidden_dim": 128,
"n_heads": 4,
"head_dim": 16,
"n_kv_heads": 2,
"n_layers": 4,
"norm_eps": 1e-6,
"rope_theta": 10000000.0,
"use_scaled_rope": False,
"vocab_size": 256,
"use_hf_rope": True,
"partial_rotary_factor": 0.25,
"attention_qkv_bias": False,
"use_qk_norm": True,
"qk_norm_before_rope": True,
"attention_type": "mha",
"use_q_gate": True,
"rms_norm_add_unit_offset": True,
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 8,
"linear_value_head_dim": 8,
"linear_num_key_heads": 4,
"linear_num_value_heads": 4,
"layer_types": [
"linear_attention",
"full_attention",
"linear_attention",
"full_attention",
],
}

def test_has_expected_ops_and_op_counts(self):
"""
Checks the presence of unwanted expensive ops.
Expand Down Expand Up @@ -103,41 +66,6 @@ def test_has_expected_ops_and_op_counts(self):
for op, _op_info in delegation_info.delegation_by_operator.items():
self.assertTrue(op not in UNWANTED_OPS)

def test_tiny_qwen35_export_uses_recurrent_gated_delta_rule(self):
with tempfile.TemporaryDirectory() as temp_dir:
params_path = Path(temp_dir) / "tiny_qwen35.json"
params_path.write_text(json.dumps(self._make_tiny_qwen35_params()))

parser = build_args_parser()
args = parser.parse_args(
[
"--model",
"qwen3_5_0_8b",
"--params",
str(params_path),
"--use_kv_cache",
"--disable_dynamic_shape",
"--max_seq_length",
"8",
"--max_context_length",
"8",
]
)

llm_config = LlmConfig.from_args(args)
builder = _prepare_for_llama_export(llm_config).export()
assert builder.pre_autograd_graph_module is not None

recurrent_nodes = [
node
for node in builder.pre_autograd_graph_module.graph.nodes
if "auto_functionalized_v2" in str(node.target)
and node.args
and "llama.recurrent_gated_delta_rule" in str(node.args[0])
]

self.assertEqual(len(recurrent_nodes), 2)

@unittest.skipUnless(HAS_ARM_BACKEND, "ARM backend not available")
def test_get_quantizer_and_quant_params_returns_tosa_quantizer(self):
llm_config = LlmConfig()
Expand Down
105 changes: 0 additions & 105 deletions examples/models/llama/tests/test_qwen3_5_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

import unittest

import executorch.examples.models.llama.attention as attention_module
import torch

from executorch.examples.models.llama.attention import ATTENTION_REGISTRY
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
Expand Down Expand Up @@ -125,109 +123,6 @@ def test_gated_deltanet_no_input_pos_does_not_leak_state(self):
torch.allclose(state_after_first, state_after_second, atol=1e-5)
)

def test_gated_deltanet_chunked_prefill_matches_full_sequence(self):
torch.manual_seed(0)
args = self._make_args(
use_kv_cache=True,
use_q_gate=True,
linear_conv_kernel_dim=4,
linear_key_head_dim=4,
linear_value_head_dim=4,
linear_num_key_heads=2,
linear_num_value_heads=4,
)
rope = Rope(args)
attn_full = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_chunked = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_chunked.load_state_dict(attn_full.state_dict())

x = torch.randn(1, 5, args.dim)
dummy_freq = torch.zeros(1, 1)

full_output, _ = attn_full(
x,
dummy_freq,
dummy_freq,
input_pos=torch.tensor([0], dtype=torch.long),
)

chunk_outputs = []
for start, end in ((0, 3), (3, 4), (4, 5)):
output, _ = attn_chunked(
x[:, start:end],
dummy_freq,
dummy_freq,
input_pos=torch.tensor([start], dtype=torch.long),
)
chunk_outputs.append(output)

chunked_output = torch.cat(chunk_outputs, dim=1)

self.assertTrue(torch.allclose(chunked_output, full_output, atol=1e-5))
self.assertTrue(
torch.allclose(
attn_chunked.recurrent_state, attn_full.recurrent_state, atol=1e-5
)
)
self.assertTrue(
torch.allclose(attn_chunked.conv_state, attn_full.conv_state, atol=1e-5)
)

def test_gated_deltanet_custom_op_matches_fallback(self):
recurrent_op = attention_module._get_recurrent_gated_delta_rule_op()
if recurrent_op is None:
self.skipTest("llama::recurrent_gated_delta_rule is not available")

torch.manual_seed(0)
args = self._make_args(
use_kv_cache=True,
use_q_gate=True,
linear_conv_kernel_dim=4,
linear_key_head_dim=4,
linear_value_head_dim=4,
linear_num_key_heads=2,
linear_num_value_heads=4,
)
rope = Rope(args)
attn_custom = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_fallback = ATTENTION_REGISTRY["gated_deltanet"](args, 0, rope)
attn_fallback.load_state_dict(attn_custom.state_dict())

query = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
key = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_k_dim)
value = torch.randn(1, 3, attn_custom.num_v_heads, attn_custom.head_v_dim)
g = torch.randn(1, 3, attn_custom.num_v_heads)
beta = torch.sigmoid(torch.randn(1, 3, attn_custom.num_v_heads))

original_op = attention_module._RECURRENT_GATED_DELTA_RULE_OP
original_tried_loading = (
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP
)
try:
attention_module._RECURRENT_GATED_DELTA_RULE_OP = recurrent_op
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
custom_output = attn_custom._recurrent_gated_delta_rule(
query, key, value, g, beta
)

attention_module._RECURRENT_GATED_DELTA_RULE_OP = None
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = True
fallback_output = attn_fallback._recurrent_gated_delta_rule(
query, key, value, g, beta
)
finally:
attention_module._RECURRENT_GATED_DELTA_RULE_OP = original_op
attention_module._TRIED_LOADING_RECURRENT_GATED_DELTA_RULE_OP = (
original_tried_loading
)

self.assertTrue(torch.allclose(custom_output, fallback_output, atol=1e-5))
self.assertTrue(
torch.allclose(
attn_custom.recurrent_state, attn_fallback.recurrent_state, atol=1e-5
)
)


if __name__ == "__main__":
unittest.main()
3 changes: 1 addition & 2 deletions extension/aten_util/make_aten_functor_from_et_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
#pragma once
#include <type_traits>
#include <vector>
#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \
(!defined(_MSC_VER) && __cplusplus < 201703L)
#if __cplusplus < 201703L
#error "This header requires C++17"
#endif
#include <ATen/native/Resize.h>
Expand Down
Loading
Loading