@@ -205,14 +205,16 @@ def Start(self):
205205 self ._task .params .eval .load_checkpoint_from )
206206
207207 if self ._eval_path :
208- self ._EvalOnce (self ._eval_path )
208+ self ._EvalOnce (path = self ._eval_path )
209209 py_utils .UpdateProcessedCheckpoints (self ._eval_dir , self ._eval_path )
210210 elif self ._task .params .eval .eval_all_checkpoints :
211- self ._RunOnAllCheckpoints (self ._EvalOnce , self ._eval_dir )
211+ self ._RunOnAllCheckpoints (
212+ runner_fn = self ._EvalOnce , runner_dir = self ._eval_dir )
212213 else :
213- self ._RunOnLatestCheckpoints (self ._EvalOnce , self ._eval_dir )
214+ self ._RunOnLatestCheckpoints (
215+ runner_fn = self ._EvalOnce , runner_dir = self ._eval_dir )
214216
215- def _EvalOnce (self , path ):
217+ def _EvalOnce (self , sess = None , path = '' ):
216218 """Eval a single checkpoint."""
217219 with self ._cluster :
218220 # Attempt to restore the checkpoint
@@ -337,20 +339,22 @@ def Start(self):
337339 self ._task .params .eval .load_checkpoint_from )
338340
339341 if self ._decode_path :
340- self ._DecodeOnce (self ._decode_path )
342+ self ._DecodeOnce (path = self ._decode_path )
341343 py_utils .UpdateProcessedCheckpoints (self ._decoder_dir , self ._decode_path )
342344 elif self ._task .params .eval .decode_all_checkpoints :
343- self ._RunOnAllCheckpoints (self ._DecodeOnce , self ._decoder_dir )
345+ self ._RunOnAllCheckpoints (
346+ runner_fn = self ._DecodeOnce , runner_dir = self ._decoder_dir )
344347 else :
345- self ._RunOnLatestCheckpoints (self ._DecodeOnce , self ._decoder_dir )
348+ self ._RunOnLatestCheckpoints (
349+ runner_fn = self ._DecodeOnce , runner_dir = self ._decoder_dir )
346350
347351 @classmethod
348352 def GetDecodeOutPath (cls , decoder_dir , checkpoint_id ):
349353 """Gets the path to decode out file."""
350354 out_dir = cls ._GetTtlDir (decoder_dir , duration = '7d' )
351355 return os .path .join (out_dir , 'decoder_out_%09d' % checkpoint_id )
352356
353- def _DecodeOnce (self , path ):
357+ def _DecodeOnce (self , sess = None , path = '' ):
354358 """Decode a single checkpoint."""
355359 with self ._cluster :
356360 # Attempt to restore the checkpoint
0 commit comments