Skip to content

Commit 72c3557

Browse files
committed
Trainer config fixes
1 parent 57cdfcd commit 72c3557

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

config/ds_z3_bf16_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"stage3_param_persistence_threshold": "auto",
3030
"stage3_max_live_parameters": 1e9,
3131
"stage3_max_reuse_distance": 1e9,
32-
"stage3_gather_16bit_weights_on_model_save": false
32+
"stage3_gather_16bit_weights_on_model_save": true
3333
},
3434
"gradient_accumulation_steps": "auto",
3535
"gradient_clipping": "auto",

training/trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ def train(
181181
evaluation_strategy="steps",
182182
eval_steps=10,
183183
save_strategy="steps",
184-
save_steps=1000,
185-
save_total_limit=None,
184+
save_steps=200,
185+
save_total_limit=1,
186186
load_best_model_at_end=True,
187187
report_to="tensorboard",
188188
disable_tqdm=True,
@@ -204,6 +204,9 @@ def train(
204204
logger.info("Training")
205205
trainer.train()
206206

207+
logger.info(f"Saving Model to {local_output_dir}")
208+
trainer.save_model(output_dir=local_output_dir)
209+
207210
if dbfs_output_dir:
208211
logger.info(f"Saving Model to {dbfs_output_dir}")
209212
trainer.save_model(output_dir=dbfs_output_dir)

0 commit comments

Comments
 (0)