diff --git a/Dockerfile b/Dockerfile index 00e13d957..6bc900ae7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,7 +38,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/ COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/ # Install dependencies within the virtual environment. -RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0 +RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0 # Copy the remaining source code with universal write permissions. COPY --chmod=777 ./Megatron-LM Megatron-LM diff --git a/fast_llm/config.py b/fast_llm/config.py index 9644df9c1..658ad5666 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -492,6 +492,10 @@ def _validate_element(cls, value, type_, name: str): value = cls._validate_dict(value, type_, name) elif origin is type: value = cls._validate_type(value, type_, name) + elif issubclass(origin, Config): + # TODO: Validate arguments for config generics. + cls._validate_element_type(value, type_.__origin__, strict=False) + value.validate(_is_validating=True) else: raise FieldTypeError(f"Unsupported __origin__ `{origin}`") elif not isinstance(type_, type): @@ -806,6 +810,8 @@ def _from_dict_nested(cls, value, type_, strict: bool): value = cls._from_dict_array(value, type_, strict) elif issubclass(origin, dict): value = cls._from_dict_dict(value, type_, strict) + elif issubclass(origin, Config): + value = cls._from_dict_config(value, type_, strict) elif origin is type: pass else: @@ -813,10 +819,15 @@ def _from_dict_nested(cls, value, type_, strict: bool): elif not isinstance(type_, type): raise FieldTypeError(f"Not a type: {type_}.") elif issubclass(type_, Config): - if value is MISSING: - value = {} - if isinstance(value, dict): - value = type_._from_dict(value, strict) + value = cls._from_dict_config(value, type_, strict) + return value + + @classmethod + def _from_dict_config(cls, value, type_, strict: bool): + if value is MISSING: + value = {} + if isinstance(value, dict): + value = type_._from_dict(value, strict) return value @classmethod @@ -938,6 +949,7 @@ def __init_subclass__(cls): We need to postpone validation until the class has been processed by the dataclass wrapper. """ Assert.eq(cls.__name__, cls.__qualname__) + super().__init_subclass__() for base_class in cls.__mro__: if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( @@ -1006,6 +1018,7 @@ def __init__(self, config: ConfigType, *args, **kwargs): def __init_subclass__(cls): # Automatically set `config_class` based on the bound type. # Make sure `ConfigType` is bound and respects class hierarchy. + super().__init_subclass__() try: config_class = None for base in types.get_original_bases(cls): diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 4c041945d..633367c80 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,9 +1,13 @@ import enum import pathlib +import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.tokenizer import Tokenizer + class MultiprocessingContext(str, enum.Enum): # Fast but risk of segfaults due to interactions with triton @@ -29,7 +33,7 @@ class TokenizerConfig(Config): hint=FieldHint.deprecated, valid=check_field(Assert.eq, TokenizerFromFile), ) - path: pathlib.Path | None = Field( + path: pathlib.Path = Field( default=None, desc="Path to the tokenizer file.", hint=FieldHint.core, @@ -39,3 +43,8 @@ class TokenizerConfig(Config): desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", hint=FieldHint.core, ) + + def get_tokenizer(self) -> "Tokenizer": + from fast_llm.data.tokenizer import Tokenizer + + return Tokenizer(self) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index e24d39985..c67dc0321 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -5,6 +5,7 @@ from fast_llm.config import Configurable from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.sample.abstract import Batch from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig @@ -47,5 +48,5 @@ def get_iterator( num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[typing.Any]: + ) -> typing.Iterator[Batch]: pass diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index efee46959..ba5be883a 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,11 +1,14 @@ import logging +import typing -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class -from fast_llm.data.config import MultiprocessingContext, TokenizerConfig +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig +from fast_llm.data.dataset.config import SampledDatasetConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.sample.language_model import LanguageModelSample logger = logging.getLogger(__name__) @@ -19,17 +22,12 @@ class GPTDataConfig(DataConfig): _abstract = False - tokenizer: TokenizerConfig = Field( - desc="Configuration for the tokenizer (for FIM).", - hint=FieldHint.feature, - ) # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, GPTSampledDatasetConfig] = Field( + datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..de47ef761 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,11 +1,8 @@ -import dataclasses import logging import pathlib import typing import warnings -from functools import partial -import numpy as np import torch import torch.utils.data @@ -14,60 +11,25 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.dataset.monitor import DatasetMonitor from fast_llm.data.iterator import SampledDatasetIterator -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) -@dataclasses.dataclass -class GPTBatch: - token_ids: torch.Tensor - loss_masking_spans: list[torch.Tensor] | None = None - sequence_lengths: list[torch.Tensor] | None = None - chosen_spans: list[torch.Tensor] | None = None - rejected_spans: list[torch.Tensor] | None = None - - -def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: - stacked_ids = np.stack([sample.token_ids for sample in batch]) - stacked_spans = None - sequence_lengths = None - stacked_chosen_spans = None - stacked_rejected_spans = None - if sampling_parameters.use_loss_masking_spans: - stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] - if sampling_parameters.use_preference_loss_spans: - stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch] - stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] - if not sampling_parameters.cross_document_attention: - sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] - return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), - loss_masking_spans=stacked_spans, - sequence_lengths=sequence_lengths, - chosen_spans=stacked_chosen_spans, - rejected_spans=stacked_rejected_spans, - ) - - class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ A global class for all dataset needs, including loading, splitting, sampling and iteration. - Currently hard-coded to a GPT dataset. - TODO: Separate generic and GPT classes. """ _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, GPTSamplingParameters] - _tokenizer: Tokenizer | None _is_setup: bool = False def __init__( @@ -108,7 +70,6 @@ def setup( ) log_main_rank(f"Preparing dataset. This may take several minutes.") - self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer) if self._cache_directory is None: # TODO: Avoid this @@ -116,11 +77,6 @@ def setup( self._datasets = {} for dataset_name, sampling_parameters in self._sampling_parameters.items(): - if self._tokenizer is not None: - # NOTE: Some models like Qwen2-1.5B-Instruct - # have vocab_size bigger in model config than in tokenizer - # TODO: Still, is it too constraining? - Assert.geq(sampling_parameters.vocab_size, self._tokenizer.vocab_size) if sampling_parameters.num_samples > 0: sampling = GPTSamplingData( config=self._config.sampling, @@ -128,7 +84,6 @@ def setup( cache_directory=self._cache_directory, distributed=distributed, dataset_name=dataset_name, - tokenizer=self._tokenizer, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) @@ -136,21 +91,16 @@ def setup( safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True - @property - def tokenizer(self) -> Tokenizer: - assert self._is_setup - return self._tokenizer - def get_iterator( self, - batch_config: BatchConfig, + batch_config: GPTBatchConfig, dataset_name: str, *, consumed_samples: int, num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[typing.Any]: + ) -> typing.Iterator[LanguageModelBatch]: assert self._is_setup # Some dataset names may come from phases and are capitalized, @@ -175,10 +125,7 @@ def get_iterator( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, - collate_fn=partial( - gpt_data_collate_fn, - sampling_parameters=sampling_parameters, - ), + collate_fn=LanguageModelBatch.from_samples, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) ) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index b470c0159..33942708b 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,11 +1,13 @@ import abc import typing +from fast_llm.data.sample.abstract import Sample + if typing.TYPE_CHECKING: from fast_llm.data.dataset.config import SamplingData -class Dataset(abc.ABC): +class Dataset[SampleType: Sample](abc.ABC): """ A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature. """ @@ -17,15 +19,23 @@ def name(self) -> str: A name for the dataset to facilitate identification and debugging. """ + def __getstate__(self): + state = super().__getstate__() + # Pickling sometimes fails with bound `SampleType`. + # This is not needed at runtime, so we just drop it. + if "__orig_class__" in state: + del state["__orig_class__"] + return state + -class SampledDataset(Dataset): +class SampledDataset[SampleType: Sample](Dataset[SampleType]): """ A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training. (See the `Sampler` class below.) """ @abc.abstractmethod - def __getitem__(self, index: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: pass @abc.abstractmethod @@ -33,8 +43,8 @@ def __len__(self) -> int: pass -class SamplableDataset(Dataset): +class SamplableDataset[SampleType: Sample](Dataset[SampleType]): @abc.abstractmethod - def sample(self, config: "SamplingData") -> SampledDataset: + def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 24b0fa76f..264eb373d 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -1,16 +1,16 @@ import logging -import typing -import numpy as np +import torch from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingData +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities logger = logging.getLogger(__name__) -class BlendedDataset(SampledDataset): +class BlendedDataset[SampleType: Sample](SampledDataset[SampleType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -21,7 +21,7 @@ class BlendedDataset(SampledDataset): def __init__( self, name: str, - datasets: list[SampledDataset], + datasets: list[SampledDataset[SampleType]], weights: list[float], sampling_config: SamplingData, ): @@ -29,51 +29,52 @@ def __init__( assert len(datasets) > 0 Assert.eq(len(datasets), len(weights)) self._datasets = datasets - self._weights = np.array(normalize_probabilities(weights)) + self._weights = torch.from_numpy(normalize_probabilities(weights, return_array=True)) self._num_samples = sampling_config.parameters.num_samples def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: """ Blending is typically done in one of the following iterative way (ex. in Megatron datasets): ```python dataset_index=np.zeros(num_samples) sample_index=np.zeros(num_samples) sampled=np.zeros(len(weights)) - for idx in range(num_samples): - error = weights * (idx + 1) - sampled + for index in range(num_samples): + error = weights * (index + 1) - sampled dataset_index_ = np.argmax(error) - dataset_index[idx] = dataset_index_ - sample_index[idx] = sampled[dataset_index_] + dataset_index[index] = dataset_index_ + sample_index[index] = sampled[dataset_index_] sampled[dataset_index_] +=1 ``` I.e. it iteratively picks samples to minimize the error `weights * sum(sampled) - sampled`. This implementation computes values on the fly instead of pre-computing them all. """ # We find the number of samples taken from each dataset prior to this point. - sampled = self._get_sampled(idx) + sampled = self._get_sampled(index) # Then get the present sample. - dataset_index = self._get_next_dataset(idx, sampled) - return self._datasets[dataset_index][sampled[dataset_index]] + dataset_index = self._get_next_dataset(index, sampled) + return self._datasets[dataset_index][sampled[dataset_index].item()] - def _get_sampled(self, num_samples: int): + def _get_sampled(self, num_samples: int) -> torch.Tensor: # First we determine a lower bound. # This is indeed a lower bound because a lower value for one dataset would involve more sampling below, # and it would be from that same dataset because it would have the highest error, - sampled = np.floor(self._weights * num_samples).astype(int) + + sampled = (self._weights * num_samples).to(torch.int64) # Then we sample until we reach the target number of samples. # This may not match the actual sampling order, but the final value of `sampled` is correct. - for idx in range(sampled.sum(), num_samples): - dataset_index = self._get_next_dataset(idx, sampled) + for index in range(sampled.sum().item(), num_samples): + dataset_index = self._get_next_dataset(index, sampled) sampled[dataset_index] += 1 return sampled - def _get_next_dataset(self, idx, sampled): + def _get_next_dataset(self, index: int, sampled: torch.Tensor) -> int: # The next sample is the one with the highest error. - return (self._weights * (idx + 1) - sampled).argmax() + return (self._weights * (index + 1) - sampled).argmax().item() @property - def name(self): + def name(self) -> str: return self._name diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 0c1b0cd09..20e40b66e 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -1,4 +1,5 @@ import dataclasses +import enum import functools import itertools import math @@ -7,6 +8,7 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: @@ -14,6 +16,17 @@ from fast_llm.engine.distributed.distributed import Distributed +class ShufflingType(str, enum.Enum): + # Shuffle all epochs together. Not extendable. + full = "full" + # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. + epoch = "epoch" + # Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones. + skip_first_epoch = "skip_first_epoch" + # Disable shuffling entirely. + disabled = "disabled" + + @config_class() class SamplingConfig(Config): """ @@ -25,6 +38,18 @@ class SamplingConfig(Config): desc="Seed for random sampling.", hint=FieldHint.feature, ) + gpu: bool = Field( + default=True, + desc="Enable fast sampling on GPU." + " Note that random sampling works differently on GPU," + " so the sample won't match the CPU equivalent.", + hint=FieldHint.feature, + ) + shuffle: ShufflingType = Field( + default=ShufflingType.epoch, + desc="Shuffling strategy.", + hint=FieldHint.feature, + ) @dataclasses.dataclass(kw_only=True) @@ -33,7 +58,12 @@ class SamplingParameters: Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ + sequence_length: int num_samples: int + truncate_documents: bool = True + # How many extra tokens to add to the sequence length. + # This is used to provide labels even for the last tokens in the sequence. + extra_tokens: int = 1 @dataclasses.dataclass(kw_only=True) @@ -64,37 +94,38 @@ def get_next_rank(self) -> int: @config_class() -class DatasetConfig(Config): +class DatasetConfig[SampleType: Sample](Config): _abstract: typing.ClassVar[bool] = True -@config_class() -class SampledDatasetConfig(DatasetConfig): +@config_class(registry=True) +class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. """ - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + # TODO: ====== `SamplingData` contains more than needed (ex. `num_samples`) raise NotImplementedError() @config_class() -class SamplableDatasetConfig(SampledDatasetConfig): - def build(self) -> SamplableDataset: +class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): + def build(self) -> SamplableDataset[SampleType]: raise NotImplementedError() - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: return self.build().sample(sampling) @config_class() -class IndexedDatasetConfig(SamplableDatasetConfig): - def _build(self) -> "IndexedDataset": +class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): + def build(self) -> "IndexedDataset[SampleType]": raise NotImplementedError() -@config_class() -class ConcatenatedDatasetConfig(SamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "concatenated"}) +class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): """ Concatenate multiple indexed datasets as if they were one. TODO: Make a post-sampling version? (staged training) @@ -106,7 +137,7 @@ class ConcatenatedDatasetConfig(SamplableDatasetConfig): desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[IndexedDatasetConfig] = Field( + datasets: list[IndexedDatasetConfig[SampleType]] = Field( default_factory=list, desc="The datasets to concatenate.", hint=FieldHint.core, @@ -116,14 +147,11 @@ class ConcatenatedDatasetConfig(SamplableDatasetConfig): def build(self) -> "ConcatenatedDataset": from fast_llm.data.dataset.indexed import ConcatenatedDataset - return self._build(ConcatenatedDataset) - - def _build[T: ConcatenatedDataset](self, cls: type[T]) -> T: - return cls(self.name, [dataset.build() for dataset in self.datasets]) + return ConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets]) -@config_class() -class DatasetSliceConfig(SamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "slice"}) +class DatasetSliceConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): """ Use a fraction of an indexed dataset, specified by the range (begin, end). Typically used to subsample a dataset, or to reserve part of the dataset for validation and/or testing. @@ -133,7 +161,7 @@ class DatasetSliceConfig(SamplableDatasetConfig): """ _abstract = False - dataset: IndexedDatasetConfig = Field( + dataset: IndexedDatasetConfig[SampleType] = Field( default=None, desc="The dataset to split.", hint=FieldHint.core, @@ -152,12 +180,9 @@ class DatasetSliceConfig(SamplableDatasetConfig): def build(self) -> "DatasetSlice": from fast_llm.data.dataset.indexed import DatasetSlice - return self._build(DatasetSlice) - - def _build[T: DatasetSlice](self, cls: type[T]) -> T: dataset = self.dataset.build() size = len(dataset) - return cls( + return DatasetSlice[SampleType]( f"{dataset.name}_{self.begin}_{self.end}", dataset, round(self.begin * size), @@ -165,8 +190,8 @@ def _build[T: DatasetSlice](self, cls: type[T]) -> T: ) -@config_class() -class SampledDatasetUpdateConfig(SampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "sampled"}) +class SampledDatasetUpdateConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): """ Wrap a dataset to explicitly sample from it and optionally update its configuration parameters. Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument. @@ -177,24 +202,24 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) - dataset: SampledDatasetConfig = Field( + dataset: SampledDatasetConfig[SampleType] = Field( desc="The dataset to sample from.", hint=FieldHint.core, ) - def build_and_sample(self, data: SamplingData) -> SampledDataset: + def build_and_sample(self, data: SamplingData) -> SampledDataset[SampleType]: return self.dataset.build_and_sample(data.update_config(self.sampling)) -@config_class() -class BlendedDatasetConfig(SampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "blended"}) +class BlendedDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): _abstract = False name: str = Field( default="blended", desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[SampledDatasetConfig] = Field( + datasets: list[SampledDatasetConfig[SampleType]] = Field( default_factory=list, desc="The datasets to blend.", hint=FieldHint.core, @@ -214,7 +239,7 @@ def _validate(self) -> None: def build_and_sample( self, sampling: SamplingData, - ) -> SampledDataset: + ) -> SampledDataset[SampleType]: from fast_llm.data.dataset.blended import BlendedDataset # Build and sample the datasets. @@ -235,7 +260,7 @@ def build_and_sample( for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) ] # Blend the datasets. - return BlendedDataset( + return BlendedDataset[SampleType]( self.name, sampled_datasets, self.weights, diff --git a/fast_llm/data/dataset/gpt/components/__init__.py b/fast_llm/data/dataset/gpt/components/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/dataset/gpt/components/config.py b/fast_llm/data/dataset/gpt/components/config.py new file mode 100644 index 000000000..59c419101 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/config.py @@ -0,0 +1,20 @@ +import dataclasses + +from fast_llm.engine.config_utils.data_type import DataType + +# TODO: Store span type? +# class SpanType(enum.StrEnum): +# none = "none" +# loss_masking = "loss_masking" +# preference = "preference" + + +@dataclasses.dataclass(kw_only=True) +class GPTMemmapDatasetHeader: + num_documents: int + token_data_type: DataType = DataType.int64 + has_spans: bool = False + has_images: bool = False + + def __post_init__(self): + self.token_data_type = DataType(self.token_data_type) diff --git a/fast_llm/data/dataset/gpt/components/images.py b/fast_llm/data/dataset/gpt/components/images.py new file mode 100644 index 000000000..cf217b5f0 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/images.py @@ -0,0 +1,243 @@ +import io +import math +import typing + +import numpy as np +import PIL.Image + +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.gpt.memmap import BufferOffset, ShiftMap +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.utils import Assert, div + + +class GPTImageDatasetComponent: + def __init__( + self, + header: GPTMemmapDatasetHeader, + index_binary_buffer: memoryview, + binary_buffer: memoryview, + offset: BufferOffset, + ): + self._header = header + self._index_binary_buffer = index_binary_buffer + self._binary_buffer = binary_buffer + + self._count_cumsum = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._header.num_documents + 1, + offset=offset.value, + ) + offset.value += self._count_cumsum.nbytes + self._sizes = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._count_cumsum[-1] * 2, + offset=offset.value, + ).reshape(-1, 2) + offset.value += self._sizes.nbytes + self._positions = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._count_cumsum[-1], + offset=offset.value, + ).reshape(-1, 2) + offset.value += self._positions.nbytes + + def get( + self, + index: int, + start_offset: int, + end_offset: int, + shift_map: ShiftMap, + buffer_offset: BufferOffset, + parameters: GPTSamplingParameters, + ) -> tuple[list[np.ndarray] | None, np.ndarray | None]: + # We get images from the document, discarding those outside the selected range. + images = [] + positions = [] + for image_index in range(self._count_cumsum[index], self._count_cumsum[index + 1]): + image_buffer_size = self._sizes[image_index].prod(initial=3) + image_position = shift_map.shift(self._positions[image_index].item()) + if start_offset <= image_position < end_offset: + images.append( + np.frombuffer( + self._binary_buffer, + dtype=np.dtype(np.uint8), + count=image_buffer_size, + offset=buffer_offset.value, + ).reshape(3, *self._sizes[image_index]) + ) + positions.append(self._positions[image_index]) + + buffer_offset.value += image_buffer_size + + def _get_insert(self, image_index: int, parameters: GPTSamplingParameters): + height, width = resized_image_length + height_patches = div(height, parameters.patch_size) + width_patches = div(width, parameters.patch_size) + image_size = height_patches * width_patches + if parameters.image_break_token is not None: + image_size += height_patches + elif parameters.image_end_token is not None: + image_size += 1 + + image_token_array = np.full((image_size,), -100, dtype=np.int64) + if parameters.image_break_token is not None: + for row in range(height_patches): + position = (row + 1) * width_patches + row + image_token_array[position] = parameters.image_break_token + + if parameters.image_end_token is not None: + # Will override the last image_break_token. + image_token_array[-1] = parameters.image_end_token + + start_pos = 0 + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, parameters.patch_size) + num_patches_w = div(width, parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if parameters.image_end_token is not None: + image_token_array[last_row_position] = parameters.image_end_token + else: + image_token_array[last_row_position] = parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if parameters.image_end_token is not None: + image_token_array[-1] = parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_sizes[idx] + start_pos = im_position + + resized_image_lengths = [ + get_resize_dims( + *image_length, + parameters.max_image_size, + parameters.max_image_size, + parameters.patch_size, + ) + for image_length in image_lengths + ] + return images, positions + + @classmethod + def write_document_and_gather_index( + cls, document: GPTSample, index_data: dict[str, typing.Any], binary_stream: io.BufferedWriter + ): + has_images = document.images is not None + if "has_images" in index_data: + Assert.eq(index_data["has_images"], has_images) + else: + index_data["has_images"] = has_images + if has_images: + if "image_sizes" not in index_data: + index_data["image_sizes"] = [] + if "image_positions" not in index_data: + index_data["image_positions"] = [] + if "num_pixels" not in index_data: + index_data["num_pixels"] = 0 + for image, image_position in zip(document.images, document.image_positions, strict=True): + # assume 3 channels (RGB) for all images + # TODO: Not consistent with GPTSample? + with PIL.Image.open(io.BytesIO(image["bytes"])) as img: + if img.mode != "RGB": + # Convert all images to RGB + img = img.convert("RGB") + pixels = np.array(img).transpose(2, 0, 1) # HWC to CHW + assert pixels.dtype == np.uint8, f"Expected uint8 pixels, got {pixels.dtype}." + index_data["image_sizes"].append(np.array(pixels.shape[1:])) + index_data["image_positions"].append(image_position) + # TODO: Shouldn't pixel count exclude the channel dimension? + index_data["num_pixels"] += pixels.size + binary_stream.write(pixels.tobytes(order="C")) + # Cumsum holds both image counts and buffer offsets. + if "image_cumsum" not in index_data: + index_data["image_cumsum"] = [0] + index_data["image_cumsum"].append(len(index_data["image_sizes"])) + + @classmethod + def write_index(self, index_data: dict[str, typing.Any], index_stream: io.BufferedWriter): + if index_data["has_images"]: + Assert.leq(index_data["image_cumsum"][-1], np.iinfo(np.int32).max) + Assert.eq(len(index_data["image_cumsum"]), index_data["num_documents"] + 1) + Assert.eq(len(index_data["image_sizes"]), index_data["image_cumsum"][-1]) + Assert.eq(len(index_data["image_positions"]), index_data["image_cumsum"][-1]) + index_stream.write(np.array(index_data["image_cumsum"], dtype=np.int32).tobytes(order="C")) + # n_pixels * 3 per image + index_stream.write(np.stack(index_data["image_sizes"], dtype=np.int32).tobytes(order="C")) + # Position of each image in the document + index_stream.write(np.array(index_data["image_positions"], dtype=np.int32).tobytes(order="C")) + + def get_sizes(self, index: int, parameters: GPTSamplingParameters) -> list[int]: + return [ + get_num_image_tokens( + *get_resize_dims( + *size.item(), + parameters.max_image_size, + parameters.max_image_size, + parameters.patch_size, + ), + parameters.patch_size, + image_break=parameters.image_break_token is not None, + image_end=parameters.image_end_token is not None, + ) + for size in self._sizes[self._count_cumsum[index] : self._count_cumsum[index + 1]] + ] + + def get_unshifted_positions_and_sizes( + self, index: int, parameters: GPTSamplingParameters + ) -> list[tuple[int, int]]: + return [ + (position, size) + for position, size in zip( + self._positions[self._count_cumsum[index] : self._count_cumsum[index + 1]], + self.get_sizes(index, parameters), + strict=True, + ) + ] + + +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) diff --git a/fast_llm/data/dataset/gpt/components/spans.py b/fast_llm/data/dataset/gpt/components/spans.py new file mode 100644 index 000000000..a4c331e00 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/spans.py @@ -0,0 +1,73 @@ +import io +import typing + +import numpy as np + +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.memmap import BufferOffset, ShiftMap +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.utils import Assert + + +class GPTSpansDatasetComponent: + def __init__( + self, + header: GPTMemmapDatasetHeader, + index_binary_buffer: memoryview, + binary_buffer: memoryview, + offset: BufferOffset, + ): + self._header = header + self._index_binary_buffer = index_binary_buffer + self._binary_buffer = binary_buffer + + self._count_cumsum = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._header.num_documents + 1, + offset=offset.value, + ) + offset.value += self._count_cumsum.nbytes + self._spans = np.frombuffer( + self._index_binary_buffer, + dtype=np.int32, + count=self._count_cumsum[-1] * 2, + offset=offset.value, + ).reshape(-1, 2) + offset.value += self._spans.nbytes + + def get(self, index: int, start_offset: int, end_offset: int, shift_map: ShiftMap) -> list[tuple[int, int]]: + loss_masking_spans = [] + for span_begin, span_end in self._spans[self._count_cumsum[index] : self._count_cumsum[index + 1]].tolist(): + span_begin = max(shift_map.shift(span_begin), start_offset) - start_offset + span_end = min(shift_map.shift(span_end), end_offset - 1) - start_offset + if span_end > span_begin: + loss_masking_spans.append((span_begin, span_end)) + return loss_masking_spans + + @classmethod + def write_document_and_gather_index( + cls, document: GPTSample, index_data: dict[str, typing.Any], binary_stream: io.BufferedWriter + ): + has_spans = document.loss_masking_spans is not None + if "has_span" in index_data: + Assert.eq(index_data["has_span"], has_spans) + else: + index_data["has_span"] = has_spans + if has_spans: + if "spans" not in index_data: + index_data["spans"] = [] + index_data["spans"].extend(document.loss_masking_spans) + if "spans_cumsum" not in index_data: + index_data["spans_cumsum"] = [0] + index_data["spans_cumsum"].append(len(index_data["spans"])) + + @classmethod + def write_index(self, index_data: dict[str, typing.Any], index_stream: io.BufferedWriter): + if index_data["has_spans"]: + # Should be ok, checking just in case. + Assert.leq(index_data["spans_cumsum"][-1], np.iinfo(np.int32).max) + Assert.eq(len(index_data["spans_cumsum"]), index_data["num_documents"] + 1) + Assert.eq(len(index_data["spans"]), index_data["spans_cumsum"][-1]) + index_stream.write(np.array(index_data["spans_cumsum"], dtype=np.int32).tobytes(order="C")) + index_stream.write(np.vstack(index_data["spans"], dtype=np.int32).tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/components/tokens.py b/fast_llm/data/dataset/gpt/components/tokens.py new file mode 100644 index 000000000..3a91fef54 --- /dev/null +++ b/fast_llm/data/dataset/gpt/components/tokens.py @@ -0,0 +1,63 @@ +import io +import typing + +import numpy as np + +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.memmap import BufferOffset, ShiftMap +from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.utils import Assert + + +class GPTTokensDatasetComponent: + def __init__( + self, + header: GPTMemmapDatasetHeader, + index_binary_buffer: memoryview, + binary_buffer: memoryview, + offset: BufferOffset, + ): + self._header = header + self._index_binary_buffer = index_binary_buffer + self._binary_buffer = binary_buffer + self.sizes = np.frombuffer( + self._index_binary_buffer, dtype=np.int32, count=self._header.num_documents, offset=offset.value + ) + self._item_size = self._header.token_data_type.numpy.itemsize + offset.value += self.sizes.nbytes + + def get( + self, index: int, start_offset: int, end_offset: int, shift_map: ShiftMap, buffer_offset: BufferOffset + ) -> np.ndarray: + unshifted_start_offset = shift_map.unshift(start_offset) + token_ids = np.frombuffer( + self._binary_buffer, + dtype=self._header.token_data_type, + count=shift_map.unshift(end_offset) - unshifted_start_offset, + offset=buffer_offset.value + unshifted_start_offset * self._item_size, + ) + buffer_offset.value += self.sizes[index] * self._item_size + return token_ids + + @classmethod + def write_document_and_gather_index( + cls, document: GPTSample, index_data: dict[str, typing.Any], binary_stream: io.BufferedWriter + ): + if "token_data_type" in index_data: + Assert.eq(document.token_ids.dtype, index_data["token_data_type"]) + else: + index_data["token_data_type"] = document.token_ids.dtype + if "document_lengths" not in index_data: + index_data["document_lengths"] = [] + index_data["document_lengths"].append(document_length := len(document.token_ids)) + if "num_tokens" not in index_data: + index_data["num_tokens"] = 0 + index_data["num_tokens"] += document_length + + # Write document to binary file + binary_stream.write(document.token_ids.tobytes(order="C")) + + @classmethod + def write_index(self, index_data: dict[str, typing.Any], index_stream: io.BufferedWriter): + # Document (tokens) lengths. + index_stream.write(np.array(index_data["document_lengths"], dtype=np.int32).tobytes(order="C")) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 656cd7d24..8bd9dde0f 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -1,80 +1,47 @@ import dataclasses -import enum import pathlib import time import typing import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import ( - BlendedDatasetConfig, - ConcatenatedDatasetConfig, - DatasetSliceConfig, IndexedDatasetConfig, SamplableDatasetConfig, SampledDatasetConfig, - SampledDatasetUpdateConfig, - SamplingConfig, SamplingData, SamplingParameters, ) +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset + from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset - from fast_llm.data.tokenizer import Tokenizer -class ShufflingType(str, enum.Enum): - # Shuffle all epochs together. Not extendable. - full = "full" - # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. - epoch = "epoch" - # Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones. - skip_first_epoch = "skip_first_epoch" - # Disable shuffling entirely. - disabled = "disabled" - - -@config_class() -class GPTSamplingConfig(SamplingConfig): - """ - A dataset-dependent configuration for sampling. - """ - - gpu: bool = Field( - default=True, - desc="Enable fast sampling on GPU." - " Note that random sampling works differently on GPU," - " so the sample won't match the CPU equivalent.", - hint=FieldHint.feature, - ) - shuffle: ShufflingType = Field( - default=ShufflingType.epoch, - desc="Shuffling strategy.", - hint=FieldHint.feature, - ) +@dataclasses.dataclass(kw_only=True) +class ImageSamplingParameters: + patch_size: int | None = None + max_image_size: int | None = None + image_break_token: int | None = None + image_end_token: int | None = None @dataclasses.dataclass(kw_only=True) -class GPTSamplingParameters(SamplingParameters): +class GPTSamplingParameters(SamplingParameters, ImageSamplingParameters): """ Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ - sequence_length: int vocab_size: int use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False - cross_document_attention: bool = True - truncate_documents: bool = True - # How many extra tokens to add to the sequence length. - # This is used to provide labels even for the last tokens in the sequence. - extra_tokens: int = 1 + use_images: bool = False @dataclasses.dataclass(kw_only=True) @@ -84,29 +51,11 @@ class GPTSamplingData(SamplingData): usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`. """ - config: GPTSamplingConfig parameters: GPTSamplingParameters - tokenizer: "Tokenizer" - - -@config_class(registry=True) -class GPTSampledDatasetConfig(SampledDatasetConfig): - pass - - -@config_class() -class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig): - pass - -@config_class() -class GPTIndexedDatasetConfig(GPTSamplableDatasetConfig, IndexedDatasetConfig): - def build(self) -> "GPTIndexedDataset": - raise NotImplementedError() - -@config_class(dynamic_type={GPTSampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "random"}) +class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -114,14 +63,14 @@ class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): hint=FieldHint.core, ) - def build(self) -> "GPTRandomDataset": + def build(self) -> "GPTRandomDataset[SampleType]": from fast_llm.data.dataset.gpt.random import GPTRandomDataset - return GPTRandomDataset(self.name) + return GPTRandomDataset[SampleType](self.name) -@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"}) -class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) +class GPTMemmapDatasetConfig[SampleType: LanguageModelSample](IndexedDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -138,50 +87,22 @@ class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): desc="Expected number of tokens in the dataset.", hint=FieldHint.optional, ) + num_pixels: int | None = Field( + default=None, + desc="Expected number of pixels in the dataset.", + hint=FieldHint.optional, + ) - def build(self) -> "GPTMemmapDataset": + def build(self) -> "GPTMemmapDataset[SampleType]": from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) -class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() - - def build(self) -> "GPTConcatenatedDataset": - from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset - - return self._build(GPTConcatenatedDataset) + return GPTMemmapDataset[SampleType]( + str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens, self.num_pixels + ) -@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"}) -class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - dataset: GPTIndexedDatasetConfig = FieldUpdate() - - def build(self) -> "GPTDatasetSlice": - from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice - - return self._build(GPTDatasetSlice) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"}) -class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): - _abstract = False - sampling: GPTSamplingConfig = FieldUpdate() - dataset: GPTSampledDatasetConfig = FieldUpdate() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"}) -class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): - _abstract: typing.ClassVar[bool] = False - datasets: list[GPTSampledDatasetConfig] = FieldUpdate() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "file"}) -class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "file"}) +class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -189,18 +110,18 @@ class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: config = self._load_config() return config.build_and_sample(sampling) - def build(self) -> SamplableDataset: + def build(self) -> SamplableDataset[SampleType]: config = self._load_config() - assert isinstance(config, GPTSamplableDatasetConfig) + assert isinstance(config, SamplableDatasetConfig) return config.build() - def _load_config(self): + def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." - return GPTSampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) + return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) def _convert_paths(self, config): # Recursively convert paths relative to `self.path.parent` to make them relative to cwd. @@ -224,6 +145,10 @@ class FimConfig(Config): Configuration for FIM. """ + tokenizer: TokenizerConfig = Field( + desc="Configuration for the tokenizer.", + hint=FieldHint.feature, + ) rate: float = Field( # TODO: Use meaningful default now that fim is a wrapper? default=0.0, @@ -286,15 +211,15 @@ class FimConfig(Config): ) -@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"}) -class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): +@config_class(dynamic_type={SampledDatasetConfig: "fim"}) +class GPTFimSampledDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType], FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - dataset: GPTSampledDatasetConfig = Field( + dataset: SampledDatasetConfig[SampleType] = Field( default=None, desc="The dataset to wrap with fim.", hint=FieldHint.core, @@ -303,14 +228,14 @@ class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): def build_and_sample( self, sampling: GPTSamplingData, - ) -> SampledDataset: + ) -> "GPTFimDataset[SampleType]": from fast_llm.data.dataset.gpt.fim import GPTFimDataset - return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling) + return GPTFimDataset[SampleType](self, self.dataset.build_and_sample(sampling), sampling) -@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) -class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "test_slow"}) +class GPTTestSlowDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. """ @@ -323,8 +248,8 @@ class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: assert sampling.distributed.config.world_size > 1 if sampling.distributed.config.rank == 0: time.sleep(self.sleep) - return GPTRandomDatasetConfig().build_and_sample(sampling) + return GPTRandomDatasetConfig[SampleType]().build_and_sample(sampling) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3be..fb6cae2ab 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -1,12 +1,14 @@ import numpy as np +import torch from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.distributed.config import MAX_SEED -class GPTFimDataset(SampledDataset): +class GPTFimDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py @@ -15,7 +17,7 @@ class GPTFimDataset(SampledDataset): def __init__( self, config: FimConfig, - dataset: SampledDataset, + dataset: SampledDataset[SampleType], sampling: GPTSamplingData, ): if sampling.parameters.use_loss_masking_spans: @@ -26,7 +28,7 @@ def __init__( self._dataset = dataset self._seed = sampling.config.seed - self._tokenizer = sampling.tokenizer + self._tokenizer = self._config.tokenizer.get_tokenizer() if self._tokenizer is None: raise ValueError("Fim requires a tokenizer") self._suffix_tok_id, self._prefix_tok_id, self._middle_tok_id, self._pad_tok_id = ( @@ -40,11 +42,18 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx: int) -> np.ndarray: - fim_token_ids = self._fim( - self._dataset[idx].token_ids, np.random.RandomState(seed=(self._seed + idx) % MAX_SEED) + def __getitem__(self, index: int) -> SampleType: + # TODO: Use torch methods to avoid back and forth. + return LanguageModelSample( + TokenSample( + torch.from_numpy( + self._fim( + self._dataset[index].tokens.tokens.numpy(), + np.random.RandomState(seed=(self._seed + index) % MAX_SEED), + ) + ) + ) ) - return GPTSample(fim_token_ids) @property def name(self) -> str: @@ -55,6 +64,7 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: # TODO: permute segments in sample_list, before concatenating. sample_len = sample.shape[0] eod = self._tokenizer.eod + # TODO: Available through `tokens.lengths` segment_breaks = np.argwhere(sample == eod) # split sample by document if segment_breaks.shape != (0, 1): # then there is an EOD token in this example @@ -73,19 +83,19 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:], np_rng) new_samples.append(permuted) - sample = np.concatenate(new_samples) + fim_sample = np.concatenate(new_samples) else: - sample = self._fim_split_and_permute_sequence(sample, np_rng) + fim_sample = self._fim_split_and_permute_sequence(sample, np_rng) # Truncate or pad sequence to max-length - diff = sample.shape[0] - sample_len + diff = fim_sample.shape[0] - sample_len if diff > 0: # too long - sample = sample[:sample_len] + fim_sample = fim_sample[:sample_len] elif diff < 0: # too short - sample = np.concatenate([sample, np.full((-1 * diff), self._pad_tok_id)]) + fim_sample = np.concatenate([fim_sample, np.full((-1 * diff), self._pad_tok_id)]) # noqa - assert sample.shape[0] == sample_len - return sample + assert fim_sample.shape[0] == sample_len + return fim_sample.astype(sample.dtype) def _fim_split_and_permute_sequence(self, sequence: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: """ @@ -158,9 +168,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, add_eos=False)], dtype=sequence.dtype) + middle = np.array([*self._tokenizer.tokenize(middle, add_bos=False, add_eos=False)], dtype=sequence.dtype) + suffix = np.array([*self._tokenizer.tokenize(suffix, add_bos=False)], dtype=sequence.dtype) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py deleted file mode 100644 index 896229772..000000000 --- a/fast_llm/data/dataset/gpt/indexed.py +++ /dev/null @@ -1,60 +0,0 @@ -import abc -import typing - -import numpy as np - -from fast_llm.data.dataset.gpt.config import GPTSamplingData -from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset - -if typing.TYPE_CHECKING: - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - -class GPTIndexedDataset(IndexedDataset): - @abc.abstractmethod - def get_document_sizes(self) -> np.ndarray: - """ - The size of each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - - @abc.abstractmethod - def get_document_size(self, index: int) -> int: - """ - The size of a document in the dataset. - """ - - def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - return GPTSampledIndexedDataset(self, sampling) - - -class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): - """ - A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. - """ - - _dataset: GPTIndexedDataset - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] - - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) - - -class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( - ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset -): - _datasets: list[GPTIndexedDataset] - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) - - def get_document_size(self, index: int) -> int: - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f39fd56f4..99d9957f8 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -1,16 +1,52 @@ +import functools +import json import pathlib import struct import typing import numpy as np +from fast_llm.data.dataset.gpt.components.config import GPTMemmapDatasetHeader +from fast_llm.data.dataset.gpt.components.images import GPTImageDatasetComponent +from fast_llm.data.dataset.gpt.components.spans import GPTSpansDatasetComponent +from fast_llm.data.dataset.gpt.components.tokens import GPTTokensDatasetComponent +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER -from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_INDEX_HEADER from fast_llm.utils import Assert, div +class BufferOffset: + # This makes offsets mutable. + def __init__(self, value: int): + self.value: int = value + + +class ShiftMap: + """ + A map between original and shifted token indices (i.e., accounting for extra content such as images). + Also serves as a cache so we don't have to recompute positions and sizes every time. + """ + + def __init__(self, positions_and_sizes: list[tuple[int, int]]): + self._positions_and_sizes = positions_and_sizes + + @functools.cached_property + def shifted_positions(self) -> list[int]: + return [self.shift(position) for position, _ in self._positions_and_sizes] + + def shift(self, index: int) -> int: + return index + sum(size for position, size in self._positions_and_sizes if index > position) + + def unshift(self, index: int) -> int: + return index - sum( + size + for shifted_position, (_, size) in zip(self.shifted_positions, self._positions_and_sizes, strict=True) + if shifted_position < index + ) + + class GPTMemmapDataset(GPTIndexedDataset): """ A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, @@ -26,293 +62,241 @@ def __init__( prefix: pathlib.Path | str, num_documents: int | None = None, num_tokens: int | None = None, + num_pixels: int | None = None, ): - self._init(name, prefix, num_documents, num_tokens) + self._init(name, prefix, num_documents, num_tokens, num_pixels) - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: + def _init( + self, + name: str, + prefix: pathlib.Path | str, + num_documents: int | None = None, + num_tokens: int | None = None, + num_pixels: int | None = None, + ) -> None: super().__init__() self._name = name self._prefix = pathlib.Path(prefix) - self._has_spans = 0 - self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: - self._has_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 2 and bool(struct.unpack("= 3 and bool(struct.unpack("= 4 and bool(struct.unpack("= 2: - self._spans = [] - self._num_spans = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, - ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes - self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] - for idx in range(self._num_documents): - self._spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) - ) - - # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: - self._chosen_spans = [] - self._rejected_spans = [] - chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes - for idx in range(self._num_documents): - self._chosen_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - rejected_span_offset = ( - offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes - ) - for idx in range(self._num_documents): - self._rejected_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") - self._bin_buffer = memoryview(self._bin_buffer_mmap) - - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) + self._spans = ( + GPTSpansDatasetComponent(self._header, self._index_binary_buffer, self._binary_buffer, offset) + if self._header.has_spans + else None + ) + self._images = ( + GPTImageDatasetComponent(self._header, self._index_binary_buffer, self._binary_buffer, offset) + if self._header.has_images + else None + ) + + if num_pixels is not None: + Assert.eq(num_pixels, self._images.total_pixels) + + # TODO: Simplify. + self._num_tokens = ( + self._binary_buffer_mmap.size + if self._images is None + else self._binary_buffer_mmap.size - self._images.total_pixels + ) if num_tokens is not None: - assert self._num_tokens == num_tokens + Assert.eq(num_tokens, self._num_tokens) - def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) + def __getstate__(self) -> tuple[str, pathlib.Path]: + return (self._name, self._prefix) - def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): + def __setstate__(self, state: tuple[str, pathlib.Path]): self._init(*state) def __del__(self): if hasattr(self, "_bin_buffer_mmap"): - self._bin_buffer_mmap._mmap.close() # noqa - del self._bin_buffer_mmap + self._binary_buffer_mmap._mmap.close() # noqa + del self._binary_buffer_mmap if hasattr(self, "_index_bin_buffer"): - self._index_bin_buffer_mmap._mmap.close() # noqa - del self._index_bin_buffer_mmap + self._index_binary_buffer_mmap._mmap.close() # noqa + del self._index_binary_buffer_mmap def get( self, - idx: int, - offset: int = 0, - length: int | None = None, - use_loss_masking_spans: bool = False, - use_preference_loss_spans: bool = False, + index: int, + start_offset: int = 0, + end_offset: int | None = None, + parameters: GPTSamplingParameters | None = None, ) -> GPTSample: - token_ids = np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - ) - sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] - - # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - - chosen_span = None - rejected_span = None - - if use_preference_loss_spans: - if not self._has_preference_spans: - raise ValueError("No preference spans found in memmap dataset.") - elif self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") - else: - chosen_span = self._chosen_spans[idx] - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] + if end_offset is None: + end_offset = self.get_document_size(index, parameters) - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset + shift_map = ShiftMap( + self._images.get_unshifted_positions_and_sizes(index, parameters) if parameters.use_images else [] + ) - rejected_span = self._rejected_spans[idx] + buffer_offset = BufferOffset(self._buffer_offsets[index].item()) + sample = GPTSample(token_ids=self._tokens.get(index, start_offset, end_offset, shift_map, buffer_offset)) - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[ - (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) - ][0] + if parameters.use_loss_masking_spans: + sample.loss_masking_spans = self._spans.get(index, start_offset, end_offset, shift_map) - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + if parameters.use_images: + sample.images, sample.image_positions = self._images.get( + index, start_offset, end_offset, shift_map, buffer_offset + ) - return GPTSample( - token_ids=token_ids, - loss_masking_spans=sample_spans, - chosen_span=chosen_span, - rejected_span=rejected_span, - ) + start_pos = 0 + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, parameters.patch_size) + num_patches_w = div(width, parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if parameters.image_end_token is not None: + image_token_array[last_row_position] = parameters.image_end_token + else: + image_token_array[last_row_position] = parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if parameters.image_end_token is not None: + image_token_array[-1] = parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_sizes[idx] + start_pos = im_position + + return sample @property def name(self) -> str: return self._name def __len__(self) -> int: - return self._num_documents + return self._header.num_documents - @property - def num_tokens(self) -> int: - return self._num_tokens - - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: """ The size of each document in the dataset. The resulting array could be very large, so this method should be called cautiously, and derived classes should try to avoid holding the whole array im memory. """ - return self._document_sizes + if parameters is not None and parameters.use_images: + # TODO: Optimize this. + return np.array([self.get_document_size(index, parameters) for index in range(self._header.num_documents)]) + return self._tokens.sizes + + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: + size = self._tokens.sizes[index].item() + if parameters is not None and parameters.use_images: + for _, size_ in self._images.get_positions_and_sizes(index, parameters): + size += size_ + return size + + def _shift_offset(self, offset, index: int, parameters: GPTSamplingParameters | None = None) -> int: + if parameters is not None and parameters.use_images: + offset += sum( + size for position, size in self._images.get_positions_and_sizes(index, parameters) if position < offset + ) + return offset - def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item() + def _unshift_offset(self, offset, index: int, parameters: GPTSamplingParameters | None = None) -> int: + unshifted_offset = offset + if parameters is not None and parameters.use_images: + for position, size in self._images.get_positions_and_sizes(index, parameters): + shifted_position = self._shift_offset(position, index, parameters) + if shifted_position < offset: + unshifted_offset -= size + return unshifted_offset @classmethod def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): - # Initialize metadata - dtype = None + buffer_offsets = [] + index_data = {} num_documents = 0 - lengths = [] - pointers = [] - offset = 0 - # number of spans for each document - num_spans = [] - spans = [] - chosen_spans = [] - rejected_spans = [] + component_classes = (GPTTokensDatasetComponent, GPTSpansDatasetComponent, GPTImageDatasetComponent) prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) - # Write the binary data file (.bin) lazily - with prefix.with_suffix(".bin").open("wb") as bin_stream: + with prefix.with_suffix(".bin").open("wb") as binary_stream: + for document in documents: - # Infer dtype from the first document - if dtype is None: - dtype = document.token_ids.dtype - assert dtype is not None, "Document dtype could not be inferred from the data." - - # Ensure all documents have the same dtype - assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." - - # Write document to binary file - bin_stream.write(document.token_ids.tobytes(order="C")) - - # Update metadata - doc_length = len(document.token_ids) - lengths.append(doc_length) - pointers.append(offset) - if document.loss_masking_spans is not None: - num_spans.append(len(document.loss_masking_spans)) - spans.append(document.loss_masking_spans) - if document.chosen_span is not None: - chosen_spans.append(document.chosen_span) - if document.rejected_span is not None: - rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize - num_documents += 1 + buffer_offsets.append(binary_stream.tell()) + for component_class in component_classes: + component_class.write_document_and_gather_index(document, index_data, binary_stream) - # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) - pointers = np.array(pointers, dtype=np.int64) - num_spans = np.array(num_spans, dtype=np.int32) - if len(spans) > 0: - spans = np.vstack(spans, dtype=np.int32) - else: - spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) + # TODO: Address + assert document.chosen_span is None and document.rejected_span is None + + num_documents += 1 # Write the index file (.idx) - with prefix.with_suffix(".idx").open("wb") as idx_stream: - idx_stream.write(MEMMAP_INDEX_HEADER) - # Indicates the version - # Version 2 optionally adds loss-masking spans - # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) - # Flag to indicate whether preference loss-masking spans are present - idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) - # Data type - idx_stream.write(struct.pack(" str: return self._name -class GPTRandomSampledDataset(SampledDataset): +class GPTRandomSampledDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed - self._sequence_length = sampling.parameters.sequence_length - self._vocab_size = sampling.parameters.vocab_size - self._num_samples = sampling.parameters.num_samples + self._parameters = sampling.parameters + # TODO: Support? + assert not self._parameters.use_loss_masking_spans + assert not self._parameters.use_preference_loss_spans + self._dtype = get_unsigned_integer_type(self._parameters.vocab_size).torch def __len__(self) -> int: - return self._num_samples - - def __getitem__(self, idx) -> np.ndarray: - return GPTSample( - np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( - 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + return self._parameters.num_samples + + def __getitem__(self, index: int) -> SampleType: + # TODO: Sample in self._dtype (breaking) + return LanguageModelSample( + TokenSample( + torch.from_numpy( + np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( + 0, + self._parameters.vocab_size, + size=(self._parameters.sequence_length + self._parameters.extra_tokens,), + ) + ).to(self._dtype), ) ) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 09ed52779..c6eac9e28 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -1,20 +1,37 @@ import abc -import typing -import numpy as np +import torch from fast_llm.data.dataset.abstract import SamplableDataset +from fast_llm.data.dataset.config import SamplingData, SamplingParameters +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, padded_cumsum -class IndexedDataset(SamplableDataset): +class IndexedDataset[SampleType: Sample](SamplableDataset[SampleType]): """ A dataset containing a list of samples. TODO: Move sampling responsibility here? """ @abc.abstractmethod - def get(self, index: int, *args, **kwargs) -> typing.Any: + def get_document_sizes(self) -> torch.Tensor: + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + + @abc.abstractmethod + def get_document_size(self, index: int) -> int: + """ + The size of a document in the dataset. + """ + + @abc.abstractmethod + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: pass @abc.abstractmethod @@ -23,13 +40,18 @@ def __len__(self) -> int: Number of samples in the dataset. """ + def sample(self, sampling: SamplingData) -> "GPTSampledIndexedDataset": + from fast_llm.data.dataset.sampled import SampledIndexedDataset + + return SampledIndexedDataset(self, sampling) -class DatasetSlice[IndexedDatasetType: IndexedDataset](IndexedDataset): + +class DatasetSlice[SampleType: Sample](IndexedDataset[SampleType]): def __init__( self, name: str, - dataset: IndexedDataset, + dataset: IndexedDataset[SampleType], begin: int | None = None, end: int | None = None, ): @@ -46,15 +68,22 @@ def __init__( except Exception as e: raise AssertionError(f"Invalid document indices for dataset {name} with length {num_samples}") from e - def get( - self, document: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - ) -> typing.Any: + def get_document_sizes(self) -> torch.Tensor: + # TODO: This can be really big. + return self._dataset.get_document_sizes()[self._begin : self._end] + + def get_document_size(self, index: int) -> int: + return self._dataset.get_document_size(self._begin + index) + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: """ Get the sample (document) with the given index (in the dataset slice), - optionally sub-sampled to a specific offset (starting point) and maximum length + optionally subsampled to a specific offset (starting point) and maximum length (end = min(offset + length, sample_length). """ - return self._dataset.get(document + self._begin, offset, length, use_loss_masking_spans) + return self._dataset.get_document(index + self._begin, begin, end, parameters) def __len__(self) -> int: return self._end - self._begin @@ -64,24 +93,36 @@ def name(self) -> str: return self._name -class ConcatenatedDataset[IndexedDatasetType: IndexedDataset](IndexedDataset): +class ConcatenatedDataset[SampleType: Sample](IndexedDataset[SampleType]): def __init__( self, name: str, - datasets: list[IndexedDataset], + datasets: list[IndexedDataset[SampleType]], ): self._name = name self._datasets = datasets sizes = [len(dataset) for dataset in self._datasets] - self._dataset_splits = padded_cumsum(sizes) + self._dataset_splits = torch.from_numpy(padded_cumsum(sizes)) def __len__(self) -> int: return self._dataset_splits[-1].item() - def get(self, index: int, *args, **kwargs): - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get(index - self._dataset_splits[dataset].item(), *args, **kwargs) + def get_document_sizes(self) -> torch.Tensor: + # TODO: This can be really big. + return torch.cat([dataset.get_document_sizes() for dataset in self._datasets]) + + def get_document_size(self, index: int) -> int: + dataset = torch.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: + dataset = torch.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_document( + index - self._dataset_splits[dataset].item(), begin, end, parameters + ) @property def name(self) -> str: diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 86bc080fe..01f3195e4 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -1,8 +1,8 @@ import logging import time -import typing from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.sample.abstract import Sample try: from fast_llm.csrc.data import build_blending_indices # noqa @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class DatasetMonitor(SampledDataset): +class DatasetMonitor[SampleType: Sample](SampledDataset[SampleType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -24,7 +24,7 @@ class DatasetMonitor(SampledDataset): def __init__( self, - dataset: SampledDataset, + dataset: SampledDataset[SampleType], data_sample_warn_time_ms: float, ): self._dataset = dataset @@ -33,19 +33,19 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: start_time = time.perf_counter() try: - sample = self._dataset[idx] + sample = self._dataset[index] sample_time = (time.perf_counter() - start_time) * 1000 if sample_time > self._data_sample_warn_time_ms: logger.warning( - f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" + f"Sample {index} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" ) return sample except Exception: - logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") + logger.error(f"Failed to get sample {index} from dataset {self._dataset.name}") raise @property diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/sampled.py similarity index 65% rename from fast_llm/data/dataset/gpt/sampled.py rename to fast_llm/data/dataset/sampled.py index 95006f18e..8d22c6b99 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -10,14 +10,16 @@ import yaml from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset +from fast_llm.data.dataset.config import SamplingData, ShufflingType +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.utils import Assert +from fast_llm.layers.vision.preprocessing import get_num_image_tokens, get_resize_dims +from fast_llm.utils import Assert, div try: - from fast_llm.csrc.data import build_padded_token_cumsum # noqa + from fast_llm.csrc.data import build_padded_token_cumsum, build_sample_idx # noqa _extension_available = True except ImportError: @@ -29,6 +31,8 @@ @dataclasses.dataclass class GPTSample: token_ids: np.ndarray + images: list[np.ndarray] | None = None + image_positions: np.ndarray | None = None loss_masking_spans: np.ndarray | None = None chosen_span: np.ndarray | None = None rejected_span: np.ndarray | None = None @@ -75,23 +79,30 @@ def _lazy_load(self): TOKEN_CUMSUM_RATE = 10 -class GPTSampledIndexedDataset(SampledDataset): +class SampledIndexedDataset[SampleType: Sample](SampledDataset[SampleType]): """ - A sampled GPT dataset. + A sampled dataset. """ def __init__( self, - indexed_dataset: GPTIndexedDataset, - sampling: GPTSamplingData, + indexed_dataset: IndexedDataset[SampleType], + sampling: SamplingData, ): - assert isinstance(sampling, GPTSamplingData) self._indexed_dataset = indexed_dataset self._config = sampling.config self._parameters = sampling.parameters self._truncate_documents = sampling.parameters.truncate_documents self._device = torch.device("cuda" if self._config.gpu else "cpu") + # TODO: address + assert not self._parameters.use_preference_loss_spans + + if self._parameters.use_images: + assert not self._truncate_documents, ( + "Truncating documents with images is not yet supported." " Please turn off truncation to use images." + ) + if sampling.cache_directory is None: self._document_shuffling = MemmapArray() self._token_cumsum_shuffled = MemmapArray() @@ -111,58 +122,119 @@ def __init__( ) # TODO: Names are confusing self._document_shuffling = MemmapArray(base_path.with_name(base_path.name + "_shuffling.npy")) - self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) - self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) self._yaml_path = base_path.with_suffix(".yaml") - # keep document sizes and len filtered docs for preference loss masking - if self._parameters.use_preference_loss_spans: - self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) - self._doc_length_filtered_indicies = MemmapArray( - base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") - ) + self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) + self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) # Sample or validate the dataset of a given rank. if sampling.distributed.config.rank == sampling.get_next_rank(): self._sample() # No barrier yet to allow running in parallel. - # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. + # There needs to be one before calling `__getitem__`, normally handled through `Data`. def _sample(self) -> None: """ - Create a `GPTSampledDataset` with the requested parameters. + Create a `SampledDataset` with the requested parameters. """ - # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) - documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + # Get the size each document, the main information needed for sampling. + # Note: "document" may refer to more than just text. + document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes(self._parameters)).to(self._device) + + documents_per_epoch, tokens_per_epoch, long_docs_filter = self._get_epoch_size(document_sizes) + num_epochs, shuffled_epochs = self._get_epoch_count(documents_per_epoch, tokens_per_epoch) + + shuffled_documents = documents_per_epoch * shuffled_epochs + unshuffled_epochs = num_epochs - shuffled_epochs + + yaml_data, cached = self._get_and_compare_yaml_data(documents_per_epoch, tokens_per_epoch, unshuffled_epochs) + if cached: + return + + if shuffled_documents > 1e8: + warnings.warn( + f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." + f" This may take a while and/or use an excessive amount of memory." + ) + elif documents_per_epoch > 1e8: + # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? + warnings.warn( + f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." + f" Sampling may take a while and/or use an excessive amount of memory." + ) - # Calculate basic stats. - if not self._truncate_documents: + document_shuffling = self._get_document_shuffling(documents_per_epoch, shuffled_documents, shuffled_epochs) + + # To get a sample on the fly we need to know where it begins, + # and this is a non-trivial information because the documents have variable length. + # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. + # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. + # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. + # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. + # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. + # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` + + # TODO: Allowing for max 100% extra tokens for padding, is that enough? + cumsum_dtype = get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs) + if unshuffled_epochs > 0: + token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum(document_sizes, 0, cumsum_dtype) + self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) + else: + unshuffled_tokens = 0 + + if shuffled_epochs > 0: + token_cumsum_shuffled, _ = self._get_token_cumsum( + document_sizes[ + # Torch indexing only works with int32 or int64 + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ], + self._unshuffled_tokens, + cumsum_dtype, + ) + self._token_cumsum_shuffled.save(token_cumsum_shuffled) + self._document_shuffling.save( + document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy(force=True) + ) + + yaml_data["unshuffled_tokens"] = unshuffled_tokens + self._load_yaml_data(yaml_data) + if self._yaml_path is not None: + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + + def _get_epoch_size(self, document_sizes: torch.Tensor) -> tuple[int, int, torch.Tensor | None]: + documents_per_epoch = document_sizes.numel() + if self._truncate_documents: + tokens_per_epoch = document_sizes.sum().item() + long_docs_filter = None + else: assert _extension_available, ( "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._parameters.sequence_length + 1 - ignored_documents = long_docs_filter.sum().item() - if ignored_documents: + long_docs_filter = document_sizes <= self._parameters.sequence_length + 1 + documents_per_epoch_filtered = long_docs_filter.sum().item() + if ignored_documents := documents_per_epoch_filtered - documents_per_epoch: log_main_rank( - f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", + f" > {ignored_documents}/{documents_per_epoch} documents" + f" are longer than {self._parameters.sequence_length+1} tokens and will be ignored.", log_fn=logger.warning, ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + tokens_per_epoch = document_sizes[long_docs_filter].sum().item() if tokens_per_epoch == 0: raise RuntimeError( - f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + f" > No documents shorter than {self._parameters.sequence_length+1}" + f" tokens found in dataset {self._indexed_dataset.name}." ) + return documents_per_epoch, tokens_per_epoch, long_docs_filter + def _get_epoch_count(self, documents_per_epoch: int, tokens_per_epoch: int) -> tuple[int, int]: # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. - if self._parameters.use_preference_loss_spans: - documents_per_epoch = (~long_docs_filter).sum().item() - num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) - elif self._truncate_documents: + if self._truncate_documents: num_epochs = math.ceil( (self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens) / tokens_per_epoch @@ -174,32 +246,34 @@ def _sample(self) -> None: ) # Prepare for shuffling. - generator = torch.Generator(device=self._device) if self._config.shuffle == ShufflingType.skip_first_epoch: shuffled_epochs = num_epochs - 1 elif self._config.shuffle == ShufflingType.disabled: shuffled_epochs = 0 else: shuffled_epochs = num_epochs - shuffled_documents = documents_per_epoch * shuffled_epochs - unshuffled_epochs = num_epochs - shuffled_epochs + return num_epochs, shuffled_epochs + def _get_and_compare_yaml_data( + self, + documents_per_epoch: int, + tokens_per_epoch: int, + unshuffled_epochs: int, + ) -> tuple[dict[str, typing.Any], bool]: yaml_data = { "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, "tokens_per_epoch": tokens_per_epoch, }, - "num_samples": self._parameters.num_samples, + "sampling": self._parameters.__dict__, "unshuffled_epochs": unshuffled_epochs, - "sequence_length": self._parameters.sequence_length, - "truncate_documents": self._truncate_documents, "config": self._config.to_dict(), } if self._truncate_documents: yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs - if self._yaml_path is not None and self._yaml_path.is_file(): + if cached := (self._yaml_path is not None and self._yaml_path.is_file()): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) # Hack to make sure unshuffled tokens are loaded if not self._truncate_documents: @@ -216,120 +290,8 @@ def _sample(self) -> None: ) # Dataset is already sampled, skip. logger.info(f"Using existing sampling for dataset {self.name}") - return - if shuffled_documents > 1e8: - warnings.warn( - f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." - f" This may take a while and/or use an excessive amount of memory." - ) - elif documents_per_epoch > 1e8: - # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? - warnings.warn( - f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." - f" Sampling may take a while and/or use an excessive amount of memory." - ) - - # Use the smallest possible data type to save memory and disk usage. - document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch - # Shuffle the dataset (documents) - # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial - # so we only evaluate and store the shuffled part `document_shuffling`. - if self._config.shuffle == ShufflingType.full: - generator.manual_seed(self._config.seed) - # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` - document_shuffling = ( - torch.randperm( - shuffled_documents, - generator=generator, - dtype=get_unsigned_integer_type(shuffled_documents).torch, - device=self._device, - ) - .remainder_(documents_per_epoch) - .to(dtype=document_shuffling_dtype) - ) - elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): - document_shuffling = torch.empty( - shuffled_documents, - dtype=document_shuffling_dtype, - device=self._device, - ) - for i in range(shuffled_epochs): - generator.manual_seed(self._config.seed + i * 571) - torch.randperm( - documents_per_epoch, - generator=generator, - out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], - ) - elif self._config.shuffle == ShufflingType.disabled: - document_shuffling = None - else: - raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - - if self._parameters.use_preference_loss_spans: - yaml_data["unshuffled_tokens"] = 0 # not used, ignore - - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # apply shuffling on doc_length_filtered_indicies - if shuffled_epochs > 0: - self._document_shuffling.save( - document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) - ) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - return - - # To get a sample on the fly we need to know where it begins, - # and this is a non-trivial information because the documents have variable length. - # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. - # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. - # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. - # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. - # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. - # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` - if unshuffled_epochs > 0: - token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes, - offset=0, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) - else: - unshuffled_tokens = 0 - - if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens - self._load_yaml_data(yaml_data) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - - if shuffled_epochs > 0: - token_cumsum_shuffled, _ = self._get_token_cumsum( - document_sizes[ - # Torch indexing only works with int32 or int64 - document_shuffling.to( - dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 - ) - ], - offset=self._unshuffled_tokens, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_shuffled.save(token_cumsum_shuffled) - self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( - force=self._config.gpu - ) - ) - # Free memory - del document_shuffling + return yaml_data, cached def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: if self._truncate_documents: @@ -372,50 +334,61 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - ] return out, num_tokens + def _get_document_shuffling( + self, + documents_per_epoch: int, + shuffled_documents: int, + shuffled_epochs: int, + ) -> torch.Tensor | None: + generator = torch.Generator(device=self._device) + # Use the smallest possible data type to save memory and disk usage. + document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch + # Shuffle the dataset (documents) + # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial + # so we only evaluate and store the shuffled part `document_shuffling`. + if self._config.shuffle == ShufflingType.full: + generator.manual_seed(self._config.seed) + # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` + document_shuffling = ( + torch.randperm( + shuffled_documents, + generator=generator, + dtype=get_unsigned_integer_type(shuffled_documents).torch, + device=self._device, + ) + .remainder_(documents_per_epoch) + .to(dtype=document_shuffling_dtype) + ) + elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): + document_shuffling = torch.empty( + shuffled_documents, + dtype=document_shuffling_dtype, + device=self._device, + ) + for i in range(shuffled_epochs): + generator.manual_seed(self._config.seed + i * 571) + torch.randperm( + documents_per_epoch, + generator=generator, + out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], + ) + elif self._config.shuffle == ShufflingType.disabled: + document_shuffling = None + else: + raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + return document_shuffling + def __len__(self) -> int: return self._parameters.num_samples - def __getitem__(self, index: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. - The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). + The returned sample is ready to be concatenated, then fed to a `Model`. """ self._lazy_load() - if self._parameters.use_preference_loss_spans: - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - use_preference_loss_spans=self._parameters.use_preference_loss_spans, - ) - - chosen_span_end = sample.chosen_span[1] + 1 - sequence_lengths = [ - chosen_span_end, - len(sample.token_ids) - chosen_span_end, - ] - - # compute padding size - padding = np.full((self._parameters.sequence_length + 1,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - if not self._parameters.cross_document_attention: - sample.sequence_lengths = np.array(sequence_lengths) - - return sample - # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample sample_length = ( @@ -441,7 +414,13 @@ def __getitem__(self, index: int) -> typing.Any: token_count = token_start_array[token_start_cumsum_index] token_ids = [] - loss_masking_spans = [] + if self._parameters.use_loss_masking_spans: + loss_masking_spans = [] + if self._parameters.use_images: + images = [] + image_positions = [] + image_tokens_added = 0 + text_tokens_added = 0 while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -449,7 +428,7 @@ def __getitem__(self, index: int) -> typing.Any: else: document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - document_size = self._indexed_dataset.get_document_size(document_index) + document_size = self._indexed_dataset.get_document_size(document_index, self._parameters) if not self._truncate_documents: if document_size > self._parameters.sequence_length + 1: @@ -461,28 +440,85 @@ def __getitem__(self, index: int) -> typing.Any: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: - # Add padding tokens to current sample - token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + documents.append(documents[-1].get_padding(padding_size)) Assert.eq(token_count + padding_size, token_end) break else: # Move on to the next sample. token_count += padding_size + continue + elif document_size + tokens_in_sample == self._parameters.sequence_length + 1: + if token_count + document_size == token_start: + token_count += document_size + document_sampling_index += 1 + continue # Determine if the document belongs to the requested sample. if token_count + document_size > token_start: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) - sample = self._indexed_dataset.get( + sample: GPTSample = self._indexed_dataset.get( document_index, offset=token_start_index_in_document, length=token_end_index_in_document - token_start_index_in_document, use_loss_masking_spans=self._parameters.use_loss_masking_spans, ) - token_ids.append(sample.token_ids) + if self._parameters.use_images: + start_pos = 0 + sample_token_ids = [] + for idx, im_position in enumerate(sample.image_positions): + # add placeholder masked tokens for images + # if image_break_token is set, it is appended after every row + # if image_end_token is set, it is appended at the end of the image instead of image_break_token + text_part = sample.token_ids[start_pos:im_position] + if self._parameters.image_break_token is not None: + height, width = resized_image_lengths[idx] + num_patches_h = div(height, self._parameters.patch_size) + num_patches_w = div(width, self._parameters.patch_size) + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + # account for break tokens after each row + for row in range(num_patches_h - 1): + position = (row + 1) * num_patches_w + row + image_token_array[position] = self._parameters.image_break_token + # handle the last row separately + last_row_position = num_patches_h * num_patches_w + num_patches_h - 1 + if self._parameters.image_end_token is not None: + image_token_array[last_row_position] = self._parameters.image_end_token + else: + image_token_array[last_row_position] = self._parameters.image_break_token + else: + image_token_array = np.full((image_sizes[idx],), -100, dtype=np.int64) + if self._parameters.image_end_token is not None: + image_token_array[-1] = self._parameters.image_end_token + sample_token_ids.append(np.concatenate([text_part, image_token_array], dtype=np.int64)) + text_tokens_added += len(text_part) + image_positions.append(text_tokens_added + image_tokens_added) + image_tokens_added += image_sizes[idx] + start_pos = im_position + # Add the last text segment after the last image + sample_token_ids.append(sample.token_ids[start_pos:]) + text_tokens_added += len(sample_token_ids[-1]) + token_ids.append(np.concatenate(sample_token_ids)) + images.append(sample.images) + else: + token_ids.append(sample.token_ids) + text_tokens_added += len(token_ids[-1]) if self._parameters.use_loss_masking_spans: for loss_masking_span in sample.loss_masking_spans: + if self._parameters.use_images: + # Shift the spans to account for the images. + loss_masking_span[0] += sum( + image_size + for image_size, image_position in zip(image_sizes, sample.image_positions) + if image_position < loss_masking_span[0] + ) + loss_masking_span[1] += sum( + image_size + for image_size, image_position in zip(image_sizes, sample.image_positions) + if image_position < loss_masking_span[1] + ) + span = np.clip( loss_masking_span + token_count - token_start, 0, @@ -500,34 +536,59 @@ def __getitem__(self, index: int) -> typing.Any: if not self._parameters.cross_document_attention else None ) + token_ids = np.concatenate(token_ids, dtype=np.int64) loss_masking_spans = ( (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) if self._parameters.use_loss_masking_spans else None ) + images = [im for img_list in images for im in img_list] if self._parameters.use_images else None + image_positions = np.array(image_positions) if self._parameters.use_images else None Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + return GPTSample( + token_ids=token_ids, + loss_masking_spans=loss_masking_spans, + sequence_lengths=sequence_lengths, + images=images, + image_positions=image_positions, + ) @property def name(self) -> str: return self._indexed_dataset.name + def _get_image_sizes(self, document_index: int): + # TODO: Duplicate of _get_document_sizes + image_lengths = self._indexed_dataset.get_image_size(document_index) + + resized_image_lengths = [ + get_resize_dims( + *image_length, + self._parameters.max_image_size, + self._parameters.max_image_size, + self._parameters.patch_size, + ) + for image_length in image_lengths + ] + image_sizes = [ + get_num_image_tokens( + *image_length, + self._parameters.patch_size, + image_break=self._parameters.image_break_token is not None, + image_end=self._parameters.image_end_token is not None, + ) + for image_length in resized_image_lengths + ] + image_tokens = sum(image_sizes) + return resized_image_lengths, image_sizes, image_tokens + def _lazy_load(self): if not hasattr(self, "_documents_per_epoch"): self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - - if self._parameters.use_preference_loss_spans: - data["unshuffled_tokens"] = 0 # not used, ignore - elif "unshuffled_tokens" not in data: - # Backward compatibility - # TODO v0.x: Remove - assert self._truncate_documents - data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] - self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d2aaee5e2..da353793d 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -42,6 +42,18 @@ class TextColumnConfig(SourceSchemaConfig): ) +@config_class(dynamic_type={SourceSchemaConfig: "text_image_column"}) +class TextImageColumnConfig(TextColumnConfig): + images_column: str = Field( + default="images", + desc="Field containing images relevant to a document.", + ) + image_positions_column: None | str = Field( + default="image_positions", + desc="Field containing image positions within a document.", + ) + + @config_class() class GPTHuggingfaceDatasetConfig(Config): path: str = Field( @@ -175,6 +187,11 @@ class GPTMemmapDatasetPreparatorConfig(DatasetPreparatorConfig): desc="Configuration for the tokenizer.", hint=FieldHint.feature, ) + image_patch_size: int = Field( + default=16, + desc="Patch size for images. This is used solely for computing the number of tokens in an image to get an even split.", + hint=FieldHint.optional, + ) splits: dict[str, float] | None = Field( default=None, desc="Split the output dataset into multiple ones (ex, train/valid/test) with the specified ratios." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 33c40bf8f..123ea8129 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -1,3 +1,5 @@ +import io +import itertools import json import logging import multiprocessing @@ -8,23 +10,27 @@ import datasets import huggingface_hub import numpy as np +import PIL.Image import requests import torch.distributed import tqdm import transformers import yaml -from fast_llm.data.dataset.gpt.config import ( - GPTBlendedDatasetConfig, - GPTDatasetSliceConfig, - GPTIndexedDatasetConfig, - GPTMemmapDatasetConfig, - GPTSampledDatasetConfig, +from fast_llm.data.dataset.config import ( + BlendedDatasetConfig, + DatasetSliceConfig, + IndexedDatasetConfig, + SampledDatasetConfig, ) +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator -from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.preparator.gpt_memmap.config import ( + GPTMemmapDatasetPreparatorConfig, + TextColumnConfig, + TextImageColumnConfig, +) from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -37,38 +43,47 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _data_type: DataType _text_column: str _loss_masking_spans_column: str | None + _sample_type: typing.ClassVar[type[LanguageModelSample]] = LanguageModelSample def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids = [ - np.array(self._tokenizer.tokenize(text), dtype=self._data_type.numpy) for text in batch[self._text_column] - ] - num_tokens = [len(x) for x in input_ids] - return { - "input_ids": input_ids, - "num_tokens": num_tokens, - } - - def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: - input_ids, token_spans = map( + input_ids, token_spans, image_token_positions = map( list, zip( *[ ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans, dtype=np.int32).reshape(-1, 2), + np.array(image_token_positions, dtype=np.int32), ) - for input_ids, token_spans in [ - self._tokenizer.tokenize_with_spans(text, char_spans) - for text, char_spans in zip(batch[self._text_column], batch[self._loss_masking_spans_column]) + for input_ids, token_spans, image_token_positions in [ + self._tokenizer.tokenize( + text, + loss_mask_spans, + im_char_positions, + ) + for text, loss_mask_spans, im_char_positions in zip( + batch[self._text_column], + batch.get(self._loss_masking_spans_column, itertools.repeat(None)), + batch.get(self._image_positions_column, itertools.repeat(None)), + ) ] ] ), ) num_tokens = [len(x) for x in input_ids] + num_pixels = [0] * len(input_ids) + for idx, images in enumerate(batch.get("images", [])): + for bytes_im in images: + with PIL.Image.open(io.BytesIO(bytes_im["bytes"])) as im: + width, height = im.size + num_pixels[idx] += width * height * 3 + return { "input_ids": input_ids, + "image_positions": image_token_positions, "token_spans": token_spans, "num_tokens": num_tokens, + "num_pixels": num_pixels, } def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: @@ -141,27 +156,22 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon shard_output_path = self._config.output_path / prefix def _document_generator(): - if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), - ) - elif ( - "chosen_token_spans" in shard_dataset.column_names - and "rejected_token_spans" in shard_dataset.column_names - and self._config.dataset.chosen_text is not None - and self._config.dataset.rejected_text is not None - ): - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), - ) - else: - for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + has_preference_spans = ( + self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None + ) + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + np.array(item["input_ids"], dtype=self._data_type.numpy), + item["images"] if self._images_column else None, + item["image_positions"] if self._image_positions_column else None, + ( + np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2) + if self._loss_masking_spans_column + else None + ), + item["chosen_token_spans"] if has_preference_spans else None, + item["rejected_token_spans"] if has_preference_spans else None, + ) GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) @@ -171,6 +181,7 @@ def _document_generator(): "path": prefix, "num_documents": len(shard_dataset), # Use the length of the shard dataset directly "num_tokens": sum(len(doc["input_ids"]) for doc in shard_dataset), + "num_pixels": sum(doc["num_pixels"] for doc in shard_dataset), } ) @@ -240,7 +251,7 @@ def run(self) -> None: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True # Load tokenizer - self._tokenizer = Tokenizer(config=self._config.tokenizer) + self._tokenizer = self._config.tokenizer.get_tokenizer() # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( @@ -290,6 +301,11 @@ def run(self) -> None: if isinstance(self._config.dataset.source_schema, TextColumnConfig): self._text_column = self._config.dataset.source_schema.input_column self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column + if isinstance(self._config.dataset.source_schema, TextImageColumnConfig): + self._images_column = self._config.dataset.source_schema.images_column + self._image_positions_column = self._config.dataset.source_schema.image_positions_column + # decoding bytes to images is slow and should be done only when needed + dataset = dataset.cast_column("images", datasets.Sequence(datasets.Image(decode=False))) else: raise ValueError( f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'." @@ -298,18 +314,17 @@ def run(self) -> None: if self._text_column not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._text_column}'.") - if self._config.dataset.source_schema.loss_masking_spans_column is not None and ( + if self._loss_masking_spans_column is not None and ( self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None ): - raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if self._loss_masking_spans_column not in dataset.column_names: + raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") # route tokenize function - if self._loss_masking_spans_column is not None: - if self._loss_masking_spans_column not in dataset.column_names: - raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.") - tokenize_fn = self._tokenize_batch_with_spans elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: if self._config.dataset.chosen_text not in dataset.column_names: raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") @@ -329,6 +344,13 @@ def run(self) -> None: # Calculate total number of tokens total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens")) + total_pixels = ( + sum(tqdm.tqdm(tokenized_dataset["num_pixels"], desc="Counting pixels", unit="pixels")) + if self._images_column + else 0 + ) + # Add the token-equivalent bytes of pixels to determine shard size + total_tokens += total_pixels // np.dtype(self._data_type.numpy).itemsize # Split dataset into shards based on number of tokens num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard)) @@ -357,7 +379,7 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa # Create the config file(s) on rank 0 if self._config.splits: for split_name, split_config in self._split_and_blend_dataset_configs( - dataset_configs, self._config.splits, self._config.output_path + dataset_configs, self._config.splits, self._config.output_path, self._config.image_patch_size ).items(): self._save_dataset_config( split_config, self._config.output_path / f"fast_llm_config_{split_name}.yaml" @@ -376,7 +398,9 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa torch.distributed.destroy_process_group() @classmethod - def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_path: pathlib.Path) -> None: + def _save_dataset_config( + cls, dataset_config: IndexedDatasetConfig[_sample_type], output_path: pathlib.Path + ) -> None: logger.info(f"Saving config to {output_path}") yaml.safe_dump( dataset_config.to_dict(), @@ -384,10 +408,12 @@ def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_pa ) @classmethod - def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) -> GPTIndexedDatasetConfig: + def _blend_dataset_configs( + cls, dataset_configs: list[GPTMemmapDatasetConfig[_sample_type]] + ) -> IndexedDatasetConfig[_sample_type]: if len(dataset_configs) == 1: return dataset_configs[0] - return GPTSampledDatasetConfig.from_dict( + return SampledDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": dataset_configs, @@ -397,7 +423,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path + cls, + dataset_configs: list[GPTMemmapDatasetConfig], + splits: dict[str, int | float], + output_path: pathlib.Path, + image_patch_size: None | int = None, ) -> dict[str, GPTSampledDatasetConfig]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] @@ -427,13 +457,23 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() - Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) - begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) - end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) + text_sizes, image_sizes = dataset.get_document_sizes() + tokens_cumsum = text_sizes.cumsum() + Assert.eq(tokens_cumsum[-1], dataset_config.num_tokens) + if image_sizes: + num_pixels_cumsum = np.cumsum([x.prod(axis=1).sum() for x in image_sizes]) + # We use the patch sizes only for the purposes of even splitting and blending weights. + # We can always use a different patch size for training without any significant impact + # Unless the patch size used at training time is significantly different from the one used here + image_tokens_cumsum = num_pixels_cumsum // (image_patch_size**2) + tokens_cumsum += image_tokens_cumsum + num_pixels_cumsum = num_pixels_cumsum * 3 + Assert.eq(num_pixels_cumsum[-1], dataset_config.num_pixels) + begin_index = _get_nearest_split(tokens_cumsum, split_begin_in_dataset * tokens_cumsum[-1]) + end_index = _get_nearest_split(tokens_cumsum, split_end_in_dataset * tokens_cumsum[-1]) if end_index > begin_index: datasets_in_split.append( - GPTDatasetSliceConfig.from_dict( + DatasetSliceConfig[cls._sample_type].from_dict( { "type": "slice", "dataset": dataset_configs[dataset_index], @@ -443,8 +483,8 @@ def _split_and_blend_dataset_configs( ) ) dataset_tokens_in_split.append( - sizes_cumsum[end_index - 1].item() - - (sizes_cumsum[begin_index - 1].item() if begin_index > 0 else 0) + tokens_cumsum[end_index - 1].item() + - (tokens_cumsum[begin_index - 1].item() if begin_index > 0 else 0) ) # [else] None of the dataset belongs to the split. @@ -455,7 +495,7 @@ def _split_and_blend_dataset_configs( elif len(datasets_in_split) == 1: dataset_splits[split_name] = datasets_in_split[0] else: - dataset_splits[split_name] = GPTBlendedDatasetConfig.from_dict( + dataset_splits[split_name] = BlendedDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": datasets_in_split, diff --git a/fast_llm/data/sample/__init__.py b/fast_llm/data/sample/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py new file mode 100644 index 000000000..031002101 --- /dev/null +++ b/fast_llm/data/sample/abstract.py @@ -0,0 +1,42 @@ +import abc +import typing + +if typing.TYPE_CHECKING: + import torch + + +class Sample(abc.ABC): + @classmethod + @abc.abstractmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + pass + + @abc.abstractmethod + def crop(self, begin: int, end: int) -> typing.Self: + pass + + @abc.abstractmethod + def __len__(self) -> int: + pass + + @abc.abstractmethod + def get_padding(self, size: int) -> typing.Self: + pass + + +class Batch(abc.ABC): + # TODO: Relate to `BatchConfig`? + @classmethod + @abc.abstractmethod + def from_samples(cls, samples: typing.Iterable[Sample]) -> typing.Self: + pass + + @abc.abstractmethod + def to_samples(self) -> list[Sample]: + pass + + def crop(self, begin: int, end: int) -> typing.Self: + return self.from_samples(sample.crop(begin, end) for sample in self.to_samples()) + + def to_device_(self, device: "torch.device | str"): + pass diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py new file mode 100644 index 000000000..f30188553 --- /dev/null +++ b/fast_llm/data/sample/language_model.py @@ -0,0 +1,107 @@ +import typing + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.data.sample.range import RangeBatch, RangeSample +from fast_llm.data.sample.token import TokenBatch, TokenSample + + +class LanguageModelSample(Sample): + def __init__( + self, + tokens: TokenSample, + loss_masking_spans: RangeSample | None = None, + chosen_spans: RangeSample | None = None, + rejected_spans: RangeSample | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + TokenSample.from_documents([document.tokens for document in documents]), + _merge_optional(RangeSample.from_documents, [document.loss_masking_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def __len__(self) -> int: + return len(self.tokens) + + def get_padding(self, size: int) -> typing.Self: + return LanguageModelSample( + self.tokens.get_padding(size), + None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), + None if self.chosen_spans is None else self.chosen_spans.get_padding(size), + None if self.rejected_spans is None else self.rejected_spans.get_padding(size), + ) + + +class LanguageModelBatch(Batch): + def __init__( + self, + tokens: TokenBatch, + loss_masking_spans: RangeBatch | None = None, + chosen_spans: RangeBatch | None = None, + rejected_spans: RangeBatch | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans + + @classmethod + def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: + return cls( + TokenBatch.from_samples([sample.tokens for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), + ) + + def to_samples(self) -> list[LanguageModelSample]: + return [ + LanguageModelSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) + for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( + self.tokens.to_samples(), + self.loss_masking_spans.to_samples(), + self.chosen_spans.to_samples(), + self.rejected_spans.to_samples(), + strict=True, + ) + ] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), + ) + + def to_device_(self, device: "torch.device | str"): + self.tokens.to_device_(device) + if self.loss_masking_spans is not None: + self.loss_masking_spans.to_device_(device) + if self.chosen_spans is not None: + self.chosen_spans.to_device_(device) + if self.rejected_spans is not None: + self.rejected_spans.to_device_(device) + + +def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: + return None if any(arg is None for arg in args) else fn(args) + + +def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) -> T | None: + return None if sample_or_batch is None else sample_or_batch.crop(begin, end) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py new file mode 100644 index 000000000..d121a38b6 --- /dev/null +++ b/fast_llm/data/sample/range.py @@ -0,0 +1,49 @@ +import typing + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.utils import get_unique + + +class RangeSample(Sample): + """ + A reusable component holding a set of ranges in a sample. + """ + + def __init__(self, ranges: list[tuple[int, int]], sample_size: int): + self.ranges = ranges + self.sample_size = sample_size + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + document: RangeSample + ranges = [] + sample_size = 0 + for document in documents: + for begin, end in document.ranges: + ranges.extend((begin + sample_size, end + sample_size)) + sample_size += document.sample_size + return cls(ranges, sample_size) + + def crop(self, begin: int, end: int) -> typing.Self: + sample_size = end - begin + cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, sample_size)) for begin_, end_ in self.ranges) + return self.__class__([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) + + def __len__(self) -> int: + return self.sample_size + + def get_padding(self, size: int) -> typing.Self: + return RangeSample([], size) + + +class RangeBatch(Batch): + def __init__(self, ranges: list[list[tuple[int, int]]], sample_size: int): + self.sample_size = sample_size + self.ranges = ranges + + @classmethod + def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: + return cls([sample.ranges for sample in samples], get_unique(sample.sample_size for sample in samples)) + + def to_samples(self) -> list[RangeSample]: + return [RangeSample(sample_ranges, self.sample_size) for sample_ranges in self.ranges] diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py new file mode 100644 index 000000000..62d1c0e67 --- /dev/null +++ b/fast_llm/data/sample/token.py @@ -0,0 +1,75 @@ +import typing + +import torch + +from fast_llm.data.sample.abstract import Batch, Sample +from fast_llm.utils import Assert + + +class TokenSample(Sample): + def __init__(self, tokens: torch.Tensor, lengths: list[int] | None = None): + self.tokens = tokens + # Length of each document in the sample. TODO: Use cumsums instead? + if lengths is None: + lengths = [len(tokens)] + else: + Assert.eq(sum(lengths), len(tokens)) + self.lengths = lengths + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + torch.cat([document.tokens for document in documents]), + sum((document.lengths for document in documents), []), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + sample_size = end - begin + if self.lengths == [len(self.tokens)]: + # Shortcut for the frequent case of a single document. + lengths = [sample_size] + else: + begin_ = 0 + lengths = [] + for length in self.lengths: + end_ = begin_ + length + cropped_length = min(end_, end) - max(begin_, begin) + if cropped_length > 0: + lengths.append(cropped_length) + if end_ > end: + break + begin_ = end_ + return self.__class__(self.tokens[begin:end], lengths) + + def __len__(self) -> int: + return len(self.tokens) + + def get_padding(self, size: int) -> typing.Self: + return TokenSample(torch.full([size], -100, dtype=self.tokens.dtype), [size]) + + +class TokenBatch(Batch): + def __init__(self, tokens: torch.Tensor, lengths: list[list[int]] | None) -> None: + self.tokens = tokens + if lengths is None: + lengths = [[tokens.size(1)]] * tokens.size(0) + self.lengths = lengths + + @classmethod + def from_samples(cls, samples: typing.Iterable[TokenSample]) -> typing.Self: + return cls( + torch.stack([sample.tokens for sample in samples]), + [sample.lengths for sample in samples], + ) + + def to_samples(self) -> list[TokenSample]: + return [TokenSample(tokens, lengths) for tokens, lengths in zip(self.tokens, self.lengths, strict=True)] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens[:, begin:end], [sample.crop(begin, end).lengths for sample in self.to_samples()] + ) + + def to_device_(self, device: "torch.device | str"): + # Also standardize the dtype while we're here. + self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index c74586207..23a839af7 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -41,44 +41,77 @@ def vocab(self) -> dict[str, int]: def inv_vocab(self) -> dict[int, str]: return self._inv_vocab - def tokenize(self, text: str, begin=True, end=True) -> list[int]: + def _tokenize(self, text: str, begin=True, end=True) -> list[int]: return ( ([self.bod_id] if begin else []) + self.tokenizer.encode(text, add_special_tokens=False) + ([self.eod_id] if end else []) ) - def tokenize_with_spans( - self, text: str, char_spans: list[tuple[int, int]] - ) -> tuple[list[int], list[tuple[int, int]]]: + def tokenize( + self, text: str, add_bos=True, add_eos=True, char_spans=None, image_positions=None + ) -> tuple[list[int], list[tuple[int, int]], list[int]]: """ - Perform span-aware tokenization and return the tokenized input_ids along with token spans. + Tokenize the input text and return the tokenized input_ids, token spans, and image token positions. + This version simplifies logic by merging all relevant positions, sorting, and tokenizing between them. """ - input_ids = [] + if not image_positions: + image_positions = [] + if not char_spans: + char_spans = [] + + # Collect all positions with their type + positions = [] + for pos in image_positions: + positions.append((pos, "image")) + + for start, end in char_spans: + positions.append((start, "span_start")) + positions.append((end + 1, "span_end")) + # Sort positions by character index. We assume that image and span positions are individually sorted and spans do not overlap + positions = sorted(positions, key=lambda x: x[0]) + + token_ids = [] token_spans = [] + image_token_positions = [] char_pos = 0 - beginning_of_text = True + current_span_start = None - for start, end in char_spans: - if char_pos < start: - curr_text = text[char_pos:start] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - input_ids.extend(tokenized_text) - curr_text = text[start : end + 1] - if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - else: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) - beginning_of_text = False - token_spans.append((len(input_ids), len(input_ids) + len(tokenized_text) - 1)) - input_ids.extend(tokenized_text) - char_pos = end + 1 + for position in positions: + # We only tokenize if there is at least one character, else we might potentially add begin/end multiple times + if char_pos < position[0]: + tokenized_text = self._tokenize( + text[char_pos : position[0]], + begin=add_bos and char_pos == 0, + end=add_eos and position[0] > len(text) - 1, + ) + token_ids.extend(tokenized_text) + char_pos = position[0] + # beginning_of_text = False + if position[1] == "image": + if position[0] == 0: + # image should be after the bos token + image_token_positions.append(1) + else: + image_token_positions.append(len(token_ids)) + elif position[1] == "span_start": + assert ( + current_span_start is None + ), "Starting a new span before current has ended, please check for overlapping spans" + current_span_start = len(token_ids) + elif position[1] == "span_end": + assert ( + current_span_start is not None + ), "Closing a span that has not started, please check for overlapping spans" + # spans are inclusive, so we take the index of the last token in the span + token_spans.append((current_span_start, len(token_ids) - 1)) + current_span_start = None + # Handle any remaining text after the last position and add EOS token if char_pos < len(text): - curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) - input_ids.extend(tokenized_text) - return input_ids, token_spans + tokenized_text = self._tokenize(text[char_pos:], begin=add_bos and char_pos == 0, end=add_eos) + token_ids.extend(tokenized_text) + + return token_ids, token_spans, image_token_positions def detokenize(self, token_ids: int | list[int] | np.ndarray | torch.Tensor) -> str: return self.tokenizer.decode(token_ids) diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index f4a2cfd6c..1a0fed91b 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -23,8 +23,10 @@ class DataType(enum.StrEnum): int32 = "int32" int16 = "int16" int8 = "int8" - uint8 = "uint8" + uint64 = "uint64" + uint32 = "uint32" uint16 = "uint16" + uint8 = "uint8" @classmethod def _missing_(cls, dtype: str) -> "DataType": @@ -105,6 +107,9 @@ def _set_torch_dtype_map() -> None: DataType.int32: torch.int32, DataType.int16: torch.int16, DataType.int8: torch.int8, + DataType.uint64: torch.uint64, + DataType.uint32: torch.uint32, + DataType.uint16: torch.uint16, DataType.uint8: torch.uint8, } _TORCH_DTYPE_MAP_INV = {y: x for x, y in _TORCH_DTYPE_MAP.items()} @@ -127,8 +132,10 @@ def _set_numpy_dtype_map() -> None: DataType.int32: np.int32, DataType.int16: np.int16, DataType.int8: np.int8, - DataType.uint8: np.uint8, + DataType.uint64: np.uint64, + DataType.uint32: np.uint32, DataType.uint16: np.uint16, + DataType.uint8: np.uint8, } _NUMPY_DTYPE_MAP_INV = {y: x for x, y in _NUMPY_DTYPE_MAP.items()} @@ -151,6 +158,9 @@ def _set_triton_dtype_map() -> None: DataType.int32: tl.int32, DataType.int16: tl.int16, DataType.int8: tl.int8, + DataType.uint64: tl.uint64, + DataType.uint32: tl.uint32, + DataType.uint16: tl.uint16, DataType.uint8: tl.uint8, } @@ -158,6 +168,7 @@ def _set_triton_dtype_map() -> None: def get_unsigned_integer_type(max_size: int) -> DataType: + # TODO: Use uint types (recently added for torch, not enough methods supported yet) if max_size < 2**8: return DataType.uint8 elif max_size < 2**15: diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 4f035e174..f8dfd4825 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -2,6 +2,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.data.config import TokenizerConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert @@ -63,6 +64,9 @@ def get_evaluator( class LmEvalEvaluatorConfig(EvaluatorConfig): _abstract: typing.ClassVar[bool] = False + tokenizer: TokenizerConfig = Field( + desc="Configuration for the tokenizer.", + ) cli_args: list[str] = Field( default_factory=lambda: [], desc="lm_eval CLI arguments, excluding those related to model, wandb, batch sizes, and device.", diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 14aed65c4..5bfb544ed 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -60,7 +60,7 @@ def setup( self._flm_wrapper = FastLLMLmEvalWrapper( model=self._hf_model, - tokenizer=self._data.tokenizer.tokenizer, + tokenizer=self._config.tokenizer.get_tokenizer(), truncation=self._config.truncation, logits_cache=self._config.logits_cache, add_bos_token=self._config.add_bos_token, diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 3a70f308f..7ab0b9ff6 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -1,51 +1,25 @@ import torch -def _compute_logprobs_for_preference_spans( - logits: torch.Tensor, targets: torch.Tensor, chosen_spans: torch.Tensor, rejected_spans: torch.Tensor -): - assert torch.all(targets < logits.size(-1)), "Target out of vocab range" +def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): + # Gather log probabilities corresponding to the target tokens + return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - # gather log probabilities corresponding to the target tokens - selected_log_probs = log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - - # apply chosen mask - chosen_logp = 0 - for idx, span in enumerate(chosen_spans): - chosen_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - - # apply rejected mask - rejected_logp = 0 - for idx, span in enumerate(rejected_spans): - rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - - return chosen_logp, rejected_logp, selected_log_probs - - -def _compute_dpo_loss( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta: float, -): - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - diff_logratios = pi_logratios - ref_logratios - - losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) - return losses +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans + ) def compute_dpo_loss( logits: torch.Tensor, targets: torch.Tensor, reference_model_logits: torch.Tensor, - chosen_spans: torch.Tensor, - rejected_spans: torch.Tensor, + chosen_spans: list[list[tuple[int, int]]], + rejected_spans: list[list[tuple[int, int]]], beta: float, grad_output: float | None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -53,21 +27,18 @@ def compute_dpo_loss( logits_ = logits.float().detach().requires_grad_() reference_model_logits_ = reference_model_logits.float().detach() - policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( - logits_, targets, chosen_spans, rejected_spans - ) + policy_log_probabilities = _get_target_log_probabilities(logits_, targets) + policy_log_ratios = _get_target_log_probability_for_spans( + policy_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) - reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( - reference_model_logits_, targets, chosen_spans, rejected_spans - ) + reference_log_probabilities = _get_target_log_probabilities(reference_model_logits_, targets) + reference_log_ratios = _get_target_log_probability_for_spans( + reference_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) - losses = _compute_dpo_loss( - policy_chosen_logps=policy_chosen_logps, - policy_rejected_logps=policy_rejected_logps, - reference_chosen_logps=reference_chosen_logps, - reference_rejected_logps=reference_rejected_logps, - beta=beta, - ) + # TODO: ====== Shouldn't the sigmoid be computed independently for each document? + losses = -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)) if grad_output is None: loss = None diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 167184193..ffbe9955e 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -5,11 +5,12 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -79,7 +80,12 @@ def __init__( peft=peft, return_bias=return_bias, ) - self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) + self._implementation = self._config.implementation + if self._implementation == AttentionImplementation.auto: + if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16): + self._implementation = AttentionImplementation.flash + else: + self._implementation = AttentionImplementation.backup self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -209,8 +215,7 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - with set_generator(self._distributed.tp_generator): - attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) + attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value ) @@ -328,29 +333,10 @@ def _forward( query, key = self._rotary(query, key, kwargs) window_size = (-1, -1) if self._config.window_size is None else (self._config.window_size - 1, 0) - - if self._use_flash_attention: - assert _flash_available - with set_generator(self._distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: - out_dims = query.size() - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) - input_ = _flash_attn_varlen_func( - query, - key, - value, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), - dropout_p=self._config.dropout if self.training else 0.0, - window_size=window_size, - causal=self._config.causal, - softmax_scale=self._softmax_scale, - ).view(*out_dims) - else: + with set_generator(self._distributed.tp_generator): + if self._implementation == AttentionImplementation.flash: + assert _flash_available + if self._config.cross_document_attention: input_ = _flash_attn_func( query, key, @@ -359,17 +345,36 @@ def _forward( dropout_p=self._config.dropout if self.training else 0.0, causal=self._config.causal, softmax_scale=self._softmax_scale, + ).flatten(-2) + else: + input_ = ( + _flash_attn_varlen_func( + query.view(-1, query.size(-2), query.size(-1)), + key.view(-1, key.size(-2), key.size(-1)), + value.view(-1, value.size(-2), value.size(-1)), + cu_seqlens_q=kwargs.get(AttentionKwargs.cu_seqlens_q), + cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), + dropout_p=self._config.dropout if self.training else 0.0, + window_size=window_size, + causal=self._config.causal, + softmax_scale=self._softmax_scale, + ) + .view(query.size()) + .flatten(-2) ) - input_ = input_.flatten(-2) - else: - # TODO: Avoid the flattens. - input_ = self._attn_fused( - query.flatten(-2), - key.flatten(-2), - value.flatten(-2), - kwargs[AttentionKwargs.attention_mask], - kwargs[AttentionKwargs.attention_mask_value], - ) + elif self._implementation == AttentionImplementation.backup: + # TODO: Avoid the flattens. + input_ = self._attn_fused( + query.flatten(-2), + key.flatten(-2), + value.flatten(-2), + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], + ) + else: + raise NotImplementedError(self._implementation) if self._debug.enabled: self._debug(query, "query", self._query_dims, kwargs) @@ -413,8 +418,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c attention_compute = sequence_q * sequence_k * attn_compute_base - if (not config.hardware) or self._use_flash_attention: + if (not config.hardware) or self._implementation in AttentionImplementation.flash: # Remove non-causal part. (TODO: Support non-causal) + # TODO: Compute is overestimated without cross-document attention. attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 if self._config.window_size is not None: @@ -439,10 +445,10 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(batch, kwargs) - if not self._use_flash_attention: + if self._implementation == AttentionImplementation.backup: self._preprocess_for_backup_attention(batch, kwargs) - elif AttentionKwargs.sequence_lengths in kwargs: - self._preprocess_for_varlen(batch, kwargs) + elif self._implementation == AttentionImplementation.flash: + self._preprocess_for_flash_attention(batch, kwargs) def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: if ( @@ -471,11 +477,11 @@ def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str kwargs[AttentionKwargs.attention_mask] = self._backup_attention_mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: + if not self._config.cross_document_attention: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) - for sample_lens in sequence_lengths + for sample_lens in kwargs[AttentionKwargs.sequence_lengths] ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) @@ -485,7 +491,7 @@ def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str ) kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value - def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + def _preprocess_for_flash_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 @@ -495,7 +501,7 @@ def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.A also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if AttentionKwargs.sequence_lengths not in kwargs: + if self._config.cross_document_attention: return sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 68b6dde91..206fa6e6f 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -1,10 +1,9 @@ +import enum import logging import typing import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockKwargs @@ -32,6 +31,12 @@ class AttentionKwargs(BlockKwargs): past_key_values = "past_key_values" +class AttentionImplementation(enum.StrEnum): + auto = "auto" + flash = "flash" + backup = "backup" + + @config_class(dynamic_type={MixerConfig: "attention"}) class AttentionConfig(MixerConfig): # TODO: Make mixer class dynamic. @@ -107,6 +112,17 @@ class AttentionConfig(MixerConfig): " Under muP (if scaling number of heads instead of head_size): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + implementation: AttentionImplementation = Field( + default=AttentionImplementation.auto, + desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.", + hint=FieldHint.feature, + ) + cross_document_attention: bool = Field( + default=True, + desc="Allow for cross-document attention.", + doc="Disable to prevent attention between tokens belonging to different documents.", + hint=FieldHint.feature, + ) def _validate(self) -> None: super()._validate() @@ -121,6 +137,3 @@ def layer_class(self) -> "type[Attention]": from fast_llm.layers.attention.attention import Attention return Attention - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) diff --git a/fast_llm/layers/attention/rotary/config.py b/fast_llm/layers/attention/rotary/config.py index 26877ee0c..74b5cf21a 100644 --- a/fast_llm/layers/attention/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -135,3 +135,11 @@ def _get_configurable_class(self) -> "type[YarnRotary]": from fast_llm.layers.attention.rotary.rotary import YarnRotary return YarnRotary + + +@config_class(dynamic_type={RotaryConfig: "default_2d"}) +class Rotary2DConfig(DefaultRotaryConfig): + def _get_configurable_class(self) -> "type[Rotary2D]": + from fast_llm.layers.transformer.rotary.rotary import Rotary2D + + return Rotary2D diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index d57d72947..6250fd4a9 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -12,6 +12,7 @@ DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, + Rotary2DConfig, RotaryConfig, YarnRotaryConfig, ) @@ -174,3 +175,49 @@ def _get_correction(self, beta: float, dim: int) -> float: * math.log(self._config.original_context_length / (beta * 2 * math.pi)) / (2 * math.log(self._config.theta)) ) + + +class Rotary2D[ConfigType: Rotary2DConfig](DefaultRotary[ConfigType]): + _rotary_embedding_frequencies: torch.Tensor + _tensor_cache_max_num_patches: int = -1 + _config: ConfigType + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors( + kwargs[VisionEncoderKwargs.max_image_size] // kwargs[VisionEncoderKwargs.patch_size], batch.device + ) + position_ids = kwargs[VisionTransformerKwargs.patch_position_ids] + kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[:, position_ids] + kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, position_ids] + + def forward( + self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] + ) -> tuple[torch.Tensor, torch.Tensor]: + rotary_fn = triton_rotary_autograd_ if self._config.triton else apply_rotary_embeddings + query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) + key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) + return query, key + + def _get_frequencies(self, sequence_length: int, head_size: int, device: torch.device) -> torch.Tensor: + max_num_patches = sequence_length + # Calculate complex frequencies by using alternating channels for width and height + height_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + width_positions = torch.arange(max_num_patches, device=device, dtype=torch.float64) + frequencies = self._config.theta ** -torch.arange(0, 1, 2 / head_size, device=device, dtype=torch.float64) + angles_h = torch.outer(height_positions, frequencies[::2]) + angles_w = torch.outer(width_positions, frequencies[1::2]) + angles = torch.cat( + [ + angles_h[:, None, :].repeat(1, max_num_patches, 1), + angles_w[None, :, :].repeat(max_num_patches, 1, 1), + ], + dim=-1, + ).reshape(-1, head_size // 2) + + frequencies = torch.polar(torch.ones_like(angles), angles)[None, :, None, :].to(torch.complex64) + if not self._config.complex_format: + frequencies = convert_rotary_complex_to_real( + torch.view_as_real(frequencies).flatten(-2), head_size, 3 + ).contiguous() + + return frequencies diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py index e7c6d9e92..0dc118269 100644 --- a/fast_llm/layers/common/linear/config.py +++ b/fast_llm/layers/common/linear/config.py @@ -1,7 +1,12 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.engine.config_utils.initialization import Initialization, init_uniform_centered_, init_zeros_ +from fast_llm.engine.config_utils.initialization import ( + Initialization, + init_normal_, + init_uniform_centered_, + init_zeros_, +) from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.config import ActivationType @@ -9,7 +14,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.common.linear.convolution import CausalConv1d + from fast_llm.layers.common.linear.convolution import CausalConv1d, Convolution2D from fast_llm.layers.common.linear.linear import LinearBase @@ -217,3 +222,44 @@ def get_layer( return CausalConv1d( weight, bias, activation=default_activation if self.activation is None else self.activation ) + + +@config_class +class Convolution2DConfig(AffineLinearBaseConfig): + def get_layer( + self, + in_dim: TensorDim, + out_dim: TensorDim, + kernel_dim_1: TensorDim, + kernel_dim_2: TensorDim, + *, + stride: tuple[int, int], + default_weight_initialization: Initialization | None = None, + default_bias_initialization: Initialization | None = None, + default_add_bias: bool = True, + lr_scale: float | None, + peft: PeftConfig | None, + ) -> "Convolution2D": + from fast_llm.layers.common.linear.convolution import Convolution2D + + if default_weight_initialization is None: + default_weight_initialization = init_normal_() + if default_bias_initialization is None: + default_bias_initialization = init_normal_() + + lr_scale = (combine_lr_scales(lr_scale, self.lr_scale),) + weight = self.weight.get_parameter( + (out_dim, in_dim, kernel_dim_1, kernel_dim_2), + default_initialization=default_weight_initialization, + lr_scale=lr_scale, + peft=peft, + ) + bias = self.bias.get_parameter( + (out_dim,), + default_initialization=default_bias_initialization, + lr_scale=lr_scale, + default_enabled=default_add_bias, + peft=peft, + ) + + return Convolution2D(weight, bias, stride=stride) diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index b88b7b2e6..6281348e1 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -55,3 +55,27 @@ def _forward_causal_conv1d(self, input_: torch.Tensor) -> torch.Tensor: def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: raise NotImplementedError() + + +class Convolution2D(torch.nn.Module): + """ + TODO: Generalize to other convolutions? + """ + + def __init__( + self, + weight: ParameterMeta, + bias: ParameterMeta | None, + *, + stride: tuple[int, int], + ): + super().__init__() + self.weight = weight + self.bias = bias + self._stride = stride + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.conv2d(input_, self.weight, self.bias, stride=self._stride) + + def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: + raise NotImplementedError() diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 25fa2d91e..18c64acc4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -53,6 +53,13 @@ class LanguageModelEmbeddingsConfig(BlockConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + cross_document_position_embeddings: bool = Field( + default=True, + desc="Allow for cross-document position embeddings.", + doc="Disable to reset position ids at the beginning of each document.", + hint=FieldHint.feature, + ) + dropout: float = Field( default=0.0, desc="Dropout applied to the embedding layer.", diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 0ad3225c8..b9d209274 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -3,7 +3,7 @@ import torch from fast_llm.core.distributed import set_generator -from fast_llm.core.ops import reduce_forward, split +from fast_llm.core.ops import gather, reduce_forward, split from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim @@ -14,6 +14,8 @@ from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert +WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" + class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[ConfigType]): """ @@ -26,7 +28,8 @@ class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[Co layer_count: float = 1000.0 _config: ConfigType - # Position embedding preprocessing + # Preprocessing + _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor _tensor_cache_max_sequence_length: int = -1 @@ -75,34 +78,62 @@ def __init__( ) @torch.compile - def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: + def _forward( + self, + input_: torch.Tensor, + token_ids: torch.Tensor, + position_ids: torch.Tensor | None, + mask_inputs: bool, + # TODO: Flatten the batch and sequence in the map? + embedding_map: tuple[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]] | None, + ) -> torch.Tensor: Assert.eq(position_ids is None, self.position_embeddings_weight is None) group = self._parallel_dim.group if self._vocab_parallel: - input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) - masked_input = (input_ - self._vocab_start_index) * input_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa + token_mask = (token_ids >= self._vocab_start_index) * (token_ids < self._vocab_end_index) + masked_input = (token_ids - self._vocab_start_index) * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * token_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) + # TODO: Input masking of position embeddings inconsistant with non-vocab-parallel if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) + + if embedding_map is not None: + # TODO: Accumulate redundant with masking? + input_index, embedding_index = embedding_map + if self._sequence_parallel: + input_ = gather(input_, group=group, dim=0) + embeddings = embeddings.index_put(embedding_index, input_[input_index], accumulate=True) + if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: - input_ = split(input_, group=group, dim=0) + token_ids = split(token_ids, group=group, dim=0) if self.position_embeddings_weight is not None: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: - input_mask = input_ >= 0 - masked_input = input_ * input_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) - else: - embeddings = torch.embedding(self.word_embeddings_weight, input_) + token_mask = token_ids >= 0 + token_ids = token_ids * token_mask + embeddings = torch.embedding(self.word_embeddings_weight, token_ids) if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: - embeddings = embeddings * input_mask.unsqueeze(2) + embeddings = embeddings * token_mask.unsqueeze(2) + + if embedding_map is not None: + # TODO: Accumulate redundant with masking? + input_index, embedding_index = embedding_map + if self._sequence_parallel: + # TODO:: Filter and shift embedding map instead? (needs cuda sync) + input_ = gather(input_, group=group, dim=0) + embeddings_ = embeddings.new_zeros(embeddings.shape[0] * group.size(), *embeddings.shape[1:]) + embeddings_.index_put(embedding_index, input_[input_index], accumulate=True) + embeddings = embeddings + split(embeddings_, group=group, dim=0) + else: + embeddings = embeddings.index_put(embedding_index, input_[input_index], accumulate=True) + with set_generator( self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): @@ -119,11 +150,17 @@ def forward( if isinstance(input_, TensorMeta): return TensorMeta.from_dims( kwargs[LanguageModelKwargs.hidden_dims], - tensor_name=f"{self.module_name} output", + tensor_name="Embedding output", dtype=self._residual_dtype, ) + return self._forward( - input_, kwargs.get(LanguageModelKwargs.position_ids), kwargs.get(LanguageModelKwargs.mask_inputs) + input_, + kwargs.get(LanguageModelKwargs.token_ids), + kwargs.get(LanguageModelKwargs.position_ids), + # TODO ====== Vision ====== Review input masking. + kwargs.get(LanguageModelKwargs.mask_inputs), + kwargs.get(LanguageModelKwargs.embedding_map), ) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: @@ -136,9 +173,12 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None self._create_position_embeddings(kwargs[LanguageModelKwargs.sequence_length], batch.device) sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: + if not self._config.cross_document_position_embeddings: position_ids = torch.stack( - [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + [ + torch.cat([torch.arange(x) for x in sample_lens]) + for sample_lens in kwargs[LanguageModelKwargs.sequence_lengths] + ] ).to(batch.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] if kwargs[LanguageModelKwargs.sequence_first]: diff --git a/fast_llm/layers/vision/__init__.py b/fast_llm/layers/vision/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/vision/config.py b/fast_llm/layers/vision/config.py new file mode 100644 index 000000000..1af986eef --- /dev/null +++ b/fast_llm/layers/vision/config.py @@ -0,0 +1,169 @@ +import typing + +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.layers.block.config import BlockConfig, BlockSequenceConfig +from fast_llm.layers.common.linear.config import Convolution2DConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.decoder.config import MLPBaseConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.layers.vision.vision_encoder import VisionEncoder + + +@config_class() +class ImageNormalizationConfig(Config): + mean_r: float = Field( + default=0.48145466, + desc="Mean value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_g: float = Field( + default=0.4578275, + desc="Mean value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + mean_b: float = Field( + default=0.40821073, + desc="Mean value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_r: float = Field( + default=0.26862954, + desc="Standard deviation value for the red channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_g: float = Field( + default=0.26130258, + desc="Standard deviation value for the green channel in the image normalization process.", + hint=FieldHint.optional, + ) + std_b: float = Field( + default=0.27577711, + desc="Standard deviation value for the blue channel in the image normalization process.", + hint=FieldHint.optional, + ) + rescale_factor: float = Field( + default=255.0, + desc="Rescale factor for the image normalization process.", + hint=FieldHint.optional, + ) + + +@config_class() +class PatchConvolutionConfig(BlockConfig): + _abstract = False + convolution: Convolution2DConfig = Field( + desc="Configuration for the 2d convolution.", + hint=FieldHint.architecture, + ) + normalization: NormalizationConfig = Field( + desc="Configuration for the normalization layer.", + hint=FieldHint.architecture, + ) + patch_size: int = Field( + default=16, + desc="Size of image patches, in pixels (width and height).", + hint=FieldHint.core, + ) + input_channels: int = Field( + default=3, + desc="Number of pixel channels (usually 3).", + hint=FieldHint.feature, + ) + + +@config_class(registry=True) +class VisionEncoderConfig(BlockConfig): + _abstract = False + patch_convolution: PatchConvolutionConfig = Field( + desc="Configuration for the patch convolution layer.", + hint=FieldHint.architecture, + ) + adapter: MLPBaseConfig = Field( + desc="Configuration for the adapter layer.", + hint=FieldHint.architecture, + ) + # TODO: ====== Appropriate name?? ====== + decoder: BlockSequenceConfig = Field( + desc="Configuration for the vision decoder.", + hint=FieldHint.architecture, + ) + hidden_size: int = Field( + default=1024, + desc="Size of the vision encoder main hidden dimension.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + + @property + def layer_class(self) -> "type[VisionEncoder]": + from fast_llm.layers.vision.vision_encoder import VisionEncoder + + return VisionEncoder + + # transformer: TransformerConfig = Field( + # desc="Configuration for the vision transformer architecture.", + # hint=FieldHint.core, + # ) + # patch_size: int = Field( + # default=16, + # desc="Patch size for the image encoder.", + # hint=FieldHint.core, + # ) + # conv_bias: bool = Field( + # default=False, + # desc="Whether to use bias in the convolutional layer.", + # hint=FieldHint.optional, + # ) + # patch_norm: NormalizationConfig = Field( + # desc="Configuration for the normalization layers applied to the image patches.", + # hint=FieldHint.optional, + # ) + # adapter_size: int = Field( + # default=5120, + # desc="Intermediate size for the adapter linear layers. Assuming 2 linear layers", + # hint=FieldHint.core, + # ) + # adapter_activation_type: ActivationType = Field( + # default=ActivationType.gelu, + # desc="The intermediate activation type for multi-modal adapter. Default: GeLU.", + # hint=FieldHint.core, + # ) + # adapter_bias: bool = Field( + # default=True, + # desc="Whether to use bias in the adapter linear layer.", + # hint=FieldHint.optional, + # ) + # image_normalization: ImageNormalizationConfig = Field( + # desc="Configuration for the normalization layers applied to the image patches.", + # hint=FieldHint.optional, + # ) + # image_break_token: int | None = Field( + # default=None, + # desc="Token id to separate image rows. If None, no token id is applied.", + # hint=FieldHint.optional, + # ) + # image_end_token: int | None = Field( + # default=None, + # desc="Token id to indicate the end of an image. If None, no token id is applied.", + # hint=FieldHint.optional, + # ) + # adapter_lr_scale: float | None = Field( + # default=None, + # desc="Custom learning rate scale for the adapter weights.", + # hint=FieldHint.feature, + # valid=skip_valid_if_none(check_field(Assert.geq, 0)), + # ) + # conv_lr_scale: float | None = Field( + # default=None, + # desc="Custom learning rate scale for the convolutional layer weights.", + # hint=FieldHint.feature, + # valid=skip_valid_if_none(check_field(Assert.geq, 0)), + # ) + # adapter_init_method_std: float = Field( + # default=None, + # desc="Standard deviation for the normal initialization of the adapter weights. Default: adapter_size ** -0.5.", + # hint=FieldHint.optional, + # valid=check_field(Assert.geq, 0), + # ) diff --git a/fast_llm/layers/vision/patch_convolution.py b/fast_llm/layers/vision/patch_convolution.py new file mode 100644 index 000000000..46cf86708 --- /dev/null +++ b/fast_llm/layers/vision/patch_convolution.py @@ -0,0 +1,71 @@ +import typing + +import torch + +from fast_llm.core.ops import split +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.block.block import Block +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.vision.config import PatchConvolutionConfig +from fast_llm.tensor import TensorMeta + + +class PatchConvolution[ConfigType: PatchConvolutionConfig](Block[ConfigType]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + # TODO: Input or output dim? + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + super().__init__( + config, + distributed_config, + hidden_dim=hidden_dim, + lr_scale=lr_scale, + peft=peft, + ) + input_dim = TensorDim("input_channels", self._config.input_channels) + patch_dim = TensorDim("patch", self._config.patch_size) + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + + self.convolution = self._config.convolution.get_layer( + self._hidden_dim, + input_dim, + patch_dim, + patch_dim, + stride=(self._config.patch_size, self._config.patch_size), + default_add_bias=False, + lr_scale=self._lr_scale, + peft=self._peft, + ) + self.normalization = self._config.normalization.get_layer(hidden_dim, lr_scale=self._lr_scale, peft=self._peft) + + def forward( + self, + input_: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict[str, typing.Any] | None = None, + metrics: dict | None = None, + ) -> torch.Tensor: + if isinstance(input_, TensorMeta): + return TensorMeta.from_dims( + input_.dims[:-1] + (self._hidden_dim,), tensor_name="patch conv output", dtype=input_.dtype + ) + # TODO: Avoid padding + input_ = self.convolution(input_) + patch_embeddings = self.normalization(input_.flatten(1)).view_as(input_) + + # TODO: Permute earlier? + if kwargs[AttentionKwargs.sequence_first]: + patch_embeddings = patch_embeddings.permute(1, 0, 2).contiguous() + if self._sequence_parallel: + patch_embeddings = split(patch_embeddings, group=self._parallel_dim.group, dim=0) + return patch_embeddings diff --git a/fast_llm/layers/vision/preprocessing.py b/fast_llm/layers/vision/preprocessing.py new file mode 100644 index 000000000..83331c739 --- /dev/null +++ b/fast_llm/layers/vision/preprocessing.py @@ -0,0 +1,194 @@ +import math +import typing + +import torch +import torchvision.transforms.v2 as torchvision_transforms + +from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.vision.config import ImageNormalizationConfig, VisionEncoderConfig +from fast_llm.utils import div + + +def get_num_patches(height: int, width: int, patch_size: int) -> int: + """ + Calculate the number of patches in height and width dimensions. + """ + return div(height, patch_size) * div(width, patch_size) + + +def get_num_image_tokens(height: int, width: int, patch_size: int, image_break: bool, image_end: bool) -> int: + """ + Calculate the number of image tokens. + If image_break is True, we consider 1 additional token after every row of patches. + """ + height_patches = div(height, patch_size) + width_patches = div(width, patch_size) + num_tokens = height_patches * width_patches + if image_break: + num_tokens += height_patches + elif image_end: + num_tokens += 1 + return num_tokens + + +def get_resize_dims(height: int, width: int, max_height: int, max_width: int, patch_size: int) -> tuple[int, int]: + """ + Calculate the new dimensions for resizing an image while maintaining the aspect ratio. + If the image is larger than the max dimensions, it will be resized to fit within them. + If the image is smaller, it will be resized to the nearest multiple of the patch size. + """ + ratio = max(height / max_height, width / max_width) + if ratio > 1: + # Resize to fit within max dimensions + height = int(height / ratio) + width = int(width / ratio) + return patch_size * math.ceil(height / patch_size), patch_size * math.ceil(width / patch_size) + + +def resize(image: torch.Tensor, target_height: int, target_width: int) -> torch.Tensor: + # cap the resizing to half of the current size as a workaround for large images + # See pytorch issue: https://github.com/pytorch/pytorch/issues/103589 + while max(image.size(1) / target_height, image.size(2) / target_width) > 2: + image = torchvision_transforms.functional.resize( + image, + size=(math.ceil(image.size(1) / 2), math.ceil(image.size(2) / 2)), + interpolation=torchvision_transforms.InterpolationMode.BICUBIC, + ) + + # TODO: options for interpolation mode? + return torchvision_transforms.functional.resize( + image, size=(target_height, target_width), interpolation=torchvision_transforms.InterpolationMode.BICUBIC + ) + + +def position_ids_in_meshgrid(height, width, max_size, patch_size) -> torch.Tensor: + patch_height = height // patch_size + patch_width = width // patch_size + return torch.arange(patch_height).repeat_interleave(patch_width) * max_size + torch.arange(patch_width).repeat( + patch_height + ) + + +class VisionPreprocessor: + def __init__(self, config: VisionEncoderConfig, distributed: Distributed): + self._config = config + self._distributed = distributed + + def preprocess(self, tokens, kwargs: dict[str, typing.Any]) -> None: + max_image_size = kwargs.get(VisionEncoderKwargs.max_image_size) + patch_size = self._config.patch_size + image_sizes = [] + + norm_config: ImageNormalizationConfig = kwargs["norm_config"] + + if LanguageModelKwargs.labels in kwargs: + labels = kwargs[LanguageModelKwargs.labels] + if (self._config.image_break_token is not None) or (self._config.image_end_token is not None): + # If image break or end token is present, we need to replace image token ids to -100 in labels + # TODO: avoid double cloning labels in case of loss masking spans? + labels = labels.clone() + patches = [] + patch_position_ids = [] + sequence_lengths = [0] + max_sequence_length = -1 + + for sample_index, (sample_images_, positions) in enumerate( + zip(kwargs[VisionEncoderKwargs.images], kwargs.get(VisionEncoderKwargs.image_positions), strict=True) + ): + image_sizes.append(sample_image_sizes := []) + + sample_sequence_length = 0 + + for image, position in zip(sample_images_, positions, strict=True): + height, width = get_resize_dims( + image.size(1), image.size(2), max_image_size, max_image_size, patch_size=patch_size + ) + + sample_image_sizes.append((height, width)) + + image = resize(image, height, width) + + # TODO: Normalize with constant dtype instead? + image = image.to(dtype=self._distributed.config.training_dtype.torch) + + image = torchvision_transforms.functional.normalize( + image / norm_config.rescale_factor, + mean=[norm_config.mean_r, norm_config.mean_g, norm_config.mean_b], + std=[norm_config.std_r, norm_config.std_g, norm_config.std_b], + ) + patches.extend( + torch.nn.functional.unfold(image, kernel_size=patch_size, stride=patch_size).T.reshape( + -1, 3, patch_size, patch_size + ) + ) + + num_height_patches = div(height, patch_size) + num_width_patches = div(width, patch_size) + grid_height = torch.arange(num_height_patches).repeat_interleave(num_width_patches) + grid_width = torch.arange(num_width_patches).repeat(num_height_patches) + grid_height * div(max_image_size, patch_size) + grid_width + patch_position_ids.append(grid_height * div(max_image_size, patch_size) + grid_width) + + if LanguageModelKwargs.labels in kwargs: + num_tokens = get_num_image_tokens( + height, + width, + patch_size=patch_size, + image_break=self._config.image_break_token is not None, + image_end=self._config.image_end_token is not None, + ) + # set labels for image patches to -100 + labels[sample_index, max(position - 1, 0) : position + num_tokens - 1] = -100 + + sequence_lengths.append(sequence_length := num_height_patches * num_width_patches) + if sequence_length > max_sequence_length: + max_sequence_length = sequence_length + sample_sequence_length += sequence_length + + # TODO: No need for padding with varlen? + padding_size = kwargs[AttentionKwargs.sequence_length] - sample_sequence_length + if padding_size > max_sequence_length: + max_sequence_length = padding_size + sequence_lengths.append(padding_size) + + patches.append( + torch.zeros(padding_size, 3, patch_size, patch_size).to( + dtype=self._tensor_space.distributed_config.training_dtype.torch, + device=self._tensor_space.distributed.device, + ), + ) + patch_position_ids.append(torch.full((padding_size,), 0, dtype=torch.int64)) + + kwargs[VisionEncoderKwargs.image_sizes] = image_sizes + kwargs[VisionEncoderKwargs.image_patches] = torch.cat(patches).to(device=self._distributed.device) + kwargs[VisionTransformerKwargs.patch_position_ids] = torch.cat(patch_position_ids).to( + device=self._distributed.device + ) + kwargs[VisionEncoderKwargs.max_image_tokens] = div(max_image_size**2, patch_size**2) + # sequence data parallel is not yet supported for images, so we use the same cu_seqlens for q and k + kwargs[VisionTransformerKwargs.cu_seqlens_q] = torch.tensor( + cu_seqlens, device=self._distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.cu_seqlens_k] = torch.tensor( + cu_seqlens, device=self._distributed.device, dtype=torch.int32 + ) + kwargs[VisionTransformerKwargs.max_seqlen_q] = max_sequence_length + kwargs[VisionTransformerKwargs.max_seqlen_k] = max_sequence_length + if LanguageModelKwargs.labels in kwargs: + kwargs[LanguageModelKwargs.labels] = labels + + # TODO: add proper preprocessing for attention-mask when not using flash attention + # Following is just a dummy code to run the tests. + kwargs[self._config.transformer._transformer_kwargs.attention_mask] = torch.ones( + (1, 1, kwargs[AttentionKwargs.sequence_length], 1, kwargs[AttentionKwargs.sequence_length]), + dtype=torch.bool, + device=self._tensor_space.distributed.device, + ) + kwargs[self._config.transformer._transformer_kwargs.attention_mask_value] = torch.full( + [], + torch.finfo(self._distributed.config.training_dtype.torch).min, + dtype=self._distributed.config.training_dtype.torch, + device=self._distributed.device, + ) diff --git a/fast_llm/layers/vision/vision_encoder.py b/fast_llm/layers/vision/vision_encoder.py new file mode 100644 index 000000000..b4fa189d5 --- /dev/null +++ b/fast_llm/layers/vision/vision_encoder.py @@ -0,0 +1,67 @@ +import logging +import typing + +import torch + +from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.common.peft.config import PeftConfig +from fast_llm.layers.vision.config import VisionEncoderConfig + +logger = logging.getLogger(__name__) + + +class VisionEncoder[ConfigType: VisionEncoderConfig](BlockBase[VisionEncoderConfig]): + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + hidden_dim: TensorDim, + lr_scale: float | None, + peft: PeftConfig | None, + ): + vision_hidden_dim = TensorDim("hidden", self._config.hidden_size) + super().__init__(config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft) + self.patch_convolution = self._config.patch_convolution.get_layer( + distributed_config, + vision_hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + # TODO: ====== Appropriate name?? ====== + self.decoder = self._config.decoder.get_layer( + distributed_config, + vision_hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + # TODO: ====== Hidden dim ====== + self.adapter = self._config.adapter.get_layer( + distributed_config, + vision_hidden_dim, + lr_scale=self._lr_scale, + peft=self._peft, + ) + + def get_layers(self) -> list["Layer"]: + return self.patch_convolution.get_layers() + self.decoder.get_layers() + self.adapter.get_layers() + + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + self.patch_convolution.preprocess(batch, kwargs) + self.decoder.preprocess(batch, kwargs) + self.adapter.preprocess(batch, kwargs) + + def get_loss_definitions(self, count: int = 1) -> list[LossDef]: + # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? + return ( + self.patch_convolution.get_loss_definitions(count) + + self.decoder.get_loss_definitions(count) + + self.adapter.get_loss_definitions(count) + ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a901a0466..c1ee246f7 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,12 +48,6 @@ class GPTBatchConfig(BatchConfig): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - # TODO: Find a better place for these? - cross_document_attention: bool = Field( - default=True, - desc="Applies attention to tokens from other documents in the packed sequence. Set to False for masking attention to other documents.", - hint=FieldHint.feature, - ) use_loss_masking_spans: bool = Field( default=False, desc="Read loss masking spans from the dataset.", diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 9215e6dc7..34e38469a 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -5,7 +5,8 @@ import torch import transformers.modeling_outputs -from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.data.sample.language_model import LanguageModelBatch +from fast_llm.data.sample.token import TokenBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM @@ -80,7 +81,9 @@ def inner_forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) batch = self.fast_llm_base_model.preprocess_batch( - GPTBatch(input_ids, sequence_lengths=sequence_lenghts), phase=PhaseType.inference, iteration=iteration + LanguageModelBatch(TokenBatch(input_ids, lengths=sequence_lenghts)), + phase=PhaseType.inference, + iteration=iteration, ) ((input_, kwargs),) = batch diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ecb..3295295f6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,7 +3,7 @@ import torch -from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType @@ -40,7 +40,7 @@ def __init__( param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa def preprocess_meta( - self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType + self, batch_meta: GPTBatchConfig | LanguageModelBatch, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: # TODO Remove (Move batch splitting elsewhere) # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence @@ -51,7 +51,7 @@ def preprocess_meta( micro_sequence_length = batch_meta.micro_sequence_length truncate_documents = batch_meta.truncate_documents else: - micro_batch_size, sequence_length = batch_meta.shape + micro_batch_size, sequence_length = batch_meta.tokens.tokens.shape if phase != PhaseType.inference: sequence_length -= self._config.head.prediction_heads micro_sequence_length = sequence_length @@ -151,7 +151,7 @@ def preprocess_meta( def preprocess_batch( self, - batch: GPTBatch, + batch: LanguageModelBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, *, phase: PhaseType, @@ -161,19 +161,10 @@ def preprocess_batch( # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup - if preprocessed_meta is None: - preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) - - _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size - sequence_first = common_kwargs[AttentionKwargs.sequence_first] - max_prediction_distance = self._config.head.max_prediction_distance + batch.to_device_(self._distributed.device) - batch.token_ids = batch.token_ids.to( - device=self._distributed.device, - dtype=torch.int64, - non_blocking=True, - ) + if preprocessed_meta is None: + preprocessed_meta = self.preprocess_meta(batch, phase) reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): @@ -191,103 +182,59 @@ def preprocess_batch( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] - token_ids = batch.token_ids - if sequence_first: - # Move the sequence dimension first to make sequence parallel ops more efficient. - token_ids = token_ids.transpose(0, 1).contiguous() - preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size - if sequence_first: - tokens = token_ids[sequence_k - sequence_q : sequence_k] - else: - # TODO: Avoid multiple contiguous calls? - tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() - if batch.sequence_lengths is not None: - kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths - if batch.chosen_spans is not None: - kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans - if batch.rejected_spans is not None: - kwargs_meta[LanguageModelKwargs.rejected_spans] = batch.rejected_spans + tokens_end = kwargs_meta[AttentionKwargs.sequence_k_dim].size + tokens_begin = tokens_end - kwargs_meta[AttentionKwargs.sequence_q_dim].size + cropped_tokens = batch.tokens.crop(tokens_begin, tokens_end) # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents presents = None if i == len(preprocessed_meta) - 1 else [] - kwargs = { + + kwargs: dict[str, typing.Any] = { **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, + AttentionKwargs.sequence_lengths: batch.tokens.lengths, + **reference_logits[i], } + if phase != PhaseType.inference: - sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels - if sequence_first: - labels = token_ids[sequence_offset : sequence_k + max_prediction_distance] - else: - # TODO: Avoid multiple contiguous calls? - labels = token_ids[:, sequence_offset : sequence_k + max_prediction_distance].contiguous() - # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss - # TODO: take ignore_index from config + labels_begin = tokens_begin + 1 + labels_end = tokens_end + self._config.head.max_prediction_distance + + labels = batch.tokens.crop(labels_begin, labels_end).tokens + if batch.loss_masking_spans is not None: - # avoid changing input tokens - labels = labels.clone() - for idx, spans in enumerate(batch.loss_masking_spans): - if not spans.numel(): - continue - valid_spans = spans[ - (spans[:, 0] <= sequence_k + max_prediction_distance - 1) - & (spans[:, 1] >= sequence_offset) - ] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k + max_prediction_distance - 1) - valid_spans -= sequence_offset - loss_mask = torch.ones_like(labels, dtype=torch.bool) - for start, end in valid_spans: - if sequence_first: - loss_mask[start : end + 1, idx] = False - else: - loss_mask[idx, start : end + 1] = False - if self._config.output_layer.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - labels = torch.where(loss_mask, labels, -100) - kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) + loss_masking_spans = batch.loss_masking_spans.crop(labels_begin, labels_end) + loss_mask = torch.ones_like(labels, dtype=torch.bool) + for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): + for begin, end in loss_masking_spans: + loss_mask[sample_index, begin:end] = False + if self._config.output_layer.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + labels = torch.where(loss_mask, labels, -100) + + kwargs[LanguageModelKwargs.labels] = ( + labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels + ).contiguous() if batch.chosen_spans is not None: - chosen_valid_spans = [] - for spans in batch.chosen_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - chosen_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - - rejected_valid_spans = [] - for spans in batch.rejected_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - rejected_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - + kwargs[LanguageModelKwargs.chosen_spans] = batch.chosen_spans.crop(labels_begin, labels_end).ranges + + if batch.rejected_spans is not None: + kwargs[LanguageModelKwargs.rejected_spans] = batch.rejected_spans.crop( + labels_begin, labels_end + ).ranges + + tokens = ( + cropped_tokens.tokens.transpose(0, 1) + if kwargs[AttentionKwargs.sequence_first] + else cropped_tokens.tokens + ).contiguous() self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54ea13dc4..b8fb22ebb 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -27,7 +27,6 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, # OK since DPO is not supported for MTP. "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), - "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } diff --git a/fast_llm/models/multimodal/__init__.py b/fast_llm/models/multimodal/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/models/multimodal/config.py b/fast_llm/models/multimodal/config.py new file mode 100644 index 000000000..2415734e4 --- /dev/null +++ b/fast_llm/models/multimodal/config.py @@ -0,0 +1,89 @@ +import logging +import typing + +from fast_llm.config import Field, FieldHint, FieldUpdate, config_class +from fast_llm.engine.checkpoint.config import CheckpointFormat +from fast_llm.engine.config_utils.runnable import RunnableConfig +from fast_llm.engine.multi_stage.config import FastLLMModelConfig +from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.vision.config import VisionEncoderConfig +from fast_llm.models.gpt.config import ( + GPTBaseModelConfig, + GPTBatchConfig, + GPTModelConfig, + GPTTrainerConfig, + PretrainedGPTModelConfig, +) + +if typing.TYPE_CHECKING: + from fast_llm.models.multimodal.huggingface import HuggingfaceMultiModalModelForCausalLM + from fast_llm.models.multimodal.model import MultiModalBaseModel, MultiModalModel, MultiModalModelInferenceRunner + from fast_llm.models.multimodal.trainer import MultiModalTrainer + +logger = logging.getLogger(__name__) + + +@config_class() +class MultiModalBatchConfig(GPTBatchConfig): + pass + + +@config_class() +class MultiModalBaseModelConfig(GPTBaseModelConfig): + vision_encoder: VisionEncoderConfig = Field( + hint=FieldHint.architecture, + desc="Configuration for the vision encoder.", + ) + + @property + def base_model_class(self) -> type["MultiModalBaseModel"]: + from fast_llm.models.multimodal.model import MultiModalBaseModel + + return MultiModalBaseModel + + +@config_class(dynamic_type={FastLLMModelConfig: "gpt"}) +class MultiModalModelConfig(GPTModelConfig): + _abstract = False + model_name: typing.ClassVar[str] = "gpt" + base_model: GPTBaseModelConfig = FieldUpdate() + # TODO: ====== Conversion ====== + checkpoint_formats: typing.ClassVar[tuple[type[CheckpointFormat], ...]] = FastLLMModelConfig.checkpoint_formats + + @classmethod + def get_model_class(cls) -> type["MultiModalModel"]: + from fast_llm.models.multimodal.model import MultiModalModel + + return MultiModalModel + + @classmethod + def get_inference_runner_class(cls) -> type["MultiModalModelInferenceRunner"]: + from fast_llm.models.multimodal.model import MultiModalModelInferenceRunner + + return MultiModalModelInferenceRunner + + @classmethod + def get_huggingface_model_for_causal_lm_class(cls) -> type["HuggingfaceMultiModalModelForCausalLM"]: + from fast_llm.models.multimodal.huggingface import HuggingfaceMultiModalModelForCausalLM + + return HuggingfaceMultiModalModelForCausalLM + + +@config_class() +class PretrainedMultiModalModelConfig(PretrainedGPTModelConfig): + _abstract = False + model: MultiModalModelConfig = FieldUpdate() + + +@config_class(dynamic_type={RunnableConfig: "train_gpt", TrainerConfig: "gpt"}) +class MultiModalTrainerConfig(PretrainedMultiModalModelConfig, GPTTrainerConfig): + data: MultiModalDataConfig = FieldUpdate() + batch: MultiModalBatchConfig = FieldUpdate() + # TODO: Use dynamic model type? + reference_models: dict[str, PretrainedMultiModalModelConfig] = FieldUpdate() + + @classmethod + def get_trainer_class(cls) -> type["MultiModalTrainer"]: + from fast_llm.models.multimodal.trainer import MultiModalTrainer + + return MultiModalTrainer diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py new file mode 100644 index 000000000..7426191f7 --- /dev/null +++ b/fast_llm/models/multimodal/model.py @@ -0,0 +1,133 @@ +import logging +import typing + +import torch + +from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType +from fast_llm.engine.inference.runner import InferenceRunner +from fast_llm.models.gpt.config import GPTBatchConfig +from fast_llm.models.gpt.model import GPTBaseModel, GPTModel +from fast_llm.models.multimodal.config import MultiModalBaseModelConfig, MultiModalBatchConfig, MultiModalModelConfig +from fast_llm.tensor import TensorMeta + +logger = logging.getLogger(__name__) + + +class MultiModalBaseModel[ConfigType: MultiModalBaseModelConfig](GPTBaseModel[ConfigType]): + """ + A transformer-based language model generalizing the GPT model architecture. + """ + + _config: ConfigType + + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + ): + super().__init__(config, distributed_config) + self.vision_encoder = self._config.vision_encoder.get_layer( + distributed_config, + self._hidden_dim, + lr_scale=None, + peft=self._config.peft, + ) + + def preprocess_meta( + self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType + ) -> list[tuple[TensorMeta, dict]]: + # TODO Remove (Move batch splitting elsewhere) + # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # try: + # max_image_size = batch_meta.max_image_size + # except AttributeError: + # max_image_size = 256 + # logger.warning("Inference mode: max_image_size not provided, defaulting to 256") + # vision_kwargs = { + # VisionEncoderKwargs.patch_size: self._config.vision_encoder.patch_size, + # VisionEncoderKwargs.max_image_size: max_image_size, + # VisionEncoderKwargs.rope_theta: self._config.vision_encoder.transformer.rotary.theta, + # VisionEncoderKwargs.kv_channels: self._tensor_space[VisionTransformerDimNames.kv_channels].size, + # VisionEncoderKwargs.out_channels: self._tensor_space[VisionEncoderDimNames.out_channels].size, + # } + # vision_hidden_dim = self._tensor_space[VisionTransformerDimNames.hidden] + # vision_hidden_dims = ( + # (hidden_sequence_q_dim, batch_dim, vision_hidden_dim) + # if sequence_first + # else (batch_dim, hidden_sequence_q_dim, vision_hidden_dim) + # ) + # vision_kwargs.update( + # { + # VisionTransformerKwargs.hidden_dims: vision_hidden_dims, + # } + # ) + # common_kwargs.update(vision_kwargs) + + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # # patch_dimensions are (batch * sequence_length) x 3 x patch_size x patch_size + # preprocessed_meta.append((kwargs[VisionEncoderKwargs.image_patches_meta], kwargs)) + # else: + # preprocessed_meta.append((tokens, kwargs)) + pass + + def preprocess_batch( + self, + batch: GPTBatch, + preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, + *, + phase: PhaseType, + iteration: int, + metrics: dict | None = None, + ) -> list[tuple[torch.Tensor, dict]]: + # TODO Move batch splitting elsewhere, align interface with LayerBase + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # if self._config.vision_encoder.image_break_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_break_token, -100, labels) + # if self._config.vision_encoder.image_end_token is not None: + # if not labels_cloned: + # labels = labels.clone() + # labels_cloned = True + # labels = torch.where(labels == self._config.vision_encoder.image_end_token, -100, labels) + # Loss-masking for distillation losses + # TODO ====== Vision ====== + # if self._config.vision_encoder.enabled: + # batch_images = ( + # batch.images if batch.images is not None else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[VisionEncoderKwargs.images] = [ + # [ + # img.to(device=self._tensor_space.distributed.device, dtype=torch.uint8, non_blocking=True) + # for img in images + # ] + # for images in batch_images + # ] + # kwargs[VisionEncoderKwargs.image_positions] = ( + # batch.image_positions + # if batch.image_positions is not None + # else [[]] * kwargs[AttentionKwargs.micro_batch_size] + # ) + # kwargs[LanguageModelKwargs.tokens] = tokens + # image_patches = kwargs.get(VisionEncoderKwargs.image_patches, None) + # if image_patches is not None: + # preprocessed.append((image_patches, kwargs)) + # else: + # preprocessed.append((tokens, kwargs)) + pass + + +class MultiModalModel[ConfigType: MultiModalModelConfig](GPTModel[ConfigType]): + # TODO: Can we drop class? + pass + + +class MultiModalInferenceRunner(InferenceRunner): + model_class: typing.ClassVar[type[MultiModalModel]] = MultiModalModel + batch_config_class: typing.ClassVar[type[MultiModalBatchConfig]] = MultiModalBatchConfig diff --git a/fast_llm/models/multimodal/trainer.py b/fast_llm/models/multimodal/trainer.py new file mode 100644 index 000000000..c4071aafe --- /dev/null +++ b/fast_llm/models/multimodal/trainer.py @@ -0,0 +1,14 @@ +import logging + +from fast_llm.models.gpt.trainer import GPTTrainer +from fast_llm.models.multimodal.config import MultiModalTrainerConfig + +logger = logging.getLogger(__name__) + + +class MultiModalTrainer[ConfigType: MultiModalTrainerConfig](GPTTrainer[ConfigType]): + def _get_data(self) -> MultiModalData: + return MultiModalData( + config=self._config.data, + distributed_config=self._config.model.distributed, + ) diff --git a/setup.cfg b/setup.cfg index 77073ab55..2a1614554 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,7 +43,7 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.52.4 + transformers==4.53.2 hf-transfer>=0.1.9 datasets>=3.6.0 huggingface-hub>=0.32.6 @@ -59,6 +59,13 @@ GENERATION = lm_eval>=0.4.9 +# Required for supporting vision inputs +VISION = + # Vision Tools + webp>=0.4.0 + pillow-simd>=9.5.0 + torchvision>=0.20.0 + DEV = # Pre-commit git hook pre-commit>=4.2.0 diff --git a/tests/data/common.py b/tests/data/common.py index d8cc6fff2..a52afc64b 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -8,17 +8,17 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import ( - GPTIndexedDatasetConfig, - GPTSampledDatasetConfig, - GPTSamplingConfig, - GPTSamplingData, - GPTSamplingParameters, +from fast_llm.data.dataset.config import ( + IndexedDatasetConfig, + SampledDatasetConfig, + SamplingConfig, + SamplingParameters, ShufflingType, ) -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.dataset.sampled import SampledIndexedDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig @@ -34,7 +34,6 @@ def get_sampling_data( phase=PhaseType.training, sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, - tokenizer: Tokenizer | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, @@ -42,7 +41,7 @@ def get_sampling_data( # Config with convenient defaults. distributed = Distributed(DistributedConfig(), use_cpu=True) return GPTSamplingData( - config=GPTSamplingConfig( + config=SamplingConfig( seed=seed, gpu=gpu, shuffle=shuffle, @@ -56,13 +55,12 @@ def get_sampling_data( cache_directory=cache_directory, distributed=distributed, dataset_name=phase.value, - tokenizer=tokenizer, ) -def get_dataset_config[T: GPTSampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: - dataset_config = GPTSampledDatasetConfig.from_dict(config) - Assert.custom(isinstance, dataset_config, cls) +def get_dataset_config[T: SampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: + dataset_config = SampledDatasetConfig.from_dict(config) + Assert.custom(isinstance, dataset_config, getattr(cls, "__origin__", cls)) return typing.cast(cls, dataset_config) @@ -96,7 +94,7 @@ def get_test_data_and_compare_samples( expected_samples = {PhaseType.training.value.lower(): expected_samples} assert "sampling" not in config - config["sampling"] = GPTSamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) + config["sampling"] = SamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) data = GPTData(GPTDataConfig.from_dict(config), distributed_config) data.setup(distributed, sampling_parameters, cache_directory) with NoAutoValidate(): @@ -105,46 +103,55 @@ def get_test_data_and_compare_samples( batch_config.validate() tokens = { phase: torch.stack( - [batch.token_ids[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] + [ + batch.tokens.tokens[0] + for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0) + ] ) for phase, samples in samples_per_dataset.items() } for phase, expected_samples_ in expected_samples.items(): - Assert.all_equal(tokens[phase], expected_samples_) + Assert.all_equal(tokens[phase].to(torch.int64), expected_samples_) return data def compare_indexed_dataset( - dataset: GPTIndexedDataset, + dataset: IndexedDataset, length: int, num_tokens: int, expected_samples: dict[int, list[int]], loss_masking_spans: dict[int, list[int]] | None = None, ) -> None: Assert.eq(len(dataset), length) - sizes = dataset.get_document_sizes() + text_sizes, image_sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], text_sizes[: min(len(dataset), 100)] ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) + Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample, dtype=np.int64)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): - Assert.all_equal( - dataset.get(i, use_loss_masking_spans=True).loss_masking_spans, - np.array(loss_masking_spans[i], dtype=np.int32).reshape(-1, 2), + print(i) + Assert.eq( + dataset.get_document( + i, + parameters=GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_loss_masking_spans=True + ), + ).loss_masking_spans.ranges, + loss_masking_spans[i], ) def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal([sampled[i].token_ids for i in range(len(expected_samples))], expected_samples) + Assert.all_equal( + torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]).to(torch.int64), expected_samples + ) -def validate_indexed_dataset_sampling( - sampled: GPTSampledIndexedDataset, expected_samples: list[list[int]] | None = None -): +def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): """ Compare `GPTSampledIndexedDataset` sampling against a more basic approach """ @@ -165,7 +172,7 @@ def validate_indexed_dataset_sampling( ) seen_tokens = 0 for document_index in document_sampling: - document = sampled._indexed_dataset.get(document_index).token_ids + document = sampled._indexed_dataset.get_document(document_index).tokens.tokens all_tokens[seen_tokens : seen_tokens + len(document)] = document[: num_tokens - seen_tokens] seen_tokens += len(document) @@ -176,7 +183,7 @@ def validate_indexed_dataset_sampling( all_tokens[index * sampled._parameters.sequence_length : (index + 1) * sampled._parameters.sequence_length + 1] for index in range(sampled._parameters.num_samples) ] - token_ids = [sampled[i].token_ids for i in range(len(sampled))] + token_ids = torch.stack([sampled[i].tokens.tokens for i in range(len(sampled))]).to(torch.int64) Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: @@ -184,8 +191,8 @@ def validate_indexed_dataset_sampling( return token_ids -@config_class(dynamic_type={GPTSampledDatasetConfig: "mock_memmap"}) -class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "mock_memmap"}) +class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False num_documents: int | None = Field( default=None, @@ -199,15 +206,15 @@ class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): ) path: pathlib.Path = Field(default=".") - def build(self) -> "GPTIndexedDataset": - return MockGPTMemmapDataset(self) + def build(self) -> "IndexedDataset": + return MockMemmapDataset(self) @property def num_tokens(self) -> int: return self.num_documents * self.num_tokens_per_document -class MockGPTMemmapDataset(GPTIndexedDataset): +class MockMemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): def __init__(self, config: MockGPTMemmapDatasetConfig): self._config = config @@ -218,11 +225,13 @@ def name(self) -> str: def __len__(self) -> int: return self._config.num_documents - def get_document_sizes(self) -> np.ndarray: + def get_document_sizes(self, parameters: GPTSamplingParameters | None = None) -> np.ndarray: return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) - def get_document_size(self, index: int) -> int: + def get_document_size(self, index: int, parameters: GPTSamplingParameters | None = None) -> int: return self._config.num_tokens_per_document - def get(self, index: int, *args, **kwargs) -> typing.Any: + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: raise NotImplementedError() diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index e64b47020..0099cb50b 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -3,7 +3,8 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig +from fast_llm.data.dataset.config import BlendedDatasetConfig +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( compare_sampled_dataset, @@ -122,7 +123,7 @@ def test_gpt_blended(): ], "weights": [0.75, 0.25], }, - GPTBlendedDatasetConfig, + BlendedDatasetConfig[LanguageModelSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) @@ -161,7 +162,7 @@ def test_gpt_blended_mixed(): ], "weights": [0.6, 0.4], }, - GPTBlendedDatasetConfig, + BlendedDatasetConfig[LanguageModelSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 2c025cbaf..5335e01c0 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,4 +1,5 @@ -from fast_llm.data.dataset.gpt.config import GPTConcatenatedDatasetConfig +from fast_llm.data.dataset.config import ConcatenatedDatasetConfig +from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset, compare_sampled_dataset, @@ -27,7 +28,7 @@ def test_gpt_concatenate(): get_test_dataset() dataset = get_dataset_config( {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, - GPTConcatenatedDatasetConfig, + ConcatenatedDatasetConfig[LanguageModelSample], ).build() compare_indexed_dataset( dataset, diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index c9212d6e3..438c5e7e3 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -1,6 +1,4 @@ -from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.gpt.config import GPTFimSampledDatasetConfig -from fast_llm.data.tokenizer import Tokenizer from tests.data.common import ( compare_sampled_dataset, get_dataset_config, @@ -29,13 +27,13 @@ def test_gpt_fim(): sampling_config = get_sampling_data( 8, sequence_length=5, - tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})), vocab_size=49157, ) sampled = get_dataset_config( { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", @@ -55,6 +53,7 @@ def test_gpt_fim_data(): "training": { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", @@ -62,7 +61,6 @@ def test_gpt_fim_data(): "suffix_token": "z", } }, - "tokenizer": {"path": TOKENIZER_PATH}, }, 8, sequence_length=5, diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index 1286bddd7..ca887f3c1 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -27,8 +27,8 @@ def test_gpt_memmap(cache_directory): MEMMAP_DATASET_SPANS = { 9: [], - 10: [[0, 4], [6, 8]], - 13: [[1, 2]], + 10: [(0, 2), (2, 7), (7, 10)], + 13: [(0, 2)], 15: [], } diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 17ba5de01..601abcf99 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -4,12 +4,14 @@ import numpy as np import pytest +import torch -from fast_llm.data.dataset.gpt.config import GPTIndexedDatasetConfig +from fast_llm.data.dataset.config import IndexedDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.data.common import MockGPTMemmapDatasetConfig # Noqa @@ -28,52 +30,45 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): - documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] + documents = [ + (torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)), None, None, None) + for _ in range(100) + ] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, document in enumerate(documents): - assert np.array_equal( - dataset.get(i).token_ids, document.token_ids, equal_nan=True - ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." + for i, (tokens, _, _, _) in enumerate(documents): + Assert.all_equal(dataset.get_document(i).tokens.tokens, tokens.to(torch.int64)) -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_preference_dataset(dtype): - def generate_valid_span(max_seq_length): - span = np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False) - return np.sort(span) +def _generate_valid_span(max_seq_length): + return np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist() - vocab_size = 1000 - max_seq_length = 8192 - num_samples = 100 +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_write_memmap_preference_dataset(dtype): documents = [ - GPTSample( - token_ids=np.random.randint(vocab_size, size=max_seq_length).astype(dtype), - chosen_span=generate_valid_span(max_seq_length=max_seq_length), - rejected_span=generate_valid_span(max_seq_length=max_seq_length), + ( + torch.from_numpy(np.random.randint(1000, size=100).astype(dtype)), + None, + _generate_valid_span(100), + _generate_valid_span(100), ) - for _ in range(num_samples) + for _ in range(50) ] with tempfile.TemporaryDirectory() as temp_dir: prefix = pathlib.Path(temp_dir) GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) - for i, document in enumerate(documents): - dataset_item = dataset.get(i, use_preference_loss_spans=True) - assert np.array_equal( - dataset_item.token_ids, document.token_ids, equal_nan=True - ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." - - assert np.array_equal( - dataset_item.chosen_span, document.chosen_span, equal_nan=True - ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get(i).chosen_span}." - - assert np.array_equal( - dataset_item.rejected_span, document.rejected_span, equal_nan=True - ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get(i).rejected_span}." + parameters = GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True + ) + for i, (token_ids, _, (chosen_begin, chosen_end), (rejected_begin, rejected_end)) in enumerate(documents): + document = dataset.get_document(i, parameters=parameters) + Assert.all_equal(document.tokens.tokens, token_ids.to(torch.int64)) + Assert.eq(document.chosen_spans.ranges, [(chosen_begin, chosen_end + 1)]) + Assert.eq(document.rejected_spans.ranges, [(rejected_begin, rejected_end + 1)]) def test_load_metadata_from_hub(): @@ -126,7 +121,7 @@ def test_absent_metadata_local(): def test_split_dataset(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0], {"training": 3, "validation": 1}, @@ -154,8 +149,8 @@ def test_split_dataset(): def test_split_datasets_0(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 1, "validation": 1}, @@ -173,8 +168,8 @@ def test_split_datasets_0(): def test_split_datasets_1(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 6a2be3dcc..58f4d3dab 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -1,11 +1,12 @@ -import typing - import numpy as np import pytest +import torch -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, ShufflingType -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.config import ShufflingType +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingParameters +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert from tests.data.common import ( get_dataset_config, @@ -62,24 +63,23 @@ def test_gpt_sampled_data(): ) -class SimpleGPTIndexedDataset(GPTIndexedDataset): +class SimpleGPTIndexedDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): # TODO: worth adding to the main codebase? def __init__(self, samples): self._samples = samples - def get(self, index: int, offset=0, length=None, use_loss_masking_spans: bool = False) -> typing.Any: - if length is None: - length = len(self._samples[index]) - assert not use_loss_masking_spans - return GPTSample( - token_ids=np.array(self._samples[index][offset : offset + length], dtype=np.int64), loss_masking_spans=None - ) + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None + ) -> SampleType: + if end is None: + end = len(self._samples[index]) + return LanguageModelSample(TokenSample(torch.tensor(self._samples[index][begin:end], dtype=torch.int64))) def __len__(self) -> int: return len(self._samples) - def get_document_sizes(self) -> np.ndarray: - return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) + def get_document_sizes(self) -> torch.Tensor: + return torch.tensor([self.get_document_size(index) for index in range(len(self))], dtype=torch.int64) def get_document_size(self, index: int) -> int: return len(self._samples[index]) @@ -180,4 +180,4 @@ def test_gpt_sample_padding(): else: sampled = dataset.sample(sampling) for idx in range(len(expected_samples)): - Assert.all_equal(sampled[idx].token_ids, np.array(expected_samples[idx])) + Assert.all_equal(sampled[idx].tokens.tokens, np.array(expected_samples[idx])) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 1fc8df1eb..3c6ae10d4 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,4 +1,5 @@ -from fast_llm.data.dataset.gpt.config import GPTDatasetSliceConfig +from fast_llm.data.dataset.config import DatasetSliceConfig +from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -34,7 +35,7 @@ def test_gpt_slice(): # samples[9:18] dataset = get_dataset_config( {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, - GPTDatasetSliceConfig, + DatasetSliceConfig[LanguageModelSample], ).build() compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 3fae970f8..489f5e1c1 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,167 +1,80 @@ -import random - import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.functional.dpo import _compute_dpo_loss, _compute_logprobs_for_preference_spans +from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert +from tests.utils.dataset import get_random_spans from tests.utils.utils import requires_cuda -def ref_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: - if temperature != 1.0: - logits.div_(temperature) - batch_dim = logits.shape[:-1] - last_dim = logits.shape[-1] - - output = torch.nn.functional.cross_entropy(logits.reshape(-1, last_dim), labels.reshape(-1), reduction="none") - log_probs_labels = -output.view(*batch_dim) - - return log_probs_labels - - -def ref_packed_get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - attention_mask, - prompt_id_lens, - packed_seq_lens, -) -> torch.FloatTensor: - labels = labels[:, 1:] - logits = logits[:, :-1, :] - per_token_logps = ref_log_probs_from_logits(logits, labels) - - loss_masks = attention_mask.clone().bool() - - index = 0 - for i, seq_len in enumerate(packed_seq_lens): - loss_masks[0, index : index + prompt_id_lens[i]] = False - index = index + seq_len - - loss_masks = loss_masks[:, 1:] - - logprobs_sums = [] - index = 0 - for i, seq_len in enumerate(packed_seq_lens): - seq = per_token_logps[0, index : index + seq_len - 1] - mask = loss_masks[0, index : index + seq_len - 1] - logprobs_sums.append((seq * mask).sum()) - index = index + seq_len - chosen_logps = logprobs_sums[: len(packed_seq_lens) // 2] - rejected_logps = logprobs_sums[len(packed_seq_lens) // 2 :] - - return torch.tensor(chosen_logps), torch.tensor(rejected_logps) - - -@pytest.mark.slow -@pytest.mark.parametrize( - ("batch_size", "seq_length", "vocab_size"), - ( - (2, 32, 50), - (1, 32, 50), - (2, 100, 50), - (2, 32, 200), - ), -) -def test_preference_logps(batch_size, seq_length, vocab_size): - random.seed(0) - torch.manual_seed(0) - - def random_split(seq_length): - min_val = int(seq_length * 0.3) - max_val = int(seq_length * 0.7) - - if max_val < min_val: - max_val = min_val - - a = random.randint(min_val, max_val) - b = seq_length - a - return [a, b] - - logits = torch.randn(batch_size, seq_length, vocab_size) - targets = torch.randint(0, vocab_size, (batch_size, seq_length)) - packed_seq_lens = random_split(seq_length) # simulate different chosen/rejected lengths - prompt_id_lens = [int(min(packed_seq_lens) * 0.75)] * 2 # sequences are 75% prompt 25% generation - attention_mask = torch.tensor([1] * packed_seq_lens[0] + [2] * packed_seq_lens[1]).unsqueeze(0) - - chosen_span = torch.tensor([[prompt_id_lens[0], packed_seq_lens[0] - 1]]) - 1 # shift by 1 due to label shifting - rejected_span = ( - torch.tensor([[packed_seq_lens[0] + prompt_id_lens[1], packed_seq_lens[0] + packed_seq_lens[1] - 1]]) - 1 - ) # shift by 1 due to label shifting - - ref_chosen_logps, ref_rejected_logps = ref_packed_get_batch_logps( - logits, targets, attention_mask, prompt_id_lens, packed_seq_lens +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans ) - chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( - logits=logits, - targets=targets[:, 1:], - chosen_spans=chosen_span, - rejected_spans=rejected_span, - ) - - ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) - - # check all logps - Assert.custom(torch.allclose, ref_logps, selected_log_probs, rtol=1e-5) - # check chosen and rejected summed logps - Assert.custom(torch.allclose, ref_chosen_logps, chosen_logps, rtol=1e-5) - Assert.custom(torch.allclose, ref_rejected_logps, rejected_logps, rtol=1e-5) - - -def ref_dpo_loss_fcn( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta=1, - label_smoothing=0, +def reference_dpo_loss( + logits: torch.Tensor, + targets: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, + beta: float, ) -> torch.Tensor: + # TODO: Too similar to the actual implementation. + policy_log_probs = ( + torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + ) + policy_chosen_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + policy_rejected_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + reference_log_probs = ( + torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) + .gather(dim=-1, index=targets.unsqueeze(-1)) + .squeeze(-1) + ) + reference_chosen_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + reference_rejected_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps - logits = pi_logratios - ref_logratios - - # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) - losses = ( - -torch.nn.functional.logsigmoid(beta * logits) * (1 - label_smoothing) - - torch.nn.functional.logsigmoid(-beta * logits) * label_smoothing - ) - - loss = losses.mean() - - return loss + return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() def test_dpo_loss(): torch.manual_seed(0) + logits = torch.randn((10, 50, 100), requires_grad=True) + reference_model_logits = torch.randn((10, 50, 100)) + targets = torch.randint(0, 100, (10, 50)) - NUM_SAMPLES = 20 - policy_chosen_logps = torch.rand(NUM_SAMPLES) - policy_rejected_logps = torch.rand(NUM_SAMPLES) - reference_chosen_logps = torch.rand(NUM_SAMPLES) - reference_rejected_logps = torch.rand(NUM_SAMPLES) - betas = torch.rand(NUM_SAMPLES) + spans = get_random_spans(10, 10, 50) - for i in range(NUM_SAMPLES): - fastllm_dpo_loss = _compute_dpo_loss( - policy_chosen_logps=policy_chosen_logps[i], - policy_rejected_logps=policy_rejected_logps[i], - reference_chosen_logps=reference_chosen_logps[i], - reference_rejected_logps=reference_rejected_logps[i], - beta=betas[i].item(), - ) - ref_dpo_loss = ref_dpo_loss_fcn( - policy_chosen_logps=policy_chosen_logps[i].unsqueeze(0), - policy_rejected_logps=policy_rejected_logps[i].unsqueeze(0), - reference_chosen_logps=reference_chosen_logps[i].unsqueeze(0), - reference_rejected_logps=reference_rejected_logps[i].unsqueeze(0), - beta=betas[i].item(), - ) - Assert.rms_close(fastllm_dpo_loss, ref_dpo_loss, 1e-5) + fastllm_loss, fast_llm_grad = compute_dpo_loss( + logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 + ) + reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1) + reference_loss.backward() + Assert.rms_close(fastllm_loss, reference_loss, 1e-5) + Assert.rms_close(fast_llm_grad, logits.grad, 1e-5) @requires_cuda diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 3c3bfb833..5db18d7ff 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -343,12 +343,15 @@ def test_huggingface_model(model_testing_config, get_convert_path): ) ) errors = [] - auto_model = ( - transformers.AutoModel - if model_testing_config.name in ("diffusion_llama", "dream") - else transformers.AutoModelForCausalLM - ) - model_as_hf = auto_model.from_pretrained(hf_path, trust_remote_code=True).cuda() + if model_testing_config.name in ("diffusion_llama", "dream"): + auto_model = transformers.AutoModel + elif model_testing_config.name in ("llava", "vision_hybrid_mamba2"): + auto_model = transformers.AutoModelForVision2Seq + else: + auto_model = transformers.AutoModelForCausalLM + model_as_hf = auto_model.from_pretrained( + hf_path, trust_remote_code=model_testing_config.checkpoint_format.trust_remote_code + ).cuda() for name, model in zip( ("From state dict", "From Huggingface", "Native Huggingface"), (model_from_fast_llm, model_from_hf, model_as_hf), diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 6aa541b8c..7447e395a 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -6,9 +6,11 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSampledDatasetConfig, GPTSamplingData +from fast_llm.data.dataset.config import SampledDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSamplingData from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample, logger +from fast_llm.data.dataset.sampled import logger +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset @@ -79,7 +81,7 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare_results_for_all_models(distributed_testing_config) -@config_class(dynamic_type={GPTSampledDatasetConfig: "megatron"}) +@config_class(dynamic_type={SampledDatasetConfig: "megatron"}) class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): _abstract: typing.ClassVar[bool] = False path: str = Field( @@ -141,18 +143,16 @@ def __getitem__(self, idx: int) -> typing.Any: shuffled_idx = self._shuffle_idx[idx] doc_f, offset_f = self._sample_idx[shuffled_idx] doc_l, offset_l = self._sample_idx[shuffled_idx + 1] - sample_list = [ - self._indexed_dataset.get( - self._doc_idx[doc].item(), - offset=(doc == doc_f) * offset_f, - length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, - ) - for doc in range(doc_f, doc_l + 1) - ] - token_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) - Assert.eq(len(token_ids), self._sequence_length + 1) - - return GPTSample(token_ids=token_ids) + return LanguageModelSample.from_documents( + [ + self._indexed_dataset.get_document( + self._doc_idx[doc].item(), + begin=(doc == doc_f) * offset_f, + end=offset_l + 1 if doc == doc_l else None, + ) + for doc in range(doc_f, doc_l + 1) + ] + ) @property def name(self) -> str: diff --git a/tests/test_attention.py b/tests/test_attention.py index a19cba8f0..b86cc95fa 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -3,7 +3,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.attention.attention import Attention -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert @@ -29,7 +29,7 @@ def test_varlen_preprocessing(): micro_sequence_length = 12 sequence_length = 36 attention = Attention( - AttentionConfig(head_size=64), + AttentionConfig(head_size=64, implementation=AttentionImplementation.flash, cross_document_attention=False), DistributedConfig(compute_dtype="bfloat16"), hidden_dim=TensorDim("", 1), lr_scale=None, diff --git a/tests/test_config.py b/tests/test_config.py index 63f2606f1..9a1f542a0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,7 +6,7 @@ import yaml from fast_llm.config import NoAutoValidate -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.config import SamplingConfig from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig @@ -60,7 +60,7 @@ def test_validate_example_config(): GPTTrainerConfig.from_dict(fast_llm_config_dict) -@pytest.mark.parametrize("cls", (GPTSamplingConfig, GPTModelConfig)) +@pytest.mark.parametrize("cls", (SamplingConfig, GPTModelConfig)) def test_serialize_default_config_updates(cls): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 680faa931..428dec56b 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -2,10 +2,10 @@ import random import numpy as np +import torch import yaml from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample from tests.utils.global_variables import ( DATASET_PREFIX, MODEL_DATASET_PREFIX, @@ -25,6 +25,15 @@ def download_santacoder_tokenizer(): transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) +def get_random_spans(num_samples: int, max_spans: int, lengths: np.ndarray | int, seed: int = 0): + spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths, [num_samples, max_spans * 2])) + spans = [np.unique(sample_spans).tolist() for sample_spans in spans] + return [ + [(begin, end) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] + for sample_spans in spans + ] + + def get_test_dataset( prefix: pathlib.Path = DATASET_PREFIX, seed: int = 1234, @@ -46,14 +55,27 @@ def get_test_dataset( tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) samples = [ - GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) for document in texts + ( + torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size), + None, + None, + None, + ) + for document in texts ] if max_spans > 0: - lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) - spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) - for sample, span in zip(samples, spans): - span = np.unique(span) - sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) + spans = get_random_spans( + len(samples), max_spans, np.array([[max(len(tokens), 1)] for tokens, _, _, _ in samples]), seed + ) + samples = [ + ( + tokens, + torch.tensor(sample_spans, dtype=torch.int32).reshape(-1, 2), + None, + None, + ) + for (tokens, _, _, _), sample_spans in zip(samples, spans, strict=True) + ] GPTMemmapDataset.write_dataset(prefix, samples) yaml.safe_dump( diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py index bbfa4b21a..926bcc346 100644 --- a/tools/concatenate_dataset.py +++ b/tools/concatenate_dataset.py @@ -35,6 +35,7 @@ def run(self): dataset_dict = { "prefix": str(prefix.relative_to(self.directory)), "num_documents": len(dataset), + # Todo: fix "num_tokens": dataset.num_tokens, } if self.min_tokens is not None and dataset_dict["num_tokens"] < self.min_tokens: