Skip to content

Commit 89bc9d4

Browse files
committed
fix(train): overwrite dropout only when specified
1 parent 65bb95f commit 89bc9d4

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

tools/train/train.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __post_init__(self):
131131
), "Restoring state only available with W&B artifact reference"
132132

133133
def get_metadata(self):
134-
if ":" in self.model_name_or_path:
134+
if self.model_name_or_path is not None and ":" in self.model_name_or_path:
135135
if jax.process_index() == 0:
136136
artifact = wandb.run.use_artifact(self.model_name_or_path)
137137
else:
@@ -685,12 +685,16 @@ def main():
685685
)
686686

687687
# Set up our new model config
688+
config_args = {
689+
k: getattr(model_args, k)
690+
for k in ["dropout", "activation_dropout", "attention_dropout"]
691+
if getattr(model_args, k) is not None
692+
}
688693
if model_args.config_name:
689694
config = DalleBartConfig.from_pretrained(model_args.config_name)
690695
config.gradient_checkpointing = training_args.gradient_checkpointing
691-
config.dropout = model_args.dropout
692-
config.activation_dropout = model_args.activation_dropout
693-
config.attention_dropout = model_args.attention_dropout
696+
for k, v in config_args.items():
697+
setattr(config, k, v)
694698
else:
695699
config = None
696700

@@ -703,9 +707,7 @@ def main():
703707
dtype=getattr(jnp, model_args.dtype),
704708
_do_init=False, # we overwrite them with loaded checkpoint
705709
gradient_checkpointing=training_args.gradient_checkpointing,
706-
dropout=model_args.dropout,
707-
activation_dropout=model_args.activation_dropout,
708-
attention_dropout=model_args.attention_dropout,
710+
**config_args,
709711
)
710712
else:
711713
model = DalleBart(

0 commit comments

Comments
 (0)