Skip to content

Commit 89b4c45

Browse files
committed
feat(train): arg to offset lr for resumed runs
1 parent 23c1ef6 commit 89b4c45

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tools/train/train.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,12 @@ class TrainingArguments:
406406
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
407407
},
408408
)
409-
409+
lr_resume_offset: bool = field(
410+
default=False,
411+
metadata={
412+
"help": "Whether to offset the learning rate function with current step when resuming a run."
413+
},
414+
)
410415
logging_steps: int = field(
411416
default=40, metadata={"help": "Log every X updates steps."}
412417
)
@@ -781,7 +786,7 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
781786
transition_steps=training_args.warmup_steps + 1, # ensure not 0
782787
)
783788
# offset step when resuming
784-
if model_metadata.get("step", 0):
789+
if model_metadata.get("step", 0) and training_args.lr_resume_offset:
785790
warmup_fn = optax.join_schedules(
786791
schedules=[optax.constant_schedule(0.0), warmup_fn],
787792
boundaries=[model_metadata["step"]],

0 commit comments

Comments
 (0)