Skip to content

Commit b993d27

Browse files
authored
feat: vmap optimizer (borisdayma#166)
1 parent 2f1e5d9 commit b993d27

File tree

4 files changed

+156
-115
lines changed

4 files changed

+156
-115
lines changed

src/dalle_mini/model/modeling.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -946,15 +946,6 @@ def __call__(
946946
if output_hidden_states:
947947
all_hidden_states += (hidden_states,)
948948

949-
# postln is already applied in every layer
950-
if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
951-
hidden_states = norm(
952-
self.config.ln_type,
953-
dtype=self.dtype,
954-
epsilon=1e-05,
955-
use_scale=self.config.force_ln_scale,
956-
)(hidden_states)
957-
958949
outputs = [
959950
hidden_states,
960951
all_hidden_states,
@@ -1034,7 +1025,7 @@ def __call__(
10341025
self.config,
10351026
dtype=self.dtype,
10361027
add_norm=self.config.ln_positions == "postln",
1037-
name="FlaxBartEncoderLayers",
1028+
name="FlaxBartDecoderLayers",
10381029
)(
10391030
hidden_states,
10401031
attention_mask,
@@ -1086,15 +1077,6 @@ def __call__(
10861077
if output_hidden_states:
10871078
all_hidden_states += (hidden_states,)
10881079

1089-
# postln is already applied in every layer
1090-
if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1091-
hidden_states = norm(
1092-
self.config.ln_type,
1093-
dtype=self.dtype,
1094-
epsilon=1e-05,
1095-
use_scale=self.config.force_ln_scale,
1096-
)(hidden_states)
1097-
10981080
outputs = [
10991081
hidden_states,
11001082
all_hidden_states,
@@ -1146,6 +1128,17 @@ def setup(self):
11461128
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
11471129
)
11481130

1131+
# postln is already applied in every layer
1132+
if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
1133+
self.final_ln = norm(
1134+
self.config.ln_type,
1135+
dtype=self.dtype,
1136+
epsilon=1e-05,
1137+
use_scale=self.config.force_ln_scale,
1138+
)
1139+
else:
1140+
self.final_ln = None
1141+
11491142
def __call__(
11501143
self,
11511144
input_ids,
@@ -1177,11 +1170,16 @@ def __call__(
11771170
return_dict=return_dict,
11781171
)
11791172

1173+
if self.final_ln is None:
1174+
final_output = outputs[0]
1175+
else:
1176+
final_output = self.final_ln(outputs[0])
1177+
11801178
if not return_dict:
1181-
return outputs
1179+
return (final_output,) + outputs[1:]
11821180

11831181
return FlaxBaseModelOutput(
1184-
last_hidden_state=outputs.last_hidden_state,
1182+
last_hidden_state=final_output,
11851183
hidden_states=outputs.hidden_states,
11861184
attentions=outputs.attentions,
11871185
)
@@ -1223,6 +1221,15 @@ def setup(self):
12231221
self.config.ln_type, dtype=self.dtype, epsilon=1e-05
12241222
)
12251223

1224+
# postln is already applied in every layer
1225+
if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1226+
self.final_ln = norm(
1227+
self.config.ln_type,
1228+
dtype=self.dtype,
1229+
epsilon=1e-05,
1230+
use_scale=self.config.force_ln_scale,
1231+
)
1232+
12261233
def __call__(
12271234
self,
12281235
input_ids,
@@ -1260,11 +1267,16 @@ def __call__(
12601267
return_dict=return_dict,
12611268
)
12621269

1270+
if self.final_ln is None:
1271+
final_output = outputs[0]
1272+
else:
1273+
final_output = self.final_ln(outputs[0])
1274+
12631275
if not return_dict:
1264-
return outputs
1276+
return (final_output,) + outputs[1:]
12651277

12661278
return FlaxBaseModelOutputWithPastAndCrossAttentions(
1267-
last_hidden_state=outputs.last_hidden_state,
1279+
last_hidden_state=final_output,
12681280
hidden_states=outputs.hidden_states,
12691281
attentions=outputs.attentions,
12701282
cross_attentions=outputs.cross_attentions,

src/dalle_mini/model/partitions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def set_partitions(in_dict, use_scan):
6565
print(f"Unmatched -> {k}")
6666
l = list(result.keys())
6767
if use_scan:
68-
# add None dimension to scanned layers
68+
# add None dimension to layers
6969
result = {
7070
k: (P(*(None,) + v) if v is not None else None)
7171
if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])

tools/train/config/mega/config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
"decoder_attention_heads": 32,
88
"decoder_ffn_dim": 4096,
99
"decoder_layerdrop": 0.0,
10-
"decoder_layers": 26,
10+
"decoder_layers": 24,
1111
"decoder_start_token_id": 16384,
1212
"do_sample": true,
1313
"dropout": 0.0,
1414
"encoder_attention_heads": 32,
1515
"encoder_ffn_dim": 4096,
1616
"encoder_layerdrop": 0.0,
17-
"encoder_layers": 26,
17+
"encoder_layers": 24,
1818
"encoder_vocab_size": 50272,
1919
"eos_token_id": 16385,
2020
"force_ln_scale": false,

0 commit comments

Comments
 (0)