Skip to content

Commit 3109050

Browse files
authored
feat: support multi validation datasets (borisdayma#192)
1 parent 21944e2 commit 3109050

File tree

2 files changed

+140
-49
lines changed

2 files changed

+140
-49
lines changed

src/dalle_mini/data.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import random
22
from dataclasses import dataclass, field
33
from functools import partial
4+
from pathlib import Path
45

56
import jax
67
import jax.numpy as jnp
@@ -34,8 +35,10 @@ class Dataset:
3435
max_clip_score: float = None
3536
filter_column: str = None
3637
filter_value: str = None
38+
multi_eval_ds: bool = False
3739
train_dataset: Dataset = field(init=False)
3840
eval_dataset: Dataset = field(init=False)
41+
other_eval_datasets: list = field(init=False)
3942
rng_dataset: jnp.ndarray = field(init=False)
4043
multi_hosts: bool = field(init=False)
4144

@@ -75,6 +78,21 @@ def __post_init__(self):
7578
else:
7679
data_files = None
7780

81+
# multiple validation datasets
82+
if self.multi_eval_ds:
83+
assert Path(
84+
self.dataset_repo_or_path
85+
).is_dir(), f"{self.dataset_repo_or_path} is not a directory, required for multi_eval_ds"
86+
data_files = {
87+
split.name: [str(f) for f in split.glob("*.parquet")]
88+
for split in Path(self.dataset_repo_or_path).glob("*")
89+
}
90+
# rename "valid" to "validation" if present for consistency
91+
if "valid" in data_files:
92+
data_files["validation"] = data_files["valid"]
93+
del data_files["valid"]
94+
self.dataset_repo_or_path = "parquet"
95+
7896
# load dataset
7997
dataset = load_dataset(
8098
self.dataset_repo_or_path,
@@ -102,6 +120,11 @@ def __post_init__(self):
102120
if self.streaming
103121
else self.eval_dataset.select(range(self.max_eval_samples))
104122
)
123+
# other eval datasets
124+
other_eval_splits = dataset.keys() - {"train", "validation"}
125+
self.other_eval_datasets = {
126+
split: dataset[split] for split in other_eval_splits
127+
}
105128

106129
def preprocess(self, tokenizer, config):
107130
# get required config variables
@@ -143,6 +166,20 @@ def preprocess(self, tokenizer, config):
143166
)
144167
),
145168
)
169+
if hasattr(self, "other_eval_datasets"):
170+
self.other_eval_datasets = {
171+
split: (
172+
ds.filter(partial_filter_function)
173+
if self.streaming
174+
else ds.filter(
175+
partial_filter_function,
176+
num_proc=self.preprocessing_num_workers,
177+
load_from_cache_file=not self.overwrite_cache,
178+
desc="Filtering datasets",
179+
)
180+
)
181+
for split, ds in self.other_eval_datasets.items()
182+
}
146183

147184
# normalize text
148185
if normalize_text:
@@ -168,6 +205,20 @@ def preprocess(self, tokenizer, config):
168205
)
169206
),
170207
)
208+
if hasattr(self, "other_eval_datasets"):
209+
self.other_eval_datasets = {
210+
split: (
211+
ds.map(partial_normalize_function)
212+
if self.streaming
213+
else ds.map(
214+
partial_normalize_function,
215+
num_proc=self.preprocessing_num_workers,
216+
load_from_cache_file=not self.overwrite_cache,
217+
desc="Normalizing datasets",
218+
)
219+
)
220+
for split, ds in self.other_eval_datasets.items()
221+
}
171222

172223
# blank captions
173224
if self.blank_caption_prob:
@@ -225,6 +276,29 @@ def preprocess(self, tokenizer, config):
225276
)
226277
),
227278
)
279+
if hasattr(self, "other_eval_datasets"):
280+
self.other_eval_datasets = {
281+
split: (
282+
ds.map(
283+
partial_preprocess_function,
284+
batched=True,
285+
remove_columns=[
286+
self.text_column,
287+
self.encoding_column,
288+
],
289+
)
290+
if self.streaming
291+
else ds.map(
292+
partial_preprocess_function,
293+
batched=True,
294+
remove_columns=getattr(ds, "column_names"),
295+
num_proc=self.preprocessing_num_workers,
296+
load_from_cache_file=not self.overwrite_cache,
297+
desc="Preprocessing datasets",
298+
)
299+
)
300+
for split, ds in self.other_eval_datasets.items()
301+
}
228302

