Skip to content

Commit 542378c

Browse files
authored
feat: implement transformer variants (borisdayma#144)
* added DeepNet * added Swin v2 * added NormFormer * added RMSNorm * added GLU variants
1 parent b7b619a commit 542378c

File tree

13 files changed

+1030
-295
lines changed

13 files changed

+1030
-295
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ __pycache__
33
.streamlit
44
wandb/
55
*.egg-info/
6+
jax_cache/

README.md

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -94,26 +94,43 @@ Many thanks to the people who helped make it better:
9494

9595
- the [DALLE-Pytorch](https://discord.gg/xBPBXfcFHd) and [EleutherAI](https://www.eleuther.ai/) communities for testing and exchanging cool ideas
9696
- [Rohan Anil](https://github.com/rohan-anil) for adding Distributed Shampoo optimizer
97+
- [Phil Wang](https://github.com/lucidrains) has provided a lot of cool implementations of transformer variants and gives interesting insights with [x-transformers](https://github.com/lucidrains/x-transformers)
9798
- [Katherine Crowson](https://github.com/crowsonkb) for [super conditioning](https://twitter.com/RiversHaveWings/status/1478093658716966912)
9899

99100
## Citing DALL·E mini
100101

101102
If you find DALL·E mini useful in your research or wish to refer, please use the following BibTeX entry.
102103

103-
```
104+
```text
104105
@misc{Dayma_DALL·E_Mini_2021,
105-
author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
106-
doi = {10.5281/zenodo.5146400},
107-
month = {7},
108-
title = {DALL·E Mini},
109-
url = {https://github.com/borisdayma/dalle-mini},
110-
year = {2021}
106+
author = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
107+
doi = {10.5281/zenodo.5146400},
108+
month = {7},
109+
title = {DALL·E Mini},
110+
url = {https://github.com/borisdayma/dalle-mini},
111+
year = {2021}
111112
}
112113
```
113114

114115
## References
115116

116-
```
117+
Original DALL·E from "[Zero-Shot Text-to-Image Generation](https://arxiv.org/abs/2102.12092)" with image quantization from "[Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)".
118+
119+
Image encoder from "[Taming Transformers for High-Resolution Image Synthesis](https://arxiv.org/abs/2012.09841v2)".
120+
121+
Sequence to sequence model based on "[BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension](https://arxiv.org/abs/1910.13461v1)" with implementation of a few variants:
122+
123+
- "[GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202)"
124+
- "[Deepnet: Scaling Transformers to 1,000 Layers](https://arxiv.org/abs/2203.00555)"
125+
- "[NormFormer: Improved Transformer Pretraining with Extra Normalization](https://arxiv.org/abs/2110.09456)"
126+
- "[Swin Transformer: Hierarchical Vision Transformer using Shifted Windows](https://arxiv.org/abs/2103.14030)"
127+
- "[Root Mean Square Layer Normalization](https://arxiv.org/abs/1910.07467)"
128+
129+
Main optimizer (Distributed Shampoo) from "[Scalable Second Order Optimization for Deep Learning](https://arxiv.org/abs/2002.09018)".
130+
131+
### Citations
132+
133+
```text
117134
@misc{ramesh2021zeroshot,
118135
title={Zero-Shot Text-to-Image Generation},
119136
author={Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
@@ -124,7 +141,18 @@ year = {2021}
124141
}
125142
```
126143

144+
```text
145+
@misc{radford2021learning,
146+
title={Learning Transferable Visual Models From Natural Language Supervision},
147+
author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
148+
year={2021},
149+
eprint={2103.00020},
150+
archivePrefix={arXiv},
151+
primaryClass={cs.CV}
152+
}
127153
```
154+
155+
```text
128156
@misc{esser2021taming,
129157
title={Taming Transformers for High-Resolution Image Synthesis},
130158
author={Patrick Esser and Robin Rombach and Björn Ommer},
@@ -135,7 +163,7 @@ year = {2021}
135163
}
136164
```
137165

138-
```
166+
```text
139167
@misc{lewis2019bart,
140168
title={BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension},
141169
author={Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and Abdelrahman Mohamed and Omer Levy and Ves Stoyanov and Luke Zettlemoyer},
@@ -146,24 +174,64 @@ year = {2021}
146174
}
147175
```
148176

149-
```
150-
@misc{radford2021learning,
151-
title={Learning Transferable Visual Models From Natural Language Supervision},
152-
author={Alec Radford and Jong Wook Kim and Chris Hallacy and Aditya Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
177+
```text
178+
@misc{anil2021scalable,
179+
title={Scalable Second Order Optimization for Deep Learning},
180+
author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
153181
year={2021},
154-
eprint={2103.00020},
182+
eprint={2002.09018},
155183
archivePrefix={arXiv},
156-
primaryClass={cs.CV}
184+
primaryClass={cs.LG}
157185
}
158186
```
159187

188+
```text
189+
@misc{shazeer2020glu,
190+
title={GLU Variants Improve Transformer},
191+
author={Noam Shazeer},
192+
year={2020},
193+
url={https://arxiv.org/abs/2002.05202}
194+
}
160195
```
161-
@misc{anil2021scalable,
162-
title={Scalable Second Order Optimization for Deep Learning},
163-
author={Rohan Anil and Vineet Gupta and Tomer Koren and Kevin Regan and Yoram Singer},
196+
197+
```text
198+
@misc{wang_ma_dong_huang_zhang_wei_2022,
199+
title={DeepNet: Scaling transformers to 1,000 layers},
200+
author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Zhang, Dongdong and Wei, Furu},
201+
year={2022},
202+
eprint={2203.00555}
203+
archivePrefix={arXiv},
204+
primaryClass={cs.LG}
205+
}
206+
```
207+
208+
```text
209+
@misc{shleifer2021normformer,
210+
title={NormFormer: Improved Transformer Pretraining with Extra Normalization},
211+
author={Sam Shleifer and Jason Weston and Myle Ott},
164212
year={2021},
165-
eprint={2002.09018},
213+
eprint={2110.09456},
166214
archivePrefix={arXiv},
167-
primaryClass={cs.LG}
215+
primaryClass={cs.CL}
216+
}
217+
```
218+
219+
```text
220+
@inproceedings{liu2021swinv2,
221+
title={Swin Transformer V2: Scaling Up Capacity and Resolution},
222+
author={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
223+
booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
224+
year={2022}
225+
}
226+
```
227+
228+
```text
229+
@misc{zhang2019root,
230+
title = {Root Mean Square Layer Normalization},
231+
author = {Biao Zhang and Rico Sennrich},
232+
year = {2019},
233+
eprint = {1910.07467},
234+
archivePrefix = {arXiv},
235+
primaryClass = {cs.LG}
168236
}
169237
```

src/dalle_mini/data.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
from dataclasses import dataclass, field
23
from functools import partial
34

@@ -39,6 +40,9 @@ class Dataset:
3940
multi_hosts: bool = field(init=False)
4041

4142
def __post_init__(self):
43+
if self.seed_dataset is None:
44+
# create a random seed
45+
self.seed_dataset = random.randint(0, 2**32 - 1)
4246
self.multi_hosts = jax.process_count() > 1
4347
# feed blank captions only in streaming mode for now
4448
# otherwise dataset could be cached with same blanked captions
@@ -106,11 +110,10 @@ def preprocess(self, tokenizer, config):
106110
if self.streaming:
107111
# we need to shuffle early in streaming mode
108112
if hasattr(self, "train_dataset"):
109-
self.train_dataset = self.train_dataset.shuffle(5000, self.seed_dataset)
113+
self.train_dataset = self.train_dataset.shuffle(
114+
buffer_size=5000, seed=self.seed_dataset
115+
)
110116
else:
111-
# prepare rng for later shuffling
112-
if self.seed_dataset is None:
113-
self.seed_dataset = np.random.get_state()[1][0]
114117
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
115118

116119
# filter data

src/dalle_mini/model/configuration.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,51 @@ def __init__(
4444
decoder_layers=12,
4545
decoder_ffn_dim=4096,
4646
decoder_attention_heads=16,
47-
encoder_layerdrop=0.0,
48-
decoder_layerdrop=0.0,
4947
activation_function="gelu",
5048
d_model=1024,
5149
dropout=0.1,
5250
attention_dropout=0.0,
5351
activation_dropout=0.0,
5452
init_std=0.02,
55-
classifier_dropout=0.0,
5653
scale_embedding=False,
5754
gradient_checkpointing=False,
5855
use_cache=True,
5956
is_encoder_decoder=True,
6057
forced_eos_token_id=None,
6158
tie_word_embeddings=False, # different modalities and sizes
6259
do_sample=True,
60+
# transformer variants
61+
head_scale=False, # used in NormFormer
62+
ln_type="layernorm", # layer normalization type, "rmsnorm", "layernorm"
63+
ln_positions="deepnet", # layer normalization positions, "normformer", "swinv2", "deepnet" (same as post-ln)
64+
use_cosine_attention=False, # used in Swin v2
65+
tau_init=0.05, # used only in cosine attention (Swin v2)
66+
use_deepnet_scaling=False, # used in Deepnet
67+
use_glu=False, # "GLU Variants Improve Transformer"
6368
**kwargs,
6469
):
70+
# text normalizer
6571
self.normalize_text = normalize_text
72+
73+
# transformer variants
74+
self.head_scale = head_scale # per Normformer
75+
assert ln_type in [
76+
"rmsnorm",
77+
"layernorm",
78+
], "ln_type must be 'rmsnorm' or 'layernorm'"
79+
self.ln_type = ln_type
80+
assert ln_positions in [
81+
"normformer",
82+
"swinv2",
83+
"deepnet",
84+
], "ln_positions must be 'normformer', 'swinv2' or 'deepnet'"
85+
self.ln_positions = ln_positions
86+
self.use_cosine_attention = use_cosine_attention
87+
self.tau_init = tau_init
88+
self.use_deepnet_scaling = use_deepnet_scaling
89+
self.use_glu = use_glu
90+
91+
# common parameters
6692
self.encoder_vocab_size = encoder_vocab_size
6793
self.image_vocab_size = image_vocab_size
6894
self.image_length = image_length
@@ -79,9 +105,6 @@ def __init__(
79105
self.activation_dropout = activation_dropout
80106
self.activation_function = activation_function
81107
self.init_std = init_std
82-
self.encoder_layerdrop = encoder_layerdrop
83-
self.decoder_layerdrop = decoder_layerdrop
84-
self.classifier_dropout = classifier_dropout
85108
self.use_cache = use_cache
86109
self.gradient_checkpointing = gradient_checkpointing
87110
self.scale_embedding = (

0 commit comments

Comments
 (0)