From 173f799c033ebe71f1045edb7d4d312c58bdac8c Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Wed, 5 Mar 2025 16:19:40 +0800 Subject: [PATCH 1/4] Qualcomm AI Engine Direct - Optimize the performance for AR-N model Summary: - Fix the bug of rms norm builder - Use HuggingFace version RoPE to improve the performance due to stride = 1 in StrideSlice Op - Modificate the axis order of the conv in qkv, feedforward and output - Original (AR:128, CL:2048): QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,2048,1)->QNN_Transpose (1,128,1,2048)->self.output-> QNN_Transpose(1,128,2048,1) -> QNN_Reshape (1,1,128,2048) - New: QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,1,2048)->QNN_Transpose (1,1,128,2048)->self.output-> QNN_Transpose(1,128,1,2048) -> QNN_Reshape (1,1,128,2048) --- .../_passes/fuse_consecutive_transpose.py | 41 ++++++++----------- .../qualcomm/_passes/recompose_rms_norm.py | 16 +++++--- backends/qualcomm/builders/op_rms_norm.py | 17 ++++---- examples/qualcomm/oss_scripts/llama/llama.py | 22 ++++++++++ .../oss_scripts/llama/model/static_llama.py | 28 ++++++++----- 5 files changed, 73 insertions(+), 51 deletions(-) diff --git a/backends/qualcomm/_passes/fuse_consecutive_transpose.py b/backends/qualcomm/_passes/fuse_consecutive_transpose.py index 16ce3803076..58ebc83962e 100644 --- a/backends/qualcomm/_passes/fuse_consecutive_transpose.py +++ b/backends/qualcomm/_passes/fuse_consecutive_transpose.py @@ -55,12 +55,6 @@ def _clone_transpose( clone_permute_node.meta = n.meta users[i].replace_input_with(n, clone_permute_node) - def _is_dispensable(self, axis_order): - for index, value in enumerate(axis_order): - if index != value: - return False - return True - def _traverse(self, node): if node in self.visited or node.target not in self.op_map: return @@ -87,25 +81,22 @@ def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: axis_order = torch.arange(len(input_shape)).tolist() for node in self.nodes: axis_order = [axis_order[i] for i in node.args[1]] - # If axis order is just [0,1,2,3], we ignore permute node - if self._is_dispensable(axis_order): - for user in output_node.users.copy(): - user.replace_input_with(output_node, n.args[0]) - else: - with graph.inserting_after(input_node): - permute_op = exir_ops.edge.aten.permute_copy.default - permute_node = graph.create_node( - "call_function", permute_op, (input_node, axis_order) - ) - users = output_node.users.copy() - for user in users: - user.replace_input_with(output_node, permute_node) - - # copy metadata - permute_node.meta = output_node.meta - # Without "qnn_permute", we might obtain wrong input shape - if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]: - permute_node.meta[QCOM_INSERTED_PERMUTE] = True + + # Reserve [0,1,2,3] permute node to ensure the next node get the right axis order. + with graph.inserting_after(input_node): + permute_op = exir_ops.edge.aten.permute_copy.default + permute_node = graph.create_node( + "call_function", permute_op, (input_node, axis_order) + ) + users = output_node.users.copy() + for user in users: + user.replace_input_with(output_node, permute_node) + + # copy metadata + permute_node.meta = output_node.meta + # Without "qnn_permute", we might obtain wrong input shape + if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]: + permute_node.meta[QCOM_INSERTED_PERMUTE] = True # clear current stack self.nodes = [] diff --git a/backends/qualcomm/_passes/recompose_rms_norm.py b/backends/qualcomm/_passes/recompose_rms_norm.py index bfaddfc47b5..77feecf9c1f 100644 --- a/backends/qualcomm/_passes/recompose_rms_norm.py +++ b/backends/qualcomm/_passes/recompose_rms_norm.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch +from executorch.backends.qualcomm.builders.utils import get_parameter, is_parameter from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -16,8 +17,9 @@ class RecomposeRmsNorm(ExportPass): Merge decomposed operators back to one super node. """ - def __init__(self): - super().__init__() + def __init__(self, edge_program: torch.export.ExportedProgram): + super(RecomposeRmsNorm, self).__init__() + self.edge_program = edge_program def _get_eps_node(self, nodes): # eps: one of inputs of add node @@ -47,11 +49,15 @@ def call(self, graph_module: torch.fx.GraphModule): input_node = inp_0 if len(inp_0.users) == 2 else inp_1 else: raise RuntimeError( - f"Found a edge case of rms_node partitoin {src_partition}, which has {input_len} inputs" + f"Found a edge case of rms_node partition {src_partition}, which has {input_len} inputs" ) output_node = src_partition.output_nodes[0] - eps_node = self._get_eps_node(src_partition.nodes) + eps = self._get_eps_node(src_partition.nodes) + if isinstance(eps, torch.fx.Node) and is_parameter( + eps, self.edge_program + ): + eps = get_parameter(eps, self.edge_program).item() gamma_node = self._get_gamma_node(output_node) with graph.inserting_before(output_node): @@ -64,7 +70,7 @@ def call(self, graph_module: torch.fx.GraphModule): input_node, list(gamma_node.meta["val"].shape), gamma_node, - eps_node, + eps, ), ) users = output_node.users.copy() diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py index e5b4778312e..d224e34feb5 100644 --- a/backends/qualcomm/builders/op_rms_norm.py +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -12,7 +12,11 @@ import torch from executorch.backends.qualcomm.builders.utils import get_parameter -from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS +from executorch.backends.qualcomm.utils.constants import ( + QCOM_DATA, + QCOM_QUANT_ATTRS, + QCOM_ZERO_POINT, +) from executorch.exir.dialects._ops import ops as exir_ops from .node_visitor import NodeVisitor, register_node_visitor @@ -66,7 +70,7 @@ def define_node( nodes_to_wrappers, ) - # Fake node, nn module seems to be inconsistant with document + # Fake node, nn module seems to be inconsistent with document bias_tensor = torch.zeros(weight_tensor.shape) bias_node = torch.fx.Node( node.graph, @@ -78,6 +82,7 @@ def define_node( ) if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + bias_node.meta[QCOM_QUANT_ATTRS][QCOM_ZERO_POINT] = 0 bias_tensor_wrapper = self.define_tensor( bias_node, node, @@ -87,14 +92,6 @@ def define_node( ) epsilon = node.args[3] - if isinstance(epsilon, torch.fx.Node): - epsilon = get_parameter(epsilon, self.edge_program) - epsilon = ( - epsilon - if isinstance(epsilon, float) - else torch.finfo(epsilon.dtype).eps - ) - output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( node, diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 0829d99d57a..a999270c15b 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -539,6 +539,28 @@ def compile(args, pte_filename, tokenizer): if "model" in state_dict: state_dict = state_dict["model"] + # Change to HuggingFace weight to improve the performance of RoPE in HTP backend. + def permute(w, heads): + dim_0 = w.size(0) + dim_1 = w.size(1) + return ( + w.view(heads, dim_0 // heads // 2, 2, dim_1) + .transpose(1, 2) + .reshape(dim_0, dim_1) + ) + + n_heads = llama_instance_list[0].n_heads + n_kv_heads = llama_instance_list[0].n_kv_heads + n_layers = llama_instance_list[0].n_layers + + for layer_i in range(n_layers): + state_dict[f"layers.{layer_i}.attention.wq.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wq.weight"], n_heads + ) + state_dict[f"layers.{layer_i}.attention.wk.weight"] = permute( + state_dict[f"layers.{layer_i}.attention.wk.weight"], n_kv_heads + ) + for llama_instance in llama_instance_list: llama_instance.load_state_dict( state_dict, diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index ea8e2f5d319..dbb41cdb743 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -19,12 +19,12 @@ def apply_rotary_emb_single( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ) -> torch.Tensor: - x_r, x_i = x[..., ::2], x[..., 1::2] - + # Change to RoPE of huggingface version + x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] # brodcast for batch_prefill mode input x if x.dim() == 4: - freqs_cos = freqs_cos[None, :, None, :] - freqs_sin = freqs_sin[None, :, None, :] + freqs_cos = freqs_cos[None, None, :, :] + freqs_sin = freqs_sin[None, None, :, :] x_out_r = x_r * freqs_cos - x_i * freqs_sin x_out_i = x_r * freqs_sin + x_i * freqs_cos @@ -108,21 +108,27 @@ def forward_sha( hidden_states, (bsz, seq_len, 1, self.dim) ).transpose(1, 3) q = [ - wq_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + wq_sha(hidden_states) + .permute(0, 2, 3, 1) + .reshape(bsz, seq_len, self.head_dim) for wq_sha in self.wq_sha ] k = [ - wk_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + wk_sha(hidden_states) + .permute(0, 2, 3, 1) + .reshape(bsz, seq_len, self.head_dim) for wk_sha in self.wk_sha ] v = [ - wv_sha(hidden_states).reshape(bsz, self.head_dim, seq_len).transpose(1, 2) + wv_sha(hidden_states) + .permute(0, 2, 3, 1) + .reshape(bsz, seq_len, self.head_dim) for wv_sha in self.wv_sha ] for i in range(len(q)): q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) for i in range(len(k)): - k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).transpose(1, 2) output_y = [] kh, vh = [], [] @@ -249,10 +255,10 @@ def prepare_feedfoward_conv(self): def forward_feedfoward_conv(self, x): bsz, _, _ = x.size() - x = torch.reshape(x, (bsz, -1, self.dim, 1)) - x = x.transpose(1, 2) # Transpose right before and after Conv + x = torch.reshape(x, (bsz, -1, 1, self.dim)) + x = x.transpose(1, 3) # Transpose right before and after Conv x = self.w2_conv(F.silu(self.w1_conv(x)) * self.w3_conv(x)) - x = x.transpose(1, 2) + x = x.transpose(1, 3) x = torch.reshape(x, (bsz, -1, self.dim)) return x From 6d12ceb3e28541dcd8961792ea5b37a95c9589f9 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Sun, 9 Mar 2025 23:19:52 -0700 Subject: [PATCH 2/4] linting --- backends/qualcomm/_passes/fuse_consecutive_transpose.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/qualcomm/_passes/fuse_consecutive_transpose.py b/backends/qualcomm/_passes/fuse_consecutive_transpose.py index 58ebc83962e..04d96462c9f 100644 --- a/backends/qualcomm/_passes/fuse_consecutive_transpose.py +++ b/backends/qualcomm/_passes/fuse_consecutive_transpose.py @@ -81,8 +81,8 @@ def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: axis_order = torch.arange(len(input_shape)).tolist() for node in self.nodes: axis_order = [axis_order[i] for i in node.args[1]] - - # Reserve [0,1,2,3] permute node to ensure the next node get the right axis order. + + # Reserve [0,1,2,3] permute node to ensure the next node get the right axis order. with graph.inserting_after(input_node): permute_op = exir_ops.edge.aten.permute_copy.default permute_node = graph.create_node( From c5c149c94db1828f155715074fcc4a1903028435 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Wed, 12 Mar 2025 11:51:11 +0800 Subject: [PATCH 3/4] address review item and linting --- examples/qualcomm/oss_scripts/llama/model/static_llama.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index dbb41cdb743..93326fa172b 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -19,9 +19,11 @@ def apply_rotary_emb_single( x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ) -> torch.Tensor: - # Change to RoPE of huggingface version + # The implementation of RoPE in HuggingFace processes query and key with two half instead of interleaved way. + # The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend. + # Ref: https://github.com/huggingface/transformers/issues/25199 x_r, x_i = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] - # brodcast for batch_prefill mode input x + # broadcast for batch_prefill mode input x if x.dim() == 4: freqs_cos = freqs_cos[None, None, :, :] freqs_sin = freqs_sin[None, None, :, :] From c94c0bd6f8f14d0dfb7158afc901bdf013c9d945 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Thu, 13 Mar 2025 09:35:08 +0800 Subject: [PATCH 4/4] Add the comment for axis order change --- examples/qualcomm/oss_scripts/llama/model/static_llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index 93326fa172b..f7893792e00 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -106,6 +106,8 @@ def forward_sha( v_caches: Optional[List[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, seq_len, _ = hidden_states.shape + # In the HTP backend, the input axis order for the convolution operation is + # more efficient with [1, 1, seq_len, dim] compared to [1, seq_len, 1, dim]. hidden_states = torch.reshape( hidden_states, (bsz, seq_len, 1, self.dim) ).transpose(1, 3)