Skip to content

Commit 803ccbf

Browse files
authored
feat: support pod (borisdayma#139)
1 parent 2e02683 commit 803ccbf

File tree

14 files changed

+873
-350
lines changed

14 files changed

+873
-350
lines changed

src/dalle_mini/data.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,19 @@ class Dataset:
2727
do_eval: bool = True
2828
seed_dataset: int = None
2929
shard_by_host: bool = False
30+
blank_caption_prob: float = 0.0
3031
train_dataset: Dataset = field(init=False)
3132
eval_dataset: Dataset = field(init=False)
3233
rng_dataset: jnp.ndarray = field(init=False)
3334
multi_hosts: bool = field(init=False)
3435

3536
def __post_init__(self):
3637
self.multi_hosts = jax.process_count() > 1
38+
# feed blank captions only in streaming mode for now
39+
if self.blank_caption_prob:
40+
assert (
41+
self.streaming is True
42+
), "blank_caption_prob can only be used in streaming mode"
3743
# define data_files
3844
if self.train_file is not None or self.validation_file is not None:
3945
# accept braceexpand notation
@@ -101,6 +107,25 @@ def preprocess(self, tokenizer, config):
101107
self.seed_dataset = np.random.get_state()[1][0]
102108
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
103109

110+
# blank captions
111+
if self.blank_caption_prob:
112+
partial_blank_caption_function = partial(
113+
blank_caption_function,
114+
text_column=self.text_column,
115+
blank_caption_prob=self.blank_caption_prob,
116+
)
117+
if hasattr(self, "train_dataset"):
118+
self.train_dataset = (
119+
self.train_dataset.map(partial_blank_caption_function)
120+
if self.streaming
121+
else self.train_dataset.map(
122+
partial_blank_caption_function,
123+
num_proc=self.preprocessing_num_workers,
124+
load_from_cache_file=False,
125+
desc="Blanking some captions",
126+
)
127+
)
128+
104129
# normalize text
105130
if normalize_text:
106131
text_normalizer = TextNormalizer()
@@ -144,6 +169,10 @@ def preprocess(self, tokenizer, config):
144169
getattr(self, ds).map(
145170
partial_preprocess_function,
146171
batched=True,
172+
remove_columns=[
173+
self.text_column,
174+
self.encoding_column,
175+
],
147176
)
148177
if self.streaming
149178
else getattr(self, ds).map(
@@ -193,8 +222,8 @@ def _dataloader_datasets_streaming(
193222
while (self.multi_hosts and split == "train") or first_loop:
194223
# in multi-host, we run forever (no epoch) as hosts need to stop
195224
# at the same time and training data may not be split equally
196-
# For validation data we put the entire set on each host as we could lose
197-
# too many samples on pods
225+
# For validation data we put the entire batch on each host and then
226+
# keep only the one specific to each host (could be improved but not necessary)
198227
if epoch is not None:
199228
assert split == "train"
200229
# reshuffle training data at each epoch
@@ -252,6 +281,12 @@ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
252281
return shifted_input_ids
253282

254283

284+
def blank_caption_function(example, text_column, blank_caption_prob):
285+
if blank_caption_prob and np.random.rand() < blank_caption_prob:
286+
example[text_column] = ""
287+
return example
288+
289+
255290
def normalize_function(example, text_column, text_normalizer):
256291
example[text_column] = text_normalizer(example[text_column])
257292
return example

src/dalle_mini/model/modeling.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# coding=utf-8
2-
# Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and the DalleBart team. All rights reserved.
2+
# Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -328,6 +328,7 @@ def __init__(
328328
dtype: jnp.dtype = jnp.float32,
329329
abstract_init: bool = False,
330330
load_on_cpu: bool = False,
331+
init_weights: bool = True,
331332
**kwargs,
332333
):
333334
module = self.module_class(config=config, dtype=dtype, **kwargs)
@@ -347,25 +348,34 @@ def __init__(
347348
self.key = PRNGKey(seed)
348349
self.dtype = dtype
349350

350-
# init weights on CPU
351-
if load_on_cpu:
352-
# init weights on CPU
353-
init_fn = jax.jit(self.init_weights, static_argnums=(1,), backend="cpu")
354-
else:
355-
init_fn = self.init_weights
351+
if init_weights:
352+
# get shape of params only
353+
random_params = self.init_weights(
354+
self.key,
355+
input_shape,
356+
abstract_init=abstract_init,
357+
load_on_cpu=load_on_cpu,
358+
)
359+
360+
# save required_params as set
361+
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
362+
self.params = random_params
356363

357-
# randomly initialized parameters
358-
random_params = self.init_weights(self.key, input_shape)
364+
def init_weights(
365+
self, rng=None, input_shape=(1, 1), abstract_init=False, load_on_cpu=False
366+
):
367+
if rng is None:
368+
rng = self.key
369+
init_fn = super().init_weights
370+
if load_on_cpu:
371+
init_fn = jax.jit(init_fn, static_argnums=(1,), backend="cpu")
359372
if abstract_init:
360373
# only set shape and dtype, load parameters separately
361374
init_fn = partial(init_fn, input_shape=input_shape)
362-
random_params = jax.eval_shape(init_fn, self.key)
375+
params = jax.eval_shape(init_fn, rng)
363376
else:
364-
random_params = init_fn(self.key, input_shape)
365-
366-
# save required_params as set
367-
self._required_params = set(flatten_dict(unfreeze(random_params)).keys())
368-
self.params = random_params
377+
params = init_fn(rng, input_shape)
378+
return params
369379

370380
@property
371381
def num_params(self):

src/dalle_mini/model/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
2323
else:
2424
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
2525
pretrained_model_name_or_path = artifact.download(tmp_dir)
26-
if artifact.metadata.get("bucket_path"):
27-
pretrained_model_name_or_path = artifact.metadata["bucket_path"]
28-
29-
if pretrained_model_name_or_path.startswith("gs://"):
30-
copy_blobs(pretrained_model_name_or_path, tmp_dir)
31-
pretrained_model_name_or_path = tmp_dir
3226

3327
return super(PretrainedFromWandbMixin, cls).from_pretrained(
3428
pretrained_model_name_or_path, *model_args, **kwargs

tools/inference/inference_pipeline.ipynb

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
8484
"\n",
8585
"# CLIP model\n",
86-
"CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
86+
"CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
8787
"CLIP_COMMIT_ID = None"
8888
]
8989
},
@@ -129,7 +129,6 @@
129129
"from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
130130
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
131131
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
132-
"import wandb\n",
133132
"\n",
134133
"# Load dalle-mini\n",
135134
"model = DalleBart.from_pretrained(\n",
@@ -168,9 +167,9 @@
168167
"if dtype == jnp.bfloat16:\n",
169168
" model.params = model.to_bf16(model.params)\n",
170169
"\n",
171-
"model_params = replicate(model.params)\n",
172-
"vqgan_params = replicate(vqgan.params)\n",
173-
"clip_params = replicate(clip.params)"
170+
"model._params = replicate(model.params)\n",
171+
"vqgan._params = replicate(vqgan.params)\n",
172+
"clip._params = replicate(clip.params)"
174173
]
175174
},
176175
{
@@ -292,7 +291,7 @@
292291
},
293292
"outputs": [],
294293
"source": [
295-
"prompt = \"a blue table\""
294+
"prompt = \"view of the beach during sunset\""
296295
]
297296
},
298297
{
@@ -414,12 +413,12 @@
414413
" key, subkey = jax.random.split(key)\n",
415414
" # generate images\n",
416415
" encoded_images = p_generate(\n",
417-
" tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
416+
" tokenized_prompt, shard_prng_key(subkey), model.params, gen_top_k, gen_top_p\n",
418417
" )\n",
419418
" # remove BOS\n",
420419
" encoded_images = encoded_images.sequences[..., 1:]\n",
421420
" # decode images\n",
422-
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
421+
" decoded_images = p_decode(encoded_images, vqgan.params)\n",
423422
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
424423
" for img in decoded_images:\n",
425424
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
@@ -453,7 +452,7 @@
453452
" max_length=77,\n",
454453
" truncation=True,\n",
455454
").data\n",
456-
"logits = p_clip(shard(clip_inputs), clip_params)\n",
455+
"logits = p_clip(shard(clip_inputs), clip.params)\n",
457456
"logits = logits.squeeze().flatten()"
458457
]
459458
},
@@ -479,6 +478,13 @@
479478
" display(images[idx])\n",
480479
" print(f\"Score: {logits[idx]:.2f}\\n\")"
481480
]
481+
},
482+
{
483+
"cell_type": "code",
484+
"execution_count": null,
485+
"metadata": {},
486+
"outputs": [],
487+
"source": []
482488
}
483489
],
484490
"metadata": {

tools/train/config/medium/config.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,5 @@
2828
"pad_token_id": 16385,
2929
"scale_embedding": false,
3030
"tie_word_embeddings": false,
31-
"transformers_version": "4.13.0.dev0",
3231
"use_cache": true
3332
}

