@@ -525,6 +525,8 @@ def __post_init__(self):
525525 self .per_device_eval_batch_size = self .per_device_train_batch_size
526526 if self .log_norm_steps is True :
527527 self .log_norm_steps = self .logging_steps
528+ if not self .do_train :
529+ self .num_train_epochs = 1
528530 if (
529531 os .path .exists (self .output_dir )
530532 and os .listdir (self .output_dir )
@@ -1354,6 +1356,8 @@ def log(self, metrics, prefix=None):
13541356 # init variables
13551357 start_time = time .perf_counter () - local_state ["train_time" ]
13561358 train_metrics = None
1359+ evaluation_ran = False
1360+ save_model_ran = False
13571361 metrics_logger = MetricsLogger (local_state ["step" ])
13581362 epochs = tqdm (
13591363 range (local_state ["epoch" ], num_epochs ),
@@ -1532,85 +1536,98 @@ def run_save_model(state, eval_metrics=None):
15321536 metrics_logger .update_state_metrics (local_state )
15331537 metrics_logger .log ({})
15341538
1535- # load data - may be replicated on multiple nodes
1536- node_groups = max (1 , training_args .mp_devices // jax .local_device_count ())
1537- loader_bs = batch_size_per_node * node_groups
1538- train_loader = dataset .dataloader (
1539- "train" ,
1540- loader_bs ,
1541- epoch ,
1542- )
1543- # train
1544- for batch in tqdm (
1545- train_loader ,
1546- desc = "Training..." ,
1547- position = 1 ,
1548- leave = False ,
1549- total = steps_per_epoch ,
1550- disable = jax .process_index () > 0 ,
1551- ):
1552- # calculate delta time (we have a lag of one step but it's ok)
1553- train_time = time .perf_counter () - start_time
1554-
1555- # set correct shape to batch
1556- # - add grad_step dim if gradient_accumulation_steps > 1
1557- bs_shape = (
1558- (batch_size_per_node_per_grad_step * node_groups ,)
1559- if not use_vmap_trick
1560- else (
1561- jax .local_device_count ()
1562- * node_groups
1563- // training_args .mp_devices , # local dp devices
1564- training_args .per_device_train_batch_size ,
1565- )
1539+ if training_args .do_train :
1540+ # load data - may be replicated on multiple nodes
1541+ node_groups = max (
1542+ 1 , training_args .mp_devices // jax .local_device_count ()
15661543 )
1567- if training_args .gradient_accumulation_steps > 1 :
1568- # reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1569- # to avoid any data redistribution when sharding
1570- bs_shape = (training_args .gradient_accumulation_steps ,) + bs_shape
1571-
1572- # reshape batch
1573- batch = jax .tree_map (
1574- lambda x : x .reshape (bs_shape + x .shape [1 :]),
1575- batch ,
1544+ loader_bs = batch_size_per_node * node_groups
1545+ train_loader = dataset .dataloader (
1546+ "train" ,
1547+ loader_bs ,
1548+ epoch ,
15761549 )
1577- # freeze batch to pass safely to jax transforms
1578- batch = freeze (batch )
1579-
1580- # train step
1581- state , train_metrics = p_train_step (state , batch , train_time )
1582- local_state ["step" ] += 1
1583- local_state ["train_time" ] = train_time
1584- local_state ["train_samples" ] += batch_size_per_step
1585-
1586- if (
1587- local_state ["step" ] % training_args .logging_steps == 0
1588- and jax .process_index () == 0
1550+ # train
1551+ for batch in tqdm (
1552+ train_loader ,
1553+ desc = "Training..." ,
1554+ position = 1 ,
1555+ leave = False ,
1556+ total = steps_per_epoch ,
1557+ disable = jax .process_index () > 0 ,
15891558 ):
1590- metrics_logger .update_state_metrics (local_state )
1591- metrics_logger .log (train_metrics , prefix = "train" )
1592-
1593- eval_metrics = None
1594- if local_state ["step" ] % training_args .eval_steps == 0 :
1595- eval_metrics = run_evaluation ()
1559+ # calculate delta time (we have a lag of one step but it's ok)
1560+ train_time = time .perf_counter () - start_time
15961561
1597- if local_state ["step" ] % training_args .save_steps == 0 :
1598- run_save_model (state , eval_metrics )
1562+ # reset control variables
1563+ evaluation_ran = False
1564+ save_model_ran = False
15991565
1600- # log final train metrics
1601- if train_metrics is not None :
1602- metrics_logger .update_state_metrics (state )
1603- metrics_logger .log (train_metrics , prefix = "train" )
1566+ # set correct shape to batch
1567+ # - add grad_step dim if gradient_accumulation_steps > 1
1568+ bs_shape = (
1569+ (batch_size_per_node_per_grad_step * node_groups ,)
1570+ if not use_vmap_trick
1571+ else (
1572+ jax .local_device_count ()
1573+ * node_groups
1574+ // training_args .mp_devices , # local dp devices
1575+ training_args .per_device_train_batch_size ,
1576+ )
1577+ )
1578+ if training_args .gradient_accumulation_steps > 1 :
1579+ # reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1580+ # to avoid any data redistribution when sharding
1581+ bs_shape = (
1582+ training_args .gradient_accumulation_steps ,
1583+ ) + bs_shape
1584+
1585+ # reshape batch
1586+ batch = jax .tree_map (
1587+ lambda x : x .reshape (bs_shape + x .shape [1 :]),
1588+ batch ,
1589+ )
1590+ # freeze batch to pass safely to jax transforms
1591+ batch = freeze (batch )
1592+
1593+ # train step
1594+ state , train_metrics = p_train_step (state , batch , train_time )
1595+ local_state ["step" ] += 1
1596+ local_state ["train_time" ] = train_time
1597+ local_state ["train_samples" ] += batch_size_per_step
1598+
1599+ if (
1600+ local_state ["step" ] % training_args .logging_steps == 0
1601+ and jax .process_index () == 0
1602+ ):
1603+ metrics_logger .update_state_metrics (local_state )
1604+ metrics_logger .log (train_metrics , prefix = "train" )
1605+
1606+ eval_metrics = None
1607+ if local_state ["step" ] % training_args .eval_steps == 0 :
1608+ eval_metrics = run_evaluation ()
1609+ evaluation_ran = True
1610+
1611+ if local_state ["step" ] % training_args .save_steps == 0 :
1612+ run_save_model (state , eval_metrics )
1613+ save_model_ran = True
1614+
1615+ # log final train metrics
1616+ if train_metrics is not None :
1617+ metrics_logger .update_state_metrics (state )
1618+ metrics_logger .log (train_metrics , prefix = "train" )
16041619
1605- epochs .write (
1606- f"Epoch... ({ epoch + 1 } /{ num_epochs } | Loss: { train_metrics ['loss' ]} , Learning Rate: { train_metrics ['learning_rate' ]} )"
1607- )
1620+ epochs .write (
1621+ f"Epoch... ({ epoch + 1 } /{ num_epochs } | Loss: { train_metrics ['loss' ]} , Learning Rate: { train_metrics ['learning_rate' ]} )"
1622+ )
16081623
1609- # Final evaluation
1610- eval_metrics = run_evaluation ()
1624+ # Final evaluation at the end of each epoch
1625+ if not evaluation_ran :
1626+ eval_metrics = run_evaluation ()
16111627
16121628 # save checkpoint after each epoch
1613- run_save_model (state , eval_metrics )
1629+ if not save_model_ran :
1630+ run_save_model (state , eval_metrics )
16141631
16151632
16161633if __name__ == "__main__" :
0 commit comments