We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b798ed3 commit 2e02683Copy full SHA for 2e02683
tools/train/train.py
@@ -531,6 +531,9 @@ def main():
531
# Set up our new model config
532
if model_args.config_name:
533
config = DalleBartConfig.from_pretrained(model_args.config_name)
534
+ # initializing params with gradient checkpointing creates issues
535
+ # we correctly set it later per training_args
536
+ config.gradient_checkpointing = False
537
else:
538
config = None
539
@@ -553,7 +556,6 @@ def main():
553
556
seed=training_args.seed_model,
554
557
dtype=getattr(jnp, model_args.dtype),
555
558
load_on_cpu=True,
- gradient_checkpointing=False,
559
)
560
561
# update model config per training args
0 commit comments