Skip to content

Commit 728a3c3

Browse files
authored
feat: better multi-node support (borisdayma#158)
* reproducible data loader * custom sharding * model parallel across multiple nodes
1 parent 49be45e commit 728a3c3

File tree

4 files changed

+90
-21
lines changed

4 files changed

+90
-21
lines changed

src/dalle_mini/data.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __post_init__(self):
4343
if self.seed_dataset is None:
4444
# create a random seed
4545
self.seed_dataset = random.randint(0, 2**32 - 1)
46+
# set numpy rng
47+
self.np_rng = np.random.default_rng(self.seed_dataset)
4648
self.multi_hosts = jax.process_count() > 1
4749
# feed blank captions only in streaming mode for now
4850
# otherwise dataset could be cached with same blanked captions
@@ -173,14 +175,17 @@ def preprocess(self, tokenizer, config):
173175
blank_caption_function,
174176
text_column=self.text_column,
175177
blank_caption_prob=self.blank_caption_prob,
178+
rng=self.np_rng,
176179
)
177180
if hasattr(self, "train_dataset"):
178181
self.train_dataset = (
179182
self.train_dataset.map(partial_blank_caption_function)
180183
if self.streaming
181184
else self.train_dataset.map(
182185
partial_blank_caption_function,
183-
num_proc=self.preprocessing_num_workers,
186+
num_proc=None
187+
if self.seed_dataset
188+
else self.preprocessing_num_workers,
184189
load_from_cache_file=False,
185190
desc="Blanking some captions",
186191
)
@@ -316,8 +321,12 @@ def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int):
316321
return shifted_input_ids
317322

318323

319-
def blank_caption_function(example, text_column, blank_caption_prob):
320-
if blank_caption_prob and np.random.rand() < blank_caption_prob:
324+
def blank_caption_function(example, text_column, blank_caption_prob, rng=None):
325+
if (
326+
blank_caption_prob
327+
and (rng.random() if rng is not None else np.random.random())
328+
< blank_caption_prob
329+
):
321330
example[text_column] = ""
322331
return example
323332

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,49 @@
11
{
22
"activation_dropout": 0.0,
3-
"activation_function": "gelu",
3+
"activation_function": "swish",
44
"attention_dropout": 0.0,
55
"bos_token_id": 16385,
66
"d_model": 2048,
77
"decoder_attention_heads": 32,
8-
"decoder_ffn_dim": 8192,
8+
"decoder_ffn_dim": 4096,
99
"decoder_layerdrop": 0.0,
10-
"decoder_layers": 24,
10+
"decoder_layers": 25,
1111
"decoder_start_token_id": 16384,
12+
"do_sample": true,
1213
"dropout": 0.0,
1314
"encoder_attention_heads": 32,
14-
"encoder_ffn_dim": 8192,
15+
"encoder_ffn_dim": 4096,
1516
"encoder_layerdrop": 0.0,
16-
"encoder_layers": 24,
17-
"encoder_vocab_size": 50264,
17+
"encoder_layers": 25,
18+
"encoder_vocab_size": 50272,
1819
"eos_token_id": 16385,
20+
"force_ln_scale": false,
21+
"gradient_checkpointing": false,
1922
"image_length": 256,
20-
"image_vocab_size": 16391,
23+
"image_vocab_size": 16415,
2124
"init_std": 0.01,
2225
"is_encoder_decoder": true,
26+
"ln_positions": "normformer",
27+
"ln_type": "layernorm",
28+
"max_length": 257,
2329
"max_text_length": 64,
30+
"min_length": 257,
2431
"model_type": "dallebart",
2532
"normalize_text": true,
2633
"pad_token_id": 16385,
2734
"scale_embedding": false,
35+
"sinkhorn_iters": 1,
36+
"tau_init": 0.05,
2837
"tie_word_embeddings": false,
29-
"use_cache": true
38+
"use_absolute_position_embeddings": true,
39+
"use_alibi": false,
40+
"use_bias": false,
41+
"use_cache": true,
42+
"use_cosine_attention": false,
43+
"use_deepnet_scaling": false,
44+
"use_final_ln_decoder": true,
45+
"use_final_ln_encoder": true,
46+
"use_glu": true,
47+
"use_head_scale": false,
48+
"use_swin_position_embeddings": false
3049
}

tools/train/config/mini/config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"eos_token_id": 16385,
1717
"gradient_checkpointing": false,
1818
"image_length": 256,
19-
"image_vocab_size": 16384,
19+
"image_vocab_size": 16391,
2020
"init_std": 0.02,
2121
"is_encoder_decoder": true,
2222
"max_text_length": 64,

tools/train/train.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,12 @@ class TrainingArguments:
368368
"help": "Whether to quantize optimizer (only supported with Distributed Shampoo)."
369369
},
370370
)
371+
shard_shampoo_across: str = field(
372+
default="dp",
373+
metadata={
374+
"help": "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)."
375+
},
376+
)
371377

372378
num_train_epochs: int = field(
373379
default=3, metadata={"help": "Total number of training epochs to perform."}
@@ -450,6 +456,11 @@ class TrainingArguments:
450456
metadata={"help": "Verify that TPU is not in use."},
451457
)
452458

