@@ -531,8 +531,6 @@ def main():
531531 # Set up our new model config
532532 if model_args .config_name :
533533 config = DalleBartConfig .from_pretrained (model_args .config_name )
534- # initializing params with gradient checkpointing create issues
535- config .gradient_checkpointing = False
536534 else :
537535 config = None
538536
@@ -545,25 +543,28 @@ def main():
545543 dtype = getattr (jnp , model_args .dtype ),
546544 abstract_init = True ,
547545 load_on_cpu = True ,
546+ # initializing params with gradient checkpointing creates issues
547+ # we correctly set it later per training_args
548+ gradient_checkpointing = False ,
548549 )
549550 else :
550551 model = DalleBart (
551552 config ,
552553 seed = training_args .seed_model ,
553554 dtype = getattr (jnp , model_args .dtype ),
554555 load_on_cpu = True ,
556+ gradient_checkpointing = False ,
555557 )
556558
557559 # update model config per training args
558560 # Done after initialization of weights to avoid issues with remat
559561 # This is still considered correctly during training as function is pjitted
560562 model .config .gradient_checkpointing = training_args .gradient_checkpointing
561563
562- # eval model cannot use remat
563- eval_config = copy .deepcopy (model .config )
564- eval_config .gradient_checkpointing = False
565-
566564 if training_args .gradient_checkpointing :
565+ # eval model cannot use remat
566+ eval_config = copy .deepcopy (model .config )
567+ eval_config .gradient_checkpointing = False
567568 eval_model = DalleBart (
568569 eval_config ,
569570 seed = training_args .seed_model ,
0 commit comments