tools/train/config/mega/config.json

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,20 @@
55
"bos_token_id": 16385,
66
"classifier_dropout": 0.0,
77
"d_model": 2048,
8-
"decoder_attention_heads": 16,
9-
"decoder_ffn_dim": 4096,
8+
"decoder_attention_heads": 32,
9+
"decoder_ffn_dim": 8192,
1010
"decoder_layerdrop": 0.0,
11-
"decoder_layers": 31,
11+
"decoder_layers": 24,
1212
"decoder_start_token_id": 16384,
13-
"dropout": 0.1,
14-
"encoder_attention_heads": 16,
15-
"encoder_ffn_dim": 4096,
13+
"dropout": 0.0,
14+
"encoder_attention_heads": 32,
15+
"encoder_ffn_dim": 8192,
1616
"encoder_layerdrop": 0.0,
17-
"encoder_layers": 31,
17+
"encoder_layers": 24,
1818
"encoder_vocab_size": 50264,
1919
"eos_token_id": 16385,
20-
"gradient_checkpointing": false,
2120
"image_length": 256,
22-
"image_vocab_size": 16384,
21+
"image_vocab_size": 16391,
2322
"init_std": 0.01,
2423
"is_encoder_decoder": true,
2524
"max_text_length": 64,
@@ -28,6 +27,5 @@
2827
"pad_token_id": 16385,
2928
"scale_embedding": false,
3029
"tie_word_embeddings": false,
31-
"transformers_version": "4.13.0.dev0",
3230
"use_cache": true
3331
}

