@@ -444,7 +444,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
444444 w = nn .Dense (
445445 self .ffn_dim ,
446446 dtype = self .dtype ,
447- use_bias = False ,
447+ use_bias = self . config . use_bias ,
448448 kernel_init = deepnet_init (gain )
449449 if self .config .use_deepnet_scaling
450450 else jax .nn .initializers .normal (self .config .init_std ),
@@ -453,7 +453,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
453453 v = nn .Dense (
454454 self .ffn_dim ,
455455 dtype = self .dtype ,
456- use_bias = False ,
456+ use_bias = self . config . use_bias ,
457457 kernel_init = deepnet_init (gain )
458458 if self .config .use_deepnet_scaling
459459 else jax .nn .initializers .normal (self .config .init_std ),
@@ -473,7 +473,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
473473 x = nn .Dense (
474474 self .embed_dim ,
475475 dtype = self .dtype ,
476- use_bias = False ,
476+ use_bias = self . config . use_bias ,
477477 kernel_init = deepnet_init (gain )
478478 if self .config .use_deepnet_scaling
479479 else jax .nn .initializers .normal (self .config .init_std ),
@@ -509,7 +509,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
509509 x = nn .Dense (
510510 self .ffn_dim ,
511511 dtype = self .dtype ,
512- use_bias = False ,
512+ use_bias = self . config . use_bias ,
513513 kernel_init = deepnet_init (gain )
514514 if self .config .use_deepnet_scaling
515515 else jax .nn .initializers .normal (self .config .init_std ),
@@ -528,7 +528,7 @@ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
528528 x = nn .Dense (
529529 self .embed_dim ,
530530 dtype = self .dtype ,
531- use_bias = False ,
531+ use_bias = self . config . use_bias ,
532532 kernel_init = deepnet_init (gain )
533533 if self .config .use_deepnet_scaling
534534 else jax .nn .initializers .normal (self .config .init_std ),
@@ -580,7 +580,7 @@ def __call__(
580580 embed_dim = embed_dim ,
581581 num_heads = self .config .encoder_attention_heads ,
582582 dropout = self .config .attention_dropout ,
583- bias = False ,
583+ bias = self . config . use_bias ,
584584 dtype = self .dtype ,
585585 is_encoder = True ,
586586 )(hidden_states = hidden_states , attention_mask = attention_mask )
@@ -686,7 +686,7 @@ def __call__(
686686 num_heads = self .config .decoder_attention_heads ,
687687 dropout = self .config .attention_dropout ,
688688 causal = True ,
689- bias = False ,
689+ bias = self . config . use_bias ,
690690 dtype = self .dtype ,
691691 is_encoder = False ,
692692 )(
@@ -724,7 +724,7 @@ def __call__(
724724 embed_dim = embed_dim ,
725725 num_heads = self .config .decoder_attention_heads ,
726726 dropout = self .config .attention_dropout ,
727- bias = False ,
727+ bias = self . config . use_bias ,
728728 dtype = self .dtype ,
729729 is_encoder = False ,
730730 )(
0 commit comments