Skip to content

Commit b798ed3

Browse files
committed
feat: no gradient checkpointing for params init
1 parent 79557f9 commit b798ed3

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tools/train/train.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -531,8 +531,6 @@ def main():
531531
# Set up our new model config
532532
if model_args.config_name:
533533
config = DalleBartConfig.from_pretrained(model_args.config_name)
534-
# initializing params with gradient checkpointing create issues
535-
config.gradient_checkpointing = False
536534
else:
537535
config = None
538536

@@ -545,25 +543,28 @@ def main():
545543
dtype=getattr(jnp, model_args.dtype),
546544
abstract_init=True,
547545
load_on_cpu=True,
546+
# initializing params with gradient checkpointing creates issues
547+
# we correctly set it later per training_args
548+
gradient_checkpointing=False,
548549
)
549550
else:
550551
model = DalleBart(
551552
config,
552553
seed=training_args.seed_model,
553554
dtype=getattr(jnp, model_args.dtype),
554555
load_on_cpu=True,
556+
gradient_checkpointing=False,
555557
)
556558

557559
# update model config per training args
558560
# Done after initialization of weights to avoid issues with remat
559561
# This is still considered correctly during training as function is pjitted
560562
model.config.gradient_checkpointing = training_args.gradient_checkpointing
561563

562-
# eval model cannot use remat
563-
eval_config = copy.deepcopy(model.config)
564-
eval_config.gradient_checkpointing = False
565-
566564
if training_args.gradient_checkpointing:
565+
# eval model cannot use remat
566+
eval_config = copy.deepcopy(model.config)
567+
eval_config.gradient_checkpointing = False
567568
eval_model = DalleBart(
568569
eval_config,
569570
seed=training_args.seed_model,

0 commit comments

Comments
 (0)