tools/train/config/micro/config.json

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,21 @@
44
"attention_dropout": 0.0,
55
"bos_token_id": 16385,
66
"classifier_dropout": 0.0,
7-
"d_model": 1024,
8-
"decoder_attention_heads": 16,
9-
"decoder_ffn_dim": 2048,
7+
"d_model": 256,
8+
"decoder_attention_heads": 2,
9+
"decoder_ffn_dim": 256,
1010
"decoder_layerdrop": 0.0,
1111
"decoder_layers": 2,
1212
"decoder_start_token_id": 16384,
1313
"dropout": 0.0,
14-
"encoder_attention_heads": 16,
15-
"encoder_ffn_dim": 2048,
14+
"encoder_attention_heads": 2,
15+
"encoder_ffn_dim": 256,
1616
"encoder_layerdrop": 0.0,
1717
"encoder_layers": 2,
1818
"encoder_vocab_size": 50264,
1919
"eos_token_id": 16385,
20-
"gradient_checkpointing": false,
2120
"image_length": 256,
22-
"image_vocab_size": 16384,
21+
"image_vocab_size": 16391,
2322
"init_std": 0.02,
2423
"is_encoder_decoder": true,
2524
"max_text_length": 64,
@@ -28,6 +27,5 @@
2827
"pad_token_id": 16385,
2928
"scale_embedding": false,
3029
"tie_word_embeddings": false,
31-
"transformers_version": "4.13.0.dev0",
3230
"use_cache": true
3331
}

tools/train/config/mini/config.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,5 @@
2828
"pad_token_id": 16385,
2929
"scale_embedding": false,
3030
"tie_word_embeddings": false,
31-
"transformers_version": "4.13.0.dev0",
3231
"use_cache": true
3332
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Notes
2+
3+
Files copied from [google-research/scalable_shampoo/optax](https://github.com/google-research/google-research/tree/master/scalable_shampoo/optax).
4+
5+
Imports have been modified to be relative.
6+
7+
This will be replaced with `optax-shampoo` package eventually.

0 commit comments

Comments
 (0)