diff --git a/backends/qualcomm/aot/wrappers/TensorWrapper.h b/backends/qualcomm/aot/wrappers/TensorWrapper.h index c973196e9d5..4aec5f71b7e 100644 --- a/backends/qualcomm/aot/wrappers/TensorWrapper.h +++ b/backends/qualcomm/aot/wrappers/TensorWrapper.h @@ -83,6 +83,10 @@ class TensorWrapper { return QNN_VER_PTR(tensor_)->rank; }; + std::uint32_t GetBytes() const { + return bytes_; + }; + const void* GetStaticTensorData() const { return QNN_VER_PTR(tensor_)->clientBuf.data; }; diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index f5c2f1352be..c4fbdeae14b 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -10,7 +10,6 @@ op_avg_pool2d, op_batch_norm, op_bmm, - op_cast, op_cat, op_ceil, op_clamp, @@ -50,6 +49,7 @@ op_sub, op_sum_int_list, op_tanh, + op_to, op_transpose, op_unsqueeze, op_upsample_bilinear2d, @@ -62,7 +62,6 @@ op_avg_pool2d, op_batch_norm, op_bmm, - op_cast, op_cat, op_ceil, op_clamp, @@ -102,6 +101,7 @@ op_sub, op_sum_int_list, op_tanh, + op_to, op_transpose, op_unsqueeze, op_upsample_bilinear2d, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 6a882f8583b..96e3b6f97f9 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -14,7 +14,13 @@ from executorch.exir.dialects._ops import ops as exir_ops -from .utils import get_parameter, is_graph_input, is_graph_output, is_parameter +from .utils import ( + deduce_dtype, + get_parameter, + is_graph_input, + is_graph_output, + is_parameter, +) QNN_QUANT_TYPE_MAP = { @@ -217,21 +223,7 @@ def get_data_type( quant_config: Dict, ) -> PyQnnWrapper.Qnn_TensorType_t: if quant_config: - quant_range = quant_config["quant_max"] - quant_config["quant_min"] - unsigned = quant_config["quant_min"] >= 0 - if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min: - if unsigned: - quant_config["dtype"] = torch.uint8 - else: - quant_config["dtype"] = torch.int8 - elif ( - quant_range - <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min - ): - if unsigned: - quant_config["dtype"] = torch.uint16 - else: - quant_config["dtype"] = torch.int16 + quant_config["dtype"] = deduce_dtype(tensor, quant_config) return QNN_QUANT_TYPE_MAP[quant_config["dtype"]] return QNN_TENSOR_TYPE_MAP[tensor.dtype] @@ -277,7 +269,6 @@ def define_tensor( nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]], is_input_tensor: bool, node_name: str = None, - is_tensor: bool = True, wrapper_idx: int = 0, ) -> PyQnnWrapper.TensorWrapper: """ @@ -296,7 +287,10 @@ def define_tensor( if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None): return cached - tensor_name = node.name + + tensor_name = f"{node.name}_{wrapper_idx}" + if is_graph_input(node, self.edge_program): + tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name if is_graph_output(node): tensor_name = "output_" + tensor_name dims = [1] if len(tensor.size()) == 0 else tensor.size() diff --git a/backends/qualcomm/builders/op_cast.py b/backends/qualcomm/builders/op_cast.py deleted file mode 100644 index d3096ee27cf..00000000000 --- a/backends/qualcomm/builders/op_cast.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -from typing import Dict - -import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper - -import torch - -from .node_visitor import NodeVisitor, register_node_visitor -from .qnn_constants import OpCast, QNN_OP_PACKAGE_NAME_QTI_AISW - - -@register_node_visitor -class Cast(NodeVisitor): - target = ["aten._to_copy.default"] - - def __init__(self, *args) -> None: - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], - ) -> PyQnnWrapper.PyQnnOpWrapper: - input_node = node.args[0] - input_tensor = self.get_tensor(input_node, node) - - input_tensor_wrapper = self.define_tensor( - input_node, - input_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - is_input_tensor=True, - ) - - output_tensor = self.get_tensor(node, node) - - output_tensor_wrapper = self.define_tensor( - node, - output_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, - nodes_to_wrappers, - is_input_tensor=False, - ) - - cast_op = PyQnnWrapper.PyQnnOpWrapper( - node.name, - QNN_OP_PACKAGE_NAME_QTI_AISW, - OpCast.op_name, - ) - cast_op.AddInputTensors([input_tensor_wrapper]) - cast_op.AddOutputTensors([output_tensor_wrapper]) - - return cast_op diff --git a/backends/qualcomm/builders/op_embedding.py b/backends/qualcomm/builders/op_embedding.py index 905578790c0..a5d6aae1702 100644 --- a/backends/qualcomm/builders/op_embedding.py +++ b/backends/qualcomm/builders/op_embedding.py @@ -34,7 +34,7 @@ def define_node( weight_tensor, PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, nodes_to_wrappers, - is_input_tensor=False, + is_input_tensor=True, ) indices_node = node.args[1] diff --git a/backends/qualcomm/builders/op_pow.py b/backends/qualcomm/builders/op_pow.py index 0c02bcdd78e..f30ffbf5286 100644 --- a/backends/qualcomm/builders/op_pow.py +++ b/backends/qualcomm/builders/op_pow.py @@ -53,14 +53,14 @@ def define_node( # scalar input scalar = node.args[1] - scalar_tensor = torch.full(input_tensor.size(), scalar).to(torch.float32) + scalar_tensor = torch.tensor(scalar).to(torch.float32) # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' scalar_node = torch.fx.Node( node.graph, node.name + "_runtime_scalar", "call_function", - exir_ops.edge.aten.full.default, + exir_ops.edge.aten.scalar_tensor.default, (), # args {}, # kwargs ) diff --git a/backends/qualcomm/builders/op_slice_copy.py b/backends/qualcomm/builders/op_slice_copy.py index 7972fb3dd92..3a294e35486 100644 --- a/backends/qualcomm/builders/op_slice_copy.py +++ b/backends/qualcomm/builders/op_slice_copy.py @@ -61,7 +61,9 @@ def define_node( ranges = [] for i in range(input_tensor_rank): if i == dim: - ranges.extend([start, end, 1]) + # find step + step = node.args[4] if len(node.args) > 4 else 1 + ranges.extend([start, end, step]) else: ranges.extend([0, input_tensor.shape[i], 1]) diff --git a/backends/qualcomm/builders/op_split_with_sizes.py b/backends/qualcomm/builders/op_split_with_sizes.py index 015cd937318..03d19b1a5ac 100644 --- a/backends/qualcomm/builders/op_split_with_sizes.py +++ b/backends/qualcomm/builders/op_split_with_sizes.py @@ -59,7 +59,6 @@ def define_node( # Edge represents chunks by specifying the size of each chunk # QNN represents chunks by specifying the index to split chunks for index, _value in enumerate(chunks[:-1]): - sum = sum + chunks[index] split_indices.append(sum) diff --git a/backends/qualcomm/builders/op_to.py b/backends/qualcomm/builders/op_to.py new file mode 100644 index 00000000000..8f3c0276cc0 --- /dev/null +++ b/backends/qualcomm/builders/op_to.py @@ -0,0 +1,104 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class To(NodeVisitor): + target = ["aten._to_copy.default"] + sufixed_8_offset_diff = 128 + sufixed_16_offset_diff = 32768 + epsilon = 1e-6 + sufixed_8 = { + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, + } + sufixed_16 = { + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, + } + + def __init__(self, *args) -> None: + super().__init__(*args) + + def is_cast_node(self, node): + input_node = node.args[0] + + # Not a case which has two quant node, no need to consider the convert op + if not all([input_node.meta.get("quant_attrs"), node.meta.get("quant_attrs")]): + return True + + input_tensor = self.get_tensor(input_node, node) + _, inp_qconfs = self.get_quant_encoding_conf(input_node, False) + inp_dtype = self.get_data_type(input_tensor, inp_qconfs) + + output_tensor = self.get_tensor(node, node) + _, out_qconfs = self.get_quant_encoding_conf(node, False) + out_dtype = self.get_data_type(output_tensor, out_qconfs) + is_qparam_castable = ( + lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon + and abs(o1 - o2) == diff + ) + + if {inp_dtype, out_dtype} == self.sufixed_8: + return is_qparam_castable( + inp_qconfs["offset"], + out_qconfs["offset"], + inp_qconfs["scale"], + out_qconfs["scale"], + self.sufixed_8_offset_diff, + ) + elif {inp_dtype, out_dtype} == self.sufixed_16: + return is_qparam_castable( + inp_qconfs["offset"], + out_qconfs["offset"], + inp_qconfs["scale"], + out_qconfs["scale"], + self.sufixed_16_offset_diff, + ) + return False + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + output_tensor = self.get_tensor(node, node) + + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + qnn_op = OpCast if self.is_cast_node(node) else OpConvert + op = PyQnnWrapper.PyQnnOpWrapper( + node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name + ) + op.AddInputTensors([input_tensor_wrapper]) + op.AddOutputTensors([output_tensor_wrapper]) + + return op diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 3d207dfc7a4..f36b0b64c29 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -39,6 +39,11 @@ class OpConv2d: param_dilation: str = "dilation" +@dataclass(init=False, frozen=True) +class OpConvert: + op_name: str = "Convert" + + @dataclass(init=False, frozen=True) class OpDepthToSpace: op_name: str = "DepthToSpace" diff --git a/backends/qualcomm/builders/utils.py b/backends/qualcomm/builders/utils.py index ca2f48158f5..38e3b676d32 100755 --- a/backends/qualcomm/builders/utils.py +++ b/backends/qualcomm/builders/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict, Optional + import torch from torch._export.utils import get_buffer, get_param, is_buffer, is_param @@ -100,3 +102,20 @@ def is_constant( return tensor.meta["val"].constant is not None return False + + +def deduce_dtype( + tensor: torch.Tensor, quant_infos: Optional[Dict] = None +) -> torch.dtype: + if quant_infos: + quant_range = quant_infos["quant_max"] - quant_infos["quant_min"] + unsigned = quant_infos["quant_min"] >= 0 + if quant_range <= torch.iinfo(torch.int8).max - torch.iinfo(torch.int8).min: + return torch.uint8 if unsigned else torch.int8 + + elif quant_range <= torch.iinfo(torch.int16).max - torch.iinfo(torch.int16).min: + return torch.uint16 if unsigned else torch.int16 + + return quant_infos["dtype"] + + return tensor.dtype diff --git a/backends/qualcomm/partition/common_defs.py b/backends/qualcomm/partition/common_defs.py index 36a2986f09a..61935cf3536 100644 --- a/backends/qualcomm/partition/common_defs.py +++ b/backends/qualcomm/partition/common_defs.py @@ -11,9 +11,9 @@ not_supported_operator = [ exir_ops.edge.aten.arange.start_step, exir_ops.edge.aten.clone.default, - exir_ops.edge.aten.index.Tensor, exir_ops.edge.aten.full.default, exir_ops.edge.aten.slice_scatter.default, + exir_ops.edge.aten.index.Tensor, exir_ops.edge.aten.index_put.default, ] diff --git a/backends/qualcomm/partition/qnn_partitioner.py b/backends/qualcomm/partition/qnn_partitioner.py index f0c2c6eea5f..0c5b25284eb 100644 --- a/backends/qualcomm/partition/qnn_partitioner.py +++ b/backends/qualcomm/partition/qnn_partitioner.py @@ -50,7 +50,7 @@ def __init__( ) self.skip_node_id_set = skip_node_id_set - self.nodes_to_wrappers = self.nodes_to_wrappers = defaultdict(dict) + self.nodes_to_wrappers = defaultdict(dict) self.qnn_manager = PyQnnManager.QnnManager( generate_qnn_executorch_option(compiler_specs) ) @@ -96,6 +96,9 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool: print(f"[QNN Partitioner Op Support]: {node.target.__name__} | {supported}") return supported + def __del__(self): + self.qnn_manager.Destroy() + class QnnPartitioner(Partitioner): def __init__( @@ -145,6 +148,7 @@ def partition(self, edge_program: torch.export.ExportedProgram) -> PartitionResu # pop certain keys in meta for not affecting the passes in compilation # TODO: need to put property name in common definitions node.meta.pop("axis_order", "") + del self.op_support_checker return PartitionResult( tagged_exported_program=edge_program, partition_tags=self.partition_tags ) diff --git a/backends/qualcomm/passes/build_quant_io.py b/backends/qualcomm/passes/build_quant_io.py new file mode 100644 index 00000000000..7a5556fcdda --- /dev/null +++ b/backends/qualcomm/passes/build_quant_io.py @@ -0,0 +1,55 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch + +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.tensor import TensorSpec + +from .utils import q_io_key + + +class BuildQuantIo(ExportPass): + """ + To make lowering process correct, the pass assign the correct quantized dtype to spec of call_delegate. + """ + + def __init__(self): + super(BuildQuantIo, self).__init__() + + def _make_spec(self, x): + if isinstance(x, torch.Tensor): + return TensorSpec.from_tensor(x) + elif isinstance(x, (int, bool, float)): + return x + else: + return None + + def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + # forcely update delegate node's meta['spec'] to get correct output + # tensor size in runtime + call_delegate = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" and node.name == "executorch_call_delegate" + ] + assert len(call_delegate) == 1 + spec = [] + for n in graph_module.graph.nodes: + if q_io_key in n.meta: + n.meta["val"] = n.meta["val"].to(dtype=n.meta[q_io_key]) + if n.op == "call_function" and "getitem" in n.name: + fake_tensor = n.meta["val"] + if q_io_key in n.meta: + fake_tensor = fake_tensor.to(dtype=n.meta[q_io_key]) + spec.append(self._make_spec(fake_tensor)) + + call_delegate[0].meta["spec"] = tuple(spec) + + def call(self, graph_module: torch.fx.GraphModule): + self._build(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/decompose_silu.py b/backends/qualcomm/passes/decompose_silu.py index 225dfefef78..d40f13a5923 100644 --- a/backends/qualcomm/passes/decompose_silu.py +++ b/backends/qualcomm/passes/decompose_silu.py @@ -3,6 +3,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Dict + import torch from executorch.exir.pass_base import ExportPass, PassResult from torch.fx.passes.utils.source_matcher_utils import get_source_partitions @@ -12,6 +14,12 @@ class DecomposeSilu(ExportPass): def __init__(self): super(DecomposeSilu, self).__init__() + def _copy_meta(self, meta: Dict): + copied = {} + for k, v in meta.items(): + copied[k] = v + return copied + def call(self, graph_module: torch.fx.GraphModule): graph = graph_module.graph partitions = get_source_partitions(graph, [torch.nn.functional.silu]) @@ -24,14 +32,14 @@ def call(self, graph_module: torch.fx.GraphModule): sigmoid_node = graph.create_node( "call_function", torch.ops.aten.sigmoid, (inputs[0],) ) - sigmoid_node.meta = silu_node.meta + sigmoid_node.meta = self._copy_meta(silu_node.meta) with graph_module.graph.inserting_after(sigmoid_node): mul_node = graph.create_node( "call_function", torch.ops.aten.mul, (inputs[0], sigmoid_node), ) - mul_node.meta = silu_node.meta + mul_node.meta = self._copy_meta(silu_node.meta) for user in silu_node.users.copy(): user.replace_input_with(silu_node, mul_node) diff --git a/backends/qualcomm/passes/fuse_consecutive_transpose.py b/backends/qualcomm/passes/fuse_consecutive_transpose.py new file mode 100644 index 00000000000..b2351fe9e8a --- /dev/null +++ b/backends/qualcomm/passes/fuse_consecutive_transpose.py @@ -0,0 +1,84 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + + +class FuseConsecutiveTranspose(ExportPass): + """ + This pass fuses consecutive transpose / permute into one to reduce runtime + overhead + """ + + def __init__(self): + super().__init__() + self.op_map = { + exir_ops.edge.aten.permute_copy.default, + } + self.visited = set() + self.nodes = [] + + def _traverse(self, node): + if node in self.visited or node.target not in self.op_map: + return + + self.nodes.append(node) + self.visited.add(node) + next_users = [n for n in list(node.users) if n.target in self.op_map] + if not next_users: + return + + if len(next_users) == 1: + self._traverse(list(node.users)[0]) + else: + raise NotImplementedError( + f"Check the node {node}, wich encounter mutilple permute output case" + ) + + def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + graph = graph_module.graph + for n in graph_module.graph.nodes: + self._traverse(n) + if len(self.nodes) > 1: + permute_order = [] + input_node, output_node = self.nodes[0].args[0], self.nodes[-1] + input_shape = input_node.meta["val"].shape + axis_order = torch.arange(len(input_shape)).tolist() + for node in self.nodes: + permute_order.append(node.args[1]) + axis_order = [axis_order[i] for i in node.args[1]] + with graph.inserting_after(input_node): + permute_op = exir_ops.edge.aten.permute_copy.default + permute_node = graph.create_node( + "call_function", permute_op, (input_node, axis_order) + ) + users = output_node.users.copy() + for user in users: + user.replace_input_with(output_node, permute_node) + + # copy metadata + permute_node.meta = output_node.meta + # Without inserted_permute_tag, we might obtain wrong input shape + if [ + pn.meta.get(LayoutTransform.inserted_permute_tag) + for pn in self.nodes + ]: + permute_node.meta[LayoutTransform.inserted_permute_tag] = True + + # clear current stack + self.nodes = [] + + def call(self, graph_module: torch.fx.GraphModule): + self._fuse(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/insert_io_qdq.py b/backends/qualcomm/passes/insert_io_qdq.py index a384f35b659..0bec89088d4 100644 --- a/backends/qualcomm/passes/insert_io_qdq.py +++ b/backends/qualcomm/passes/insert_io_qdq.py @@ -11,7 +11,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from .utils import dq_ops, q_ops +from .utils import q_io_key, q_ops class InsertIOQDQ(ExportPass): @@ -47,7 +47,7 @@ def _ceate_args(self, target: torch.fx.node.Target, quant_attrs: Dict): if name == "out_dtype": continue value = quant_attrs[name] - if isinstance(arg_schema.type, torch.tensor) and ( + if isinstance(arg_schema.type, torch.Tensor) and ( isinstance(value, int) or isinstance(value, float) ): value = torch.tensor(value) @@ -109,41 +109,18 @@ def _insert_dequant_node( if user.op == "output": user.replace_input_with(node, inserted_node) - # When having requantization dq/q nodes at the input, - # such as the case: input1 -> dq_node1 -> q_node1 -> node1, - # we should fold the dq_node1 and connect input -> q_node1 -> node1. - def _fold_mix_quantization_dq_node(self, graph_module, input_node): - input_users = list(input_node.users.keys()) - for input_user in input_users: - if input_user.target not in dq_ops: - continue - dq_users = list(input_user.users.keys()) - for dq_user in dq_users: - dq_user.replace_input_with(input_user, input_node) - - # When having requantization dq/q nodes at the output, - # such as the case: node(int32) -> dq(int32) -> q(uint8) -> output(int32), - # we should fold the q node and connect node(int32) -> dq(int32) -> output(int32). - def _fold_mix_quantization_q_node(self, graph_module, node, users): - for user in users: - if user.op == "output": - output_node = user - break - - dq_node = node.args[0] - for out_node in output_node.meta["val"]: - if dq_node.meta["val"].dtype == out_node.dtype: - user.replace_input_with(node, dq_node) - def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: + # do nothing when a node is expected to output a quant tensor + if n.meta.get(q_io_key): + continue + # insert q after input or fold mix_quantization dq if applicable if ( n.op == "placeholder" and n.meta.get("quant_attrs") and not is_parameter(n, self.edge_program) ): - self._fold_mix_quantization_dq_node(graph_module, n) self._insert_quant_node( graph_module, n, n.meta["quant_attrs"]["encoding"] ) @@ -151,10 +128,6 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: # insert dq before output or fold mix_quantization q if applicable users = list(n.users.keys()) if n.meta.get("quant_attrs") and any(user.op == "output" for user in users): - if n.target in q_ops: - self._fold_mix_quantization_q_node(graph_module, n, users) - # If q_node is fold, it will have no users, - # so it won't insert dequant node in following function. self._insert_dequant_node( graph_module, n, diff --git a/backends/qualcomm/passes/insert_requantize.py b/backends/qualcomm/passes/insert_requantize.py index d0169ebe357..4e79a4bda60 100644 --- a/backends/qualcomm/passes/insert_requantize.py +++ b/backends/qualcomm/passes/insert_requantize.py @@ -6,14 +6,17 @@ import torch -from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from .utils import q_io_key -class InsertRequantize(InsertIOQDQ): + +class InsertRequantize(ExportPass): """ - This pass inserts dq/q nodes for non-arithmetic operators which have - different quantization specs in input and activation + This pass inserts convert op for operators which have + different quantization specs in input and activation. + Convert OP is a specific op which helps to requantize in Qnn backend """ # Storing ops that has multi output but run _single_output_annotation logic @@ -26,10 +29,9 @@ class InsertRequantize(InsertIOQDQ): def __init__( self, edge_program: torch.export.ExportedProgram, - insert_requantize: bool = False, ): - super().__init__(edge_program) - self.insert_requantize = insert_requantize + super(InsertRequantize, self).__init__() + self.edge_program = edge_program # TODO: Implement this function when we have an op with # multiple outputs that requires quant attributes. @@ -39,23 +41,36 @@ def _multi_output_annotation(self) -> None: def _single_output_annotation( self, gm: torch.fx.GraphModule, n: torch.fx.node ) -> None: - dq_attr = n.meta["quant_attrs"] - q_attr = n.meta["requantize"] - # insert dq with given quantization attribute in input node - dq = self._insert_quant_node( - gm, n, InsertIOQDQ.q_dq_map[q_attr["encoding"]], dq_attr - ) - dq.meta["quant_attrs"] = dq_attr - # insert q with given quantization attribute in current node - q = self._insert_quant_node(gm, dq, q_attr["encoding"], q_attr) - q.meta["quant_attrs"] = q_attr + with gm.graph.inserting_after(n): + users = list(n.users.keys()) + inserted_n = gm.graph.create_node( + "call_function", + exir_ops.edge.aten._to_copy.default, + (n,), + ) + + inserted_n.meta["val"] = n.meta["val"] + inserted_n.meta["quant_attrs"] = n.meta.pop("requantize") + if n.meta.get(q_io_key): + inserted_n.meta[q_io_key] = n.meta[q_io_key] + + for user in users: + user.replace_input_with(n, inserted_n) def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: if "requantize" in n.meta: ( self._single_output_annotation(graph_module, n) - if len(n.meta["val"]) == 1 + if isinstance( + n.meta["val"], torch._subclasses.fake_tensor.FakeTensor + ) or n.target in self.multi_output_op_ignore_set else self._multi_output_annotation() ) + + def call(self, graph_module: torch.fx.GraphModule): + self._insert(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/remove_clone.py b/backends/qualcomm/passes/remove_redundancy.py similarity index 83% rename from backends/qualcomm/passes/remove_clone.py rename to backends/qualcomm/passes/remove_redundancy.py index 5512690fc3c..c54596f6583 100644 --- a/backends/qualcomm/passes/remove_clone.py +++ b/backends/qualcomm/passes/remove_redundancy.py @@ -9,23 +9,25 @@ from executorch.exir.passes import dead_code_elimination_pass -class RemoveClone(ExportPass): +class RemoveRedundancy(ExportPass): """ Trim the 'identity' operators to reduce the unnecessary copy overhead. """ - clone_ops = { + redundant_ops = { torch.clone, torch.ops.aten.clone.default, exir_ops.edge.aten.clone.default, + torch.ops.aten.alias.default, + exir_ops.edge.aten.alias.default, } def __init__(self): - super(RemoveClone, self).__init__() + super(RemoveRedundancy, self).__init__() def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: for n in graph_module.graph.nodes: - if n.target not in self.clone_ops: + if n.target not in self.redundant_ops: continue to_be_remove = n diff --git a/backends/qualcomm/passes/utils.py b/backends/qualcomm/passes/utils.py index 92fed14894c..c97f2b8f53e 100755 --- a/backends/qualcomm/passes/utils.py +++ b/backends/qualcomm/passes/utils.py @@ -9,6 +9,9 @@ from executorch.exir.dialects._ops import ops as exir_ops +# TODO, Move all Qualcomm specific keys to here, like "quant_attrs" +q_io_key = "q_tensor_io" + q_ops = { exir_ops.edge.quantized_decomposed.quantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, diff --git a/backends/qualcomm/qnn_preprocess.py b/backends/qualcomm/qnn_preprocess.py index 653ed92fc64..b3979afc587 100644 --- a/backends/qualcomm/qnn_preprocess.py +++ b/backends/qualcomm/qnn_preprocess.py @@ -12,6 +12,9 @@ from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors from executorch.backends.qualcomm.passes.convert_to_linear import ConvertToLinear +from executorch.backends.qualcomm.passes.fuse_consecutive_transpose import ( + FuseConsecutiveTranspose, +) from executorch.backends.qualcomm.passes.insert_io_qdq import InsertIOQDQ from executorch.backends.qualcomm.passes.insert_requantize import InsertRequantize from executorch.backends.qualcomm.passes.layout_transform import LayoutTransform @@ -48,6 +51,7 @@ def preprocess( InsertRequantize(edge_program), InsertIOQDQ(edge_program), LayoutTransform(edge_program, insert_permute=True), + FuseConsecutiveTranspose(), ] ) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index dc669de6adf..cc2ce008a7c 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from enum import IntEnum, unique -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from typing import Callable, Dict, Optional, Sequence, Set import torch from executorch.backends.qualcomm.passes.decompose_scaled_dot_product_attention import ( @@ -15,31 +15,28 @@ RecomposePixelUnshuffle, ) from executorch.backends.qualcomm.passes.reduce_dynamic_range import ReduceDynamicRange -from executorch.backends.qualcomm.passes.remove_clone import RemoveClone +from executorch.backends.qualcomm.passes.remove_redundancy import RemoveRedundancy from executorch.backends.qualcomm.passes.replace_inf_buffer import ReplaceInfBuffer -from torch import Tensor from torch._ops import OpOverload -from torch.ao.quantization.observer import ( - HistogramObserver, - MinMaxObserver, - MovingAverageMinMaxObserver, - PerChannelMinMaxObserver, +from torch.ao.quantization.quantizer import Quantizer +from torch.fx import GraphModule + +from .utils import ( + get_16a4w_qnn_ptq_config, + get_16a8w_qnn_ptq_config, + get_default_16bit_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + get_ptq_per_channel_weight_config, + OP_ANNOTATOR, + QuantizationConfig, ) -from torch.ao.quantization.quantizer import ( - DerivedQuantizationSpec, - QuantizationSpec, - Quantizer, -) - -from torch.fx import GraphModule, Node - -from .utils import OP_ANNOTATOR, QuantizationConfig __all__ = [ "QnnQuantizer", "QuantDtype", "get_16a4w_qnn_ptq_config", + "get_16a8w_qnn_ptq_config", "get_default_16bit_qnn_ptq_config", "get_default_8bit_qnn_ptq_config", ] @@ -56,205 +53,6 @@ class QuantDtype(IntEnum): use_8a8w = 2 -def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: - def _derive_bias_qparams_fn( - obs_or_fqs: List, - ) -> Tuple[Tensor, Tensor]: - assert ( - len(obs_or_fqs) == 2 - ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" - act_obs_or_fq = obs_or_fqs[0] - weight_obs_or_fq = obs_or_fqs[1] - weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() - act_scale, act_zp = act_obs_or_fq.calculate_qparams() - (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( - act_scale, weight_scale - ) - derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) - derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) - return (derived_scale, derived_zero) - - input_act = node.args[0] - assert isinstance(input_act, Node) - weight = node.args[1] - assert isinstance(weight, Node) - - return DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], - derive_qparams_fn=_derive_bias_qparams_fn, - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - ch_axis=0, - qscheme=torch.per_channel_symmetric, - ) - - -def get_default_8bit_qnn_ptq_config() -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - act_quantization_spec = QuantizationSpec( - dtype=torch.uint8, - quant_min=0, - quant_max=torch.iinfo(torch.uint8).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=torch.iinfo(torch.int8).min + 1, - quant_max=torch.iinfo(torch.int8).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -# 4 bits quantization only supports specific ops. -def get_16a4w_qnn_ptq_config() -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8, - quant_min=-7, - quant_max=7, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_default_16bit_qnn_ptq_config() -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-20} - act_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.uint16).min, - quant_max=torch.iinfo(torch.uint16).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int16, - quant_min=torch.iinfo(torch.int16).min + 1, - quant_max=torch.iinfo(torch.int16).max, - qscheme=torch.per_tensor_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - # torch does not support uint16 quantization, use int32 to bypass - bias_quantization_spec = QuantizationSpec( - dtype=torch.int32, - quant_min=torch.iinfo(torch.int32).min, - quant_max=torch.iinfo(torch.int32).max, - qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), - ) - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - -def get_ptq_per_channel_weight_config( - act_dtype=torch.uint8, weight_dtype=torch.int8 -) -> QuantizationConfig: - extra_args: Dict[str, Any] = {"eps": 2**-12} - - supported_act_types = { - torch.uint8, - torch.uint16, - torch.int8, - torch.int16, - } - # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype - supported_weight_dtypes = {"int4", torch.int8, torch.int16} - assert ( - act_dtype in supported_act_types - ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" - - assert ( - weight_dtype in supported_weight_dtypes - ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" - - # torch do not support uint16 quantization, use int32 to bypass - act_quantization_spec = QuantizationSpec( - dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, - quant_min=torch.iinfo(act_dtype).min, - quant_max=torch.iinfo(act_dtype).max, - qscheme=torch.per_tensor_affine, - observer_or_fake_quant_ctr=HistogramObserver.with_args(**extra_args), - ) - - weight_quantization_spec = QuantizationSpec( - dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, - quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, - quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, - qscheme=torch.per_channel_symmetric, - ch_axis=0, - observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), - ) - - bias_quantization_spec = _derived_bias_quant_spec - - quantization_config = QuantizationConfig( - input_activation=act_quantization_spec, - output_activation=act_quantization_spec, - weight=weight_quantization_spec, - bias=bias_quantization_spec, - ) - - return quantization_config - - class QnnQuantizer(Quantizer): SUPPORTED_OPS: Set = set(OP_ANNOTATOR.keys()) @@ -384,7 +182,7 @@ def set_per_channel_linear_quant(self, enable: bool) -> None: self._update_per_channel_weight_quant_ops(linear_ops, enable) def transform_for_annotation(self, model: GraphModule) -> GraphModule: - model = RemoveClone()(model).graph_module + model = RemoveRedundancy()(model).graph_module model = ReduceDynamicRange()(model).graph_module model = RecomposePixelUnshuffle(quantization_capture=True)(model).graph_module model = DecomposeScaledDotProductAttention()(model).graph_module diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index 351de70dce8..f2265daf325 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -3,29 +3,36 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import numbers from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Sequence +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple import torch +from torch import Tensor from torch._ops import OpOverload from torch._subclasses import FakeTensor +from torch.ao.quantization.observer import ( + FixedQParamsObserver, + MinMaxObserver, + MovingAverageMinMaxObserver, + PerChannelMinMaxObserver, +) + from torch.ao.quantization.quantizer import ( + DerivedQuantizationSpec, QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) - from torch.ao.quantization.quantizer.utils import ( _annotate_input_qspec_map, _annotate_output_qspec, ) from torch.fx import Node -QUANT_ANNOTATION_KEY = "quantization_annotation" -OP_ANNOTATOR: Dict[OpOverload, Callable] = {} - @dataclass(eq=True, frozen=True) class QuantizationConfig: @@ -35,6 +42,253 @@ class QuantizationConfig: bias: Optional[QuantizationSpec | Callable] +def _derived_bias_quant_spec(node: Node) -> DerivedQuantizationSpec: + def _derive_bias_qparams_fn( + obs_or_fqs: List, + ) -> Tuple[Tensor, Tensor]: + assert ( + len(obs_or_fqs) == 2 + ), f"Expecting two obs/fqs, one for activation and one for weight, got: {len(obs_or_fqs)}" + act_obs_or_fq = obs_or_fqs[0] + weight_obs_or_fq = obs_or_fqs[1] + weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams() + act_scale, act_zp = act_obs_or_fq.calculate_qparams() + (broadcast_act_scale, broadcast_weight_scale) = torch.broadcast_tensors( + act_scale, weight_scale + ) + derived_scale = (broadcast_act_scale * broadcast_weight_scale).to(torch.float32) + derived_zero = torch.zeros(derived_scale.size()).to(torch.int32) + return (derived_scale, derived_zero) + + input_act = node.args[0] + assert isinstance(input_act, Node) + weight = node.args[1] + assert isinstance(weight, Node) + + return DerivedQuantizationSpec( + derived_from=[(input_act, node), (weight, node)], + derive_qparams_fn=_derive_bias_qparams_fn, + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + ch_axis=0, + qscheme=torch.per_channel_symmetric, + ) + + +def get_default_8bit_qnn_ptq_config( + act_symmetric: bool = False, act_observer=MinMaxObserver +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=( + torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine + ), + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=torch.iinfo(torch.int8).min + 1, + quant_max=torch.iinfo(torch.int8).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +# 4 bits quantization only supports specific ops. +def get_16a4w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8, + quant_min=-7, + quant_max=7, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_16a8w_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_default_16bit_qnn_ptq_config( + act_observer=MovingAverageMinMaxObserver, +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-20} + act_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.uint16).min, + quant_max=torch.iinfo(torch.uint16).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int16, + quant_min=torch.iinfo(torch.int16).min + 1, + quant_max=torch.iinfo(torch.int16).max, + qscheme=torch.per_tensor_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + # torch does not support uint16 quantization, use int32 to bypass + bias_quantization_spec = QuantizationSpec( + dtype=torch.int32, + quant_min=torch.iinfo(torch.int32).min, + quant_max=torch.iinfo(torch.int32).max, + qscheme=torch.per_tensor_symmetric, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +def get_ptq_per_channel_weight_config( + act_dtype=torch.uint8, weight_dtype=torch.int8 +) -> QuantizationConfig: + extra_args: Dict[str, Any] = {"eps": 2**-12} + + supported_act_types = { + torch.uint8, + torch.uint16, + torch.int8, + torch.int16, + } + # TODO accept "int4" temporally. Remove "int4" when torch support torch.int4 dtype + supported_weight_dtypes = {"int4", torch.int8, torch.int16} + assert ( + act_dtype in supported_act_types + ), f"act_dtype, {act_dtype} is not one of supported types, {supported_act_types}" + + assert ( + weight_dtype in supported_weight_dtypes + ), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}" + + # torch do not support uint16 quantization, use int32 to bypass + act_quantization_spec = QuantizationSpec( + dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype, + quant_min=torch.iinfo(act_dtype).min, + quant_max=torch.iinfo(act_dtype).max, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), + ) + + weight_quantization_spec = QuantizationSpec( + dtype=torch.int8 if weight_dtype == "int4" else weight_dtype, + quant_min=-7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).min + 1, + quant_max=7 if weight_dtype == "int4" else torch.iinfo(weight_dtype).max, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), + ) + + bias_quantization_spec = _derived_bias_quant_spec + + quantization_config = QuantizationConfig( + input_activation=act_quantization_spec, + output_activation=act_quantization_spec, + weight=weight_quantization_spec, + bias=bias_quantization_spec, + ) + + return quantization_config + + +QUANT_ANNOTATION_KEY = "quantization_annotation" +OP_ANNOTATOR: Dict[OpOverload, Callable] = {} + + def register_annotator(ops: List[OpOverload]): def decorator(annotator: Callable): for op in ops: @@ -43,19 +297,6 @@ def decorator(annotator: Callable): return decorator -def _is_input_float_tensor(node: Node): - """Check if the input is not a float tensor, so that we can skip quantization for the node - since observers only works with float Tensors - """ - if ( - not isinstance(node, Node) - or "val" not in node.meta - or not isinstance(node.meta["val"], FakeTensor) - ): - return False - return node.meta["val"].dtype == torch.float32 - - def _is_annotated(nodes: List[Node]): """ Given a list of nodes (that represents an operator pattern), @@ -71,6 +312,19 @@ def _is_annotated(nodes: List[Node]): return annotated +def _is_input_float_tensor(node: Node): + """Check if the input is not a float tensor, so that we can skip quantization for the node + since observers only works with float Tensors + """ + if ( + not isinstance(node, Node) + or "val" not in node.meta + or not isinstance(node.meta["val"], FakeTensor) + ): + return False + return node.meta["val"].dtype == torch.float32 + + def _mark_nodes_as_annotated(nodes: List[Node]): for node in nodes: if QUANT_ANNOTATION_KEY not in node.meta: @@ -117,15 +371,12 @@ def annotate_single_in_single_out( assert isinstance(input_act, Node) input_qspec_map[input_act] = quantization_config.input_activation - node_tensor = node.meta.get("val") - if torch.is_tensor(node_tensor) and node_tensor.dtype != torch.float32: - return - - node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( - input_qspec_map=input_qspec_map, - output_qspec=quantization_config.output_activation, - _annotated=True, - ) + if _is_input_float_tensor(node): + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None: @@ -133,7 +384,9 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None return input_act_qspec = quantization_config.input_activation - output_act_qspec = quantization_config.output_activation + output_act_qspec = ( + quantization_config.output_activation if _is_input_float_tensor(node) else None + ) input_qspec_map = {} input_act0 = node.args[0] @@ -151,24 +404,88 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None ) -@register_annotator([torch.ops.aten.add.Tensor]) +@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor]) def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) -@register_annotator([torch.ops.aten.sub.Tensor]) +@register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor]) def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) -@register_annotator([torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar]) +@register_annotator( + [torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar] +) def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None: annotate_binary(node, quantization_config) -@register_annotator([torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor]) +@register_annotator( + [torch.ops.aten.div, torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor] +) def annotate_div(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_binary(node, quantization_config) + def _derived_inp1_const_div_quant_spec( + node: torch.fx.Node, output_qspec: QuantizationSpec + ) -> DerivedQuantizationSpec: + def _derive_div_qparams_fn( + obs_or_fqs: List, + const_val: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + inp_0_obs_or_fq = obs_or_fqs[0] + inp_0_scale, inp_0_zp = inp_0_obs_or_fq.calculate_qparams() + derived_scale = inp_0_scale / const_val + return (derived_scale, inp_0_zp) + + inp_0 = node.args[0] + const_inp_1 = node.args[1] + _derive_div_qparams_with_const_fn = partial( + _derive_div_qparams_fn, const_val=const_inp_1 + ) + + q_min = ( + torch.iinfo(output_qspec.dtype).min + if output_qspec.quant_min is None + else output_qspec.quant_min + ) + q_max = ( + torch.iinfo(output_qspec.dtype).max + if output_qspec.quant_max is None + else output_qspec.quant_max + ) + return DerivedQuantizationSpec( + derived_from=[(inp_0, node)], + derive_qparams_fn=_derive_div_qparams_with_const_fn, + dtype=output_qspec.dtype, + quant_min=q_min, + quant_max=q_max, + ch_axis=0, + qscheme=output_qspec.qscheme, + ) + + if [a for a in node.args if isinstance(a, Node)]: + annotate_binary(node, quantization_config) + # special constant divisor case + elif isinstance(node.args[0], Node) and isinstance(node.args[1], numbers.Number): + if _is_annotated([node]): + return + + input_act_qspec = quantization_config.input_activation + output_act_qspec = _derived_inp1_const_div_quant_spec( + node, quantization_config.output_activation + ) + input_qspec_map = {} + input_act0 = node.args[0] + if _is_input_float_tensor(input_act0): + input_qspec_map[input_act0] = input_act_qspec + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=output_act_qspec, + _annotated=True, + ) + else: + raise NotImplementedError(f"No quant annotation is implemented for {node}.") @register_annotator([torch.ops.aten.rsub.Scalar]) @@ -251,7 +568,9 @@ def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> N @register_annotator([torch.ops.aten.permute.default]) def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_single_out(node, quantization_config) @register_annotator( @@ -267,7 +586,9 @@ def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator([torch.ops.aten.view.default]) def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + annotate_in_out_obs_sharing_op(node, quantization_config) + if not _is_annotated([node]): + annotate_single_in_single_out(node, quantization_config) @register_annotator([torch.ops.aten.pixel_shuffle.default]) @@ -370,7 +691,51 @@ def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None: @register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default]) def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> None: - annotate_single_in_single_out(node, quantization_config) + if _is_annotated([node]): + return + + input_qspec_map = {} + input_act = node.args[0] + input_qspec_map[input_act] = quantization_config.input_activation + + assert isinstance(input_act, Node) + out_qconf = quantization_config.output_activation + + q_max = ( + torch.iinfo(out_qconf.dtype).max + if out_qconf.quant_max is None + else out_qconf.quant_max + ) + q_min = ( + torch.iinfo(out_qconf.dtype).min + if out_qconf.quant_min is None + else out_qconf.quant_min + ) + + scale = 1 / (q_max - q_min + 1) + + # make sigmoid map to the range between 0~1 + out_act_quantization_spec = QuantizationSpec( + dtype=quantization_config.output_activation.dtype, + quant_max=q_max, + quant_min=q_min, + observer_or_fake_quant_ctr=FixedQParamsObserver.with_args( + scale=scale, + zero_point=0, + dtype=quantization_config.output_activation.dtype, + qscheme=torch.torch.per_tensor_affine, + quant_max=q_max, + quant_min=q_min, + ), + qscheme=torch.torch.per_tensor_affine, + ) + + if _is_input_float_tensor(node): + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=out_act_quantization_spec, + _annotated=True, + ) @register_annotator([torch.ops.aten.pow.Tensor_Scalar]) @@ -468,10 +833,8 @@ def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None if isinstance(input_act1, Node): # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: - input_qspec_map[input_act1] = quantization_config.weight - quantization_annotation = input_act1.meta.get(QUANT_ANNOTATION_KEY, None) - if quantization_annotation: - quantization_annotation.output_qspec = quantization_config.weight + # we should use int16 for mm / bmm instead of int4 + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -499,10 +862,8 @@ def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None: if isinstance(input_act1, Node): # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized. if input_act_qspec.dtype == torch.int32: - input_qspec_map[input_act1] = quantization_config.weight - quantization_annotation = input_act1.meta.get(QUANT_ANNOTATION_KEY, None) - if quantization_annotation: - quantization_annotation.output_qspec = quantization_config.weight + # we should use int16 for mm / bmm instead of int4 + input_qspec_map[input_act1] = get_default_16bit_qnn_ptq_config().weight else: input_qspec_map[input_act1] = input_act_qspec @@ -664,7 +1025,7 @@ def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None ) -@register_annotator([torch.ops.aten.chunk.default]) +@register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default]) def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None: if _is_annotated([node]): return diff --git a/backends/qualcomm/runtime/QnnExecuTorch.h b/backends/qualcomm/runtime/QnnExecuTorch.h index d54de1059d7..45525726ca7 100644 --- a/backends/qualcomm/runtime/QnnExecuTorch.h +++ b/backends/qualcomm/runtime/QnnExecuTorch.h @@ -7,6 +7,7 @@ */ #pragma once +#include #ifdef __cplusplus #include #include @@ -33,6 +34,19 @@ typedef struct { } // clang-format on +/// Allocate memory in different way, check qnn document for more details. +enum QnnMemDescriptor { kIon, kCustom }; + +struct CustomMemTensorInfo { + void* custom_mem; + void* tensor_addr; + size_t pos; + size_t tensor_bytes; + uint32_t* shape; + uint32_t rank; + torch::executor::ScalarType dtype; +}; + /// Allocate specific tensors (usually graph inputs and outputs) on shared /// memory. Users are responsible to allocate "enough" tensor bytes, and set /// alignment as MemoryAllocator::kDefaultAlignment. @@ -40,6 +54,14 @@ typedef struct { /// if allocation is successful. void* QnnExecuTorchAllocCustomMem(size_t bytes, size_t alignment); +/// Add tensor to custom memory with custom type descriptor. Create memory +/// handle to tensor wrapper during execution +void QnnExecuTorchAddCustomMemTensorAddr(void* tensor_addr, void* custom_mem); + +/// Add custom mem tensor info. Help to bring forward the memHandle creating +/// time from execution to initialization. +void QnnExecuTorchAddCustomMemTensorInfo(const CustomMemTensorInfo& info); + /// Free the allocated shared memory. void QnnExecuTorchFreeCustomMem(void* buffer_ptr); diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index 77449703c5f..feccfff9fa8 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -11,15 +11,12 @@ #include #include #include - -#include namespace torch { namespace executor { // ========== Public method implementations ========================= using namespace qnn; using namespace qnn_delegate; constexpr const char* QNN_COMPILE_SPEC = "qnn_compile_spec"; - Result QnnExecuTorchBackend::init( BackendInitContext& context, FreeableBuffer* processed, diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index dc3217fc1c8..a77ec1a557e 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -9,12 +9,26 @@ #include #include #include +#include #include #include #include +#include + namespace torch { namespace executor { namespace qnn { + +bool CompareExportedInput( + const std::shared_ptr& a, + const std::shared_ptr& b) { + // Using the order of the nodes as external_id in AOT + // to extract the right arg from *args at runtime + int numA = std::stoi(a->GetName().substr(a->GetName().find('_') + 1)); + int numB = std::stoi(b->GetName().substr(b->GetName().find('_') + 1)); + return numA < numB; +} + QnnManager::~QnnManager() { backend_params_ptr_.reset(new BackendConfigParameters()); logger_.reset(); @@ -84,6 +98,52 @@ Error QnnManager::LoadQnnLibrary() { return ret; } +Error QnnManager::PreRegisterMem() { + SharedBuffer& shared_buffer_manager = SharedBuffer::GetSharedBufferManager(); + for (const auto info : shared_buffer_manager.GetCustomMemTensorInfoSet()) { + void* unaligned_custom_mem_base = + shared_buffer_manager.GetUnAlignedAddr(info.custom_mem); + + size_t tensor_offset = (static_cast(info.custom_mem) - + static_cast(unaligned_custom_mem_base)) + + info.pos; + size_t total_custom_mem_size = + shared_buffer_manager.GetAllocatedSize(info.custom_mem); + + int32_t mem_fd = shared_buffer_manager.MemToFd(unaligned_custom_mem_base); + if (mem_fd == -1) { + QNN_EXECUTORCH_LOG_WARN( + "PreRegisterMem failed to get file descriptor.", + "custom_mem: %p", + "tensor_addr: %p", + "pos: %uz", + "tensor_bytes: %uz", + "shape: %p", + "rank: %zu", + "qnn_dtype: %X", + info.custom_mem, + info.tensor_addr, + info.pos, + info.tensor_bytes, + info.shape, + info.rank, + info.dtype); + return Error::Internal; + } + + ET_CHECK_OR_RETURN_ERROR( + backend_params_ptr_->qnn_mem_manager_ptr_->PreRegisterCustomMemHandle( + mem_fd, + unaligned_custom_mem_base, + total_custom_mem_size, + tensor_offset, + info) == Error::Ok, + Internal, + "Fail to register to shared memory."); + } + return Error::Ok; +} + Error QnnManager::RegisterMem( void* data_ptr, const std::shared_ptr& tensor_wrapper) { @@ -100,6 +160,17 @@ Error QnnManager::RegisterMem( return Error::Internal; } + void* custom_mem_base = shared_buffer_manager.GetCustomMemBase(data_ptr); + if (custom_mem_base != nullptr) { + return RegisterCustomMem(data_ptr, custom_mem_base, tensor_wrapper); + } + return RegisterIonMem(data_ptr, tensor_wrapper); +} + +Error QnnManager::RegisterIonMem( + void* data_ptr, + const std::shared_ptr& tensor_wrapper) { + SharedBuffer& shared_buffer_manager = SharedBuffer::GetSharedBufferManager(); if (!shared_buffer_manager.IsAllocated(data_ptr)) { // It means two scenarios here: // 1. the input and output partitioned graph @@ -107,7 +178,7 @@ Error QnnManager::RegisterMem( // QnnExecuTorchAllocCustomMem API return Error::Internal; } else if (backend_params_ptr_->qnn_mem_manager_ptr_->IsRegistered( - tensor_wrapper->GetMemHandle())) { + tensor_wrapper->GetMemHandle(), data_ptr)) { if (options_->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) QNN_EXECUTORCH_LOG_INFO( "Tensor name %s has been registered shared memory.", @@ -115,7 +186,7 @@ Error QnnManager::RegisterMem( return Error::Ok; } - int32_t mem_fd = SharedBuffer::GetSharedBufferManager().MemToFd(data_ptr); + int32_t mem_fd = shared_buffer_manager.MemToFd(data_ptr); if (mem_fd == -1) { QNN_EXECUTORCH_LOG_WARN( "Tensor name %s is failed to get file descriptor.", @@ -123,8 +194,74 @@ Error QnnManager::RegisterMem( return Error::Internal; } ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_->qnn_mem_manager_ptr_->RegisterMem( - tensor_wrapper, mem_fd) == Error::Ok, + backend_params_ptr_->qnn_mem_manager_ptr_->RegisterIonMem( + tensor_wrapper, mem_fd, data_ptr) == Error::Ok, + Internal, + "Fail to register to shared memory."); + + return Error::Ok; +} + +Error QnnManager::RegisterCustomMem( + void* data_ptr, + void* custom_mem_base, + const std::shared_ptr& tensor_wrapper) { + if (backend_params_ptr_->qnn_mem_manager_ptr_->IsRegistered( + tensor_wrapper->GetMemHandle(), data_ptr)) { + if (options_->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) + QNN_EXECUTORCH_LOG_INFO( + "Tensor name %s has been registered shared memory.", + tensor_wrapper->GetName().c_str()); + return Error::Ok; + } + + CustomMemTensorInfo info{ + custom_mem_base, + data_ptr, + static_cast( + static_cast(data_ptr) - static_cast(custom_mem_base)), + tensor_wrapper->GetBytes(), + tensor_wrapper->GetDims(), + tensor_wrapper->GetRank(), + qnn_dtype_to_scalar_type_[tensor_wrapper->GetDataType()]}; + + Qnn_MemHandle_t pre_registered_handle = + backend_params_ptr_->qnn_mem_manager_ptr_->GetPreRegisteredHandle(info); + if (pre_registered_handle != nullptr) { + if (options_->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) { + QNN_EXECUTORCH_LOG_INFO( + "Tensor name %s found a pre-registered memHandle.", + tensor_wrapper->GetName().c_str()); + } + return backend_params_ptr_->qnn_mem_manager_ptr_->SetMemHandle( + tensor_wrapper, data_ptr, pre_registered_handle); + } + + SharedBuffer& shared_buffer_manager = SharedBuffer::GetSharedBufferManager(); + void* unaligned_custom_mem_base = + shared_buffer_manager.GetUnAlignedAddr(custom_mem_base); + + size_t tensor_offset = static_cast(custom_mem_base) - + static_cast(unaligned_custom_mem_base) + info.pos; + size_t total_custom_mem_size = + shared_buffer_manager.GetAllocatedSize(custom_mem_base); + + int32_t mem_fd = shared_buffer_manager.MemToFd(unaligned_custom_mem_base); + if (mem_fd == -1) { + QNN_EXECUTORCH_LOG_WARN( + "Tensor name %s failed to get file descriptor.", + tensor_wrapper->GetName().c_str()); + return Error::Internal; + } + + ET_CHECK_OR_RETURN_ERROR( + backend_params_ptr_->qnn_mem_manager_ptr_->RegisterCustomMem( + tensor_wrapper, + mem_fd, + data_ptr, + unaligned_custom_mem_base, + total_custom_mem_size, + tensor_offset) == Error::Ok, Internal, "Fail to register to shared memory."); @@ -164,6 +301,12 @@ Error QnnManager::Init() { BackendInitializeState::INITIALIZED; } +#if defined(__aarch64__) + ET_CHECK_OR_RETURN_ERROR( + PreRegisterMem() == Error::Ok, + Internal, + "Fail to pre register custom memory handle"); +#endif return Error::Ok; } @@ -178,7 +321,10 @@ Error QnnManager::AllocateTensor() { tensor_wrapper->UpdateQnnTensorMeta(tensor); input_tensors_.emplace_back(std::move(tensor_wrapper)); } - + if (!options_->is_from_context_binary()) { + std::sort( + input_tensors_.begin(), input_tensors_.end(), CompareExportedInput); + } for (auto& tensor : output_tensors) { std::shared_ptr tensor_wrapper = CreateTensorWrapper(tensor); tensor_wrapper->UpdateQnnTensorMeta(tensor); @@ -199,6 +345,10 @@ Error QnnManager::AllocateTensor( output_tensor->AllocateDataBuffer(); } } + if (!options_->is_from_context_binary()) { + std::sort( + input_tensors_.begin(), input_tensors_.end(), CompareExportedInput); + } output_tensors_ = std::move(outputs); return Error::Ok; } @@ -371,13 +521,23 @@ Error QnnManager::Compile( } // namespace executor } // namespace torch void* QnnExecuTorchAllocCustomMem(size_t bytes, size_t alignment) { - using torch::executor::qnn::SharedBuffer; void* buffer_ptr = - SharedBuffer::GetSharedBufferManager().AllocMem(bytes, alignment); + torch::executor::qnn::SharedBuffer::GetSharedBufferManager().AllocMem( + bytes, alignment); return buffer_ptr; } void QnnExecuTorchFreeCustomMem(void* buffer_ptr) { - using torch::executor::qnn::SharedBuffer; - SharedBuffer::GetSharedBufferManager().FreeMem(buffer_ptr); + torch::executor::qnn::SharedBuffer::GetSharedBufferManager().FreeMem( + buffer_ptr); +} + +void QnnExecuTorchAddCustomMemTensorAddr(void* tensor_addr, void* custom_mem) { + torch::executor::qnn::SharedBuffer::GetSharedBufferManager() + .AddCusomMemTensorAddr(tensor_addr, custom_mem); +} + +void QnnExecuTorchAddCustomMemTensorInfo(const CustomMemTensorInfo& info) { + torch::executor::qnn::SharedBuffer::GetSharedBufferManager() + .AddCusomMemTensorInfo(info); } diff --git a/backends/qualcomm/runtime/QnnManager.h b/backends/qualcomm/runtime/QnnManager.h index 639d3534de4..5190f6768b7 100644 --- a/backends/qualcomm/runtime/QnnManager.h +++ b/backends/qualcomm/runtime/QnnManager.h @@ -16,6 +16,7 @@ #include #include +#include namespace torch { namespace executor { @@ -65,6 +66,9 @@ class QnnManager { void* data_ptr, const std::shared_ptr& tensor_wrapper); + // Pre-register custom memory handle from the SharedBuffer before execution + Error PreRegisterMem(); + std::vector> GetGraphInputs() { return input_tensors_; } @@ -86,6 +90,21 @@ class QnnManager { const QnnExecuTorchOptions* options_; std::vector> input_tensors_; std::vector> output_tensors_; + Error RegisterIonMem( + void* data_ptr, + const std::shared_ptr& tensor_wrapper); + Error RegisterCustomMem( + void* data_ptr, + void* custom_mem_base, + const std::shared_ptr& tensor_wrapper); + std::unordered_map qnn_dtype_to_scalar_type_ = { + {Qnn_DataType_t::QNN_DATATYPE_INT_32, ScalarType::Int}, + {Qnn_DataType_t::QNN_DATATYPE_FLOAT_32, ScalarType::Float}, + {Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_8, ScalarType::Char}, + {Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_16, ScalarType::Short}, + {Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8, ScalarType::Byte}, + {Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16, ScalarType::Bits16}, + }; }; } // namespace qnn } // namespace executor diff --git a/backends/qualcomm/runtime/SharedBuffer.cpp b/backends/qualcomm/runtime/SharedBuffer.cpp index 423c5d63723..430c8f757a6 100644 --- a/backends/qualcomm/runtime/SharedBuffer.cpp +++ b/backends/qualcomm/runtime/SharedBuffer.cpp @@ -14,6 +14,34 @@ constexpr uint8_t RPCMEM_HEAP_ID_SYSTEM = 25; constexpr uint8_t RPCMEM_DEFAULT_FLAGS = 1; +std::size_t std::hash::operator()( + const CustomMemTensorInfo& info) const noexcept { + size_t hash_val = 0; + hash_val ^= std::hash()(info.tensor_addr); + hash_val ^= std::hash()(info.custom_mem); + hash_val ^= std::hash()(info.pos); + hash_val ^= std::hash()(info.tensor_bytes); + for (int i = 0; i < info.rank; ++i) { + hash_val ^= info.shape[i]; + } + hash_val ^= std::hash()(info.rank); + hash_val ^= std::hash()(info.dtype); + return hash_val; +} + +bool operator==( + const CustomMemTensorInfo& lhs, + const CustomMemTensorInfo& rhs) { + bool is_same = + (lhs.tensor_addr == rhs.tensor_addr && lhs.custom_mem == rhs.custom_mem && + lhs.pos == rhs.pos && lhs.tensor_bytes == rhs.tensor_bytes && + lhs.rank == rhs.rank && lhs.dtype == rhs.dtype); + for (int i = 0; i < lhs.rank; ++i) { + is_same &= lhs.shape[i] == rhs.shape[i]; + } + return is_same; +} + namespace torch { namespace executor { namespace qnn { @@ -31,6 +59,30 @@ intptr_t alignTo(size_t alignment, intptr_t offset) { std::mutex SharedBuffer::init_mutex_; +void* SharedBuffer::GetCustomMemBase(void* buf) { + auto it = tensor_addr_to_custom_mem_.find(buf); + if (it == tensor_addr_to_custom_mem_.end()) { + return nullptr; + } + return it->second; +} + +void* SharedBuffer::GetUnAlignedAddr(void* buf) { + auto it = restore_map_.find(buf); + if (it == restore_map_.end()) { + return nullptr; + } + return it->second; +} + +size_t SharedBuffer::GetAllocatedSize(void* buf) { + auto it = allocated_size_map_.find(buf); + if (it == allocated_size_map_.end()) { + return 0; + } + return it->second; +} + SharedBuffer& SharedBuffer::GetSharedBufferManager() { std::lock_guard lk(init_mutex_); static SharedBuffer shared_buffer_manager; @@ -62,10 +114,10 @@ void* SharedBuffer::AllocMem(size_t bytes, size_t alignment) { QNN_EXECUTORCH_LOG_WARN("Failed to allocate the tensor by RPC memory."); return nullptr; } + allocated_size_map_.insert({buf, allocate_bytes}); auto aligned_buf = reinterpret_cast( alignTo(alignment, reinterpret_cast(buf))); - bool status = - restore_map_.insert(std::pair(aligned_buf, buf)).second; + bool status = restore_map_.insert({aligned_buf, buf}).second; if (!status) { QNN_EXECUTORCH_LOG_ERROR("Failed to allocate the tensor by RPC memory."); rpc_mem_free_(buf); @@ -123,6 +175,15 @@ Error SharedBuffer::Load() { return Error::Ok; } +void SharedBuffer::AddCusomMemTensorAddr(void* tensor_addr, void* custom_mem) { + tensor_addr_to_custom_mem_.insert({tensor_addr, custom_mem}); +}; + +void SharedBuffer::AddCusomMemTensorInfo(const CustomMemTensorInfo& info) { + custom_mem_tensor_info_set_.insert(info); + tensor_addr_to_custom_mem_.insert({info.tensor_addr, info.custom_mem}); +} + Error SharedBuffer::UnLoad() { if (dlclose(lib_cdsp_rpc_) != 0) { QNN_EXECUTORCH_LOG_ERROR( diff --git a/backends/qualcomm/runtime/SharedBuffer.h b/backends/qualcomm/runtime/SharedBuffer.h index 1803e8af879..9d01e67c8e2 100644 --- a/backends/qualcomm/runtime/SharedBuffer.h +++ b/backends/qualcomm/runtime/SharedBuffer.h @@ -6,20 +6,31 @@ * LICENSE file in the root directory of this source tree. */ #pragma once +#include +#include #include #include #include #include #include #include +#include using RpcMemAllocFn_t = void* (*)(int, uint32_t, int); using RpcMemFreeFn_t = void (*)(void*); using RpcMemToFdFn_t = int (*)(void*); +// TODO Finad a better file to place CustomMemTensorInfo +bool operator==(const CustomMemTensorInfo& lhs, const CustomMemTensorInfo& rhs); +template <> +struct std::hash { + std::size_t operator()(const CustomMemTensorInfo& info) const noexcept; +}; + namespace torch { namespace executor { namespace qnn { + class SharedBuffer final { public: SharedBuffer(const SharedBuffer&) = delete; @@ -45,6 +56,22 @@ class SharedBuffer final { initialize_ = initialize; } + // memory handle is registered during execution + void AddCusomMemTensorAddr(void* tensor_addr, void* custom_mem); + + // memory handle can be registered before execution + void AddCusomMemTensorInfo(const CustomMemTensorInfo& info); + + size_t GetAllocatedSize(void* buf); + + void* GetCustomMemBase(void* buf); + + void* GetUnAlignedAddr(void* buf); + + const std::unordered_set& GetCustomMemTensorInfoSet() { + return custom_mem_tensor_info_set_; + }; + private: SharedBuffer() = default; @@ -63,6 +90,10 @@ class SharedBuffer final { // Function pointer to rpcmem_to_fd RpcMemToFdFn_t rpc_mem_to_fd_; std::unordered_map restore_map_; + std::unordered_map allocated_size_map_; + // Maps for the custom memory + std::unordered_map tensor_addr_to_custom_mem_; + std::unordered_set custom_mem_tensor_info_set_; std::atomic_bool initialize_{false}; static std::mutex init_mutex_; }; diff --git a/backends/qualcomm/runtime/backends/CMakeLists.txt b/backends/qualcomm/runtime/backends/CMakeLists.txt index e173f08af08..ed61d7545a9 100644 --- a/backends/qualcomm/runtime/backends/CMakeLists.txt +++ b/backends/qualcomm/runtime/backends/CMakeLists.txt @@ -126,6 +126,7 @@ set(qnn_header_basenames HTP/QnnHtpCommon.h HTP/QnnHtpDevice.h HTP/QnnHtpGraph.h + HTP/QnnHtpMem.h HTP/QnnHtpPerfInfrastructure.h HTP/QnnHtpProfile.h HTP/QnnHtpProperty.h diff --git a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp index 0c569ae5ab6..8c7639460fb 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp @@ -96,6 +96,7 @@ QnnBackendCache::QnnBackendCache( if (qcir::VerifyGraphBuffer(verifier)) { state_ = ONLINE_PREPARE; + QNN_EXECUTORCH_LOG_INFO("Verify context blob came from flatbuffer."); return; } } diff --git a/backends/qualcomm/runtime/backends/QnnMemManager.cpp b/backends/qualcomm/runtime/backends/QnnMemManager.cpp index 8f8317e0136..9b5cb4bdc03 100644 --- a/backends/qualcomm/runtime/backends/QnnMemManager.cpp +++ b/backends/qualcomm/runtime/backends/QnnMemManager.cpp @@ -11,13 +11,18 @@ namespace torch { namespace executor { namespace qnn { -bool QnnMemManager::IsRegistered(Qnn_MemHandle_t handle) { - return registered_set_.count(handle) != 0U; +bool QnnMemManager::IsRegistered(Qnn_MemHandle_t handle, void* mem_ptr) { + auto it = registered_map_.find(handle); + if (it != registered_map_.end()) { + return it->second == mem_ptr; + } + return false; } -Error QnnMemManager::RegisterMem( +Error QnnMemManager::RegisterIonMem( const std::shared_ptr& tensor_wrapper, - int32_t mem_fd) { + int32_t mem_fd, + void* mem_ptr) { const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); Qnn_MemDescriptor_t descriptor = { {tensor_wrapper->GetRank(), tensor_wrapper->GetDims(), nullptr}, @@ -39,26 +44,128 @@ Error QnnMemManager::RegisterMem( return Error::Internal; } tensor_wrapper->SetMemHandle(handle); - registered_set_.insert(handle); + registered_map_.insert({handle, mem_ptr}); + QNN_EXECUTORCH_LOG_INFO( + "Tensor %s is successfully registered to ION shared memory.", + tensor_wrapper->GetName().c_str()); + return Error::Ok; +} + +Error QnnMemManager::RegisterCustomMem( + const std::shared_ptr& tensor_wrapper, + int32_t mem_fd, + void* mem_ptr, + void* unaligned_custom_mem_base, + size_t total_custom_mem_size, + size_t tensor_offset) { + const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + Qnn_MemDescriptor_t descriptor = { + {tensor_wrapper->GetRank(), tensor_wrapper->GetDims(), nullptr}, + tensor_wrapper->GetDataType(), + QNN_MEM_TYPE_CUSTOM, + {{mem_fd}}}; + Qnn_MemHandle_t handle = nullptr; + Qnn_ErrorHandle_t error = QNN_SUCCESS; + + QnnMemHtp_Descriptor_t htp_descriptor; + htp_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER; + htp_descriptor.size = total_custom_mem_size; + + QnnHtpMem_SharedBufferConfig_t htpSharedBuffConfig = {mem_fd, tensor_offset}; + htp_descriptor.sharedBufferConfig = htpSharedBuffConfig; + + descriptor.customInfo = &htp_descriptor; + + error = qnn_interface.qnn_mem_register( + context_->GetHandle(), + &descriptor, + /*numDescriptors=*/1, + &handle); + if (error != QNN_SUCCESS) { + QNN_EXECUTORCH_LOG_WARN( + "Tensor %s is failed to register shared memory. Error %d", + tensor_wrapper->GetName().c_str(), + QNN_GET_ERROR_CODE(error)); + return Error::Internal; + } + tensor_wrapper->SetMemHandle(handle); + registered_map_.insert({handle, mem_ptr}); QNN_EXECUTORCH_LOG_INFO( - "Tensor %s is successfully registered to shared memory.", + "Tensor %s is successfully registered to custom shared memory.", tensor_wrapper->GetName().c_str()); return Error::Ok; } +Error QnnMemManager::PreRegisterCustomMemHandle( + int32_t mem_fd, + void* unaligned_custom_mem_base, + size_t total_custom_mem_size, + size_t tensor_offset, + const CustomMemTensorInfo& info) { + const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); + Qnn_MemDescriptor_t descriptor = { + {info.rank, info.shape, nullptr}, + scalar_type_to_qnn_dtype_[info.dtype], + QNN_MEM_TYPE_CUSTOM, + {{mem_fd}}}; + Qnn_MemHandle_t handle = nullptr; + Qnn_ErrorHandle_t error = QNN_SUCCESS; + + QnnMemHtp_Descriptor_t htp_descriptor; + htp_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER; + htp_descriptor.size = total_custom_mem_size; + + QnnHtpMem_SharedBufferConfig_t htpSharedBuffConfig = {mem_fd, tensor_offset}; + htp_descriptor.sharedBufferConfig = htpSharedBuffConfig; + + descriptor.customInfo = &htp_descriptor; + + error = qnn_interface.qnn_mem_register( + context_->GetHandle(), + &descriptor, + /*numDescriptors=*/1, + &handle); + if (error != QNN_SUCCESS) { + QNN_EXECUTORCH_LOG_WARN( + "PreRegisterCustomMemHandle fail", QNN_GET_ERROR_CODE(error)); + return Error::Internal; + } + + pre_registered_handles_.insert({info, handle}); + registered_map_.insert({handle, nullptr}); + return Error::Ok; +} + +void* QnnMemManager::GetPreRegisteredHandle(const CustomMemTensorInfo& info) { + auto it = pre_registered_handles_.find(info); + if (it == pre_registered_handles_.end()) { + return nullptr; + } + return it->second; +} + +Error QnnMemManager::SetMemHandle( + const std::shared_ptr& tensor_wrapper, + void* mem_ptr, + Qnn_MemHandle_t handle) { + tensor_wrapper->SetMemHandle(handle); + registered_map_.insert({handle, mem_ptr}); + return Error::Ok; +} + void QnnMemManager::DeRegisterMem() { const QnnInterface& qnn_interface = implementation_.GetQnnInterface(); Qnn_ErrorHandle_t error = QNN_SUCCESS; - for (auto& mem_handle : registered_set_) { - error = qnn_interface.qnn_mem_de_register(&mem_handle, /*numHandles=*/1); + for (auto& it : registered_map_) { + error = qnn_interface.qnn_mem_de_register(&it.first, /*numHandles=*/1); if (error != QNN_SUCCESS) { QNN_EXECUTORCH_LOG_WARN( "Failed to de-register shared memory. Error %d", QNN_GET_ERROR_CODE(error)); } } - registered_set_.clear(); + registered_map_.clear(); } } // namespace qnn diff --git a/backends/qualcomm/runtime/backends/QnnMemManager.h b/backends/qualcomm/runtime/backends/QnnMemManager.h index 9d5949db16a..ea79c492e60 100644 --- a/backends/qualcomm/runtime/backends/QnnMemManager.h +++ b/backends/qualcomm/runtime/backends/QnnMemManager.h @@ -7,9 +7,11 @@ */ #pragma once #include +#include #include #include -#include +#include +#include "HTP/QnnHtpMem.h" namespace torch { namespace executor { @@ -25,18 +27,52 @@ class QnnMemManager { DeRegisterMem(); } - Error RegisterMem( + Error RegisterIonMem( const std::shared_ptr& tensor_wrapper, - int32_t mem_fd); + int32_t mem_fd, + void* mem_ptr); - bool IsRegistered(Qnn_MemHandle_t handle); + Error RegisterCustomMem( + const std::shared_ptr& tensor_wrapper, + int32_t mem_fd, + void* mem_ptr, + void* unaligned_custom_mem_base, + size_t total_custom_mem_size, + size_t tensor_offset); + + // Pre-register custom mem handle from SharedBuffer. Bring forward the + // memHandle creating time from execution to initialization. + Error PreRegisterCustomMemHandle( + int32_t mem_fd, + void* unaligned_custom_mem_base, + size_t total_custom_mem_size, + size_t tensor_offset, + const CustomMemTensorInfo& info); + + bool IsRegistered(Qnn_MemHandle_t handle, void* mem_ptr); + + void* GetPreRegisteredHandle(const CustomMemTensorInfo& info); + + Error SetMemHandle( + const std::shared_ptr& tensor_wrapper, + void* mem_ptr, + Qnn_MemHandle_t handle); private: void DeRegisterMem(); const QnnImplementation& implementation_; QnnContext* context_; - std::unordered_set registered_set_; + std::unordered_map registered_map_; + std::unordered_map pre_registered_handles_; + std::unordered_map scalar_type_to_qnn_dtype_ = { + {ScalarType::Int, Qnn_DataType_t::QNN_DATATYPE_INT_32}, + {ScalarType::Float, Qnn_DataType_t::QNN_DATATYPE_FLOAT_32}, + {ScalarType::Char, Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_8}, + {ScalarType::Short, Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_16}, + {ScalarType::Byte, Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8}, + {ScalarType::Bits16, Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16}, + }; }; } // namespace qnn } // namespace executor diff --git a/backends/qualcomm/scripts/build.sh b/backends/qualcomm/scripts/build.sh index c8379cf0b7a..b2c8e0d61ca 100755 --- a/backends/qualcomm/scripts/build.sh +++ b/backends/qualcomm/scripts/build.sh @@ -71,6 +71,7 @@ if [ "$BUILD_AARCH64" = true ]; then -DCMAKE_INSTALL_PREFIX=$BUILD_ROOT \ -DEXECUTORCH_BUILD_QNN=ON \ -DEXECUTORCH_BUILD_SDK=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_ENABLE_EVENT_TRACER=ON \ -DQNN_SDK_ROOT=$QNN_SDK_ROOT \ -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ diff --git a/backends/qualcomm/serialization/qnn_compile_spec_schema.py b/backends/qualcomm/serialization/qnn_compile_spec_schema.py index 0f926fc0975..338f61997ea 100644 --- a/backends/qualcomm/serialization/qnn_compile_spec_schema.py +++ b/backends/qualcomm/serialization/qnn_compile_spec_schema.py @@ -132,3 +132,4 @@ class QnnExecuTorchOptions: tensor_dump_output_path: str = "" profile_level: QnnExecuTorchProfileLevel = QnnExecuTorchProfileLevel.kProfileOff shared_buffer: bool = False + is_from_context_binary: bool = False diff --git a/backends/qualcomm/serialization/schema.fbs b/backends/qualcomm/serialization/schema.fbs index 8c4d23172f0..4288c83b130 100644 --- a/backends/qualcomm/serialization/schema.fbs +++ b/backends/qualcomm/serialization/schema.fbs @@ -175,6 +175,9 @@ table QnnExecuTorchOptions { /// Enables usage of shared buffer between application and backend for graph I/O. shared_buffer:bool; + + /// Is model from qnn context binary + is_from_context_binary: bool; } root_type QnnExecuTorchOptions; diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index b1f06d6e871..35fcb6bfc63 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -29,7 +29,7 @@ def __init__(self): super().__init__() def forward(self, x): - return 10.0 + x + return 10 + x class Arange(torch.nn.Module): @@ -788,6 +788,20 @@ def forward(self, x, y): return x[:, :seq_length] + self.position_ids[:, :seq_length] +class SliceCopyWithStep(torch.nn.Module): + def __init__(self): + super().__init__() + self.position_ids = torch.randn([1, 512]) + self.step = 2 + + def forward(self, x, y): + seq_length = y.size()[1] + return ( + x[:, : seq_length : self.step] + + self.position_ids[:, : seq_length : self.step] + ) + + class Softmax(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index de4cfbdf049..98deb8e11fc 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -312,6 +312,8 @@ def test_qnn_backend_mha(self): sample_input = (torch.randn(1, 197, 96),) self.lower_module_and_test_output(module, sample_input) + # fp16 pad op might hit corner case in runtime + @unittest.expectedFailure def test_qnn_backend_pad(self): module = Pad() # noqa: F405 sample_input = (torch.randn([1, 8, 128]),) @@ -391,12 +393,13 @@ def test_qnn_backend_select_copy(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_slice_copy(self): - module = SliceCopy() # noqa: F405 + modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405 sample_input = ( torch.randn([1, 512]), torch.randn([1, 8]), ) - self.lower_module_and_test_output(module, sample_input) + for module in modules: + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_stack(self): module = Stack() # noqa: F405 @@ -423,7 +426,6 @@ def test_qnn_backend_tanh(self): sample_input = (torch.randn(2, 5, 1, 3),) self.lower_module_and_test_output(module, sample_input) - @unittest.expectedFailure def test_qnn_backend_unbind(self): module = Unbind() # noqa: F405 sample_input = (torch.randn([3, 3]),) @@ -979,13 +981,14 @@ def test_qnn_backend_sigmoid(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_slice_copy(self): - module = SliceCopy() # noqa: F405 + modules = [SliceCopy(), SliceCopyWithStep()] # noqa: F405 sample_input = ( torch.randn([1, 512]), torch.randn([1, 8]), ) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + for module in modules: + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_softmax(self): module = Softmax() # noqa: F405 @@ -1020,7 +1023,6 @@ def test_qnn_backend_tanh(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - @unittest.expectedFailure def test_qnn_backend_unbind(self): module = Unbind() # noqa: F405 sample_input = (torch.randn([3, 3]),) @@ -1944,16 +1946,13 @@ def test_deeplab_v3(self): self.assertGreaterEqual(msg["MPA"], 0.70) self.assertGreaterEqual(msg["MIoU"], 0.55) - def test_dummy_llama2(self): - self.skipTest( - "The module of llama is changing frequently. Reopen it when it's stable" - ) + def test_stories_single_llama(self): if not self.required_envs(): self.skipTest("missing required envs") cmds = [ "python", - f"{self.executorch_root}/examples/qualcomm/scripts/dummy_llama2.py", + f"{self.executorch_root}/examples/qualcomm/llama2/llama.py", "--artifact", self.artifact_dir, "--build_folder", @@ -1962,59 +1961,36 @@ def test_dummy_llama2(self): self.device, "--model", self.model, + "--checkpoint", + f"{self.artifact_dir}/stories110M.pt", + "--params", + f"{self.artifact_dir}/params.json", + "--tokenizer_model", + f"{self.artifact_dir}/tokenizer.model", + "--tokenizer_bin", + f"{self.artifact_dir}/tokenizer.bin", "--ip", self.ip, "--port", str(self.port), - "--use_fp16", + "--prompt", + "Once", + "--ptq", + "16a4w", + "--temperature", + "0", ] if self.host: cmds.extend(["--host", self.host]) - if self.shared_buffer: - cmds.extend(["--shared_buffer"]) - - p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) - with Listener((self.ip, self.port)) as listener: - conn = listener.accept() - p.communicate() - msg = json.loads(conn.recv()) - self.assertTrue(msg["is_close"]) - - @unittest.expectedFailure - def test_ptq_dummy_llama2(self): - self.skipTest( - "The module of llama is changing frequently. Reopen it when it's stable" - ) - if not self.required_envs(): - self.skipTest("missing required envs") - - cmds = [ - "python", - f"{self.executorch_root}/examples/qualcomm/scripts/dummy_llama2.py", - "--artifact", - self.artifact_dir, - "--build_folder", - self.build_folder, - "--device", - self.device, - "--model", - self.model, - "--ip", - self.ip, - "--port", - str(self.port), - ] - if self.host: - cmds.extend(["--host", self.host]) - if self.shared_buffer: - cmds.extend(["--shared_buffer"]) + golden_start_with = "Once upon a time," p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) with Listener((self.ip, self.port)) as listener: conn = listener.accept() p.communicate() msg = json.loads(conn.recv()) - self.assertTrue(all(msg["is_close"])) + model_out = msg["result"][0] + self.assertTrue(model_out.startswith(golden_start_with)) def test_mobilebert(self): if not self.required_envs([self.pretrained_weight]): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 608aa0bced5..295033e572f 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -148,7 +148,7 @@ def validate_profile(): adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=self.build_folder, + build_path=self.build_folder, pte_path=pte_fname, workspace="/data/local/tmp/qnn_executorch_test", device_id=self.device, diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index ac55611ecfc..dde852135bf 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -30,7 +30,7 @@ from executorch.backends.qualcomm.passes.recompose_pixel_unshuffle import ( RecomposePixelUnshuffle, ) -from executorch.backends.qualcomm.passes.remove_clone import RemoveClone +from executorch.backends.qualcomm.passes.remove_redundancy import RemoveRedundancy from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( _soc_info_table, QcomChipset, @@ -67,6 +67,48 @@ def qnn_edge_config() -> exir.EdgeCompileConfig: ) +def convert_linear_to_conv2d(module: torch.nn.Module): + class Conv2D(torch.nn.Module): + def __init__(self, weight, bias=None): + super().__init__() + use_bias = bias is not None + self.conv = torch.nn.Conv2d( + in_channels=weight.shape[0], + out_channels=weight.shape[1], + kernel_size=1, + padding=0, + bias=use_bias, + ) + self.conv.weight = torch.nn.Parameter(weight.reshape(*weight.shape, 1, 1)) + if use_bias: + self.conv.bias = torch.nn.Parameter(bias) + + def forward(self, x): + rank = x.dim() + x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1) + x = torch.transpose(x, 1, 2) + res = self.conv(x) + res = torch.transpose(res, 1, 2) + res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3]) + return res + + def replace_linear(module: torch.nn.Module): + attr_strs = dir(module) + if isinstance(module, torch.nn.ModuleList): + attr_strs += [str(i) for i in range(len(module))] + + for attr_str in attr_strs: + target_attr = getattr(module, attr_str) + if isinstance(target_attr, torch.nn.Linear): + setattr(module, attr_str, Conv2D(target_attr.weight, target_attr.bias)) + + for _, sub_module in module.named_children(): + sub_module = replace_linear(sub_module) + return module + + return replace_linear(module) + + def canonicalize_program(prog: ExportedProgram): # check if user specifies to use multi_contexts # this is a generic approach in case there exists multiple backends @@ -109,7 +151,7 @@ def _transform(edge_program: ExportedProgram) -> None: # changes of input number which was caused by FoldQDQ # apply passes one by one here to avoid IR capture failure graph_module = edge_program.graph_module - RemoveClone()(graph_module) + RemoveRedundancy()(graph_module) RecomposePixelUnshuffle()(graph_module) ConvertToLinear()(graph_module) ConvertPReLU(edge_program)(graph_module) @@ -138,6 +180,7 @@ def capture_program( core_ep.transform(ConvertBinaryOpsWithScalar()) edge_ep = core_ep.to_edge(qnn_edge_config()) _transform(edge_ep.exported_program) + return edge_ep diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index c9a1c111193..7234632a122 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +set(CMAKE_CXX_STANDARD 17) # qnn_executor_runner: Like executor_runner but with QNN + if(NOT ${ANDROID}) message(FATAL_ERROR "Not building Android, quitting...") endif() @@ -51,6 +53,7 @@ get_filename_component( EXECUTORCH_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../.." ABSOLUTE ) set(_qnn_executor_runner__srcs ${_executor_runner__srcs}) +set(_qnn_llama_runner__srcs ${_llama_runner__srcs}) # portable_ops_lib gen_selected_ops(LIB_NAME "full_portable_ops_lib" INCLUDE_ALL_OPS "ON") @@ -68,19 +71,73 @@ target_include_directories( full_portable_ops_lib PUBLIC ${_common_include_directories} ) -list(TRANSFORM _qnn_executor_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/") -list(FILTER _qnn_executor_runner__srcs EXCLUDE REGEX ".*executor_runner.cpp$") -list(PREPEND _qnn_executor_runner__srcs - ${CMAKE_CURRENT_LIST_DIR}/executor_runner/qnn_executor_runner.cpp +# prerpocess executor runner src files +list( + TRANSFORM + _qnn_executor_runner__srcs + PREPEND + "${EXECUTORCH_SOURCE_DIR}/" +) +list( + FILTER + _qnn_executor_runner__srcs + EXCLUDE REGEX + ".*executor_runner.cpp$" +) +list( + PREPEND + _qnn_executor_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/executor_runner/qnn_executor_runner.cpp ) -add_executable(qnn_executor_runner ${_qnn_executor_runner__srcs}) +# preprocess llama runner src files +list( + TRANSFORM + _qnn_llama_runner__srcs + PREPEND + "${EXECUTORCH_SOURCE_DIR}/" +) +list( + FILTER + _qnn_llama_runner__srcs + EXCLUDE REGEX + ".*runner.cpp$" +) +list( + PREPEND + _qnn_llama_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/executor_runner/qnn_llama_runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/llama2/runner/runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/llama2/runner/runner.h +) -target_include_directories( - qnn_executor_runner PUBLIC ${_common_include_directories} +# build executor runner +add_executable(qnn_executor_runner ${_qnn_executor_runner__srcs}) +target_include_directories(qnn_executor_runner + PUBLIC + ${_common_include_directories} ) target_link_libraries( qnn_executor_runner qnn_executorch_backend full_portable_ops_lib etdump ${FLATCCRT_LIB} gflags ) -target_compile_options(qnn_executor_runner PUBLIC ${_common_compile_options}) +target_link_options( + qnn_executor_runner PUBLIC -fsanitize=undefined) + +# build llama runner +add_executable(qnn_llama_runner ${_qnn_llama_runner__srcs}) +target_include_directories(qnn_llama_runner + PUBLIC + ${_common_include_directories} +) +target_link_libraries(qnn_llama_runner + qnn_executorch_backend + full_portable_ops_lib + extension_data_loader + extension_module + gflags +) +target_compile_options(qnn_llama_runner + PUBLIC + ${_common_compile_options} +) diff --git a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp index 8998ee634e0..e5ced476ac9 100644 --- a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp @@ -13,7 +13,7 @@ * This tool can run ExecuTorch model files with Qualcomm AI Engine Direct * and the portable kernels. * - * User could specify arguments like desired input data, iterfations, etc. + * User could specify arguments like desired input data, iterations, etc. * Currently we assume that the outputs are all fp32 tensors. */ diff --git a/examples/qualcomm/executor_runner/qnn_llama_runner.cpp b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp new file mode 100644 index 00000000000..0d654e68363 --- /dev/null +++ b/examples/qualcomm/executor_runner/qnn_llama_runner.cpp @@ -0,0 +1,95 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * This tool can run ExecuTorch model files with Qualcomm AI Engine Direct + * and the portable kernels. + * + * User could specify arguments like desired input data, iterations, etc. + */ + +#include +#include +#include +#include + +#include + +#include +#include + +using torch::executor::MemoryAllocator; + +DEFINE_string( + model_path, + "qnn_llama2.pte", + "Model serialized in flatbuffer format."); + +DEFINE_string( + output_folder_path, + "outputs", + "Executorch inference data output path."); + +DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff."); + +DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt."); + +DEFINE_double( + temperature, + 0.8f, + "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); + +DEFINE_int32( + seq_len, + 128, + "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens."); + +int main(int argc, char** argv) { + using namespace torch::executor; + + gflags::ParseCommandLineFlags(&argc, &argv, true); + + const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); + const char* prompt = FLAGS_prompt.c_str(); + double temperature = FLAGS_temperature; + int32_t seq_len = FLAGS_seq_len; + + // create llama runner + Runner runner(FLAGS_model_path, tokenizer_path, temperature); + ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method"); + + // MethodMeta describes the memory requirements of the method. + Result method_meta = runner.get_method_meta(); + ET_CHECK_MSG( + method_meta.ok(), + "Failed to get method_meta 0x%x", + (unsigned int)method_meta.error()); + ET_CHECK_MSG( + runner.mem_alloc(MemoryAllocator::kDefaultAlignment, seq_len) == + Error::Ok, + "Runner failed to allocate memory"); + + // generate tokens + std::string inference_output; + // prompt are determined by command line arguments + // pos_ids, atten_mask are infered inside runner + runner.generate(prompt, seq_len, [&](const std::string& piece) { + inference_output += piece; + }); + + size_t inference_index = 0; + auto output_file_name = FLAGS_output_folder_path + "/output_" + + std::to_string(inference_index++) + "_0.raw"; + std::ofstream fout(output_file_name.c_str()); + fout << inference_output; + fout.close(); + + return 0; +} diff --git a/examples/qualcomm/llama2/README.md b/examples/qualcomm/llama2/README.md new file mode 100644 index 00000000000..5f463a8aaa9 --- /dev/null +++ b/examples/qualcomm/llama2/README.md @@ -0,0 +1,40 @@ +# Summary + +## Overview +This file provides you the instructions to run LLAMA2 with different parameters via Qualcomm HTP backend. The following setting is the support +1. Stories 110M + +Please check corresponding section for more information. + +## Stories 110M +This example demonstrates how to run a smaller LLAMA2, stories110M on mobile via Qualcomm HTP backend. Model architecture is fine-tuned specifically for HTP to accelerate the performance. Weight is quantized via PTQ quantization to fit the model on a phone. + +## Instructions +### Step 1: Set up +1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. +2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend. + +### Step2: Prepare Model +Download and preapre stories110M model + +```bash +# tokenizer.model & stories110M.pt: +wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt" +wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model" + +# tokenizer.bin: +python -m examples.models.llama2.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin + +# params.json: +echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json +``` + +### Step3: Run default examples +Default example generates the story based on the given prompt, "Once". +```bash +# 16a4w quant: +python examples/qualcomm/llama2/llama.py -a ${ARTIFACTS} -b build_android -s ${SERIAL_NUM} -H ${HOST_NAME} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --prompt "Once" +``` + +### (Note) Customized PTQ data set +User prompts are used for PTQ calibration data. Take the examples above, the word "Once" is the only word for PTQ. If you want to observe more data during the calibration time. Please add more prompts to the args `--prompt`. diff --git a/examples/qualcomm/llama2/llama.py b/examples/qualcomm/llama2/llama.py new file mode 100644 index 00000000000..a1a939cb60e --- /dev/null +++ b/examples/qualcomm/llama2/llama.py @@ -0,0 +1,603 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import codecs +import getpass +import json +import os +import shutil +import stat +import time +from multiprocessing.connection import Client + +import torch + +from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner +from executorch.backends.qualcomm.passes.build_quant_io import BuildQuantIo +from executorch.backends.qualcomm.passes.utils import q_io_key + +from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype +from executorch.backends.qualcomm.quantizer.utils import get_16a4w_qnn_ptq_config +from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( + QcomChipset, +) +from executorch.backends.qualcomm.utils.utils import ( + capture_program, + convert_linear_to_conv2d, + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, +) +from executorch.examples.models.llama2.builder import DType +from executorch.examples.qualcomm.llama2.model.static_llama import LlamaModel, ModelArgs +from executorch.examples.qualcomm.scripts.utils import ( + make_output_dir, + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from executorch.exir.program._program import _get_updated_graph_signature + +from sentencepiece import SentencePieceProcessor +from torch.ao.quantization.observer import MinMaxObserver +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + + +soc_to_chipset_map = { + "SM8650": QcomChipset.SM8650, + "SM8550": QcomChipset.SM8550, + "SM8475": QcomChipset.SM8475, + "SM8450": QcomChipset.SM8450, +} + + +pte_filename = "llama2_qnn" + + +def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: + """ + This function is specific for matmul op 16a8w. + """ + from typing import Sequence + + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_16a8w_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, + QuantizationConfig, + ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY + from torch.ao.quantization.quantizer import ( + QuantizationAnnotation, + SharedQuantizationSpec, + ) + from torch.fx import Node + + def annotate_matmul(node: Node, quantization_config: QuantizationConfig): + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + input_act1 = node.args[1] + input_spec1 = quantization_config.weight + input_qspec_map[input_act1] = input_spec1 + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_cat(node: Node, quantization_config: QuantizationConfig): + input_nodes = node.args[0] + + first_input_node = input_nodes[0] + input_qspec_map = {} + input_qspec_map[first_input_node] = quantization_config.input_activation + share_qparams_with_input_act0_qspec = SharedQuantizationSpec( + (first_input_node, node) + ) + + for input_node in input_nodes[1:]: + if input_node not in input_qspec_map: + input_qspec_map[input_node] = share_qparams_with_input_act0_qspec + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=share_qparams_with_input_act0_qspec, + _annotated=True, + ) + + def annotate_single_in_single_out( + node: Node, quantization_config: QuantizationConfig + ) -> None: + + input_qspec_map = {} + input_act = node.args[0] + input_qspec_map[input_act] = quantization_config.input_activation + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + def annotate_matmul_input1(node: Node): + quantization_config_8a8w = get_default_8bit_qnn_ptq_config(act_symmetric=True) + while isinstance(node, Node) and node.op == "call_function": + if node.target in [ + torch.ops.aten.permute.default, + torch.ops.aten.transpose.int, + ]: + annotate_single_in_single_out(node, quantization_config_8a8w) + node = node.args[0] + elif node.target == torch.ops.aten.cat.default: + annotate_cat(node, quantization_config_8a8w) + node = node.args[0][0] + else: + node = node.args[0] + + quantization_config_16a8w = get_16a8w_qnn_ptq_config() + + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.matmul.default: + annotate_matmul(node, quantization_config_16a8w) + annotate_matmul_input1(node.args[1]) + + +def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None: + from executorch.backends.qualcomm.quantizer.quantizer import ( + get_ptq_per_channel_weight_config, + QuantizationConfig, + ) + from executorch.backends.qualcomm.quantizer.utils import QUANT_ANNOTATION_KEY + from torch.ao.quantization.quantizer import QuantizationAnnotation + from torch.fx import Node + + def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: + input_qspec_map = {} + input_act = node.args[0] + input_spec = quantization_config.input_activation + input_qspec_map[input_act] = input_spec + + weight = node.args[1] + input_qspec_map[weight] = quantization_config.weight + + node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map=input_qspec_map, + output_qspec=quantization_config.output_activation, + _annotated=True, + ) + + quantization_config_16a8w_per_channel = get_ptq_per_channel_weight_config( + torch.uint16, weight_dtype=torch.int8 + ) + for node in gm.graph.nodes: + if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default: + if "nn_module_stack" in node.meta: + module_values_list = list(node.meta["nn_module_stack"].values()) + full_qualified_name = module_values_list[0][0] + if full_qualified_name == "L['self'].llama.output": + annotate_conv2d( + node, quantization_config=quantization_config_16a8w_per_channel + ) + + +def calibrate( + example_inputs, + user_prompts, + module: torch.fx.GraphModule, + tokenizer_model_path="tokenizer.model", +): + sp_model = SentencePieceProcessor(model_file=tokenizer_model_path) + _, _, atten_mask, k_caches, v_caches = example_inputs + + # TODO: change criteria & support batch inputs if necessary + pos = torch.tensor(0, dtype=torch.int32) + token_list = [sp_model.bos_id()] + for prompt in user_prompts.split(): + token_list += sp_model.encode(prompt) + + def sample_top_p(probs: torch.Tensor, top_p: float) -> torch.Tensor: + probs_sort, probs_indices = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > top_p + probs_sort[mask] = 0 + probs_sort /= probs_sort.sum(dim=-1, keepdim=True) + next_token = torch.multinomial(probs_sort, num_samples=1) + return probs_indices.gather(dim=-1, index=next_token) + + with torch.no_grad(): + while token_list[-1] != sp_model.eos_id() and pos < 128: + logits, new_k_caches, new_v_caches = module( + torch.full((1, 1), token_list[pos]), + torch.full((1, 1), pos), + atten_mask, + *k_caches, + *v_caches, + ) + k_caches = [ + torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1) + for i, k_cache in enumerate(k_caches) + ] + v_caches = [ + torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1) + for i, v_cache in enumerate(v_caches) + ] + + pos += 1 + atten_mask[0][-pos - 1] = 0 + if pos >= len(token_list): + probs = torch.softmax(logits[:, -1] / 0.8, dim=-1) + token_list.append(sample_top_p(probs, 0.9).item()) + + print(f"calibration data:\n{sp_model.decode(token_list)}") + + +class SingleLlama: + def __init__(self, llama_model) -> None: + super().__init__() + self.llama_model = llama_model + self.quant_dtype = None + self.llama_meta = self.llama_model.get_metadata() + self.has_quant_io = False + tokens, pos_ids, atten_mask, k_caches, v_caches = self.get_example_inputs() + self.inputs = (tokens, pos_ids, atten_mask, *k_caches, *v_caches) + + def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type): + if not self.has_quant_io: + return + + # shape of k caches and v caches + input_cache_shape = { + (self.llama_meta["get_head_dim"], self.llama_meta["get_max_seq_len"]), + (self.llama_meta["get_max_seq_len"], self.llama_meta["get_head_dim"]), + } + for n in gm.graph.nodes: + if ( + n.op == "placeholder" + and len(users := list(n.users)) == 1 + and users[0].meta["val"].size()[-2:] in input_cache_shape + ): + n.meta[q_io_key] = kv_type + elif n.op == "output": + for a in n.args[0]: + if ( + a.meta["val"].flatten().size()[0] + == self.llama_meta["get_head_dim"] + ): + a.meta[q_io_key] = kv_type + + def quantize(self, quant_dtype, custom_annotations=()): + self.quant_dtype = quant_dtype + quantizer = QnnQuantizer() + quantizer.set_per_channel_linear_quant(True) + quantizer.set_per_channel_conv_quant(True) + + if quant_dtype == QuantDtype.use_8a8w: + pass # default setting + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) + ) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + else: + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") + quantizer.add_custom_quant_annotations(custom_annotations) + + self.has_quant_io = True + fx_graph_module = None + + with torch.no_grad(): + fx_graph_module = torch._export.capture_pre_autograd_graph( + self.llama_model, self.inputs + ) + fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) + print("Quantizing the model...") + calibrate( + self.get_example_inputs(), + args.prompt, + fx_graph_module, + tokenizer_model_path=args.tokenizer_model, + ) + + self.llama_model = convert_pt2e(fx_graph_module) + + def lowering_modules( + self, work_space, kv_type=torch.uint8, soc_model=QcomChipset.SM8650 + ): + executorch_config = ExecutorchBackendConfig( + passes=[ + BuildQuantIo(), + ], + extract_constant_segment=False, + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + with torch.no_grad(): + # backend option + backend_options = generate_htp_compiler_spec(use_fp16=False) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=soc_model, + backend_options=backend_options, + shared_buffer=True, + ) + partitioner = QnnPartitioner(compiler_specs) + edge_prog = capture_program(self.llama_model, self.inputs) + self._tag_kv_ios(edge_prog.exported_program.graph_module, kv_type=kv_type) + edge_prog_mgr = EdgeProgramManager( + edge_programs={"forward": edge_prog.exported_program}, + constant_methods=self.llama_meta, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + setattr( + edge_prog_mgr.exported_program(), + "_graph_signature", + _get_updated_graph_signature( + edge_prog_mgr.exported_program().graph_signature, + edge_prog_mgr.exported_program().graph_module, + ), + ) + + edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{work_space}/{pte_filename}.pte", "wb") as file: + exec_prog_mgr.write_to_file(file) + + def get_example_inputs(self): + return self.llama_model.get_example_inputs() + + def get_export_inputs(self): + return self.llama_model.get_export_inputs() + + +def compile(args): + os.makedirs(args.artifact, exist_ok=True) + start_ts = time.time() + with open(args.params) as f: + config = ModelArgs(**json.load(f)) + # TODO: support batch inputs if necessary + config.max_batch_size = 1 + config.max_seq_len = 1024 + state_dict = torch.load( + args.checkpoint, weights_only=True, map_location="cpu", mmap=True + ) + end_load_ts = time.time() + print("torch.load checkpoint", end_load_ts - start_ts) + llama_instance = None + with torch.device("meta"): + llama_instance = LlamaModel(config, output_new_cache_only=True) + if "model" in state_dict: + state_dict = state_dict["model"] + llama_instance.load_state_dict( + state_dict, + strict=False, + assign=True, + ) + end_load_state_dict_ts = time.time() + print("instance.load_state_dict", end_load_state_dict_ts - end_load_ts) + + for layer in llama_instance.layers: + if getattr(layer.attention, "prepare_sha", None): + layer.attention.prepare_sha() + kv_type = torch.uint8 + if args.ptq == "8a8w": + quant_dtype = QuantDtype.use_8a8w + elif args.ptq == "16a4w": + quant_dtype = QuantDtype.use_16a4w + else: + raise AssertionError( + f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." + ) + + assert args.tokenizer_model is not None, "Need tokenizer model for calibration" + + if args.dtype_override is not None: + dtype_override = DType[args.dtype_override] + llama_instance = llama_instance.to(dtype_override.to_torch_dtype()) + + llama_instance = convert_linear_to_conv2d(llama_instance) + single_llama = SingleLlama(llama_instance.eval()) + + start_quantize_ts = time.time() + single_llama.quantize( + quant_dtype, + custom_annotations=( + annotate_matmul_16a8w, + annotate_linear_16a8w_in_affine_layer, + ), + ) + end_quantize_ts = time.time() + print("single_llama.quantize(quant_dtype)", end_quantize_ts - start_quantize_ts) + single_llama.lowering_modules( + args.artifact, kv_type=kv_type, soc_model=soc_to_chipset_map[args.model] + ) + end_lowering_ts = time.time() + print("Complete Compile", end_lowering_ts - end_quantize_ts) + + +def inference(args, pre_gen_pte=""): + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" + + runner_args = " ".join( + [ + f"--model_path {pte_filename}.pte", + "--output_folder_path outputs", + f"--tokenizer_path {os.path.basename(args.tokenizer_bin)}", + f'--prompt "{args.prompt}"', + f"--seq_len {args.seq_len}", + f"--temperature {args.temperature}", + ] + ) + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qnn_llama_runner {runner_args}", + ] + ) + + pte_path = ( + f"{pre_gen_pte}/{pte_filename}.pte" + if pre_gen_pte + else f"{args.artifact}/{pte_filename}.pte" + ) + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + runner="examples/qualcomm/qnn_llama_runner", + ) + # No pregen inputs, input_list is not required + adb.push(inputs=[], input_list="", files=[args.tokenizer_bin]) + adb.execute(custom_runner_cmd=runner_cmd) + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + for f in sorted( + os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) + ): + with codecs.open( + os.path.join(output_data_folder, f), + "r", + encoding="utf-8", + errors="replace", + ) as fdata: + outputs.append(fdata.read()) + + adb.pull(output_path=args.artifact, callback=post_process) + + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "result": outputs, + } + ) + ) + else: + for idx, output in enumerate(outputs): + print(f"Results[{idx}]:\n{output}") + + +# flake8: noqa: C901 +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts and output by this example. Default ./llama2_qnn", + default="./llama2_qnn", + type=str, + ) + + parser.add_argument( + "-P", + "--ptq", + help="If specified, will do PTQ quantization. default is 16bits activation and 4bits weight. Support 8a8w and 16a4w.", + required=True, + default="16a4w", + ) + + parser.add_argument( + "--checkpoint", + help="Pass llama2 checkpoint.", + required=True, + type=str, + ) + + parser.add_argument( + "--params", + help="Pass llama2 params json file.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_bin", + help="Pass llama2 tokenizer binary.", + required=True, + type=str, + ) + + parser.add_argument( + "--tokenizer_model", + help="Pass llama2 tokenizer model.", + type=str, + default=None, + ) + + parser.add_argument( + "--prompt", + help="User prompts for llama2.", + required=True, + type=str, + ) + + parser.add_argument( + "--seq_len", + help="Ouput sequence length for llama2.", + default=128, + type=int, + ) + + parser.add_argument( + "--temperature", + help="Sampling temperature for llama2.", + default=0.8, + type=float, + ) + + parser.add_argument( + "-d", + "--dtype-override", + default="fp32", + type=str, + choices=["fp32", "fp16"], + help="Override the dtype of the model (default is the checkpoint dtype). Options: fp32", + ) + + parser.add_argument( + "--pre_gen_pte", + help="Run the Pre-generated llama2 in the given directory", + type=str, + ) + + args = parser.parse_args() + if args.compile_only and args.pre_gen_pte: + exit("Cannot set both compile_only and pre_gen_pte as true") + + if args.pre_gen_pte: + inference(args, args.pre_gen_pte) + exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") + + compile(args) + if args.compile_only: + exit(f"Finish compile_only and save to {args.artifact}") + + inference(args) diff --git a/examples/qualcomm/llama2/model/static_llama.py b/examples/qualcomm/llama2/model/static_llama.py new file mode 100644 index 00000000000..85f018e71f5 --- /dev/null +++ b/examples/qualcomm/llama2/model/static_llama.py @@ -0,0 +1,350 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + +import torch +import torch.nn as nn + +from executorch.examples.models.llama2.llama_transformer import ( + FeedForward, + ModelArgs, + precompute_freqs_cis, + RMSNorm, +) + + +def apply_rotary_emb_single( + x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor +) -> torch.Tensor: + x_r, x_i = x[..., ::2], x[..., 1::2] + + x_out_r = x_r * freqs_cos - x_i * freqs_sin + x_out_i = x_r * freqs_sin + x_i * freqs_cos + + x_out = torch.cat([x_out_r, x_out_i], dim=-1) + return x_out + + +class LlamaAttention(nn.Module): + def __init__(self, config: ModelArgs, output_new_cache_only=False): + super().__init__() + self.dim = config.dim + self.n_heads = config.n_heads + self.head_dim = config.dim // config.n_heads + self.n_kv_heads = config.n_kv_heads + self.num_key_value_groups = config.n_heads // self.n_kv_heads + self.max_seq_len = config.max_seq_len + self.output_new_cache_only = output_new_cache_only + + self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False) + + self.attn_softmax = torch.nn.Softmax(dim=-1) + + self.scale = float(self.head_dim) ** 0.5 + + def prepare_sha(self): + self.wq_sha = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wk_sha = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + self.wv_sha = nn.ModuleList( + [ + nn.Linear(self.dim, self.head_dim, bias=False) + for _ in range(self.n_heads) + ] + ) + + self.forward_mha = self.forward + self.forward = self.forward_sha + + for i in range(self.n_heads): + self.wq_sha[i].weight.data.copy_( + self.wq.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wk_sha[i].weight.data.copy_( + self.wk.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + self.wv_sha[i].weight.data.copy_( + self.wv.weight[i * self.head_dim : (i + 1) * self.head_dim] + ) + + def forward_sha( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = [wq_sha(hidden_states) for wq_sha in self.wq_sha] + k = [wk_sha(hidden_states) for wk_sha in self.wk_sha] + v = [wv_sha(hidden_states) for wv_sha in self.wv_sha] + for i in range(len(q)): + q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) + k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin).permute(0, 2, 1) + + output_kh, output_vh, output_y = [], [], [] + for i, _ in enumerate(k_caches): + # cat at the seq dim + kh = torch.cat([k_caches[i], k[i]], dim=-1) + vh = torch.cat([v_caches[i], v[i]], dim=1) + + attn = q[i] @ kh + attn = attn / self.scale + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh + + if self.output_new_cache_only: + output_kh.append(k[i]) + output_vh.append(v[i]) + else: + output_kh.append(kh) + output_vh.append(vh) + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.wo(y) + return y, output_kh, output_vh + + def forward( + self, + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + bsz, seqlen, _ = hidden_states.shape + + q, k, v = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states) + q = q.view(bsz, seqlen, self.n_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + + q = apply_rotary_emb_single(q, freqs_cos, freqs_sin) + k = apply_rotary_emb_single(k, freqs_cos, freqs_sin).permute(0, 2, 3, 1) + + output_kh, output_vh, output_y = [], [], [] + + for i, _ in enumerate(k_caches): + # cat at the seq dim + kh = torch.cat([k_caches[i], k[:, i, :, :]], dim=-1) + vh = torch.cat([v_caches[i], v[:, :, i, :]], dim=1) + + attn = q[:, :, i, :] @ kh + attn = attn / self.scale + atten_mask + attn = self.attn_softmax(attn) + y = attn @ vh + + if self.output_new_cache_only: + output_kh.append(k[:, i, :, :]) + output_vh.append(v[:, :, i, :]) + else: + output_kh.append(kh) + output_vh.append(vh) + output_y.append(y) + + y = torch.concat(output_y, dim=-1) + y = self.wo(y) + + return y, output_kh, output_vh + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: ModelArgs, output_new_cache_only=False): + super().__init__() + self.dim = config.dim + self.attention = LlamaAttention( + config=config, output_new_cache_only=output_new_cache_only + ) + self.feed_forward = FeedForward(config) + self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + atten_mask: torch.Tensor, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + h, k_cache, v_cache = self.attention( + hidden_states=self.attention_norm(x), + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + h = x + h + output = h + self.feed_forward(self.ffn_norm(h)) + return output, k_cache, v_cache + + +class LlamaModel(nn.Module): + def __init__(self, config: ModelArgs, output_new_cache_only=True): + super().__init__() + self.dim = config.dim + self.head_dim = config.dim // config.n_heads + self.max_batch_size = config.max_batch_size + self.max_seq_len = config.max_seq_len + self.n_heads = config.n_heads + self.n_kv_heads = config.n_kv_heads + self.n_layers = config.n_layers + self.vocab_size = config.vocab_size + self.rope_freq_base = config.rope_freq_base + self.output_new_cache_only = output_new_cache_only + + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, self.output_new_cache_only) + for _ in range(config.n_layers) + ] + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + freqs_cos, freqs_sin = precompute_freqs_cis( + config.dim // config.n_heads, + config.max_seq_len, + config.rope_freq_base, + ) + self.register_buffer("freqs_cos", freqs_cos, persistent=False) + self.register_buffer("freqs_sin", freqs_sin, persistent=False) + + def forward( + self, + tokens: torch.Tensor, + input_pos: torch.Tensor, + atten_mask: torch.Tensor, + *args, + ) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]: + output_k_cache = [] + output_v_cache = [] + # following tensors should be invariant across batches + freqs_cos = self.freqs_cos[input_pos][0] + freqs_sin = self.freqs_sin[input_pos][0] + + hidden_states = self.tok_embeddings(tokens) + for ind, decoder_layer in enumerate(self.layers): + offset_k = ind * self.n_heads + offset_v = self.n_layers * self.n_heads + offset_k + k_caches = args[offset_k : offset_k + self.n_heads] + v_caches = args[offset_v : offset_v + self.n_heads] + hidden_states, k, v = decoder_layer( + hidden_states, + freqs_cos=freqs_cos, + freqs_sin=freqs_sin, + atten_mask=atten_mask, + k_caches=k_caches, + v_caches=v_caches, + ) + output_k_cache.extend(k) + output_v_cache.extend(v) + + hidden_states = self.norm(hidden_states) + logits = self.output(hidden_states) + + return logits, output_k_cache, output_v_cache + + def get_example_inputs(self): + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 + ) + pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + k_cache, v_cache = [], [] + atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) + atten_mask[:, -1] = 0 + for _ in range(self.n_layers): + for _ in range(self.n_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.zeros( + self.max_batch_size, + self.head_dim, + self.max_seq_len - 1, + ) + ) + v_cache.append( + torch.zeros( + self.max_batch_size, + self.max_seq_len - 1, + self.head_dim, + ) + ) + return ( + tokens, + pos_ids, + atten_mask, + k_cache, + v_cache, + ) + + def get_export_inputs(self): + tokens = torch.randint( + self.vocab_size, (self.max_batch_size, 1), dtype=torch.int32 + ) + pos_ids = torch.zeros((self.max_batch_size, 1), dtype=torch.int32) + # this is important for torch.export not to take it as dummy input + k_cache, v_cache = [], [] + atten_mask = torch.full((self.max_batch_size, self.max_seq_len), -255.0) + atten_mask[:, -1] = 0 + for _ in range(self.n_layers): + for _ in range(self.n_heads): + # transpose first to decrease the runtime efforts + k_cache.append( + torch.randn( + self.max_batch_size, + self.head_dim, + self.max_seq_len - 1, + ) + ) + v_cache.append( + torch.randn( + self.max_batch_size, + self.max_seq_len - 1, + self.head_dim, + ) + ) + return ( + tokens, + pos_ids, + atten_mask, + k_cache, + v_cache, + ) + + def get_metadata(self): + # TODO: modify this when enabling LLAMA 7B + return { + "get_bos_id": 1, + "get_eos_id": 2, + "get_dim": self.dim, + "get_head_dim": self.dim // self.n_heads, + "get_max_batch_size": self.max_batch_size, + "get_max_seq_len": self.max_seq_len, + "get_n_bos": 1, + "get_n_eos": 1, + "get_n_kv_heads": self.n_heads, + "get_n_layers": self.n_layers, + "get_vocab_size": self.vocab_size, + } diff --git a/examples/qualcomm/llama2/runner/runner.cpp b/examples/qualcomm/llama2/runner/runner.cpp new file mode 100644 index 00000000000..691a0bf7390 --- /dev/null +++ b/examples/qualcomm/llama2/runner/runner.cpp @@ -0,0 +1,667 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple llama2 runner that includes preprocessing and post processing logic. +// The module takes in a string as input and emits a string as output. + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace torch { +namespace executor { + +namespace { +static constexpr auto kTopp = 0.9f; +void printReport(const Runner::Stats& stats); +std::string statsToJsonString(const Runner::Stats& stats); +} // namespace + +Runner::Runner( + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature) + : module_(std::make_unique( + model_path, + Module::MlockConfig::UseMlockIgnoreErrors)), + tokenizer_path_(tokenizer_path), + model_path_(model_path), + temperature_(temperature) { + ET_LOG( + Info, + "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", + model_path.c_str(), + tokenizer_path.c_str()); +} + +bool Runner::is_loaded() const { + return module_->is_loaded() && tokenizer_ && sampler_; +} + +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + stats_.model_load_start_ms = util::time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); + + // Read out metadata from the model + ET_LOG(Info, "Reading metadata from model"); + const auto method_names = module_->method_names(); + ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model"); + model_methods_ = method_names.get(); + vocab_size_ = getMetadataHelper("get_vocab_size", 32000); + bos_id_ = getMetadataHelper("get_bos_id", 1); + eos_id_ = getMetadataHelper("get_eos_id", 2); + n_bos_ = getMetadataHelper("get_n_bos", 1); + n_eos_ = getMetadataHelper("get_n_eos", 1); + max_seq_len_ = getMetadataHelper("get_max_seq_len", 128); + head_dim_ = getMetadataHelper("get_head_dim", 32); + dim_ = getMetadataHelper("get_dim", 4096); + + // Load tokenizer + tokenizer_ = std::make_unique(); + tokenizer_->load(tokenizer_path_); + if (tokenizer_->bos_tok() != bos_id_) { + ET_LOG( + Error, + "Tokenizer's BOS id %lu does not match model's BOS id %d, will override tokenizer's BOS.", + tokenizer_->bos_tok(), + bos_id_); + } + if (tokenizer_->eos_tok() != eos_id_) { + ET_LOG( + Error, + "Tokenizer's EOS id %lu does not match model's EOS id %d, will override tokenizer's EOS.", + tokenizer_->eos_tok(), + eos_id_); + } + // Create sampler + sampler_ = std::make_unique( + vocab_size_, + temperature_, + kTopp, + static_cast(std::time(nullptr))); + stats_.model_load_end_ms = util::time_in_ms(); + + return Error::Ok; +} + +template +T Runner::getMetadataHelper(std::string method_name, T default_val) { + T res = default_val; + if (model_methods_.count(method_name)) { + Result> outputs = module_->execute(method_name); + if (outputs.ok()) { + std::vector outs = outputs.get(); + if (outs.size() > 0) { + res = outs[0].to(); + } + } + } else { + ET_LOG( + Info, + "The model does not contain %s method, using default value %lld", + method_name.c_str(), + (long long)default_val); + } + ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res); + return res; +} + +template +int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) { + T* logits = logits_tensor.mutable_data_ptr(); + + // Since the logits are for all tokens, get the last token probabilities + T* logits_last = logits; + return sampler_->sample(logits_last); +} + +// Given an input token. Set up the inputs for the model and execute a single +// step. Returning the logits tensor. +Result Runner::run_model_step( + int64_t input_token, + Tensor& token, + Tensor& start_pos, + Tensor& atten_mask, + std::vector& kv_tensors, + std::vector& kv_outputs) { + token.mutable_data_ptr()[0] = input_token; + + // inputs:[tokens, start_pos, atten_mask, k_cache, v_cache] + std::vector inputs = {token, start_pos, atten_mask}; + inputs.insert(inputs.end(), kv_tensors.begin(), kv_tensors.end()); + Result> outputs_res = module_->forward(inputs); + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + + // TODO: need to handle batch size != 1 + size_t v_offset = kv_outputs[0].nbytes(); + size_t el_size = kv_outputs[0].element_size(); + size_t k_input_step = (max_seq_len_ - 1) * el_size; + int k_tensors_end = kv_tensors.size() / 2; + // update k caches + for (int j = 0; j < k_tensors_end; ++j) { + uint8_t* input_addr = + static_cast(kv_tensors[j].mutable_data_ptr()); + uint8_t* output_addr = + static_cast(kv_outputs[j].mutable_data_ptr()); + // fill the output k values back + for (int src = 0, dst = k_input_step; src < kv_outputs[j].nbytes(); + src += el_size, dst += k_input_step) { + input_addr[dst] = output_addr[src]; + } + char* new_inp_addr = io_mem_mgr_.update_k_caches_read(j, el_size); + // inputs + ET_CHECK_MSG( + internal::set_tensor_data( + kv_tensors[j], new_inp_addr, kv_tensors[j].nbytes()) == Error::Ok, + "Failed to set input tensor when updating k_cache"); + } + // update v caches + for (int j = k_tensors_end, v_idx = 0; j < kv_tensors.size(); ++j, ++v_idx) { + // inputs + char* new_inp_addr = io_mem_mgr_.update_v_caches_read(v_idx, v_offset); + + ET_CHECK_MSG( + internal::set_tensor_data( + kv_tensors[j], new_inp_addr, kv_tensors[j].nbytes()) == Error::Ok, + "Failed to set input tensor when updating v_cache"); + // outputs + char* new_out_addr = io_mem_mgr_.update_v_caches_write(v_idx, v_offset); + ET_CHECK_MSG( + internal::set_tensor_data( + kv_outputs[j], new_out_addr, kv_outputs[j].nbytes()) == Error::Ok, + "Failed to set output tensor when updating v_cache"); + ET_CHECK_MSG( + module_->set_output_data_ptr(kv_outputs[j], j + 1) == Error::Ok, + "Failed to set llama output data pointer"); + } + + // Bump start_pos by 1 + start_pos.mutable_data_ptr()[0]++; + + // update atten_mask + atten_mask.mutable_data_ptr() + [atten_mask.numel() - 1 - start_pos.const_data_ptr()[0]] = 0; + return outputs_res.get()[0].toTensor(); +} +// TODO: add overloaded method for on-device tokenize +Error Runner::generate( + const std::string& prompt, + int32_t seq_len, + std::function token_callback, + std::function stats_callback) { + ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); + ET_CHECK_MSG(is_loaded(), "Please invoke load method first"); + + // First token time only measures the time it takes to encode the prompt and + // return a response token. + stats_.inference_start_ms = util::time_in_ms(); + shouldStop_ = false; + + // Set the sequence length to the max seq length if not provided + seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; + + Result> encode_res = + tokenizer_->encode(prompt, n_bos_, 0); + + ET_CHECK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + int num_prompt_tokens = prompt_tokens.size(); + + ET_CHECK_MSG( + num_prompt_tokens < max_seq_len_, + "Max seq length exceeded - please increase max seq len value in static_llama.py"); + + ET_CHECK_MSG( + num_prompt_tokens < seq_len, + "Sequence length exceeded - please increase the seq_len value passed to generate()"); + + int32_t pos = 0, prev_token, cur_token = prompt_tokens[0]; + std::vector token_shape = {1, 1}; + + io_mem_mgr_.get_input_token_ptr()[0] = 0; + std::vector start_pos_shape = {1, 1}; + + float* atten_mask_ptr = + reinterpret_cast(io_mem_mgr_.get_atten_mask_ptr()); + std::fill(atten_mask_ptr, atten_mask_ptr + max_seq_len_, -255); + atten_mask_ptr[max_seq_len_ - 1] = 0; + + std::vector atten_mask_shape = {1, max_seq_len_}; + + std::vector logits_data_shape = {1, vocab_size_}; + + std::vector hidden_states_data_shape = {1, 1, dim_}; + + // initialize tensor wrappers + ManagedTensor managed_token( + io_mem_mgr_.get_input_token_ptr(), token_shape, ScalarType::Int); + ManagedTensor managed_pos_id( + io_mem_mgr_.get_pos_idx_ptr(), start_pos_shape, ScalarType::Int); + ManagedTensor managed_atten_mask( + io_mem_mgr_.get_atten_mask_ptr(), atten_mask_shape, ScalarType::Float); + + Tensor token = managed_token.get_aliasing_tensor(); + Tensor atten_mask = managed_atten_mask.get_aliasing_tensor(); + Tensor start_pos = managed_pos_id.get_aliasing_tensor(); + + std::vector managed_kv_inputs, managed_kv_outputs; + std::vector kv_tensors, kv_outputs; + + Result method_meta = get_method_meta(); + size_t num_inputs = method_meta->num_inputs(); + int k_caches_num = (num_inputs - 3) / 2; + + // TODO: need to handle batch size != 1 + // k caches init + for (int input_index = 3, i = 0; input_index < k_caches_num + 3; + ++input_index, ++i) { + // inputs + Result tensor_meta = + method_meta->input_tensor_meta(input_index); + + auto tensor_shape = tensor_meta->sizes(); + std::vector sizes( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + managed_kv_inputs.emplace_back(ManagedTensor( + io_mem_mgr_.get_k_caches_read_ptr(i), + sizes, + tensor_meta->scalar_type())); + kv_tensors.emplace_back(managed_kv_inputs.back().get_aliasing_tensor()); + + // outpus + Result out_tensor_meta = method_meta->output_tensor_meta(i + 1); + tensor_shape = out_tensor_meta->sizes(); + sizes = std::vector{ + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()}; + managed_kv_outputs.emplace_back(ManagedTensor( + io_mem_mgr_.get_k_caches_write_ptr(i), + sizes, + kv_tensors.back().scalar_type())); + kv_outputs.emplace_back(managed_kv_outputs.back().get_aliasing_tensor()); + ET_CHECK_MSG( + module_->set_output_data_ptr(kv_outputs.back(), i + 1) == Error::Ok, + "Failed to set output tensor for kv cache"); + } + + // v caches init + for (int i = 0, input_index = k_caches_num + 3; input_index < num_inputs; + ++input_index, ++i) { + int output_index = i + k_caches_num + 1; + // inputs + Result tensor_meta = + method_meta->input_tensor_meta(input_index); + auto tensor_shape = tensor_meta->sizes(); + std::vector sizes( + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()); + + managed_kv_inputs.emplace_back(ManagedTensor( + io_mem_mgr_.get_v_caches_read_ptr(i), + sizes, + tensor_meta->scalar_type())); + kv_tensors.push_back(managed_kv_inputs.back().get_aliasing_tensor()); + + // outputs + Result out_tensor_meta = + method_meta->output_tensor_meta(output_index); + tensor_shape = out_tensor_meta->sizes(); + sizes = std::vector{ + tensor_shape.data(), tensor_shape.data() + tensor_shape.size()}; + + managed_kv_outputs.push_back(ManagedTensor( + io_mem_mgr_.get_v_caches_write_ptr(i), + sizes, + kv_tensors.back().scalar_type())); + kv_outputs.push_back(managed_kv_outputs.back().get_aliasing_tensor()); + ET_CHECK_MSG( + module_->set_output_data_ptr(kv_outputs.back(), output_index) == + Error::Ok, + "Failed to set output tensor for llama block"); + } + + ManagedTensor affine_managed_logits( + reinterpret_cast(io_mem_mgr_.get_logit_ptr()), + logits_data_shape, + ScalarType::Float); + Tensor affine_logits = affine_managed_logits.get_aliasing_tensor(); + ET_CHECK_MSG( + module_->set_output_data_ptr(affine_logits, 0) == Error::Ok, + "Failed to set output tensor for affine module - logits"); + + // Start consuming user's prompts and generating new tokens + std::string final_output; + while (pos < seq_len - 1) { + // Run the model + Result logits_res = run_model_step( + cur_token, token, start_pos, atten_mask, kv_tensors, kv_outputs); + if (pos == num_prompt_tokens) { + stats_.first_token_ms = util::time_in_ms(); + } else if (pos == num_prompt_tokens - 1) { + stats_.prompt_eval_end_ms = util::time_in_ms(); + } + + ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); + exec_aten::Tensor& logits_tensor = logits_res.get(); + prev_token = cur_token; + long sample_start_time_ms = util::time_in_ms(); + + cur_token = logitsToToken(logits_tensor); + stats_.aggregate_sampling_time_ms += + util::time_in_ms() - sample_start_time_ms; + + // advance the state machine + if (pos < num_prompt_tokens - 1) { + // prefill, force the next token to be the next prompt token + cur_token = prompt_tokens[pos + 1]; + } + pos++; + + // print the token as string, decode it with the Tokenizer object + auto piece_res = tokenizer_->decode(prev_token, cur_token); + ET_CHECK(piece_res.ok()); + + if (token_callback) { + token_callback(piece_res.get()); + } + + if (shouldStop_) { + break; + } + + // data-dependent terminating condition: we have n_eos_ number of EOS + if (pos >= num_prompt_tokens && cur_token == eos_id_) { + ET_LOG(Info, "Reached to the end of generation"); + break; + } + } + stats_.inference_end_ms = util::time_in_ms(); + + if (pos == seq_len) { + ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len); + } + + stats_.num_prompt_tokens = num_prompt_tokens; + stats_.num_generated_tokens = pos - num_prompt_tokens; + printReport(stats_); + if (stats_callback) { + stats_callback(stats_); + } + + return Error::Ok; +} + +namespace { +void printReport(const Runner::Stats& stats) { + printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str()); + + ET_LOG( + Info, + "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, + stats.num_prompt_tokens, + stats.num_generated_tokens); + + ET_LOG( + Info, + "\tModel Load Time:\t\t%f (seconds)", + ((double)(stats.model_load_end_ms - stats.model_load_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); + double inference_time_ms = + (double)(stats.inference_end_ms - stats.inference_start_ms); + ET_LOG( + Info, + "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND, + + (stats.num_generated_tokens) / + (double)(stats.inference_end_ms - stats.inference_start_ms) * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + double prompt_eval_time = + (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); + ET_LOG( + Info, + "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + (stats.num_prompt_tokens) / prompt_eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + + double eval_time = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + ET_LOG( + Info, + "\t\tGenerated %" PRIu64 + " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", + stats.num_generated_tokens, + eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + stats.num_generated_tokens / eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + + // Time to first token is measured from the start of inference, excluding + // model load time. + ET_LOG( + Info, + "\tTime to first generated token:\t%f (seconds)", + ((double)(stats.first_token_ms - stats.inference_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)", + stats.num_prompt_tokens + stats.num_generated_tokens, + (double)stats.aggregate_sampling_time_ms / + stats.SCALING_FACTOR_UNITS_PER_SECOND); +} + +std::string statsToJsonString(const Runner::Stats& stats) { + std::stringstream ss; + ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << "," + << "\"generated_tokens\":" << stats.num_generated_tokens << "," + << "\"model_load_start_ms\":" << stats.model_load_start_ms << "," + << "\"model_load_end_ms\":" << stats.model_load_end_ms << "," + << "\"inference_start_ms\":" << stats.inference_start_ms << "," + << "\"inference_end_ms\":" << stats.inference_end_ms << "," + << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << "," + << "\"first_token_ms\":" << stats.first_token_ms << "," + << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms + << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" + << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}"; + return ss.str(); +} +} // namespace + +IoMemMgr::IoMemMgr(MethodMeta method_meta) { + method_meta_ = std::make_unique(method_meta); + init_io_info(); + compute_total_nbytes(); +} + +void IoMemMgr::init_io_info() { + set_tensor_meta(); + for (auto info : io_info_.tensor_info) { + info->size = info->tensor_meta->nbytes(); + info->rank = info->tensor_meta->sizes().size(); + info->shape.resize(info->rank); + for (int i = 0; i < info->rank; i++) { + info->shape[i] = + static_cast(info->tensor_meta->sizes().data()[i]); + } + info->dtype = info->tensor_meta->scalar_type(); + info->element_size = scalar_type_to_size[info->tensor_meta->scalar_type()]; + } +}; + +void IoMemMgr::set_tensor_meta() { + io_info_.input_token.tensor_meta = + std::make_unique(method_meta_->input_tensor_meta(0).get()); + io_info_.pos_idx.tensor_meta = + std::make_unique(method_meta_->input_tensor_meta(1).get()); + io_info_.atten_mask.tensor_meta = + std::make_unique(method_meta_->input_tensor_meta(2).get()); + + io_info_.k_caches_read.tensor_meta = + std::make_unique(method_meta_->input_tensor_meta(3).get()); + io_info_.k_caches_write.tensor_meta = + std::make_unique(method_meta_->output_tensor_meta(1).get()); + + io_info_.v_caches_read.tensor_meta = std::make_unique( + method_meta_->input_tensor_meta(method_meta_->num_inputs() - 1).get()); + io_info_.v_caches_write.tensor_meta = std::make_unique( + method_meta_->output_tensor_meta(method_meta_->num_outputs() - 1).get()); + + io_info_.logit.tensor_meta = + std::make_unique(method_meta_->output_tensor_meta(0).get()); +} + +void IoMemMgr::compute_total_nbytes() { + total_nbytes_ = io_info_.input_token.size + io_info_.pos_idx.size + + io_info_.atten_mask.size + io_info_.logit.size; + size_t num_heads = (method_meta_->num_inputs() - 3) / 2; + + // To update v cache via shifting pointer, v caches need a buffer with size + // of (max_seq_len_ - 1) * head_dim_. It is equivalent to one more cache + size_t num_v_cache = num_heads + 1; + // To update v cache via shifting pointer, k buffer need the size of + // max_seq_len - 1 + size_t k_buffer = io_info_.k_caches_read.size / io_info_.k_caches_write.size; + + // k_caches_read need a buffer with size of head_dim_ + total_nbytes_ += num_heads * io_info_.k_caches_read.size + k_buffer; + total_nbytes_ += num_heads * io_info_.k_caches_write.size; + total_nbytes_ += num_v_cache * io_info_.v_caches_read.size; + // Add a head dim size for the convinience of shifting ptr from the last + // non-used v cache write + total_nbytes_ += io_info_.v_caches_write.size; +} + +bool IoMemMgr::init_tensors() { + size_t cur_pos = input_token_pos_; + pos_idx_pos_ = cur_pos += io_info_.input_token.size; + atten_mask_pos_ = cur_pos += io_info_.pos_idx.size; + logit_pos_ = cur_pos += io_info_.atten_mask.size; + set_input_token_ptr(); + set_pos_idx_ptr(); + set_atten_mask_ptr(); + set_logit_ptr(); + + // set start point of kv caches + cur_pos += io_info_.logit.size; + + size_t num_heads = (method_meta_->num_inputs() - 3) / 2; + k_caches_read_pos_.resize(num_heads); + k_caches_write_pos_.resize(num_heads); + v_caches_read_pos_.resize(num_heads); + v_caches_write_pos_.resize(num_heads); + + for (int i = 0; i < num_heads; i++) { + set_k_caches_read(i, cur_pos); + cur_pos += io_info_.k_caches_read.size; + } + // add a size of k caches buffer + cur_pos += io_info_.k_caches_read.size / io_info_.k_caches_write.size; + for (int i = 0; i < num_heads; i++) { + set_k_caches_write(i, cur_pos); + cur_pos += io_info_.k_caches_write.size; + } + + for (int i = 0; i < num_heads; i++) { + set_v_caches_read(i, cur_pos); + set_v_caches_write(i, cur_pos + io_info_.v_caches_read.size); + cur_pos += io_info_.v_caches_read.size; + } + // add a caches as the b caches buffer + cur_pos += io_info_.v_caches_read.size; + return cur_pos <= total_nbytes_; +} + +void IoMemMgr::set_all_shifted_ptrs(size_t seq_len) { + auto iter_setter = [&](std::vector& cache, + size_t shift_size, + InfoAttrs& tensor_info) { + for (int i = 0; i < cache.size(); ++i) { + size_t pos = cache[i] + shift_size; + CustomMemTensorInfo info = { + ptr_, + ptr_ + pos, + pos, + tensor_info.size, + tensor_info.shape.data(), + tensor_info.rank, + tensor_info.dtype}; + QnnExecuTorchAddCustomMemTensorInfo(info); + } + }; + for (int i = 0; i < seq_len; ++i) { + iter_setter( + k_caches_read_pos_, + i * io_info_.k_caches_read.element_size, + io_info_.k_caches_read); + iter_setter( + v_caches_read_pos_, + i * io_info_.v_caches_write.size, + io_info_.v_caches_read); + iter_setter( + v_caches_write_pos_, + i * io_info_.v_caches_write.size, + io_info_.v_caches_write); + } +} + +void Runner::stop() { + shouldStop_ = true; +} + +Result Runner::get_method_meta() { + return module_->method_meta("forward"); +} + +Error Runner::mem_alloc(size_t alignment, size_t seq_len) { + Result method_meta_result = get_method_meta(); + io_mem_mgr_ = IoMemMgr(method_meta_result.get()); + ET_CHECK_MSG( + io_mem_mgr_.allocate(alignment), + "IoMemMgr failed to allocate custom memory"); + + ET_CHECK_MSG( + io_mem_mgr_.init_tensors(), + "IoMemMgr required more bytes than allocated bytes"); + + io_mem_mgr_.set_all_shifted_ptrs(seq_len); + // To register rpc_mem_handle from SharedBuffer + // Reset and re-init again to trigger registered function + module_.reset(); + module_ = std::make_unique( + model_path_, Module::MlockConfig::UseMlockIgnoreErrors), + ET_CHECK_MSG(load() == Error::Ok, "Runner failed to load method"); + + return Error::Ok; +} + +// explicit instantiation of template methods +template int64_t Runner::getMetadataHelper( + std::string method_name, + int64_t default_val); +template bool Runner::getMetadataHelper( + std::string method_name, + bool default_val); + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/llama2/runner/runner.h b/examples/qualcomm/llama2/runner/runner.h new file mode 100644 index 00000000000..5128c365cb8 --- /dev/null +++ b/examples/qualcomm/llama2/runner/runner.h @@ -0,0 +1,280 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple llama2 runner that includes preprocessing and post processing logic. +// The module takes in a string as input and emits a string as output. + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +class RpcMemAllocator { + public: + RpcMemAllocator(QnnMemDescriptor shared_buffer_type) + : shared_buffer_type_(shared_buffer_type){}; + bool allocate(size_t bytes, size_t alignment) { + ptr_ = QnnExecuTorchAllocCustomMem(bytes, alignment); + if (ptr_ == nullptr) { + ET_LOG( + Info, + "Allocate Rpc mem falied, fallback to nromal ptr: bytes=%zu, alignment=%zu", + bytes, + alignment); + input_data_.resize(bytes); + ptr_ = input_data_.data(); + } + return ptr_ != nullptr; + } + + ~RpcMemAllocator() { + if (shared_buffer_type_ == QnnMemDescriptor::kIon || + shared_buffer_type_ == QnnMemDescriptor::kCustom) { + if (ptr_ != nullptr) { + QnnExecuTorchFreeCustomMem(ptr_); + } + } + } + + void* GetPtr() { + return ptr_; + } + + private: + QnnMemDescriptor shared_buffer_type_; + void* ptr_{nullptr}; + std::vector input_data_; + std::vector tensor_base_addrs_; +}; + +#define DEFINE_IOMEMMGR_ACCESSOR(name) \ + size_t get_##name##_pos() const { \ + return name##_pos_; \ + } \ + char* get_##name##_ptr() const { \ + return reinterpret_cast(ptr_) + name##_pos_; \ + } \ + char* set_##name##_ptr() { \ + CustomMemTensorInfo info = { \ + ptr_, \ + ptr_ + name##_pos_, \ + name##_pos_, \ + io_info_.name.size, \ + io_info_.name.shape.data(), \ + io_info_.name.rank, \ + io_info_.name.dtype}; \ + QnnExecuTorchAddCustomMemTensorInfo(info); \ + return reinterpret_cast(ptr_) + name##_pos_; \ + } + +#define DEFINE_IOMEMMGR_VEC_ACCESSOR(name) \ + const std::vector& get_##name##_pos_vec() const { \ + return name##_pos_; \ + } \ + char* get_##name##_ptr(int idx) { \ + return ptr_ + name##_pos_[idx]; \ + } \ + char* set_##name(int idx, size_t pos) { \ + name##_pos_[idx] = pos; \ + CustomMemTensorInfo info = { \ + ptr_, \ + ptr_ + name##_pos_[idx], \ + name##_pos_[idx], \ + io_info_.name.size, \ + io_info_.name.shape.data(), \ + io_info_.name.rank, \ + io_info_.name.dtype}; \ + QnnExecuTorchAddCustomMemTensorInfo(info); \ + return reinterpret_cast(ptr_) + pos; \ + } \ + char* update_##name(int idx, size_t shift_size) { \ + name##_pos_[idx] += shift_size; \ + return reinterpret_cast(ptr_) + name##_pos_[idx]; \ + } + +namespace torch { +namespace executor { +class IoMemMgr { + public: + // Allocate a big memory which is capable to contain all IO of all modules + IoMemMgr(){}; + IoMemMgr(MethodMeta method_meta); + + struct InfoAttrs { + std::unique_ptr tensor_meta; + size_t size = 0; + std::vector shape; + uint32_t rank; + size_t element_size; + torch::executor::ScalarType dtype; + }; + + struct IoInfo { + InfoAttrs input_token; + InfoAttrs pos_idx; + InfoAttrs atten_mask; + InfoAttrs k_caches_read; + InfoAttrs k_caches_write; + InfoAttrs v_caches_read; + InfoAttrs v_caches_write; + InfoAttrs logit; + std::vector tensor_info{ + &input_token, + &pos_idx, + &atten_mask, + &k_caches_read, + &k_caches_write, + &v_caches_read, + &v_caches_write, + &logit, + }; + }; + + bool allocate(size_t alignment) { + bool ret = rpc_mem_allocator.allocate(total_nbytes_, alignment); + ptr_ = reinterpret_cast(rpc_mem_allocator.GetPtr()); + return ret; + } + bool init_tensors(); + + char* get_custom_mem_ptr() { + return ptr_; + } + + // Pointers of k cache read, v cache read and write are shifted every step. + // Set them first to register mem handle during qnn delegation init. + void set_all_shifted_ptrs(size_t max_seq_len); + + DEFINE_IOMEMMGR_ACCESSOR(atten_mask); + DEFINE_IOMEMMGR_ACCESSOR(input_token); + DEFINE_IOMEMMGR_ACCESSOR(pos_idx); + DEFINE_IOMEMMGR_ACCESSOR(logit); + + DEFINE_IOMEMMGR_VEC_ACCESSOR(k_caches_read); + DEFINE_IOMEMMGR_VEC_ACCESSOR(k_caches_write); + DEFINE_IOMEMMGR_VEC_ACCESSOR(v_caches_read); + DEFINE_IOMEMMGR_VEC_ACCESSOR(v_caches_write); + + private: + size_t total_nbytes_{0}; + char* ptr_{nullptr}; + void compute_total_nbytes(); + void set_tensor_meta(); + void init_io_info(); + + size_t atten_mask_pos_; + size_t input_token_pos_{0}; + size_t logit_pos_; + size_t pos_idx_pos_; + std::vector k_caches_read_pos_; + std::vector k_caches_write_pos_; + std::vector v_caches_read_pos_; + std::vector v_caches_write_pos_; + + IoInfo io_info_; + std::unique_ptr method_meta_; + RpcMemAllocator rpc_mem_allocator{QnnMemDescriptor::kCustom}; + std::unordered_map scalar_type_to_size = { + {ScalarType::Int, sizeof(int32_t)}, + {ScalarType::Float, sizeof(float)}, + {ScalarType::Char, sizeof(int8_t)}, + {ScalarType::Short, sizeof(int16_t)}, + {ScalarType::Byte, sizeof(uint8_t)}, + {ScalarType::Bits16, sizeof(uint16_t)}, + }; +}; + +class Runner { + public: + explicit Runner( + const std::string& model_path, + const std::string& tokenizer_path, + const float temperature = 0.8f); + + struct Stats { + // Scaling factor for timestamps - in this case, we use ms. + const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; + // Time stamps for the different stages of the execution + // model_load_start_ms: Start of model loading. + long model_load_start_ms; + // model_load_end_ms: End of model loading. + long model_load_end_ms; + // inference_start_ms: Immediately after the model is loaded (or we check + // for model load), measure the inference time. + long inference_start_ms; + // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right + // before the inference loop starts + long prompt_eval_end_ms; + // first_token: Timestamp when the first generated token is emitted + long first_token_ms; + // inference_end_ms: End of inference/generation. + long inference_end_ms; + // Keep a running total of the time spent in sampling. + long aggregate_sampling_time_ms; + // Token count from prompt + int64_t num_prompt_tokens; + // Token count from generated (total - prompt) + int64_t num_generated_tokens; + }; + + bool is_loaded() const; + Error load(); + Error mem_alloc(size_t alignment, size_t seq_len); + Error generate( + const std::string& prompt, + int32_t seq_len, + std::function token_callback = {}, + std::function stats_callback = {}); + void stop(); + Result get_method_meta(); + + private: + // metadata + template + T getMetadataHelper(std::string method_name, T default_val); + template + int32_t logitsToToken(const exec_aten::Tensor& logits_tensor); + Result run_model_step( + int64_t input_token, + Tensor& token, + Tensor& start_pos, + Tensor& atten_mask, + std::vector& kv_tensors, + std::vector& kv_outputs); + // metadata + int32_t vocab_size_; + int64_t bos_id_; + int64_t eos_id_; + int32_t n_bos_; + int32_t n_eos_; + int32_t max_seq_len_; + int32_t head_dim_; + int32_t dim_; + std::unordered_set model_methods_; + std::unique_ptr module_; + std::string tokenizer_path_; + std::string model_path_; + float temperature_; + std::unique_ptr tokenizer_; + std::unique_ptr sampler_; + bool shouldStop_{false}; + Stats stats_; + IoMemMgr io_mem_mgr_; +}; + +} // namespace executor +} // namespace torch diff --git a/examples/qualcomm/oss_scripts/dino_v2.py b/examples/qualcomm/oss_scripts/dino_v2.py index d4fb3d757e7..e4d4c6af252 100644 --- a/examples/qualcomm/oss_scripts/dino_v2.py +++ b/examples/qualcomm/oss_scripts/dino_v2.py @@ -131,13 +131,13 @@ def get_instance(): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/oss_scripts/esrgan.py b/examples/qualcomm/oss_scripts/esrgan.py index ad11eb760dd..50dc59cf0cc 100644 --- a/examples/qualcomm/oss_scripts/esrgan.py +++ b/examples/qualcomm/oss_scripts/esrgan.py @@ -119,13 +119,13 @@ def get_instance(repo: str): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/oss_scripts/fbnet.py b/examples/qualcomm/oss_scripts/fbnet.py index 5996566a178..d62c4a78b15 100755 --- a/examples/qualcomm/oss_scripts/fbnet.py +++ b/examples/qualcomm/oss_scripts/fbnet.py @@ -80,7 +80,7 @@ adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/oss_scripts/gMLP_image_classification.py b/examples/qualcomm/oss_scripts/gMLP_image_classification.py index d4c77531158..3d98f55a7da 100644 --- a/examples/qualcomm/oss_scripts/gMLP_image_classification.py +++ b/examples/qualcomm/oss_scripts/gMLP_image_classification.py @@ -121,13 +121,13 @@ def get_data_loader(): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built - # pte_path : path where executorch binary was stored + # build_path : path where artifacts were built + # pte_path : path where QNN delegate executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/oss_scripts/squeezenet.py b/examples/qualcomm/oss_scripts/squeezenet.py index db5556ead93..53edb98b91b 100644 --- a/examples/qualcomm/oss_scripts/squeezenet.py +++ b/examples/qualcomm/oss_scripts/squeezenet.py @@ -119,13 +119,13 @@ def get_data_loader(): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/oss_scripts/ssd300_vgg16.py b/examples/qualcomm/oss_scripts/ssd300_vgg16.py index 936db49d0a1..cd4eb8764f0 100644 --- a/examples/qualcomm/oss_scripts/ssd300_vgg16.py +++ b/examples/qualcomm/oss_scripts/ssd300_vgg16.py @@ -201,13 +201,13 @@ def SSD300VGG16(pretrained_weight_model): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/scripts/deeplab_v3.py b/examples/qualcomm/scripts/deeplab_v3.py index 4e08ab078c2..ff1f53c1807 100755 --- a/examples/qualcomm/scripts/deeplab_v3.py +++ b/examples/qualcomm/scripts/deeplab_v3.py @@ -117,13 +117,13 @@ def get_dataset(data_size, dataset_dir, download): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/scripts/dummy_llama2.py b/examples/qualcomm/scripts/dummy_llama2.py deleted file mode 100755 index 8bff578abba..00000000000 --- a/examples/qualcomm/scripts/dummy_llama2.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import json -import os -import re -import sys -from multiprocessing.connection import Client - -import numpy as np -import torch -from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype -from executorch.examples.models.llama2 import Llama2Model -from executorch.examples.qualcomm.scripts.utils import ( - build_executorch_binary, - make_output_dir, - setup_common_args_and_variables, - SimpleADB, -) - - -def create_device_inputs(example_inputs, use_kv_cache): - inputs = [inp.to(torch.int32) for inp in example_inputs] - input_list = "" - if use_kv_cache: - for i, d in enumerate(inputs[0]): - if isinstance(d, list): - d = torch.stack(d) - d.numpy().tofile(f"{args.artifact}/input_0_0.raw") - input_list = f"input_0_{i}.raw " - else: - inputs[0].numpy().tofile(f"{args.artifact}/input_0_0.raw") - input_list = "input_0_0.raw" - input_list += "\n" - return tuple(inputs), input_list - - -if __name__ == "__main__": - print( - "[WARNING] The module of llama is changing frequently. This script might not work" - ) - parser = setup_common_args_and_variables() - parser.add_argument( - "-a", - "--artifact", - help="path for storing generated artifacts by this example. Default ./dummy_llama2", - default="./dummy_llama2", - type=str, - ) - - # TODO kv cache is not yet enabled - parser.add_argument( - "-kv", - "--use_kv_cache", - default=False, - action="store_true", - help="Whether or not to export a model using kv cache", - ) - - parser.add_argument( - "-F", - "--use_fp16", - help="If specified, will run in fp16 precision and discard ptq setting", - action="store_true", - default=False, - ) - - parser.add_argument( - "-P", - "--ptq", - help="If specified, will do PTQ quantization. default is 8bits activation and 8bits weight. Support 8a8w, 16a16w and 16a4w.", - default="8a8w", - ) - - parser.add_argument( - "--checkpoint", - help="Pass llama2 checkpoint.", - default=False, - ) - - parser.add_argument( - "--params", - help="Pass llama2 params json file.", - default=False, - ) - - args = parser.parse_args() - - # ensure the working directory exist. - os.makedirs(args.artifact, exist_ok=True) - - if args.params and args.checkpoint: - instance = Llama2Model( - use_kv_cache=args.use_kv_cache, - checkpoint=args.checkpoint, - params=args.params, - ) - else: - instance = Llama2Model( - use_kv_cache=args.use_kv_cache, - ) - - inputs, input_list = create_device_inputs( - instance.get_example_inputs(), args.use_kv_cache - ) - - pte_filename = "dummy_llama2_qnn" - - if args.ptq == "8a8w": - quant_dtype = QuantDtype.use_8a8w - elif args.ptq == "16a16w": - quant_dtype = QuantDtype.use_16a16w - elif args.ptq == "16a4w": - quant_dtype = QuantDtype.use_16a4w - else: - raise AssertionError( - f"No support for quant type {args.ptq}. Support 8a8w, 16a16w and 16a4w." - ) - - if args.use_fp16: - quant_dtype = None - - build_executorch_binary( - instance.get_eager_model().eval(), - inputs, - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - custom_annotations=(), - quant_dtype=quant_dtype, - shared_buffer=args.shared_buffer, - ) - - if args.compile_only: - sys.exit(0) - - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", - pte_path=f"{args.artifact}/{pte_filename}.pte", - workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, - ) - adb.push(inputs=inputs, input_list=input_list) - adb.execute() - - # collect output data - output_data_folder = f"{args.artifact}/outputs" - make_output_dir(output_data_folder) - - output_raws = [] - - def post_process(): - for f in sorted( - os.listdir(output_data_folder), key=lambda f: int(f.split("_")[1]) - ): - filename = os.path.join(output_data_folder, f) - if re.match(r"^output_[0-9]+_[1-9].raw$", f): - os.remove(filename) - else: - output = np.fromfile(filename, dtype=np.float32) - output_raws.append(output) - - adb.pull(output_path=args.artifact, callback=post_process) - - x86_golden = instance.get_eager_model().eval()(inputs[0]) - device_output = torch.from_numpy(output_raws[0]).reshape(x86_golden.size()) - result = torch.all(torch.isclose(x86_golden, device_output, atol=1e-2)).tolist() - - if args.ip and args.port != -1: - with Client((args.ip, args.port)) as conn: - conn.send( - json.dumps( - { - "is_close": result, - } - ) - ) - else: - print(f"is_close? {result}") - print(f"x86_golden {x86_golden}") - print(f"device_out {device_output}") diff --git a/examples/qualcomm/scripts/edsr.py b/examples/qualcomm/scripts/edsr.py index d09bb1e0dbd..54cc8bff196 100755 --- a/examples/qualcomm/scripts/edsr.py +++ b/examples/qualcomm/scripts/edsr.py @@ -164,13 +164,13 @@ def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/scripts/inception_v3.py b/examples/qualcomm/scripts/inception_v3.py index a3b5c41923d..94aa618c720 100755 --- a/examples/qualcomm/scripts/inception_v3.py +++ b/examples/qualcomm/scripts/inception_v3.py @@ -119,13 +119,13 @@ def get_data_loader(): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/scripts/inception_v4.py b/examples/qualcomm/scripts/inception_v4.py index 06b8047a18c..e457fef0f7c 100755 --- a/examples/qualcomm/scripts/inception_v4.py +++ b/examples/qualcomm/scripts/inception_v4.py @@ -118,13 +118,13 @@ def get_data_loader(): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/scripts/mobilebert_fine_tune.py b/examples/qualcomm/scripts/mobilebert_fine_tune.py index cb067690f94..85aafe7cae7 100755 --- a/examples/qualcomm/scripts/mobilebert_fine_tune.py +++ b/examples/qualcomm/scripts/mobilebert_fine_tune.py @@ -305,13 +305,13 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/scripts/mobilenet_v2.py b/examples/qualcomm/scripts/mobilenet_v2.py index e389c00b3ec..f642e0172c1 100755 --- a/examples/qualcomm/scripts/mobilenet_v2.py +++ b/examples/qualcomm/scripts/mobilenet_v2.py @@ -119,13 +119,13 @@ def get_data_loader(): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/scripts/mobilenet_v3.py b/examples/qualcomm/scripts/mobilenet_v3.py index 8f83ae8d7e9..d15827160a8 100644 --- a/examples/qualcomm/scripts/mobilenet_v3.py +++ b/examples/qualcomm/scripts/mobilenet_v3.py @@ -117,13 +117,13 @@ def get_data_loader(): # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/scripts/torchvision_vit.py b/examples/qualcomm/scripts/torchvision_vit.py index 694610cbe42..cd5463c8a2a 100755 --- a/examples/qualcomm/scripts/torchvision_vit.py +++ b/examples/qualcomm/scripts/torchvision_vit.py @@ -100,13 +100,13 @@ def get_data_loader(): ) # setup required paths accordingly # qnn_sdk : QNN SDK path setup in environment variable - # artifact_path : path where artifacts were built + # build_path : path where QNN delegate artifacts were built # pte_path : path where executorch binary was stored # device_id : serial number of android device # workspace : folder for storing artifacts on android device adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), - artifact_path=f"{args.build_folder}", + build_path=f"{args.build_folder}", pte_path=f"{args.artifact}/{pte_filename}.pte", workspace=f"/data/local/tmp/executorch/{pte_filename}", device_id=args.device, diff --git a/examples/qualcomm/scripts/utils.py b/examples/qualcomm/scripts/utils.py index affe79ece05..274b4b39b78 100755 --- a/examples/qualcomm/scripts/utils.py +++ b/examples/qualcomm/scripts/utils.py @@ -10,7 +10,7 @@ import sys from pathlib import Path -from typing import Optional +from typing import Callable, List, Optional import numpy as np @@ -30,9 +30,11 @@ generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, ) +from executorch.exir import EdgeCompileConfig, EdgeProgramManager from executorch.exir.backend.backend_api import to_backend from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from torch.ao.quantization.observer import MovingAverageMinMaxObserver from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -40,7 +42,7 @@ class SimpleADB: def __init__( self, qnn_sdk, - artifact_path, + build_path, pte_path, workspace, device_id, @@ -48,9 +50,10 @@ def __init__( host_id=None, error_only=False, shared_buffer=False, + runner="examples/qualcomm/qnn_executor_runner", ): self.qnn_sdk = qnn_sdk - self.artifact_path = artifact_path + self.build_path = build_path self.pte_path = pte_path self.workspace = workspace self.device_id = device_id @@ -68,6 +71,7 @@ def __init__( self.soc_model = arch_table[soc_model] self.error_only = error_only self.shared_buffer = shared_buffer + self.runner = runner def _adb(self, cmd): if not self.host_id: @@ -80,7 +84,7 @@ def _adb(self, cmd): cmds, stdout=subprocess.DEVNULL if self.error_only else sys.stdout ) - def push(self, inputs, input_list): + def push(self, inputs, input_list, files=None): self._adb(["shell", f"rm -rf {self.workspace}"]) self._adb(["shell", f"mkdir -p {self.workspace}"]) @@ -104,8 +108,8 @@ def push(self, inputs, input_list): ), f"{self.qnn_sdk}/lib/aarch64-android/libQnnHtpPrepare.so", f"{self.qnn_sdk}/lib/aarch64-android/libQnnSystem.so", - f"{self.artifact_path}/examples/qualcomm/qnn_executor_runner", - f"{self.artifact_path}/backends/qualcomm/libqnn_executorch_backend.so", + f"{self.build_path}/{self.runner}", + f"{self.build_path}/backends/qualcomm/libqnn_executorch_backend.so", input_list_file, ]: self._adb(["push", artifact, self.workspace]) @@ -117,26 +121,35 @@ def push(self, inputs, input_list): d.detach().numpy().tofile(file_name) self._adb(["push", file_name, self.workspace]) - def execute(self): + # extra files + if files is not None: + for f in files: + self._adb(["push", f, self.workspace]) + + def execute(self, custom_runner_cmd=None): self._adb(["shell", f"mkdir -p {self.output_folder}"]) # run the delegation - qnn_executor_runner_args = " ".join( - [ - f"--model_path {os.path.basename(self.pte_path)}", - f"--output_folder_path {self.output_folder}", - f"--input_list_path {self.input_list_filename}", - f"--etdump_path {self.etdump_path}", - "--shared_buffer" if self.shared_buffer else "", - ] - ) - qnn_executor_runner_cmds = " ".join( - [ - f"cd {self.workspace} &&", - "export ADSP_LIBRARY_PATH=. &&", - "export LD_LIBRARY_PATH=. &&", - f"./qnn_executor_runner {qnn_executor_runner_args}", - ] - ) + if custom_runner_cmd is None: + qnn_executor_runner_args = " ".join( + [ + f"--model_path {os.path.basename(self.pte_path)}", + f"--output_folder_path {self.output_folder}", + f"--input_list_path {self.input_list_filename}", + f"--etdump_path {self.etdump_path}", + "--shared_buffer" if self.shared_buffer else "", + ] + ) + qnn_executor_runner_cmds = " ".join( + [ + f"cd {self.workspace} &&", + "export ADSP_LIBRARY_PATH=. &&", + "export LD_LIBRARY_PATH=. &&", + f"./qnn_executor_runner {qnn_executor_runner_args}", + ] + ) + else: + qnn_executor_runner_cmds = custom_runner_cmd + self._adb(["shell", f"{qnn_executor_runner_cmds}"]) def pull(self, output_path, callback=None): @@ -156,25 +169,33 @@ def build_executorch_binary( inputs, # noqa: B006 soc_model, file_name, - dataset, + dataset: List[torch.Tensor] | Callable[[torch.fx.GraphModule], None], custom_annotations=(), skip_node_id_set=None, skip_node_op_set=None, quant_dtype: Optional[QuantDtype] = None, + per_channel_linear=False, # TODO: remove this once QNN fully supports linear shared_buffer=False, + metadata=None, + act_observer=MovingAverageMinMaxObserver, ): if quant_dtype is not None: quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_annotations) + quantizer.set_per_channel_linear_quant(per_channel_linear) if quant_dtype == QuantDtype.use_8a8w: pass # default setting elif quant_dtype == QuantDtype.use_16a16w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_default_16bit_qnn_ptq_config()) + quantizer.set_bit16_op_quant_config( + get_default_16bit_qnn_ptq_config(act_observer=act_observer) + ) elif quant_dtype == QuantDtype.use_16a4w: quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config(get_16a4w_qnn_ptq_config()) + quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=act_observer) + ) quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") else: raise AssertionError(f"No support for QuantDtype {quant_dtype}.") @@ -183,8 +204,11 @@ def build_executorch_binary( annotated_model = prepare_pt2e(captured_model, quantizer) print("Quantizing the model...") # calibration - for data in dataset: - annotated_model(*data) + if callable(dataset): + dataset(annotated_model) + else: + for data in dataset: + annotated_model(*data) quantized_model = convert_pt2e(annotated_model) edge_prog = capture_program(quantized_model, inputs) @@ -208,29 +232,45 @@ def build_executorch_binary( debug=False, saver=False, shared_buffer=shared_buffer, + profile=False, ), skip_node_id_set, skip_node_op_set, ) - edge_prog.exported_program = to_backend(edge_prog.exported_program, qnn_partitioner) - edge_prog.exported_program.graph_module.graph.print_tabular() - exec_prog = edge_prog.to_executorch( - config=ExecutorchBackendConfig( - extract_constant_segment=False, - # For shared buffer, user must pass the memory address - # which is allocated by RPC memory to executor runner. - # Therefore, won't want to pre-allocate - # by memory manager in runtime. - memory_planning_pass=MemoryPlanningPass( - memory_planning_algo="greedy", - alloc_graph_input=not shared_buffer, - alloc_graph_output=not shared_buffer, - ), - extract_delegate_segments=True, - ) + + executorch_config = ExecutorchBackendConfig( + extract_constant_segment=False, + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + memory_planning_algo="greedy", + alloc_graph_input=not shared_buffer, + alloc_graph_output=not shared_buffer, + ), + extract_delegate_segments=True, ) - with open(f"{file_name}.pte", "wb") as file: - file.write(exec_prog.buffer) + + if metadata is None: + edge_prog.exported_program = to_backend( + edge_prog.exported_program, qnn_partitioner + ) + edge_prog.exported_program.graph_module.graph.print_tabular() + exec_prog = edge_prog.to_executorch(config=executorch_config) + with open(f"{file_name}.pte", "wb") as file: + file.write(exec_prog.buffer) + else: + edge_prog_mgr = EdgeProgramManager( + edge_programs={"forward": edge_prog.exported_program}, + constant_methods=metadata, + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + + edge_prog_mgr = edge_prog_mgr.to_backend(qnn_partitioner) + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{file_name}.pte", "wb") as file: + file.write(exec_prog_mgr.buffer) def make_output_dir(path: str): diff --git a/exir/scalar_type.py b/exir/scalar_type.py index d527e4a9833..b789a09f3a8 100644 --- a/exir/scalar_type.py +++ b/exir/scalar_type.py @@ -26,3 +26,4 @@ class ScalarType(IntEnum): BFLOAT16 = 15 QUINT4x2 = 16 QUINT2x4 = 17 + Bits16 = 22 diff --git a/exir/tensor.py b/exir/tensor.py index ee074cf7119..452bd6ab8a5 100644 --- a/exir/tensor.py +++ b/exir/tensor.py @@ -258,6 +258,7 @@ def memory_format_enum(memory_format: torch.memory_format) -> int: torch.qint32: ScalarType.QINT32, torch.bfloat16: ScalarType.BFLOAT16, torch.quint4x2: ScalarType.QUINT4x2, + torch.uint16: ScalarType.Bits16, } diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 3a5649b558c..6141ab862eb 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -165,4 +165,11 @@ Result> Module::execute( return outputs; } +Error Module::set_output_data_ptr(Tensor& output_tensor, size_t output_index) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method("forward")); + auto& method = methods_.at("forward").method; + return method->set_output_data_ptr( + output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index); +} + } // namespace torch::executor diff --git a/extension/module/module.h b/extension/module/module.h index 1cf10022269..983faf1faa9 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -194,6 +194,16 @@ class Module final { return event_tracer_.get(); } + /** + * Set output data pointer for forward method. + * + * @param[in] output_tensor A Tensor for the output of 'forward' method. + * @param[in] output_index Index of the output in 'forward' method. + * + * @returns An Error to indicate success or failure of the loading process. + */ + Error set_output_data_ptr(Tensor& output_tensor, size_t output_index); + private: struct MethodHolder { std::vector> planned_buffers;