229303
def dataloader(self, split, batch_size, epoch=None):
230304
def _dataloader_datasets_non_streaming(
@@ -283,7 +357,7 @@ def _dataloader_datasets_streaming(
283357
elif split == "eval":
284358
ds = self.eval_dataset
285359
else:
286-
raise ValueError(f'split must be "train" or "eval", got {split}')
360+
ds = self.other_eval_datasets[split]
287361

288362
if self.streaming:
289363
return _dataloader_datasets_streaming(ds, epoch)

tools/train/train.py

Lines changed: 65 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,12 @@ class DataTrainingArguments:
250250
default=None,
251251
metadata={"help": "Class value to be kept during filtering."},
252252
)
253+
multi_eval_ds: Optional[bool] = field(
254+
default=False,
255+
metadata={
256+
"help": "Whether to look for multiple validation datasets (local support only)."
257+
},
258+
)
253259
max_train_samples: Optional[int] = field(
254260
default=None,
255261
metadata={
@@ -1383,62 +1389,73 @@ def run_evaluation():
13831389
# ======================== Evaluating ==============================
13841390
if training_args.do_eval:
13851391
start_eval_time = time.perf_counter()
1386-
eval_loader = dataset.dataloader(
1387-
"eval",
1388-
eval_batch_size_per_step
1389-
* max(1, training_args.mp_devices // jax.local_device_count()),
1390-
)
1391-
eval_steps = (
1392-
len_eval_dataset // eval_batch_size_per_step
1393-
if len_eval_dataset is not None
1394-
else None
1392+
# get validation datasets
1393+
val_datasets = list(
1394+
dataset.other_eval_datasets.keys()
1395+
if hasattr(dataset, "other_eval_datasets")
1396+
else []
13951397
)
1396-
eval_loss = []
1397-
for batch in tqdm(
1398-
eval_loader,
1399-
desc="Evaluating...",
1400-
position=2,
1401-
leave=False,
1402-
total=eval_steps,
1403-
disable=jax.process_index() > 0,
1404-
):
1405-
# need to keep only eval_batch_size_per_node items relevant to the node
1406-
batch = jax.tree_map(
1407-
lambda x: x.reshape(
1408-
(jax.process_count(), eval_batch_size_per_node) + x.shape[1:]
1409-
),
1410-
batch,
1398+
val_datasets += ["eval"]
1399+
for val_dataset in val_datasets:
1400+
eval_loader = dataset.dataloader(
1401+
val_dataset,
1402+
eval_batch_size_per_step
1403+
* max(1, training_args.mp_devices // jax.local_device_count()),
14111404
)
1412-
batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
1413-
1414-
# add dp dimension when using "vmap trick"
1415-
if use_vmap_trick:
1416-
bs_shape = (
1417-
jax.local_device_count() // training_args.mp_devices,
1418-
training_args.per_device_eval_batch_size,
1419-
)
1405+
eval_steps = (
1406+
len_eval_dataset // eval_batch_size_per_step
1407+
if len_eval_dataset is not None
1408+
else None
1409+
)
1410+
eval_loss = []
1411+
for batch in tqdm(
1412+
eval_loader,
1413+
desc="Evaluating...",
1414+
position=2,
1415+
leave=False,
1416+
total=eval_steps,
1417+
disable=jax.process_index() > 0,
1418+
):
1419+
# need to keep only eval_batch_size_per_node items relevant to the node
14201420
batch = jax.tree_map(
1421-
lambda x: x.reshape(bs_shape + x.shape[1:]), batch
1421+
lambda x: x.reshape(
1422+
(jax.process_count(), eval_batch_size_per_node)
1423+
+ x.shape[1:]
1424+
),
1425+
batch,
14221426
)
1427+
batch = jax.tree_map(lambda x: x[jax.process_index()], batch)
14231428

1424-
# freeze batch to pass safely to jax transforms
1425-
batch = freeze(batch)
1426-
# accumulate losses async
1427-
eval_loss.append(p_eval_step(state, batch))
1429+
# add dp dimension when using "vmap trick"
1430+
if use_vmap_trick:
1431+
bs_shape = (
1432+
jax.local_device_count() // training_args.mp_devices,
1433+
training_args.per_device_eval_batch_size,
1434+
)
1435+
batch = jax.tree_map(
1436+
lambda x: x.reshape(bs_shape + x.shape[1:]), batch
1437+
)
14281438

1429-
# get the mean of the loss
1430-
eval_loss = jnp.stack(eval_loss)
1431-
eval_loss = jnp.mean(eval_loss)
1432-
eval_metrics = {"loss": eval_loss}
1439+
# freeze batch to pass safely to jax transforms
1440+
batch = freeze(batch)
1441+
# accumulate losses async
1442+
eval_loss.append(p_eval_step(state, batch))
14331443

1434-
# log metrics
1435-
metrics_logger.log(eval_metrics, prefix="eval")
1436-
metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)
1444+
# get the mean of the loss
1445+
eval_loss = jnp.stack(eval_loss)
1446+
eval_loss = jnp.mean(eval_loss)
1447+
eval_metrics = {"loss": eval_loss}
1448+
1449+
# log metrics
1450+
metrics_logger.log(eval_metrics, prefix=val_dataset)
14371451

1438-
# Print metrics and update progress bar
1439-
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
1440-
epochs.write(desc)
1441-
epochs.desc = desc
1452+
# Print metrics and update progress bar
1453+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | {val_dataset} Loss: {eval_metrics['loss']})"
1454+
epochs.write(desc)
1455+
epochs.desc = desc
1456+
1457+
# log time
1458+
metrics_logger.log_time("eval", time.perf_counter() - start_eval_time)
14421459

14431460
return eval_metrics
14441461

0 commit comments

Comments
 (0)