Skip to content

Commit 2e02683

Browse files
committed
fix: no gradient checkpointing for new model
1 parent b798ed3 commit 2e02683

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tools/train/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,9 @@ def main():
531531
# Set up our new model config
532532
if model_args.config_name:
533533
config = DalleBartConfig.from_pretrained(model_args.config_name)
534+
# initializing params with gradient checkpointing creates issues
535+
# we correctly set it later per training_args
536+
config.gradient_checkpointing = False
534537
else:
535538
config = None
536539

@@ -553,7 +556,6 @@ def main():
553556
seed=training_args.seed_model,
554557
dtype=getattr(jnp, model_args.dtype),
555558
load_on_cpu=True,
556-
gradient_checkpointing=False,
557559
)
558560

559561
# update model config per training args

0 commit comments

Comments
 (0)