@@ -131,7 +131,7 @@ def __post_init__(self):
131131 ), "Restoring state only available with W&B artifact reference"
132132
133133 def get_metadata (self ):
134- if ":" in self .model_name_or_path :
134+ if self . model_name_or_path is not None and ":" in self .model_name_or_path :
135135 if jax .process_index () == 0 :
136136 artifact = wandb .run .use_artifact (self .model_name_or_path )
137137 else :
@@ -685,12 +685,16 @@ def main():
685685 )
686686
687687 # Set up our new model config
688+ config_args = {
689+ k : getattr (model_args , k )
690+ for k in ["dropout" , "activation_dropout" , "attention_dropout" ]
691+ if getattr (model_args , k ) is not None
692+ }
688693 if model_args .config_name :
689694 config = DalleBartConfig .from_pretrained (model_args .config_name )
690695 config .gradient_checkpointing = training_args .gradient_checkpointing
691- config .dropout = model_args .dropout
692- config .activation_dropout = model_args .activation_dropout
693- config .attention_dropout = model_args .attention_dropout
696+ for k , v in config_args .items ():
697+ setattr (config , k , v )
694698 else :
695699 config = None
696700
@@ -703,9 +707,7 @@ def main():
703707 dtype = getattr (jnp , model_args .dtype ),
704708 _do_init = False , # we overwrite them with loaded checkpoint
705709 gradient_checkpointing = training_args .gradient_checkpointing ,
706- dropout = model_args .dropout ,
707- activation_dropout = model_args .activation_dropout ,
708- attention_dropout = model_args .attention_dropout ,
710+ ** config_args ,
709711 )
710712 else :
711713 model = DalleBart (
0 commit comments