2121import numpy as np
2222
2323from ..measure import MeasureInput , create_measure_batch
24+ from ..util import format_si_prefix
2425
2526from ..env import GLOBAL_SCOPE
2627
@@ -87,7 +88,7 @@ def update(self, inputs, results):
8788 """
8889
8990
90- def tune (self , n_trial , measure_option , early_stopping = None , callbacks = ()):
91+ def tune (self , n_trial , measure_option , early_stopping = None , callbacks = (), si_prefix = 'G' ):
9192 """Begin tuning
9293
9394 Parameters
@@ -104,13 +105,18 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()):
104105 (Tuner, List of MeasureInput, List of MeasureResult)
105106 with no return value. These callback functions will be called on
106107 every measurement pair. See autotvm/tuner/callback.py for some examples.
108+ si_prefix: str
109+ One of tvm.autotvm.util.SI_PREFIXES. The SI prefix to use when reporting FLOPS.
107110 """
108111 measure_batch = create_measure_batch (self .task , measure_option )
109112 n_parallel = getattr (measure_batch , 'n_parallel' , 1 )
110113 early_stopping = early_stopping or 1e9
111114 self .n_trial = n_trial
112115 self .early_stopping = early_stopping
113116
117+ # Validate si_prefix arg
118+ format_si_prefix (0 , si_prefix )
119+
114120 old_level = logger .level
115121
116122 GLOBAL_SCOPE .in_tuning = True
@@ -140,9 +146,9 @@ def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()):
140146 self .best_measure_pair = (inp , res )
141147 self .best_iter = i + k
142148
143- logger .debug ("No: %d\t GFLOPS : %.2f/%.2f\t result: %s\t %s" ,
144- i + k + 1 , flops / 1e9 , self . best_flops / 1e9 ,
145- res , config )
149+ logger .debug ("No: %d\t %sFLOPS : %.2f/%.2f\t result: %s\t %s" ,
150+ i + k + 1 , si_prefix , format_si_prefix ( flops , si_prefix ) ,
151+ format_si_prefix ( self . best_flops , si_prefix ), res , config )
146152
147153 i += len (results )
148154 self .ttl = min (early_stopping + self .best_iter , n_trial ) - i
0 commit comments