@@ -119,7 +119,7 @@ def __post_init__(self):
119119 ), "Restoring state only available with W&B artifact reference"
120120
121121 def get_metadata (self ):
122- if self .restore_state :
122+ if ":" in self .model_name_or_path :
123123 if jax .process_index () == 0 :
124124 artifact = wandb .run .use_artifact (self .model_name_or_path )
125125 else :
@@ -413,11 +413,9 @@ class TrainingArguments:
413413 "help" : "Whether to use staircase or continuous learning rate when using exponential decay."
414414 },
415415 )
416- lr_resume_offset : bool = field (
417- default = False ,
418- metadata = {
419- "help" : "Whether to offset the learning rate function with current step when resuming a run."
420- },
416+ lr_offset : int = field (
417+ default = 0 ,
418+ metadata = {"help" : "Number of steps to offset learning rate and keep it at 0." },
421419 )
422420 logging_steps : int = field (
423421 default = 40 , metadata = {"help" : "Log every X updates steps." }
@@ -796,14 +794,14 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
796794 end_value = training_args .learning_rate ,
797795 transition_steps = training_args .warmup_steps + 1 , # ensure not 0
798796 )
799- # offset step when resuming
800797 last_boundary = training_args .warmup_steps
801- if model_metadata .get ("step" , 0 ) and training_args .lr_resume_offset :
798+ # offset step when resuming
799+ if training_args .lr_offset :
802800 warmup_fn = optax .join_schedules (
803801 schedules = [optax .constant_schedule (0.0 ), warmup_fn ],
804- boundaries = [model_metadata [ "step" ] ],
802+ boundaries = [training_args . lr_offset ],
805803 )
806- last_boundary += model_metadata [ "step" ]
804+ last_boundary += training_args . lr_offset
807805 if training_args .lr_decay is None :
808806 return warmup_fn
809807 elif training_args .lr_decay == "linear" :
@@ -1005,6 +1003,14 @@ def maybe_init_params(params):
10051003
10061004 with mesh :
10071005 logger .info (" Creating state" )
1006+
1007+ # restore metadata
1008+ attr_state = {}
1009+ keys = ["train_time" , "train_samples" ]
1010+ if model_args .restore_state :
1011+ keys += ["step" , "epoch" ]
1012+ attr_state = {k : v for k , v in model_metadata .items () if k in keys }
1013+
10081014 if not model_args .restore_state :
10091015
10101016 def init_state (params ):
@@ -1013,6 +1019,7 @@ def init_state(params):
10131019 tx = optimizer ,
10141020 params = maybe_init_params (params ),
10151021 dropout_rng = dropout_rng ,
1022+ ** attr_state ,
10161023 )
10171024
10181025 state = pjit (
@@ -1028,12 +1035,6 @@ def init_state(params):
10281035 # load opt_state
10291036 opt_state = from_bytes (opt_state_shape , model_args .get_opt_state ())
10301037
1031- # restore other attributes
1032- attr_state = {
1033- k : model_metadata [k ]
1034- for k in ["step" , "epoch" , "train_time" , "train_samples" ]
1035- }
1036-
10371038 def restore_state (params , opt_state ):
10381039 return TrainState (
10391040 apply_fn = model .__call__ ,
0 commit comments