From 23a3efd7b58dd41c2cd5a00a36b731f0acc6d315 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 11 Mar 2025 12:54:16 -0700 Subject: [PATCH] [ExecuTorch][Weight Sharing] Track Named Data Store in EdgeProgramManager We enable Backends to return Named Data by adding NamedDataStoreOutput to the preprocess result. This is a completely BC change, as no backends with an implemented preprocess will see any change if nothing is explicitly implemented. For backend developers to leverage the new NamedDataStore, they can initialize a new NamedDataStore() within preprocess, add_named_data to the data store, and return the NamedDataStore.get_named_data_store_output() in the preprocess result like such: ``` def preprocess(ExportedProgram, List[CompileSpecs]) -> PreprocessResult: named_data_store = NamedDataStore() for node in exported_program.graph.nodes: named_data_store.add_named_data("name", bytes) return PreprocessResult( processed_bytes=bytes, debug_handle_map={}, data_store_output= named_data_store.get_named_data_store_output() ) ``` Under the hood, the data store output is embedded in the loweredbackendmodule, (serializing loweredbackendmodule by itself with the a named_data_store_output is still a todo). But via the EdgeProgramManager path, we add the named_data_store_outputs to the edge_program_manger's named data store to keep track of all the named data returned by backends. Differential Revision: [D70451660](https://our.internmc.facebook.com/intern/diff/D70451660/) [ghstack-poisoned] --- exir/backend/backend_api.py | 9 +- exir/backend/backend_details.py | 6 + exir/backend/test/TARGETS | 56 +++++++++ .../test/backend_with_named_data_map.py | 114 ++++++++++++++++++ .../test/test_backend_with_named_data_map.py | 85 +++++++++++++ exir/lowered_backend_module.py | 12 ++ exir/program/_program.py | 24 +++- 7 files changed, 302 insertions(+), 4 deletions(-) create mode 100644 exir/backend/test/backend_with_named_data_map.py create mode 100644 exir/backend/test/test_backend_with_named_data_map.py diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index 519f184871a..9927dee13fb 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -9,11 +9,15 @@ import logging from contextlib import contextmanager, nullcontext from functools import singledispatch -from typing import Generator, List +from typing import Generator, List, Optional import torch -from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult +from executorch.exir.backend.backend_details import ( + BackendDetails, + PreprocessResult, +) +from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import Partitioner, PartitionResult @@ -120,6 +124,7 @@ def to_backend( backend_id=backend_id, processed_bytes=preprocess_result.processed_bytes, compile_specs=compile_specs, + named_data_store_output=preprocess_result.data_store_output ) lowered_module.meta = { "debug_handle_map": preprocess_result.debug_handle_map diff --git a/exir/backend/backend_details.py b/exir/backend/backend_details.py index bdbc1a1fafd..167cc5374d2 100644 --- a/exir/backend/backend_details.py +++ b/exir/backend/backend_details.py @@ -11,6 +11,7 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from torch.export.exported_program import ExportedProgram +from executorch.exir._serialize._named_data_store import NamedDataStoreOutput def enforcedmethod(func): @@ -24,6 +25,11 @@ class PreprocessResult: debug_handle_map: Optional[Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]] = ( None ) + # Data Store output created from NamedDataStore. + + # Named Data store contains all the named data that is stored in the PTE file, + # but retrieveable by delegates via the NamedDataMap at runtime. + data_store_output: Optional[NamedDataStoreOutput] = None """ diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS index b453f4c722a..f0ba618936d 100644 --- a/exir/backend/test/TARGETS +++ b/exir/backend/test/TARGETS @@ -38,6 +38,62 @@ python_library( ], ) +python_library( + name = "backend_with_named_data_map", + srcs = [ + "backend_with_named_data_map.py", + ], + visibility = [ + "//executorch/...", + "//executorch/test/...", + ], + deps = [ + "//caffe2:torch", + "//caffe2/functorch:functorch_src", + "//executorch/exir:delegate", + "//executorch/exir:graph_module", + "//executorch/exir:lib", + "//executorch/exir:lowered_backend_module", + "//executorch/exir:print_program", + "//executorch/exir:schema", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/backend:partitioner", + "//executorch/exir/dialects:lib", + "//executorch/extension/pybindings:portable_lib", # @manual + "//executorch/extension/pytree:pylib", + "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", + ], +) + +python_unittest( + name = "test_backend_with_named_data_map", + srcs = [ + "test_backend_with_named_data_map.py", + ], + visibility = [ + "//executorch/...", + "//executorch/test/...", + ], + deps = [ + "//caffe2:torch", + "//caffe2/functorch:functorch_src", + "//executorch/exir:delegate", + "//executorch/exir:graph_module", + "//executorch/exir:lib", + "//executorch/exir:lowered_backend_module", + "//executorch/exir:print_program", + "//executorch/exir:schema", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/backend:partitioner", + "//executorch/exir/dialects:lib", + "//executorch/extension/pybindings:portable_lib", # @manual + "//executorch/extension/pytree:pylib", + ":backend_with_named_data_map", + ], +) + python_library( name = "qnn_backend_demo", srcs = [ diff --git a/exir/backend/test/backend_with_named_data_map.py b/exir/backend/test/backend_with_named_data_map.py new file mode 100644 index 00000000000..0a77a8752be --- /dev/null +++ b/exir/backend/test/backend_with_named_data_map.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import final, List, Dict, Tuple + +import torch + +from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.dialects._ops import ops as exir_ops +from torch.export.exported_program import ExportedProgram +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_pattern_op_partitions, +) + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_control_flow_submodules +from torch.export import ExportedProgram +from torch.fx.passes.operator_support import OperatorSupportBase +from executorch.exir._serialize._named_data_store import NamedDataStore + + +# Backend details are final (cannot be subclassed). +@final +class BackendWithNamedDataMap(BackendDetails): + """ + Test Backend for Named Data Map Functionality + + This backend returns no processed_bytes, instead it uses + the named data store and serializes the name of the op + as the key and the data as its code value + """ + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + op_codes = { + exir_ops.edge.aten.sin.default: 0, + exir_ops.edge.aten.add.Tensor: 1, + exir_ops.edge.aten.sub.Tensor: 2, + exir_ops.edge.aten.mul.Tensor: 3, + exir_ops.edge.aten.div.Tensor: 4 + } + ndm = NamedDataStore() + for node in edge_program.graph.nodes: + if node.op == "call_function": + if node.target in op_codes.keys(): + ndm.add_named_data(node.target.__name__, bytes(op_codes[node.target])) + + + return PreprocessResult( + processed_bytes=bytes(b""), + debug_handle_map={}, + data_store_output=ndm.get_named_data_store_output(), + ) + +class SimpleOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node:torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + exir_ops.edge.aten.sin.default, + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor + ] + +@final +class BackendWithNDMPartitioner(Partitioner): + def __init__(self) -> None: + self._op_support = SimpleOperatorSupport() + self.backend_id = BackendWithNamedDataMap.__name__ + + def _partition_gm(self, graph_module: torch.fx.GraphModule, id_start:int = 0) -> Tuple[int, Dict[str, DelegationSpec]]: + partition_tags: Dict[str, DelegationSpec] = {} + partition_list = generate_pattern_op_partitions( + graph_module, op_support=self._op_support + ) + + num_partitions_in_gm = len(partition_list) + for partition in partition_list: + curr_par_id = partition.id or 0 + delegation_tag =f"tag_{curr_par_id + id_start}" + for node in partition.nodes: + node.meta["delegation_tag"] = delegation_tag + delegation_spec = DelegationSpec(self.backend_id, []) + partition_tags[delegation_tag] = delegation_spec + + start_idx_for_submodules = num_partitions_in_gm + for _, submodule, _ in get_control_flow_submodules(graph_module): + start_idx_for_submodules, ret_partition_tags = self._partition_gm( + submodule, start_idx_for_submodules + ) + partition_tags.update(ret_partition_tags) + + + return start_idx_for_submodules, partition_tags + + def partition(self, edge_program: ExportedProgram) -> PartitionResult: + _, partition_tags = self._partition_gm(edge_program.graph_module) + return PartitionResult( + tagged_exported_program=edge_program, + partition_tags=partition_tags, + ) diff --git a/exir/backend/test/test_backend_with_named_data_map.py b/exir/backend/test/test_backend_with_named_data_map.py new file mode 100644 index 00000000000..11fe691385e --- /dev/null +++ b/exir/backend/test/test_backend_with_named_data_map.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +from executorch.exir import to_edge + +from executorch.exir.backend.test.backend_with_named_data_map import ( + BackendWithNamedDataMap, + BackendWithNDMPartitioner +) +from executorch.exir.backend.backend_api import to_backend + +from torch.testing import FileCheck +from torch.export.exported_program import ExportedProgram + +class TestBackendWithNamedDataMap(unittest.TestCase): + def test_lowered_backend_module_has_output(self): + class M(torch.nn.Module): + def forward(self, x): + return x + x + + ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),))) + lowered = to_backend( + BackendWithNamedDataMap.__name__, ep.exported_program(), [] + ) + + buffer_entries = lowered.named_data_store_output.buffers + self.assertTrue(len(buffer_entries) == 1) + stored_data = lowered.named_data_store_output.pte_data + + self.assertTrue("aten.add.Tensor" in stored_data) + self.assertTrue(buffer_entries[0].buffer == bytes(1)) + + def test_named_data_with_partitioner(self): + class M(torch.nn.Module): + def forward(self, x): + y = x + x + y = torch.cos(y) + y = y + y + y = torch.sin(y) + return y - y + + ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),))) + ep.to_backend(BackendWithNDMPartitioner()) + + ndm_output = ep._named_data_store.get_named_data_store_output() + buffer_entries = ndm_output.buffers + stored_data =ndm_output.pte_data + self.assertEqual(len(buffer_entries), 3) + self.assertTrue("aten.add.Tensor" in stored_data) + self.assertTrue("aten.sub.Tensor" in stored_data) + self.assertTrue("aten.sin.default" in stored_data) + + def test_named_data_with_control_flow(self): + class M(torch.nn.Module): + def true_branch(self, x): + y = x * x + y = torch.cos(y) + return torch.sin(y) + + def false_branch(self, x): + return torch.sin(x) + + def forward(self, x, y): + z = x/y + z = torch.cond(z > 1, self.true_branch, self.false_branch, [x]) + return z - z + + ep = to_edge(torch.export.export(M(), (torch.randn(1, 2), torch.randn(1, 2)))) + ep.to_backend(BackendWithNDMPartitioner()) + + ndm_output = ep._named_data_store.get_named_data_store_output() + buffer_entries = ndm_output.buffers + stored_data =ndm_output.pte_data + self.assertEqual(len(buffer_entries), 4) + self.assertTrue("aten.sub.Tensor" in stored_data) + self.assertTrue("aten.div.Tensor" in stored_data) + self.assertTrue("aten.sin.default" in stored_data) + self.assertTrue("aten.mul.Tensor" in stored_data) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index dde6a397d9a..77a8d831c46 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -23,6 +23,7 @@ from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass from executorch.exir.schema import Program +from executorch.exir._serialize._named_data_store import NamedDataStoreOutput from executorch.exir.tracer import Value from torch._library.fake_class_registry import FakeScriptObject @@ -62,6 +63,7 @@ class LoweredBackendModule(torch.nn.Module): CompileSpec ] # A list of backend-specific objects with static metadata to configure the "compilation" process. _original_exported_program: ExportedProgram # The original EXIR module + _named_data_store_output: Optional[NamedDataStoreOutput] # Named Data serialized by the backend def __init__( self, @@ -69,12 +71,14 @@ def __init__( backend_id: str, processed_bytes: bytes, compile_specs: List[CompileSpec], + named_data_store_output: Optional[NamedDataStoreOutput] = None, ) -> None: super().__init__() self._original_exported_program = edge_program self._backend_id = backend_id self._processed_bytes = processed_bytes self._compile_specs = compile_specs + self._named_data_store_output = named_data_store_output # pyre-ignore def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule": @@ -133,6 +137,13 @@ def original_module(self) -> ExportedProgram: Returns the original EXIR module """ return self._original_exported_program + + @property + def named_data_store_output(self) -> Optional[NamedDataStoreOutput]: + """ + Returns the Named Data Store Output + """ + return self._named_data_store_output # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api def buffer( @@ -154,6 +165,7 @@ def buffer( segment_alignment=segment_alignment, constant_tensor_alignment=constant_tensor_alignment, delegate_alignment=delegate_alignment, + named_data=self.named_data_store_output, ) ) return out diff --git a/exir/program/_program.py b/exir/program/_program.py index ed9dace34d1..aab488cae9a 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -41,6 +41,7 @@ MemoryFormatOpsPass, OpReplacePass, ) +from executorch.exir.delegate import executorch_call_delegate, is_lowered_module from executorch.exir.passes.external_constants_pass import ( external_constants_pass, external_mutable_weights_pass, @@ -1304,6 +1305,7 @@ def __init__( constant_methods: Optional[Dict[str, Any]] = None, compile_config: Optional[EdgeCompileConfig] = None, ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, + named_data_store: Optional[NamedDataStore] = None, ): """ Should not be called directly by users. User should use :func:'to_edge' instead. @@ -1327,7 +1329,7 @@ def __init__( self._edge_programs: Dict[str, ExportedProgram] = edge_programs self._config_methods = constant_methods - self._named_data_store = NamedDataStore() + self._named_data_store = named_data_store or NamedDataStore() @property def methods(self) -> Set[str]: @@ -1436,10 +1438,28 @@ def to_backend( else: # apply partitioner to every method for name, program in self._edge_programs.items(): new_edge_programs[name] = to_backend(program, partitioner) + + # collected all the named data into the named data store for deduplication + def collect_named_data_store_outputs( + graph_module: torch.fx.GraphModule, + ) -> None: + for node in graph_module.graph.nodes: + if node.target == executorch_call_delegate: + lbm = getattr(graph_module, node.args[0].name) + assert(is_lowered_module(lbm)) + data_store_output = lbm.named_data_store_output + if data_store_output is not None: + self._named_data_store.merge_named_data_store(data_store_output) + + for _, submod, _ in get_control_flow_submodules(graph_module): + collect_named_data_store_outputs(submod) + + for name, program in new_edge_programs.items(): + collect_named_data_store_outputs(program.graph_module) config = EdgeCompileConfig(_check_ir_validity=False) return EdgeProgramManager( - new_edge_programs, copy.deepcopy(self._config_methods), config + new_edge_programs, copy.deepcopy(self._config_methods), config, named_data_store=self._named_data_store ) @et_logger("to_executorch")