Skip to content

Commit 361a994

Browse files
authored
feat(model): allow bias (borisdayma#152)
1 parent 02b2308 commit 361a994

File tree

4 files changed

+15
-11
lines changed

4 files changed

+15
-11
lines changed

src/dalle_mini/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.0.3"
1+
__version__ = "0.0.4"
22

33
from .model import DalleBart, DalleBartProcessor

src/dalle_mini/model/configuration.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,15 @@ def __init__(
5858
tie_word_embeddings=False, # different modalities and sizes
5959
do_sample=True,
6060
# transformer variants
61+
use_bias=False, # use bias in attention and dense layers (except for lm_head)
6162
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
6263
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
6364
use_head_scale=False, # used in NormFormer
6465
use_cosine_attention=False, # used in Swin v2
6566
tau_init=0.05, # used only in cosine attention (Swin v2)
6667
use_deepnet_scaling=False, # used in Deepnet
6768
use_glu=False, # "GLU Variants Improve Transformer"
68-
use_alibi=False, # from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
69+
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
6970
sinkhorn_iters=1, # used in SinkFormers
7071
use_final_ln_encoder=False, # final layer normalization in encoder
7172
use_final_ln_decoder=False, # final layer normalization in decoder
@@ -77,7 +78,7 @@ def __init__(
7778
self.normalize_text = normalize_text
7879

7980
# transformer variants
80-
self.use_head_scale = use_head_scale # per Normformer
81+
self.use_bias = use_bias
8182
assert ln_type in [
8283
"rmsnorm",
8384
"layernorm",
@@ -92,6 +93,7 @@ def __init__(
9293
"postln",
9394
"preln",
9495
], "ln_positions must be 'normformer', 'swinv2', 'cogview', 'postln', 'preln'"
96+
self.use_head_scale = use_head_scale
9597
assert use_alibi is False, "use_alibi is not supported yet"
9698
self.ln_positions = ln_positions
9799
self.use_cosine_attention = use_cosine_attention

src/dalle_mini/model/modeling.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
444444
w = nn.Dense(
445445
self.ffn_dim,
446446
dtype=self.dtype,
447-
use_bias=False,
447+
use_bias=self.config.use_bias,
448448
kernel_init=deepnet_init(gain)
449449
if self.config.use_deepnet_scaling
450450
else jax.nn.initializers.normal(self.config.init_std),
@@ -453,7 +453,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
453453
v = nn.Dense(
454454
self.ffn_dim,
455455
dtype=self.dtype,
456-
use_bias=False,
456+
use_bias=self.config.use_bias,
457457
kernel_init=deepnet_init(gain)
458458
if self.config.use_deepnet_scaling
459459
else jax.nn.initializers.normal(self.config.init_std),
@@ -473,7 +473,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
473473
x = nn.Dense(
474474
self.embed_dim,
475475
dtype=self.dtype,
476-
use_bias=False,
476+
use_bias=self.config.use_bias,
477477
kernel_init=deepnet_init(gain)
478478
if self.config.use_deepnet_scaling
479479
else jax.nn.initializers.normal(self.config.init_std),
@@ -509,7 +509,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
509509
x = nn.Dense(
510510
self.ffn_dim,
511511
dtype=self.dtype,
512-
use_bias=False,
512+
use_bias=self.config.use_bias,
513513
kernel_init=deepnet_init(gain)
514514
if self.config.use_deepnet_scaling
515515
else jax.nn.initializers.normal(self.config.init_std),
@@ -528,7 +528,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
528528
x = nn.Dense(
529529
self.embed_dim,
530530
dtype=self.dtype,
531-
use_bias=False,
531+
use_bias=self.config.use_bias,
532532
kernel_init=deepnet_init(gain)
533533
if self.config.use_deepnet_scaling
534534
else jax.nn.initializers.normal(self.config.init_std),
@@ -580,7 +580,7 @@ def __call__(
580580
embed_dim=embed_dim,
581581
num_heads=self.config.encoder_attention_heads,
582582
dropout=self.config.attention_dropout,
583-
bias=False,
583+
bias=self.config.use_bias,
584584
dtype=self.dtype,
585585
is_encoder=True,
586586
)(hidden_states=hidden_states, attention_mask=attention_mask)
@@ -686,7 +686,7 @@ def __call__(
686686
num_heads=self.config.decoder_attention_heads,
687687
dropout=self.config.attention_dropout,
688688
causal=True,
689-
bias=False,
689+
bias=self.config.use_bias,
690690
dtype=self.dtype,
691691
is_encoder=False,
692692
)(
@@ -724,7 +724,7 @@ def __call__(
724724
embed_dim=embed_dim,
725725
num_heads=self.config.decoder_attention_heads,
726726
dropout=self.config.attention_dropout,
727-
bias=False,
727+
bias=self.config.use_bias,
728728
dtype=self.dtype,
729729
is_encoder=False,
730730
)(

tools/train/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from tqdm import tqdm
5050
from transformers import HfArgumentParser
5151

52+
import dalle_mini
5253
from dalle_mini.data import Dataset
5354
from dalle_mini.model import (
5455
DalleBart,
@@ -675,6 +676,7 @@ def main():
675676
"transformers": transformers.__version__,
676677
"datasets": datasets.__version__,
677678
"wandb": wandb.__version__,
679+
"dalle_mini": dalle_mini.__version__,
678680
},
679681
}
680682
)

0 commit comments

Comments
 (0)