diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index efe30e5cbb50..027fd6f824db 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -214,7 +214,9 @@ TVM_DLL Pass BindSymbolicVars(Map binding_map, Optional func_name = NullOpt); /*! - * \brief Fold constant expressions. + * \brief Fold constant expressions within dataflow blocks. + * + * \note ConvertToDataflow may need to be called first to provide dataflow blocks. * * \return The Pass. */ @@ -458,6 +460,8 @@ class PatternCheckContext : public ObjectRef { * of the return value as the target. If it is not specified, the first return value will be the * target. * \return The Pass. + * + * \note ConvertToDataflow may need to be called first to provide dataflow blocks. */ TVM_DLL Pass Gradient(String func_name, Optional> require_grads = NullOpt, int target_index = 0); @@ -477,6 +481,8 @@ TVM_DLL Pass Gradient(String func_name, Optional> require_grads = Nul * This must be True if the created composite functions are intended to be offloaded to * an external backend without using the MergeCompositeFunctions pass. * \return The Pass. + * + * \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first. */ TVM_DLL Pass FuseOpsByPattern(const tvm::Array& patterns, bool bind_constants = true, bool annotate_codegen = false); @@ -548,6 +554,7 @@ TVM_DLL Pass AlterOpImpl(const Map& op_impl_map, * \brief Layout conversion pass. * \param desired_layouts The desired layouts for some operators. * \return The Pass. + * \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first. */ TVM_DLL Pass ConvertLayout(Map> desired_layouts); @@ -564,10 +571,13 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2); * \brief Dead code elimination. * \sa RemoveAllUnused * Currently it removes: - * 1. Unused local VarBindings in a DataflowBlock. - * 2. Unused DataflowBlocks in a function. - * 3. Unused Relax functions in the module. + * 1. Unused local VarBindings + * (those where the bound var is unused and no impure operation is used). + * 2. Unused Relax functions in the module. * We detect the call chain from the entry function, and remove all unused functions. + * + * Any binding blocks that are left empty will be removed by the normalizer. + * * \return The Pass. */ TVM_DLL Pass DeadCodeElimination(Array entry_functions); @@ -578,6 +588,7 @@ TVM_DLL Pass DeadCodeElimination(Array entry_functions); * Supported operators will be replaced by calls to `call_tir_inplace` that invoke in-place * PrimFunc implementations of those operators (which are based on the legalizations of those * operators). + * \note ConvertToDataflow may need to be called first to provide dataflow blocks. * \return The pass. */ TVM_DLL Pass DataflowUseInplaceCalls(); @@ -589,6 +600,8 @@ TVM_DLL Pass DataflowUseInplaceCalls(); * \param fp16_input_names The names of function parameters whose dtype should become fp16. The * function signature would change accordingly. * \return The Pass. + * + * \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first. */ TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype, Optional> fp16_input_names = NullOpt); diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index e360c09392f3..b2aaa3e331a1 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -52,7 +52,7 @@ def Gradient( """Reverse-mode automatic differentiation. This pass will differentiate one function in the IRModule. Now the input function must have only - one dataflow block. + one dataflow block (ConvertToDataflow may need to be called first). For a given function specified by `func_name`, it generates a new function with the name `func_name + "_adjoint"`. The new function computes the gradient of the **differentiation @@ -260,6 +260,8 @@ def DataflowUseInplaceCalls() -> tvm.ir.transform.Pass: in-place PrimFunc implementations of those operators (which are based on the legalizations of those operators). + Note: ConvertToDataflow may need to be called first to provide dataflow blocks. + Returns ------- ret: tvm.ir.transform.Pass @@ -282,6 +284,8 @@ def ConvertToDataflow(min_size: int = 2) -> tvm.ir.transform.Pass: """A pass that converts consecutive dataflow operations inside binding blocks into dataflow blocks. + Note: ConvertToDataflow may need to be called first. + Params ------ min_size: int @@ -395,6 +399,8 @@ def RewriteDataflowReshape() -> tvm.ir.transform.Pass: operation at runtime, instead of doing real data copy. Here "reshape-like" includes reshape, expand_dims, flatten, etc. + Note: Operates only in dataflow blocks. ConvertToDataflow may need to be called first. + Returns ------- ret : tvm.ir.transform.Pass @@ -584,7 +590,9 @@ def RunCodegen( def FoldConstant() -> tvm.ir.transform.Pass: - """Fold constant expressions. + """Fold constant expressions within dataflow blocks. + + Note: ConvertToDataflow may need to be called first to provide dataflow blocks. Returns ------- @@ -651,6 +659,8 @@ def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass: A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. + Note: ConvertToDataflow may need to be called first to provide dataflow blocks. + Parameters ---------- fuse_opt_level : int @@ -764,6 +774,8 @@ def FuseOpsByPattern( The end result is similar to FuseOps, but fusion is driven completely by the provided patterns. + Note: Only operates within dataflow blocks. ConvertToDataflow may need to be called first. + Parameters ---------- patterns : List[Union[FusionPattern, Tuple]] @@ -1172,11 +1184,12 @@ def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.t """Remove dead code in the IRModule. Currently it removes: - 1. Unused local VarBindings in a DataflowBlock. - 2. Unused DataflowBlocks in a function. - 3. Unused Relax functions in the module. + 1. Unused local VarBindings + (those where the bound var is unused and no impure operation is used). + 2. Unused Relax functions in the module. We detect the call chain from the entry function, and remove all unused functions. + Any binding blocks that are left empty will be removed by the normalizer. Notes ----- @@ -1203,6 +1216,8 @@ def ToMixedPrecision( """Automatic mixed precision pass. Currently the pass assumes the input module to be fp32 only, and will automatically cast fp32 to fp16 for certain ops. + Note: Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first. + Parameters ---------- out_dtype : str diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 248e4c1c00b7..73f66d2ef362 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -24,10 +24,12 @@ * \sa tvm/relax/ir/binding_rewrite.cc * * Currently it removes: - * 1. Unused local VarBindings in a DataflowBlock. - * 2. Unused DataflowBlocks in a function. - * 3. Unused Relax functions in the module. + * 1. Unused local VarBindings + * (those where the bound var is unused and no impure operation is used). + * 2. Unused Relax functions in the module. * We detect the call chain from the entry function, and remove all unused functions. + * + * Any binding blocks that are left empty will be removed by the normalizer. */ #include