From 8eb99327c78a29dcf15fdd013b006e1cdea8ca41 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 11 Apr 2024 09:43:15 -0700 Subject: [PATCH 1/2] fix et-view (#2843) Summary: et-view should always copy the data pointer. Reviewed By: JacobSzwejbka Differential Revision: D55715318 --- kernels/prim_ops/et_view.cpp | 13 +------------ kernels/prim_ops/test/prim_ops_test.cpp | 8 +++----- 2 files changed, 4 insertions(+), 17 deletions(-) diff --git a/kernels/prim_ops/et_view.cpp b/kernels/prim_ops/et_view.cpp index 69a75170260..b3d3592fe7b 100644 --- a/kernels/prim_ops/et_view.cpp +++ b/kernels/prim_ops/et_view.cpp @@ -87,18 +87,7 @@ void et_view(RuntimeContext& context, EValue** stack) { // Do some checks ET_CHECK(self.numel() == out.numel()); - // If out has a data_ptr, it must match self - // We hit this path for memory-planned tensors - if (out.const_data_ptr() != nullptr) { - ET_CHECK_MSG( - self.const_data_ptr() == out.const_data_ptr(), - "out has a non-null data_ptr, but it does not equal self's data_ptr."); - - // nothing else to do - return; - } - - // out.const_data_ptr() == nullptr now + // Update data ptr ET_CHECK_MSG( internal::set_tensor_data( out, diff --git a/kernels/prim_ops/test/prim_ops_test.cpp b/kernels/prim_ops/test/prim_ops_test.cpp index fdcc13cf13e..7d91a0f6820 100644 --- a/kernels/prim_ops/test/prim_ops_test.cpp +++ b/kernels/prim_ops/test/prim_ops_test.cpp @@ -331,14 +331,13 @@ TEST_F(RegisterPrimOpsTest, TestETView) { EValue(good_outs[0]), EValue(good_outs[1])}; // bad outs expect death - constexpr int N_BAD_OUTS = 3; + constexpr int N_BAD_OUTS = 2; Tensor bad_outs[N_BAD_OUTS] = { tf.ones({1, 3, 2, 1}), // wrong rank - tf.ones({1, 3, 3}), // wrong size - tf.ones({1, 3, 2}) // occupied data_ptr + tf.ones({1, 3, 3}) // wrong size }; EValue bad_out_evalues[N_BAD_OUTS] = { - EValue(bad_outs[0]), EValue(bad_outs[1]), EValue(bad_outs[2])}; + EValue(bad_outs[0]), EValue(bad_outs[1])}; // *************************************************************************** // Run tests @@ -349,7 +348,6 @@ TEST_F(RegisterPrimOpsTest, TestETView) { // Bad out stacks {&self_evalue, &size_int_list_evalue, &bad_out_evalues[0]}, {&self_evalue, &size_int_list_evalue, &bad_out_evalues[1]}, - {&self_evalue, &size_int_list_evalue, &bad_out_evalues[2]}, // Bad size stacks {&self_evalue, &bad_size_int_list_evalue1, &good_out_evalues[0]}, {&self_evalue, &bad_size_int_list_evalue2, &good_out_evalues[0]}}; From 3170cd888eb31e4d44c1565a1d34735422c2c74a Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 11 Apr 2024 09:43:15 -0700 Subject: [PATCH 2/2] Replace view copy with view (3/3) (#2463) Summary: Design: https://docs.google.com/document/d/1l9x925EOrE8mHFJdRCC59nBJXyqBdnoeK-EgNQScXD0/edit#heading=h.kocb2mvchnib This stack replaces view_copy nodes with memory.view nodes. In the first diff (D54816555), I write a pass to normalize view_copy nodes by making their base point to the upstream non-view node. This means if we have something like op -> view_copy1 -> view_copy2, then after normalization, both view copies will point to op in their base (assuming op is not a view node). Note that this pass combined with dead-code elimination removes redundant view copies. This is because a redundant view copy will have no users have this pass. In the second diff (D54827305), I write a pass to convert view_copy nodes to memory.view nodes. A memory.view is similar to torch.ops.aten.view.default, but it is its own function so that we can handle it specially during memory planning and emission. A memory.view node has a special TensorSpec of type _MemoryViewSpec. This spec is immutable and dynamically looks up non-size related fields from its base's TensorSpec. Because it is immutable, fields on a _MemoryViewSpec cannot be set, but if a field is updated on the base spec, this update is reflected in the memory.view node's _MemoryViewSpec. Not all view_copy nodes are converted to memory.view nodes. Only static nodes that are memory planned are converted. Not all static nodes are memory planned in ExecuTorch. For example, there is an option to turn off memory planning for input nodes, and outputs from some higher order ops like cond are not memory planned. Which nodes are memory planned is not easily available, and I did not try to cover all cases of nodes that can be converted. We can expand this list over time. In the third diff (D54827438), I implement the actual view_copy elimination. In the ExecutorchBackendConfig, there is a new option remove_static_view_copy. If remove_static_view_copy = True, the memory planning passes are [NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass]; if remove_static_view_copy = False, the memory planning passes are [config.to_out_var_pass, config.memory_planning_pass] (state today). Let's look at the flow when remove_static_view_copy = True: NormalizeViewCopyBasePass(), ReplaceViewCopyWithMemoryViewPass(), config.to_out_var_pass, config.memory_planning_pass. The first two steps are the just the first and second diff described above. In config.to_out_var_pass, the memory.view nodes are skipped. In config.memory_planning_pass, when a spec is requested for a memory.view node (e.g., to update the lifetime), we return the spec of its base. Returning the spec for the base means that whenever we see a memory.view node, we actually update the lifetime of the base to cover it. Moreover, the memory.view node's special _MemoryViewSpec sees this update reflected. (Note that an exception would be thrown if we kept the usual flow and returned the spec for the memory.view node. This is because the special _MemoryViewSpec is immutable and would not allow the memory_planning_pass to update its lifetime.) Finally, during emission the memory.view is emitted as an evalue. There are two more diffs on the stack D54866523 and D54866539. The first of these replaces the old RemoveRedundantViewCopy pass with a NormalizeViewCopyBasePass + dead code elimination. The second converts view-like ops (squeeze, unsqueeze, slice) to view ops when safe to do so to take advantage of the view_copy elimination. Reviewed By: larryliu0820 Differential Revision: D54827438 --- examples/selective_build/CMakeLists.txt | 6 +- exir/capture/_config.py | 4 + exir/emit/_emitter.py | 29 +++ exir/emit/test/test_emit.py | 16 +- exir/memory_planning.py | 9 +- exir/passes/__init__.py | 1 + .../replace_view_copy_with_view_pass.py | 245 +++++++++++++----- exir/program/TARGETS | 2 + exir/program/_program.py | 29 ++- exir/tests/TARGETS | 14 + exir/tests/test_passes.py | 68 ++--- exir/tests/test_quant_fusion_pass.py | 2 +- exir/tests/test_remove_view_copy.py | 202 +++++++++++++++ 13 files changed, 505 insertions(+), 122 deletions(-) create mode 100644 exir/tests/test_remove_view_copy.py diff --git a/examples/selective_build/CMakeLists.txt b/examples/selective_build/CMakeLists.txt index 29791187185..239cdc828de 100644 --- a/examples/selective_build/CMakeLists.txt +++ b/examples/selective_build/CMakeLists.txt @@ -118,7 +118,11 @@ add_executable(selective_build_test ${_executor_runner__srcs}) if(CMAKE_BUILD_TYPE EQUAL "RELEASE") target_link_options(selective_build_test PRIVATE "LINKER:--gc-sections") endif() -target_link_libraries(selective_build_test executorch gflags select_build_lib) +target_link_libraries( + selective_build_test PRIVATE executorch gflags select_build_lib +) +target_link_options_shared_lib(select_build_lib) +target_link_options_shared_lib(executorch) target_compile_options(selective_build_test PUBLIC ${_common_compile_options}) # Print all summary diff --git a/exir/capture/_config.py b/exir/capture/_config.py index d743e4b0329..a2d3b53bcb6 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -75,3 +75,7 @@ class ExecutorchBackendConfig: # be a power of 2. If not provided, uses the value in the schema file. delegate_alignment: Optional[int] = None sym_shape_eval_pass: PassType = HintBasedSymShapeEvalPass() + + # If set to true, view_copy operations will be converted to lightweight + # view operations in the ET runtime + remove_view_copy: bool = True diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index af5614bf208..3238c23eda0 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -844,6 +844,32 @@ def _emit_control_flow( ) ) + def _emit_view(self, args: Tuple[_Argument, ...]) -> _EmitterValue: + assert len(args) == 2 + + self_arg = self._emit_argument(args[0], torch.TensorType) # pyre-ignore[6] + size_arg = self._emit_argument(args[1], torch.ListType.ofInts()) + out_arg = self._emit_argument( + self._emit_spec(self.node.meta["spec"]), torch.TensorType # pyre-ignore[6] + ) + + op_idx, op = self._get_operator( + name="executorch_prim::et_view", + overload="default", + ) + kernel = Instruction( + KernelCall( + op_idx, + args=[ + self_arg.id, + size_arg.id, + out_arg.id, + ], + ) + ) + self.chain.instructions.append(kernel) + return out_arg + def _add_debug_handle(self, emitter_id: int, target: _Target) -> None: """Updates the debug handle information for the current node. @@ -1198,6 +1224,9 @@ def call_function( assert len(args) == 1 return self._emit_spec(self.node.meta["spec"]) + elif target == memory.view: + return self._emit_view(args) + elif target == memory.free: assert len(args) == 1 # pyre-ignore diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 3eebe52faef..b55fb5e5dae 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -265,16 +265,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: edge = to_edge(export(f, inputs)) removed_ops = ["aten::relu_", "aten::view"] - expected_ops = ["aten::sin", "aten::relu", "aten::max", "aten::view_copy"] + expected_ops = [ + "aten::sin", + "aten::relu", + "aten::max", + "executorch_prim::et_view", # aten::view_copy if ExecutorchBackendConfig.remove_view_copy = False + ] for opname in removed_ops: self.assertEqual( self.count_node(edge.exported_program().graph_module, opname), 0 ) for opname in expected_ops: - self.assertTrue( - self.count_node(edge.exported_program().graph_module, opname) >= 1 - ) + if ( + opname != "executorch_prim::et_view" + ): # et_view appears as call_function with target = memory.view in graph + self.assertTrue( + self.count_node(edge.exported_program().graph_module, opname) >= 1 + ) program = edge.to_executorch().executorch_program for opname in removed_ops: diff --git a/exir/memory_planning.py b/exir/memory_planning.py index b8c47b440c5..675f196fcd8 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -397,6 +397,7 @@ def collect_specs_from_nodes( # noqa: C901 or node.target in [ memory.alloc, + memory.view, operator.getitem, torch.ops.higher_order.cond, exir_while, @@ -534,7 +535,13 @@ def get_node_tensor_specs( has no tensor specs. """ # get tensor specs - specs = node.meta.get("spec") + if node.target == memory.view: + base = node.args[0] + assert isinstance(base, torch.fx.Node) + specs = base.meta.get("spec") + else: + specs = node.meta.get("spec") + if isinstance(specs, TensorSpec): specs = [specs] if not isinstance(specs, (list, tuple)): diff --git a/exir/passes/__init__.py b/exir/passes/__init__.py index 2611d6a1541..f43b2973a4e 100644 --- a/exir/passes/__init__.py +++ b/exir/passes/__init__.py @@ -248,6 +248,7 @@ def callWithLoggerEnabled(self, graph_module: torch.fx.GraphModule) -> None: # we won't see it in the input graph to the to_out_variant pass, unless # it's retraced after running to_out_variant with the first trace. memory.alloc, + memory.view, executorch_call_delegate, torch.ops.aten.copy_.default, } diff --git a/exir/passes/replace_view_copy_with_view_pass.py b/exir/passes/replace_view_copy_with_view_pass.py index 33f98304174..a9304f3eec8 100644 --- a/exir/passes/replace_view_copy_with_view_pass.py +++ b/exir/passes/replace_view_copy_with_view_pass.py @@ -6,9 +6,9 @@ # pyre-strict +import copy import logging -import math -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple import torch from executorch.exir import memory @@ -36,28 +36,113 @@ def _is_view_copy(node: torch.fx.Node) -> bool: _VIEW_OP = memory.view +class _Guard: + def __init__( + self, name: str, field_lambda, expected_val: Any # pyre-ignore[2] + ) -> None: + self.name: str = name + self.field_lambda = field_lambda # pyre-ignore[4] + self.expected_val = copy.deepcopy(expected_val) # pyre-ignore[4] + + def __call__(self, view_spec) -> None: # pyre-ignore[2] + assert view_spec._unguarded_access + observed_val = self.field_lambda(view_spec) + if observed_val != self.expected_val: + raise Exception( + f"Guard {self.name} failed. Expected to see value {self.expected_val}, but saw value {observed_val}." + ) + + class _ViewSpec(TensorSpec): def __init__(self, base: TensorSpec, shape: List[int]) -> None: """ - A ViewSpec is an immutable TensorSpec that mirrors its base for non-size - related information. - """ + A _ViewSpec is TensorSpec that shares non-size related fields with its base. + The size-related fields are: shape, stride, dim_order, and shape_dynamism. - if math.prod(base.shape) != math.prod(shape): - raise Exception( - f"Cannot create a ViewSpec because the provided shape {shape} is not consistent with the number of elements in the provided base ({math.prod(base.shape)})." - ) + If either the base or view spec updates a non-size related field, the change + is reflected in both specs. But size related fields are not linked and can + be set separately. - self._init_setters = [ - "_frozen", - "_base", - "_guards", + A _ViewSpec can only be created from a non-sparse, strided TensorSpec. + On creation, a _ViewSpec must be compatible with its base with respect to + shape_dynamism, dtype, and nbytes. + + A _ViewSpec contains _guards that are evaluated on every __getattribute__ call. + The purpose of the guards is to make sure the _ViewSpec is still compatible + with its base. + """ + + # Explicitly put all attributes into _self_fields or _base_fields + # Any attribute that is not in _self_fields or _base_fields will + # raise an Exception. If TensorSpec is extended with a new attribute, + # we should explicitly decide how _ViewSpec will handle it. + self._self_fields = [ + # We need to get the debug method from self + # so that the object id it prints is correct. + "debug", # method + "__repr__", # method + # The following are related to size and should use self "shape", "stride", "dim_order", "shape_dynamism", + "nbytes", # method + "allocated_memory", # property + "is_dynamic_shape_tensor", # property + "is_static_shape_tensor", # property + "is_upper_bound_tensor", # property + "is_dynamic_unbound_tensor", # property + ] + self._base_fields = [ + "scalar_type", + "const", + "alignment", + "storage", + "requires_grad", + "layout", + "is_sparse", + "init_mem_planning_fields", # method + "realign", # method + "from_tensor", # class method + "lifetime", + "mem_id", + "mem_obj_id", + "mem_offset", + "dtype", # property ] - self._frozen = False + + # Make sure _self_fields and _base_fields are disjoint + assert len(set(self._self_fields) & set(self._base_fields)) == 0 + + self._guards: List[_Guard] = [] + self._unguarded_access = False + + # Make sure base is not sparse and add a guard + if base.is_sparse: + raise Exception( + "_ViewSpec can only be created from non-sparse TensorSpec, but base.is_sparse=True." + ) + self._guards.append( + _Guard( + "is_sparse", + lambda view_spec: view_spec.is_sparse, + False, + ) + ) + + # Make sure base layout is strided and add a guard + if base.layout != torch.strided: + raise Exception( + f"_ViewSpec can only be created from TensorSpec with layout={torch.strided}, but got layout={base.layout}." + ) + self._guards.append( + _Guard( + "layout", + lambda view_spec: view_spec.layout, + torch.strided, + ) + ) + self._base = base self.shape: List[int] = shape self.stride: Tuple[int] = contiguous_stride_from_shape(torch.Size(self.shape)) @@ -66,66 +151,108 @@ def __init__(self, base: TensorSpec, shape: List[int]) -> None: torch.Size(self.shape) ) - # This spec gives a view into its base. - # The base can be modified (e.g., mem_id) and this spec will - # update accordingly, but certain fields we do not expect to change - # We create guards for these - self._guards: Dict[str, Any] = { - "shape_dynamism": base.shape_dynamism, - "scalar_type": base.scalar_type, - "layout": base.layout, - "is_sparse": base.is_sparse, - } - self._frozen = True - - def _check_guards(self) -> None: - for name in self._guards: - if getattr(self._base, name) != self._guards[name]: - raise Exception( - f"The guarded attribute '{name}' has changed value. At creation of the ViewSpec, it was {self._guards[name]}, but it is now {getattr(self._base, name)}." - ) + # Check compatibility with base on creation + if self.shape_dynamism != base.shape_dynamism: + raise Exception( + f"_ViewSpec is incompatible with its base on creation. It has shape_dynamism={self.shape_dynamism}, but its base has shape_dynamism={base.shape_dynamism}." + ) + self._guards.append( + _Guard( + "shape_dynamism_init", + lambda view_spec: view_spec.shape_dynamism, + base.shape_dynamism, + ) + ) + self._guards.append( + _Guard( + "shape_dynamism_eq_base", + lambda view_spec: view_spec.shape_dynamism + == view_spec._base.shape_dynamism, + True, + ) + ) + + if self.dtype != base.dtype: + raise Exception( + f"_ViewSpec is incompatible with its base on creation. It has dtype={self.dtype}, but its base has dtype={base.dtype}." + ) + self._guards.append( + _Guard("dtype", lambda view_spec: view_spec.dtype, base.dtype) + ) + + # We do not guard nbytes because dynamic symints are replaced by upper bounds. + # We do guard on rank, though + if self.nbytes() != base.nbytes(): + raise Exception( + f"_ViewSpec is incompatible with its base on creation. It has nbytes={self.nbytes()}, but its base has nbytes={base.nbytes()}." + ) + self._guards.append( + _Guard("rank", lambda view_spec: len(view_spec.shape), len(shape)) + ) - def __getattribute__(self, name): # pyre-ignore + def _run_guards(self) -> None: + unguarded_access = self._unguarded_access + try: + self._unguarded_access = True + for g in self._guards: + g(self) + finally: + self._unguarded_access = unguarded_access + + def __getattribute__(self, name: str): # pyre-ignore + # Special field so we don't recurse infinitely if name in [ - "_init_setters", - "_frozen", "_base", + "_self_fields", + "_base_fields", "_guards", - "_check_guards", - # Adding debug is needed so that view_spec.debug() shows the right id in - # its string (if debug is excluded, it shows the id(view_spec._base) instead - # of id(view_spec)) - "debug", + "_unguarded_access", + "_run_guards", ]: return object.__getattribute__(self, name) - # Guard check after freeze - if self._frozen: - self._check_guards() + # Get some attributes from self + if name in self._self_fields: + val = object.__getattribute__(self, name) + elif name in self._base_fields: + val = object.__getattribute__(self._base, name) + else: + if len(name) > 0 and name[0] != "_": + logger.warning( + f"Getting non-private attribute {name} on self, but it is not in _self_fields or _base_fields. Is this intended?" + ) + val = object.__getattribute__(self, name) - # self._init_setters attributes come from self, others come from base - if name in self._init_setters: - return object.__getattribute__(self, name) - return getattr(self._base, name) + if not self._unguarded_access: + self._run_guards() + return val def __setattr__(self, name: str, val) -> None: # pyre-ignore - if name in ["_init_setters", "_frozen"]: + # Special field so we don't recurse infinitely + if name in [ + "_base", + "_self_fields", + "_base_fields", + "_guards", + "_unguarded_access", + "_run_guards", + ]: object.__setattr__(self, name, val) return - # Allow setting during initialization - if name in self._init_setters and not self._frozen: + if name in self._self_fields: object.__setattr__(self, name, val) return - if name in self._init_setters: - raise Exception( - f"ViewSpec is immutable. Cannot set the attribute '{name}' after creation." - ) + if name in self._base_fields: + object.__setattr__(self._base, name, val) + return - raise Exception( - f"ViewSpec is immutable. To update the non-size related attribute '{name}', update the base." - ) + if len(name) > 0 and name[0] != "_": + logger.warning( + f"Setting non-private attribute {name} on self, but it is not in _self_fields or _base_fields. Is this intended?" + ) + object.__setattr__(self, name, val) class ReplaceViewCopyWithViewPass(PassBase): @@ -151,8 +278,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: node.target = _VIEW_OP # Create spec for the node. - # _ViewSpec is an immutable TensorSpec gives a view into - # its base spec for non-size related information. + # _ViewSpec gives a view into its base spec for non-size + # related information. # the shape is not the same as node.args[1] because node.args[1] # can have an inferred sizes (-1). diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 49da0648a06..5ae3cf1ac59 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -33,8 +33,10 @@ python_library( "//executorch/exir/emit:lib", "//executorch/exir/passes:insert_write_back_for_buffers_pass", "//executorch/exir/passes:lib", + "//executorch/exir/passes:normalize_view_copy_base_pass", "//executorch/exir/passes:remove_graph_asserts_pass", "//executorch/exir/passes:remove_mixed_type_operators", + "//executorch/exir/passes:replace_view_copy_with_view_pass", "//executorch/exir/passes:spec_prop_pass", "//executorch/exir/verification:verifier", ], diff --git a/exir/program/_program.py b/exir/program/_program.py index 10d0043398f..086768b879d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -31,8 +31,14 @@ from executorch.exir.passes.insert_write_back_for_buffers_pass import ( insert_write_back_for_buffers_pass, ) +from executorch.exir.passes.normalize_view_copy_base_pass import ( + NormalizeViewCopyBasePass, +) from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators +from executorch.exir.passes.replace_view_copy_with_view_pass import ( + ReplaceViewCopyWithViewPass, +) from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.print_program import pretty_print, print_program from executorch.exir.schema import Program @@ -615,8 +621,24 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": return new_ep +def pre_memory_planning_passes(config: ExecutorchBackendConfig) -> List[PassType]: + if config.remove_view_copy: + # pyre-ignore + return [ + NormalizeViewCopyBasePass(), + ReplaceViewCopyWithViewPass(), + config.sym_shape_eval_pass, + config.to_out_var_pass, + ] + else: + # pyre-ignore + return [ + config.sym_shape_eval_pass, + config.to_out_var_pass, + ] + + def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]: - # pyre-ignore passes: List[PassType] = [ *config.passes, SpecPropPass(), @@ -625,9 +647,8 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType] # there exists an unbacked symint operation. EdgeToBackendOpsPass(), RemoveGraphAssertsPass(), - config.sym_shape_eval_pass, - config.to_out_var_pass, - ] + ] + pre_memory_planning_passes(config) + return passes diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 0c3232916d6..94a82d8a2bc 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -411,3 +411,17 @@ python_unittest( "//executorch/exir:print_program", ], ) + +python_unittest( + name = "test_remove_view_copy", + srcs = [ + "test_remove_view_copy.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:memory", + "//executorch/exir/capture:config", + "//executorch/exir/passes:lib", + ], +) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 848269c6573..bfa0d393235 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1498,61 +1498,29 @@ def __init__(self): self.parameter = torch.nn.Parameter(torch.ones(1)) def forward(self, x): - o1 = torch.ops.aten.view_copy.default( - self.parameter, [1] - ) # replaceable parameter - o2 = torch.ops.aten.view_copy.default(x, [1]) # replaceable user input - o3 = torch.ops.aten.view_copy.default( - torch.ops.aten.relu.default(x), [1] - ) # replaceable dynamic unbound - o4 = torch.ops.aten.view_copy.default( - torch.ops.aten.gelu.default(x), [1] - ) # replaceable dynamic bound - o5 = torch.ops.aten.view_copy.default( - torch.ops.aten.tanh.default(x), [1] - ) # replaceable static - return o1, o2, o3, o4, o5 + o1 = torch.ops.aten.view_copy.default(x, [1]) + o2 = torch.ops.aten.view_copy.default(self.parameter, [1]) + return o1, o2 ep = torch.export.export( TestViewCopies(), args=(torch.ones(1),), ) - self.assertEqual(len(ep.graph.nodes), 11) for node in ep.graph.nodes: if node.op == "placeholder": node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC - elif node.target == torch.ops.aten.relu.default: - node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) - node.meta["spec"].shape_dynamism = TensorShapeDynamism.DYNAMIC_UNBOUND - elif node.target == torch.ops.aten.gelu.default: - node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) - node.meta["spec"].shape_dynamism = TensorShapeDynamism.DYNAMIC_BOUND - elif node.target == torch.ops.aten.tanh.default: - node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) - node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC - elif node.target == torch.ops.aten.view_copy.default: - node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1)) - node.meta["spec"].shape_dynamism = ( - node.args[0].meta["spec"].shape_dynamism - ) - else: - pass # Run tests gm = ep.graph_module # Check before transformation - n_view_copy_before = 0 - n_memory_view_before = 0 - for node in gm.graph.nodes: - if is_view_copy(node): - n_view_copy_before += 1 - if is_memory_view(node): - n_memory_view_before += 1 - - self.assertEqual(n_view_copy_before, 5) - self.assertEqual(n_memory_view_before, 0) + FileCheck().check_count( + "torch.ops.aten.view_copy.default", 2, exactly=True + ).run(gm.code) + FileCheck().check_count("executorch_exir_memory_view", 0, exactly=True).run( + gm.code + ) # Do transformation p = ReplaceViewCopyWithViewPass() @@ -1560,14 +1528,10 @@ def forward(self, x): assert gm_res is not None gm = gm_res.graph_module - # Check after transformation - n_view_copy_after = 0 - n_memory_view_after = 0 - for node in gm.graph.nodes: - if is_view_copy(node): - n_view_copy_after += 1 - if is_memory_view(node): - n_memory_view_after += 1 - - self.assertEqual(n_view_copy_after, 0) - self.assertEqual(n_memory_view_after, 5) + # Check before transformation + FileCheck().check_count( + "torch.ops.aten.view_copy.default", 0, exactly=True + ).run(gm.code) + FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run( + gm.code + ) diff --git a/exir/tests/test_quant_fusion_pass.py b/exir/tests/test_quant_fusion_pass.py index 00269da92d7..69610a73abe 100644 --- a/exir/tests/test_quant_fusion_pass.py +++ b/exir/tests/test_quant_fusion_pass.py @@ -117,7 +117,7 @@ def forward(self, x, y): m.exported_program.graph_module.code ) - m = m.to_executorch() + m = m.to_executorch(exir.ExecutorchBackendConfig(remove_view_copy=False)) # check that we are using out variant of q/dq/add FileCheck().check("torch.ops.quantized_decomposed.add.out").check( "torch.ops.aten.view_copy.out" diff --git a/exir/tests/test_remove_view_copy.py b/exir/tests/test_remove_view_copy.py new file mode 100644 index 00000000000..0c5b61f8d8f --- /dev/null +++ b/exir/tests/test_remove_view_copy.py @@ -0,0 +1,202 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +import torch.nn as nn +from executorch.exir import memory, to_edge +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes import MemoryPlanningPass + + +class TestModel1(nn.Module): + def __init__(self): + super().__init__() + self.parameter = nn.Parameter(torch.rand(5, 6)) + self.parameter.requires_grad = False + + def forward(self, x): + v1 = self.parameter.view( + 6, 5 + ) # removed, lifetime of parameter will be extended + v2 = x.view(6, 5) # not removed + v3 = torch.ops.aten.mul.Tensor(v1, v2).view( + 30 + ) # removed, lifetime of mul.Tensor will be extended + return v3 + + def get_example_inputs(self): + return (torch.rand(5, 6),) + + +class TestRemoveViewCopy(unittest.TestCase): + def test_disable(self) -> None: + model = TestModel1() + model.eval() + example_inputs = model.get_example_inputs() + ep = torch.export.export(model, example_inputs) + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=False, + memory_planning_pass=MemoryPlanningPass( + "greedy", alloc_graph_input=False + ), + ), + ) + + for node in etpm.exported_program().graph_module.graph.nodes: + assert node.target != memory.view + + def test_output_matches(self) -> None: + model = TestModel1() + model.eval() + example_inputs = model.get_example_inputs() + ep = torch.export.export(model, example_inputs) + + epm_remove = to_edge(ep) + epm_no_remove = copy.deepcopy( + epm_remove + ) # to_executorch modifies the edge_program, so we make a copy + + # Run pass with no removal + etpm_remove = epm_remove.to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass( + "greedy", alloc_graph_input=False + ), + ), + ) + + # Run pass with removal + etpm_no_remove = epm_no_remove.to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass( + "greedy", alloc_graph_input=False + ), + ), + ) + + out_remove = etpm_remove.exported_program().module()(*example_inputs) + out_no_remove = etpm_no_remove.exported_program().module()(*example_inputs) + + self.assertTrue(torch.allclose(out_remove, out_no_remove)) + + def test_spec(self) -> None: + model = TestModel1() + model.eval() + example_inputs = model.get_example_inputs() + ep = torch.export.export(model, example_inputs) + + etpm = to_edge(ep).to_executorch( + config=ExecutorchBackendConfig( + remove_view_copy=True, + memory_planning_pass=MemoryPlanningPass( + "greedy", alloc_graph_input=False + ), + ), + ) + + # etpm.exported_program().graph.print_tabular() + + # idx opcode name target args kwargs + # --- ------------- ------------------------ ---------------------------------- -------------------------------------------------- -------------- + # 0 placeholder p_parameter p_parameter () {} + # 1 placeholder x x () {} + # 2 call_function aten_view_copy_default (p_parameter, [6, 5]) {} + # 3 call_function aten_view_copy_default_1 (x, [6, 5]) {} + # 4 call_function alloc (((6, 5), torch.float32),) {} + # 5 call_function aten_mul_tensor aten.mul.out (aten_view_copy_default, aten_view_copy_default_1) {'out': alloc} + # 6 call_function aten_view_copy_default_2 (aten_mul_tensor, [30]) {} + # 7 output output_1 output ((aten_view_copy_default_2,),) {} + + for node in etpm.exported_program().graph.nodes: + if node.name == "p_parameter": + # p_parameter's lifetime is extended through aten_view_copy_default (memory.view) to idx 5 + self.assertEqual(node.meta["spec"].lifetime, [0, 5]) + elif node.name == "aten_view_copy_default": + # aten_view_copy_default is a memory.view of p_parameter. + # p_parameter is a constant with storage, so we check that the view's storage matches the base + + # assert base is p_parameter + self.assertEqual(node.args[0].name, "p_parameter") + + # assert base is const with storage + self.assertTrue(node.args[0].meta["spec"].const) + self.assertTrue(node.args[0].meta["spec"].storage is not None) + self.assertTrue(node.args[0].meta["spec"].mem_id is None) + self.assertTrue(node.args[0].meta["spec"].mem_offset is None) + + # assert self is const with storage + self.assertTrue(node.meta["spec"].const) + self.assertTrue(node.meta["spec"].storage is not None) + self.assertTrue(node.meta["spec"].mem_id is None) + self.assertTrue(node.meta["spec"].mem_offset is None) + + # assert storage matches + self.assertEqual( + node.meta["spec"].storage, node.args[0].meta["spec"].storage + ) + + # assert lifetime matches + self.assertEqual( + node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime + ) + elif node.name == "aten_mul_tensor": + # aten_mul_tensor's lifetime is extended through aten_view_copy_default_2 (memory.view) to idx 7 + self.assertEqual(node.meta["spec"].lifetime, [4, 7]) + elif node.name == "aten_view_copy_default_2": + # aten_view_copy_default_2 is a memory.view of aten_mul_tensor + + # assert base is aten_mul_tensor + self.assertEqual(node.args[0].name, "aten_mul_tensor") + + # assert base and self are not const, do not have storage, + # but do have mem_id and mem_offset + self.assertFalse(node.args[0].meta["spec"].const) + self.assertTrue(node.args[0].meta["spec"].storage is None) + self.assertTrue(node.args[0].meta["spec"].mem_id is not None) + self.assertTrue(node.args[0].meta["spec"].mem_offset is not None) + + self.assertFalse(node.meta["spec"].const) + self.assertTrue(node.meta["spec"].storage is None) + self.assertTrue(node.meta["spec"].mem_id is not None) + self.assertTrue(node.meta["spec"].mem_offset is not None) + + # assert self and base mem_id, mem_offset, and lifetime matches + self.assertEqual( + node.meta["spec"].mem_id, node.args[0].meta["spec"].mem_id + ) + self.assertEqual( + node.meta["spec"].mem_offset, node.args[0].meta["spec"].mem_offset + ) + self.assertEqual( + node.meta["spec"].lifetime, node.args[0].meta["spec"].lifetime + ) + + # Test evalues in execution plan + plan = etpm.executorch_program.execution_plan[0] + self.assertEqual(plan.operators[0].name, "executorch_prim::et_view") + self.assertEqual(plan.operators[1].name, "aten::mul") + + instructions = plan.chains[0].instructions + self.assertEqual(len(instructions), 4) + + self.assertEqual( + instructions[0].instr_args.op_index, 0 # pyre-ignore + ) # view @ idx2 + self.assertEqual( + instructions[1].instr_args.op_index, 0 # pyre-ignore + ) # view @ idx3 + self.assertEqual( + instructions[2].instr_args.op_index, 1 # pyre-ignore + ) # aten:mul @ idx5 + self.assertEqual( + instructions[3].instr_args.op_index, 0 # pyre-ignore + ) # view @ idx6