Skip to content

Commit bc4734f

Browse files
committed
fix(train): consider schedule offset
1 parent 9f5e879 commit bc4734f

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tools/train/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,8 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
688688
staircase=training_args.lr_staircase,
689689
)
690690
schedule_fn = optax.join_schedules(
691-
schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
691+
schedules=[warmup_fn, decay_fn],
692+
boundaries=[model_metadata.get("step", 0) + training_args.warmup_steps],
692693
)
693694
return schedule_fn
694695

0 commit comments

Comments
 (0)