1717#pylint: disable=unused-argument, not-context-manager
1818"""Automatic quantization toolkit."""
1919import tvm .ir
20+ import tvm
2021from tvm .runtime import Object
2122
2223from . import _quantize
@@ -240,7 +241,7 @@ def partition():
240241
241242 Returns
242243 -------
243- ret: tvm.relay .Pass
244+ ret: tvm.transform .Pass
244245 The registered pass for VTA rewrite.
245246 """
246247 return _quantize .QuantizePartition ()
@@ -253,7 +254,7 @@ def annotate():
253254
254255 Returns
255256 -------
256- ret: tvm.relay .Pass
257+ ret: tvm.transform .Pass
257258 The registered pass for quantization annotation.
258259 """
259260 return _quantize .QuantizeAnnotate ()
@@ -267,7 +268,7 @@ def realize():
267268
268269 Returns
269270 -------
270- ret: tvm.relay .Pass
271+ ret: tvm.transform .Pass
271272 The registered pass for quantization realization.
272273 """
273274 return _quantize .QuantizeRealize ()
@@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
298299 """ Prerequisite optimization passes for quantization. Perform
299300 "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
300301 "CanonicalizeOps" optimization before quantization. """
301- optimize = _transform .Sequential ([_transform .SimplifyInference (),
302- _transform .FoldConstant (),
303- _transform .FoldScaleAxis (),
304- _transform .CanonicalizeOps (),
305- _transform .FoldConstant ()])
302+ optimize = tvm .transform .Sequential (
303+ [_transform .SimplifyInference (),
304+ _transform .FoldConstant (),
305+ _transform .FoldScaleAxis (),
306+ _transform .CanonicalizeOps (),
307+ _transform .FoldConstant ()])
306308
307309 if params :
308310 mod ['main' ] = _bind_params (mod ['main' ], params )
@@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
336338 """
337339 mod = prerequisite_optimize (mod , params )
338340
339- calibrate_pass = _transform .module_pass (calibrate (dataset ), opt_level = 1 ,
340- name = "QuantizeCalibrate" )
341+ calibrate_pass = tvm .transform .module_pass (
342+ calibrate (dataset ), opt_level = 1 ,
343+ name = "QuantizeCalibrate" )
341344 quant_passes = [partition (),
342345 annotate (),
343346 calibrate_pass ]
344347 if not current_qconfig ().do_simulation :
345348 quant_passes .append (realize ())
346349 quant_passes .append (_transform .FoldConstant ())
347- quantize_seq = _transform .Sequential (quant_passes )
348- with _transform .PassContext (opt_level = 3 ,
349- required_pass = ["QuantizeAnnotate" ,
350- "QuantizeCalibrate" ,
351- "QuantizeRealize" ]):
350+ quantize_seq = tvm . transform .Sequential (quant_passes )
351+ with tvm . transform .PassContext (opt_level = 3 ,
352+ required_pass = ["QuantizeAnnotate" ,
353+ "QuantizeCalibrate" ,
354+ "QuantizeRealize" ]):
352355 with quantize_context ():
353356 mod = quantize_seq (mod )
354357
0 commit comments