Skip to content

Commit 65bb95f

Browse files
committed
feat: allow eval only
1 parent a4d2af8 commit 65bb95f

File tree

1 file changed

+87
-70
lines changed

1 file changed

+87
-70
lines changed

tools/train/train.py

Lines changed: 87 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,8 @@ def __post_init__(self):
525525
self.per_device_eval_batch_size = self.per_device_train_batch_size
526526
if self.log_norm_steps is True:
527527
self.log_norm_steps = self.logging_steps
528+
if not self.do_train:
529+
self.num_train_epochs = 1
528530
if (
529531
os.path.exists(self.output_dir)
530532
and os.listdir(self.output_dir)
@@ -1354,6 +1356,8 @@ def log(self, metrics, prefix=None):
13541356
# init variables
13551357
start_time = time.perf_counter() - local_state["train_time"]
13561358
train_metrics = None
1359+
evaluation_ran = False
1360+
save_model_ran = False
13571361
metrics_logger = MetricsLogger(local_state["step"])
13581362
epochs = tqdm(
13591363
range(local_state["epoch"], num_epochs),
@@ -1532,85 +1536,98 @@ def run_save_model(state, eval_metrics=None):
15321536
metrics_logger.update_state_metrics(local_state)
15331537
metrics_logger.log({})
15341538

1535-
# load data - may be replicated on multiple nodes
1536-
node_groups = max(1, training_args.mp_devices // jax.local_device_count())
1537-
loader_bs = batch_size_per_node * node_groups
1538-
train_loader = dataset.dataloader(
1539-
"train",
1540-
loader_bs,
1541-
epoch,
1542-
)
1543-
# train
1544-
for batch in tqdm(
1545-
train_loader,
1546-
desc="Training...",
1547-
position=1,
1548-
leave=False,
1549-
total=steps_per_epoch,
1550-
disable=jax.process_index() > 0,
1551-
):
1552-
# calculate delta time (we have a lag of one step but it's ok)
1553-
train_time = time.perf_counter() - start_time
1554-
1555-
# set correct shape to batch
1556-
# - add grad_step dim if gradient_accumulation_steps > 1
1557-
bs_shape = (
1558-
(batch_size_per_node_per_grad_step * node_groups,)
1559-
if not use_vmap_trick
1560-
else (
1561-
jax.local_device_count()
1562-
* node_groups
1563-
// training_args.mp_devices, # local dp devices
1564-
training_args.per_device_train_batch_size,
1565-
)
1539+
if training_args.do_train:
1540+
# load data - may be replicated on multiple nodes
1541+
node_groups = max(
1542+
1, training_args.mp_devices // jax.local_device_count()
15661543
)
1567-
if training_args.gradient_accumulation_steps > 1:
1568-
# reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1569-
# to avoid any data redistribution when sharding
1570-
bs_shape = (training_args.gradient_accumulation_steps,) + bs_shape
1571-
1572-
# reshape batch
1573-
batch = jax.tree_map(
1574-
lambda x: x.reshape(bs_shape + x.shape[1:]),
1575-
batch,
1544+
loader_bs = batch_size_per_node * node_groups
1545+
train_loader = dataset.dataloader(
1546+
"train",
1547+
loader_bs,
1548+
epoch,
15761549
)
1577-
# freeze batch to pass safely to jax transforms
1578-
batch = freeze(batch)
1579-
1580-
# train step
1581-
state, train_metrics = p_train_step(state, batch, train_time)
1582-
local_state["step"] += 1
1583-
local_state["train_time"] = train_time
1584-
local_state["train_samples"] += batch_size_per_step
1585-
1586-
if (
1587-
local_state["step"] % training_args.logging_steps == 0
1588-
and jax.process_index() == 0
1550+
# train
1551+
for batch in tqdm(
1552+
train_loader,
1553+
desc="Training...",
1554+
position=1,
1555+
leave=False,
1556+
total=steps_per_epoch,
1557+
disable=jax.process_index() > 0,
15891558
):
1590-
metrics_logger.update_state_metrics(local_state)
1591-
metrics_logger.log(train_metrics, prefix="train")
1592-
1593-
eval_metrics = None
1594-
if local_state["step"] % training_args.eval_steps == 0:
1595-
eval_metrics = run_evaluation()
1559+
# calculate delta time (we have a lag of one step but it's ok)
1560+
train_time = time.perf_counter() - start_time
15961561

1597-
if local_state["step"] % training_args.save_steps == 0:
1598-
run_save_model(state, eval_metrics)
1562+
# reset control variables
1563+
evaluation_ran = False
1564+
save_model_ran = False
15991565

1600-
# log final train metrics
1601-
if train_metrics is not None:
1602-
metrics_logger.update_state_metrics(state)
1603-
metrics_logger.log(train_metrics, prefix="train")
1566+
# set correct shape to batch
1567+
# - add grad_step dim if gradient_accumulation_steps > 1
1568+
bs_shape = (
1569+
(batch_size_per_node_per_grad_step * node_groups,)
1570+
if not use_vmap_trick
1571+
else (
1572+
jax.local_device_count()
1573+
* node_groups
1574+
// training_args.mp_devices, # local dp devices
1575+
training_args.per_device_train_batch_size,
1576+
)
1577+
)
1578+
if training_args.gradient_accumulation_steps > 1:
1579+
# reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1580+
# to avoid any data redistribution when sharding
1581+
bs_shape = (
1582+
training_args.gradient_accumulation_steps,
1583+
) + bs_shape
1584+
1585+
# reshape batch
1586+
batch = jax.tree_map(
1587+
lambda x: x.reshape(bs_shape + x.shape[1:]),
1588+
batch,
1589+
)
1590+
# freeze batch to pass safely to jax transforms
1591+
batch = freeze(batch)
1592+
1593+
# train step
1594+
state, train_metrics = p_train_step(state, batch, train_time)
1595+
local_state["step"] += 1
1596+
local_state["train_time"] = train_time
1597+
local_state["train_samples"] += batch_size_per_step
1598+
1599+
if (
1600+
local_state["step"] % training_args.logging_steps == 0
1601+
and jax.process_index() == 0
1602+
):
1603+
metrics_logger.update_state_metrics(local_state)
1604+
metrics_logger.log(train_metrics, prefix="train")
1605+
1606+
eval_metrics = None
1607+
if local_state["step"] % training_args.eval_steps == 0:
1608+
eval_metrics = run_evaluation()
1609+
evaluation_ran = True
1610+
1611+
if local_state["step"] % training_args.save_steps == 0:
1612+
run_save_model(state, eval_metrics)
1613+
save_model_ran = True
1614+
1615+
# log final train metrics
1616+
if train_metrics is not None:
1617+
metrics_logger.update_state_metrics(state)
1618+
metrics_logger.log(train_metrics, prefix="train")
16041619

1605-
epochs.write(
1606-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
1607-
)
1620+
epochs.write(
1621+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
1622+
)
16081623

1609-
# Final evaluation
1610-
eval_metrics = run_evaluation()
1624+
# Final evaluation at the end of each epoch
1625+
if not evaluation_ran:
1626+
eval_metrics = run_evaluation()
16111627

16121628
# save checkpoint after each epoch
1613-
run_save_model(state, eval_metrics)
1629+
if not save_model_ran:
1630+
run_save_model(state, eval_metrics)
16141631

16151632

16161633
if __name__ == "__main__":

0 commit comments

Comments
 (0)