@@ -193,6 +193,13 @@ def _ResolveCkptPath(ckpt_rules):
193193 return res_rules
194194
195195 self ._restore_fns = []
196+ if py_utils .IsEagerMode ():
197+ if self ._model :
198+ all_vars = list (self ._model .GetVariablesDict ().values ())
199+ else :
200+ raise TypeError ('self._model cannot be None in eager mode.' )
201+ else :
202+ all_vars = tf .global_variables ()
196203
197204 # Add graph nodes to restore specific variables based on
198205 # init_from_checkpoint_rules.
@@ -202,14 +209,16 @@ def _ResolveCkptPath(ckpt_rules):
202209 tp = task .params .train
203210 if tp .init_from_checkpoint_rules :
204211 rules = _ResolveCkptPath (tp .init_from_checkpoint_rules )
205- fn = py_utils .OverrideVarsFromCheckpoints (tf .global_variables (),
206- rules )
212+ fn = py_utils .OverrideVarsFromCheckpoints (all_vars , rules )
207213 self ._restore_fns .append ((f'OverrideVarsFromCheckpoints { rules } ' , fn ))
208214
209215 if self ._params and self ._params .train .init_from_checkpoint_rules :
216+ if self ._model is None :
217+ raise TypeError (
218+ 'self._model cannot be None when using init_from_checkpoint_rules.' )
210219 tp = self ._params .train
211220 rules = _ResolveCkptPath (tp .init_from_checkpoint_rules )
212- fn = py_utils .OverrideVarsFromCheckpoints (tf . global_variables () , rules )
221+ fn = py_utils .OverrideVarsFromCheckpoints (all_vars , rules )
213222 self ._restore_fns .append ((f'OverrideVarsFromCheckpoints { rules } ' , fn ))
214223
215224 @property
@@ -337,11 +346,9 @@ def Restore(self, sess=None, force_reinitialize=False):
337346 sess .run (self ._init_op )
338347 tf .logging .info ('Initialized all vars.' )
339348
340- if self ._restore_fns :
341- for msg , fn in self ._restore_fns :
342- tf .logging .info (msg )
343- fn (sess )
344- tf .logging .info ('Restored vars using checkpoint rules.' )
349+ for msg , fn in self ._restore_fns :
350+ tf .logging .info (msg )
351+ fn (sess )
345352 return None
346353
347354 def RestoreIfNeeded (self , sess ):
@@ -474,7 +481,16 @@ def _GetSaver(self):
474481 def Restore (self , sess = None , force_reinitialize = None ):
475482 """`sess` and `force_reinitialize` are unused in Eager context."""
476483 assert sess is None
477- return self ._RestoreFromLatestCheckpoint (sess )
484+ path = self ._RestoreFromLatestCheckpoint (sess )
485+ if path :
486+ tf .logging .info ('Eager checkpoint is restored with path: %s' , path )
487+ return path
488+ # No checkpoint is loaded, we need to initialize the variables,
489+ # and apply the init_from_checkpoint_rules if applicable.
490+ for msg , fn in self ._restore_fns :
491+ tf .logging .info (msg )
492+ fn (sess )
493+ return path
478494
479495 def RestoreGlobalStepIfNeeded (self , sess = None ):
480496 """`sess` is unused in Eager context."""
@@ -503,7 +519,9 @@ def RestoreFromPath(self, sess=None, checkpoint_path=None):
503519
504520 assert not self ._save_only
505521 tf .logging .info ('Load from checkpoint (V1) %s.' , checkpoint_path )
506- self ._saver .restore (sess = None , save_path = checkpoint_path )
522+ load_status = self ._saver .restore (sess = None , save_path = checkpoint_path )
523+ # Check all model vars are matched from the checkpoint
524+ load_status .assert_existing_objects_matched ()
507525 tf .logging .info ('Load checkpoint done.' )
508526
509527 def Save (self , sess = None , gsteps = None ):
0 commit comments