@@ -368,6 +368,12 @@ class TrainingArguments:
368368 "help" : "Whether to quantize optimizer (only supported with Distributed Shampoo)."
369369 },
370370 )
371+ shard_shampoo_across : str = field (
372+ default = "dp" ,
373+ metadata = {
374+ "help" : "Whether to shard the optimizer across data devices (dp), model devices (mp) or both (2d)."
375+ },
376+ )
371377
372378 num_train_epochs : int = field (
373379 default = 3 , metadata = {"help" : "Total number of training epochs to perform." }
@@ -450,6 +456,11 @@ class TrainingArguments:
450456 metadata = {"help" : "Verify that TPU is not in use." },
451457 )
452458
459+ use_vmap_trick : bool = field (
460+ default = True ,
461+ metadata = {"help" : "Verify that TPU is not in use." },
462+ )
463+
453464 mp_devices : Optional [int ] = field (
454465 default = 1 ,
455466 metadata = {
@@ -500,6 +511,11 @@ def __post_init__(self):
500511 f"Output directory ({ self .output_dir } ) already exists and is not empty."
501512 "Use --overwrite_output_dir to overcome."
502513 )
514+ assert self .shard_shampoo_across in [
515+ "dp" ,
516+ "mp" ,
517+ "2d" ,
518+ ], f"Shard shampoo across { self .shard_shampoo_across } not supported."
503519 assert (
504520 self .mp_devices > 0
505521 ), f"Number of devices for model parallelism must be > 0"
@@ -530,6 +546,12 @@ def main():
530546 else :
531547 model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
532548
549+ # check arguments
550+ if training_args .mp_devices > jax .local_device_count ():
551+ assert (
552+ data_args .seed_dataset is not None
553+ ), "Seed dataset must be provided when model is split over multiple hosts"
554+
533555 # Make one log on every process with the configuration for debugging.
534556 logging .basicConfig (
535557 format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -748,8 +770,20 @@ def create_learning_rate_fn() -> Callable[[int], jnp.array]:
748770 graft_type = graft_type ,
749771 nesterov = False ,
750772 exponent_override = 0 ,
751- statistics_partition_spec = PartitionSpec (None , "dp" , None ),
752- preconditioner_partition_spec = PartitionSpec ("dp" , None , None ),
773+ statistics_partition_spec = PartitionSpec (
774+ None , training_args .shard_shampoo_across , None
775+ )
776+ if training_args .shard_shampoo_across != "2d"
777+ else PartitionSpec (None , "dp" , "mp" ),
778+ preconditioner_partition_spec = PartitionSpec (
779+ training_args .shard_shampoo_across , None , None
780+ )
781+ if training_args .shard_shampoo_across != "2d"
782+ else PartitionSpec (
783+ "mp" if training_args .mp_devices > training_args .dp_devices else "dp" ,
784+ None ,
785+ None ,
786+ ),
753787 num_devices_for_pjit = training_args .dp_devices ,
754788 shard_optimizer_states = True ,
755789 inverse_failure_threshold = 0.1 ,
@@ -917,7 +951,7 @@ def loss_fn(logits, labels):
917951
918952 # "vmap trick" avoids a crash when mp_devices > 1 (not sure why it happens)
919953 # lead to better perf: see https://wandb.ai/dalle-mini/dalle-mini/reports/JAX-pmap-vs-pjit--VmlldzoxNDg1ODA2
920- use_vmap_trick = True
954+ use_vmap_trick = training_args . use_vmap_trick
921955
922956 # make grad_param_spec for vmap
923957 if use_vmap_trick :
@@ -1145,7 +1179,8 @@ def update_state_metrics(self, state):
11451179 self .log_time ("train_per_log" , delta_time , offset = False )
11461180
11471181 def log_time (self , key , duration , offset = True ):
1148- wandb .log ({f"time/{ key } " : duration , ** self .state_dict })
1182+ if jax .process_index () == 0 :
1183+ wandb .log ({f"time/{ key } " : duration , ** self .state_dict })
11491184 if offset :
11501185 self .offset_time += duration
11511186
@@ -1191,7 +1226,11 @@ def run_evaluation():
11911226 # ======================== Evaluating ==============================
11921227 if training_args .do_eval :
11931228 start_eval_time = time .perf_counter ()
1194- eval_loader = dataset .dataloader ("eval" , eval_batch_size_per_step )
1229+ eval_loader = dataset .dataloader (
1230+ "eval" ,
1231+ eval_batch_size_per_step
1232+ * max (1 , training_args .mp_devices // jax .local_device_count ()),
1233+ )
11951234 eval_steps = (
11961235 len_eval_dataset // eval_batch_size_per_step
11971236 if len_eval_dataset is not None
@@ -1353,10 +1392,12 @@ def run_save_model(state, eval_metrics=None):
13531392 metrics_logger .update_state_metrics (local_state )
13541393 metrics_logger .log ({})
13551394
1356- # Generate an epoch by shuffling sampling indices from the train dataset
1395+ # load data - may be replicated on multiple nodes
1396+ node_groups = max (1 , training_args .mp_devices // jax .local_device_count ())
1397+ loader_bs = batch_size_per_node * node_groups
13571398 train_loader = dataset .dataloader (
13581399 "train" ,
1359- batch_size_per_node ,
1400+ loader_bs ,
13601401 epoch ,
13611402 )
13621403 # train
@@ -1373,12 +1414,12 @@ def run_save_model(state, eval_metrics=None):
13731414
13741415 # set correct shape to batch
13751416 # - add grad_step dim if gradient_accumulation_steps > 1
1376- # - split per dp device if not multi-host for vmap trick (does not work in multi-host)
13771417 bs_shape = (
1378- (batch_size_per_node_per_grad_step ,)
1418+ (batch_size_per_node_per_grad_step * node_groups ,)
13791419 if not use_vmap_trick
13801420 else (
13811421 jax .local_device_count ()
1422+ * node_groups
13821423 // training_args .mp_devices , # local dp devices
13831424 training_args .per_device_train_batch_size ,
13841425 )
0 commit comments