@@ -106,6 +106,18 @@ class ModelArguments:
106106 "help" : "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
107107 },
108108 )
109+ dropout : Optional [float ] = field (
110+ default = None ,
111+ metadata = {"help" : "Dropout rate. Overwrites config." },
112+ )
113+ activation_dropout : Optional [float ] = field (
114+ default = None ,
115+ metadata = {"help" : "Activation dropout rate. Overwrites config." },
116+ )
117+ attention_dropout : Optional [float ] = field (
118+ default = None ,
119+ metadata = {"help" : "Attention dropout rate. Overwrites config." },
120+ )
109121
110122 def __post_init__ (self ):
111123 if self .tokenizer_name is None :
@@ -674,6 +686,9 @@ def main():
674686 if model_args .config_name :
675687 config = DalleBartConfig .from_pretrained (model_args .config_name )
676688 config .gradient_checkpointing = training_args .gradient_checkpointing
689+ config .dropout = model_args .dropout
690+ config .activation_dropout = model_args .activation_dropout
691+ config .attention_dropout = model_args .attention_dropout
677692 else :
678693 config = None
679694
@@ -686,6 +701,9 @@ def main():
686701 dtype = getattr (jnp , model_args .dtype ),
687702 _do_init = False , # we overwrite them with loaded checkpoint
688703 gradient_checkpointing = training_args .gradient_checkpointing ,
704+ dropout = model_args .dropout ,
705+ activation_dropout = model_args .activation_dropout ,
706+ attention_dropout = model_args .attention_dropout ,
689707 )
690708 else :
691709 model = DalleBart (
0 commit comments