Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,9 @@ TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map,
Optional<String> func_name = NullOpt);

/*!
* \brief Fold constant expressions.
* \brief Fold constant expressions within dataflow blocks.
*
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
*
* \return The Pass.
*/
Expand Down Expand Up @@ -458,6 +460,8 @@ class PatternCheckContext : public ObjectRef {
* of the return value as the target. If it is not specified, the first return value will be the
* target.
* \return The Pass.
*
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
*/
TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = NullOpt,
int target_index = 0);
Expand All @@ -477,6 +481,8 @@ TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = Nul
* This must be True if the created composite functions are intended to be offloaded to
* an external backend without using the MergeCompositeFunctions pass.
* \return The Pass.
*
* \note Only operates within dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool bind_constants = true,
bool annotate_codegen = false);
Expand Down Expand Up @@ -548,6 +554,7 @@ TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
* \brief Layout conversion pass.
* \param desired_layouts The desired layouts for some operators.
* \return The Pass.
* \note Operates only on dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);

Expand All @@ -564,10 +571,13 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2);
* \brief Dead code elimination.
* \sa RemoveAllUnused
* Currently it removes:
* 1. Unused local VarBindings in a DataflowBlock.
* 2. Unused DataflowBlocks in a function.
* 3. Unused Relax functions in the module.
* 1. Unused local VarBindings
* (those where the bound var is unused and no impure operation is used).
* 2. Unused Relax functions in the module.
* We detect the call chain from the entry function, and remove all unused functions.
*
* Any binding blocks that are left empty will be removed by the normalizer.
*
* \return The Pass.
*/
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
Expand All @@ -578,6 +588,7 @@ TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
* Supported operators will be replaced by calls to `call_tir_inplace` that invoke in-place
* PrimFunc implementations of those operators (which are based on the legalizations of those
* operators).
* \note ConvertToDataflow may need to be called first to provide dataflow blocks.
* \return The pass.
*/
TVM_DLL Pass DataflowUseInplaceCalls();
Expand All @@ -589,6 +600,8 @@ TVM_DLL Pass DataflowUseInplaceCalls();
* \param fp16_input_names The names of function parameters whose dtype should become fp16. The
* function signature would change accordingly.
* \return The Pass.
*
* \note Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first.
*/
TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype,
Optional<Array<String>> fp16_input_names = NullOpt);
Expand Down
25 changes: 20 additions & 5 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def Gradient(
"""Reverse-mode automatic differentiation.

This pass will differentiate one function in the IRModule. Now the input function must have only
one dataflow block.
one dataflow block (ConvertToDataflow may need to be called first).

For a given function specified by `func_name`, it generates a new function with the name
`func_name + "_adjoint"`. The new function computes the gradient of the **differentiation
Expand Down Expand Up @@ -260,6 +260,8 @@ def DataflowUseInplaceCalls() -> tvm.ir.transform.Pass:
in-place PrimFunc implementations of those operators (which are based on the legalizations of
those operators).

Note: ConvertToDataflow may need to be called first to provide dataflow blocks.

Returns
-------
ret: tvm.ir.transform.Pass
Expand All @@ -282,6 +284,8 @@ def ConvertToDataflow(min_size: int = 2) -> tvm.ir.transform.Pass:
"""A pass that converts consecutive dataflow operations
inside binding blocks into dataflow blocks.

Note: ConvertToDataflow may need to be called first.

Params
------
min_size: int
Expand Down Expand Up @@ -395,6 +399,8 @@ def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
operation at runtime, instead of doing real data copy.
Here "reshape-like" includes reshape, expand_dims, flatten, etc.

Note: Operates only in dataflow blocks. ConvertToDataflow may need to be called first.

Returns
-------
ret : tvm.ir.transform.Pass
Expand Down Expand Up @@ -584,7 +590,9 @@ def RunCodegen(


def FoldConstant() -> tvm.ir.transform.Pass:
"""Fold constant expressions.
"""Fold constant expressions within dataflow blocks.

Note: ConvertToDataflow may need to be called first to provide dataflow blocks.

Returns
-------
Expand Down Expand Up @@ -651,6 +659,8 @@ def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass:

A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function.

Note: ConvertToDataflow may need to be called first to provide dataflow blocks.

Parameters
----------
fuse_opt_level : int
Expand Down Expand Up @@ -764,6 +774,8 @@ def FuseOpsByPattern(

The end result is similar to FuseOps, but fusion is driven completely by the provided patterns.

Note: Only operates within dataflow blocks. ConvertToDataflow may need to be called first.

Parameters
----------
patterns : List[Union[FusionPattern, Tuple]]
Expand Down Expand Up @@ -1172,11 +1184,12 @@ def DeadCodeElimination(entry_functions: Optional[List[str]] = None) -> tvm.ir.t
"""Remove dead code in the IRModule.
Currently it removes:

1. Unused local VarBindings in a DataflowBlock.
2. Unused DataflowBlocks in a function.
3. Unused Relax functions in the module.
1. Unused local VarBindings
(those where the bound var is unused and no impure operation is used).
2. Unused Relax functions in the module.
We detect the call chain from the entry function, and remove all unused functions.

Any binding blocks that are left empty will be removed by the normalizer.

Notes
-----
Expand All @@ -1203,6 +1216,8 @@ def ToMixedPrecision(
"""Automatic mixed precision pass. Currently the pass assumes the input module to be fp32
only, and will automatically cast fp32 to fp16 for certain ops.

Note: Mainly operates within dataflow blocks. ConvertToDataflow may need to be called first.

Parameters
----------
out_dtype : str
Expand Down
8 changes: 5 additions & 3 deletions src/relax/transform/dead_code_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
* \sa tvm/relax/ir/binding_rewrite.cc
*
* Currently it removes:
* 1. Unused local VarBindings in a DataflowBlock.
* 2. Unused DataflowBlocks in a function.
* 3. Unused Relax functions in the module.
* 1. Unused local VarBindings
* (those where the bound var is unused and no impure operation is used).
* 2. Unused Relax functions in the module.
* We detect the call chain from the entry function, and remove all unused functions.
*
* Any binding blocks that are left empty will be removed by the normalizer.
*/

#include <tvm/relax/analysis.h>
Expand Down