From a1e3286cf8d1d5c7ba21a584e4905d8947ab3127 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Fri, 9 Aug 2024 16:45:59 +0800 Subject: [PATCH] Unexpected graph for mutable buffer in Quantization --- backends/qualcomm/tests/models.py | 3 ++- backends/qualcomm/tests/utils.py | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index ff52fc61b57..ef8b2851651 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -459,11 +459,12 @@ def __init__(self): self.register_buffer( "k_cache", torch.zeros((1, 1024, 12, 64), dtype=torch.float32), + persistent=False, ) def forward(self, input_pos, k_val): k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val) - return k_out + return k_out + k_out class LayerNorm(torch.nn.Module): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index ef0ac0f202f..f046541bc47 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -293,7 +293,8 @@ def lower_module_and_test_output( # this is needed for the ETRecord as lowering modifies the graph in-place edge_copy = copy.deepcopy(delegated_program) - + from executorch.backends.qualcomm.utils.utils import draw_graph + draw_graph("before_lower",".", delegated_program.exported_program.graph_module) delegated_program.exported_program = to_backend( delegated_program.exported_program, qnn_partitioner ) @@ -342,7 +343,10 @@ def get_qdq_module( custom_quant_annotations: Tuple[Callable] = (), quant_dtype: QuantDtype = QuantDtype.use_8a8w, ) -> torch.fx.GraphModule: + # New advice Api m = torch.export.export(module, inputs).module() + # Deprecated Api + # m = torch._export.capture_pre_autograd_graph(module, inputs) quantizer = QnnQuantizer() quantizer.add_custom_quant_annotations(custom_quant_annotations) @@ -363,7 +367,8 @@ def get_qdq_module( prepared = prepare_pt2e(m, quantizer) prepared(*inputs) - quantized_module = convert_pt2e(prepared) + # Whether fold quantized or not + quantized_module = convert_pt2e(prepared, fold_quantize=True) nodes = {node.target for node in quantized_module.graph.nodes} q_and_dq = { torch.ops.quantized_decomposed.quantize_per_tensor.default,