459+
use_vmap_trick: bool = field(
460+
default=True,
461+
metadata={"help": "Verify that TPU is not in use."},
462+
)
463+
453464
mp_devices: Optional[int] = field(
454465
default=1,
455466
metadata={
@@ -500,6 +511,11 @@ def __post_init__(self):
500511
f"Output directory ({self.output_dir}) already exists and is not empty."
501512
"Use --overwrite_output_dir to overcome."
502513
)
514+
assert self.shard_shampoo_across in [
515+
"dp",
516+
"mp",
517+
"2d",
518+
], f"Shard shampoo across {self.shard_shampoo_across} not supported."
503519
assert (
504520
self.mp_devices > 0
505521
), f"Number of devices for model parallelism must be > 0"
@@ -530,6 +546,12 @@ def main():
530546
else:
531547
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
532548

549+
# check arguments
550+
if training_args.mp_devices > jax.local_device_count():
551+
assert (
552+
data_args.seed_dataset is not None
553+
), "Seed dataset must be provided when model is split over multiple hosts"
554+
533555
# Make one log on every process with the configuration for debugging.
534556
logging.basicConfig(
535557
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -748,8 +770,20 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
748770
graft_type=graft_type,
749771
nesterov=False,
750772
exponent_override=0,
751-
statistics_partition_spec=PartitionSpec(None, "dp", None),
752-
preconditioner_partition_spec=PartitionSpec("dp", None, None),
773+
statistics_partition_spec=PartitionSpec(
774+
None, training_args.shard_shampoo_across, None
775+
)
776+
if training_args.shard_shampoo_across != "2d"
777+
else PartitionSpec(None, "dp", "mp"),
778+
preconditioner_partition_spec=PartitionSpec(
779+
training_args.shard_shampoo_across, None, None
780+
)
781+
if training_args.shard_shampoo_across != "2d"
782+
else PartitionSpec(
783+
"mp" if training_args.mp_devices > training_args.dp_devices else "dp",
784+
None,
785+
None,
786+
),
753787
num_devices_for_pjit=training_args.dp_devices,
754788
shard_optimizer_states=True,
755789
inverse_failure_threshold=0.1,
@@ -917,7 +951,7 @@ def loss_fn(logits, labels):
917951

918952
# "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
919953
# lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
920-
use_vmap_trick = True
954+
use_vmap_trick = training_args.use_vmap_trick
921955

922956
# make grad_param_spec for vmap
923957
if use_vmap_trick:
@@ -1145,7 +1179,8 @@ def update_state_metrics(self, state):
11451179
self.log_time("train_per_log", delta_time, offset=False)
11461180

11471181
def log_time(self, key, duration, offset=True):
1148-
wandb.log({f"time/{key}": duration, **self.state_dict})
1182+
if jax.process_index() == 0:
1183+
wandb.log({f"time/{key}": duration, **self.state_dict})
11491184
if offset:
11501185
self.offset_time += duration
11511186

@@ -1191,7 +1226,11 @@ def run_evaluation():
11911226
# ======================== Evaluating ==============================
11921227
if training_args.do_eval:
11931228
start_eval_time = time.perf_counter()
1194-
eval_loader = dataset.dataloader("eval", eval_batch_size_per_step)
1229+
eval_loader = dataset.dataloader(
1230+
"eval",
1231+
eval_batch_size_per_step
1232+
* max(1, training_args.mp_devices // jax.local_device_count()),
1233+
)
11951234
eval_steps = (
11961235
len_eval_dataset // eval_batch_size_per_step
11971236
if len_eval_dataset is not None
@@ -1353,10 +1392,12 @@ def run_save_model(state, eval_metrics=None):
13531392
metrics_logger.update_state_metrics(local_state)
13541393
metrics_logger.log({})
13551394

1356-
# Generate an epoch by shuffling sampling indices from the train dataset
1395+
# load data - may be replicated on multiple nodes
1396+
node_groups = max(1, training_args.mp_devices // jax.local_device_count())
1397+
loader_bs = batch_size_per_node * node_groups
13571398
train_loader = dataset.dataloader(
13581399
"train",
1359-
batch_size_per_node,
1400+
loader_bs,
13601401
epoch,
13611402
)
13621403
# train
@@ -1373,12 +1414,12 @@ def run_save_model(state, eval_metrics=None):
13731414

13741415
# set correct shape to batch
13751416
# - add grad_step dim if gradient_accumulation_steps > 1
1376-
# - split per dp device if not multi-host for vmap trick (does not work in multi-host)
13771417
bs_shape = (
1378-
(batch_size_per_node_per_grad_step,)
1418+
(batch_size_per_node_per_grad_step * node_groups,)
13791419
if not use_vmap_trick
13801420
else (
13811421
jax.local_device_count()
1422+
* node_groups
13821423
// training_args.mp_devices, # local dp devices
13831424
training_args.per_device_train_batch_size,
13841425
)

0 commit comments

Comments
 (0)