Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 9 additions & 65 deletions fast_llm/data/dataset.py β†’ fast_llm/data/blended.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import abc
import logging
import pathlib
import time

import numpy as np
import torch.utils.data

from fast_llm.core.distributed import ProcessGroup, safe_barrier
from fast_llm.data.config import SampledDataset
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.utils import Assert

Expand All @@ -20,43 +19,6 @@
logger = logging.getLogger(__name__)


class Dataset(abc.ABC):
"""
A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature.
"""

@abc.abstractmethod
def __getitem__(self, index: int):
pass

@abc.abstractmethod
def __len__(self):
pass

@property
@abc.abstractmethod
def name(self):
"""
A name for the dataset to facilitate identification and debugging.
"""


class RawDataset(Dataset): # noqa
"""
A raw dataset class containing a list of unsampled, unprocessed samples, i.e., matching what is stored on disk.
(Excluding off-line processing prior to training.)
Functionally identical to a `Dataset`, but renamed for clarity.
"""


class SampledDataset(Dataset): # noqa
"""
A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training.
(See the `Sampler` class below.)
Functionally identical to a `Dataset`, but renamed for clarity.
"""


class BlendedDataset(SampledDataset):
"""
A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability.
Expand All @@ -72,7 +34,7 @@ def __init__(
*,
name: str = "blended",
num_samples: int,
cache_dir: pathlib.Path | None = None,
cache_directory: pathlib.Path | None = None,
group: ProcessGroup | None = None,
verbose: bool = True,
data_sample_warn_time_ms: float = 1000,
Expand All @@ -83,19 +45,20 @@ def __init__(
self._weights = weights
self._data_sample_warn_time_ms = data_sample_warn_time_ms

if cache_dir is None:
if cache_directory is None:
self._dataset_idx_filename, self._sample_idx_filename = None, None
self._dataset_index, self._sample_index = self._build_blending_indices(verbose and len(datasets) <= 20)
else:
self._dataset_idx_filename = cache_dir / (self._name + "_blending_dataset_idx.npy")
self._sample_idx_filename = cache_dir / (self._name + "_blending_sample_idx.npy")
self._dataset_idx_filename = cache_directory / (self._name + "_blending_dataset_idx.npy")
self._sample_idx_filename = cache_directory / (self._name + "_blending_sample_idx.npy")

# Build the indexed mapping if it doesn't exist.
# TODO: This only works if the dataset location is accessible by all job.
if (group is None or group.rank() == 0) and not (
self._dataset_idx_filename.is_file() and self._sample_idx_filename.is_file()
):
dataset_index, sample_index = self._build_blending_indices(verbose and len(datasets) <= 20)
cache_directory.mkdir(exist_ok=True, parents=True)
np.save(self._dataset_idx_filename, dataset_index)
np.save(self._sample_idx_filename, sample_index)

Expand Down Expand Up @@ -140,7 +103,9 @@ def __len__(self):
return self._num_samples

def _build_blending_indices(self, verbose: bool):
assert _extension_available, "Please run `make -C ./fast_llm/csrc/` first."
assert _extension_available, (
"The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly."
)
Assert.lt(len(self._datasets), 32767)
dataset_index = np.zeros(self._num_samples, dtype=np.int16)
dataset_sample_index = np.zeros(self._num_samples, dtype=np.int64)
Expand Down Expand Up @@ -191,24 +156,3 @@ def __getitem__(self, idx):
@property
def name(self):
return self._name


class Sampler(torch.utils.data.Sampler):
"""
A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers).
To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`.
"""

def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel):
self._total_samples = total_samples
self._begin_index = begin_index
self._batch_size = micro_batch_size * data_parallel
self._start_idx = data_rank * micro_batch_size
self._end_idx = (data_rank + 1) * micro_batch_size

def __len__(self):
return self._total_samples

def __iter__(self):
for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size):
yield list(range(idx + self._start_idx, idx + self._end_idx))
80 changes: 33 additions & 47 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ class TokenizerConfig(Config):


@config_class()
class AbstractDataConfig(Config):
class DataConfig(Config):
_abstract = True


class AbstractData(abc.ABC):
class Data(abc.ABC):
# TODO: Improve interface
@abc.abstractmethod
def setup(self, distributed: "Distributed", samples_per_phase: dict[PhaseType, int]):
Expand All @@ -146,52 +146,38 @@ def get_iterator(
pass


@config_class()
class DataConfig(AbstractDataConfig):
class Dataset(abc.ABC):
"""
Configuration for the dataset(s), split and sampling.
Currently hard-coded to a GPT dataset.
TODO: Extract generalizable content.
A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature.
"""

_abstract = False
@abc.abstractmethod
def __getitem__(self, index: int):
pass

tokenizer: TokenizerConfig = Field(
default_factory=TokenizerConfig,
desc="Configuration for the tokenizer (for FIM).",
hint=FieldHint.feature,
)
fim: FimConfig = Field(
default_factory=FimConfig,
desc="Configuration for Fill In the Middle (FIM).",
hint=FieldHint.feature,
)
# TODO: set default to [1,0,0]?
split: list[float] = Field(
default_factory=lambda: [969, 30, 1],
desc="Split ratio for train, valid and test datasets.",
hint=FieldHint.core,
valid=_validate_split,
)
format: DatasetSource = Field(
default=DatasetSource.list,
desc="Format for the dataset definition.",
hint=FieldHint.core,
)
path: list[str] = Field(
default_factory=list,
desc="Path or list of paths and weights.",
hint=FieldHint.core,
valid=_validate_path,
)
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
hint=FieldHint.feature,
valid=check_field(Assert.gt, 0),
)
multiprocessing_context: MultiprocessingContext = Field(
default=MultiprocessingContext.spawn,
desc="Multiprocessing context. Do not touch.",
hint=FieldHint.expert,
)
@abc.abstractmethod
def __len__(self):
pass

@property
@abc.abstractmethod
def name(self):
"""
A name for the dataset to facilitate identification and debugging.
"""


class RawDataset(Dataset): # noqa
"""
A raw dataset class containing a list of unsampled, unprocessed samples, i.e., matching what is stored on disk.
(Excluding off-line processing prior to training.)
Functionally identical to a `Dataset`, but renamed for clarity.
"""


class SampledDataset(Dataset): # noqa
"""
A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training.
(See the `Sampler` class below.)
Functionally identical to a `Dataset`, but renamed for clarity.
"""
Loading