We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9f5e879 commit bc4734fCopy full SHA for bc4734f
tools/train/train.py
@@ -688,7 +688,8 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
688
staircase=training_args.lr_staircase,
689
)
690
schedule_fn = optax.join_schedules(
691
- schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
+ schedules=[warmup_fn, decay_fn],
692
+ boundaries=[model_metadata.get("step", 0) + training_args.warmup_steps],
693
694
return schedule_fn
695
0 commit comments