Skip to content

Commit 971e0bc

Browse files
Jonathan Shencopybara-github
authored andcommitted
Fix eager evaler/decoders.
PiperOrigin-RevId: 415432320
1 parent 4fb8acd commit 971e0bc

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

lingvo/eager_runners.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,16 @@ def Start(self):
205205
self._task.params.eval.load_checkpoint_from)
206206

207207
if self._eval_path:
208-
self._EvalOnce(self._eval_path)
208+
self._EvalOnce(path=self._eval_path)
209209
py_utils.UpdateProcessedCheckpoints(self._eval_dir, self._eval_path)
210210
elif self._task.params.eval.eval_all_checkpoints:
211-
self._RunOnAllCheckpoints(self._EvalOnce, self._eval_dir)
211+
self._RunOnAllCheckpoints(
212+
runner_fn=self._EvalOnce, runner_dir=self._eval_dir)
212213
else:
213-
self._RunOnLatestCheckpoints(self._EvalOnce, self._eval_dir)
214+
self._RunOnLatestCheckpoints(
215+
runner_fn=self._EvalOnce, runner_dir=self._eval_dir)
214216

215-
def _EvalOnce(self, path):
217+
def _EvalOnce(self, sess=None, path=''):
216218
"""Eval a single checkpoint."""
217219
with self._cluster:
218220
# Attempt to restore the checkpoint
@@ -337,20 +339,22 @@ def Start(self):
337339
self._task.params.eval.load_checkpoint_from)
338340

339341
if self._decode_path:
340-
self._DecodeOnce(self._decode_path)
342+
self._DecodeOnce(path=self._decode_path)
341343
py_utils.UpdateProcessedCheckpoints(self._decoder_dir, self._decode_path)
342344
elif self._task.params.eval.decode_all_checkpoints:
343-
self._RunOnAllCheckpoints(self._DecodeOnce, self._decoder_dir)
345+
self._RunOnAllCheckpoints(
346+
runner_fn=self._DecodeOnce, runner_dir=self._decoder_dir)
344347
else:
345-
self._RunOnLatestCheckpoints(self._DecodeOnce, self._decoder_dir)
348+
self._RunOnLatestCheckpoints(
349+
runner_fn=self._DecodeOnce, runner_dir=self._decoder_dir)
346350

347351
@classmethod
348352
def GetDecodeOutPath(cls, decoder_dir, checkpoint_id):
349353
"""Gets the path to decode out file."""
350354
out_dir = cls._GetTtlDir(decoder_dir, duration='7d')
351355
return os.path.join(out_dir, 'decoder_out_%09d' % checkpoint_id)
352356

353-
def _DecodeOnce(self, path):
357+
def _DecodeOnce(self, sess=None, path=''):
354358
"""Decode a single checkpoint."""
355359
with self._cluster:
356360
# Attempt to restore the checkpoint

lingvo/trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -415,11 +415,15 @@ def UpdateClusterParamsFromFlags(self, cluster, job_name):
415415
cluster.input.replicas = FLAGS.input_replicas
416416
cluster.input.targets = FLAGS.input_targets
417417

418-
cluster.evaler.name = FLAGS.evaler_job
418+
if py_utils.IsEagerMode():
419+
cluster.evaler.name = '/job:localhost'
420+
cluster.decoder.name = '/job:localhost'
421+
else:
422+
cluster.evaler.name = FLAGS.evaler_job
423+
cluster.decoder.name = FLAGS.decoder_job
424+
419425
cluster.evaler.replicas = FLAGS.evaler_replicas
420426
cluster.evaler.gpus_per_replica = FLAGS.evaler_gpus
421-
422-
cluster.decoder.name = FLAGS.decoder_job
423427
cluster.decoder.replicas = FLAGS.decoder_replicas
424428
cluster.decoder.gpus_per_replica = FLAGS.decoder_gpus
425429

0 commit comments

Comments
 (0)