@@ -619,6 +619,9 @@ def __call__(
619619 deterministic : bool = True ,
620620 ) -> Tuple [jnp .ndarray ]:
621621
622+ if self .config .use_scan :
623+ hidden_states = hidden_states [0 ]
624+
622625 res_gain = (
623626 deepnet_gain ["encoder" ]["alpha" ](self .config )
624627 if self .config .use_deepnet_scaling
@@ -679,12 +682,8 @@ def __call__(
679682 )
680683 hidden_states = ff_block (hidden_states , deterministic = deterministic )
681684 hidden_states = residual * res_gain + hidden_states
682- if self .add_norm or self .config .ln_positions in ["postln" ]:
683- use_scale = (
684- self .use_scale
685- or self .config .ln_positions == "postln"
686- or self .config .force_ln_scale
687- )
685+ if self .add_norm :
686+ use_scale = self .use_scale or self .config .force_ln_scale
688687 hidden_states = norm (
689688 self .config .ln_type ,
690689 dtype = self .dtype ,
@@ -697,6 +696,9 @@ def __call__(
697696 if output_attentions :
698697 outputs += (attn_weights ,)
699698
699+ if self .config .use_scan :
700+ outputs = (outputs , None )
701+
700702 return outputs
701703
702704
@@ -710,7 +712,7 @@ class FlaxBartDecoderLayer(nn.Module):
710712 config : DalleBartConfig
711713 dtype : jnp .dtype = jnp .float32
712714 add_norm : bool = False
713- use_scale : bool = False
715+ use_scale : bool = True
714716
715717 @nn .compact
716718 def __call__ (
@@ -724,6 +726,9 @@ def __call__(
724726 deterministic : bool = True ,
725727 ) -> Tuple [jnp .ndarray ]:
726728
729+ if self .config .use_scan :
730+ hidden_states = hidden_states [0 ]
731+
727732 res_gain = (
728733 deepnet_gain ["decoder" ]["alpha" ](self .config )
729734 if self .config .use_deepnet_scaling
@@ -831,12 +836,8 @@ def __call__(
831836 )
832837 hidden_states = ff_block (hidden_states , deterministic = deterministic )
833838 hidden_states = residual * res_gain + hidden_states
834- if self .add_norm or self .config .ln_positions in ["postln" ]:
835- use_scale = (
836- self .use_scale
837- or self .config .ln_positions == "postln"
838- or self .config .force_ln_scale
839- )
839+ if self .add_norm :
840+ use_scale = self .use_scale or self .config .force_ln_scale
840841 hidden_states = norm (
841842 self .config .ln_type ,
842843 dtype = self .dtype ,
@@ -849,6 +850,9 @@ def __call__(
849850 if output_attentions :
850851 outputs += (attn_weights , cross_attn_weights )
851852
853+ if self .config .use_scan :
854+ outputs = (outputs , None )
855+
852856 return outputs
853857
854858
@@ -876,35 +880,80 @@ def __call__(
876880
877881 n_layers = self .config .encoder_layers
878882 layer = (
879- remat (FlaxBartEncoderLayer , static_argnums = (2 , 3 ))
883+ remat (
884+ FlaxBartEncoderLayer ,
885+ static_argnums = (2 , 3 ),
886+ prevent_cse = not self .config .use_scan ,
887+ )
880888 if self .config .gradient_checkpointing
881889 else FlaxBartEncoderLayer
882890 )
883- for i in range (n_layers ):
884- if output_hidden_states :
885- all_hidden_states += (hidden_states ,)
886- # final layernorm on the output of the last layer
887- # or every 6 layers for Swin v2
888- add_norm = (
889- self .config .ln_positions == "swinv2" and ((i + 1 ) % 6 == 0 )
890- ) or (self .config .use_final_ln_encoder and (i == n_layers - 1 ))
891- # we don't need to scale the norm for the last layer
892- use_scale = i != n_layers - 1
893- layer_outputs = layer (
894- self .config , dtype = self .dtype , add_norm = add_norm , use_scale = use_scale
891+
892+ if self .config .use_scan :
893+ # all blocks are the same so we use nn.scan
894+ assert not output_attentions , "cannot scan with output_attentions"
895+ assert not output_hidden_states , "cannot scan with output_hidden_states"
896+ hidden_states = (hidden_states ,)
897+ # we use a scale on all norms (even last layer) to allow scanning
898+ hidden_states , _ = nn .scan (
899+ layer ,
900+ variable_axes = {"params" : 0 },
901+ split_rngs = {"params" : True , "dropout" : True },
902+ in_axes = (nn .broadcast , nn .broadcast , nn .broadcast ),
903+ length = n_layers ,
904+ )(
905+ self .config ,
906+ dtype = self .dtype ,
907+ add_norm = self .config .ln_positions == "postln" ,
908+ name = "FlaxBartEncoderLayers" ,
895909 )(
896910 hidden_states ,
897911 attention_mask ,
898912 output_attentions ,
899913 deterministic ,
900914 )
901- hidden_states = layer_outputs [0 ]
902- if output_attentions :
903- all_self_attns += (layer_outputs [1 ],)
915+ hidden_states = hidden_states [0 ]
916+ else :
917+ for i in range (n_layers ):
918+ if output_hidden_states :
919+ all_hidden_states += (hidden_states ,)
920+ # final layernorm on the output of the last layer
921+ # or every 6 layers for Swin v2
922+ add_norm = self .config .ln_positions == "postln" or (
923+ self .config .ln_positions == "swinv2"
924+ and ((i + 1 ) % 6 == 0 )
925+ and (i != n_layers - 1 )
926+ )
927+ # we don't need to scale the norm for the last layer
928+ use_scale = i != n_layers - 1
929+ layer_outputs = layer (
930+ self .config ,
931+ dtype = self .dtype ,
932+ add_norm = add_norm ,
933+ use_scale = use_scale ,
934+ name = f"FlaxBartEncoderLayer_{ i } " ,
935+ )(
936+ hidden_states ,
937+ attention_mask ,
938+ output_attentions ,
939+ deterministic ,
940+ )
941+ hidden_states = layer_outputs [0 ]
942+ if output_attentions :
943+ all_self_attns += (layer_outputs [1 ],)
904944
905- # add hidden states from the last layer
906- if output_hidden_states :
907- all_hidden_states += (hidden_states ,)
945+ # add hidden states from the last layer
946+ if output_hidden_states :
947+ all_hidden_states += (hidden_states ,)
948+
949+ # postln is already applied in every layer
950+ if self .config .use_final_ln_encoder and self .config .ln_positions != "postln" :
951+ hidden_states = norm (
952+ self .config .ln_type ,
953+ dtype = self .dtype ,
954+ epsilon = 1e-05 ,
955+ use_scale = self .config .force_ln_scale ,
956+ )(hidden_states )
908957
909958 outputs = [
910959 hidden_states ,
@@ -953,22 +1002,39 @@ def __call__(
9531002
9541003 n_layers = self .config .decoder_layers
9551004 layer = (
956- remat (FlaxBartDecoderLayer , static_argnums = (4 , 5 , 6 ))
1005+ remat (
1006+ FlaxBartDecoderLayer ,
1007+ static_argnums = (4 , 5 , 6 ),
1008+ prevent_cse = not self .config .use_scan ,
1009+ )
9571010 if self .config .gradient_checkpointing
9581011 else FlaxBartDecoderLayer
9591012 )
960- for i in range (n_layers ):
961- if output_hidden_states :
962- all_hidden_states += (hidden_states ,)
963- # final layernorm on the output of the last layer
964- # or every 6 layers for Swin v2
965- add_norm = (
966- self .config .ln_positions == "swinv2" and ((i + 1 ) % 6 == 0 )
967- ) or (self .config .use_final_ln_decoder and (i == n_layers - 1 ))
968- # we don't need to scale the norm for the last layer
969- use_scale = i != n_layers - 1
970- layer_outputs = layer (
971- self .config , dtype = self .dtype , add_norm = add_norm , use_scale = use_scale
1013+
1014+ if self .config .use_scan :
1015+ # all blocks are the same so we use nn.scan
1016+ assert not output_attentions , "cannot scan with output_attentions"
1017+ assert not output_hidden_states , "cannot scan with output_hidden_states"
1018+ hidden_states = (hidden_states ,)
1019+ # we use a scale on all norms (even last layer) to allow scanning
1020+ hidden_states , _ = nn .scan (
1021+ layer ,
1022+ variable_axes = {"params" : 0 },
1023+ split_rngs = {"params" : True , "dropout" : True },
1024+ in_axes = (
1025+ nn .broadcast ,
1026+ nn .broadcast ,
1027+ nn .broadcast ,
1028+ nn .broadcast ,
1029+ nn .broadcast ,
1030+ nn .broadcast ,
1031+ ),
1032+ length = n_layers ,
1033+ )(
1034+ self .config ,
1035+ dtype = self .dtype ,
1036+ add_norm = self .config .ln_positions == "postln" ,
1037+ name = "FlaxBartEncoderLayers" ,
9721038 )(
9731039 hidden_states ,
9741040 attention_mask ,
@@ -978,17 +1044,56 @@ def __call__(
9781044 output_attentions ,
9791045 deterministic ,
9801046 )
1047+ hidden_states = hidden_states [0 ]
9811048
982- hidden_states = layer_outputs [0 ]
983- if output_attentions :
984- all_self_attns += (layer_outputs [1 ],)
1049+ else :
1050+ for i in range (n_layers ):
1051+ if output_hidden_states :
1052+ all_hidden_states += (hidden_states ,)
1053+ # final layernorm on the output of the last layer
1054+ # or every 6 layers for Swin v2
1055+ add_norm = self .config .ln_positions == "postln" or (
1056+ self .config .ln_positions == "swinv2"
1057+ and ((i + 1 ) % 6 == 0 )
1058+ and (i != n_layers - 1 )
1059+ )
1060+ # we don't need to scale the norm for the last layer
1061+ use_scale = i != n_layers - 1
1062+ layer_outputs = layer (
1063+ self .config ,
1064+ dtype = self .dtype ,
1065+ add_norm = add_norm ,
1066+ use_scale = use_scale ,
1067+ name = f"FlaxBartDecoderLayer_{ i } " ,
1068+ )(
1069+ hidden_states ,
1070+ attention_mask ,
1071+ encoder_hidden_states ,
1072+ encoder_attention_mask ,
1073+ init_cache ,
1074+ output_attentions ,
1075+ deterministic ,
1076+ )
1077+
1078+ hidden_states = layer_outputs [0 ]
1079+ if output_attentions :
1080+ all_self_attns += (layer_outputs [1 ],)
1081+
1082+ if encoder_hidden_states is not None :
1083+ all_cross_attentions += (layer_outputs [2 ],)
9851084
986- if encoder_hidden_states is not None :
987- all_cross_attentions += (layer_outputs [2 ],)
1085+ # add hidden states from the last decoder layer
1086+ if output_hidden_states :
1087+ all_hidden_states += (hidden_states ,)
9881088
989- # add hidden states from the last decoder layer
990- if output_hidden_states :
991- all_hidden_states += (hidden_states ,)
1089+ # postln is already applied in every layer
1090+ if self .config .use_final_ln_decoder and self .config .ln_positions != "postln" :
1091+ hidden_states = norm (
1092+ self .config .ln_type ,
1093+ dtype = self .dtype ,
1094+ epsilon = 1e-05 ,
1095+ use_scale = self .config .force_ln_scale ,
1096+ )(hidden_states )
9921097
9931098 outputs = [
9941099 hidden_states ,
0 commit comments