Skip to content

Commit e71ee02

Browse files
lingvo-botcopybara-github
authored andcommitted
- Support checkpoint loading using init_from_checkpoint_rules for eager trainer.
- Make gradient tape persistent so that it can be used for multiple optimizers in eager mode. PiperOrigin-RevId: 415600445
1 parent 20ec550 commit e71ee02

File tree

3 files changed

+31
-13
lines changed

3 files changed

+31
-13
lines changed

lingvo/core/checkpointer.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,13 @@ def _ResolveCkptPath(ckpt_rules):
193193
return res_rules
194194

195195
self._restore_fns = []
196+
if py_utils.IsEagerMode():
197+
if self._model:
198+
all_vars = list(self._model.GetVariablesDict().values())
199+
else:
200+
raise TypeError('self._model cannot be None in eager mode.')
201+
else:
202+
all_vars = tf.global_variables()
196203

197204
# Add graph nodes to restore specific variables based on
198205
# init_from_checkpoint_rules.
@@ -202,14 +209,16 @@ def _ResolveCkptPath(ckpt_rules):
202209
tp = task.params.train
203210
if tp.init_from_checkpoint_rules:
204211
rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
205-
fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(),
206-
rules)
212+
fn = py_utils.OverrideVarsFromCheckpoints(all_vars, rules)
207213
self._restore_fns.append((f'OverrideVarsFromCheckpoints {rules}', fn))
208214

209215
if self._params and self._params.train.init_from_checkpoint_rules:
216+
if self._model is None:
217+
raise TypeError(
218+
'self._model cannot be None when using init_from_checkpoint_rules.')
210219
tp = self._params.train
211220
rules = _ResolveCkptPath(tp.init_from_checkpoint_rules)
212-
fn = py_utils.OverrideVarsFromCheckpoints(tf.global_variables(), rules)
221+
fn = py_utils.OverrideVarsFromCheckpoints(all_vars, rules)
213222
self._restore_fns.append((f'OverrideVarsFromCheckpoints {rules}', fn))
214223

215224
@property
@@ -337,11 +346,9 @@ def Restore(self, sess=None, force_reinitialize=False):
337346
sess.run(self._init_op)
338347
tf.logging.info('Initialized all vars.')
339348

340-
if self._restore_fns:
341-
for msg, fn in self._restore_fns:
342-
tf.logging.info(msg)
343-
fn(sess)
344-
tf.logging.info('Restored vars using checkpoint rules.')
349+
for msg, fn in self._restore_fns:
350+
tf.logging.info(msg)
351+
fn(sess)
345352
return None
346353

347354
def RestoreIfNeeded(self, sess):
@@ -474,7 +481,16 @@ def _GetSaver(self):
474481
def Restore(self, sess=None, force_reinitialize=None):
475482
"""`sess` and `force_reinitialize` are unused in Eager context."""
476483
assert sess is None
477-
return self._RestoreFromLatestCheckpoint(sess)
484+
path = self._RestoreFromLatestCheckpoint(sess)
485+
if path:
486+
tf.logging.info('Eager checkpoint is restored with path: %s', path)
487+
return path
488+
# No checkpoint is loaded, we need to initialize the variables,
489+
# and apply the init_from_checkpoint_rules if applicable.
490+
for msg, fn in self._restore_fns:
491+
tf.logging.info(msg)
492+
fn(sess)
493+
return path
478494

479495
def RestoreGlobalStepIfNeeded(self, sess=None):
480496
"""`sess` is unused in Eager context."""
@@ -503,7 +519,9 @@ def RestoreFromPath(self, sess=None, checkpoint_path=None):
503519

504520
assert not self._save_only
505521
tf.logging.info('Load from checkpoint (V1) %s.', checkpoint_path)
506-
self._saver.restore(sess=None, save_path=checkpoint_path)
522+
load_status = self._saver.restore(sess=None, save_path=checkpoint_path)
523+
# Check all model vars are matched from the checkpoint
524+
load_status.assert_existing_objects_matched()
507525
tf.logging.info('Load checkpoint done.')
508526

509527
def Save(self, sess=None, gsteps=None):

lingvo/core/program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def TpuTrainStep(self, *args):
539539
with py_utils.OpportunisticVariableReuseScope(True):
540540
with contextlib.ExitStack() as stack:
541541
if py_utils.IsEagerMode():
542-
stack.enter_context(py_utils.GradientTape())
542+
stack.enter_context(py_utils.GradientTape(persistent=True))
543543
self._model.ConstructFPropBPropGraph()
544544
per_step_eval_metrics = self._eval_metrics.SetMetrics(
545545
self._task.eval_metrics, args)

lingvo/eager_runners.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def Start(self):
4545

4646
@tf.function(autograph=False)
4747
def TrainFunc():
48-
with py_utils.GradientTape():
48+
with py_utils.GradientTape(persistent=True):
4949
model.ConstructFPropBPropGraph()
5050
return task.eval_metrics, task.per_example_tensors
5151

@@ -146,7 +146,7 @@ def Start(self):
146146
@tf.function(autograph=False)
147147
def ModelFunc():
148148
with self._summary_writer.as_default():
149-
with py_utils.GradientTape():
149+
with py_utils.GradientTape(persistent=True):
150150
model.ConstructFPropBPropGraph()
151151
return task.eval_metrics
152152

0 commit comments

Comments
 (0)