From 12ecb5a93e9c2cf18852cbc4c825fac248701b18 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 17 Sep 2025 17:05:02 -0700 Subject: [PATCH] Add option in memory planning to put shared state on same location across entry points (#14230) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14230 API that lets you place the same state tensor on the same id and offset across entry points. Lets you have get and set state more natively in the runtime if the underlying arenas are the same. Reviewed By: GregoryComer Differential Revision: D82250153 --- exir/emit/_emitter.py | 27 ++++- exir/memory_planning.py | 3 + exir/passes/memory_planning_pass.py | 155 +++++++++++++++++++++++++++- exir/program/_program.py | 17 ++- exir/tests/test_memory_planning.py | 52 ++++++++++ 5 files changed, 244 insertions(+), 10 deletions(-) diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 6995f9f73a9..7701ca7b8ff 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -93,7 +93,8 @@ from executorch.exir.types import LeafValueSpec, ValueSpec from torch._subclasses.fake_tensor import FakeTensor -from torch.export.exported_program import ExportedProgram +from torch.export.exported_program import ExportedProgram, ExportGraphSignature +from torch.fx.node import Node from torch.utils import _pytree as pytree from typing_extensions import TypeAlias @@ -209,11 +210,11 @@ class _AbstractValue: ] -# pyre-ignore[13]: Attribute `node` is never initialized. class _Emitter(torch.fx.Interpreter): """An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the given traced torch.fx.GraphModule to the flatbuffer schema.""" + # pyre-ignore[13]: Attribute `node` is never initialized. node: torch.fx.Node def __init__( @@ -1633,6 +1634,28 @@ def placeholder( # noqa: C901 if isinstance(target, str) and isinstance(spec, TensorSpec): fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec) + def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool: + """ + Check if the node is buffer according to the provided graph signature. + If it is one return its fqn as well + """ + if node.op == "placeholder": + if isinstance(node.target, str): + if node.target in graph_signature.inputs_to_buffers: + return True + return False + + # If the spec does not appear in the mutable section of the graph signature it still might + # overall be considered a mutable buffer if it has already been memory planned. This would + # suggest that the same abstract buffer is mutable in another entry point so we should + # compel it to be considered mutable in all entry points at emission just as the user did with + # memory planning. + is_mutable_buffer |= ( + _is_buffer(self.node, self.exported_program.graph_signature) + and spec.mem_id is not None + and spec.mem_offset is not None + ) + # If the placeholder has a constant_tag, it is external to the PTE file # and requires a fqn and location=TensorDataLocation.EXTERNAL if constant_tag is not None: diff --git a/exir/memory_planning.py b/exir/memory_planning.py index e08d3e55772..0394ed9c529 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -245,6 +245,8 @@ def verify_graph_input_output(self) -> None: assert len(specs) > 0, "Expect tensor specs" specs = list(filter(lambda spec: not spec.const, specs)) if len(specs) == 0: + # all outputs are const so no need to allocate memory just say we suceeded + graph_output_allocated = self.alloc_graph_output continue allocated = any( spec is None or spec.mem_offset is not None for spec in specs @@ -408,6 +410,7 @@ def collect_specs_from_nodes( # noqa: C901 ignore_graph_input: bool = False, ignore_graph_output: bool = False, ignore_mutable_buffers: bool = False, + share_mutable_buffers: bool = False, ignore_const: bool = True, ignore_out_var_node: bool = True, dedup: bool = True, diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index 9bd4ab20bf5..2636b61780c 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -4,10 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools import logging import warnings +from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Set, Tuple import torch from executorch.exir._warnings import deprecated @@ -16,14 +18,18 @@ from executorch.exir.memory_planning import ( _is_out_var_node, apply_algo, + collect_specs_from_nodes, + filter_nodes, get_node_tensor_specs, MemoryPlanningAlgorithmSuite, Verifier, ) from executorch.exir.operator.convert import get_out_args_from_opoverload from executorch.exir.pass_base import PassBase, PassResult -from executorch.exir.tensor import ALIGNMENT +from executorch.exir.tensor import ALIGNMENT, TensorSpec +from torch import fx from torch.export.exported_program import ExportGraphSignature +from torch.fx import Node # copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function @@ -37,6 +43,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str: return str(any_callable) +def _is_buffer( + node: Node, graph_signature: ExportGraphSignature +) -> Tuple[bool, Optional[str]]: + """ + Check if the node is buffer according to the provided graph signature. + If it is one return its fqn as well + """ + if node.op == "placeholder": + if isinstance(node.target, str): + if node.target in graph_signature.inputs_to_buffers: + fqn = graph_signature.inputs_to_buffers[node.target] + return (True, fqn) + return (False, None) + + +def _is_mutable_buffer( + node: Node, graph_signature: ExportGraphSignature +) -> Tuple[bool, Optional[str]]: + """ + Check if the node is mutable buffer according to the provided graph signature. + If it is one return its fqn as well + """ + if node.op == "placeholder": + if isinstance(node.target, str): + if node.target in graph_signature.inputs_to_buffers: + fqn = graph_signature.inputs_to_buffers[node.target] + # if the buffer is mutated then record that + if fqn in graph_signature.buffers_to_mutate.values(): + return True, fqn + return False, None + + +def _get_spec_from_node(node: fx.Node) -> TensorSpec: + specs = get_node_tensor_specs(node) + return specs[0] + + +def _insert_mutable_buffer_specs( + state: "_MemoryPlanningState", gm: torch.fx.GraphModule, gs: ExportGraphSignature +): + for node in gm.graph.nodes: + is_mutable, fqn = _is_mutable_buffer(node, gs) + if is_mutable: + assert fqn + spec = _get_spec_from_node(node) + if ( + getattr(spec, "mem_id", None) is not None + or getattr(spec, "mem_offset", None) is not None + ): + raise ValueError( + "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned" + ) + if fqn not in state.mutable_buffers.keys(): + state.mutable_buffers[fqn] = set() + state.mutable_buffers[fqn].add(spec) + continue + is_buffer, fqn = _is_buffer(node, gs) + # If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state() + # So cache it and later double check that this buffer never appears mutable + if is_buffer: + assert fqn + spec = _get_spec_from_node(node) + if ( + getattr(spec, "mem_id", None) is not None + or getattr(spec, "mem_offset", None) is not None + ): + raise ValueError( + "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned" + ) + if fqn not in state.maybe_mutable_buffers.keys(): + state.maybe_mutable_buffers[fqn] = set() + state.maybe_mutable_buffers[fqn].add(spec) + + +def _check_default_mem_ids(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + for spec in collect_specs_from_nodes( + filter_nodes(itertools.chain([node], node.args, node.kwargs.values())), + None, + ignore_graph_input=False, + ignore_const=False, + ignore_out_var_node=False, + dedup=False, + do_assertion=False, + ignore_dynamic_unbound_tensor=False, + ): + mem_id = getattr(spec, "mem_id", None) + if mem_id is not None and mem_id != 1: + raise ValueError( + "Cannot share mutable buffers if all other tensors are not on the default mem_id of 1" + ) + + +@dataclass +class _MemoryPlanningState: + mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict) + maybe_mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict) + graph_modules: List[torch.fx.GraphModule] = field(default_factory=list) + + class MemoryPlanningPass(PassBase): def __init__( self, @@ -45,6 +151,7 @@ def __init__( alloc_graph_input: bool = True, alloc_graph_output: bool = True, alloc_mutable_buffers: bool = True, + share_mutable_buffers: bool = False, alignment: int = ALIGNMENT, ) -> None: r""" @@ -55,12 +162,18 @@ def __init__( """ if memory_planning_algo is None: memory_planning_algo = MemoryPlanningAlgorithmSuite() + if share_mutable_buffers and not alloc_mutable_buffers: + raise ValueError( + "share_mutable_buffers is only meaningful when alloc_mutable_buffers is True" + ) self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap self.alloc_graph_input = alloc_graph_input self.alloc_graph_output = alloc_graph_output self.alloc_mutable_buffers = alloc_mutable_buffers + self.share_mutable_buffers = share_mutable_buffers self.alignment = alignment + self.state = _MemoryPlanningState() def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: """ @@ -134,9 +247,17 @@ def run( graph_signature, self.alloc_graph_input, self.alloc_graph_output, - self.alloc_mutable_buffers, + # If we are sharing the mutable buffers then do not allocate them in + # memory planning algo, instead collect all of the specs over all the entry + # points and then allocate them directly in the run_multimethod name call + self.alloc_mutable_buffers and not self.share_mutable_buffers, ) + if self.share_mutable_buffers and graph_signature is not None: + self.state.graph_modules.append(graph_module) + _check_default_mem_ids(graph_module) + _insert_mutable_buffer_specs(self.state, graph_module, graph_signature) + # TODO: make the verifier do the work recursively to handle # control flow verifier = Verifier( @@ -164,3 +285,31 @@ def run( # I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function verifier.verify_storage_reuse() return PassResult(graph_module, True) + + def run_multimethod(self): + "Resolve any memory planning done across entry points" + if self.share_mutable_buffers: + arena: int = 0 + + # Every spec that shares an fqn is the same tensor! So we give it the same id and offset + # anywhere it appears. + for fqn, specs_set in self.state.mutable_buffers.items(): + specs = list(specs_set) + # If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable. + if fqn in self.state.maybe_mutable_buffers.keys(): + specs.extend(self.state.maybe_mutable_buffers[fqn]) + for spec in specs: + # Assume a default memory planning placed all activations on 1, place shared state on 2. + spec.mem_id = 2 + spec.realign(self.alignment) + # State is persistent, so the memory never overlaps. + spec.mem_offset = arena + # They should all be the same size since they are the same tensor, so just bump off the first. + arena += specs[0].allocated_memory + + for graph_module in self.state.graph_modules: + if len(graph_module.meta["non_const_buffer_sizes"]) != 2: + raise ValueError( + "Cannot share mutable state if not using default memory ids" + ) + graph_module.meta["non_const_buffer_sizes"].append(arena) diff --git a/exir/program/_program.py b/exir/program/_program.py index f3d9eef9221..a33d715ca3b 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1681,7 +1681,7 @@ def to_backend( return epm @et_logger("to_executorch") - def to_executorch( + def to_executorch( # noqa (FLAKE8) C901 self, config: Optional[ExecutorchBackendConfig] = None, ) -> "ExecutorchProgramManager": @@ -1745,11 +1745,9 @@ def to_executorch( memory_planning_pass = config.memory_planning_pass # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work if hasattr(memory_planning_pass, "run"): - new_gm_res = memory_planning_pass.run( # pyre-ignore[16] - new_gm, new_signature - ) + new_gm_res = memory_planning_pass.run(new_gm, new_signature) else: - new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] + new_gm_res = memory_planning_pass(new_gm) # WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS. # THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER. @@ -1758,6 +1756,15 @@ def to_executorch( _copy_module(program.graph_module, new_gm) execution_programs[name] = program + # After running memory planning on all entry points we can run the cross entry point memory planning + if isinstance(config.memory_planning_pass, dict): + for memory_planning_pass in config.memory_planning_pass.values(): + if hasattr(memory_planning_pass, "run_multimethod"): + memory_planning_pass.run_multimethod() + else: + memory_planning_pass = config.memory_planning_pass + if hasattr(memory_planning_pass, "run_multimethod"): + memory_planning_pass.run_multimethod() et_pm = ExecutorchProgramManager( execution_programs, diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 426cc54dc66..ce20de8f820 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -14,6 +14,7 @@ import torch from executorch.exir import ExecutorchBackendConfig, to_edge +from executorch.exir.capture._capture import patch_forward from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import ( _do_user_inputs_exist, @@ -93,6 +94,24 @@ def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: return (torch.randn(10), torch.randn(10)) +class MultiEntryPointStatefulModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("state", torch.zeros(2, 2)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.state.add_(x).view(-1) * 2 + + def set_state(self, state: torch.Tensor) -> None: + self.state.copy_(state) + + def get_state(self) -> torch.Tensor: + return self.state + + def get_example_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.ones(1),) + + class ModelWithDifferentTensorSizes(torch.nn.Module): def __init__(self) -> None: super(ModelWithDifferentTensorSizes, self).__init__() @@ -1081,3 +1100,36 @@ def test_multi_map(self) -> None: verifier.storage_overlap(outer_spec, inner_spec), f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap", ) + + def test_multi_state_plan(self) -> None: + eager_module = MultiEntryPointStatefulModel().eval() + forward = export(eager_module, eager_module.get_example_inputs()) + with patch_forward(eager_module, eager_module.get_state): + get_state = export(eager_module, ()) + with patch_forward(eager_module, eager_module.set_state): + set_state = export(eager_module, (torch.zeros(1),)) + edge = to_edge( + {"forward": forward, "set_state": set_state, "get_state": get_state} + ) + et = edge.to_executorch( + ExecutorchBackendConfig( + memory_planning_pass=MemoryPlanningPass(share_mutable_buffers=True), + emit_mutable_buffer_names=True, + ) + ) + et_prog = et.executorch_program + count = 0 + for plan in et_prog.execution_plan: + for value in plan.values: + if ( + hasattr(value.val, "allocation_info") + and value.val.allocation_info is not None + and value.val.allocation_info.memory_id == 2 + ): + count += 1 + self.assertEqual(value.val.allocation_info.memory_offset_low, 0) + self.assertTrue(value.val.extra_tensor_info is not None) + self.assertEqual( + value.val.extra_tensor_info.fully_qualified_name, "state" + ) + self.assertEqual(count, 3)