From 843f6d0bb9ff64348e7ae569ab41a82ff8145a4e Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Wed, 28 Aug 2024 10:50:40 -0700 Subject: [PATCH] Allow delegate to consume buffer mutations (#4830) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/4830 Fixing https://github.com/pytorch/executorch/issues/4209 Edge Program: ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Partitioned / lowered Exported Program (buffer mutation gets removed): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # No stacktrace found for following nodes lowered_module_0 = self.lowered_module_0 executorch_call_delegate = torch.ops.higher_order.executorch_call_delegate(lowered_module_0, x); lowered_module_0 = x = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b getitem_1: "f32[3, 3]" = executorch_call_delegate[0]; executorch_call_delegate = None return (getitem_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem_1'), target=None)]) ``` Delegate (consumes the buffer mutation): ``` ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, b_b: "f32[3, 3]", x: "f32[3, 3]"): # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:631 in forward, code: self.b.add_(x) aten_add_tensor: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(b_b, x); b_b = None # File: /data/users/angelayi/fbsource/buck-out/v2/gen/fbcode/389acaeb40d57230/executorch/exir/backend/test/__test_partitioner__/test_partitioner#link-tree/executorch/exir/backend/test/test_partitioner.py:632 in forward, code: return x + self.b aten_add_tensor_1: "f32[3, 3]" = executorch_exir_dialects_edge__ops_aten_add_Tensor(x, aten_add_tensor); x = None return (aten_add_tensor, aten_add_tensor_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='b_b'), target='b', persistent=True), InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='aten_add_tensor'), target='b'), OutputSpec(kind=, arg=TensorArgument(name='aten_add_tensor_1'), target=None)]) ``` Differential Revision: D60838243 --- backends/apple/mps/test/test_mps_utils.py | 2 +- exir/backend/test/TARGETS | 3 + exir/backend/test/op_partitioner_demo.py | 50 +++++++++ exir/backend/test/test_partitioner.py | 112 +++++++++++++++++++ exir/lowered_backend_module.py | 128 +++++++++++++++++++++- extension/export_util/utils.py | 2 +- 6 files changed, 289 insertions(+), 8 deletions(-) diff --git a/backends/apple/mps/test/test_mps_utils.py b/backends/apple/mps/test/test_mps_utils.py index 77c02f533be..199a7fe1782 100644 --- a/backends/apple/mps/test/test_mps_utils.py +++ b/backends/apple/mps/test/test_mps_utils.py @@ -229,7 +229,7 @@ def lower_module_and_test_output( compile_specs = [CompileSpec("use_fp16", bytes([use_fp16]))] if use_partitioner: - logging.info(f"Edge IR graph:\n{edge_program.exported_program().graph}") + logging.info(f"Edge IR graph:\n{edge_program.exported_program()}") delegated_program = edge_program delegated_program = edge_program.to_backend( MPSPartitioner(compile_specs=compile_specs) diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS index ed58b06b3dd..b99f374d83c 100644 --- a/exir/backend/test/TARGETS +++ b/exir/backend/test/TARGETS @@ -88,6 +88,8 @@ python_library( "//executorch/exir/backend:compile_spec_schema", "//executorch/exir/backend:partitioner", "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", + "//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner", + "//executorch/exir/backend/test/demos/rpc:executor_backend_preprocess", "//executorch/exir/dialects:lib", ], ) @@ -290,6 +292,7 @@ python_unittest( "//executorch/exir/backend/test/demos/rpc:executor_backend_register", ], deps = [ + ":op_partitioner_demo", "//caffe2:torch", "//executorch/exir:lib", "//executorch/exir/backend:backend_details", diff --git a/exir/backend/test/op_partitioner_demo.py b/exir/backend/test/op_partitioner_demo.py index dc20c03e68b..62a0aeb782c 100644 --- a/exir/backend/test/op_partitioner_demo.py +++ b/exir/backend/test/op_partitioner_demo.py @@ -21,6 +21,9 @@ from executorch.exir.backend.test.backend_with_compiler_demo import ( BackendWithCompilerDemo, ) +from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( + ExecutorBackend, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.graph_module import get_control_flow_submodules from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param @@ -29,6 +32,11 @@ from torch.fx.passes.operator_support import any_chain, OperatorSupportBase +class AllOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" + + class AddOperatorSupport(OperatorSupportBase): def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: return node.op == "call_function" and node.target in [ @@ -126,6 +134,48 @@ def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult: ) +@final +class AllNodesPartitionerDemo(Partitioner): + """ + Partitions all nodes + """ + + def __init__(self) -> None: + self.op_support = AllOperatorSupport() + self.delegation_spec = DelegationSpec(ExecutorBackend.__name__, []) + + def partition(self, edge_exported_program: ExportedProgram) -> PartitionResult: + partition_tags = {} + partition_list = generate_pattern_op_partitions( + edge_exported_program.graph_module, op_support=self.op_support + ) + for partition in partition_list: + for node in partition.nodes: + delegation_tag = f"tag{partition.id}" + partition_tags[delegation_tag] = self.delegation_spec + + # Tag the add nodes + node.meta["delegation_tag"] = delegation_tag + + for arg_node in node.args: + if not isinstance(arg_node, torch.fx.Node): + continue + + is_get_attr = arg_node.op == "get_attr" + is_param_buffer = arg_node.op == "placeholder" and ( + is_param(edge_exported_program, arg_node) + or is_buffer(edge_exported_program, arg_node) + or is_lifted_tensor_constant(edge_exported_program, arg_node) + ) + if is_get_attr or is_param_buffer: + arg_node.meta["delegation_tag"] = delegation_tag + # Add to the list of partitioned nodes. + + return PartitionResult( + tagged_exported_program=edge_exported_program, partition_tags=partition_tags + ) + + ops_not_to_decompose = [ torch.ops.aten.linear.default, torch.ops.aten.scaled_dot_product_attention.default, diff --git a/exir/backend/test/test_partitioner.py b/exir/backend/test/test_partitioner.py index 3ee6202ae8e..3973011a269 100644 --- a/exir/backend/test/test_partitioner.py +++ b/exir/backend/test/test_partitioner.py @@ -26,6 +26,10 @@ from executorch.exir.backend.test.demos.rpc.executor_backend_preprocess import ( ExecutorBackend, ) +from executorch.exir.backend.test.op_partitioner_demo import ( + AddAttributePartitionerDemo, + AllNodesPartitionerDemo, +) from executorch.exir.backend.utils import get_delegates, tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops @@ -619,3 +623,111 @@ def partition( and node.target == torch.ops.aten.copy_.default ] self.assertEqual(len(copy_node), 1) + + def test_buffer_mutation1(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("b", torch.ones(3, 3)) + + def forward(self, x): + self.b.add_(x) + return x + self.b + + model_inputs = (torch.ones(3, 3),) + orig_res = TestModule()(*model_inputs) + edge_program = exir.to_edge(torch.export.export(TestModule(), model_inputs)) + lowered = edge_program.to_backend(AddAttributePartitionerDemo()) + + self.assertTrue( + torch.allclose(lowered.exported_program().module()(*model_inputs), orig_res) + ) + + self.assertEqual( + len(lowered.exported_program().graph_signature.buffers_to_mutate), + 0, + ) + lowered_module_nodes = get_delegates(lowered.exported_program().graph) + self.assertEqual(len(lowered_module_nodes), 1) + lowered_module_node = lowered_module_nodes[0] + + # get call delegate node + call_delegate_node = list(lowered_module_node.users.keys())[0] + self.assertEqual(len(call_delegate_node.args), 2) + + lower_module = getattr( + lowered.exported_program().graph_module, lowered_module_node.name + ) + delegated_ep = lower_module.original_module + + self.assertEqual(len(delegated_ep.state_dict), 1) + self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1) + self.assertEqual(len(delegated_ep.graph_signature.buffers), 1) + + def test_buffer_mutation_llama_repro(self): + SHAPE = (2, 3) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("cache", torch.zeros(SHAPE, dtype=torch.float32)) + + def forward(self, q, k_val, input_pos): + q_T = q.transpose(0, 1) + k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val) + attn = k.mm(q_T) + return attn + + q = torch.rand(1, 3) + k = torch.rand(1, 3) + example_inputs = (q, k, torch.tensor([1, 1])) + + model = Model() + model.eval() + + exir_program_aten = torch.export.export(model, example_inputs) + exir_program_aten.module()(*example_inputs) + edge_program_manager = exir.to_edge(exir_program_aten) + lowered = edge_program_manager.to_backend(AllNodesPartitionerDemo()) + + self.assertEqual( + len(lowered.exported_program().graph_signature.buffers_to_mutate), + 0, + ) + lowered_module_nodes = get_delegates(lowered.exported_program().graph) + self.assertEqual(len(lowered_module_nodes), 1) + lowered_module_node = lowered_module_nodes[0] + + # get call delegate node + call_delegate_node = list(lowered_module_node.users.keys())[0] + self.assertEqual(len(call_delegate_node.args), 4) + + lower_module = getattr( + lowered.exported_program().graph_module, lowered_module_node.name + ) + delegated_ep = lower_module.original_module + + self.assertEqual(len(delegated_ep.state_dict), 1) + self.assertEqual(len(delegated_ep.graph_signature.buffers_to_mutate), 1) + self.assertEqual(len(delegated_ep.graph_signature.buffers), 1) + + def test_buffer_mutation_unsupported(self): + SHAPE = (2, 3) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("state_1", torch.zeros(SHAPE, dtype=torch.float32)) + + def forward(self, x): + add = self.state_1.add_(x) + return add + + model = Model() + model.eval() + + example_inputs = (torch.randn(SHAPE),) + exir_program_aten = torch.export.export(model, example_inputs) + edge_program_manager = exir.to_edge(exir_program_aten) + with self.assertRaises(AssertionError): + edge_program_manager.to_backend(AddAttributePartitionerDemo()) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 4d07fdcdf06..d93905a2bd0 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -8,6 +8,7 @@ import copy import operator +from collections import defaultdict from typing import Any, Dict, List, Optional, Set, Tuple, Union import torch @@ -488,8 +489,12 @@ def _get_new_signature( # noqa: C901 else {} ) + toplevel_output_node_to_sig: Dict[str, List[OutputSpec]] = defaultdict(list) + if not is_submodule: + for output_spec in old_signature.output_specs: + toplevel_output_node_to_sig[output_spec.arg.name].append(output_spec) + for node in gm.graph.nodes: - is_tagged = tag is None or node.meta.get("delegation_tag", None) == tag if node.op == "placeholder": if node.name not in input_node_to_sig: @@ -507,7 +512,7 @@ def _get_new_signature( # noqa: C901 if not isinstance(orig_input_spec.arg, TensorArgument): input_specs.append(orig_input_spec) - elif is_tagged: + elif node.meta.get("delegation_tag", None) == tag: input_specs.append(orig_input_spec) if orig_input_spec.kind == InputKind.USER_INPUT: @@ -551,11 +556,72 @@ def _get_new_signature( # noqa: C901 ) if node.op == "output": - output_nodes = pytree.tree_leaves((node.args, node.kwargs)) + buffer_mutation_idxs: Dict[int, List[OutputSpec]] = defaultdict(list) + for user in call_module_node.users.keys(): + if user.name in toplevel_output_node_to_sig: + assert ( + user.op == "call_function" and user.target == operator.getitem + ), f"Invalid user {user}, node.op is {user.op} and node.target is {user.target}" + getitem_idx = user.args[1] + assert isinstance( + getitem_idx, int + ), f"Invalid getitem type: {type(getitem_idx)}" + buffer_mutation_idxs[getitem_idx].extend( + toplevel_output_node_to_sig[user.name] + ) - for output_node in output_nodes: + for i, output_node in enumerate(node.args[0]): + if i in buffer_mutation_idxs: + assert isinstance(output_node, torch.fx.Node) + orig_output_specs = buffer_mutation_idxs[i] + + if any( + orig_output_spec.kind == OutputKind.BUFFER_MUTATION + and orig_output_spec.target in new_state_dict + for orig_output_spec in orig_output_specs + ): + # If the delegate wants to consume the buffer, then the + # delegate should also consume the buffer mutation + # (output spec would be a BUFFER_MUTATION). Otherwise + # the delegate will just return the result of the + # mutation as a USER_OUTPUT. + + orig_output_spec = [ + orig_output_spec + for orig_output_spec in orig_output_specs + if orig_output_spec.kind == OutputKind.BUFFER_MUTATION + and orig_output_spec.target in new_state_dict + ][0] + + assert len(orig_output_specs) == 1, ( + f"Constant {orig_output_spec.target} was tagged to be " + "consumed by the buffer, and was found to also contain " + "a buffer mutation. However this buffer mutation node " + "was found to also be used as other types of outputs " + "which is currently not supported. Please file an " + "issue on Github. \n\n" + f"The toplevel program: {original_program}\n" + ) + output_specs.append( + OutputSpec( + kind=OutputKind.BUFFER_MUTATION, + arg=TensorArgument(name=output_node.name), + target=orig_output_spec.target, + ) + ) + output_specs_to_delete[orig_output_spec.arg.name] = ( + orig_output_spec + ) + else: + output_specs.append( + OutputSpec( + kind=OutputKind.USER_OUTPUT, + arg=TensorArgument(name=output_node.name), + target=None, + ) + ) - if not isinstance(output_node, torch.fx.Node): + elif not isinstance(output_node, torch.fx.Node): output_specs.append( OutputSpec( kind=OutputKind.USER_OUTPUT, @@ -630,6 +696,9 @@ def create_exported_program_from_submodule( in_spec = pytree.tree_flatten((tuple(subgraph_signature.user_inputs), {}))[1] out_spec = pytree.tree_flatten(subgraph_signature.user_outputs)[1] + print(submodule.graph) + print(subgraph_signature) + return ( ExportedProgram( root=submodule, @@ -774,7 +843,7 @@ def get_lowered_backend_modules( return lowered_programs -def _unsafe_adjust_original_program( +def _unsafe_adjust_original_program( # noqa: C901 original_program: ExportedProgram, call_delegate_node: torch.fx.Node, input_specs_to_delete: Dict[str, InputSpec], @@ -830,3 +899,50 @@ def _unsafe_adjust_original_program( del original_program._constants[input_spec.target] else: raise RuntimeError(f"Invalid input spec {input_spec} received") + + # Delete buffer mutations from the output which were consumed by the delegate + toplevel_output_node = None + for node in reversed(original_program.graph.nodes): + if node.op == "output": + toplevel_output_node = node + break + + assert toplevel_output_node is not None + assert ( + len(toplevel_output_node.args) == 1 + ), f"Invalid output node: {toplevel_output_node} with args {toplevel_output_node.args}" + + new_output_args = [ + arg + for arg in toplevel_output_node.args[0] + if not isinstance(arg, torch.fx.Node) or arg.name not in output_specs_to_delete + ] + toplevel_output_node.args = (tuple(new_output_args),) + + # Delete the buffer mutation getitem nodes + getitem_idxs: List[int] = [] + user_nodes = list(call_delegate_node.users.keys()) + for user in user_nodes: + if user.name in output_specs_to_delete: + assert ( + user.op == "call_function" and user.target == operator.getitem + ), f"Invalid user {user}, node.op is {node.op} and node.target is {node.target}" + user_idx = user.args[1] + assert isinstance(user_idx, int), f"Invalid getitem type: {type(user_idx)}" + getitem_idxs.append(user_idx) + original_program.graph.erase_node(user) + + getitem_idxs.sort(reverse=True) + + # Adjust all the getitem indices after the deleted getitems + user_nodes = list(call_delegate_node.users.keys()) + for user in user_nodes: + assert user.op == "call_function" and user.target == operator.getitem + user_idx = user.args[1] + assert isinstance(user_idx, int) + for i, idx in enumerate(getitem_idxs): + if user_idx > idx: + user.args = (user.args[0], user_idx - (len(getitem_idxs) - i)) + break + + original_program._validate() diff --git a/extension/export_util/utils.py b/extension/export_util/utils.py index 5c2700e6f5e..37e09babbb9 100644 --- a/extension/export_util/utils.py +++ b/extension/export_util/utils.py @@ -63,7 +63,7 @@ def _core_aten_to_edge( compile_config=edge_compile_config, ) if verbose: - logging.info(f"Exported graph:\n{edge_manager.exported_program().graph}") + logging.info(f"Exported graph:\n{edge_manager.exported_program()}") return edge_manager