Skip to content

Commit 61d1031

Browse files
authored
fix: use correct version of orbax (borisdayma#337)
1 parent fd818d2 commit 61d1031

File tree

5 files changed

+561
-564
lines changed

5 files changed

+561
-564
lines changed

app/gradio/app.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def infer(prompt):
2222
with gr.Group():
2323
with gr.Box():
2424
with gr.Row().style(mobile_collapse=False, equal_height=True):
25-
2625
text = gr.Textbox(
2726
label="Enter your prompt", show_label=False, max_lines=1
2827
).style(

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ install_requires =
3131
pillow
3232
jax==0.3.25
3333
flax==0.6.3
34+
orbax==0.0.23
3435
wandb
3536

3637
[options.extras_require]

src/dalle_mini/model/modeling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _smelu(x: Any) -> Any:
7777

7878
ACT2FN.update({"smelu": smelu()})
7979

80+
8081
# deepnet initialization
8182
def deepnet_init(init_std, gain=1):
8283
init = jax.nn.initializers.normal(init_std)
@@ -498,7 +499,6 @@ class GLU(nn.Module):
498499

499500
@nn.compact
500501
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
501-
502502
if self.config.use_deepnet_scaling:
503503
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
504504
self.config
@@ -567,7 +567,6 @@ class FFN(nn.Module):
567567

568568
@nn.compact
569569
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
570-
571570
if self.config.use_deepnet_scaling:
572571
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](
573572
self.config
@@ -634,7 +633,6 @@ def __call__(
634633
output_attentions: bool = True,
635634
deterministic: bool = True,
636635
) -> Tuple[jnp.ndarray]:
637-
638636
if self.config.use_scan:
639637
hidden_states = hidden_states[0]
640638

@@ -742,7 +740,6 @@ def __call__(
742740
output_attentions: bool = True,
743741
deterministic: bool = True,
744742
) -> Tuple[jnp.ndarray]:
745-
746743
if self.config.use_scan:
747744
hidden_states = hidden_states[0]
748745

0 commit comments

Comments
 (0)