diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 6995f9f73a9..7701ca7b8ff 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -93,7 +93,8 @@ from executorch.exir.types import LeafValueSpec, ValueSpec from torch._subclasses.fake_tensor import FakeTensor -from torch.export.exported_program import ExportedProgram +from torch.export.exported_program import ExportedProgram, ExportGraphSignature +from torch.fx.node import Node from torch.utils import _pytree as pytree from typing_extensions import TypeAlias @@ -209,11 +210,11 @@ class _AbstractValue: ] -# pyre-ignore[13]: Attribute `node` is never initialized. class _Emitter(torch.fx.Interpreter): """An abstract interpreter (https://wiki.mozilla.org/Abstract_Interpretation) used to emit the given traced torch.fx.GraphModule to the flatbuffer schema.""" + # pyre-ignore[13]: Attribute `node` is never initialized. node: torch.fx.Node def __init__( @@ -1633,6 +1634,28 @@ def placeholder( # noqa: C901 if isinstance(target, str) and isinstance(spec, TensorSpec): fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec) + def _is_buffer(node: Node, graph_signature: ExportGraphSignature) -> bool: + """ + Check if the node is buffer according to the provided graph signature. + If it is one return its fqn as well + """ + if node.op == "placeholder": + if isinstance(node.target, str): + if node.target in graph_signature.inputs_to_buffers: + return True + return False + + # If the spec does not appear in the mutable section of the graph signature it still might + # overall be considered a mutable buffer if it has already been memory planned. This would + # suggest that the same abstract buffer is mutable in another entry point so we should + # compel it to be considered mutable in all entry points at emission just as the user did with + # memory planning. + is_mutable_buffer |= ( + _is_buffer(self.node, self.exported_program.graph_signature) + and spec.mem_id is not None + and spec.mem_offset is not None + ) + # If the placeholder has a constant_tag, it is external to the PTE file # and requires a fqn and location=TensorDataLocation.EXTERNAL if constant_tag is not None: diff --git a/exir/memory_planning.py b/exir/memory_planning.py index e08d3e55772..0394ed9c529 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -245,6 +245,8 @@ def verify_graph_input_output(self) -> None: assert len(specs) > 0, "Expect tensor specs" specs = list(filter(lambda spec: not spec.const, specs)) if len(specs) == 0: + # all outputs are const so no need to allocate memory just say we suceeded + graph_output_allocated = self.alloc_graph_output continue allocated = any( spec is None or spec.mem_offset is not None for spec in specs @@ -408,6 +410,7 @@ def collect_specs_from_nodes( # noqa: C901 ignore_graph_input: bool = False, ignore_graph_output: bool = False, ignore_mutable_buffers: bool = False, + share_mutable_buffers: bool = False, ignore_const: bool = True, ignore_out_var_node: bool = True, dedup: bool = True, diff --git a/exir/passes/memory_planning_pass.py b/exir/passes/memory_planning_pass.py index 9bd4ab20bf5..2636b61780c 100644 --- a/exir/passes/memory_planning_pass.py +++ b/exir/passes/memory_planning_pass.py @@ -4,10 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import itertools import logging import warnings +from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Set, Tuple import torch from executorch.exir._warnings import deprecated @@ -16,14 +18,18 @@ from executorch.exir.memory_planning import ( _is_out_var_node, apply_algo, + collect_specs_from_nodes, + filter_nodes, get_node_tensor_specs, MemoryPlanningAlgorithmSuite, Verifier, ) from executorch.exir.operator.convert import get_out_args_from_opoverload from executorch.exir.pass_base import PassBase, PassResult -from executorch.exir.tensor import ALIGNMENT +from executorch.exir.tensor import ALIGNMENT, TensorSpec +from torch import fx from torch.export.exported_program import ExportGraphSignature +from torch.fx import Node # copied from https://stackoverflow.com/questions/75582932/python-how-can-i-print-the-function-name-of-a-partial-function @@ -37,6 +43,106 @@ def _callable_name(any_callable: Callable[..., Any]) -> str: return str(any_callable) +def _is_buffer( + node: Node, graph_signature: ExportGraphSignature +) -> Tuple[bool, Optional[str]]: + """ + Check if the node is buffer according to the provided graph signature. + If it is one return its fqn as well + """ + if node.op == "placeholder": + if isinstance(node.target, str): + if node.target in graph_signature.inputs_to_buffers: + fqn = graph_signature.inputs_to_buffers[node.target] + return (True, fqn) + return (False, None) + + +def _is_mutable_buffer( + node: Node, graph_signature: ExportGraphSignature +) -> Tuple[bool, Optional[str]]: + """ + Check if the node is mutable buffer according to the provided graph signature. + If it is one return its fqn as well + """ + if node.op == "placeholder": + if isinstance(node.target, str): + if node.target in graph_signature.inputs_to_buffers: + fqn = graph_signature.inputs_to_buffers[node.target] + # if the buffer is mutated then record that + if fqn in graph_signature.buffers_to_mutate.values(): + return True, fqn + return False, None + + +def _get_spec_from_node(node: fx.Node) -> TensorSpec: + specs = get_node_tensor_specs(node) + return specs[0] + + +def _insert_mutable_buffer_specs( + state: "_MemoryPlanningState", gm: torch.fx.GraphModule, gs: ExportGraphSignature +): + for node in gm.graph.nodes: + is_mutable, fqn = _is_mutable_buffer(node, gs) + if is_mutable: + assert fqn + spec = _get_spec_from_node(node) + if ( + getattr(spec, "mem_id", None) is not None + or getattr(spec, "mem_offset", None) is not None + ): + raise ValueError( + "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned" + ) + if fqn not in state.mutable_buffers.keys(): + state.mutable_buffers[fqn] = set() + state.mutable_buffers[fqn].add(spec) + continue + is_buffer, fqn = _is_buffer(node, gs) + # If it is not a mutable buffer it might just appear to be a buffer in this entry point. Think model.get_state() + # So cache it and later double check that this buffer never appears mutable + if is_buffer: + assert fqn + spec = _get_spec_from_node(node) + if ( + getattr(spec, "mem_id", None) is not None + or getattr(spec, "mem_offset", None) is not None + ): + raise ValueError( + "Cannot share mutable buffers if they already have a mem_id or mem_offset assigned" + ) + if fqn not in state.maybe_mutable_buffers.keys(): + state.maybe_mutable_buffers[fqn] = set() + state.maybe_mutable_buffers[fqn].add(spec) + + +def _check_default_mem_ids(gm: torch.fx.GraphModule): + for node in gm.graph.nodes: + for spec in collect_specs_from_nodes( + filter_nodes(itertools.chain([node], node.args, node.kwargs.values())), + None, + ignore_graph_input=False, + ignore_const=False, + ignore_out_var_node=False, + dedup=False, + do_assertion=False, + ignore_dynamic_unbound_tensor=False, + ): + mem_id = getattr(spec, "mem_id", None) + if mem_id is not None and mem_id != 1: + raise ValueError( + "Cannot share mutable buffers if all other tensors are not on the default mem_id of 1" + ) + + +@dataclass +class _MemoryPlanningState: + mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict) + maybe_mutable_buffers: Dict[str, Set[TensorSpec]] = field(default_factory=dict) + graph_modules: List[torch.fx.GraphModule] = field(default_factory=list) + + class MemoryPlanningPass(PassBase): def __init__( self, @@ -45,6 +151,7 @@ def __init__( alloc_graph_input: bool = True, alloc_graph_output: bool = True, alloc_mutable_buffers: bool = True, + share_mutable_buffers: bool = False, alignment: int = ALIGNMENT, ) -> None: r""" @@ -55,12 +162,18 @@ def __init__( """ if memory_planning_algo is None: memory_planning_algo = MemoryPlanningAlgorithmSuite() + if share_mutable_buffers and not alloc_mutable_buffers: + raise ValueError( + "share_mutable_buffers is only meaningful when alloc_mutable_buffers is True" + ) self.memory_planning_algo: Callable[..., List[int]] = memory_planning_algo self.allow_lifetime_and_storage_overlap = allow_lifetime_and_storage_overlap self.alloc_graph_input = alloc_graph_input self.alloc_graph_output = alloc_graph_output self.alloc_mutable_buffers = alloc_mutable_buffers + self.share_mutable_buffers = share_mutable_buffers self.alignment = alignment + self.state = _MemoryPlanningState() def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None: """ @@ -134,9 +247,17 @@ def run( graph_signature, self.alloc_graph_input, self.alloc_graph_output, - self.alloc_mutable_buffers, + # If we are sharing the mutable buffers then do not allocate them in + # memory planning algo, instead collect all of the specs over all the entry + # points and then allocate them directly in the run_multimethod name call + self.alloc_mutable_buffers and not self.share_mutable_buffers, ) + if self.share_mutable_buffers and graph_signature is not None: + self.state.graph_modules.append(graph_module) + _check_default_mem_ids(graph_module) + _insert_mutable_buffer_specs(self.state, graph_module, graph_signature) + # TODO: make the verifier do the work recursively to handle # control flow verifier = Verifier( @@ -164,3 +285,31 @@ def run( # I dont know if that is a valid thing but if it is we should adjust verify_storage_reuse function verifier.verify_storage_reuse() return PassResult(graph_module, True) + + def run_multimethod(self): + "Resolve any memory planning done across entry points" + if self.share_mutable_buffers: + arena: int = 0 + + # Every spec that shares an fqn is the same tensor! So we give it the same id and offset + # anywhere it appears. + for fqn, specs_set in self.state.mutable_buffers.items(): + specs = list(specs_set) + # If the same buffer appears in mutable and maybe mutable then we know it is in fact mutable. + if fqn in self.state.maybe_mutable_buffers.keys(): + specs.extend(self.state.maybe_mutable_buffers[fqn]) + for spec in specs: + # Assume a default memory planning placed all activations on 1, place shared state on 2. + spec.mem_id = 2 + spec.realign(self.alignment) + # State is persistent, so the memory never overlaps. + spec.mem_offset = arena + # They should all be the same size since they are the same tensor, so just bump off the first. + arena += specs[0].allocated_memory + + for graph_module in self.state.graph_modules: + if len(graph_module.meta["non_const_buffer_sizes"]) != 2: + raise ValueError( + "Cannot share mutable state if not using default memory ids" + ) + graph_module.meta["non_const_buffer_sizes"].append(arena) diff --git a/exir/program/_program.py b/exir/program/_program.py index f3d9eef9221..a33d715ca3b 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -1681,7 +1681,7 @@ def to_backend( return epm @et_logger("to_executorch") - def to_executorch( + def to_executorch( # noqa (FLAKE8) C901 self, config: Optional[ExecutorchBackendConfig] = None, ) -> "ExecutorchProgramManager": @@ -1745,11 +1745,9 @@ def to_executorch( memory_planning_pass = config.memory_planning_pass # TODO(jakeszwe): Follow up with compiler on if the deepcopy is necessary and if so how to make it work if hasattr(memory_planning_pass, "run"): - new_gm_res = memory_planning_pass.run( # pyre-ignore[16] - new_gm, new_signature - ) + new_gm_res = memory_planning_pass.run(new_gm, new_signature) else: - new_gm_res = memory_planning_pass(new_gm) # pyre-ignore[29] + new_gm_res = memory_planning_pass(new_gm) # WARNING: DO NOT ADD ANY MORE PASSES AFTER MEMORY PLANNING PASS. # THERE ARE A LOT OF ASSUMPTIONS IN THE STACK THAT MEMORY PLANNING IS THE LAST PASS BEFORE THE EMITTER. @@ -1758,6 +1756,15 @@ def to_executorch( _copy_module(program.graph_module, new_gm) execution_programs[name] = program + # After running memory planning on all entry points we can run the cross entry point memory planning + if isinstance(config.memory_planning_pass, dict): + for memory_planning_pass in config.memory_planning_pass.values(): + if hasattr(memory_planning_pass, "run_multimethod"): + memory_planning_pass.run_multimethod() + else: + memory_planning_pass = config.memory_planning_pass + if hasattr(memory_planning_pass, "run_multimethod"): + memory_planning_pass.run_multimethod() et_pm = ExecutorchProgramManager( execution_programs, diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index 426cc54dc66..ce20de8f820 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -14,6 +14,7 @@ import torch from executorch.exir import ExecutorchBackendConfig, to_edge +from executorch.exir.capture._capture import patch_forward from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import ( _do_user_inputs_exist, @@ -93,6 +94,24 @@ def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: return (torch.randn(10), torch.randn(10)) +class MultiEntryPointStatefulModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("state", torch.zeros(2, 2)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.state.add_(x).view(-1) * 2 + + def set_state(self, state: torch.Tensor) -> None: + self.state.copy_(state) + + def get_state(self) -> torch.Tensor: + return self.state + + def get_example_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.ones(1),) + + class ModelWithDifferentTensorSizes(torch.nn.Module): def __init__(self) -> None: super(ModelWithDifferentTensorSizes, self).__init__() @@ -1081,3 +1100,36 @@ def test_multi_map(self) -> None: verifier.storage_overlap(outer_spec, inner_spec), f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap", ) + + def test_multi_state_plan(self) -> None: + eager_module = MultiEntryPointStatefulModel().eval() + forward = export(eager_module, eager_module.get_example_inputs()) + with patch_forward(eager_module, eager_module.get_state): + get_state = export(eager_module, ()) + with patch_forward(eager_module, eager_module.set_state): + set_state = export(eager_module, (torch.zeros(1),)) + edge = to_edge( + {"forward": forward, "set_state": set_state, "get_state": get_state} + ) + et = edge.to_executorch( + ExecutorchBackendConfig( + memory_planning_pass=MemoryPlanningPass(share_mutable_buffers=True), + emit_mutable_buffer_names=True, + ) + ) + et_prog = et.executorch_program + count = 0 + for plan in et_prog.execution_plan: + for value in plan.values: + if ( + hasattr(value.val, "allocation_info") + and value.val.allocation_info is not None + and value.val.allocation_info.memory_id == 2 + ): + count += 1 + self.assertEqual(value.val.allocation_info.memory_offset_low, 0) + self.assertTrue(value.val.extra_tensor_info is not None) + self.assertEqual( + value.val.extra_tensor_info.fully_qualified_name, "state" + ) + self.assertEqual(count, 3)