Skip to content

Commit 2b6e845

Browse files
tqchendhruvaray
authored andcommitted
[RELAY] Remove re-exports of tvm.transform (apache#5337)
1 parent 16d3da1 commit 2b6e845

38 files changed

Lines changed: 169 additions & 229 deletions

docs/api/python/ir.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,11 @@ tvm.ir
2121
:members:
2222
:imported-members:
2323
:autosummary:
24+
25+
26+
tvm.transform
27+
-------------
28+
.. automodule:: tvm.transform
29+
:members:
30+
:imported-members:
31+
:autosummary:

docs/dev/convert_layout.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
227227
228228
# Convert the layout to NCHW
229229
# RemoveUnunsedFunctions is used to clean up the graph.
230-
seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
230+
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
231231
relay.transform.ConvertLayout('NCHW')])
232232
with relay.transform.PassContext(opt_level=3):
233233
mod = seq(mod)

docs/dev/relay_pass_infra.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
582582
func = relay.Function([x], z2)
583583
584584
# Customize the optimization pipeline.
585-
seq = _transform.Sequential([
585+
seq = tvm.transform.Sequential([
586586
relay.transform.InferType(),
587587
relay.transform.FoldConstant(),
588588
relay.transform.EliminateCommonSubexpr(),
@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for
609609

610610
.. code:: python
611611
612-
seq = _transform.Sequential([
612+
seq = tvm.transform.Sequential([
613613
relay.transform.InferType(),
614614
relay.transform.FoldConstant(),
615615
relay.transform.PrintIR(),

include/tvm/ir/transform.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(
361361

362362
/*!
363363
* \brief A special trace pass that prints the header and IR to LOG(INFO).
364+
* \param header The header to be attached to the output.
365+
* \param show_meta_data Whether should we show meta data.
364366
* \return The pass.
365367
*/
366-
TVM_DLL Pass PrintIR(std::string header);
368+
TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
367369

368370
} // namespace transform
369371
} // namespace tvm

python/tvm/ir/json_compact.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _update_global_key(item, _):
106106
"relay.PassInfo": _rename("transform.PassInfo"),
107107
"relay.PassContext": _rename("transform.PassContext"),
108108
"relay.ModulePass": _rename("transform.ModulePass"),
109-
"relay.Sequantial": _rename("transform.Sequantial"),
109+
"relay.Sequential": _rename("transform.Sequential"),
110110
# TIR
111111
"Variable": _update_tir_var("tir.Var"),
112112
"SizeVar": _update_tir_var("tir.SizeVar"),

python/tvm/ir/transform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,16 +329,19 @@ def create_module_pass(pass_arg):
329329
return create_module_pass
330330

331331

332-
def PrintIR(header):
332+
def PrintIR(header="", show_meta_data=False):
333333
"""A special trace pass that prints the header and IR.
334334
335335
Parameters
336336
----------
337337
header : str
338338
The header to be displayed along with the dump.
339339
340+
show_meta_data : bool
341+
A boolean flag to indicate if meta data should be printed.
342+
340343
Returns
341344
--------
342345
The pass
343346
"""
344-
return _ffi_transform_api.PrintIR(header)
347+
return _ffi_transform_api.PrintIR(header, show_meta_data)

python/tvm/relay/__init__.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -128,20 +128,9 @@
128128
# Scope builder
129129
ScopeBuilder = scope_builder.ScopeBuilder
130130

131-
module_pass = transform.module_pass
132-
function_pass = transform.function_pass
133-
134131
# Parser
135132
fromtext = parser.fromtext
136133

137134
# Param Serialization
138135
save_param_dict = param_dict.save_param_dict
139136
load_param_dict = param_dict.load_param_dict
140-
141-
# Pass manager
142-
PassInfo = transform.PassInfo
143-
PassContext = transform.PassContext
144-
Pass = transform.Pass
145-
ModulePass = transform.ModulePass
146-
FunctionPass = transform.FunctionPass
147-
Sequential = transform.Sequential

python/tvm/relay/backend/interpreter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,10 @@ def optimize(self):
210210
opt_mod : tvm.IRModule
211211
The optimized module.
212212
"""
213-
seq = transform.Sequential([transform.SimplifyInference(),
214-
transform.FuseOps(0),
215-
transform.ToANormalForm(),
216-
transform.InferType()])
213+
seq = tvm.transform.Sequential([transform.SimplifyInference(),
214+
transform.FuseOps(0),
215+
transform.ToANormalForm(),
216+
transform.InferType()])
217217
return seq(self.mod)
218218

219219
def _make_executor(self, expr=None):

python/tvm/relay/qnn/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def @main(%quantized_data: Tensor[(200), int32]) -> Tensor[(200), int8] {
6060
6161
Returns
6262
-------
63-
ret : tvm.relay.Pass
63+
ret : tvm.transform.Pass
6464
The registered pass that canonicalizes QNN ops to Relay ops.
6565
"""
6666

@@ -108,7 +108,7 @@ def Legalize():
108108
109109
Returns
110110
-------
111-
ret : tvm.relay.Pass
111+
ret : tvm.transform.Pass
112112
The registered pass that legalizes QNN ops.
113113
"""
114114

python/tvm/relay/quantize/quantize.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#pylint: disable=unused-argument, not-context-manager
1818
"""Automatic quantization toolkit."""
1919
import tvm.ir
20+
import tvm
2021
from tvm.runtime import Object
2122

2223
from . import _quantize
@@ -240,7 +241,7 @@ def partition():
240241
241242
Returns
242243
-------
243-
ret: tvm.relay.Pass
244+
ret: tvm.transform.Pass
244245
The registered pass for VTA rewrite.
245246
"""
246247
return _quantize.QuantizePartition()
@@ -253,7 +254,7 @@ def annotate():
253254
254255
Returns
255256
-------
256-
ret: tvm.relay.Pass
257+
ret: tvm.transform.Pass
257258
The registered pass for quantization annotation.
258259
"""
259260
return _quantize.QuantizeAnnotate()
@@ -267,7 +268,7 @@ def realize():
267268
268269
Returns
269270
-------
270-
ret: tvm.relay.Pass
271+
ret: tvm.transform.Pass
271272
The registered pass for quantization realization.
272273
"""
273274
return _quantize.QuantizeRealize()
@@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
298299
""" Prerequisite optimization passes for quantization. Perform
299300
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
300301
"CanonicalizeOps" optimization before quantization. """
301-
optimize = _transform.Sequential([_transform.SimplifyInference(),
302-
_transform.FoldConstant(),
303-
_transform.FoldScaleAxis(),
304-
_transform.CanonicalizeOps(),
305-
_transform.FoldConstant()])
302+
optimize = tvm.transform.Sequential(
303+
[_transform.SimplifyInference(),
304+
_transform.FoldConstant(),
305+
_transform.FoldScaleAxis(),
306+
_transform.CanonicalizeOps(),
307+
_transform.FoldConstant()])
306308

307309
if params:
308310
mod['main'] = _bind_params(mod['main'], params)
@@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
336338
"""
337339
mod = prerequisite_optimize(mod, params)
338340

339-
calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
340-
name="QuantizeCalibrate")
341+
calibrate_pass = tvm.transform.module_pass(
342+
calibrate(dataset), opt_level=1,
343+
name="QuantizeCalibrate")
341344
quant_passes = [partition(),
342345
annotate(),
343346
calibrate_pass]
344347
if not current_qconfig().do_simulation:
345348
quant_passes.append(realize())
346349
quant_passes.append(_transform.FoldConstant())
347-
quantize_seq = _transform.Sequential(quant_passes)
348-
with _transform.PassContext(opt_level=3,
349-
required_pass=["QuantizeAnnotate",
350-
"QuantizeCalibrate",
351-
"QuantizeRealize"]):
350+
quantize_seq = tvm.transform.Sequential(quant_passes)
351+
with tvm.transform.PassContext(opt_level=3,
352+
required_pass=["QuantizeAnnotate",
353+
"QuantizeCalibrate",
354+
"QuantizeRealize"]):
352355
with quantize_context():
353356
mod = quantize_seq(mod)
354357

0 commit comments

Comments
 (0)