From 558539c2bf0186fd6139b7d3259e9c39774715b6 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 7 Nov 2024 10:25:21 -0500 Subject: [PATCH 01/13] Flexiple dataset prototype --- fast_llm/data/{dataset.py => blended.py} | 64 +--- fast_llm/data/config.py | 76 ++-- fast_llm/data/data.py | 230 ------------ fast_llm/data/gpt/__init__.py | 0 fast_llm/data/gpt/concatenated.py | 44 +++ fast_llm/data/gpt/config.py | 429 +++++++++++++++++++++++ fast_llm/data/gpt/data.py | 236 +++++++++++++ fast_llm/data/gpt/dummy.py | 39 +++ fast_llm/data/{mmap.py => gpt/memmap.py} | 36 +- fast_llm/data/{gpt.py => gpt/sampled.py} | 206 +++-------- fast_llm/data/gpt/slice.py | 71 ++++ fast_llm/data/sampler.py | 22 ++ fast_llm/engine/config_utils/run.py | 3 - fast_llm/engine/training/trainer.py | 4 +- fast_llm/models/custom/config.py | 4 +- fast_llm/models/custom/data.py | 4 +- fast_llm/models/gpt/config.py | 4 +- fast_llm/models/gpt/trainer.py | 4 +- tests/common.py | 2 +- tools/concatenate_dataset.py | 2 +- 20 files changed, 960 insertions(+), 520 deletions(-) rename fast_llm/data/{dataset.py => blended.py} (75%) delete mode 100644 fast_llm/data/data.py create mode 100644 fast_llm/data/gpt/__init__.py create mode 100644 fast_llm/data/gpt/concatenated.py create mode 100644 fast_llm/data/gpt/config.py create mode 100644 fast_llm/data/gpt/data.py create mode 100644 fast_llm/data/gpt/dummy.py rename fast_llm/data/{mmap.py => gpt/memmap.py} (76%) rename fast_llm/data/{gpt.py => gpt/sampled.py} (60%) create mode 100644 fast_llm/data/gpt/slice.py create mode 100644 fast_llm/data/sampler.py diff --git a/fast_llm/data/dataset.py b/fast_llm/data/blended.py similarity index 75% rename from fast_llm/data/dataset.py rename to fast_llm/data/blended.py index 37dd4c948..6373d3a79 100644 --- a/fast_llm/data/dataset.py +++ b/fast_llm/data/blended.py @@ -1,12 +1,12 @@ -import abc import logging import pathlib import time +import typing import numpy as np -import torch.utils.data from fast_llm.core.distributed import ProcessGroup, safe_barrier +from fast_llm.data.config import SampledDataset from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import Assert @@ -20,43 +20,6 @@ logger = logging.getLogger(__name__) -class Dataset(abc.ABC): - """ - A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature. - """ - - @abc.abstractmethod - def __getitem__(self, index: int): - pass - - @abc.abstractmethod - def __len__(self): - pass - - @property - @abc.abstractmethod - def name(self): - """ - A name for the dataset to facilitate identification and debugging. - """ - - -class RawDataset(Dataset): # noqa - """ - A raw dataset class containing a list of unsampled, unprocessed samples, i.e., matching what is stored on disk. - (Excluding off-line processing prior to training.) - Functionally identical to a `Dataset`, but renamed for clarity. - """ - - -class SampledDataset(Dataset): # noqa - """ - A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training. - (See the `Sampler` class below.) - Functionally identical to a `Dataset`, but renamed for clarity. - """ - - class BlendedDataset(SampledDataset): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. @@ -73,7 +36,7 @@ def __init__( name: str = "blended", num_samples: int, cache_dir: pathlib.Path | None = None, - group: ProcessGroup | None = None, + group: typing.Optional["ProcessGroup"] = None, verbose: bool = True, data_sample_warn_time_ms: float = 1000, ): @@ -191,24 +154,3 @@ def __getitem__(self, idx): @property def name(self): return self._name - - -class Sampler(torch.utils.data.Sampler): - """ - A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers). - To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`. - """ - - def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel): - self._total_samples = total_samples - self._begin_index = begin_index - self._batch_size = micro_batch_size * data_parallel - self._start_idx = data_rank * micro_batch_size - self._end_idx = (data_rank + 1) * micro_batch_size - - def __len__(self): - return self._total_samples - - def __iter__(self): - for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size): - yield list(range(idx + self._start_idx, idx + self._end_idx)) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index f105c5054..e6f33eebf 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -146,52 +146,38 @@ def get_iterator( pass -@config_class() -class DataConfig(AbstractDataConfig): +class Dataset(abc.ABC): """ - Configuration for the dataset(s), split and sampling. - Currently hard-coded to a GPT dataset. - TODO: Extract generalizable content. + A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature. """ - _abstract = False + @abc.abstractmethod + def __getitem__(self, index: int): + pass - tokenizer: TokenizerConfig = Field( - default_factory=TokenizerConfig, - desc="Configuration for the tokenizer (for FIM).", - hint=FieldHint.feature, - ) - fim: FimConfig = Field( - default_factory=FimConfig, - desc="Configuration for Fill In the Middle (FIM).", - hint=FieldHint.feature, - ) - # TODO: set default to [1,0,0]? - split: list[float] = Field( - default_factory=lambda: [969, 30, 1], - desc="Split ratio for train, valid and test datasets.", - hint=FieldHint.core, - valid=_validate_split, - ) - format: DatasetSource = Field( - default=DatasetSource.list, - desc="Format for the dataset definition.", - hint=FieldHint.core, - ) - path: list[str] = Field( - default_factory=list, - desc="Path or list of paths and weights.", - hint=FieldHint.core, - valid=_validate_path, - ) - data_sample_warn_time_ms: float = Field( - default=1000, - desc="Warn if a sample takes too long to load.", - hint=FieldHint.feature, - valid=check_field(Assert.gt, 0), - ) - multiprocessing_context: MultiprocessingContext = Field( - default=MultiprocessingContext.spawn, - desc="Multiprocessing context. Do not touch.", - hint=FieldHint.expert, - ) + @abc.abstractmethod + def __len__(self): + pass + + @property + @abc.abstractmethod + def name(self): + """ + A name for the dataset to facilitate identification and debugging. + """ + + +class RawDataset(Dataset): # noqa + """ + A raw dataset class containing a list of unsampled, unprocessed samples, i.e., matching what is stored on disk. + (Excluding off-line processing prior to training.) + Functionally identical to a `Dataset`, but renamed for clarity. + """ + + +class SampledDataset(Dataset): # noqa + """ + A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training. + (See the `Sampler` class below.) + Functionally identical to a `Dataset`, but renamed for clarity. + """ diff --git a/fast_llm/data/data.py b/fast_llm/data/data.py deleted file mode 100644 index e58b62c4a..000000000 --- a/fast_llm/data/data.py +++ /dev/null @@ -1,230 +0,0 @@ -import json -import logging -import math -import pathlib -import typing -import warnings - -import numpy as np -import torch -import torch.utils.data - -from fast_llm.data.config import AbstractData, DataConfig, DatasetSource -from fast_llm.data.dataset import BlendedDataset, SampledDataset, Sampler -from fast_llm.data.gpt import DummyGPTDataset, GPTDataset, GPTSampledDataset -from fast_llm.data.mmap import MMapIndexedDataset -from fast_llm.data.tokenizer import Tokenizer -from fast_llm.engine.config_utils.run import get_run, log_main_rank -from fast_llm.engine.distributed.config import DistributedConfig, PhaseType -from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.utils import Assert - -logger = logging.getLogger(__name__) - - -def normalize_probs(p: list[float]) -> list[float]: - p = np.array(p) - Assert.custom(lambda x: np.all(x >= 0), p) - p_sum = p.sum() - Assert.gt(p_sum, 0) - return (p / p_sum).tolist() - - -class Data(AbstractData): - """ - A global class for all dataset needs, including loading, splitting, sampling and iteration. - Currently hard-coded to a GPT dataset. - TODO: Separate generic and GPT classes. - """ - - _sampled_datasets: dict[PhaseType, dict[str, SampledDataset]] - _blended_datasets: dict[PhaseType, SampledDataset] - _tokenizer: Tokenizer | None - _distributed: Distributed - _cache_dir: pathlib.Path | None - _samples_per_phase: dict[PhaseType, int] - _phases: typing.ClassVar[tuple[PhaseType, ...]] = (PhaseType.training, PhaseType.validation, PhaseType.test) - - def __init__( - self, - config: DataConfig, - distributed_config: DistributedConfig, - vocab_size: int, - max_sequence_length: int, - ): - """ - Create the data and gather some basic information on the dataset(s). - Should be `setup` before use. - """ - self._config = config.validate() - self._distributed_config = distributed_config.validate() - self._vocab_size = vocab_size - self._max_sequence_length = max_sequence_length - Assert.eq(len(self._config.split), len(self._phases)) - self._phase_split = { - phase: ratio for phase, ratio in zip(self._phases, normalize_probs(self._config.split)) if ratio > 0 - } - - data_base_path = None - if self._config.format == DatasetSource.file: - Assert.eq(len(self._config.path), 1) - data_path = pathlib.Path(self._config.path[0]) - dataset_defs = json.load(data_path.open("r")) - data_base_path = data_path.parent - dataset_prefixes = [dataset_def["prefix"] for dataset_def in dataset_defs["datasets"]] - dataset_weights = normalize_probs([dataset_def["weight"] for dataset_def in dataset_defs["datasets"]]) - self._build_and_sample_dataset = self._build_and_sample_gpt_dataset - elif self._config.format == DatasetSource.list: - Assert.geq(len(self._config.path), 1) - if len(self._config.path) == 1: - dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0] - else: - Assert.custom(lambda x: x % 2 == 0, len(self._config.path)) - dataset_prefixes = [x.strip() for x in self._config.path[1::2]] - assert len(dataset_prefixes) == len(set(dataset_prefixes)) - dataset_weights = normalize_probs([float(x) for x in self._config.path[::2]]) - self._build_and_sample_dataset = self._build_and_sample_gpt_dataset - elif self._config.format == DatasetSource.sample: - Assert.eq(len(self._config.path), 1) - dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0] - self._build_and_sample_dataset = self._build_and_sample_dummy_dataset - elif self._config.format == DatasetSource.random: - Assert.eq(len(self._config.path), 0) - dataset_prefixes, dataset_weights = [None], [1.0] - self._build_and_sample_dataset = self._build_and_sample_dummy_dataset - else: - raise NotImplementedError(self._config.format) - - dataset_names = [ - f"dataset_{i}_{'dummy' if prefix is None else prefix.replace('/','__')}" - for i, prefix in enumerate(dataset_prefixes) - ] - self._num_datasets = len(dataset_names) - self._dataset_prefixes = { - name: ( - None - if prefix is None - else ( - pathlib.Path(prefix).resolve() - if data_base_path is None - else (pathlib.Path(data_base_path) / prefix).resolve() - ) - ) - for name, prefix in zip(dataset_names, dataset_prefixes) - } - self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)} - - def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]): - """ - Load the datasets, and prepare or load the samplings. - This may take a while and a significant amount of cpu memory. - """ - run = get_run() - Assert.leq(set(samples_per_phase), set(self._phase_split)) - log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.") - self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None - self._distributed = distributed - self._cache_dir = run.dataset_cache_dir - self._samples_per_phase = samples_per_phase - if self._cache_dir is None: - warnings.warn(f"Using the dataset directory for the index cache.") - - # Build and split datasets. - self._sampled_datasets = {phase: {} for phase in self._samples_per_phase} - for i, (name, weight) in enumerate(self._dataset_weights.items()): - if i % 100 == 0 and i > 0: - log_main_rank(f"Prepared {i} of {self._num_datasets} datasets.") - dataset_samples_per_phase = {} - for phase, samples_per_phase in self._samples_per_phase.items(): - expected_samples = self._dataset_weights[name] * samples_per_phase - # Add 5 times the standard deviation (of a binomial distribution) - # so the probability of sampling more than this amount during blending is negligible. - dataset_samples_per_phase[phase] = math.ceil( - expected_samples - + 5 * math.sqrt(expected_samples * self._dataset_weights[name] * (1 - self._dataset_weights[name])) - ) - sampled_datasets = self._build_and_sample_dataset(name, dataset_samples_per_phase) - for phase, dataset in sampled_datasets.items(): - self._sampled_datasets[phase][name] = dataset - - self._blended_datasets = { - phase: ( - list(datasets.values())[0] - if len(datasets) == 1 - else BlendedDataset( - list(datasets.values()), - weights=[self._dataset_weights[name] for name in datasets], - name=phase.value, - num_samples=self._samples_per_phase[phase], - cache_dir=self._cache_dir, - group=self._distributed.world_group, - verbose=run.is_main_rank, - data_sample_warn_time_ms=self._config.data_sample_warn_time_ms, - ) - ) - for phase, datasets in self._sampled_datasets.items() - } - - def get_iterator( - self, - batch_config: BatchConfig, - phase: PhaseType, - *, - consumed_samples: int, - num_workers: int, - prefetch_factor: int | None = None, - ): - Assert.incl(phase, self._blended_datasets) - Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length) - log_main_rank(f"Initializing {phase} data iterator from sample {consumed_samples}...") - return iter( - torch.utils.data.DataLoader( - self._blended_datasets[phase], # noqa - batch_sampler=Sampler( - total_samples=len(self._blended_datasets[phase]), - begin_index=consumed_samples, - micro_batch_size=batch_config.micro_batch_size, - data_rank=self._distributed.config.batch_data_rank, - data_parallel=self._distributed.config.batch_data_parallel, - ), - num_workers=num_workers, - prefetch_factor=prefetch_factor, - pin_memory=True, - multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, - ) - ) - - def _build_and_sample_gpt_dataset(self, name: str, dataset_samples_per_phase: dict[PhaseType, int]): - dataset_split = GPTDataset.from_splits( - name, MMapIndexedDataset(self._dataset_prefixes[name]), self._phase_split - ) - - sampled_datasets = {} - for phase, num_samples in dataset_samples_per_phase.items(): - if num_samples == 0: - continue - sampled_datasets[phase] = GPTSampledDataset( - dataset_split[phase], - num_samples=num_samples, - sequence_length=self._max_sequence_length, - seed=self._distributed.config.seed, - group=self._distributed.world_group, - config=self._config, - tokenizer=self._tokenizer, - cache_dir=self._dataset_prefixes[name].parent if self._cache_dir is None else self._cache_dir, - verbose=self._num_datasets <= 5, - ) - return sampled_datasets - - def _build_and_sample_dummy_dataset(self, name: str, dataset_samples_per_phase: dict[PhaseType, int]): - return { - phase: DummyGPTDataset( - self._dataset_prefixes[name], - dataset_samples_per_phase[phase], - self._max_sequence_length, - self._vocab_size, - name, - ) - for phase in dataset_samples_per_phase - } diff --git a/fast_llm/data/gpt/__init__.py b/fast_llm/data/gpt/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/gpt/concatenated.py b/fast_llm/data/gpt/concatenated.py new file mode 100644 index 000000000..19883dd00 --- /dev/null +++ b/fast_llm/data/gpt/concatenated.py @@ -0,0 +1,44 @@ +import numpy as np + +from fast_llm.data.gpt.config import GPTConcatenatedDatasetConfig, GPTRawDataset +from fast_llm.utils import padded_cumsum + + +class GPTConcatenatedDataset(GPTRawDataset): + + def __init__( + self, + config: GPTConcatenatedDatasetConfig, + datasets: list[GPTRawDataset], + ): + self._config = config + self._datasets = datasets + sizes = [dataset.num_documents for dataset in self._datasets] + self._dataset_splits = padded_cumsum(sizes) + self._num_documents = sum(sizes) + + @property + def num_tokens(self): + return sum(dataset.num_tokens for dataset in self._datasets) + + def num_documents(self): + return self._num_documents + + def __getitem__(self, index: int): + """ + Get the sample (document) with the given index (in the split dataset). + """ + return self.get(index) + + def get(self, document: int, offset: int = 0, length: int | None = None): + """ + Get the sample (document) with the given index (in the dataset slice), + optionally sub-sampled to a specific offset (starting point) and maximum length + (end = min(offset + length, sample_length). + """ + dataset = np.searchsorted(self._dataset_splits[1:], document, side="right") + return self._datasets[dataset].get(document - self._dataset_splits[dataset], offset, length) + + @property + def name(self): + return self._config.name diff --git a/fast_llm/data/gpt/config.py b/fast_llm/data/gpt/config.py new file mode 100644 index 000000000..118670da8 --- /dev/null +++ b/fast_llm/data/gpt/config.py @@ -0,0 +1,429 @@ +import abc +import enum +import logging +import pathlib +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.data.config import ( + AbstractDataConfig, + FimConfig, + MultiprocessingContext, + RawDataset, + SampledDataset, + TokenizerConfig, +) +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.data.gpt.data import GPTData + +logger = logging.getLogger(__name__) + + +class GPTRawDataset(RawDataset): + def __len__(self): + return self.num_documents + + def get(self, document: int, offset: int = 0, length: int | None = None): + pass + + def __getitem__(self, index: int): + """ + Get the sample (document) with the given index (in the split dataset). + """ + return self.get(index) + + @property + @abc.abstractmethod + def num_documents(self): + pass + + @property + @abc.abstractmethod + def num_tokens(self): + pass + + @property + @abc.abstractmethod + def document_sizes(self): + pass + + +class GPTDatasetConfigType(str, enum.Enum): + split = "split" + splits = "splits" + concatenated = "concatenated" + blended = "blended" + memmap = "random" + + +@config_class() +class GPTDatasetConfig(Config): + + _abstract = True + type: GPTDatasetConfigType = Field( + desc="Format for the dataset definition.", + hint=FieldHint.core, + ) + prefix = Field( + desc="A prefix for the dataset name, set by wrapping datasets.", + init=False, + ) + + def _validate(self): + assert hasattr(self, "prefix") + super()._validate() + + @classmethod + def from_dict( + cls, + default: typing.Union["Config", dict[str, typing.Any]], + *updates: typing.Union["Config", dict[str | tuple[str, ...], typing.Any]], + strict: bool = True, + ): + if cls.type == GPTDatasetConfigType.split: + cls_ = GPTSplitDatasetConfig + elif cls.type == GPTDatasetConfigType.splits: + cls_ = GPTDatasetSplitsConfig + elif cls.type == GPTDatasetConfigType.concatenated: + cls_ = GPTConcatenatedDatasetConfig + elif cls.type == GPTDatasetConfigType.blended: + cls_ = GPTBlendedDatasetConfig + elif cls.type == GPTDatasetConfigType.memmap: + cls_ = GPTMemmapDatasetConfig + else: + raise NotImplementedError(cls.type) + return cls_.from_dict(default, *updates, strict=strict) + + @property + def sampled(self) -> bool: + raise NotImplementedError() + + @property + def split(self) -> bool: + raise NotImplementedError() + + def build_split_sampled(self, data: "GPTData") -> dict[PhaseType, SampledDataset]: + if self.split: + return self._build_split_sampled(data) + else: + return {PhaseType.training: self.build_unsplit_sampled(data)} + + def build_unsplit_sampled(self, data: "GPTData") -> SampledDataset: + assert not self.split + if self.sampled: + return self._build_unsplit_sampled(data) + else: + return self._sample(self.build_unsplit_unsampled(), data) + + def build_split_unsampled(self) -> dict[PhaseType, GPTRawDataset]: + assert not self.sampled + if self.split: + return self._build_split_unsampled() + else: + return {PhaseType.training: self.build_unsplit_unsampled()} + + def build_unsplit_unsampled(self) -> GPTRawDataset: + assert not self.split + assert not self.sampled + return self._build_unsplit_unsampled() + + def _build(self) -> GPTRawDataset | SampledDataset | dict[PhaseType, GPTRawDataset | SampledDataset]: + raise NotImplementedError() + + def _build_split_sampled(self, data: "GPTData") -> dict[PhaseType, SampledDataset]: + raise NotImplementedError() + + def _build_unsplit_sampled(self, data: "GPTData") -> SampledDataset: + raise NotImplementedError() + + def _build_split_unsampled(self) -> dict[PhaseType, GPTRawDataset]: + raise NotImplementedError() + + def _build_unsplit_unsampled(self) -> GPTRawDataset: + raise NotImplementedError() + + @property + def full_name(self) -> str: + return f"{self.prefix}{self._base_name}" + + @property + def _base_name(self) -> str: + raise NotImplementedError() + + +@config_class() +class GPTMemmapDatasetConfig(GPTDatasetConfig): + # Path -> (unsampled, unsplit) + _abstract = False + path: pathlib.Path = Field( + desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", + hint=FieldHint.core, + ) + + @property + def split(self) -> bool: + return False + + @property + def sampled(self) -> bool: + return False + + def _build_unsplit_unsampled(self) -> GPTRawDataset: + from fast_llm.data.gpt.memmap import GPTMemmapDataset + + return GPTMemmapDataset(self) + + @property + def _base_name(self) -> str: + return self.path.stem + + +@config_class() +class GPTConcatenatedDatasetConfig(GPTDatasetConfig): + """ + Concatenate multiple datasets as if they were one. + Must be done before sampling and splitting. + TODO: OK after sampling (staged training?) or splitting (Equal split for each sub-dataset, probably better? + [(unsampled, unsplit)] -> (unsampled, unsplit) + """ + + _abstract = False + name: str = Field( + default="concatenated", + desc="The name of the dataset.", + hint=FieldHint.core, + ) + datasets: list[GPTDatasetConfig] = Field( + desc="The datasets to concatenate.", + hint=FieldHint.core, + ) + + def _validate(self): + if not hasattr(self, "prefix"): + self.prefix = "" + for i, dataset in enumerate(self.datasets): + # The phase name will also appear in the phase suffix, + # but we still need this to disambiguate the non-suffixed name. + dataset.prefix = f"{self.prefix}{self._base_name}{i}_" + super()._validate() + for dataset in self.datasets: + assert not dataset.split + assert not any(dataset.sampled for dataset in self.datasets) + + @property + def split(self) -> bool: + return False + + @property + def sampled(self) -> bool: + return False + + def _build_unsplit_unsampled(self) -> GPTRawDataset: + from fast_llm.data.gpt.concatenated import GPTConcatenatedDataset + + return GPTConcatenatedDataset(self, [dataset.build_unsplit_unsampled() for dataset in self.datasets]) + + @property + def _base_name(self) -> str: + return f"{self.name}/" + + +@config_class() +class GPTSplitDatasetConfig(GPTDatasetConfig): + """ + Split a single dataset into multiple phases. + Must be done before sampling. + TODO: Ok after sampling? + (unsampled, unsplit) -> (unsampled, split) + """ + + _abstract = False + dataset: GPTDatasetConfig = Field( + desc="The dataset to split.", + hint=FieldHint.core, + ) + ratios: dict[PhaseType, float] = Field( + desc="The split ratio for each phase", + hint=FieldHint.core, + ) + name: str = Field( + default="split", + desc="The name of the dataset.", + hint=FieldHint.core, + ) + + def _validate(self): + if not hasattr(self, "prefix"): + self.prefix = "" + self.dataset.prefix = f"{self.prefix}{self._base_name}" + super()._validate() + assert not self.dataset.split + assert not self.dataset.sampled + + @property + def sampled(self) -> bool: + return False + + @property + def split(self) -> bool: + return True + + def _build_split_unsampled(self) -> dict[PhaseType, GPTRawDataset]: + from fast_llm.data.gpt.slice import GPTDatasetSlice + + return GPTDatasetSlice.from_splits(self.dataset.build_unsplit_unsampled(), self.ratios) + + @property + def _base_name(self) -> str: + return f"{self.name}/" + + +@config_class() +class GPTDatasetSplitsConfig(GPTDatasetConfig): + """ + Create a separate dataset for each phase. + May be done before or after sampling. + {phase:(?sampled, unsplit)} -> (?sampled, split) + """ + + _abstract = False + datasets: dict[PhaseType, GPTDatasetConfig] = Field( + desc="The dataset to split.", + hint=FieldHint.core, + ) + + def _validate(self): + if not hasattr(self, "prefix"): + self.prefix = "" + for phase, dataset in self.datasets.items(): + # The phase name will also appear in the phase suffix, + # but we still need this to disambiguate the non-suffixed name. + dataset.prefix = f"{self.prefix}{phase.value}/" + super()._validate() + for phase, dataset in self.datasets.items(): + assert not dataset.split, phase + _ = self.sampled + + @property + def split(self) -> bool: + return True + + @property + def sampled(self) -> bool: + sampled = {dataset.sampled for dataset in self.datasets.values()} + assert len(sampled) == 1, sampled + return sampled.pop() + + def _build_split_sampled(self, data: "GPTData") -> dict[PhaseType, SampledDataset]: + return {phase: dataset.build_unsplit_sampled(data) for phase, dataset in self.datasets.items()} + + def _build_split_unsampled(self) -> dict[PhaseType, GPTRawDataset]: + return {phase: dataset.build_unsplit_unsampled() for phase, dataset in self.datasets.items()} + + @property + def _base_name(self) -> str: + return "" + + +@config_class() +class GPTBlendedDatasetConfig(GPTDatasetConfig): + # [(?sampled, ?split)] -> (sampled, split) + _abstract = False + datasets: list[GPTDatasetConfig] = Field( + desc="The datasets to concatenate.", + hint=FieldHint.core, + ) + weights: list[float] = Field( + desc="The blending weight of each dataset.", + hint=FieldHint.core, + ) + + def _validate(self): + super()._validate() + for dataset in self.datasets: + assert not dataset.split + assert not any(dataset.sampled for dataset in self.datasets) + + @property + def split(self) -> bool: + return True + + @property + def sampled(self) -> bool: + return True + + def _build_split_sampled(self, data: "GPTData") -> dict[PhaseType, SampledDataset]: + from fast_llm.data.blended import BlendedDataset + + datasets = {} + for dataset in self.datasets: + dataset_split = dataset.build_split_sampled(data) + if datasets: + Assert.eq(set(datasets), set(dataset_split)) + else: + datasets = {phase: [] for phase in dataset_split} + for phase, phase_datasets in datasets.items(): + phase_datasets.append(dataset_split[phase]) + return { + phase: BlendedDataset(phase_datasets, self.weights, data) for phase, phase_datasets in datasets.items() + } + + +@config_class() +class GPTDataConfig(AbstractDataConfig): + """ + Configuration for the dataset(s), split and sampling. + Currently hard-coded to a GPT dataset. + TODO: Extract generalizable content. + """ + + _abstract = False + + dataset: GPTDatasetConfig = Field( + # TODO: Dummy default? + default_factory=GPTDatasetConfig, + desc="Configuration for the dataset.", + hint=FieldHint.core, + ) + tokenizer: TokenizerConfig = Field( + default_factory=TokenizerConfig, + desc="Configuration for the tokenizer (for FIM).", + hint=FieldHint.feature, + ) + fim: FimConfig = Field( + default_factory=FimConfig, + desc="Configuration for Fill In the Middle (FIM).", + hint=FieldHint.feature, + ) + # TODO: set default to [1,0,0]? + # split: list[float] = Field( + # default_factory=lambda: [969, 30, 1], + # desc="Split ratio for train, valid and test datasets.", + # hint=FieldHint.core, + # valid=_validate_split, + # ) + # format: DatasetSource = Field( + # default=DatasetSource.list, + # desc="Format for the dataset definition.", + # hint=FieldHint.core, + # ) + # path: list[str] = Field( + # default_factory=list, + # desc="Path or list of paths and weights.", + # hint=FieldHint.core, + # valid=_validate_path, + # ) + data_sample_warn_time_ms: float = Field( + default=1000, + desc="Warn if a sample takes too long to load.", + hint=FieldHint.feature, + valid=check_field(Assert.gt, 0), + ) + multiprocessing_context: MultiprocessingContext = Field( + default=MultiprocessingContext.spawn, + desc="Multiprocessing context. Do not touch.", + hint=FieldHint.expert, + ) diff --git a/fast_llm/data/gpt/data.py b/fast_llm/data/gpt/data.py new file mode 100644 index 000000000..1158cc73e --- /dev/null +++ b/fast_llm/data/gpt/data.py @@ -0,0 +1,236 @@ +import logging +import pathlib +import typing + +import numpy as np +import torch +import torch.utils.data + +from fast_llm.data.config import AbstractData, SampledDataset +from fast_llm.data.gpt.config import GPTDataConfig +from fast_llm.data.sampler import Sampler +from fast_llm.data.tokenizer import Tokenizer +from fast_llm.engine.config_utils.run import get_run, log_main_rank +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.utils import Assert + +logger = logging.getLogger(__name__) + + +def normalize_probs(p: list[float]) -> list[float]: + p = np.array(p) + Assert.custom(lambda x: np.all(x >= 0), p) + p_sum = p.sum() + Assert.gt(p_sum, 0) + return (p / p_sum).tolist() + + +class GPTData(AbstractData): + """ + A global class for all dataset needs, including loading, splitting, sampling and iteration. + Currently hard-coded to a GPT dataset. + TODO: Separate generic and GPT classes. + """ + + _sampled_dataset: dict[PhaseType, SampledDataset] + _tokenizer: Tokenizer | None + _distributed: Distributed + _cache_directory: pathlib.Path | None + _samples_per_phase: dict[PhaseType, int] + _phases: typing.ClassVar[tuple[PhaseType, ...]] = (PhaseType.training, PhaseType.validation, PhaseType.test) + + def __init__( + self, + config: GPTDataConfig, + distributed_config: DistributedConfig, + vocab_size: int, + max_sequence_length: int, + ): + """ + Create the data and gather some basic information on the dataset(s). + Should be `setup` before use. + """ + self._config = config.validate() + self._distributed_config = distributed_config.validate() + self._vocab_size = vocab_size + self._max_sequence_length = max_sequence_length + # Assert.eq(len(self._config.split), len(self._phases)) + # self._phase_split = { + # phase: ratio for phase, ratio in zip(self._phases, normalize_probs(self._config.split)) if ratio > 0 + # } + + # data_base_path = None + # if self._config.format == DatasetSource.file: + # Assert.eq(len(self._config.path), 1) + # data_path = pathlib.Path(self._config.path[0]) + # dataset_defs = json.load(data_path.open("r")) + # data_base_path = data_path.parent + # dataset_prefixes = [dataset_def["prefix"] for dataset_def in dataset_defs["datasets"]] + # dataset_weights = normalize_probs([dataset_def["weight"] for dataset_def in dataset_defs["datasets"]]) + # self._build_and_sample_dataset = self._build_and_sample_gpt_dataset + # elif self._config.format == DatasetSource.list: + # Assert.geq(len(self._config.path), 1) + # if len(self._config.path) == 1: + # dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0] + # else: + # Assert.custom(lambda x: x % 2 == 0, len(self._config.path)) + # dataset_prefixes = [x.strip() for x in self._config.path[1::2]] + # assert len(dataset_prefixes) == len(set(dataset_prefixes)) + # dataset_weights = normalize_probs([float(x) for x in self._config.path[::2]]) + # self._build_and_sample_dataset = self._build_and_sample_gpt_dataset + # elif self._config.format == DatasetSource.sample: + # Assert.eq(len(self._config.path), 1) + # dataset_prefixes, dataset_weights = [self._config.path[0].strip()], [1.0] + # self._build_and_sample_dataset = self._build_and_sample_dummy_dataset + # elif self._config.format == DatasetSource.random: + # Assert.eq(len(self._config.path), 0) + # dataset_prefixes, dataset_weights = [None], [1.0] + # self._build_and_sample_dataset = self._build_and_sample_dummy_dataset + # else: + # raise NotImplementedError(self._config.format) + + # dataset_names = [ + # f"dataset_{i}_{'dummy' if prefix is None else prefix.replace('/','__')}" + # for i, prefix in enumerate(dataset_prefixes) + # ] + # self._num_datasets = len(dataset_names) + # self._dataset_prefixes = { + # name: ( + # None + # if prefix is None + # else ( + # pathlib.Path(prefix).resolve() + # if data_base_path is None + # else (pathlib.Path(data_base_path) / prefix).resolve() + # ) + # ) + # for name, prefix in zip(dataset_names, dataset_prefixes) + # } + # self._dataset_weights = {name: weight for name, weight in zip(dataset_names, dataset_weights)} + + def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int]): + """ + Load the datasets, and prepare or load the samplings. + This may take a while and a significant amount of cpu memory. + """ + run = get_run() + # Assert.leq(set(samples_per_phase), set(self._phase_split)) + log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.") + self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None + self._distributed = distributed + self._samples_per_phase = samples_per_phase + # TODO: experiment_directory can be none. + self._cache_directory = run.experiment_directory / "dataset_cache" + + self._sampled_dataset = self._config.dataset.build_split_sampled(self) + + # log_main_rank(f"Preparing {self._num_datasets} datasets. This may take several minutes.") + # self._tokenizer = Tokenizer(self._config.tokenizer) if self._config.fim.rate > 0 else None + # self._distributed = distributed + # self._cache_dir = run.dataset_cache_dir + # self._samples_per_phase = samples_per_phase + # if self._cache_dir is None: + # warnings.warn(f"Using the dataset directory for the index cache.") + + # Build and split datasets. + # self._sampled_datasets = {phase: {} for phase in self._samples_per_phase} + # for i, (name, weight) in enumerate(self._dataset_weights.items()): + # if i % 100 == 0 and i > 0: + # log_main_rank(f"Prepared {i} of {self._num_datasets} datasets.") + # dataset_samples_per_phase = {} + # for phase, samples_per_phase in self._samples_per_phase.items(): + # expected_samples = self._dataset_weights[name] * samples_per_phase + # # Add 5 times the standard deviation (of a binomial distribution) + # # so the probability of sampling more than this amount during blending is negligible. + # dataset_samples_per_phase[phase] = math.ceil( + # expected_samples + # + 5 * math.sqrt(expected_samples * self._dataset_weights[name] * (1 - self._dataset_weights[name])) + # ) + # sampled_datasets = self._build_and_sample_dataset(name, dataset_samples_per_phase) + # for phase, dataset in sampled_datasets.items(): + # self._sampled_datasets[phase][name] = dataset + + # self._blended_datasets = { + # phase: ( + # list(datasets.values())[0] + # if len(datasets) == 1 + # else BlendedDataset( + # list(datasets.values()), + # weights=[self._dataset_weights[name] for name in datasets], + # name=phase.value, + # num_samples=self._samples_per_phase[phase], + # cache_dir=self._cache_dir, + # group=self._distributed.world_group, + # verbose=run.is_main_rank, + # data_sample_warn_time_ms=self._config.data_sample_warn_time_ms, + # ) + # ) + # for phase, datasets in self._sampled_datasets.items() + # } + + def get_iterator( + self, + batch_config: BatchConfig, + phase: PhaseType, + *, + consumed_samples: int, + num_workers: int, + prefetch_factor: int | None = None, + ): + Assert.incl(phase, self._sampled_dataset) + Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length) + log_main_rank(f"Initializing {phase} data iterator from sample {consumed_samples}...") + return iter( + torch.utils.data.DataLoader( + self._blended_datasets[phase], # noqa + batch_sampler=Sampler( + total_samples=len(self._sampled_dataset[phase]), + begin_index=consumed_samples, + micro_batch_size=batch_config.micro_batch_size, + data_rank=self._distributed.config.batch_data_rank, + data_parallel=self._distributed.config.batch_data_parallel, + ), + num_workers=num_workers, + prefetch_factor=prefetch_factor, + pin_memory=True, + multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, + ) + ) + + # def _build_and_sample_gpt_dataset(self, name: str, dataset_samples_per_phase: dict[PhaseType, int]): + # dataset_split = GPTDataset.from_splits( + # name, MMapIndexedDataset(self._dataset_prefixes[name]), self._phase_split + # ) + + +# +# sampled_datasets = {} +# for phase, num_samples in dataset_samples_per_phase.items(): +# if num_samples == 0: +# continue +# sampled_datasets[phase] = GPTSampledDataset( +# dataset_split[phase], +# num_samples=num_samples, +# sequence_length=self._max_sequence_length, +# seed=self._distributed.config.seed, +# group=self._distributed.world_group, +# config=self._config, +# tokenizer=self._tokenizer, +# cache_dir=self._dataset_prefixes[name].parent if self._cache_dir is None else self._cache_dir, +# verbose=self._num_datasets <= 5, +# ) +# return sampled_datasets + +# def _build_and_sample_dummy_dataset(self, name: str, dataset_samples_per_phase: dict[PhaseType, int]): +# return { +# phase: DummyGPTDataset( +# self._dataset_prefixes[name], +# dataset_samples_per_phase[phase], +# self._max_sequence_length, +# self._vocab_size, +# name, +# ) +# for phase in dataset_samples_per_phase +# } diff --git a/fast_llm/data/gpt/dummy.py b/fast_llm/data/gpt/dummy.py new file mode 100644 index 000000000..a43e8bb9b --- /dev/null +++ b/fast_llm/data/gpt/dummy.py @@ -0,0 +1,39 @@ +import pathlib + +import numpy as np + +from fast_llm.data.config import SampledDataset +from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.utils import Assert + + +class DummyGPTDataset(SampledDataset): + """ + A dummy dataset that always returns the same sample, for debugging purposes. + The sample can be purely random, or read from a file to allow reproducing in other runs. + """ + + def __init__( + self, prefix: pathlib.Path | None, num_samples: int, sequence_length: int, vocab_size: int, name: str = "dummy" + ): + self._num_samples = num_samples + if prefix is None: + self._dummy_sample = np.random.randint(0, vocab_size, size=(sequence_length + 1,), dtype=np.int64) + else: + log_main_rank(f"> Loading dummy dataset from file {prefix}") + self._dummy_sample = np.load(prefix, allow_pickle=True)[: sequence_length + 1] + Assert.eq(self._dummy_sample.shape, (sequence_length + 1,)) + Assert.eq(self._dummy_sample.dtype, np.int64) + Assert.lt(self._dummy_sample.max(), vocab_size) + Assert.geq(self._dummy_sample.min(), 0) + self._name = name + + def __len__(self): + return self._num_samples + + def __getitem__(self, idx): + return self._dummy_sample + + @property + def name(self): + return self._name diff --git a/fast_llm/data/mmap.py b/fast_llm/data/gpt/memmap.py similarity index 76% rename from fast_llm/data/mmap.py rename to fast_llm/data/gpt/memmap.py index f18fc08b7..2c595e499 100644 --- a/fast_llm/data/mmap.py +++ b/fast_llm/data/gpt/memmap.py @@ -3,10 +3,11 @@ import numpy as np +from fast_llm.data.gpt.config import GPTMemmapDatasetConfig, GPTRawDataset from fast_llm.utils import Assert, div, padded_cumsum -class MMapIndexedDataset: +class GPTMemmapDataset(GPTRawDataset): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing @@ -27,14 +28,14 @@ class MMapIndexedDataset: } _INDEX_HEADER = b"MMIDIDX\x00\x00" - def __init__(self, prefix: pathlib.Path | str): - self._init(prefix) + def __init__(self, config: GPTMemmapDatasetConfig): + self._init(config) - def _init(self, prefix: pathlib.Path | str): + def _init(self, config: GPTMemmapDatasetConfig): super().__init__() - self._prefix = pathlib.Path(prefix) + self._config = config - with self._prefix.with_suffix(".idx").open("rb") as stream: + with self._config.path.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), self._INDEX_HEADER) Assert.eq(struct.unpack(" 1 and last_epoch_samples < 0.8 * samples_per_epoch - - doc_idx = np.tile(np.arange(self._split_begin, self._split_end, dtype=np.int32), num_epochs) - if separate_last_epoch: - np_rng.shuffle(doc_idx[: -len(self)]) - np_rng.shuffle(doc_idx[-len(self) :]) - else: - np_rng.shuffle(doc_idx) - - assert _extension_available, "Please run `make -C ./fast_llm/csrc/` first." - sample_idx = build_sample_idx( - self._indexed_dataset.sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch, verbose - ) - - # shuffle-idx. - # -1 is due to data structure used to retrieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) - total_size = sample_idx.shape[0] - 1 - # TODO: Isn't the dataset already shuffled above? - shuffle_idx = np.arange( - 0, total_size, dtype=np.int64 if total_size >= (np.iinfo(np.uint32).max - 1) else np.uint32 - ) - if separate_last_epoch: - np_rng.shuffle(shuffle_idx[:main_epochs_samples]) - np_rng.shuffle(shuffle_idx[main_epochs_samples:]) - else: - np_rng.shuffle(shuffle_idx) - - Assert.geq(len(shuffle_idx), num_samples) - # TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch. - return doc_idx, sample_idx, shuffle_idx[:num_samples] - - class GPTSampledDataset(SampledDataset): """ A GPT dataset augmented with a sampling, i.e., @@ -142,12 +32,12 @@ class GPTSampledDataset(SampledDataset): def __init__( self, - dataset: GPTDataset, + dataset: GPTRawDataset, num_samples: int, sequence_length: int, seed: int, group: ProcessGroup | None, - config: DataConfig, + config: GPTDataConfig, tokenizer: Tokenizer | None, cache_dir: pathlib.Path, verbose: bool = True, @@ -179,7 +69,7 @@ def __init__( ): if verbose: log_main_rank(" > Building the index map on rank 0 ...") - doc_idx, sample_idx, shuffle_idx = self._dataset.sample(num_samples, sequence_length, np_rng, verbose) + doc_idx, sample_idx, shuffle_idx = self._sample(num_samples, sequence_length, np_rng, verbose) np.save(self._doc_idx_filename, doc_idx) np.save(self._sample_idx_filename, sample_idx) np.save(self._shuffle_idx_filename, shuffle_idx) @@ -187,6 +77,52 @@ def __init__( safe_barrier(group, self._dataset.name) self._load_mappings(verbose) + def _sample(self, num_samples: int, sequence_length: int, np_rng: numpy.random.RandomState, verbose: bool): + """ + Create a `GPTSampledDataset` with the requested parameters. + """ + tokens_per_epoch = self._dataset.num_tokens + num_epochs = math.ceil((sequence_length * num_samples + 1) / tokens_per_epoch) + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + # Get the number of samples for the last epoch + main_epochs_samples = ((num_epochs - 1) * tokens_per_epoch - 1) // sequence_length + last_epoch_samples = num_samples - main_epochs_samples + samples_per_epoch = (tokens_per_epoch - 1) // sequence_length + # If we have less than 80% of the samples for the last epoch, separate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can be adjusted if needed. + separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch + + doc_idx = np.tile(np.arange(len(self._dataset), dtype=np.int32), num_epochs) + if separate_last_epoch: + np_rng.shuffle(doc_idx[: -len(self._dataset)]) + np_rng.shuffle(doc_idx[-len(self._dataset) :]) + else: + np_rng.shuffle(doc_idx) + + assert _extension_available, "Please run `make -C ./fast_llm/csrc/` first." + sample_idx = build_sample_idx( + self._dataset.document_sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch, verbose + ) + + # shuffle-idx. + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + total_size = sample_idx.shape[0] - 1 + # TODO: Isn't the dataset already shuffled above? + shuffle_idx = np.arange( + 0, total_size, dtype=np.int64 if total_size >= (np.iinfo(np.uint32).max - 1) else np.uint32 + ) + if separate_last_epoch: + np_rng.shuffle(shuffle_idx[:main_epochs_samples]) + np_rng.shuffle(shuffle_idx[main_epochs_samples:]) + else: + np_rng.shuffle(shuffle_idx) + + Assert.geq(len(shuffle_idx), num_samples) + # TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch. + return doc_idx, sample_idx, shuffle_idx[:num_samples] + def __getstate__(self): return ( self._dataset, @@ -257,35 +193,3 @@ def __getitem__(self, idx): @property def name(self): return self._dataset.name - - -class DummyGPTDataset(SampledDataset): - """ - A dummy dataset that always returns the same sample, for debugging purposes. - The sample can be purely random, or read from a file to allow reproducing in other runs. - """ - - def __init__( - self, prefix: pathlib.Path | None, num_samples: int, sequence_length: int, vocab_size: int, name: str = "dummy" - ): - self._num_samples = num_samples - if prefix is None: - self._dummy_sample = np.random.randint(0, vocab_size, size=(sequence_length + 1,), dtype=np.int64) - else: - log_main_rank(f"> Loading dummy dataset from file {prefix}") - self._dummy_sample = np.load(prefix, allow_pickle=True)[: sequence_length + 1] - Assert.eq(self._dummy_sample.shape, (sequence_length + 1,)) - Assert.eq(self._dummy_sample.dtype, np.int64) - Assert.lt(self._dummy_sample.max(), vocab_size) - Assert.geq(self._dummy_sample.min(), 0) - self._name = name - - def __len__(self): - return self._num_samples - - def __getitem__(self, idx): - return self._dummy_sample - - @property - def name(self): - return self._name diff --git a/fast_llm/data/gpt/slice.py b/fast_llm/data/gpt/slice.py new file mode 100644 index 000000000..2b761e35d --- /dev/null +++ b/fast_llm/data/gpt/slice.py @@ -0,0 +1,71 @@ +import numpy as np + +from fast_llm.data.gpt.config import GPTRawDataset +from fast_llm.engine.distributed.config import PhaseType +from fast_llm.utils import Assert, padded_cumsum + + +class GPTDatasetSlice(GPTRawDataset): + """ + A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. + """ + + def __init__( + self, + name: str, + dataset: GPTRawDataset, + begin: int | None = None, + end: int | None = None, + ): + self._name = name + self._dataset = dataset + self._begin = 0 if begin is None else begin + self._end = len(dataset) if end is None else end + + # Checks + try: + Assert.geq(self._begin, 0) + Assert.in_range_incl(self._end, self._begin + 1, len(dataset)) + except Exception as e: + raise AssertionError(f"Invalid document indices for dataset {name} with length {len(dataset)}") from e + + def __getitem__(self, index: int): + """ + Get the sample (document) with the given index (in the split dataset). + """ + return self.get(index) + + def get(self, document: int, offset: int = 0, length: int | None = None): + """ + Get the sample (document) with the given index (in the dataset slice), + optionally sub-sampled to a specific offset (starting point) and maximum length + (end = min(offset + length, sample_length). + """ + return self._dataset.get(document + self._begin, offset, length) + + @property + def num_documents(self): + return self._end - self._begin + + @property + def num_tokens(self): + return np.sum(self._dataset.document_sizes[self._begin : self._end]) + + @property + def name(self): + return self._name + + @classmethod + def from_splits(cls, dataset: GPTRawDataset, phase_split: dict[PhaseType, float]): + """ + Create a set of GPT datasets from a MMapIndexedDataset, + each containing approximately the requested proportion of the total tokens. + """ + split_probs = list(phase_split.values()) + Assert.eq(sum(split_probs), 1) + num_documents = dataset.num_documents + splits = [round(x) for x in padded_cumsum(split_probs) * num_documents] + return { + phase: GPTDatasetSlice(f"{dataset.name}_{phase.value}", dataset, split_begin, split_end) + for phase, split_begin, split_end in zip(phase_split, splits[:-1], splits[1:]) + } diff --git a/fast_llm/data/sampler.py b/fast_llm/data/sampler.py new file mode 100644 index 000000000..609134310 --- /dev/null +++ b/fast_llm/data/sampler.py @@ -0,0 +1,22 @@ +import torch.utils.data + + +class Sampler(torch.utils.data.Sampler): + """ + A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers). + To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`. + """ + + def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel): + self._total_samples = total_samples + self._begin_index = begin_index + self._batch_size = micro_batch_size * data_parallel + self._start_idx = data_rank * micro_batch_size + self._end_idx = (data_rank + 1) * micro_batch_size + + def __len__(self): + return self._total_samples + + def __iter__(self): + for idx in range(self._begin_index, self._total_samples - self._batch_size + 1, self._batch_size): + yield list(range(idx + self._start_idx, idx + self._end_idx)) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index cd32b7d51..a0255ea66 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -151,13 +151,11 @@ def __init__( if self._config.experiment_dir is not None: self._experiment_directory = self._config.experiment_dir.resolve() - self.dataset_cache_dir = self._experiment_directory / "dataset_cache" if self._is_main_rank: (self._experiment_directory / "runs").mkdir(exist_ok=True, parents=True) run = len(list((self._experiment_directory / "runs").iterdir())) (self._experiment_directory / "runs" / str(run)).mkdir() yaml.safe_dump(config_dict, (self._experiment_directory / "config.yaml").open("w")) - self.dataset_cache_dir.mkdir(exist_ok=True) else: run = 0 # Make sure all the workers agree on the run. This also acts as a barrier. @@ -167,7 +165,6 @@ def __init__( log_dir = run_dir / "logs" else: _experiment_directory, self._artifact_dir, log_dir = None, None, None - self.dataset_cache_dir = None self.index = None if self._config.structured_logs: diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 0c3b80b3e..404c6e85c 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -10,7 +10,7 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.config import AbstractData -from fast_llm.data.data import Data +from fast_llm.data.gpt.data import GPTData from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.distributed.distributed import Distributed @@ -112,7 +112,7 @@ def setup(self, distributed: Distributed, run: Run): @abc.abstractmethod def _get_data(self) -> AbstractData: - return Data( + return GPTData( config=self._config.data, distributed_config=self._config.distributed, # TODO: `vocab_size` is not generic. diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 0881a6e54..f0f32dc3d 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -1,7 +1,7 @@ import typing from fast_llm.config import FieldUpdate, config_class -from fast_llm.data.config import DataConfig +from fast_llm.data.gpt.config import GPTDataConfig from fast_llm.models.gpt.config import ( GPTArchitectureConfig, GPTBaseModelConfig, @@ -12,7 +12,7 @@ @config_class() -class CustomDataConfig(DataConfig): +class CustomDataConfig(GPTDataConfig): # TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything. pass diff --git a/fast_llm/models/custom/data.py b/fast_llm/models/custom/data.py index fd9ee6cfc..567466d9a 100644 --- a/fast_llm/models/custom/data.py +++ b/fast_llm/models/custom/data.py @@ -1,9 +1,9 @@ -from fast_llm.data.data import Data +from fast_llm.data.gpt.data import GPTData from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.models.custom.config import CustomDataConfig -class CustomData(Data): +class CustomData(GPTData): # TODO: If needed, inherit from AbstractData instead and re-implement everything. def __init__( self, diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 5c4d11770..fbe069514 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,7 +1,7 @@ import typing from fast_llm.config import Field, FieldHint, FieldUpdate, config_class -from fast_llm.data.config import DataConfig +from fast_llm.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -122,7 +122,7 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: DataConfig = FieldUpdate(default_factory=DataConfig) + data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) def _setup(self): super()._setup() diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 568df761b..5d0558b86 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -1,6 +1,6 @@ import logging -from fast_llm.data.data import Data +from fast_llm.data.gpt.data import GPTData from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.training.trainer import Trainer from fast_llm.models.gpt.config import GPTTrainerConfig @@ -16,7 +16,7 @@ class GPTTrainer(Trainer): model_class = GPTModel def _get_data(self): - return Data( + return GPTData( config=self._config.data, distributed_config=self._config.distributed, vocab_size=self._config.base_model.vocab_size, diff --git a/tests/common.py b/tests/common.py index aab2655c7..75e003dff 100644 --- a/tests/common.py +++ b/tests/common.py @@ -10,7 +10,7 @@ import pytest import torch -from fast_llm.data.mmap import MMapIndexedDataset +from fast_llm.data.gpt.memmap import MMapIndexedDataset from fast_llm.models.gpt.config import ( MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py index e0db453e9..b4c4ed1c9 100644 --- a/tools/concatenate_dataset.py +++ b/tools/concatenate_dataset.py @@ -3,7 +3,7 @@ import pathlib from fast_llm.config import Field, config_class -from fast_llm.data.mmap import MMapIndexedDataset +from fast_llm.data.gpt.memmap import MMapIndexedDataset from fast_llm.engine.config_utils.runnable import RunnableConfig logger = logging.getLogger(__name__) From 76596ced3aaeb3b9a9eace6cf2b505c639493b2a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Nov 2024 10:49:42 -0500 Subject: [PATCH 02/13] Reduce diff --- fast_llm/data/blended.py | 16 +++++++++------- fast_llm/data/config.py | 4 ++-- fast_llm/data/gpt/config.py | 4 ++-- fast_llm/data/gpt/data.py | 13 ++----------- fast_llm/engine/training/config.py | 6 +++--- fast_llm/engine/training/trainer.py | 4 ++-- fast_llm/utils.py | 10 ++++++++++ tests/common.py | 4 ++-- tools/concatenate_dataset.py | 4 ++-- 9 files changed, 34 insertions(+), 31 deletions(-) diff --git a/fast_llm/data/blended.py b/fast_llm/data/blended.py index 6373d3a79..405a6adf7 100644 --- a/fast_llm/data/blended.py +++ b/fast_llm/data/blended.py @@ -1,7 +1,6 @@ import logging import pathlib import time -import typing import numpy as np @@ -35,8 +34,8 @@ def __init__( *, name: str = "blended", num_samples: int, - cache_dir: pathlib.Path | None = None, - group: typing.Optional["ProcessGroup"] = None, + cache_directory: pathlib.Path | None = None, + group: ProcessGroup | None = None, verbose: bool = True, data_sample_warn_time_ms: float = 1000, ): @@ -46,12 +45,12 @@ def __init__( self._weights = weights self._data_sample_warn_time_ms = data_sample_warn_time_ms - if cache_dir is None: + if cache_directory is None: self._dataset_idx_filename, self._sample_idx_filename = None, None self._dataset_index, self._sample_index = self._build_blending_indices(verbose and len(datasets) <= 20) else: - self._dataset_idx_filename = cache_dir / (self._name + "_blending_dataset_idx.npy") - self._sample_idx_filename = cache_dir / (self._name + "_blending_sample_idx.npy") + self._dataset_idx_filename = cache_directory / (self._name + "_blending_dataset_idx.npy") + self._sample_idx_filename = cache_directory / (self._name + "_blending_sample_idx.npy") # Build the indexed mapping if it doesn't exist. # TODO: This only works if the dataset location is accessible by all job. @@ -59,6 +58,7 @@ def __init__( self._dataset_idx_filename.is_file() and self._sample_idx_filename.is_file() ): dataset_index, sample_index = self._build_blending_indices(verbose and len(datasets) <= 20) + cache_directory.mkdir(exist_ok=True, parents=True) np.save(self._dataset_idx_filename, dataset_index) np.save(self._sample_idx_filename, sample_index) @@ -103,7 +103,9 @@ def __len__(self): return self._num_samples def _build_blending_indices(self, verbose: bool): - assert _extension_available, "Please run `make -C ./fast_llm/csrc/` first." + assert _extension_available, ( + "The C++ extension for dataset blending is missing." " Please make sure Fast-LLM is installed correctly." + ) Assert.lt(len(self._datasets), 32767) dataset_index = np.zeros(self._num_samples, dtype=np.int16) dataset_sample_index = np.zeros(self._num_samples, dtype=np.int64) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index e6f33eebf..09fe6eb63 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -123,11 +123,11 @@ class TokenizerConfig(Config): @config_class() -class AbstractDataConfig(Config): +class DataConfig(Config): _abstract = True -class AbstractData(abc.ABC): +class Data(abc.ABC): # TODO: Improve interface @abc.abstractmethod def setup(self, distributed: "Distributed", samples_per_phase: dict[PhaseType, int]): diff --git a/fast_llm/data/gpt/config.py b/fast_llm/data/gpt/config.py index 118670da8..0220b80b1 100644 --- a/fast_llm/data/gpt/config.py +++ b/fast_llm/data/gpt/config.py @@ -6,7 +6,7 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.data.config import ( - AbstractDataConfig, + DataConfig, FimConfig, MultiprocessingContext, RawDataset, @@ -373,7 +373,7 @@ def _build_split_sampled(self, data: "GPTData") -> dict[PhaseType, SampledDatase @config_class() -class GPTDataConfig(AbstractDataConfig): +class GPTDataConfig(DataConfig): """ Configuration for the dataset(s), split and sampling. Currently hard-coded to a GPT dataset. diff --git a/fast_llm/data/gpt/data.py b/fast_llm/data/gpt/data.py index 1158cc73e..0b090e849 100644 --- a/fast_llm/data/gpt/data.py +++ b/fast_llm/data/gpt/data.py @@ -2,11 +2,10 @@ import pathlib import typing -import numpy as np import torch import torch.utils.data -from fast_llm.data.config import AbstractData, SampledDataset +from fast_llm.data.config import Data, SampledDataset from fast_llm.data.gpt.config import GPTDataConfig from fast_llm.data.sampler import Sampler from fast_llm.data.tokenizer import Tokenizer @@ -19,15 +18,7 @@ logger = logging.getLogger(__name__) -def normalize_probs(p: list[float]) -> list[float]: - p = np.array(p) - Assert.custom(lambda x: np.all(x >= 0), p) - p_sum = p.sum() - Assert.gt(p_sum, 0) - return (p / p_sum).tolist() - - -class GPTData(AbstractData): +class GPTData(Data): """ A global class for all dataset needs, including loading, splitting, sampling and iteration. Currently hard-coded to a GPT dataset. diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 9e090670f..fe1336152 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -7,7 +7,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none -from fast_llm.data.config import AbstractDataConfig +from fast_llm.data.config import DataConfig from fast_llm.engine.checkpoint.config import ( CheckpointLoadConfig, CheckpointSaveConfig, @@ -330,8 +330,8 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): schedule: ScheduleConfig = Field( default_factory=ScheduleConfig, desc="Configuration for the scheduling of each iteration.", hint=FieldHint.core ) - data: AbstractDataConfig = Field( - default_factory=AbstractDataConfig, + data: DataConfig = Field( + default_factory=DataConfig, desc="Configuration for the dataset and model-independent preprocessing.", hint=FieldHint.core, ) diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 404c6e85c..e4677a4c8 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -9,7 +9,7 @@ import torch from fast_llm.core.distributed import safe_barrier -from fast_llm.data.config import AbstractData +from fast_llm.data.config import Data from fast_llm.data.gpt.data import GPTData from fast_llm.engine.config_utils.run import Run, is_main_rank, log_main_rank, log_pipeline_parallel_main_rank from fast_llm.engine.distributed.config import PhaseType @@ -111,7 +111,7 @@ def setup(self, distributed: Distributed, run: Run): self._data.setup(distributed, self._samples_per_split) @abc.abstractmethod - def _get_data(self) -> AbstractData: + def _get_data(self) -> Data: return GPTData( config=self._config.data, distributed_config=self._config.distributed, diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 4d61dffdc..937aacd8c 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -208,3 +208,13 @@ def log(*message, log_fn: typing.Union[BaseException, typing.Callable] = logger. raise log_fn(message) else: return log_fn(message) + + +def normalize_probabilities(p: list[float]) -> list[float]: + import numpy as np + + p = np.array(p) + Assert.custom(lambda x: np.all(x >= 0), p) + p_sum = p.sum() + Assert.gt(p_sum, 0) + return (p / p_sum).tolist() diff --git a/tests/common.py b/tests/common.py index 75e003dff..a944dd826 100644 --- a/tests/common.py +++ b/tests/common.py @@ -10,7 +10,7 @@ import pytest import torch -from fast_llm.data.gpt.memmap import MMapIndexedDataset +from fast_llm.data.gpt.memmap import GPTMemmapDataset from fast_llm.models.gpt.config import ( MistralGPTHuggingfaceCheckpointFormat, MixtralGPTHuggingfaceCheckpointFormat, @@ -186,7 +186,7 @@ def get_test_data(): tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) documents = [np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % 8192 for document in documents] - MMapIndexedDataset.write_dataset(DATASET_PREFIX, documents) + GPTMemmapDataset.write_dataset(DATASET_PREFIX, documents) def run_test_script( diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py index b4c4ed1c9..a90e909e8 100644 --- a/tools/concatenate_dataset.py +++ b/tools/concatenate_dataset.py @@ -3,7 +3,7 @@ import pathlib from fast_llm.config import Field, config_class -from fast_llm.data.gpt.memmap import MMapIndexedDataset +from fast_llm.data.gpt.memmap import GPTMemmapDataset from fast_llm.engine.config_utils.runnable import RunnableConfig logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def run(self): for path in self.directory.glob("**/*.idx"): prefix = path.with_suffix("") logger.info(str(prefix)) - dataset = MMapIndexedDataset(prefix) + dataset = GPTMemmapDataset(prefix) dataset_dict = { "prefix": str(prefix.relative_to(self.directory)), "num_documents": dataset.num_documents, From 1b9485a0ef62f8d4ea236cf78e4061a9aa6795f2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Nov 2024 10:52:08 -0500 Subject: [PATCH 03/13] Reduce diff --- fast_llm/data/gpt/data.py | 2 +- fast_llm/data/{sampler.py => iterator.py} | 3 ++- fast_llm/models/custom/config.py | 4 ++-- fast_llm/models/gpt/config.py | 4 ++-- fast_llm/models/gpt/conversion.py | 2 +- 5 files changed, 8 insertions(+), 7 deletions(-) rename fast_llm/data/{sampler.py => iterator.py} (90%) diff --git a/fast_llm/data/gpt/data.py b/fast_llm/data/gpt/data.py index 0b090e849..a9c1b3110 100644 --- a/fast_llm/data/gpt/data.py +++ b/fast_llm/data/gpt/data.py @@ -7,7 +7,7 @@ from fast_llm.data.config import Data, SampledDataset from fast_llm.data.gpt.config import GPTDataConfig -from fast_llm.data.sampler import Sampler +from fast_llm.data.iterator import Sampler from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.run import get_run, log_main_rank from fast_llm.engine.distributed.config import DistributedConfig, PhaseType diff --git a/fast_llm/data/sampler.py b/fast_llm/data/iterator.py similarity index 90% rename from fast_llm/data/sampler.py rename to fast_llm/data/iterator.py index 609134310..8a8fdcd25 100644 --- a/fast_llm/data/sampler.py +++ b/fast_llm/data/iterator.py @@ -1,13 +1,14 @@ import torch.utils.data -class Sampler(torch.utils.data.Sampler): +class SampledDatasetIterator(torch.utils.data.Sampler): """ A distributed sampler generating indices for a `SampledDataset` (i.e., the natural numbers). To be used as the `batch_sampler` of a `torch.utils.data.DataLoader`. """ def __init__(self, total_samples, begin_index, micro_batch_size, data_rank, data_parallel): + super().__init__() self._total_samples = total_samples self._begin_index = begin_index self._batch_size = micro_batch_size * data_parallel diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index f0f32dc3d..8f965b8f8 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -1,7 +1,7 @@ import typing from fast_llm.config import FieldUpdate, config_class -from fast_llm.data.gpt.config import GPTDataConfig +from fast_llm.data.gpt.config import DataConfig from fast_llm.models.gpt.config import ( GPTArchitectureConfig, GPTBaseModelConfig, @@ -12,7 +12,7 @@ @config_class() -class CustomDataConfig(GPTDataConfig): +class CustomDataConfig(DataConfig): # TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything. pass diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index fbe069514..467bd63e5 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,7 +1,7 @@ import typing from fast_llm.config import Field, FieldHint, FieldUpdate, config_class -from fast_llm.data.gpt.config import GPTDataConfig +from fast_llm.data.gpt.config import DataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -122,7 +122,7 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) + data: DataConfig = FieldUpdate(default_factory=DataConfig) def _setup(self): super()._setup() diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 257a9edfa..eea98b8b6 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -258,7 +258,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: # TODO: Llama supports biases ConstantExportParamConverter(None, "attention_bias", False), ConstantExportParamConverter(None, "mlp_bias", False), - ConstantExportParamConverter(None, "rope_scaling", False), + ConstantExportParamConverter(None, "rope_scaling", None), ] def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str): From 6e43321394360e4389fba0060d770558290c031d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Nov 2024 11:04:31 -0500 Subject: [PATCH 04/13] fix --- setup.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 33f455f5b..8d5efbb9c 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,13 @@ import sys -import pybind11 -import setuptools +try: + import pybind11 + import setuptools +except ImportError: + raise ImportError( + "Could not import third party module during setup." + " Please make sure it is installed before installing Fast-LLM, and use `--no-build-isolation" + ) # Minimum setuptools version required to parse setup.cfg metadata. _SETUPTOOLS_MIN_VERSION = "30.3" From 363458030a280ab4eef889aefd0d80450f25a189 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Nov 2024 15:33:08 -0500 Subject: [PATCH 05/13] fix --- fast_llm/data/config.py | 25 +++------ fast_llm/data/gpt/config.py | 59 ++++----------------- fast_llm/data/gpt/data.py | 4 +- fast_llm/data/gpt/dataset.py | 99 ++++++++++++++++++++++++++++++++++++ fast_llm/data/gpt/memmap.py | 5 +- fast_llm/data/gpt/sampled.py | 5 +- fast_llm/data/gpt/slice.py | 8 +-- 7 files changed, 129 insertions(+), 76 deletions(-) create mode 100644 fast_llm/data/gpt/dataset.py diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 09fe6eb63..836a6d17b 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -151,14 +151,6 @@ class Dataset(abc.ABC): A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature. """ - @abc.abstractmethod - def __getitem__(self, index: int): - pass - - @abc.abstractmethod - def __len__(self): - pass - @property @abc.abstractmethod def name(self): @@ -167,17 +159,16 @@ def name(self): """ -class RawDataset(Dataset): # noqa - """ - A raw dataset class containing a list of unsampled, unprocessed samples, i.e., matching what is stored on disk. - (Excluding off-line processing prior to training.) - Functionally identical to a `Dataset`, but renamed for clarity. - """ - - class SampledDataset(Dataset): # noqa """ A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training. (See the `Sampler` class below.) - Functionally identical to a `Dataset`, but renamed for clarity. """ + + @abc.abstractmethod + def __getitem__(self, index: int): + pass + + @abc.abstractmethod + def __len__(self): + pass diff --git a/fast_llm/data/gpt/config.py b/fast_llm/data/gpt/config.py index 0220b80b1..c56d778e9 100644 --- a/fast_llm/data/gpt/config.py +++ b/fast_llm/data/gpt/config.py @@ -1,56 +1,20 @@ -import abc import enum import logging import pathlib import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import ( - DataConfig, - FimConfig, - MultiprocessingContext, - RawDataset, - SampledDataset, - TokenizerConfig, -) +from fast_llm.data.config import DataConfig, FimConfig, MultiprocessingContext, SampledDataset, TokenizerConfig from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.data.gpt.data import GPTData + from fast_llm.data.gpt.dataset import GPTIndexedDataset logger = logging.getLogger(__name__) -class GPTRawDataset(RawDataset): - def __len__(self): - return self.num_documents - - def get(self, document: int, offset: int = 0, length: int | None = None): - pass - - def __getitem__(self, index: int): - """ - Get the sample (document) with the given index (in the split dataset). - """ - return self.get(index) - - @property - @abc.abstractmethod - def num_documents(self): - pass - - @property - @abc.abstractmethod - def num_tokens(self): - pass - - @property - @abc.abstractmethod - def document_sizes(self): - pass - - class GPTDatasetConfigType(str, enum.Enum): split = "split" splits = "splits" @@ -118,31 +82,28 @@ def build_unsplit_sampled(self, data: "GPTData") -> SampledDataset: else: return self._sample(self.build_unsplit_unsampled(), data) - def build_split_unsampled(self) -> dict[PhaseType, GPTRawDataset]: + def build_split_unsampled(self) -> dict[PhaseType, "GPTIndexedDataset"]: assert not self.sampled if self.split: return self._build_split_unsampled() else: return {PhaseType.training: self.build_unsplit_unsampled()} - def build_unsplit_unsampled(self) -> GPTRawDataset: + def build_unsplit_unsampled(self) -> "GPTIndexedDataset": assert not self.split assert not self.sampled return self._build_unsplit_unsampled() - def _build(self) -> GPTRawDataset | SampledDataset | dict[PhaseType, GPTRawDataset | SampledDataset]: - raise NotImplementedError() - def _build_split_sampled(self, data: "GPTData") -> dict[PhaseType, SampledDataset]: raise NotImplementedError() def _build_unsplit_sampled(self, data: "GPTData") -> SampledDataset: raise NotImplementedError() - def _build_split_unsampled(self) -> dict[PhaseType, GPTRawDataset]: + def _build_split_unsampled(self) -> dict[PhaseType, "GPTIndexedDataset"]: raise NotImplementedError() - def _build_unsplit_unsampled(self) -> GPTRawDataset: + def _build_unsplit_unsampled(self) -> "GPTIndexedDataset": raise NotImplementedError() @property @@ -171,7 +132,7 @@ def split(self) -> bool: def sampled(self) -> bool: return False - def _build_unsplit_unsampled(self) -> GPTRawDataset: + def _build_unsplit_unsampled(self) -> "GPTIndexedDataset": from fast_llm.data.gpt.memmap import GPTMemmapDataset return GPTMemmapDataset(self) @@ -221,7 +182,7 @@ def split(self) -> bool: def sampled(self) -> bool: return False - def _build_unsplit_unsampled(self) -> GPTRawDataset: + def _build_unsplit_unsampled(self) -> "GPTIndexedDataset": from fast_llm.data.gpt.concatenated import GPTConcatenatedDataset return GPTConcatenatedDataset(self, [dataset.build_unsplit_unsampled() for dataset in self.datasets]) @@ -271,7 +232,7 @@ def sampled(self) -> bool: def split(self) -> bool: return True - def _build_split_unsampled(self) -> dict[PhaseType, GPTRawDataset]: + def _build_split_unsampled(self) -> dict[PhaseType, "GPTIndexedDataset"]: from fast_llm.data.gpt.slice import GPTDatasetSlice return GPTDatasetSlice.from_splits(self.dataset.build_unsplit_unsampled(), self.ratios) @@ -320,7 +281,7 @@ def sampled(self) -> bool: def _build_split_sampled(self, data: "GPTData") -> dict[PhaseType, SampledDataset]: return {phase: dataset.build_unsplit_sampled(data) for phase, dataset in self.datasets.items()} - def _build_split_unsampled(self) -> dict[PhaseType, GPTRawDataset]: + def _build_split_unsampled(self) -> dict[PhaseType, "GPTIndexedDataset"]: return {phase: dataset.build_unsplit_unsampled() for phase, dataset in self.datasets.items()} @property diff --git a/fast_llm/data/gpt/data.py b/fast_llm/data/gpt/data.py index a9c1b3110..f03dfdd18 100644 --- a/fast_llm/data/gpt/data.py +++ b/fast_llm/data/gpt/data.py @@ -7,7 +7,7 @@ from fast_llm.data.config import Data, SampledDataset from fast_llm.data.gpt.config import GPTDataConfig -from fast_llm.data.iterator import Sampler +from fast_llm.data.iterator import SampledDatasetIterator from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.run import get_run, log_main_rank from fast_llm.engine.distributed.config import DistributedConfig, PhaseType @@ -176,7 +176,7 @@ def get_iterator( return iter( torch.utils.data.DataLoader( self._blended_datasets[phase], # noqa - batch_sampler=Sampler( + batch_sampler=SampledDatasetIterator( total_samples=len(self._sampled_dataset[phase]), begin_index=consumed_samples, micro_batch_size=batch_config.micro_batch_size, diff --git a/fast_llm/data/gpt/dataset.py b/fast_llm/data/gpt/dataset.py new file mode 100644 index 000000000..010a3f6f3 --- /dev/null +++ b/fast_llm/data/gpt/dataset.py @@ -0,0 +1,99 @@ +import abc +import math + +import numpy as np +import numpy.random + +from fast_llm.data.config import Dataset +from fast_llm.utils import Assert + +try: + from fast_llm.csrc.data import build_sample_idx # noqa + + _extension_available = True +except ImportError: + _extension_available = False + + +class GPTIndexedDataset(Dataset): + """ + A GPT dataset containing a list of unsampled, unprocessed samples. + TODO: Move sampling responsibility here? + """ + + def get(self, document: int, offset: int = 0, length: int | None = None): + pass + + @property + def num_documents(self) -> int: + """ + Number of documents in the dataset. + Can be calculated from document sizes but may be overridden if there is a better method. + """ + return len(self.get_document_sizes()) + + @property + def num_tokens(self) -> int: + """ + Number of tokens in the dataset. + Can be calculated from document sizes but may be overridden if there is a better method. + """ + return self.get_document_sizes().sum() + + @abc.abstractmethod + def get_document_sizes(self) -> "np.ndarray": + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + + def sample(self, num_samples: int, sequence_length: int, np_rng: numpy.random.RandomState, verbose: bool): + """ + Create a `GPTSampledDataset` with the requested parameters. + """ + document_sizes = self.get_document_sizes() + num_documents = len(document_sizes) + num_tokens = document_sizes.sum() + + num_epochs = math.ceil((sequence_length * num_samples + 1) / num_tokens) + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + # Get the number of samples for the last epoch + main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // sequence_length + last_epoch_samples = num_samples - main_epochs_samples + samples_per_epoch = (num_tokens - 1) // sequence_length + # If we have less than 80% of the samples for the last epoch, separate out the epoch and treat it differently. + # Note: the 80% number is just based on common sense and can be adjusted if needed. + separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch + + doc_idx = np.tile(np.arange(num_documents, dtype=np.int32), num_epochs) + if separate_last_epoch: + np_rng.shuffle(doc_idx[:-num_documents]) + np_rng.shuffle(doc_idx[-num_documents:]) + else: + np_rng.shuffle(doc_idx) + + assert _extension_available, ( + "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." + ) + + sample_idx = build_sample_idx(document_sizes, doc_idx, sequence_length, num_epochs, num_tokens, verbose) + + # shuffle-idx. + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + total_size = sample_idx.shape[0] - 1 + # TODO: Isn't the dataset already shuffled above? + shuffle_idx = np.arange( + 0, total_size, dtype=np.int64 if total_size >= (np.iinfo(np.uint32).max - 1) else np.uint32 + ) + if separate_last_epoch: + np_rng.shuffle(shuffle_idx[:main_epochs_samples]) + np_rng.shuffle(shuffle_idx[main_epochs_samples:]) + else: + np_rng.shuffle(shuffle_idx) + + Assert.geq(len(shuffle_idx), num_samples) + # TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch. + return doc_idx, sample_idx, shuffle_idx[:num_samples] diff --git a/fast_llm/data/gpt/memmap.py b/fast_llm/data/gpt/memmap.py index 2c595e499..272508866 100644 --- a/fast_llm/data/gpt/memmap.py +++ b/fast_llm/data/gpt/memmap.py @@ -3,11 +3,12 @@ import numpy as np -from fast_llm.data.gpt.config import GPTMemmapDatasetConfig, GPTRawDataset +from fast_llm.data.gpt.config import GPTMemmapDatasetConfig +from fast_llm.data.gpt.dataset import GPTIndexedDataset from fast_llm.utils import Assert, div, padded_cumsum -class GPTMemmapDataset(GPTRawDataset): +class GPTMemmapDataset(GPTIndexedDataset): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, i.e. a pair of numpy file containing diff --git a/fast_llm/data/gpt/sampled.py b/fast_llm/data/gpt/sampled.py index f4867885b..f7e8d1390 100644 --- a/fast_llm/data/gpt/sampled.py +++ b/fast_llm/data/gpt/sampled.py @@ -8,7 +8,8 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.config import SampledDataset from fast_llm.data.fim import Fim -from fast_llm.data.gpt.config import GPTDataConfig, GPTRawDataset +from fast_llm.data.gpt.config import GPTDataConfig +from fast_llm.data.gpt.dataset import GPTIndexedDataset from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import MAX_SEED @@ -32,7 +33,7 @@ class GPTSampledDataset(SampledDataset): def __init__( self, - dataset: GPTRawDataset, + dataset: GPTIndexedDataset, num_samples: int, sequence_length: int, seed: int, diff --git a/fast_llm/data/gpt/slice.py b/fast_llm/data/gpt/slice.py index 2b761e35d..d7ee939e6 100644 --- a/fast_llm/data/gpt/slice.py +++ b/fast_llm/data/gpt/slice.py @@ -1,11 +1,11 @@ import numpy as np -from fast_llm.data.gpt.config import GPTRawDataset +from fast_llm.data.gpt.dataset import GPTIndexedDataset from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, padded_cumsum -class GPTDatasetSlice(GPTRawDataset): +class GPTDatasetSlice(GPTIndexedDataset): """ A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. """ @@ -13,7 +13,7 @@ class GPTDatasetSlice(GPTRawDataset): def __init__( self, name: str, - dataset: GPTRawDataset, + dataset: GPTIndexedDataset, begin: int | None = None, end: int | None = None, ): @@ -56,7 +56,7 @@ def name(self): return self._name @classmethod - def from_splits(cls, dataset: GPTRawDataset, phase_split: dict[PhaseType, float]): + def from_splits(cls, dataset: GPTIndexedDataset, phase_split: dict[PhaseType, float]): """ Create a set of GPT datasets from a MMapIndexedDataset, each containing approximately the requested proportion of the total tokens. From 9b1ad5577d17b3cd1caf410ddcb8d9a19c7abeff Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Nov 2024 15:35:08 -0500 Subject: [PATCH 06/13] stuff --- fast_llm/data/gpt/slice.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/fast_llm/data/gpt/slice.py b/fast_llm/data/gpt/slice.py index d7ee939e6..cc81736ca 100644 --- a/fast_llm/data/gpt/slice.py +++ b/fast_llm/data/gpt/slice.py @@ -1,8 +1,6 @@ -import numpy as np - from fast_llm.data.gpt.dataset import GPTIndexedDataset from fast_llm.engine.distributed.config import PhaseType -from fast_llm.utils import Assert, padded_cumsum +from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum class GPTDatasetSlice(GPTIndexedDataset): @@ -20,14 +18,15 @@ def __init__( self._name = name self._dataset = dataset self._begin = 0 if begin is None else begin - self._end = len(dataset) if end is None else end + dataset_documents = dataset.num_documents + self._end = dataset_documents if end is None else end # Checks try: Assert.geq(self._begin, 0) - Assert.in_range_incl(self._end, self._begin + 1, len(dataset)) + Assert.in_range_incl(self._end, self._begin + 1, dataset_documents) except Exception as e: - raise AssertionError(f"Invalid document indices for dataset {name} with length {len(dataset)}") from e + raise AssertionError(f"Invalid document indices for dataset {name} with length {dataset_documents}") from e def __getitem__(self, index: int): """ @@ -47,9 +46,9 @@ def get(self, document: int, offset: int = 0, length: int | None = None): def num_documents(self): return self._end - self._begin - @property - def num_tokens(self): - return np.sum(self._dataset.document_sizes[self._begin : self._end]) + def get_document_sizes(self): + # TODO: This can be really big. + return self._dataset.get_document_sizes()[self._begin : self._end] @property def name(self): @@ -61,10 +60,8 @@ def from_splits(cls, dataset: GPTIndexedDataset, phase_split: dict[PhaseType, fl Create a set of GPT datasets from a MMapIndexedDataset, each containing approximately the requested proportion of the total tokens. """ - split_probs = list(phase_split.values()) - Assert.eq(sum(split_probs), 1) - num_documents = dataset.num_documents - splits = [round(x) for x in padded_cumsum(split_probs) * num_documents] + probabilities = normalize_probabilities(list(phase_split.values())) + splits = [round(x) for x in padded_cumsum(probabilities) * dataset.num_documents] return { phase: GPTDatasetSlice(f"{dataset.name}_{phase.value}", dataset, split_begin, split_end) for phase, split_begin, split_end in zip(phase_split, splits[:-1], splits[1:]) From f63ef0b02a0d0e6d414aa7dc9ce1d384916c567d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Nov 2024 15:58:30 -0500 Subject: [PATCH 07/13] stuff --- fast_llm/data/gpt/config.py | 79 ++++++++++++++++++++++++---- fast_llm/data/gpt/dataset.py | 99 ------------------------------------ fast_llm/data/gpt/memmap.py | 3 +- fast_llm/data/gpt/sampled.py | 41 ++++++++------- fast_llm/data/gpt/slice.py | 2 +- 5 files changed, 93 insertions(+), 131 deletions(-) diff --git a/fast_llm/data/gpt/config.py b/fast_llm/data/gpt/config.py index c56d778e9..e5447a8dd 100644 --- a/fast_llm/data/gpt/config.py +++ b/fast_llm/data/gpt/config.py @@ -1,16 +1,25 @@ +import abc import enum import logging import pathlib import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.data.config import DataConfig, FimConfig, MultiprocessingContext, SampledDataset, TokenizerConfig +from fast_llm.data.config import ( + DataConfig, + Dataset, + FimConfig, + MultiprocessingContext, + SampledDataset, + TokenizerConfig, +) from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert if typing.TYPE_CHECKING: + import numpy as np + from fast_llm.data.gpt.data import GPTData - from fast_llm.data.gpt.dataset import GPTIndexedDataset logger = logging.getLogger(__name__) @@ -23,6 +32,56 @@ class GPTDatasetConfigType(str, enum.Enum): memmap = "random" +class GPTDataset(Dataset): + @abc.abstractmethod + def sample( + self, + data: "GPTData", + num_samples: int, + sequence_length: int, + np_rng: "np.random.RandomState", + verbose: bool, + ): + pass + + +class GPTIndexedDataset(Dataset): + """ + A GPT dataset containing a list of unsampled, unprocessed samples. + TODO: Move sampling responsibility here? + """ + + def get(self, document: int, offset: int = 0, length: int | None = None): + pass + + @property + def num_documents(self) -> int: + """ + Number of documents in the dataset. + Can be calculated from document sizes but may be overridden if there is a better method. + """ + return len(self.get_document_sizes()) + + @property + def num_tokens(self) -> int: + """ + Number of tokens in the dataset. + Can be calculated from document sizes but may be overridden if there is a better method. + """ + return self.get_document_sizes().sum() + + @abc.abstractmethod + def get_document_sizes(self) -> "np.ndarray": + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + + def sample(self, num_samples: int, sequence_length: int, np_rng: "np.random.RandomState", verbose: bool): + return + + @config_class() class GPTDatasetConfig(Config): @@ -82,14 +141,14 @@ def build_unsplit_sampled(self, data: "GPTData") -> SampledDataset: else: return self._sample(self.build_unsplit_unsampled(), data) - def build_split_unsampled(self) -> dict[PhaseType, "GPTIndexedDataset"]: + def build_split_unsampled(self) -> dict[PhaseType, GPTIndexedDataset]: assert not self.sampled if self.split: return self._build_split_unsampled() else: return {PhaseType.training: self.build_unsplit_unsampled()} - def build_unsplit_unsampled(self) -> "GPTIndexedDataset": + def build_unsplit_unsampled(self) -> GPTIndexedDataset: assert not self.split assert not self.sampled return self._build_unsplit_unsampled() @@ -100,10 +159,10 @@ def _build_split_sampled(self, data: "GPTData") -> dict[PhaseType, SampledDatase def _build_unsplit_sampled(self, data: "GPTData") -> SampledDataset: raise NotImplementedError() - def _build_split_unsampled(self) -> dict[PhaseType, "GPTIndexedDataset"]: + def _build_split_unsampled(self) -> dict[PhaseType, GPTIndexedDataset]: raise NotImplementedError() - def _build_unsplit_unsampled(self) -> "GPTIndexedDataset": + def _build_unsplit_unsampled(self) -> GPTIndexedDataset: raise NotImplementedError() @property @@ -132,7 +191,7 @@ def split(self) -> bool: def sampled(self) -> bool: return False - def _build_unsplit_unsampled(self) -> "GPTIndexedDataset": + def _build_unsplit_unsampled(self) -> GPTIndexedDataset: from fast_llm.data.gpt.memmap import GPTMemmapDataset return GPTMemmapDataset(self) @@ -182,7 +241,7 @@ def split(self) -> bool: def sampled(self) -> bool: return False - def _build_unsplit_unsampled(self) -> "GPTIndexedDataset": + def _build_unsplit_unsampled(self) -> GPTIndexedDataset: from fast_llm.data.gpt.concatenated import GPTConcatenatedDataset return GPTConcatenatedDataset(self, [dataset.build_unsplit_unsampled() for dataset in self.datasets]) @@ -232,7 +291,7 @@ def sampled(self) -> bool: def split(self) -> bool: return True - def _build_split_unsampled(self) -> dict[PhaseType, "GPTIndexedDataset"]: + def _build_split_unsampled(self) -> dict[PhaseType, GPTIndexedDataset]: from fast_llm.data.gpt.slice import GPTDatasetSlice return GPTDatasetSlice.from_splits(self.dataset.build_unsplit_unsampled(), self.ratios) @@ -281,7 +340,7 @@ def sampled(self) -> bool: def _build_split_sampled(self, data: "GPTData") -> dict[PhaseType, SampledDataset]: return {phase: dataset.build_unsplit_sampled(data) for phase, dataset in self.datasets.items()} - def _build_split_unsampled(self) -> dict[PhaseType, "GPTIndexedDataset"]: + def _build_split_unsampled(self) -> dict[PhaseType, GPTIndexedDataset]: return {phase: dataset.build_unsplit_unsampled() for phase, dataset in self.datasets.items()} @property diff --git a/fast_llm/data/gpt/dataset.py b/fast_llm/data/gpt/dataset.py index 010a3f6f3..e69de29bb 100644 --- a/fast_llm/data/gpt/dataset.py +++ b/fast_llm/data/gpt/dataset.py @@ -1,99 +0,0 @@ -import abc -import math - -import numpy as np -import numpy.random - -from fast_llm.data.config import Dataset -from fast_llm.utils import Assert - -try: - from fast_llm.csrc.data import build_sample_idx # noqa - - _extension_available = True -except ImportError: - _extension_available = False - - -class GPTIndexedDataset(Dataset): - """ - A GPT dataset containing a list of unsampled, unprocessed samples. - TODO: Move sampling responsibility here? - """ - - def get(self, document: int, offset: int = 0, length: int | None = None): - pass - - @property - def num_documents(self) -> int: - """ - Number of documents in the dataset. - Can be calculated from document sizes but may be overridden if there is a better method. - """ - return len(self.get_document_sizes()) - - @property - def num_tokens(self) -> int: - """ - Number of tokens in the dataset. - Can be calculated from document sizes but may be overridden if there is a better method. - """ - return self.get_document_sizes().sum() - - @abc.abstractmethod - def get_document_sizes(self) -> "np.ndarray": - """ - The size of each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - - def sample(self, num_samples: int, sequence_length: int, np_rng: numpy.random.RandomState, verbose: bool): - """ - Create a `GPTSampledDataset` with the requested parameters. - """ - document_sizes = self.get_document_sizes() - num_documents = len(document_sizes) - num_tokens = document_sizes.sum() - - num_epochs = math.ceil((sequence_length * num_samples + 1) / num_tokens) - # For the last epoch, decide whether include the entire epoch - # in the global shuffle or not. - # Get the number of samples for the last epoch - main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // sequence_length - last_epoch_samples = num_samples - main_epochs_samples - samples_per_epoch = (num_tokens - 1) // sequence_length - # If we have less than 80% of the samples for the last epoch, separate out the epoch and treat it differently. - # Note: the 80% number is just based on common sense and can be adjusted if needed. - separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch - - doc_idx = np.tile(np.arange(num_documents, dtype=np.int32), num_epochs) - if separate_last_epoch: - np_rng.shuffle(doc_idx[:-num_documents]) - np_rng.shuffle(doc_idx[-num_documents:]) - else: - np_rng.shuffle(doc_idx) - - assert _extension_available, ( - "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." - ) - - sample_idx = build_sample_idx(document_sizes, doc_idx, sequence_length, num_epochs, num_tokens, verbose) - - # shuffle-idx. - # -1 is due to data structure used to retrieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) - total_size = sample_idx.shape[0] - 1 - # TODO: Isn't the dataset already shuffled above? - shuffle_idx = np.arange( - 0, total_size, dtype=np.int64 if total_size >= (np.iinfo(np.uint32).max - 1) else np.uint32 - ) - if separate_last_epoch: - np_rng.shuffle(shuffle_idx[:main_epochs_samples]) - np_rng.shuffle(shuffle_idx[main_epochs_samples:]) - else: - np_rng.shuffle(shuffle_idx) - - Assert.geq(len(shuffle_idx), num_samples) - # TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch. - return doc_idx, sample_idx, shuffle_idx[:num_samples] diff --git a/fast_llm/data/gpt/memmap.py b/fast_llm/data/gpt/memmap.py index 272508866..77b986f58 100644 --- a/fast_llm/data/gpt/memmap.py +++ b/fast_llm/data/gpt/memmap.py @@ -3,8 +3,7 @@ import numpy as np -from fast_llm.data.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.gpt.dataset import GPTIndexedDataset +from fast_llm.data.gpt.config import GPTIndexedDataset, GPTMemmapDatasetConfig from fast_llm.utils import Assert, div, padded_cumsum diff --git a/fast_llm/data/gpt/sampled.py b/fast_llm/data/gpt/sampled.py index f7e8d1390..d528be3fc 100644 --- a/fast_llm/data/gpt/sampled.py +++ b/fast_llm/data/gpt/sampled.py @@ -8,8 +8,7 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.config import SampledDataset from fast_llm.data.fim import Fim -from fast_llm.data.gpt.config import GPTDataConfig -from fast_llm.data.gpt.dataset import GPTIndexedDataset +from fast_llm.data.gpt.config import GPTDataConfig, GPTIndexedDataset from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import MAX_SEED @@ -33,7 +32,7 @@ class GPTSampledDataset(SampledDataset): def __init__( self, - dataset: GPTIndexedDataset, + indexed_dataset: GPTIndexedDataset, num_samples: int, sequence_length: int, seed: int, @@ -43,7 +42,7 @@ def __init__( cache_dir: pathlib.Path, verbose: bool = True, ): - self._dataset = dataset + self._indexed_dataset = indexed_dataset if config.fim.rate > 0: assert tokenizer is not None @@ -75,37 +74,41 @@ def __init__( np.save(self._sample_idx_filename, sample_idx) np.save(self._shuffle_idx_filename, shuffle_idx) - safe_barrier(group, self._dataset.name) + safe_barrier(group, self._indexed_dataset.name) self._load_mappings(verbose) def _sample(self, num_samples: int, sequence_length: int, np_rng: numpy.random.RandomState, verbose: bool): """ Create a `GPTSampledDataset` with the requested parameters. """ - tokens_per_epoch = self._dataset.num_tokens - num_epochs = math.ceil((sequence_length * num_samples + 1) / tokens_per_epoch) + document_sizes = self._indexed_dataset.get_document_sizes() + num_documents = len(document_sizes) + num_tokens = document_sizes.sum() + + num_epochs = math.ceil((sequence_length * num_samples + 1) / num_tokens) # For the last epoch, decide whether include the entire epoch # in the global shuffle or not. # Get the number of samples for the last epoch - main_epochs_samples = ((num_epochs - 1) * tokens_per_epoch - 1) // sequence_length + main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // sequence_length last_epoch_samples = num_samples - main_epochs_samples - samples_per_epoch = (tokens_per_epoch - 1) // sequence_length + samples_per_epoch = (num_tokens - 1) // sequence_length # If we have less than 80% of the samples for the last epoch, separate out the epoch and treat it differently. # Note: the 80% number is just based on common sense and can be adjusted if needed. separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch - doc_idx = np.tile(np.arange(len(self._dataset), dtype=np.int32), num_epochs) + doc_idx = np.tile(np.arange(num_documents, dtype=np.int32), num_epochs) if separate_last_epoch: - np_rng.shuffle(doc_idx[: -len(self._dataset)]) - np_rng.shuffle(doc_idx[-len(self._dataset) :]) + np_rng.shuffle(doc_idx[:-num_documents]) + np_rng.shuffle(doc_idx[-num_documents:]) else: np_rng.shuffle(doc_idx) - assert _extension_available, "Please run `make -C ./fast_llm/csrc/` first." - sample_idx = build_sample_idx( - self._dataset.document_sizes, doc_idx, sequence_length, num_epochs, tokens_per_epoch, verbose + assert _extension_available, ( + "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) + sample_idx = build_sample_idx(document_sizes, doc_idx, sequence_length, num_epochs, num_tokens, verbose) + # shuffle-idx. # -1 is due to data structure used to retrieve the index: # sample i --> [sample_idx[i], sample_idx[i+1]) @@ -126,7 +129,7 @@ def _sample(self, num_samples: int, sequence_length: int, np_rng: numpy.random.R def __getstate__(self): return ( - self._dataset, + self._indexed_dataset, self._fim, self._seed, self._doc_idx_filename, @@ -136,7 +139,7 @@ def __getstate__(self): def __setstate__(self, state): ( - self._dataset, + self._indexed_dataset, self._fim, self._seed, self._doc_idx_filename, @@ -175,7 +178,7 @@ def __getitem__(self, idx): doc_f, offset_f = self._sample_idx[shuffled_idx] doc_l, offset_l = self._sample_idx[shuffled_idx + 1] sample_list = [ - self._dataset.get( + self._indexed_dataset.get( self._doc_idx[doc], offset=(doc == doc_f) * offset_f, length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, @@ -193,4 +196,4 @@ def __getitem__(self, idx): @property def name(self): - return self._dataset.name + return self._indexed_dataset.name diff --git a/fast_llm/data/gpt/slice.py b/fast_llm/data/gpt/slice.py index cc81736ca..d97fe5c53 100644 --- a/fast_llm/data/gpt/slice.py +++ b/fast_llm/data/gpt/slice.py @@ -1,4 +1,4 @@ -from fast_llm.data.gpt.dataset import GPTIndexedDataset +from fast_llm.data.gpt.config import GPTIndexedDataset from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum From 84e54e904371c746637e25037627ffc50f1f2d62 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 Nov 2024 16:38:44 -0500 Subject: [PATCH 08/13] fix --- fast_llm/data/gpt/dummy.py | 18 +------- fast_llm/data/gpt/sampled.py | 74 ++++++++++++++++---------------- fast_llm/data/gpt/slice.py | 2 +- fast_llm/models/custom/config.py | 4 +- fast_llm/models/gpt/config.py | 4 +- 5 files changed, 43 insertions(+), 59 deletions(-) diff --git a/fast_llm/data/gpt/dummy.py b/fast_llm/data/gpt/dummy.py index a43e8bb9b..76a4379bd 100644 --- a/fast_llm/data/gpt/dummy.py +++ b/fast_llm/data/gpt/dummy.py @@ -1,10 +1,6 @@ -import pathlib - import numpy as np from fast_llm.data.config import SampledDataset -from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.utils import Assert class DummyGPTDataset(SampledDataset): @@ -13,19 +9,9 @@ class DummyGPTDataset(SampledDataset): The sample can be purely random, or read from a file to allow reproducing in other runs. """ - def __init__( - self, prefix: pathlib.Path | None, num_samples: int, sequence_length: int, vocab_size: int, name: str = "dummy" - ): + def __init__(self, num_samples: int, sequence_length: int, vocab_size: int, name: str = "dummy"): self._num_samples = num_samples - if prefix is None: - self._dummy_sample = np.random.randint(0, vocab_size, size=(sequence_length + 1,), dtype=np.int64) - else: - log_main_rank(f"> Loading dummy dataset from file {prefix}") - self._dummy_sample = np.load(prefix, allow_pickle=True)[: sequence_length + 1] - Assert.eq(self._dummy_sample.shape, (sequence_length + 1,)) - Assert.eq(self._dummy_sample.dtype, np.int64) - Assert.lt(self._dummy_sample.max(), vocab_size) - Assert.geq(self._dummy_sample.min(), 0) + self._dummy_sample = np.random.randint(0, vocab_size, size=(sequence_length + 1,), dtype=np.int64) self._name = name def __len__(self): diff --git a/fast_llm/data/gpt/sampled.py b/fast_llm/data/gpt/sampled.py index d528be3fc..0479b900f 100644 --- a/fast_llm/data/gpt/sampled.py +++ b/fast_llm/data/gpt/sampled.py @@ -1,15 +1,12 @@ import math -import pathlib import numpy as np -import numpy.random -from torch._C._distributed_c10d import ProcessGroup from fast_llm.core.distributed import safe_barrier from fast_llm.data.config import SampledDataset from fast_llm.data.fim import Fim -from fast_llm.data.gpt.config import GPTDataConfig, GPTIndexedDataset -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.gpt.data import GPTData +from fast_llm.data.gpt.dataset import GPTIndexedDataset, GPTSamplingConfig from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import MAX_SEED from fast_llm.utils import Assert @@ -22,7 +19,7 @@ _extension_available = False -class GPTSampledDataset(SampledDataset): +class GPTSampledIndexedDataset(SampledDataset): """ A GPT dataset augmented with a sampling, i.e., a pre-computed, shuffled list of samples to be indexed sequentially (as-is) during training. @@ -33,33 +30,29 @@ class GPTSampledDataset(SampledDataset): def __init__( self, indexed_dataset: GPTIndexedDataset, - num_samples: int, - sequence_length: int, - seed: int, - group: ProcessGroup | None, - config: GPTDataConfig, - tokenizer: Tokenizer | None, - cache_dir: pathlib.Path, - verbose: bool = True, + config: GPTSamplingConfig, + data: GPTData, ): + assert isinstance(config, GPTSamplingConfig) + assert isinstance(data, GPTData) self._indexed_dataset = indexed_dataset + self._config = config - if config.fim.rate > 0: - assert tokenizer is not None - self._fim = Fim(config.fim, tokenizer) + if data.config.fim.rate > 0: + assert data.tokenizer is not None + self._fim = Fim(data.config.fim, data.tokenizer) else: self._fim = None - self._seed = seed - # rng state - np_rng = np.random.RandomState(seed=self._seed) - - cache_prefix = f"{self.name}_ns_{num_samples}_sl_{sequence_length}_s_{seed}" + cache_prefix = ( + f"{self.name}_ns_{self._config.num_samples}_sl_{self._config.sequence_length}_s_{self._config.seed}" + ) # TODO: Any way to combine into a single file? (Memmap is harder) - self._doc_idx_filename = cache_dir / (cache_prefix + "_doc_idx.npy") - self._sample_idx_filename = cache_dir / (cache_prefix + "_sample_idx.npy") - self._shuffle_idx_filename = cache_dir / (cache_prefix + "_shuffle_idx.npy") + self._doc_idx_filename = self._config.cache_directory / (cache_prefix + "_doc_idx.npy") + self._sample_idx_filename = self._config.cache_directory / (cache_prefix + "_sample_idx.npy") + self._shuffle_idx_filename = self._config.cache_directory / (cache_prefix + "_shuffle_idx.npy") + group = data.distributed.world_group # Build the indexed mapping if it doesn't exist. # TODO: This only works if the dataset location is accessible by all job. if (group is None or group.rank() == 0) and not ( @@ -67,31 +60,33 @@ def __init__( and self._sample_idx_filename.is_file() and self._shuffle_idx_filename.is_file() ): - if verbose: + if self._config.verbose: log_main_rank(" > Building the index map on rank 0 ...") - doc_idx, sample_idx, shuffle_idx = self._sample(num_samples, sequence_length, np_rng, verbose) + doc_idx, sample_idx, shuffle_idx = self._sample() + self._config.cache_directory.mkdir(parents=True, exist_ok=True) np.save(self._doc_idx_filename, doc_idx) np.save(self._sample_idx_filename, sample_idx) np.save(self._shuffle_idx_filename, shuffle_idx) safe_barrier(group, self._indexed_dataset.name) - self._load_mappings(verbose) + self._load_mappings(self._config.verbose) - def _sample(self, num_samples: int, sequence_length: int, np_rng: numpy.random.RandomState, verbose: bool): + def _sample(self): """ Create a `GPTSampledDataset` with the requested parameters. """ document_sizes = self._indexed_dataset.get_document_sizes() num_documents = len(document_sizes) num_tokens = document_sizes.sum() + np_rng = np.random.RandomState(seed=self._config.seed) - num_epochs = math.ceil((sequence_length * num_samples + 1) / num_tokens) + num_epochs = math.ceil((self._config.sequence_length * self._config.num_samples + 1) / num_tokens) # For the last epoch, decide whether include the entire epoch # in the global shuffle or not. # Get the number of samples for the last epoch - main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // sequence_length - last_epoch_samples = num_samples - main_epochs_samples - samples_per_epoch = (num_tokens - 1) // sequence_length + main_epochs_samples = ((num_epochs - 1) * num_tokens - 1) // self._config.sequence_length + last_epoch_samples = self._config.num_samples - main_epochs_samples + samples_per_epoch = (num_tokens - 1) // self._config.sequence_length # If we have less than 80% of the samples for the last epoch, separate out the epoch and treat it differently. # Note: the 80% number is just based on common sense and can be adjusted if needed. separate_last_epoch = num_epochs > 1 and last_epoch_samples < 0.8 * samples_per_epoch @@ -107,7 +102,9 @@ def _sample(self, num_samples: int, sequence_length: int, np_rng: numpy.random.R "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - sample_idx = build_sample_idx(document_sizes, doc_idx, sequence_length, num_epochs, num_tokens, verbose) + sample_idx = build_sample_idx( + document_sizes, doc_idx, self._config.sequence_length, num_epochs, num_tokens, self._config.verbose + ) # shuffle-idx. # -1 is due to data structure used to retrieve the index: @@ -123,15 +120,15 @@ def _sample(self, num_samples: int, sequence_length: int, np_rng: numpy.random.R else: np_rng.shuffle(shuffle_idx) - Assert.geq(len(shuffle_idx), num_samples) + Assert.geq(len(shuffle_idx), self._config.num_samples) # TODO: The doc and sample idx are way bigger than needed when sampling for << 1 epoch. - return doc_idx, sample_idx, shuffle_idx[:num_samples] + return doc_idx, sample_idx, shuffle_idx[: self._config.num_samples] def __getstate__(self): return ( self._indexed_dataset, self._fim, - self._seed, + self._config.to_serialized(), self._doc_idx_filename, self._sample_idx_filename, self._shuffle_idx_filename, @@ -141,11 +138,12 @@ def __setstate__(self, state): ( self._indexed_dataset, self._fim, - self._seed, + config, self._doc_idx_filename, self._sample_idx_filename, self._shuffle_idx_filename, ) = state + self._config = GPTSamplingConfig.from_dict(config) self._load_mappings(False) def _load_mappings(self, verbose): diff --git a/fast_llm/data/gpt/slice.py b/fast_llm/data/gpt/slice.py index d97fe5c53..cc81736ca 100644 --- a/fast_llm/data/gpt/slice.py +++ b/fast_llm/data/gpt/slice.py @@ -1,4 +1,4 @@ -from fast_llm.data.gpt.config import GPTIndexedDataset +from fast_llm.data.gpt.dataset import GPTIndexedDataset from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum diff --git a/fast_llm/models/custom/config.py b/fast_llm/models/custom/config.py index 8f965b8f8..f0f32dc3d 100644 --- a/fast_llm/models/custom/config.py +++ b/fast_llm/models/custom/config.py @@ -1,7 +1,7 @@ import typing from fast_llm.config import FieldUpdate, config_class -from fast_llm.data.gpt.config import DataConfig +from fast_llm.data.gpt.config import GPTDataConfig from fast_llm.models.gpt.config import ( GPTArchitectureConfig, GPTBaseModelConfig, @@ -12,7 +12,7 @@ @config_class() -class CustomDataConfig(DataConfig): +class CustomDataConfig(GPTDataConfig): # TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything. pass diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 467bd63e5..fbe069514 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -1,7 +1,7 @@ import typing from fast_llm.config import Field, FieldHint, FieldUpdate, config_class -from fast_llm.data.gpt.config import DataConfig +from fast_llm.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointFormat, CheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig @@ -122,7 +122,7 @@ class PretrainedGPTModelConfig(PretrainedFastLLMModelConfig): @config_class() class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): - data: DataConfig = FieldUpdate(default_factory=DataConfig) + data: GPTDataConfig = FieldUpdate(default_factory=GPTDataConfig) def _setup(self): super()._setup() From 27aa65730eb18ba3ab71a8d7406b3db9dc374590 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 Nov 2024 16:39:17 -0500 Subject: [PATCH 09/13] fix --- fast_llm/data/config.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 836a6d17b..7487265aa 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,5 +1,6 @@ import abc import enum +import pathlib import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -159,7 +160,20 @@ def name(self): """ -class SampledDataset(Dataset): # noqa +@config_class +class SamplingConfig(Config): + num_samples: int = Field(default=1, desc="Number of samples to generate.") + seed: int = Field(default=0, desc="Random seed.") + cache_directory: pathlib.Path | None = Field(default=None, desc="Path to the sampling cache directory.") + verbose: bool = Field(default=True, desc="Log sampling progress.") + + +class SamplableDataset(Dataset): + def sample(self, config: SamplingConfig, data: Data): + pass + + +class SampledDataset(Dataset): """ A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training. (See the `Sampler` class below.) From 11a3b4e4afacf01467ac922857fe07b84188a489 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 Nov 2024 16:50:20 -0500 Subject: [PATCH 10/13] fix --- fast_llm/data/gpt/dataset.py | 62 ++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/fast_llm/data/gpt/dataset.py b/fast_llm/data/gpt/dataset.py index e69de29bb..bae81c4f8 100644 --- a/fast_llm/data/gpt/dataset.py +++ b/fast_llm/data/gpt/dataset.py @@ -0,0 +1,62 @@ +import abc +import typing + +import numpy as np + +from fast_llm.config import Field, config_class +from fast_llm.data.config import SamplableDataset, SamplingConfig + +if typing.TYPE_CHECKING: + from fast_llm.data.gpt.data import GPTData + + +try: + from fast_llm.csrc.data import build_sample_idx # noqa + + _extension_available = True +except ImportError: + _extension_available = False + + +@config_class +class GPTSamplingConfig(SamplingConfig): + sequence_length: int = Field(default=None, desc="Number of token in each sample.") + + +class GPTIndexedDataset(SamplableDataset): + """ + A GPT dataset containing a list of unsampled, unprocessed samples. + TODO: Move sampling responsibility here? + """ + + def get(self, document: int, offset: int = 0, length: int | None = None): + pass + + @property + def num_documents(self) -> int: + """ + Number of documents in the dataset. + Can be calculated from document sizes but may be overridden if there is a better method. + """ + return len(self.get_document_sizes()) + + @property + def num_tokens(self) -> int: + """ + Number of tokens in the dataset. + Can be calculated from document sizes but may be overridden if there is a better method. + """ + return self.get_document_sizes().sum() + + @abc.abstractmethod + def get_document_sizes(self) -> "np.ndarray": + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + + def sample(self, config: GPTSamplingConfig, data: "GPTData"): + from fast_llm.data.gpt.sampled import GPTSampledIndexedDataset + + return GPTSampledIndexedDataset(self, config, data) From 3ed7c3430f7f80f0415619bb376949abc3dbe942 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 Nov 2024 16:52:48 -0500 Subject: [PATCH 11/13] fix --- fast_llm/data/gpt/config.py | 38 +------------------------------------ fast_llm/data/gpt/memmap.py | 3 ++- 2 files changed, 3 insertions(+), 38 deletions(-) diff --git a/fast_llm/data/gpt/config.py b/fast_llm/data/gpt/config.py index e5447a8dd..a63e23e50 100644 --- a/fast_llm/data/gpt/config.py +++ b/fast_llm/data/gpt/config.py @@ -13,6 +13,7 @@ SampledDataset, TokenizerConfig, ) +from fast_llm.data.gpt.dataset import GPTIndexedDataset from fast_llm.engine.distributed.config import PhaseType from fast_llm.utils import Assert @@ -45,43 +46,6 @@ def sample( pass -class GPTIndexedDataset(Dataset): - """ - A GPT dataset containing a list of unsampled, unprocessed samples. - TODO: Move sampling responsibility here? - """ - - def get(self, document: int, offset: int = 0, length: int | None = None): - pass - - @property - def num_documents(self) -> int: - """ - Number of documents in the dataset. - Can be calculated from document sizes but may be overridden if there is a better method. - """ - return len(self.get_document_sizes()) - - @property - def num_tokens(self) -> int: - """ - Number of tokens in the dataset. - Can be calculated from document sizes but may be overridden if there is a better method. - """ - return self.get_document_sizes().sum() - - @abc.abstractmethod - def get_document_sizes(self) -> "np.ndarray": - """ - The size of each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - - def sample(self, num_samples: int, sequence_length: int, np_rng: "np.random.RandomState", verbose: bool): - return - - @config_class() class GPTDatasetConfig(Config): diff --git a/fast_llm/data/gpt/memmap.py b/fast_llm/data/gpt/memmap.py index 77b986f58..272508866 100644 --- a/fast_llm/data/gpt/memmap.py +++ b/fast_llm/data/gpt/memmap.py @@ -3,7 +3,8 @@ import numpy as np -from fast_llm.data.gpt.config import GPTIndexedDataset, GPTMemmapDatasetConfig +from fast_llm.data.gpt.config import GPTMemmapDatasetConfig +from fast_llm.data.gpt.dataset import GPTIndexedDataset from fast_llm.utils import Assert, div, padded_cumsum From 2f7574621399711af49a08ace3bcd42a0a3e88e1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 Nov 2024 16:58:55 -0500 Subject: [PATCH 12/13] fix --- fast_llm/data/config.py | 4 ++-- fast_llm/data/gpt/dataset.py | 2 +- fast_llm/data/gpt/memmap.py | 29 +++++++++++++++++++++-------- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 7487265aa..1df544bd8 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -154,7 +154,7 @@ class Dataset(abc.ABC): @property @abc.abstractmethod - def name(self): + def name(self) -> str: """ A name for the dataset to facilitate identification and debugging. """ @@ -169,7 +169,7 @@ class SamplingConfig(Config): class SamplableDataset(Dataset): - def sample(self, config: SamplingConfig, data: Data): + def sample(self, config: SamplingConfig, data: Data) -> "SampledDataset": pass diff --git a/fast_llm/data/gpt/dataset.py b/fast_llm/data/gpt/dataset.py index bae81c4f8..37520ded9 100644 --- a/fast_llm/data/gpt/dataset.py +++ b/fast_llm/data/gpt/dataset.py @@ -49,7 +49,7 @@ def num_tokens(self) -> int: return self.get_document_sizes().sum() @abc.abstractmethod - def get_document_sizes(self) -> "np.ndarray": + def get_document_sizes(self) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, diff --git a/fast_llm/data/gpt/memmap.py b/fast_llm/data/gpt/memmap.py index 272508866..4ca3f507c 100644 --- a/fast_llm/data/gpt/memmap.py +++ b/fast_llm/data/gpt/memmap.py @@ -47,9 +47,14 @@ def _init(self, config: GPTMemmapDatasetConfig): self._index_bin_buffer_mmap = np.memmap(self._config.path.with_suffix(".idx"), mode="r", order="C") self._index_bin_buffer = memoryview(self._index_bin_buffer_mmap) - self._sizes = np.frombuffer(self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset) + self._document_sizes = np.frombuffer( + self._index_bin_buffer, dtype=np.int32, count=self._num_documents, offset=offset + ) self._pointers = np.frombuffer( - self._index_bin_buffer, dtype=np.int64, count=self._num_documents, offset=offset + self._sizes.nbytes + self._index_bin_buffer, + dtype=np.int64, + count=self._num_documents, + offset=offset + self._document_sizes.nbytes, ) self._bin_buffer_mmap = np.memmap(self._config.path.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -73,21 +78,29 @@ def get(self, document: int, offset: int = 0, length: int | None = None): return np.frombuffer( self._bin_buffer, dtype=self._dtype, - count=self._sizes[document] - offset if length is None else length, + count=self._document_sizes[document] - offset if length is None else length, offset=self._pointers[document] + offset * np.dtype(self._dtype).itemsize, ) @property - def num_documents(self): + def name(self) -> str: + return self._name + + @property + def num_documents(self) -> int: return self._num_documents @property - def num_tokens(self): + def num_tokens(self) -> int: return div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - @property - def document_sizes(self): - return self._sizes + def get_document_sizes(self) -> np.ndarray: + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + return self._document_sizes @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: list[np.ndarray]): From b54747fbae8ab4772da2695a47fea0550fa2bb8e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 11 Nov 2024 17:02:47 -0500 Subject: [PATCH 13/13] fix --- fast_llm/data/gpt/concatenated.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/fast_llm/data/gpt/concatenated.py b/fast_llm/data/gpt/concatenated.py index 19883dd00..9a6de0750 100644 --- a/fast_llm/data/gpt/concatenated.py +++ b/fast_llm/data/gpt/concatenated.py @@ -1,15 +1,16 @@ import numpy as np -from fast_llm.data.gpt.config import GPTConcatenatedDatasetConfig, GPTRawDataset +from fast_llm.data.gpt.config import GPTConcatenatedDatasetConfig +from fast_llm.data.gpt.dataset import GPTIndexedDataset from fast_llm.utils import padded_cumsum -class GPTConcatenatedDataset(GPTRawDataset): +class GPTConcatenatedDataset(GPTIndexedDataset): def __init__( self, config: GPTConcatenatedDatasetConfig, - datasets: list[GPTRawDataset], + datasets: list[GPTIndexedDataset], ): self._config = config self._datasets = datasets @@ -24,12 +25,6 @@ def num_tokens(self): def num_documents(self): return self._num_documents - def __getitem__(self, index: int): - """ - Get the sample (document) with the given index (in the split dataset). - """ - return self.get(index) - def get(self, document: int, offset: int = 0, length: int | None = None): """ Get the sample (document) with the given index (in the dataset slice),