File tree Expand file tree Collapse file tree 1 file changed +7
-2
lines changed
Expand file tree Collapse file tree 1 file changed +7
-2
lines changed Original file line number Diff line number Diff line change @@ -406,7 +406,12 @@ class TrainingArguments:
406406 "help" : "Whether to use staircase or continuous learning rate when using exponential decay."
407407 },
408408 )
409-
409+ lr_resume_offset : bool = field (
410+ default = False ,
411+ metadata = {
412+ "help" : "Whether to offset the learning rate function with current step when resuming a run."
413+ },
414+ )
410415 logging_steps : int = field (
411416 default = 40 , metadata = {"help" : "Log every X updates steps." }
412417 )
@@ -781,7 +786,7 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
781786 transition_steps = training_args .warmup_steps + 1 , # ensure not 0
782787 )
783788 # offset step when resuming
784- if model_metadata .get ("step" , 0 ):
789+ if model_metadata .get ("step" , 0 ) and training_args . lr_resume_offset :
785790 warmup_fn = optax .join_schedules (
786791 schedules = [optax .constant_schedule (0.0 ), warmup_fn ],
787792 boundaries = [model_metadata ["step" ]],
You can’t perform that action at this time.
0 commit comments