Skip to content

Commit 80d791a

Browse files
committed
feat(train): allow editing dropout during training
1 parent b6f5026 commit 80d791a

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

tools/train/train.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ class ModelArguments:
106106
"help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
107107
},
108108
)
109+
dropout: Optional[float] = field(
110+
default=None,
111+
metadata={"help": "Dropout rate. Overwrites config."},
112+
)
113+
activation_dropout: Optional[float] = field(
114+
default=None,
115+
metadata={"help": "Activation dropout rate. Overwrites config."},
116+
)
117+
attention_dropout: Optional[float] = field(
118+
default=None,
119+
metadata={"help": "Attention dropout rate. Overwrites config."},
120+
)
109121

110122
def __post_init__(self):
111123
if self.tokenizer_name is None:
@@ -674,6 +686,9 @@ def main():
674686
if model_args.config_name:
675687
config = DalleBartConfig.from_pretrained(model_args.config_name)
676688
config.gradient_checkpointing = training_args.gradient_checkpointing
689+
config.dropout = model_args.dropout
690+
config.activation_dropout = model_args.activation_dropout
691+
config.attention_dropout = model_args.attention_dropout
677692
else:
678693
config = None
679694

@@ -686,6 +701,9 @@ def main():
686701
dtype=getattr(jnp, model_args.dtype),
687702
_do_init=False, # we overwrite them with loaded checkpoint
688703
gradient_checkpointing=training_args.gradient_checkpointing,
704+
dropout=model_args.dropout,
705+
activation_dropout=model_args.activation_dropout,
706+
attention_dropout=model_args.attention_dropout,
689707
)
690708
else:
691709
model = DalleBart(

0 commit comments

Comments
 (0)