Skip to content

Commit c3e93df

Browse files
authored
feat: support LR offset (borisdayma#174)
1 parent 79a3849 commit c3e93df

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

tools/train/train.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def __post_init__(self):
119119
), "Restoring state only available with W&B artifact reference"
120120

121121
def get_metadata(self):
122-
if self.restore_state:
122+
if ":" in self.model_name_or_path:
123123
if jax.process_index() == 0:
124124
artifact = wandb.run.use_artifact(self.model_name_or_path)
125125
else:
@@ -413,11 +413,9 @@ class TrainingArguments:
413413
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
414414
},
415415
)
416-
lr_resume_offset: bool = field(
417-
default=False,
418-
metadata={
419-
"help": "Whether to offset the learning rate function with current step when resuming a run."
420-
},
416+
lr_offset: int = field(
417+
default=0,
418+
metadata={"help": "Number of steps to offset learning rate and keep it at 0."},
421419
)
422420
logging_steps: int = field(
423421
default=40, metadata={"help": "Log every X updates steps."}
@@ -796,14 +794,14 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
796794
end_value=training_args.learning_rate,
797795
transition_steps=training_args.warmup_steps + 1, # ensure not 0
798796
)
799-
# offset step when resuming
800797
last_boundary = training_args.warmup_steps
801-
if model_metadata.get("step", 0) and training_args.lr_resume_offset:
798+
# offset step when resuming
799+
if training_args.lr_offset:
802800
warmup_fn = optax.join_schedules(
803801
schedules=[optax.constant_schedule(0.0), warmup_fn],
804-
boundaries=[model_metadata["step"]],
802+
boundaries=[training_args.lr_offset],
805803
)
806-
last_boundary += model_metadata["step"]
804+
last_boundary += training_args.lr_offset
807805
if training_args.lr_decay is None:
808806
return warmup_fn
809807
elif training_args.lr_decay == "linear":
@@ -1005,6 +1003,14 @@ def maybe_init_params(params):
10051003

10061004
with mesh:
10071005
logger.info(" Creating state")
1006+
1007+
# restore metadata
1008+
attr_state = {}
1009+
keys = ["train_time", "train_samples"]
1010+
if model_args.restore_state:
1011+
keys += ["step", "epoch"]
1012+
attr_state = {k: v for k, v in model_metadata.items() if k in keys}
1013+
10081014
if not model_args.restore_state:
10091015

10101016
def init_state(params):
@@ -1013,6 +1019,7 @@ def init_state(params):
10131019
tx=optimizer,
10141020
params=maybe_init_params(params),
10151021
dropout_rng=dropout_rng,
1022+
**attr_state,
10161023
)
10171024

10181025
state = pjit(
@@ -1028,12 +1035,6 @@ def init_state(params):
10281035
# load opt_state
10291036
opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
10301037

1031-
# restore other attributes
1032-
attr_state = {
1033-
k: model_metadata[k]
1034-
for k in ["step", "epoch", "train_time", "train_samples"]
1035-
}
1036-
10371038
def restore_state(params, opt_state):
10381039
return TrainState(
10391040
apply_fn=model.__call__,

0 commit comments

Comments
 (0)