From 194b7513d45059e3d15fb36c5dfcbaa74bcc27de Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Wed, 10 Apr 2024 09:09:33 -0700 Subject: [PATCH] Fixes for constant_prop_pass (#2967) Summary: The constant_prop_pass didn't properly propagate constants when there were simple primitives in the argument set. Extend it to see floats, ints, strings, etc. as constant functions. This allows this pass to fold additional things like quantize functions on weights. Sometimes users don't want that, so allow them to use a lambda to skip some nodes. Differential Revision: D55942686 --- exir/passes/constant_prop_pass.py | 52 +++++++++++++++++++++++++++---- 1 file changed, 46 insertions(+), 6 deletions(-) diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index 14ff651c936..764efffa18f 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Callable, List, Optional + import torch from torch._export.utils import get_buffer, get_param, is_buffer, is_param from torch._guards import detect_fake_mode @@ -11,11 +13,27 @@ from torch.export.exported_program import InputKind, InputSpec, TensorArgument -def is_const(arg, exported_program, const_data_list) -> bool: +_PRIMITIVE_TYPES = ( + float, + int, + bool, + str, + torch.Tensor, + torch.device, + torch.dtype, + torch.layout, +) + + +def is_const( + arg: object, exported_program: ExportedProgram, const_data_list: List[str] +) -> bool: if isinstance(arg, (tuple, list)): return all(is_const(x, exported_program, const_data_list) for x in arg) elif isinstance(arg, dict): return all(is_const(x, exported_program, const_data_list) for x in arg.values()) + elif isinstance(arg, _PRIMITIVE_TYPES): + return True elif not isinstance(arg, torch.fx.Node) or arg.op != "placeholder": return False elif ( @@ -27,9 +45,11 @@ def is_const(arg, exported_program, const_data_list) -> bool: return False -def get_data(exported_program, arg): +def get_data(exported_program: ExportedProgram, arg): if isinstance(arg, (tuple, list)): return [get_data(exported_program, x) for x in arg] + elif isinstance(arg, _PRIMITIVE_TYPES): + return arg elif is_param(exported_program, arg): return get_param(exported_program, arg) elif is_buffer(exported_program, arg): @@ -37,7 +57,10 @@ def get_data(exported_program, arg): return None -def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: +def constant_prop_pass( + exported_program: ExportedProgram, + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, +) -> ExportedProgram: """ This pass is for constant propagation for Exported Program with lifted parameters, as the parameters will not be shown up as `get_attr` but as `placeholder` to the graph. @@ -56,12 +79,14 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: if len(has_cond) > 0: raise RuntimeError("constant_prop_pass for control flow is not supported yet.") + first_user_input_idx = -1 first_user_input = None - for node in exported_program.graph.nodes: + for i, node in enumerate(exported_program.graph.nodes): if ( node.op == "placeholder" and node.name in exported_program.graph_signature.user_inputs ): + first_user_input_idx = i first_user_input = node break @@ -79,6 +104,9 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: assert fake_mode is not None for node in exported_program.graph.nodes: + if skip_folding_node_fn is not None and skip_folding_node_fn(node): + # Do not process this node if we were told to skip it. + continue if node.op == "call_function": constant_data_name_list = [ input_spec.target for input_spec in prop_constant_data @@ -115,9 +143,11 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: exported_program.state_dict[prop_constant_tensor_fqn] = ( prop_constant_tensor ) - exported_program.graph_signature.input_specs.append( - prop_constant_node_input_spec + # Insert new buffers before the first user input. + exported_program.graph_signature.input_specs.insert( + first_user_input_idx, prop_constant_node_input_spec ) + first_user_input_idx += 1 # Remove the propogated buffer from the state dict for node in exported_program.graph.nodes: @@ -128,6 +158,16 @@ def constant_prop_pass(exported_program: ExportedProgram) -> ExportedProgram: ): exported_program.state_dict.pop(node.name, None) exported_program.graph.erase_node(node) + # Delete the input spec for this deleted buffer. + to_erase_idx = [] + for i, spec in enumerate(exported_program.graph_signature.input_specs): + if spec.arg.name == node.name: + to_erase_idx.append(i) + assert ( + len(to_erase_idx) == 1 + ), f"Should only delete one spec per node, but deleting multiple: {to_erase_idx} {exported_program.graph_signature.input_specs}" + for i in reversed(to_erase_idx): + exported_program.graph_signature.input_specs.pop(i) exported_program.graph_module.recompile() return exported_program