@@ -946,15 +946,6 @@ def __call__(
946946 if output_hidden_states :
947947 all_hidden_states += (hidden_states ,)
948948
949- # postln is already applied in every layer
950- if self .config .use_final_ln_encoder and self .config .ln_positions != "postln" :
951- hidden_states = norm (
952- self .config .ln_type ,
953- dtype = self .dtype ,
954- epsilon = 1e-05 ,
955- use_scale = self .config .force_ln_scale ,
956- )(hidden_states )
957-
958949 outputs = [
959950 hidden_states ,
960951 all_hidden_states ,
@@ -1034,7 +1025,7 @@ def __call__(
10341025 self .config ,
10351026 dtype = self .dtype ,
10361027 add_norm = self .config .ln_positions == "postln" ,
1037- name = "FlaxBartEncoderLayers " ,
1028+ name = "FlaxBartDecoderLayers " ,
10381029 )(
10391030 hidden_states ,
10401031 attention_mask ,
@@ -1086,15 +1077,6 @@ def __call__(
10861077 if output_hidden_states :
10871078 all_hidden_states += (hidden_states ,)
10881079
1089- # postln is already applied in every layer
1090- if self .config .use_final_ln_decoder and self .config .ln_positions != "postln" :
1091- hidden_states = norm (
1092- self .config .ln_type ,
1093- dtype = self .dtype ,
1094- epsilon = 1e-05 ,
1095- use_scale = self .config .force_ln_scale ,
1096- )(hidden_states )
1097-
10981080 outputs = [
10991081 hidden_states ,
11001082 all_hidden_states ,
@@ -1146,6 +1128,17 @@ def setup(self):
11461128 self .config .ln_type , dtype = self .dtype , epsilon = 1e-05
11471129 )
11481130
1131+ # postln is already applied in every layer
1132+ if self .config .use_final_ln_encoder and self .config .ln_positions != "postln" :
1133+ self .final_ln = norm (
1134+ self .config .ln_type ,
1135+ dtype = self .dtype ,
1136+ epsilon = 1e-05 ,
1137+ use_scale = self .config .force_ln_scale ,
1138+ )
1139+ else :
1140+ self .final_ln = None
1141+
11491142 def __call__ (
11501143 self ,
11511144 input_ids ,
@@ -1177,11 +1170,16 @@ def __call__(
11771170 return_dict = return_dict ,
11781171 )
11791172
1173+ if self .final_ln is None :
1174+ final_output = outputs [0 ]
1175+ else :
1176+ final_output = self .final_ln (outputs [0 ])
1177+
11801178 if not return_dict :
1181- return outputs
1179+ return ( final_output ,) + outputs [ 1 :]
11821180
11831181 return FlaxBaseModelOutput (
1184- last_hidden_state = outputs . last_hidden_state ,
1182+ last_hidden_state = final_output ,
11851183 hidden_states = outputs .hidden_states ,
11861184 attentions = outputs .attentions ,
11871185 )
@@ -1223,6 +1221,15 @@ def setup(self):
12231221 self .config .ln_type , dtype = self .dtype , epsilon = 1e-05
12241222 )
12251223
1224+ # postln is already applied in every layer
1225+ if self .config .use_final_ln_decoder and self .config .ln_positions != "postln" :
1226+ self .final_ln = norm (
1227+ self .config .ln_type ,
1228+ dtype = self .dtype ,
1229+ epsilon = 1e-05 ,
1230+ use_scale = self .config .force_ln_scale ,
1231+ )
1232+
12261233 def __call__ (
12271234 self ,
12281235 input_ids ,
@@ -1260,11 +1267,16 @@ def __call__(
12601267 return_dict = return_dict ,
12611268 )
12621269
1270+ if self .final_ln is None :
1271+ final_output = outputs [0 ]
1272+ else :
1273+ final_output = self .final_ln (outputs [0 ])
1274+
12631275 if not return_dict :
1264- return outputs
1276+ return ( final_output ,) + outputs [ 1 :]
12651277
12661278 return FlaxBaseModelOutputWithPastAndCrossAttentions (
1267- last_hidden_state = outputs . last_hidden_state ,
1279+ last_hidden_state = final_output ,
12681280 hidden_states = outputs .hidden_states ,
12691281 attentions = outputs .attentions ,
12701282 cross_attentions = outputs .cross_attentions ,
0 commit comments