Skip to content

Commit 07a6f9a

Browse files
authored
feat: scan layers + gradient checkpointing (borisdayma#161)
* scan layers for faster compilation * support gradient checkpointing
1 parent 0199604 commit 07a6f9a

File tree

5 files changed

+252
-64
lines changed

5 files changed

+252
-64
lines changed

src/dalle_mini/model/configuration.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,24 @@ def __init__(
5151
activation_dropout=0.0,
5252
init_std=0.02,
5353
scale_embedding=False,
54-
gradient_checkpointing=False,
54+
gradient_checkpointing=True,
55+
use_scan=None,
5556
use_cache=True,
5657
is_encoder_decoder=True,
5758
forced_eos_token_id=None,
5859
tie_word_embeddings=False, # different modalities and sizes
5960
do_sample=True,
6061
# transformer variants
6162
use_bias=False, # use bias in attention and dense layers (except for lm_head)
62-
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
63+
ln_type="rmsnorm", # layer normalization type, "rmsnorm", "layernorm"
6364
ln_positions="normformer", # layer normalization positions, "normformer", "swinv2", "cogview", "postln", "preln", "deepnet" (same as postln)
6465
use_head_scale=False, # used in NormFormer
6566
use_cosine_attention=False, # used in Swin v2
6667
tau_init=0.05, # used only in cosine attention (Swin v2)
6768
use_absolute_position_embeddings=True, # default
6869
use_swin_position_embeddings=False, # used in Swin v1/v2
6970
use_deepnet_scaling=False, # used in Deepnet
70-
use_glu=False, # "GLU Variants Improve Transformer"
71+
use_glu=True, # "GLU Variants Improve Transformer"
7172
use_alibi=False, # Not implemented yet - from "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation"
7273
sinkhorn_iters=1, # used in SinkFormers
7374
use_final_ln_encoder=True, # final layer normalization in encoder
@@ -136,6 +137,11 @@ def __init__(
136137
self.init_std = init_std
137138
self.use_cache = use_cache
138139
self.gradient_checkpointing = gradient_checkpointing
140+
# all layers are the same in most configurations
141+
self.use_scan = use_scan if use_scan is not None else ln_positions != "swinv2"
142+
assert not (
143+
self.use_scan and ln_positions == "swinv2"
144+
), "scan cannot be used with 'swinv2'"
139145
self.scale_embedding = (
140146
scale_embedding # scale factor will be sqrt(d_model) if True
141147
)

src/dalle_mini/model/modeling.py

Lines changed: 158 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,9 @@ def __call__(
619619
deterministic: bool = True,
620620
) -> Tuple[jnp.ndarray]:
621621

622+
if self.config.use_scan:
623+
hidden_states = hidden_states[0]
624+
622625
res_gain = (
623626
deepnet_gain["encoder"]["alpha"](self.config)
624627
if self.config.use_deepnet_scaling
@@ -679,12 +682,8 @@ def __call__(
679682
)
680683
hidden_states = ff_block(hidden_states, deterministic=deterministic)
681684
hidden_states = residual * res_gain + hidden_states
682-
if self.add_norm or self.config.ln_positions in ["postln"]:
683-
use_scale = (
684-
self.use_scale
685-
or self.config.ln_positions == "postln"
686-
or self.config.force_ln_scale
687-
)
685+
if self.add_norm:
686+
use_scale = self.use_scale or self.config.force_ln_scale
688687
hidden_states = norm(
689688
self.config.ln_type,
690689
dtype=self.dtype,
@@ -697,6 +696,9 @@ def __call__(
697696
if output_attentions:
698697
outputs += (attn_weights,)
699698

699+
if self.config.use_scan:
700+
outputs = (outputs, None)
701+
700702
return outputs
701703

702704

@@ -710,7 +712,7 @@ class FlaxBartDecoderLayer(nn.Module):
710712
config: DalleBartConfig
711713
dtype: jnp.dtype = jnp.float32
712714
add_norm: bool = False
713-
use_scale: bool = False
715+
use_scale: bool = True
714716

715717
@nn.compact
716718
def __call__(
@@ -724,6 +726,9 @@ def __call__(
724726
deterministic: bool = True,
725727
) -> Tuple[jnp.ndarray]:
726728

729+
if self.config.use_scan:
730+
hidden_states = hidden_states[0]
731+
727732
res_gain = (
728733
deepnet_gain["decoder"]["alpha"](self.config)
729734
if self.config.use_deepnet_scaling
@@ -831,12 +836,8 @@ def __call__(
831836
)
832837
hidden_states = ff_block(hidden_states, deterministic=deterministic)
833838
hidden_states = residual * res_gain + hidden_states
834-
if self.add_norm or self.config.ln_positions in ["postln"]:
835-
use_scale = (
836-
self.use_scale
837-
or self.config.ln_positions == "postln"
838-
or self.config.force_ln_scale
839-
)
839+
if self.add_norm:
840+
use_scale = self.use_scale or self.config.force_ln_scale
840841
hidden_states = norm(
841842
self.config.ln_type,
842843
dtype=self.dtype,
@@ -849,6 +850,9 @@ def __call__(
849850
if output_attentions:
850851
outputs += (attn_weights, cross_attn_weights)
851852

853+
if self.config.use_scan:
854+
outputs = (outputs, None)
855+
852856
return outputs
853857

854858

@@ -876,35 +880,80 @@ def __call__(
876880

877881
n_layers = self.config.encoder_layers
878882
layer = (
879-
remat(FlaxBartEncoderLayer, static_argnums=(2, 3))
883+
remat(
884+
FlaxBartEncoderLayer,
885+
static_argnums=(2, 3),
886+
prevent_cse=not self.config.use_scan,
887+
)
880888
if self.config.gradient_checkpointing
881889
else FlaxBartEncoderLayer
882890
)
883-
for i in range(n_layers):
884-
if output_hidden_states:
885-
all_hidden_states += (hidden_states,)
886-
# final layernorm on the output of the last layer
887-
# or every 6 layers for Swin v2
888-
add_norm = (
889-
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
890-
) or (self.config.use_final_ln_encoder and (i == n_layers - 1))
891-
# we don't need to scale the norm for the last layer
892-
use_scale = i != n_layers - 1
893-
layer_outputs = layer(
894-
self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
891+
892+
if self.config.use_scan:
893+
# all blocks are the same so we use nn.scan
894+
assert not output_attentions, "cannot scan with output_attentions"
895+
assert not output_hidden_states, "cannot scan with output_hidden_states"
896+
hidden_states = (hidden_states,)
897+
# we use a scale on all norms (even last layer) to allow scanning
898+
hidden_states, _ = nn.scan(
899+
layer,
900+
variable_axes={"params": 0},
901+
split_rngs={"params": True, "dropout": True},
902+
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
903+
length=n_layers,
904+
)(
905+
self.config,
906+
dtype=self.dtype,
907+
add_norm=self.config.ln_positions == "postln",
908+
name="FlaxBartEncoderLayers",
895909
)(
896910
hidden_states,
897911
attention_mask,
898912
output_attentions,
899913
deterministic,
900914
)
901-
hidden_states = layer_outputs[0]
902-
if output_attentions:
903-
all_self_attns += (layer_outputs[1],)
915+
hidden_states = hidden_states[0]
916+
else:
917+
for i in range(n_layers):
918+
if output_hidden_states:
919+
all_hidden_states += (hidden_states,)
920+
# final layernorm on the output of the last layer
921+
# or every 6 layers for Swin v2
922+
add_norm = self.config.ln_positions == "postln" or (
923+
self.config.ln_positions == "swinv2"
924+
and ((i + 1) % 6 == 0)
925+
and (i != n_layers - 1)
926+
)
927+
# we don't need to scale the norm for the last layer
928+
use_scale = i != n_layers - 1
929+
layer_outputs = layer(
930+
self.config,
931+
dtype=self.dtype,
932+
add_norm=add_norm,
933+
use_scale=use_scale,
934+
name=f"FlaxBartEncoderLayer_{i}",
935+
)(
936+
hidden_states,
937+
attention_mask,
938+
output_attentions,
939+
deterministic,
940+
)
941+
hidden_states = layer_outputs[0]
942+
if output_attentions:
943+
all_self_attns += (layer_outputs[1],)
904944

905-
# add hidden states from the last layer
906-
if output_hidden_states:
907-
all_hidden_states += (hidden_states,)
945+
# add hidden states from the last layer
946+
if output_hidden_states:
947+
all_hidden_states += (hidden_states,)
948+
949+
# postln is already applied in every layer
950+
if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
951+
hidden_states = norm(
952+
self.config.ln_type,
953+
dtype=self.dtype,
954+
epsilon=1e-05,
955+
use_scale=self.config.force_ln_scale,
956+
)(hidden_states)
908957

909958
outputs = [
910959
hidden_states,
@@ -953,22 +1002,39 @@ def __call__(
9531002

9541003
n_layers = self.config.decoder_layers
9551004
layer = (
956-
remat(FlaxBartDecoderLayer, static_argnums=(4, 5, 6))
1005+
remat(
1006+
FlaxBartDecoderLayer,
1007+
static_argnums=(4, 5, 6),
1008+
prevent_cse=not self.config.use_scan,
1009+
)
9571010
if self.config.gradient_checkpointing
9581011
else FlaxBartDecoderLayer
9591012
)
960-
for i in range(n_layers):
961-
if output_hidden_states:
962-
all_hidden_states += (hidden_states,)
963-
# final layernorm on the output of the last layer
964-
# or every 6 layers for Swin v2
965-
add_norm = (
966-
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0)
967-
) or (self.config.use_final_ln_decoder and (i == n_layers - 1))
968-
# we don't need to scale the norm for the last layer
969-
use_scale = i != n_layers - 1
970-
layer_outputs = layer(
971-
self.config, dtype=self.dtype, add_norm=add_norm, use_scale=use_scale
1013+
1014+
if self.config.use_scan:
1015+
# all blocks are the same so we use nn.scan
1016+
assert not output_attentions, "cannot scan with output_attentions"
1017+
assert not output_hidden_states, "cannot scan with output_hidden_states"
1018+
hidden_states = (hidden_states,)
1019+
# we use a scale on all norms (even last layer) to allow scanning
1020+
hidden_states, _ = nn.scan(
1021+
layer,
1022+
variable_axes={"params": 0},
1023+
split_rngs={"params": True, "dropout": True},
1024+
in_axes=(
1025+
nn.broadcast,
1026+
nn.broadcast,
1027+
nn.broadcast,
1028+
nn.broadcast,
1029+
nn.broadcast,
1030+
nn.broadcast,
1031+
),
1032+
length=n_layers,
1033+
)(
1034+
self.config,
1035+
dtype=self.dtype,
1036+
add_norm=self.config.ln_positions == "postln",
1037+
name="FlaxBartEncoderLayers",
9721038
)(
9731039
hidden_states,
9741040
attention_mask,
@@ -978,17 +1044,56 @@ def __call__(
9781044
output_attentions,
9791045
deterministic,
9801046
)
1047+
hidden_states = hidden_states[0]
9811048

982-
hidden_states = layer_outputs[0]
983-
if output_attentions:
984-
all_self_attns += (layer_outputs[1],)
1049+
else:
1050+
for i in range(n_layers):
1051+
if output_hidden_states:
1052+
all_hidden_states += (hidden_states,)
1053+
# final layernorm on the output of the last layer
1054+
# or every 6 layers for Swin v2
1055+
add_norm = self.config.ln_positions == "postln" or (
1056+
self.config.ln_positions == "swinv2"
1057+
and ((i + 1) % 6 == 0)
1058+
and (i != n_layers - 1)
1059+
)
1060+
# we don't need to scale the norm for the last layer
1061+
use_scale = i != n_layers - 1
1062+
layer_outputs = layer(
1063+
self.config,
1064+
dtype=self.dtype,
1065+
add_norm=add_norm,
1066+
use_scale=use_scale,
1067+
name=f"FlaxBartDecoderLayer_{i}",
1068+
)(
1069+
hidden_states,
1070+
attention_mask,
1071+
encoder_hidden_states,
1072+
encoder_attention_mask,
1073+
init_cache,
1074+
output_attentions,
1075+
deterministic,
1076+
)
1077+
1078+
hidden_states = layer_outputs[0]
1079+
if output_attentions:
1080+
all_self_attns += (layer_outputs[1],)
1081+
1082+
if encoder_hidden_states is not None:
1083+
all_cross_attentions += (layer_outputs[2],)
9851084

986-
if encoder_hidden_states is not None:
987-
all_cross_attentions += (layer_outputs[2],)
1085+
# add hidden states from the last decoder layer
1086+
if output_hidden_states:
1087+
all_hidden_states += (hidden_states,)
9881088

989-
# add hidden states from the last decoder layer
990-
if output_hidden_states:
991-
all_hidden_states += (hidden_states,)
1089+
# postln is already applied in every layer
1090+
if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1091+
hidden_states = norm(
1092+
self.config.ln_type,
1093+
dtype=self.dtype,
1094+
epsilon=1e-05,
1095+
use_scale=self.config.force_ln_scale,
1096+
)(hidden_states)
9921097

9931098
outputs = [
9941099
hidden_states,

src/dalle_mini/model/partitions.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,22 @@ def _get_partition_rules():
5555
]
5656

5757

58-
def set_partitions(in_dict):
58+
def set_partitions(in_dict, use_scan):
5959
rules = _get_partition_rules()
6060
replace = _replacement_rules(rules)
6161
initd = {k: _unmatched for k in flatten_dict(in_dict)}
6262
result = {k: replace(k, v) for k, v in initd.items()}
6363
for k, v in result.items():
6464
if v == _unmatched:
6565
print(f"Unmatched -> {k}")
66+
l = list(result.keys())
67+
if use_scan:
68+
# add None dimension to scanned layers
69+
result = {
70+
k: (P(*(None,) + v) if v is not None else None)
71+
if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
72+
else v
73+
for k, v in result.items()
74+
}
6675
assert _unmatched not in result.values(), "Incomplete partition spec."
6776
return freeze(unflatten_dict(result))

tools/train/config/mega/config.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
"decoder_attention_heads": 32,
88
"decoder_ffn_dim": 4096,
99
"decoder_layerdrop": 0.0,
10-
"decoder_layers": 25,
10+
"decoder_layers": 26,
1111
"decoder_start_token_id": 16384,
1212
"do_sample": true,
1313
"dropout": 0.0,
1414
"encoder_attention_heads": 32,
1515
"encoder_ffn_dim": 4096,
1616
"encoder_layerdrop": 0.0,
17-
"encoder_layers": 25,
17+
"encoder_layers": 26,
1818
"encoder_vocab_size": 50272,
1919
"eos_token_id": 16385,
2020
"force_ln_scale": false,

0 commit comments

Comments
 (0)