From 6084122ad14af4190588eb0c4bc263b0a11f0d08 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 3 Oct 2025 23:51:53 -0400 Subject: [PATCH 1/7] language_model --- fast_llm/layers/block/block.py | 1 + fast_llm/layers/language_model/config.py | 33 ++++++++------- .../layers/language_model/language_model.py | 40 ++++++++++++------- fast_llm/models/gpt/config.py | 6 +++ fast_llm/models/gpt/conversion/llama.py | 16 ++++---- fast_llm/models/gpt/model.py | 8 ++-- fast_llm/utils.py | 4 +- tests/layers/test_lm_head.py | 2 +- 8 files changed, 64 insertions(+), 46 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index ab6cb22b0..67ce5eea9 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -103,6 +103,7 @@ def __init__( config: ConfigType, distributed_config: DistributedConfig, *, + # TODO: Review. Use `input_dim(s)` and `output_dim(s)` instead? hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d2fbc4909..25fa2d91e 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,7 +2,6 @@ import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.parameter import OptionalParameterConfig, ParameterConfig, combine_lr_scales from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig @@ -16,6 +15,7 @@ if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase + from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction @@ -41,12 +41,6 @@ class LanguageModelEmbeddingsConfig(BlockConfig): desc="Configuration for the word embedding (weight).", hint=FieldHint.architecture, ) - hidden_size: int = Field( - default=1024, - desc="Size of the model's main hidden dimension, e.g., for its input and output layers.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", @@ -295,24 +289,29 @@ def max_prediction_distance(self) -> int: @config_class() -class LanguageModelConfig(ModuleConfig): - # TODO: block +class LanguageModelConfig(BlockConfig): decoder: BlockSequenceConfig = Field( desc="Configuration for the language model decoder.", hint=FieldHint.architecture, ) - embeddings: LanguageModelEmbeddingsConfig = Field() - head: LanguageModelHeadBaseConfig = Field() - # TODO: Allow overriding in sub-models? - peft: PeftConfig = Field( - desc="Configuration for parameter-efficient fine tuning.", + embeddings: LanguageModelEmbeddingsConfig = Field( hint=FieldHint.architecture, + desc="Configuration for the language model embeddings.", + ) + head: LanguageModelHeadBaseConfig = Field( + hint=FieldHint.architecture, desc="Configuration for the language model head(s)." ) tied_embedding_weight: bool = Field( default=False, desc="Tie the output weights (logits) with the vocabulary embedding.", hint=FieldHint.architecture, ) + hidden_size: int = Field( + default=1024, + desc="Size of the model's main hidden dimension, e.g., for its input and output layers.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) sequence_first: bool | None = Field( default=None, desc="Override the default dimension ordering", @@ -321,3 +320,9 @@ class LanguageModelConfig(ModuleConfig): " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", hint=FieldHint.testing, ) + + @property + def layer_class(self) -> "type[LanguageModel]": + from fast_llm.layers.language_model.language_model import LanguageModel + + return LanguageModel diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 9a3bef195..56d41dc3a 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -1,52 +1,64 @@ import logging import typing -from fast_llm.config import Configurable -from fast_llm.engine.base_model.base_model import Layer, LayerBase +import torch + +from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding logger = logging.getLogger(__name__) -class LanguageModel[ConfigType: LanguageModelConfig](Configurable[ConfigType], LayerBase): +class LanguageModel[ConfigType: LanguageModelConfig](BlockBase[ConfigType]): _config: ConfigType def __init__( self, config: ConfigType, distributed_config: DistributedConfig, + *, + # Unused, but required by the `BlockBase` interface. + hidden_dim: TensorDim | None = None, + lr_scale: float | None, + peft: PeftConfig | None, ): - super().__init__(config, distributed_config) - - self._hidden_dim = TensorDim("hidden", config.embeddings.hidden_size) + super().__init__( + config, + distributed_config, + hidden_dim=TensorDim("hidden", self._config.hidden_size), + lr_scale=lr_scale, + peft=peft, + ) self.embeddings: LanguageModelEmbedding = self._config.embeddings.get_layer( distributed_config, hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, + lr_scale=self._lr_scale, + peft=self._peft, ) self.decoder = self._config.decoder.get_layer( distributed_config, self._hidden_dim, - lr_scale=None, - peft=self._config.peft, + lr_scale=self._lr_scale, + peft=self._peft, ) self.head = self._config.head.get_layer( distributed_config, self._config.embeddings, hidden_dim=self._hidden_dim, - lr_scale=None, - peft=self._config.peft, + lr_scale=self._lr_scale, + peft=self._peft, ) - def get_layers(self) -> list["Layer"]: + def get_layers(self) -> list[Layer]: return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() - def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(batch, kwargs) self.decoder.preprocess(batch, kwargs) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 1e57f3b8c..a901a0466 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -10,6 +10,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.engine.training.config import TrainerConfig +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelConfig, MultiTokenPredictionConfig from fast_llm.models.gpt.conversion.config import ( AprielHybridSSMCheckpointFormat, @@ -84,6 +85,11 @@ def micro_batch_splits(self) -> int: class GPTBaseModelConfig(LanguageModelConfig, BaseModelConfig): _abstract = False + # TODO: Allow overriding in sub-models? + peft: PeftConfig = Field( + desc="Configuration for parameter-efficient fine tuning.", + hint=FieldHint.architecture, + ) # Debug, to get an exact match with megatron init. use_megatron_initialization: bool = Field( default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 786d923f2..a92492260 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -449,19 +449,13 @@ def get_converters( class LlamaEmbeddingsConverter: @classmethod def import_config(cls, config: dict) -> dict: - return { - "vocab_size": config["vocab_size"], - "hidden_size": config["hidden_size"], - } + return {"vocab_size": config["vocab_size"]} @classmethod def export_config(cls, config: LanguageModelEmbeddingsConfig) -> dict: Assert.custom(isinstance, config, LanguageModelEmbeddingsConfig) assert not config.position_embeddings.enabled - return { - "vocab_size": config.vocab_size, - "hidden_size": config.hidden_size, - } + return {"vocab_size": config.vocab_size} @classmethod def get_converters( @@ -516,6 +510,7 @@ def import_config(cls, config: dict) -> dict: "embeddings": cls.embeddings_converter_class.import_config(config), "decoder": cls.decoder_converter_class.import_config(config), "head": cls.head_converter_class.import_config(config), + "hidden_size": config["hidden_size"], "tied_embedding_weight": config["tie_word_embeddings"], } @@ -526,7 +521,10 @@ def export_config(cls, config: GPTBaseModelConfig) -> dict: cls.embeddings_converter_class.export_config(config.embeddings), cls.decoder_converter_class.export_config(config.decoder), cls.head_converter_class.export_config(config.head), - {"tie_word_embeddings": config.tied_embedding_weight}, + { + "tie_word_embeddings": config.tied_embedding_weight, + "hidden_size": config.hidden_size, + }, ) @classmethod diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2c1fb0e4a..158bbd92c 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -30,16 +30,14 @@ class GPTBaseModel[ConfigType: GPTBaseModelConfig](LanguageModel[ConfigType], Ba def __init__( self, - config: GPTBaseModelConfig, + config: ConfigType, distributed_config: DistributedConfig, ): - super().__init__(config, distributed_config) + super().__init__(config, distributed_config, lr_scale=self._config.lr_scale, peft=self._config.peft) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) - param.init_parameter = get_init_megatron( - param, self._config.decoder.block, config.embeddings.hidden_size - ) # Noqa + param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType diff --git a/fast_llm/utils.py b/fast_llm/utils.py index bbd69ae8a..1f9feceb4 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -316,9 +316,7 @@ def new_decorator(*args, **kwargs): return new_decorator -def compare_nested( - config_a, config_b, errors: list | None = None, prefix: tuple = (), ignore_missing: tuple[str, ...] = () -): +def compare_nested(config_a, config_b, errors: list | None = None, prefix: tuple = ()): if errors is None: errors = [] # Check for equality of both values and types. diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 0de823e2a..d65d33a8b 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -255,7 +255,7 @@ def test_lm_head( logit_weight = torch.nn.Parameter( torch.empty( VOCAB_SIZE, HIDDEN_SIZE, dtype=distributed.config.compute_dtype.torch, device=distributed.device - ).normal_(config.embeddings.hidden_size**-0.5) + ).normal_(config.hidden_size**-0.5) ) else: logit_weight = None From 4a9698003c78326108f465d89da5b82800dc4366 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 6 Oct 2025 16:27:07 -0400 Subject: [PATCH 2/7] fixes --- Dockerfile | 5 ++++ examples/mistral.yaml | 2 +- fast_llm/engine/checkpoint/huggingface.py | 5 ++-- .../layers/language_model/language_model.py | 4 +-- fast_llm/models/gpt/model.py | 2 +- tests/layers/test_lm_head.py | 26 +++++++++---------- tests/test_config.py | 13 ++++------ tests/utils/model_configs.py | 2 +- 8 files changed, 29 insertions(+), 30 deletions(-) diff --git a/Dockerfile b/Dockerfile index 526026fa4..00e13d957 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,3 +47,8 @@ COPY --chmod=777 ./tests tests COPY --chmod=777 ./tools tools COPY --chmod=777 ./fast_llm_external_models fast_llm_external_models COPY --chmod=777 --exclude=./fast_llm/csrc/ ./fast_llm/ fast_llm/ + +# Set a dummy default user so we don't run in root by default. +# The image is still compatible with any user id. +RUN useradd user +USER user diff --git a/examples/mistral.yaml b/examples/mistral.yaml index 2e4a57de7..904325c5c 100644 --- a/examples/mistral.yaml +++ b/examples/mistral.yaml @@ -28,7 +28,6 @@ optimizer: model: base_model: embeddings: - hidden_size: 4096 vocab_size: 32000 dropout: 0.0 decoder: @@ -58,6 +57,7 @@ model: normalization: type: rms_norm epsilon: 1.0e-05 + hidden_size: 4096 tied_embedding_weight: false multi_stage: zero_stage: 2 diff --git a/fast_llm/engine/checkpoint/huggingface.py b/fast_llm/engine/checkpoint/huggingface.py index afe381295..96fb53321 100644 --- a/fast_llm/engine/checkpoint/huggingface.py +++ b/fast_llm/engine/checkpoint/huggingface.py @@ -150,7 +150,6 @@ def _load_weights( ].values() } elif (config.path / transformers.utils.WEIGHTS_NAME).is_file(): - # TODO: Prevent unsafe by default paths = {config.path / transformers.utils.WEIGHTS_NAME} elif (config.path / transformers.utils.WEIGHTS_INDEX_NAME).is_file(): logger.info(f"Loading index from {config.path / transformers.utils.WEIGHTS_INDEX_NAME}") @@ -170,7 +169,7 @@ def _load_weights( for key in f.keys(): yield key, "weights", f.get_slice(key) elif path.suffix == ".bin": - # TODO: Prevent unsafe by default - yield from torch.load(path) + # TODO: Confirm that loading works with `weights_only=True` + yield from torch.load(path, weights_only=True) else: raise NotImplementedError(f"Unknown file format for {path}") diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 56d41dc3a..2e46bb57a 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -23,7 +23,7 @@ def __init__( config: ConfigType, distributed_config: DistributedConfig, *, - # Unused, but required by the `BlockBase` interface. + # TODO: Unused, but required by the `BlockBase` interface. hidden_dim: TensorDim | None = None, lr_scale: float | None, peft: PeftConfig | None, @@ -31,7 +31,7 @@ def __init__( super().__init__( config, distributed_config, - hidden_dim=TensorDim("hidden", self._config.hidden_size), + hidden_dim=TensorDim("hidden", config.hidden_size), lr_scale=lr_scale, peft=peft, ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 158bbd92c..efa348ecb 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -33,7 +33,7 @@ def __init__( config: ConfigType, distributed_config: DistributedConfig, ): - super().__init__(config, distributed_config, lr_scale=self._config.lr_scale, peft=self._config.peft) + super().__init__(config, distributed_config, lr_scale=config.lr_scale, peft=config.peft) if self._config.use_megatron_initialization: for param in self.parameters(): Assert.custom(isinstance, param, ParameterMeta) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index d65d33a8b..5c044596f 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -171,22 +171,20 @@ def test_lm_head( } config = GPTBaseModelConfig.from_dict( { - "decoder": { - "num_blocks": 0, - }, - "embeddings": { - "vocab_size": VOCAB_SIZE, - "hidden_size": HIDDEN_SIZE, - }, + "decoder": {"num_blocks": 0}, + "embeddings": {"vocab_size": VOCAB_SIZE}, "head": ( - head_config - if prediction_heads == 1 - else { - "type": "multi_token_prediction", - "head": head_config, - "prediction_heads": prediction_heads, - } + ( + head_config + if prediction_heads == 1 + else { + "type": "multi_token_prediction", + "head": head_config, + "prediction_heads": prediction_heads, + } + ), ), + "hidden_size": HIDDEN_SIZE, }, config_dict, update_type=UpdateType.update, diff --git a/tests/test_config.py b/tests/test_config.py index 326200537..63f2606f1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -74,9 +74,6 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config = GPTModelConfig.from_dict( { "base_model": { - "embeddings": { - "hidden_size": 1024, # Default - }, "decoder": { "block": { "mixer": { @@ -92,6 +89,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, # Default }, + "hidden_size": 1024, # Default "tied_embedding_weight": False, }, "multi_stage": {"zero_stage": 3}, @@ -105,7 +103,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): pretrained_model_config.save_metadata(save_config) base_model_update = { - "embeddings": {"hidden_size": 512, "vocab_size": 1000}, + "embeddings": {"vocab_size": 1000}, "decoder": { "block": { "mixer": { @@ -115,6 +113,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): "normalization": {"implementation": "triton"}, # Update non-default nested }, }, + "hidden_size": 512, "peft": {"type": "lora", "freeze_others": False}, # Update default nested, change type } pretrained_config = PretrainedGPTModelConfig.from_dict( @@ -134,10 +133,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): expected_config["distributed"].update({"seed": 1234, "compute_dtype": "float16"}) if load_config in (ModelConfigType.fast_llm, ModelConfigType.model): expected_config["base_model"] = { - "embeddings": { - "hidden_size": 512, - "vocab_size": 1000, - }, + "embeddings": {"vocab_size": 1000}, "decoder": { "block": { "mixer": { @@ -152,6 +148,7 @@ def test_pretrained_config(load_config: ModelConfigType, result_path): }, "num_blocks": 12, }, + "hidden_size": 512, "tied_embedding_weight": False, "peft": {"freeze_others": False}, } diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 6b313aa8a..77d038259 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -192,7 +192,6 @@ def _update_and_add_testing_config( "embeddings": { "word_embeddings": init_1, "position_embeddings": {"enabled": True, **init_1}, - "hidden_size": 256, "num_position_embeddings": 512, "vocab_size": MODEL_TEST_VOCAB_SIZE, }, @@ -216,6 +215,7 @@ def _update_and_add_testing_config( "num_blocks": 2, }, "head": {"output_weight": init_1}, + "hidden_size": 256, "tied_embedding_weight": True, }, "multi_stage": { From 785413806b34b12ea5d720d3fee2db640885a044 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 6 Oct 2025 17:09:37 -0400 Subject: [PATCH 3/7] fixes --- fast_llm/engine/config_utils/parameter.py | 5 +++-- tests/layers/test_lm_head.py | 16 +++++++--------- tests/utils/model_configs.py | 2 ++ 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/fast_llm/engine/config_utils/parameter.py b/fast_llm/engine/config_utils/parameter.py index 76416d365..c0910c09a 100644 --- a/fast_llm/engine/config_utils/parameter.py +++ b/fast_llm/engine/config_utils/parameter.py @@ -1,7 +1,8 @@ import math import typing -from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import ModuleConfig from fast_llm.engine.config_utils.initialization import Initialization, InitializationConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.common.peft.config import PeftConfig @@ -36,7 +37,7 @@ def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): @config_class() -class ParameterConfig(Config): +class ParameterConfig(ModuleConfig): initialization: InitializationConfig = Field( desc="If provided, override the default initialization method set by the parent layer.", hint=FieldHint.feature, diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 5c044596f..0dc2421a7 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -174,15 +174,13 @@ def test_lm_head( "decoder": {"num_blocks": 0}, "embeddings": {"vocab_size": VOCAB_SIZE}, "head": ( - ( - head_config - if prediction_heads == 1 - else { - "type": "multi_token_prediction", - "head": head_config, - "prediction_heads": prediction_heads, - } - ), + head_config + if prediction_heads == 1 + else { + "type": "multi_token_prediction", + "head": head_config, + "prediction_heads": prediction_heads, + } ), "hidden_size": HIDDEN_SIZE, }, diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 77d038259..c02521d7b 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -344,6 +344,8 @@ def _update_and_add_testing_config( ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, ) +del MODEL_CONFIGS["starcoder_2"].config_dict["model"]["base_model"]["embeddings"]["num_position_embeddings"] + _update_and_add_testing_config( # Main tested model. From 47c8d2010bc2ef16e8772e5fd6403be072dfe9e7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 14 Oct 2025 15:36:49 -0400 Subject: [PATCH 4/7] stuff --- fast_llm/data/config.py | 11 +- fast_llm/data/data/abstract.py | 3 +- fast_llm/data/data/gpt/config.py | 12 +- fast_llm/data/data/gpt/data.py | 52 +-- fast_llm/data/dataset/config.py | 54 ++- fast_llm/data/dataset/gpt/config.py | 163 ++------- fast_llm/data/dataset/gpt/fim.py | 4 +- fast_llm/data/dataset/gpt/indexed.py | 60 ---- fast_llm/data/dataset/gpt/memmap.py | 318 ------------------ fast_llm/data/dataset/gpt/random.py | 4 +- fast_llm/data/dataset/indexed.py | 43 ++- fast_llm/data/dataset/memmap.py | 103 ++++++ fast_llm/data/dataset/sample/__init__.py | 25 ++ fast_llm/data/dataset/sample/abstract.py | 56 +++ fast_llm/data/dataset/sample/config.py | 147 ++++++++ .../data/dataset/sample/language_model.py | 119 +++++++ fast_llm/data/dataset/sample/range.py | 89 +++++ fast_llm/data/dataset/sample/token.py | 95 ++++++ fast_llm/data/dataset/{gpt => }/sampled.py | 146 ++------ fast_llm/data/preparator/gpt_memmap/config.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 15 +- fast_llm/engine/config_utils/data_type.py | 14 +- fast_llm/engine/evaluation/config.py | 4 + .../engine/evaluation/lm_eval/evaluator.py | 2 +- fast_llm/functional/dpo.py | 40 +-- tests/data/common.py | 37 +- tests/data/test_blending.py | 16 +- tests/data/test_concatenate.py | 6 +- tests/data/test_dataset_from_file.py | 4 +- tests/data/test_fim.py | 6 +- tests/data/test_memmap.py | 10 +- tests/data/test_prepare_gpt_memmap.py | 17 +- tests/data/test_sampling.py | 11 +- tests/data/test_slice.py | 10 +- tests/functional/test_functional.py | 4 +- tests/models/test_match_megatron.py | 8 +- tests/utils/dataset.py | 35 +- tests/utils/global_variables.py | 2 +- tools/concatenate_dataset.py | 4 +- 39 files changed, 899 insertions(+), 852 deletions(-) delete mode 100644 fast_llm/data/dataset/gpt/indexed.py delete mode 100644 fast_llm/data/dataset/gpt/memmap.py create mode 100644 fast_llm/data/dataset/memmap.py create mode 100644 fast_llm/data/dataset/sample/__init__.py create mode 100644 fast_llm/data/dataset/sample/abstract.py create mode 100644 fast_llm/data/dataset/sample/config.py create mode 100644 fast_llm/data/dataset/sample/language_model.py create mode 100644 fast_llm/data/dataset/sample/range.py create mode 100644 fast_llm/data/dataset/sample/token.py rename fast_llm/data/dataset/{gpt => }/sampled.py (75%) diff --git a/fast_llm/data/config.py b/fast_llm/data/config.py index 4c041945d..633367c80 100644 --- a/fast_llm/data/config.py +++ b/fast_llm/data/config.py @@ -1,9 +1,13 @@ import enum import pathlib +import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.tokenizer import Tokenizer + class MultiprocessingContext(str, enum.Enum): # Fast but risk of segfaults due to interactions with triton @@ -29,7 +33,7 @@ class TokenizerConfig(Config): hint=FieldHint.deprecated, valid=check_field(Assert.eq, TokenizerFromFile), ) - path: pathlib.Path | None = Field( + path: pathlib.Path = Field( default=None, desc="Path to the tokenizer file.", hint=FieldHint.core, @@ -39,3 +43,8 @@ class TokenizerConfig(Config): desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.", hint=FieldHint.core, ) + + def get_tokenizer(self) -> "Tokenizer": + from fast_llm.data.tokenizer import Tokenizer + + return Tokenizer(self) diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index e24d39985..22f5fd194 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -5,6 +5,7 @@ from fast_llm.config import Configurable from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SamplingParameters +from fast_llm.data.dataset.sample import Batch from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig @@ -47,5 +48,5 @@ def get_iterator( num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[typing.Any]: + ) -> typing.Iterator[Batch]: pass diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index efee46959..cb8cb38f6 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,9 +1,9 @@ import logging from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class -from fast_llm.data.config import MultiprocessingContext, TokenizerConfig +from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig +from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -19,17 +19,13 @@ class GPTDataConfig(DataConfig): _abstract = False - tokenizer: TokenizerConfig = Field( - desc="Configuration for the tokenizer (for FIM).", - hint=FieldHint.feature, - ) # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, GPTSampledDatasetConfig] = Field( + datasets: dict[str, SampledDatasetConfig] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: GPTSamplingConfig = FieldUpdate() + sampling: SamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6724afb59..1934b617b 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -3,9 +3,7 @@ import pathlib import typing import warnings -from functools import partial -import numpy as np import torch import torch.utils.data @@ -14,14 +12,13 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters -from fast_llm.data.dataset.gpt.sampled import GPTSample from fast_llm.data.dataset.monitor import DatasetMonitor +from fast_llm.data.dataset.sample.language_model import LanguageModelBatch from fast_llm.data.iterator import SampledDatasetIterator -from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.schedule.config import BatchConfig +from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert logger = logging.getLogger(__name__) @@ -36,28 +33,6 @@ class GPTBatch: rejected_spans: list[torch.Tensor] | None = None -def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: - stacked_ids = np.stack([sample.token_ids for sample in batch]) - stacked_spans = None - sequence_lengths = None - stacked_chosen_spans = None - stacked_rejected_spans = None - if sampling_parameters.use_loss_masking_spans: - stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] - if sampling_parameters.use_preference_loss_spans: - stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch] - stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] - if not sampling_parameters.cross_document_attention: - sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] - return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), - loss_masking_spans=stacked_spans, - sequence_lengths=sequence_lengths, - chosen_spans=stacked_chosen_spans, - rejected_spans=stacked_rejected_spans, - ) - - class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ A global class for all dataset needs, including loading, splitting, sampling and iteration. @@ -67,7 +42,6 @@ class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): _datasets: dict[str, SampledDataset] _sampling_parameters: dict[str, GPTSamplingParameters] - _tokenizer: Tokenizer | None _is_setup: bool = False def __init__( @@ -108,7 +82,6 @@ def setup( ) log_main_rank(f"Preparing dataset. This may take several minutes.") - self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer) if self._cache_directory is None: # TODO: Avoid this @@ -116,11 +89,6 @@ def setup( self._datasets = {} for dataset_name, sampling_parameters in self._sampling_parameters.items(): - if self._tokenizer is not None: - # NOTE: Some models like Qwen2-1.5B-Instruct - # have vocab_size bigger in model config than in tokenizer - # TODO: Still, is it too constraining? - Assert.geq(sampling_parameters.vocab_size, self._tokenizer.vocab_size) if sampling_parameters.num_samples > 0: sampling = GPTSamplingData( config=self._config.sampling, @@ -128,7 +96,6 @@ def setup( cache_directory=self._cache_directory, distributed=distributed, dataset_name=dataset_name, - tokenizer=self._tokenizer, ) dataset = self._config.datasets[dataset_name].build_and_sample(sampling) self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms) @@ -136,21 +103,16 @@ def setup( safe_barrier(self._distributed.world_group, "data_preparation", timeout) self._is_setup = True - @property - def tokenizer(self) -> Tokenizer: - assert self._is_setup - return self._tokenizer - def get_iterator( self, - batch_config: BatchConfig, + batch_config: GPTBatchConfig, dataset_name: str, *, consumed_samples: int, num_workers: int, prefetch_factor: int | None = None, timeout: float = 60, - ) -> typing.Iterator[typing.Any]: + ) -> typing.Iterator[LanguageModelBatch]: assert self._is_setup # Some dataset names may come from phases and are capitalized, @@ -175,10 +137,8 @@ def get_iterator( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, - collate_fn=partial( - gpt_data_collate_fn, - sampling_parameters=sampling_parameters, - ), + # TODO: ====== Make sure the samples are compatible ===== + collate_fn=LanguageModelBatch.from_samples, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) ) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 0c1b0cd09..49025a60a 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -1,4 +1,5 @@ import dataclasses +import enum import functools import itertools import math @@ -14,17 +15,41 @@ from fast_llm.engine.distributed.distributed import Distributed +class ShufflingType(str, enum.Enum): + # Shuffle all epochs together. Not extendable. + full = "full" + # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. + epoch = "epoch" + # Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones. + skip_first_epoch = "skip_first_epoch" + # Disable shuffling entirely. + disabled = "disabled" + + @config_class() class SamplingConfig(Config): """ A dataset-dependent configuration for sampling. """ + # TODO: ====== DocumentSamplingConfig? ====== seed: int = Field( default=784569, desc="Seed for random sampling.", hint=FieldHint.feature, ) + gpu: bool = Field( + default=True, + desc="Enable fast sampling on GPU." + " Note that random sampling works differently on GPU," + " so the sample won't match the CPU equivalent.", + hint=FieldHint.feature, + ) + shuffle: ShufflingType = Field( + default=ShufflingType.epoch, + desc="Shuffling strategy.", + hint=FieldHint.feature, + ) @dataclasses.dataclass(kw_only=True) @@ -34,6 +59,14 @@ class SamplingParameters: """ num_samples: int + # TODO: ====== Always return sequence lengths, let the model decide. ====== + cross_document_attention: bool = True + # TODO: ====== DocumentSamplingParameter? ====== + truncate_documents: bool = True + sequence_length: int + # How many extra tokens to add to the sequence length. + # This is used to provide labels even for the last tokens in the sequence. + extra_tokens: int = 1 @dataclasses.dataclass(kw_only=True) @@ -68,7 +101,7 @@ class DatasetConfig(Config): _abstract: typing.ClassVar[bool] = True -@config_class() +@config_class(registry=True) class SampledDatasetConfig(DatasetConfig): """ A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. @@ -93,7 +126,7 @@ def _build(self) -> "IndexedDataset": raise NotImplementedError() -@config_class() +@config_class(dynamic_type={SampledDatasetConfig: "concatenated"}) class ConcatenatedDatasetConfig(SamplableDatasetConfig): """ Concatenate multiple indexed datasets as if they were one. @@ -116,13 +149,10 @@ class ConcatenatedDatasetConfig(SamplableDatasetConfig): def build(self) -> "ConcatenatedDataset": from fast_llm.data.dataset.indexed import ConcatenatedDataset - return self._build(ConcatenatedDataset) - - def _build[T: ConcatenatedDataset](self, cls: type[T]) -> T: - return cls(self.name, [dataset.build() for dataset in self.datasets]) + return ConcatenatedDataset(self.name, [dataset.build() for dataset in self.datasets]) -@config_class() +@config_class(dynamic_type={SampledDatasetConfig: "slice"}) class DatasetSliceConfig(SamplableDatasetConfig): """ Use a fraction of an indexed dataset, specified by the range (begin, end). @@ -152,12 +182,10 @@ class DatasetSliceConfig(SamplableDatasetConfig): def build(self) -> "DatasetSlice": from fast_llm.data.dataset.indexed import DatasetSlice - return self._build(DatasetSlice) - - def _build[T: DatasetSlice](self, cls: type[T]) -> T: dataset = self.dataset.build() size = len(dataset) - return cls( + + return DatasetSlice( f"{dataset.name}_{self.begin}_{self.end}", dataset, round(self.begin * size), @@ -165,7 +193,7 @@ def _build[T: DatasetSlice](self, cls: type[T]) -> T: ) -@config_class() +@config_class(dynamic_type={SampledDatasetConfig: "sampled"}) class SampledDatasetUpdateConfig(SampledDatasetConfig): """ Wrap a dataset to explicitly sample from it and optionally update its configuration parameters. @@ -186,7 +214,7 @@ def build_and_sample(self, data: SamplingData) -> SampledDataset: return self.dataset.build_and_sample(data.update_config(self.sampling)) -@config_class() +@config_class(dynamic_type={SampledDatasetConfig: "blended"}) class BlendedDatasetConfig(SampledDatasetConfig): _abstract = False name: str = Field( diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 656cd7d24..078df37b8 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -1,63 +1,18 @@ import dataclasses -import enum import pathlib import time import typing import yaml -from fast_llm.config import Config, Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none +from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset -from fast_llm.data.dataset.config import ( - BlendedDatasetConfig, - ConcatenatedDatasetConfig, - DatasetSliceConfig, - IndexedDatasetConfig, - SamplableDatasetConfig, - SampledDatasetConfig, - SampledDatasetUpdateConfig, - SamplingConfig, - SamplingData, - SamplingParameters, -) +from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset, GPTDatasetSlice, GPTIndexedDataset - from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset - from fast_llm.data.tokenizer import Tokenizer - - -class ShufflingType(str, enum.Enum): - # Shuffle all epochs together. Not extendable. - full = "full" - # Shuffle all epochs separately. Default mode, recommended if the dataset doesn't come pre-shuffled. - epoch = "epoch" - # Shuffle all epochs except the first one. Recommended for pre-shuffled datasets, especially big ones. - skip_first_epoch = "skip_first_epoch" - # Disable shuffling entirely. - disabled = "disabled" - - -@config_class() -class GPTSamplingConfig(SamplingConfig): - """ - A dataset-dependent configuration for sampling. - """ - - gpu: bool = Field( - default=True, - desc="Enable fast sampling on GPU." - " Note that random sampling works differently on GPU," - " so the sample won't match the CPU equivalent.", - hint=FieldHint.feature, - ) - shuffle: ShufflingType = Field( - default=ShufflingType.epoch, - desc="Shuffling strategy.", - hint=FieldHint.feature, - ) @dataclasses.dataclass(kw_only=True) @@ -66,15 +21,11 @@ class GPTSamplingParameters(SamplingParameters): Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ - sequence_length: int + # TODO: ====== The dataset should know it already ====== vocab_size: int + # TODO: ====== Where to put? ====== use_loss_masking_spans: bool = False use_preference_loss_spans: bool = False - cross_document_attention: bool = True - truncate_documents: bool = True - # How many extra tokens to add to the sequence length. - # This is used to provide labels even for the last tokens in the sequence. - extra_tokens: int = 1 @dataclasses.dataclass(kw_only=True) @@ -84,29 +35,11 @@ class GPTSamplingData(SamplingData): usage-dependent ones (`GPTSamplingParameters`), and others set by the `Data`. """ - config: GPTSamplingConfig parameters: GPTSamplingParameters - tokenizer: "Tokenizer" - -@config_class(registry=True) -class GPTSampledDatasetConfig(SampledDatasetConfig): - pass - -@config_class() -class GPTSamplableDatasetConfig(SamplableDatasetConfig, GPTSampledDatasetConfig): - pass - - -@config_class() -class GPTIndexedDatasetConfig(GPTSamplableDatasetConfig, IndexedDatasetConfig): - def build(self) -> "GPTIndexedDataset": - raise NotImplementedError() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig(GPTSamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "random"}) +class GPTRandomDatasetConfig(SamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -120,68 +53,8 @@ def build(self) -> "GPTRandomDataset": return GPTRandomDataset(self.name) -@config_class(dynamic_type={GPTSampledDatasetConfig: "memmap"}) -class GPTMemmapDatasetConfig(GPTIndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - path: pathlib.Path = Field( - default=None, - desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", - hint=FieldHint.core, - ) - num_documents: int | None = Field( - default=None, - desc="Expected number of documents in the dataset.", - hint=FieldHint.optional, - ) - num_tokens: int | None = Field( - default=None, - desc="Expected number of tokens in the dataset.", - hint=FieldHint.optional, - ) - - def build(self) -> "GPTMemmapDataset": - from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset - - return GPTMemmapDataset(str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "concatenated"}) -class GPTConcatenatedDatasetConfig(ConcatenatedDatasetConfig, GPTIndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - datasets: list[GPTIndexedDatasetConfig] = FieldUpdate() - - def build(self) -> "GPTConcatenatedDataset": - from fast_llm.data.dataset.gpt.indexed import GPTConcatenatedDataset - - return self._build(GPTConcatenatedDataset) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "slice"}) -class GPTDatasetSliceConfig(DatasetSliceConfig, GPTIndexedDatasetConfig): - _abstract: typing.ClassVar[bool] = False - dataset: GPTIndexedDatasetConfig = FieldUpdate() - - def build(self) -> "GPTDatasetSlice": - from fast_llm.data.dataset.gpt.indexed import GPTDatasetSlice - - return self._build(GPTDatasetSlice) - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "sampled"}) -class GPTSampledDatasetUpdateConfig(SampledDatasetUpdateConfig, GPTSampledDatasetConfig): - _abstract = False - sampling: GPTSamplingConfig = FieldUpdate() - dataset: GPTSampledDatasetConfig = FieldUpdate() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "blended"}) -class GPTBlendedDatasetConfig(BlendedDatasetConfig, GPTSampledDatasetConfig): - _abstract: typing.ClassVar[bool] = False - datasets: list[GPTSampledDatasetConfig] = FieldUpdate() - - -@config_class(dynamic_type={GPTSampledDatasetConfig: "file"}) -class GPTDatasetFromFileConfig(GPTSamplableDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "file"}) +class GPTDatasetFromFileConfig(SamplableDatasetConfig): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -195,12 +68,12 @@ def build_and_sample(self, sampling: SamplingData) -> SampledDataset: def build(self) -> SamplableDataset: config = self._load_config() - assert isinstance(config, GPTSamplableDatasetConfig) + assert isinstance(config, SamplableDatasetConfig) return config.build() def _load_config(self): assert self.path.is_file(), f"File {self.path} does not exist." - return GPTSampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) + return SampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) def _convert_paths(self, config): # Recursively convert paths relative to `self.path.parent` to make them relative to cwd. @@ -224,6 +97,10 @@ class FimConfig(Config): Configuration for FIM. """ + tokenizer: TokenizerConfig = Field( + desc="Configuration for the tokenizer.", + hint=FieldHint.feature, + ) rate: float = Field( # TODO: Use meaningful default now that fim is a wrapper? default=0.0, @@ -286,15 +163,15 @@ class FimConfig(Config): ) -@config_class(dynamic_type={GPTSampledDatasetConfig: "fim"}) -class GPTFimSampledDatasetConfig(GPTSampledDatasetConfig, FimConfig): +@config_class(dynamic_type={SampledDatasetConfig: "fim"}) +class GPTFimSampledDatasetConfig(SampledDatasetConfig, FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - dataset: GPTSampledDatasetConfig = Field( + dataset: SampledDatasetConfig = Field( default=None, desc="The dataset to wrap with fim.", hint=FieldHint.core, @@ -309,8 +186,8 @@ def build_and_sample( return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling) -@config_class(dynamic_type={GPTSampledDatasetConfig: "test_slow"}) -class GPTTestSlowDatasetConfig(GPTSampledDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "test_slow"}) +class GPTTestSlowDatasetConfig(SampledDatasetConfig): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. """ diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 2b2c8b3be..7c86aef81 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -2,7 +2,7 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.sampled import GPTSample from fast_llm.engine.distributed.config import MAX_SEED @@ -26,7 +26,7 @@ def __init__( self._dataset = dataset self._seed = sampling.config.seed - self._tokenizer = sampling.tokenizer + self._tokenizer = self._config.tokenizer.get_tokenizer() if self._tokenizer is None: raise ValueError("Fim requires a tokenizer") self._suffix_tok_id, self._prefix_tok_id, self._middle_tok_id, self._pad_tok_id = ( diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py deleted file mode 100644 index 896229772..000000000 --- a/fast_llm/data/dataset/gpt/indexed.py +++ /dev/null @@ -1,60 +0,0 @@ -import abc -import typing - -import numpy as np - -from fast_llm.data.dataset.gpt.config import GPTSamplingData -from fast_llm.data.dataset.indexed import ConcatenatedDataset, DatasetSlice, IndexedDataset - -if typing.TYPE_CHECKING: - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - -class GPTIndexedDataset(IndexedDataset): - @abc.abstractmethod - def get_document_sizes(self) -> np.ndarray: - """ - The size of each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - - @abc.abstractmethod - def get_document_size(self, index: int) -> int: - """ - The size of a document in the dataset. - """ - - def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset - - return GPTSampledIndexedDataset(self, sampling) - - -class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): - """ - A GPT dataset, which reads samples from (a split of) a `MMapIndexedDataset` pointing to a GPT dataset. - """ - - _dataset: GPTIndexedDataset - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] - - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) - - -class GPTConcatenatedDataset[IndexedDatasetType: GPTIndexedDataset]( - ConcatenatedDataset[IndexedDatasetType], GPTIndexedDataset -): - _datasets: list[GPTIndexedDataset] - - def get_document_sizes(self) -> np.ndarray: - # TODO: This can be really big. - return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets]) - - def get_document_size(self, index: int) -> int: - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py deleted file mode 100644 index f39fd56f4..000000000 --- a/fast_llm/data/dataset/gpt/memmap.py +++ /dev/null @@ -1,318 +0,0 @@ -import pathlib -import struct -import typing - -import numpy as np - -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample -from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, MEMMAP_DTYPES_INV, MEMMAP_INDEX_HEADER -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.utils import Assert, div - - -class GPTMemmapDataset(GPTIndexedDataset): - """ - A memory map dataset, which handles lazy loading of a pre-processed dataset in the Megatron-LM format, - i.e. a pair of numpy file containing - 1. A data file (`{prefix}.bin`) containing a flat buffer containing the concatenated, tokenized documents. - 2. An index file (`{prefix}.idx`) containing a list of document sizes and pointers (start index) in the data file. - See https://github.com/NVIDIA/Megatron-LM?tab=readme-ov-file#data-preprocessing for more details. - """ - - def __init__( - self, - name: str, - prefix: pathlib.Path | str, - num_documents: int | None = None, - num_tokens: int | None = None, - ): - self._init(name, prefix, num_documents, num_tokens) - - def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None, num_tokens: int | None) -> None: - super().__init__() - self._name = name - self._prefix = pathlib.Path(prefix) - self._has_spans = 0 - self._has_preference_spans = False - - with self._prefix.with_suffix(".idx").open("rb") as stream: - Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") - self._version = struct.unpack("= 2: - self._has_spans = struct.unpack("= 3: - self._has_preference_spans = struct.unpack("= 2: - self._spans = [] - self._num_spans = np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_documents, - offset=offset + self._document_sizes.nbytes + self._pointers.nbytes, - ) - span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + self._num_spans.nbytes - self._num_spans_cumsum = np.r_[0, np.cumsum(self._num_spans[:-1], dtype=np.int64)] - for idx in range(self._num_documents): - self._spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=self._num_spans[idx] * 2, - offset=span_offset + self._num_spans_cumsum[idx] * 2 * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) - ) - - # read preference spans - self._chosen_spans = None - self._rejected_spans = None - if self._has_preference_spans and self._version >= 3: - self._chosen_spans = [] - self._rejected_spans = [] - chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes - for idx in range(self._num_documents): - self._chosen_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - rejected_span_offset = ( - offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes - ) - for idx in range(self._num_documents): - self._rejected_spans.append( - np.frombuffer( - self._index_bin_buffer, - dtype=np.int32, - count=2, - offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ) - ) - - self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") - self._bin_buffer = memoryview(self._bin_buffer_mmap) - - self._num_tokens = div(self._bin_buffer_mmap.size, np.dtype(self._dtype).itemsize) - if num_tokens is not None: - assert self._num_tokens == num_tokens - - def __getstate__(self) -> tuple[str, pathlib.Path, int | None, int | None]: - return (self._name, self._prefix, self._num_documents, self._num_tokens) - - def __setstate__(self, state: tuple[str, pathlib.Path, int | None, int | None]): - self._init(*state) - - def __del__(self): - if hasattr(self, "_bin_buffer_mmap"): - self._bin_buffer_mmap._mmap.close() # noqa - del self._bin_buffer_mmap - if hasattr(self, "_index_bin_buffer"): - self._index_bin_buffer_mmap._mmap.close() # noqa - del self._index_bin_buffer_mmap - - def get( - self, - idx: int, - offset: int = 0, - length: int | None = None, - use_loss_masking_spans: bool = False, - use_preference_loss_spans: bool = False, - ) -> GPTSample: - token_ids = np.frombuffer( - self._bin_buffer, - dtype=self._dtype, - count=self._document_sizes[idx] - offset if length is None else length, - offset=self._pointers[idx] + offset * np.dtype(self._dtype).itemsize, - ) - sample_spans = None - if use_loss_masking_spans and self._spans is not None: - sample_spans = self._spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - sample_spans = sample_spans[ - (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) - ] - - # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset - sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - - chosen_span = None - rejected_span = None - - if use_preference_loss_spans: - if not self._has_preference_spans: - raise ValueError("No preference spans found in memmap dataset.") - elif self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") - else: - chosen_span = self._chosen_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] - - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset - - rejected_span = self._rejected_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[ - (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) - ][0] - - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset - - return GPTSample( - token_ids=token_ids, - loss_masking_spans=sample_spans, - chosen_span=chosen_span, - rejected_span=rejected_span, - ) - - @property - def name(self) -> str: - return self._name - - def __len__(self) -> int: - return self._num_documents - - @property - def num_tokens(self) -> int: - return self._num_tokens - - def get_document_sizes(self) -> np.ndarray: - """ - The size of each document in the dataset. - The resulting array could be very large, so this method should be called cautiously, - and derived classes should try to avoid holding the whole array im memory. - """ - return self._document_sizes - - def get_document_size(self, index: int) -> int: - return self._document_sizes[index].item() - - @classmethod - def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GPTSample]): - # Initialize metadata - dtype = None - num_documents = 0 - lengths = [] - pointers = [] - offset = 0 - # number of spans for each document - num_spans = [] - spans = [] - chosen_spans = [] - rejected_spans = [] - - prefix = pathlib.Path(prefix) - prefix.parent.mkdir(parents=True, exist_ok=True) - - # Write the binary data file (.bin) lazily - with prefix.with_suffix(".bin").open("wb") as bin_stream: - for document in documents: - # Infer dtype from the first document - if dtype is None: - dtype = document.token_ids.dtype - assert dtype is not None, "Document dtype could not be inferred from the data." - - # Ensure all documents have the same dtype - assert document.token_ids.dtype == dtype, f"Expected dtype {dtype}, got {document.token_ids.dtype}." - - # Write document to binary file - bin_stream.write(document.token_ids.tobytes(order="C")) - - # Update metadata - doc_length = len(document.token_ids) - lengths.append(doc_length) - pointers.append(offset) - if document.loss_masking_spans is not None: - num_spans.append(len(document.loss_masking_spans)) - spans.append(document.loss_masking_spans) - if document.chosen_span is not None: - chosen_spans.append(document.chosen_span) - if document.rejected_span is not None: - rejected_spans.append(document.rejected_span) - offset += doc_length * np.dtype(dtype).itemsize - num_documents += 1 - - # Finalize metadata arrays - lengths = np.array(lengths, dtype=np.int32) - pointers = np.array(pointers, dtype=np.int64) - num_spans = np.array(num_spans, dtype=np.int32) - if len(spans) > 0: - spans = np.vstack(spans, dtype=np.int32) - else: - spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) - - # Write the index file (.idx) - with prefix.with_suffix(".idx").open("wb") as idx_stream: - idx_stream.write(MEMMAP_INDEX_HEADER) - # Indicates the version - # Version 2 optionally adds loss-masking spans - # Version 3 optionally adds chosen/rejected spans - idx_stream.write(struct.pack(" 0 else 0)) - # Flag to indicate whether preference loss-masking spans are present - idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) - # Data type - idx_stream.write(struct.pack(" int: return self._num_samples def __getitem__(self, idx) -> np.ndarray: - return GPTSample( + return LanguageModelSample( np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 ) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 09ed52779..43f675c81 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -2,19 +2,21 @@ import typing import numpy as np +import torch from fast_llm.data.dataset.abstract import SamplableDataset +from fast_llm.data.dataset.config import SamplingData +from fast_llm.data.dataset.sample import Sample from fast_llm.utils import Assert, padded_cumsum class IndexedDataset(SamplableDataset): """ - A dataset containing a list of samples. - TODO: Move sampling responsibility here? + A dataset containing a list of `documents`, to be merged into `samples`. """ @abc.abstractmethod - def get(self, index: int, *args, **kwargs) -> typing.Any: + def get_document(self, index: int, begin: int, end: int) -> Sample: pass @abc.abstractmethod @@ -23,6 +25,25 @@ def __len__(self) -> int: Number of samples in the dataset. """ + @abc.abstractmethod + def get_document_sizes(self) -> torch.Tensor: + """ + The size of each document in the dataset. + The resulting array could be very large, so this method should be called cautiously, + and derived classes should try to avoid holding the whole array im memory. + """ + + @abc.abstractmethod + def get_document_size(self, index: int) -> int: + """ + The size of a document in the dataset. + """ + + def sample(self, sampling: SamplingData) -> "SampledIndexedDataset": + from fast_llm.data.dataset.sampled import SampledIndexedDataset + + return SampledIndexedDataset(self, sampling) + class DatasetSlice[IndexedDatasetType: IndexedDataset](IndexedDataset): @@ -59,6 +80,13 @@ def get( def __len__(self) -> int: return self._end - self._begin + def get_document_sizes(self) -> torch.Tensor: + # TODO: This can be really big. + return self._dataset.get_document_sizes()[self._begin : self._end] + + def get_document_size(self, index: int) -> int: + return self._dataset.get_document_size(self._begin + index) + @property def name(self) -> str: return self._name @@ -80,9 +108,18 @@ def __len__(self) -> int: return self._dataset_splits[-1].item() def get(self, index: int, *args, **kwargs): + dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") return self._datasets[dataset].get(index - self._dataset_splits[dataset].item(), *args, **kwargs) + def get_document_sizes(self) -> torch.Tensor: + # TODO: This can be really big. + return torch.cat([dataset.get_document_sizes() for dataset in self._datasets]) + + def get_document_size(self, index: int) -> int: + dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + @property def name(self) -> str: return self._name diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py new file mode 100644 index 000000000..1228fcaea --- /dev/null +++ b/fast_llm/data/dataset/memmap.py @@ -0,0 +1,103 @@ +import json +import pathlib +import typing + +import numpy as np +import torch + +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.dataset.sample import Sample +from fast_llm.data.dataset.sample.abstract import MemmapReader +from fast_llm.data.dataset.sample.config import MemmapIndexDatasetReaderConfig + +FILE_HEADER = b"fast_llm_prepared_dataset" + + +class MemmapDataset(IndexedDataset): + """ + A memory map dataset, which handles lazy loading of a pre-processed dataset. + """ + + def __init__( + self, + name: str, + path: pathlib.Path | str, + ): + self._init(name, path) + + def _init(self, name: str, path: pathlib.Path | str) -> None: + super().__init__() + self._name = name + self._path = path + + with self._path.open("rb") as stream: + # Very file type. + assert stream.read(len(FILE_HEADER)) == FILE_HEADER + # Go to reader configs. + stream.seek(int.from_bytes(stream.read(4), signed=False)) + # Read the reader config. + reader_config = MemmapIndexDatasetReaderConfig.from_dict( + json.loads(stream.read(int.from_bytes(stream.read(4), signed=False)).decode("utf-8")) + ) + + self._memmap = np.memmap(self._path, mode="r") + # TODO: ===== Check num_documents, num_tokens ====== + self._reader = reader_config.get_reader(memoryview(self._memmap)) + + def __getstate__(self) -> tuple[str, pathlib.Path]: + return (self._name, self._path) + + def __setstate__(self, state: tuple[str, pathlib.Path]): + self._init(*state) + + def __del__(self): + if hasattr(self, "_memmap"): + self._memmap._mmap.close() # noqa + del self._memmap + + def get( + self, + index: int, + begin: int, + end: int, + ) -> Sample: + return self._reader.get(index, begin, end) + + @property + def name(self) -> str: + return self._name + + def __len__(self) -> int: + return self._reader + + # TODO: ====== needed? ====== + # @property + # def num_tokens(self) -> int: + # return self._reader.num_tokens + + def get_document_sizes(self) -> torch.Tensor: + return self._reader.get_document_sizes() + + def get_document_size(self, index: int) -> int: + return self._reader.get_document_size(index) + + @classmethod + def write_dataset(cls, path: pathlib.Path, documents: typing.Iterable[Sample], reader_class: type[MemmapReader]): + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("wb") as stream: + # Write the file type header. + stream.write(FILE_HEADER) + # Leave space for a pointer to the reader config. + # We write the config at the end since we don't know it yet. + start = stream.tell() + stream.seek(start + 4) + # Write the data. + reader_config = reader_class.write(documents, stream) + # Write the reader config. + config_offset = stream.tell() + reader_config_bytes = json.dumps(reader_config.to_dict()).encode("utf-8") + stream.write(len(reader_config_bytes).to_bytes(4, signed=False)) + stream.write(reader_config_bytes) + # Write a pointer to the reader config. + stream.seek(start) + stream.write(config_offset.to_bytes(4, signed=False)) diff --git a/fast_llm/data/dataset/sample/__init__.py b/fast_llm/data/dataset/sample/__init__.py new file mode 100644 index 000000000..b42cac943 --- /dev/null +++ b/fast_llm/data/dataset/sample/__init__.py @@ -0,0 +1,25 @@ +import abc +import typing + +from fast_llm.data.dataset.sample.abstract import Batch, Sample + + +class LanguageModelSample(Sample): + + @classmethod + @abc.abstractmethod + def merge_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + pass + + @classmethod + @abc.abstractmethod + def merge_into_batch(cls, samples: typing.Iterable[typing.Self]) -> "LanguageModelBatch": + pass + + @abc.abstractmethod + def crop(self, offset: int = 0, length: int | None = None): + pass + + +class LanguageModelBatch(Batch): + pass diff --git a/fast_llm/data/dataset/sample/abstract.py b/fast_llm/data/dataset/sample/abstract.py new file mode 100644 index 000000000..cc33bda48 --- /dev/null +++ b/fast_llm/data/dataset/sample/abstract.py @@ -0,0 +1,56 @@ +import abc +import io +import typing + +from fast_llm.config import Configurable +from fast_llm.data.dataset.sample.config import MemmapIndexDatasetReaderConfig, MemmapReaderBaseConfig + +if typing.TYPE_CHECKING: + import torch + + +class Sample(abc.ABC): + @classmethod + @abc.abstractmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + pass + + @abc.abstractmethod + def crop(self, begin: int, end: int): + pass + + @abc.abstractmethod + def __len__(self) -> int: + pass + + +class Batch(abc.ABC): + @classmethod + @abc.abstractmethod + def from_samples(cls, samples: typing.Iterable[Sample]) -> typing.Self: + pass + + +class MemmapReader[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config) + self._buffer = buffer[self._config.begin : self._config.end] + + @abc.abstractmethod + def get_document(self, index: int, begin: int, end: int) -> Sample: + pass + + @classmethod + @abc.abstractmethod + def write(cls, documents: typing.Iterable[Sample], stream: io.BufferedWriter) -> MemmapReaderBaseConfig: + pass + + +class MemmapIndexedDatasetReader[ConfigType: MemmapIndexDatasetReaderConfig](MemmapReader[ConfigType]): + @abc.abstractmethod + def get_document_sizes(self) -> "torch.Tensor": + pass + + @abc.abstractmethod + def get_document_size(self, index: int) -> int: + pass diff --git a/fast_llm/data/dataset/sample/config.py b/fast_llm/data/dataset/sample/config.py new file mode 100644 index 000000000..61afd627c --- /dev/null +++ b/fast_llm/data/dataset/sample/config.py @@ -0,0 +1,147 @@ +import abc +import pathlib +import typing + +from fast_llm.config import Config, Field, FieldHint, config_class +from fast_llm.data.dataset.config import IndexedDatasetConfig, SampledDatasetConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.data.dataset.memmap import MemmapDataset + from fast_llm.data.dataset.sample.abstract import MemmapIndexedDatasetReader, MemmapReader + from fast_llm.data.dataset.sample.language_model import LanguageModelReader + from fast_llm.data.dataset.sample.range import RangeReader + from fast_llm.data.dataset.sample.token import TokenReader + + +@config_class(registry=True) +class MemmapReaderBaseConfig(Config): + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if cls is MemmapReaderBaseConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass, necessary for loading configs where some components could be absent. + return NullReaderConfig._from_dict(default, strict) + return super()._from_dict(default, strict=strict) + + @abc.abstractmethod + def get_reader(self, buffer: memoryview) -> "MemmapReader|None": + pass + + @property + @abc.abstractmethod + def expected_buffer_size(self) -> int: + """ + The expected buffer size in bytes. Used for self-validation. + """ + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "none"}) +class NullReaderConfig(MemmapReaderBaseConfig): + def get_reader(self, buffer: memoryview) -> None: + return None + + @property + def expected_buffer_size(self) -> int: + return 0 + + +@config_class(registry=True) +class MemmapReaderConfig(MemmapReaderBaseConfig): + begin: int = Field() + end: int = Field() + + @property + def reader_class(self) -> "type[MemmapReader]": + raise NotImplementedError() + + def get_reader(self, buffer: memoryview) -> "MemmapReader": + return self.reader_class(self, buffer) + + def _validate(self): + super()._validate() + Assert.eq(self.end - self.begin, self.expected_buffer_size) + + +@config_class() +class MemmapIndexDatasetReaderConfig(MemmapReaderConfig): + @property + def reader_class(self) -> "type[MemmapIndexedDatasetReader]": + raise NotImplementedError() + + # + def get_reader( + self, + buffer: memoryview, + ) -> "MemmapIndexedDatasetReader": + return self.reader_class(self, buffer[self.begin : self.end]) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "range"}) +class RangeReaderConfig(MemmapReaderConfig): + num_documents: int = Field() + num_ranges: int = Field() + + @property + def reader_class(self) -> "type[RangeReader]": + from fast_llm.data.dataset.sample.range import RangeReader + + return RangeReader + + @property + def expected_buffer_size(self) -> int: + return (self.num_ranges + 1) * 4 * 2 + (self.num_documents + 1) * 4 + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "token"}) +class TokenReaderConfig(MemmapIndexDatasetReaderConfig): + num_documents: int = Field() + num_tokens: int = Field() + data_type: DataType = Field() + + @property + def reader_class(self) -> "type[TokenReader]": + from fast_llm.data.dataset.sample.token import TokenReader + + return TokenReader + + @property + def expected_buffer_size(self) -> int: + return self.num_tokens * self.data_type.numpy.itemsize + (self.num_documents + 1) * 8 + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "language_model"}) +class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): + tokens: TokenReaderConfig = Field() + # Using dynamic type for optional readers for enabling/disabling + loss_masking_spans: MemmapReaderBaseConfig = Field() + preference_spans: MemmapReaderBaseConfig = Field() + + @property + def reader_class(self) -> "type[LanguageModelReader]": + from fast_llm.data.dataset.sample.language_model import LanguageModelReader + + return LanguageModelReader + + @property + def expected_buffer_size(self) -> int: + return ( + self.tokens.expected_buffer_size + + self.loss_masking_spans.expected_buffer_size + + self.preference_spans.expected_buffer_size + ) + + +@config_class(dynamic_type={SampledDatasetConfig: "memmap"}) +class MemmapDatasetConfig(IndexedDatasetConfig): + _abstract: typing.ClassVar[bool] = False + path: pathlib.Path = Field( + default=None, + desc="The path to the dataset, excluding the `.bin` or `.idx` suffix.", + hint=FieldHint.core, + ) + + def build(self) -> "MemmapDataset": + from fast_llm.data.dataset.memmap import MemmapDataset + + return MemmapDataset(str(self.path).replace("/", "__"), self.path) diff --git a/fast_llm/data/dataset/sample/language_model.py b/fast_llm/data/dataset/sample/language_model.py new file mode 100644 index 000000000..96f27187e --- /dev/null +++ b/fast_llm/data/dataset/sample/language_model.py @@ -0,0 +1,119 @@ +import io +import typing + +import numpy as np + +from fast_llm.data.dataset.sample import Batch, Sample +from fast_llm.data.dataset.sample.abstract import MemmapIndexedDatasetReader +from fast_llm.data.dataset.sample.config import LanguageModelReaderConfig, NullReaderConfig +from fast_llm.data.dataset.sample.range import RangeBatch, RangeReader, RangeSample +from fast_llm.data.dataset.sample.token import TokenBatch, TokenReader, TokenSample +from fast_llm.utils import Assert, get_unique + + +class LanguageModelSample(Sample): + def __init__( + self, + tokens: TokenSample, + loss_masking_spans: RangeSample | None = None, + preference_spans: RangeSample | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.preference_spans = preference_spans + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + TokenSample.from_documents(document.tokens for document in documents), + _merge_optional(RangeSample.from_documents, (document.loss_masking_spans for document in documents)), + _merge_optional(RangeSample.from_documents, (document.preference_spans for document in documents)), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + None if self.loss_masking_spans is None else self.loss_masking_spans.crop(begin, end), + None if self.preference_spans is None else self.preference_spans.crop(begin, end), + ) + + def __len__(self) -> int: + return len(self.tokens) + + +class LanguageModelBatch(Batch): + def __init__( + self, + tokens: TokenBatch, + loss_masking_spans: RangeBatch | None = None, + preference_spans: RangeBatch | None = None, + ): + self.tokens = tokens + self.loss_masking_spans = loss_masking_spans + self.preference_spans = preference_spans + + @classmethod + def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: + return cls( + TokenBatch.from_samples(sample.tokens for sample in samples), + _merge_optional(RangeBatch.from_samples, (sample.loss_masking_spans for sample in samples)), + _merge_optional(RangeBatch.from_samples, (sample.preference_spans for sample in samples)), + ) + + +def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: + return None if any(arg is None for arg in args) else fn(args) + + +class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + # Using `buffer` and not `self._buffer` because nested offsets (`begin`, `end`) are global. + self._tokens = self._config.tokens.get_reader(buffer) + self._loss_masking_spans = self._config.loss_masking_spans.get_reader(buffer) + self._preference_spans = self._config.preference_spans.get_reader(buffer) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + return LanguageModelSample( + self._tokens.get_document(index, begin, end), + self._loss_masking_spans.get_document(index, begin, end), + self._preference_spans.get_document(index, begin, end), + ) + + def get_document_sizes(self) -> "np.ndarray": + return self._tokens.get_document_sizes() + + def get_document_size(self, index: int) -> int: + return self._tokens.get_document_size(index) + + @classmethod + def write( + cls, documents: typing.Iterable[LanguageModelSample], stream: io.BufferedWriter + ) -> LanguageModelReaderConfig: + begin = stream.tell() + tokens = TokenReader.write((document.tokens for document in documents), stream) + + # Ensure either all samples have loss masking spans or none of them do. + if get_unique(document.loss_masking_spans is not None for document in documents): + loss_masking_spans = RangeReader.write((document.loss_masking_spans for document in documents), stream) + else: + loss_masking_spans = NullReaderConfig() + + # If enabled, ensure all samples have exactly 2 spans + num_preference_spans = get_unique( + None if document.preference_spans is None else len(document.preference_spans.ranges) + for document in documents + ) + if num_preference_spans == 2: + preference_spans = RangeReader.write((document.preference_spans for document in documents), stream) + else: + Assert.none(num_preference_spans) + preference_spans = NullReaderConfig() + + return LanguageModelReaderConfig( + begin=begin, + end=stream.tell(), + tokens=tokens, + loss_masking_spans=loss_masking_spans, + preference_spans=preference_spans, + ) diff --git a/fast_llm/data/dataset/sample/range.py b/fast_llm/data/dataset/sample/range.py new file mode 100644 index 000000000..def3e7986 --- /dev/null +++ b/fast_llm/data/dataset/sample/range.py @@ -0,0 +1,89 @@ +import abc +import io +import typing + +import numpy as np + +from fast_llm.data.dataset.sample import Sample +from fast_llm.data.dataset.sample.abstract import MemmapReader +from fast_llm.data.dataset.sample.config import RangeReaderConfig + + +class RangeSample(Sample): + """ + A reusable component holding a set of ranges in a sample. + """ + + def __init__(self, sample_size: int, ranges: tuple[tuple[int, int], ...] = ()): + self.sample_size = sample_size + self.ranges = ranges + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + document: RangeSample + ranges = [] + sample_size = 0 + for document in documents: + for begin, end in document.ranges: + ranges.extend((begin + sample_size, end + sample_size)) + sample_size += document.sample_size + return cls(sample_size, tuple(ranges)) + + def crop(self, begin: int, end: int) -> typing.Self: + sample_size = end - begin + cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, sample_size)) for begin_, end_ in self.ranges) + return self.__class__(sample_size, tuple((begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_)) + + def __len__(self) -> int: + return self.sample_size + + +class RangeBatch(abc.ABC): + def __init__(self, ranges: tuple[tuple[tuple[int, int], ...], ...]): + self._ranges = ranges + + @classmethod + def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: + return cls(tuple(sample.ranges for sample in samples)) + + +class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): + # ====== TODO: Standardize to end = 1 past last index (spans use last index) ====== + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + self._ranges = np.frombuffer( + self._buffer, + dtype=np.int32, + count=self._config.num_ranges, + ).reshape(-1, 2) + self._count_cumsums = np.frombuffer( + self._buffer, + dtype=np.int32, + count=self._config.num_documents + 1, + offset=self._ranges.nbytes, + ) + + def get(self, index: int, begin: int, end: int) -> RangeSample: + sample_size = end - begin + cropped_ranges = ( + (max(begin_ - begin, 0), min(end_ - begin, sample_size)) + for begin_, end_ in self._ranges[self._count_cumsums[index] : self._count_cumsums[index + 1]].tolist() + ) + return RangeSample(sample_size, tuple((begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_)) + + @classmethod + def write(cls, documents: typing.Iterable[RangeSample], stream: io.BufferedWriter) -> RangeReaderConfig: + begin = stream.tell() + count_cumsum = [0] + for document in documents: + stream.write(np.array(document.ranges, dtype=np.uint32).tobytes(order="C")) + count_cumsum.append(count_cumsum[-1] + len(document.ranges)) + stream.write(np.array(count_cumsum, dtype=np.uint32).tobytes(order="C")) + end = stream.tell() + + return RangeReaderConfig( + begin=begin, + end=end, + num_documents=len(count_cumsum) - 1, + num_ranges=count_cumsum[-1], + ) diff --git a/fast_llm/data/dataset/sample/token.py b/fast_llm/data/dataset/sample/token.py new file mode 100644 index 000000000..649f4e0eb --- /dev/null +++ b/fast_llm/data/dataset/sample/token.py @@ -0,0 +1,95 @@ +import io +import typing + +import numpy as np +import torch + +from fast_llm.data.dataset.sample import Batch, Sample +from fast_llm.data.dataset.sample.abstract import MemmapIndexedDatasetReader +from fast_llm.data.dataset.sample.config import TokenReaderConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert + + +class TokenSample(Sample): + def __init__(self, tokens: torch.Tensor, lengths: list[int] | None = None): + self.tokens = tokens + # Length of each document in the sample. TODO: Use cumsums instead? + if lengths is None: + lengths = [len(tokens)] + self.lengths = lengths + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls( + torch.cat([document.tokens for document in documents]), + sum((document.lengths for document in documents), []), + ) + + def crop(self, begin: int, end: int) -> typing.Self: + # We only expect to crop documents, not samples. TODO: Support other cases? + Assert.eq(self.lengths, [len(self.tokens)]) + return self.__class__(self.tokens[begin:end], [end - begin]) + + def __len__(self) -> int: + return len(self.tokens) + + +class TokenBatch(Batch): + def __init__(self, tokens: torch.Tensor, lengths: list[list[int]]) -> None: + self.tokens = tokens + self.lengths = lengths + + @classmethod + def from_samples(cls, samples: typing.Iterable[TokenSample]) -> typing.Self: + return cls( + torch.stack([sample.tokens for sample in samples]), + [sample.lengths for sample in samples], + ) + + +class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview): + super().__init__(config, buffer) + self._tokens = np.frombuffer( + self._buffer, + dtype=self._config.data_type.numpy, + count=self._config.num_tokens, + ) + self._size_cumsums = np.frombuffer( + self._buffer, dtype=np.uint64, count=self._config.num_documents + 1, offset=self._tokens.nbytes + ) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + begin_ = self._size_cumsums[index].item() + return TokenSample(torch.from_numpy(self._tokens[begin_ + begin : begin_ + end]), [end - begin]) + + def get_document_sizes(self) -> torch.Tensor: + return torch.from_numpy(self._size_cumsums[1:] - self._size_cumsums[:-1]) + + def get_document_size(self, index: int) -> int: + return self._size_cumsums[index + 1].item() - self._size_cumsums[index].item() + + @classmethod + def write(cls, documents: typing.Iterable[TokenSample], stream: io.BufferedWriter) -> TokenReaderConfig: + begin = stream.tell() + size_cumsum = [0] + data_type = None + for document in documents: + if data_type is None: + data_type = document.tokens.dtype + else: + Assert.eq(data_type, document.tokens.dtype) + stream.write(document.tokens.numpy().tobytes()) + size_cumsum.append(size_cumsum[-1] + len(document.tokens)) + + # Write the cumsums (pointers and sizes) + stream.write(np.array(size_cumsum, dtype=np.uint64).tobytes(order="C")) + + return TokenReaderConfig( + begin=begin, + end=stream.tell(), + num_documents=len(size_cumsum) - 1, + num_tokens=size_cumsum[-1], + data_type=DataType.from_torch(data_type), + ) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/sampled.py similarity index 75% rename from fast_llm/data/dataset/gpt/sampled.py rename to fast_llm/data/dataset/sampled.py index 95006f18e..493300a6b 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -1,4 +1,3 @@ -import dataclasses import logging import math import pathlib @@ -10,11 +9,10 @@ import yaml from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTSamplingData, ShufflingType -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset +from fast_llm.data.dataset.config import SamplingData, ShufflingType +from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.utils import Assert try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa @@ -26,13 +24,13 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class GPTSample: - token_ids: np.ndarray - loss_masking_spans: np.ndarray | None = None - chosen_span: np.ndarray | None = None - rejected_span: np.ndarray | None = None - sequence_lengths: np.ndarray | None = None +# @dataclasses.dataclass +# class GPTSample: +# token_ids: np.ndarray +# loss_masking_spans: np.ndarray | None = None +# chosen_span: np.ndarray | None = None +# rejected_span: np.ndarray | None = None +# sequence_lengths: np.ndarray | None = None class MemmapArray: @@ -75,17 +73,16 @@ def _lazy_load(self): TOKEN_CUMSUM_RATE = 10 -class GPTSampledIndexedDataset(SampledDataset): +class SampledIndexedDataset(SampledDataset): """ A sampled GPT dataset. """ def __init__( self, - indexed_dataset: GPTIndexedDataset, - sampling: GPTSamplingData, + indexed_dataset: IndexedDataset, + sampling: SamplingData, ): - assert isinstance(sampling, GPTSamplingData) self._indexed_dataset = indexed_dataset self._config = sampling.config self._parameters = sampling.parameters @@ -115,13 +112,6 @@ def __init__( self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) self._yaml_path = base_path.with_suffix(".yaml") - # keep document sizes and len filtered docs for preference loss masking - if self._parameters.use_preference_loss_spans: - self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) - self._doc_length_filtered_indicies = MemmapArray( - base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") - ) - # Sample or validate the dataset of a given rank. if sampling.distributed.config.rank == sampling.get_next_rank(): self._sample() @@ -133,7 +123,7 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + document_sizes = self._indexed_dataset.get_document_sizes().to(self._device) documents_per_epoch = document_sizes.numel() tokens_per_epoch = document_sizes.sum().item() @@ -159,10 +149,7 @@ def _sample(self) -> None: # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. - if self._parameters.use_preference_loss_spans: - documents_per_epoch = (~long_docs_filter).sum().item() - num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) - elif self._truncate_documents: + if self._truncate_documents: num_epochs = math.ceil( (self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens) / tokens_per_epoch @@ -266,24 +253,6 @@ def _sample(self) -> None: else: raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - if self._parameters.use_preference_loss_spans: - yaml_data["unshuffled_tokens"] = 0 # not used, ignore - - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # apply shuffling on doc_length_filtered_indicies - if shuffled_epochs > 0: - self._document_shuffling.save( - document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) - ) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - return - # To get a sample on the fly we need to know where it begins, # and this is a non-trivial information because the documents have variable length. # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. @@ -383,39 +352,6 @@ def __getitem__(self, index: int) -> typing.Any: """ self._lazy_load() - if self._parameters.use_preference_loss_spans: - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._parameters.use_loss_masking_spans, - use_preference_loss_spans=self._parameters.use_preference_loss_spans, - ) - - chosen_span_end = sample.chosen_span[1] + 1 - sequence_lengths = [ - chosen_span_end, - len(sample.token_ids) - chosen_span_end, - ] - - # compute padding size - padding = np.full((self._parameters.sequence_length + 1,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - if not self._parameters.cross_document_attention: - sample.sequence_lengths = np.array(sequence_lengths) - - return sample - # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample sample_length = ( @@ -440,8 +376,7 @@ def __getitem__(self, index: int) -> typing.Any: token_count = token_start_array[token_start_cumsum_index] - token_ids = [] - loss_masking_spans = [] + documents = [] while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -461,9 +396,8 @@ def __getitem__(self, index: int) -> typing.Any: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: - # Add padding tokens to current sample - token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) - Assert.eq(token_count + padding_size, token_end) + # TODO: ====== Handle padding ====== + documents.append(PaddingSample(padding_size)) break else: # Move on to the next sample. @@ -474,41 +408,19 @@ def __getitem__(self, index: int) -> typing.Any: # Determine which part of the document belong to the sample, and add it to the list. token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) - sample = self._indexed_dataset.get( - document_index, - offset=token_start_index_in_document, - length=token_end_index_in_document - token_start_index_in_document, - use_loss_masking_spans=self._parameters.use_loss_masking_spans, + documents.append( + self._indexed_dataset.get( + document_index, + offset=token_start_index_in_document, + length=token_end_index_in_document - token_start_index_in_document, + ) ) - token_ids.append(sample.token_ids) - if self._parameters.use_loss_masking_spans: - for loss_masking_span in sample.loss_masking_spans: - span = np.clip( - loss_masking_span + token_count - token_start, - 0, - self._parameters.sequence_length + self._parameters.extra_tokens, - ) - if span[1] >= span[0]: - loss_masking_spans.append(span) - # Go to the next document. document_sampling_index += 1 token_count += document_size - sequence_lengths = ( - np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - if not self._parameters.cross_document_attention - else None - ) - token_ids = np.concatenate(token_ids, dtype=np.int64) - loss_masking_spans = ( - (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) - if self._parameters.use_loss_masking_spans - else None - ) - Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens) - - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + # TODO: ====== Better way to get the class method? ====== + return documents[0].merge_documents(documents) @property def name(self) -> str: @@ -521,13 +433,5 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if self._parameters.use_preference_loss_spans: - data["unshuffled_tokens"] = 0 # not used, ignore - elif "unshuffled_tokens" not in data: - # Backward compatibility - # TODO v0.x: Remove - assert self._truncate_documents - data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] - self._unshuffled_tokens = data["unshuffled_tokens"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d2aaee5e2..d16181ee0 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -22,7 +22,7 @@ 8: DataType.uint16, } MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()} -MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" +MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x01" @config_class(registry=True) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 33c40bf8f..a4c554951 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -21,8 +21,9 @@ GPTMemmapDatasetConfig, GPTSampledDatasetConfig, ) -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.dataset.sample.language_model import LanguageModelReader +from fast_llm.data.dataset.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig from fast_llm.data.tokenizer import Tokenizer @@ -137,8 +138,6 @@ def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: shard_idx, shard_dataset = args - prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" - shard_output_path = self._config.output_path / prefix def _document_generator(): if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: @@ -163,7 +162,11 @@ def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) - GPTMemmapDataset.write_dataset(prefix=shard_output_path, documents=_document_generator()) + MemmapDataset.write_dataset( + path=self._config.output_path / f"shard_{self._config.distributed.rank}_{shard_idx}.fast_llm_dataset", + documents=_document_generator(), + reader_class=LanguageModelReader, + ) return GPTMemmapDatasetConfig.from_dict( { @@ -240,7 +243,7 @@ def run(self) -> None: datasets.builder.has_sufficient_disk_space = lambda needed_bytes, directory=".": True # Load tokenizer - self._tokenizer = Tokenizer(config=self._config.tokenizer) + self._tokenizer = self._config.tokenizer.get_tokenizer() # Decide the datatype based on the tokenizer vocabulary size self._data_type = ( diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index f4a2cfd6c..add121c50 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -23,8 +23,10 @@ class DataType(enum.StrEnum): int32 = "int32" int16 = "int16" int8 = "int8" - uint8 = "uint8" + uint64 = "uint64" + uint32 = "uint32" uint16 = "uint16" + uint8 = "uint8" @classmethod def _missing_(cls, dtype: str) -> "DataType": @@ -105,6 +107,9 @@ def _set_torch_dtype_map() -> None: DataType.int32: torch.int32, DataType.int16: torch.int16, DataType.int8: torch.int8, + DataType.uint64: torch.uint64, + DataType.uint32: torch.uint32, + DataType.uint16: torch.uint16, DataType.uint8: torch.uint8, } _TORCH_DTYPE_MAP_INV = {y: x for x, y in _TORCH_DTYPE_MAP.items()} @@ -127,8 +132,10 @@ def _set_numpy_dtype_map() -> None: DataType.int32: np.int32, DataType.int16: np.int16, DataType.int8: np.int8, - DataType.uint8: np.uint8, + DataType.uint64: np.uint64, + DataType.uint32: np.uint32, DataType.uint16: np.uint16, + DataType.uint8: np.uint8, } _NUMPY_DTYPE_MAP_INV = {y: x for x, y in _NUMPY_DTYPE_MAP.items()} @@ -151,6 +158,9 @@ def _set_triton_dtype_map() -> None: DataType.int32: tl.int32, DataType.int16: tl.int16, DataType.int8: tl.int8, + DataType.uint64: tl.uint64, + DataType.uint32: tl.uint32, + DataType.uint16: tl.uint16, DataType.uint8: tl.uint8, } diff --git a/fast_llm/engine/evaluation/config.py b/fast_llm/engine/evaluation/config.py index 4f035e174..f8dfd4825 100644 --- a/fast_llm/engine/evaluation/config.py +++ b/fast_llm/engine/evaluation/config.py @@ -2,6 +2,7 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.data.config import TokenizerConfig from fast_llm.engine.schedule.config import BatchConfig from fast_llm.utils import Assert @@ -63,6 +64,9 @@ def get_evaluator( class LmEvalEvaluatorConfig(EvaluatorConfig): _abstract: typing.ClassVar[bool] = False + tokenizer: TokenizerConfig = Field( + desc="Configuration for the tokenizer.", + ) cli_args: list[str] = Field( default_factory=lambda: [], desc="lm_eval CLI arguments, excluding those related to model, wandb, batch sizes, and device.", diff --git a/fast_llm/engine/evaluation/lm_eval/evaluator.py b/fast_llm/engine/evaluation/lm_eval/evaluator.py index 14aed65c4..5bfb544ed 100644 --- a/fast_llm/engine/evaluation/lm_eval/evaluator.py +++ b/fast_llm/engine/evaluation/lm_eval/evaluator.py @@ -60,7 +60,7 @@ def setup( self._flm_wrapper = FastLLMLmEvalWrapper( model=self._hf_model, - tokenizer=self._data.tokenizer.tokenizer, + tokenizer=self._config.tokenizer.get_tokenizer(), truncation=self._config.truncation, logits_cache=self._config.logits_cache, add_bos_token=self._config.add_bos_token, diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 3a70f308f..a56af3cf8 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -1,11 +1,9 @@ import torch -def _compute_logprobs_for_preference_spans( +def _get_logratios( logits: torch.Tensor, targets: torch.Tensor, chosen_spans: torch.Tensor, rejected_spans: torch.Tensor ): - assert torch.all(targets < logits.size(-1)), "Target out of vocab range" - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # gather log probabilities corresponding to the target tokens @@ -21,23 +19,7 @@ def _compute_logprobs_for_preference_spans( for idx, span in enumerate(rejected_spans): rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - return chosen_logp, rejected_logp, selected_log_probs - - -def _compute_dpo_loss( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta: float, -): - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - diff_logratios = pi_logratios - ref_logratios - - losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) - return losses + return chosen_logp - rejected_logp def compute_dpo_loss( @@ -53,21 +35,9 @@ def compute_dpo_loss( logits_ = logits.float().detach().requires_grad_() reference_model_logits_ = reference_model_logits.float().detach() - policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( - logits_, targets, chosen_spans, rejected_spans - ) - - reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( - reference_model_logits_, targets, chosen_spans, rejected_spans - ) - - losses = _compute_dpo_loss( - policy_chosen_logps=policy_chosen_logps, - policy_rejected_logps=policy_rejected_logps, - reference_chosen_logps=reference_chosen_logps, - reference_rejected_logps=reference_rejected_logps, - beta=beta, - ) + pi_logratios = _get_logratios(logits_, targets, chosen_spans, rejected_spans) + ref_logratios = _get_logratios(reference_model_logits_, targets, chosen_spans, rejected_spans) + losses = -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)) if grad_output is None: loss = None diff --git a/tests/data/common.py b/tests/data/common.py index d8cc6fff2..232ea090a 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -8,17 +8,10 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import ( - GPTIndexedDatasetConfig, - GPTSampledDatasetConfig, - GPTSamplingConfig, - GPTSamplingData, - GPTSamplingParameters, - ShufflingType, -) -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset -from fast_llm.data.tokenizer import Tokenizer +from fast_llm.data.dataset.config import IndexedDatasetConfig, SampledDatasetConfig, SamplingConfig, ShufflingType +from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.dataset.sampled import SampledIndexedDataset from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig @@ -34,7 +27,6 @@ def get_sampling_data( phase=PhaseType.training, sequence_length: int = 512, vocab_size=TEST_VOCAB_SIZE, - tokenizer: Tokenizer | None = None, gpu: bool = False, shuffle: ShufflingType = ShufflingType.epoch, truncate_documents=True, @@ -42,7 +34,7 @@ def get_sampling_data( # Config with convenient defaults. distributed = Distributed(DistributedConfig(), use_cpu=True) return GPTSamplingData( - config=GPTSamplingConfig( + config=SamplingConfig( seed=seed, gpu=gpu, shuffle=shuffle, @@ -56,12 +48,11 @@ def get_sampling_data( cache_directory=cache_directory, distributed=distributed, dataset_name=phase.value, - tokenizer=tokenizer, ) -def get_dataset_config[T: GPTSampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: - dataset_config = GPTSampledDatasetConfig.from_dict(config) +def get_dataset_config[T: SampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: + dataset_config = SampledDatasetConfig.from_dict(config) Assert.custom(isinstance, dataset_config, cls) return typing.cast(cls, dataset_config) @@ -96,7 +87,7 @@ def get_test_data_and_compare_samples( expected_samples = {PhaseType.training.value.lower(): expected_samples} assert "sampling" not in config - config["sampling"] = GPTSamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) + config["sampling"] = SamplingConfig(seed=seed, gpu=gpu, shuffle=shuffle) data = GPTData(GPTDataConfig.from_dict(config), distributed_config) data.setup(distributed, sampling_parameters, cache_directory) with NoAutoValidate(): @@ -115,7 +106,7 @@ def get_test_data_and_compare_samples( def compare_indexed_dataset( - dataset: GPTIndexedDataset, + dataset: IndexedDataset, length: int, num_tokens: int, expected_samples: dict[int, list[int]], @@ -142,9 +133,7 @@ def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list Assert.all_equal([sampled[i].token_ids for i in range(len(expected_samples))], expected_samples) -def validate_indexed_dataset_sampling( - sampled: GPTSampledIndexedDataset, expected_samples: list[list[int]] | None = None -): +def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): """ Compare `GPTSampledIndexedDataset` sampling against a more basic approach """ @@ -184,8 +173,8 @@ def validate_indexed_dataset_sampling( return token_ids -@config_class(dynamic_type={GPTSampledDatasetConfig: "mock_memmap"}) -class MockGPTMemmapDatasetConfig(GPTIndexedDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "mock_memmap"}) +class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): _abstract: typing.ClassVar[bool] = False num_documents: int | None = Field( default=None, @@ -207,7 +196,7 @@ def num_tokens(self) -> int: return self.num_documents * self.num_tokens_per_document -class MockGPTMemmapDataset(GPTIndexedDataset): +class MockGPTMemmapDataset(IndexedDataset): def __init__(self, config: MockGPTMemmapDatasetConfig): self._config = config diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index e64b47020..cf23eede2 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTBlendedDatasetConfig +from fast_llm.data.dataset.config import BlendedDatasetConfig from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( compare_sampled_dataset, @@ -12,13 +12,13 @@ get_test_data_and_compare_samples, ) from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX +from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH _DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" def _get_test_dataset_mix_1(): - return get_test_dataset(prefix=_DATASET_PREFIX_MIX_1, seed=2345) + return get_test_dataset(path=_DATASET_PREFIX_MIX_1, seed=2345) def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, np.ndarray]: @@ -117,7 +117,7 @@ def test_gpt_blended(): { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": DATASET_PATH}, {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, ], "weights": [0.75, 0.25], @@ -136,7 +136,7 @@ def test_gpt_blended_data(): "training": { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": DATASET_PATH}, {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, ], "weights": [0.75, 0.25], @@ -156,12 +156,12 @@ def test_gpt_blended_mixed(): { "type": "blended", "datasets": [ - {"type": "memmap", "path": DATASET_PREFIX}, + {"type": "memmap", "path": DATASET_PATH}, {"type": "random"}, ], "weights": [0.6, 0.4], }, - GPTBlendedDatasetConfig, + BlendedDatasetConfig, ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) @@ -173,7 +173,7 @@ def test_gpt_blended_mixed_data(): "datasets": { "training": { "type": "blended", - "datasets": [{"type": "memmap", "path": DATASET_PREFIX}, {"type": "random"}], + "datasets": [{"type": "memmap", "path": DATASET_PATH}, {"type": "random"}], "weights": [0.6, 0.4], } } diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 2c025cbaf..0691dde9d 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -8,7 +8,7 @@ ) from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH GPT_CONCATENATED_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], @@ -26,7 +26,7 @@ def test_gpt_concatenate(): # Make sure the dataset concatenation works and check for unintended changes in behavior. get_test_dataset() dataset = get_dataset_config( - {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)]}, + {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)]}, GPTConcatenatedDatasetConfig, ).build() compare_indexed_dataset( @@ -46,7 +46,7 @@ def test_gpt_concatenate_data(): "datasets": { "training": { "type": "concatenated", - "datasets": [{"type": "memmap", "path": DATASET_PREFIX} for _ in range(3)], + "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)], } } }, diff --git a/tests/data/test_dataset_from_file.py b/tests/data/test_dataset_from_file.py index c149e1395..af91df1e2 100644 --- a/tests/data/test_dataset_from_file.py +++ b/tests/data/test_dataset_from_file.py @@ -2,11 +2,11 @@ from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.data.test_memmap import MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_TOKENS from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH def test_dataset_from_file(): get_test_dataset() - dataset_config = {"type": "file", "path": str(DATASET_PREFIX.parent.joinpath("fast_llm_config.yaml"))} + dataset_config = {"type": "file", "path": str(DATASET_PATH.parent.joinpath("fast_llm_config.yaml"))} dataset = get_dataset_config(dataset_config, GPTDatasetFromFileConfig).build() compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index c9212d6e3..1211b92d6 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -8,7 +8,7 @@ get_test_data_and_compare_samples, ) from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX, TOKENIZER_PATH +from tests.utils.global_variables import DATASET_PATH, TOKENIZER_PATH GPT_FIM_SAMPLES = [ [4709, 819, 79, 207, 277, 1790], @@ -35,7 +35,7 @@ def test_gpt_fim(): sampled = get_dataset_config( { "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", @@ -54,7 +54,7 @@ def test_gpt_fim_data(): "datasets": { "training": { "type": "fim", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index 1286bddd7..48efd174e 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -2,10 +2,10 @@ import pytest -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig +from fast_llm.data.dataset.sample.config import MemmapDatasetConfig from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_CACHE, DATASET_PREFIX, DATASET_SAMPLING_CACHE +from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH, DATASET_SAMPLING_CACHE MEMMAP_DATASET_LENGTH = 6153 MEMMAP_DATASET_TOKENS = 508327 @@ -21,7 +21,7 @@ def test_gpt_memmap(cache_directory): # Make sure the memmap dataset works and check for unintended changes in behavior. get_test_dataset() - dataset = get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build() + dataset = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build() compare_indexed_dataset(dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES) @@ -36,13 +36,13 @@ def test_gpt_memmap(cache_directory): def test_gpt_data_with_spans(): - get_test_dataset(prefix=_DATASET_PREFIX_SPANS, max_spans=5) + get_test_dataset(path=_DATASET_PREFIX_SPANS, max_spans=5) dataset = get_dataset_config( { "type": "memmap", "path": _DATASET_PREFIX_SPANS, }, - GPTMemmapDatasetConfig, + MemmapDatasetConfig, ).build() compare_indexed_dataset( dataset, MEMMAP_DATASET_LENGTH, MEMMAP_DATASET_TOKENS, MEMMAP_DATASET_SAMPLES, MEMMAP_DATASET_SPANS diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 17ba5de01..c34463f5d 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -6,8 +6,9 @@ import pytest from fast_llm.data.dataset.gpt.config import GPTIndexedDatasetConfig -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.dataset.sample.language_model import LanguageModelReader +from fast_llm.data.dataset.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator from fast_llm.utils import Assert @@ -30,9 +31,9 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset def test_write_memmap_dataset(dtype): documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) + path = pathlib.Path(temp_dir) + MemmapDataset.write_dataset(path, documents, LanguageModelReader) + dataset = MemmapDataset(name="foo", path=path) for i, document in enumerate(documents): assert np.array_equal( dataset.get(i).token_ids, document.token_ids, equal_nan=True @@ -58,9 +59,9 @@ def generate_valid_span(max_seq_length): for _ in range(num_samples) ] with tempfile.TemporaryDirectory() as temp_dir: - prefix = pathlib.Path(temp_dir) - GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) - dataset = GPTMemmapDataset(name="foo", prefix=prefix) + path = pathlib.Path(temp_dir) + MemmapDataset.write_dataset(path, documents, LanguageModelReader) + dataset = MemmapDataset(name="foo", path=path) for i, document in enumerate(documents): dataset_item = dataset.get(i, use_preference_loss_spans=True) assert np.array_equal( diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 6a2be3dcc..8474d8e58 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -3,9 +3,10 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, ShufflingType +from fast_llm.data.dataset.config import ShufflingType +from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.sampled import GPTSample from fast_llm.utils import Assert from tests.data.common import ( get_dataset_config, @@ -14,7 +15,7 @@ validate_indexed_dataset_sampling, ) from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa @@ -39,7 +40,7 @@ def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. get_test_dataset() - sampled = get_dataset_config({"type": "memmap", "path": DATASET_PREFIX}, GPTMemmapDatasetConfig).build_and_sample( + sampled = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, GPTMemmapDatasetConfig).build_and_sample( get_sampling_data(8, sequence_length=5) ) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) @@ -52,7 +53,7 @@ def test_gpt_sampled_data(): "datasets": { "training": { "type": "memmap", - "path": DATASET_PREFIX, + "path": DATASET_PATH, } } }, diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 1fc8df1eb..07a9bd776 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -8,7 +8,7 @@ ) from tests.data.test_memmap import MEMMAP_DATASET_SAMPLES from tests.utils.dataset import get_test_dataset -from tests.utils.global_variables import DATASET_PREFIX +from tests.utils.global_variables import DATASET_PATH GPT_SLICE_TRAINING_SAMPLES = [ [80, 268, 79, 260, 207, 3086], @@ -33,7 +33,7 @@ def test_gpt_slice(): get_test_dataset() # samples[9:18] dataset = get_dataset_config( - {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PREFIX}, "begin": 0.0015, "end": 0.003}, + {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.0015, "end": 0.003}, GPTDatasetSliceConfig, ).build() compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) @@ -47,19 +47,19 @@ def test_gpt_slice_data(): "datasets": { "training": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0, "end": 0.0015, }, "validation": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.0015, "end": 0.003, }, "test": { "type": "slice", - "dataset": {"type": "memmap", "path": DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.003, "end": 1, }, diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 3fae970f8..65c7587b2 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -4,7 +4,7 @@ import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.functional.dpo import _compute_dpo_loss, _compute_logprobs_for_preference_spans +from fast_llm.functional.dpo import _compute_dpo_loss, _get_logratios from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert @@ -96,7 +96,7 @@ def random_split(seq_length): logits, targets, attention_mask, prompt_id_lens, packed_seq_lens ) - chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( + chosen_logps, rejected_logps, selected_log_probs = _get_logratios( logits=logits, targets=targets[:, 1:], chosen_spans=chosen_span, diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index 6aa541b8c..c4146e94b 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -7,8 +7,8 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSampledDatasetConfig, GPTSamplingData -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample, logger +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.dataset.sampled import GPTSample, logger from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset @@ -87,13 +87,13 @@ class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): hint=FieldHint.core, ) - def build(self) -> "GPTMemmapDataset": + def build(self) -> "MemmapDataset": return GPTMegatronMemmapDataset( str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens ) -class GPTMegatronMemmapDataset(GPTMemmapDataset): +class GPTMegatronMemmapDataset(MemmapDataset): def sample(self, sampling: GPTSamplingData) -> "MegatronGPTSampledIndexedDataset": return MegatronGPTSampledIndexedDataset(self, sampling) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 680faa931..7abf0fcdd 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -2,12 +2,14 @@ import random import numpy as np +import torch import yaml -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset -from fast_llm.data.dataset.gpt.sampled import GPTSample +from fast_llm.data.dataset.memmap import MemmapDataset +from fast_llm.data.dataset.sample.language_model import LanguageModelReader, LanguageModelSample +from fast_llm.data.dataset.sample.token import TokenSample from tests.utils.global_variables import ( - DATASET_PREFIX, + DATASET_PATH, MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE, TEST_CHARACTERS, @@ -26,7 +28,7 @@ def download_santacoder_tokenizer(): def get_test_dataset( - prefix: pathlib.Path = DATASET_PREFIX, + path: pathlib.Path = DATASET_PATH, seed: int = 1234, num_tokens: int = TEST_DATASET_TOKENS, characters: str = TEST_CHARACTERS, @@ -35,34 +37,33 @@ def get_test_dataset( ): download_santacoder_tokenizer() - if not ( - prefix.with_suffix(".idx").is_file() - and prefix.with_suffix(".bin").is_file() - and prefix.parent.joinpath("fast_llm_config.yaml").is_file() - ): + if not (path.is_file() and path.parent.joinpath("fast_llm_config.yaml").is_file()): import transformers texts = "".join(random.Random(seed).choices(characters, k=num_tokens)).splitlines() tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) samples = [ - GPTSample(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size) for document in texts + LanguageModelSample( + TokenSample( + torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size), + ) + ) + for document in texts ] if max_spans > 0: - lengths = np.array([max(len(sample.token_ids), 1) for sample in samples]) + lengths = np.array([max(len(sample), 1) for sample in samples]) spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) for sample, span in zip(samples, spans): span = np.unique(span) - sample.loss_masking_spans = span[: len(span) // 2 * 2].reshape(-1, 2) + sample.loss_masking_spans = torch.from_numpy(span[: len(span) // 2 * 2].reshape(-1, 2)) - GPTMemmapDataset.write_dataset(prefix, samples) - yaml.safe_dump( - {"type": "memmap", "path": prefix.name}, prefix.parent.joinpath("fast_llm_config.yaml").open("w") - ) + MemmapDataset.write_dataset(path, samples, LanguageModelReader) + yaml.safe_dump({"type": "memmap", "path": path.name}, path.parent.joinpath("fast_llm_config.yaml").open("w")) def get_model_test_dataset( prefix: pathlib.Path = MODEL_DATASET_PREFIX, vocab_size: int = MODEL_TEST_VOCAB_SIZE, ): - return get_test_dataset(prefix=prefix, vocab_size=vocab_size) + return get_test_dataset(path=prefix, vocab_size=vocab_size) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 42e588911..8ff6d2a9f 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -37,7 +37,7 @@ def set_testing_global_variables(): TOKENIZER_PATH = SHARED_RESULT_PATH / "tokenizer" TOKENIZER_FILE = TOKENIZER_PATH / "tokenizer.json" DATASET_CACHE = SHARED_RESULT_PATH / "dataset" -DATASET_PREFIX = DATASET_CACHE / "common_dataset" +DATASET_PATH = DATASET_CACHE / "common_dataset.fast_llm_dataset" DATASET_SAMPLING_CACHE = TEST_RESULTS_PATH / "dataset_sampling_cache" TEST_VOCAB_SIZE = 8192 # Random lowercase: 80.7% (3.1% each); space: 18.6%; doc end: 0.6% diff --git a/tools/concatenate_dataset.py b/tools/concatenate_dataset.py index bbfa4b21a..a2c957a3b 100644 --- a/tools/concatenate_dataset.py +++ b/tools/concatenate_dataset.py @@ -3,7 +3,7 @@ import pathlib from fast_llm.config import Field, config_class -from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset +from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.engine.config_utils.runnable import RunnableConfig logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ def run(self): for path in self.directory.glob("**/*.idx"): prefix = path.with_suffix("") logger.info(str(prefix)) - dataset = GPTMemmapDataset("dataset", prefix) + dataset = MemmapDataset("dataset", prefix) dataset_dict = { "prefix": str(prefix.relative_to(self.directory)), "num_documents": len(dataset), From 517a67b4f908e5772b8c9f5e9eeee4b955f05199 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 15 Oct 2025 17:04:25 -0400 Subject: [PATCH 5/7] Fix merge --- fast_llm/config.py | 21 ++++- fast_llm/data/data/abstract.py | 2 +- fast_llm/data/data/gpt/config.py | 2 +- fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/dataset/abstract.py | 20 +++-- fast_llm/data/dataset/blended.py | 41 ++++----- fast_llm/data/dataset/config.py | 43 +++++----- fast_llm/data/dataset/gpt/config.py | 20 ++--- fast_llm/data/dataset/gpt/fim.py | 19 +++-- fast_llm/data/dataset/gpt/random.py | 14 ++-- fast_llm/data/dataset/indexed.py | 84 ++++++++++--------- fast_llm/data/dataset/memmap.py | 5 +- fast_llm/data/dataset/monitor.py | 14 ++-- fast_llm/data/dataset/sample/__init__.py | 25 ------ fast_llm/data/dataset/sampled.py | 16 +--- .../data/preparator/gpt_memmap/prepare.py | 57 +++++++------ fast_llm/data/sample/__init__.py | 0 .../data/{dataset => }/sample/abstract.py | 2 +- fast_llm/data/{dataset => }/sample/config.py | 14 ++-- .../{dataset => }/sample/language_model.py | 10 +-- fast_llm/data/{dataset => }/sample/range.py | 6 +- fast_llm/data/{dataset => }/sample/token.py | 6 +- fast_llm/functional/dpo.py | 40 +++++++-- tests/data/test_memmap.py | 2 +- tests/data/test_prepare_gpt_memmap.py | 15 ++-- tests/utils/dataset.py | 4 +- 26 files changed, 261 insertions(+), 223 deletions(-) delete mode 100644 fast_llm/data/dataset/sample/__init__.py create mode 100644 fast_llm/data/sample/__init__.py rename fast_llm/data/{dataset => }/sample/abstract.py (93%) rename fast_llm/data/{dataset => }/sample/config.py (89%) rename fast_llm/data/{dataset => }/sample/language_model.py (92%) rename fast_llm/data/{dataset => }/sample/range.py (94%) rename fast_llm/data/{dataset => }/sample/token.py (94%) diff --git a/fast_llm/config.py b/fast_llm/config.py index 9644df9c1..658ad5666 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -492,6 +492,10 @@ def _validate_element(cls, value, type_, name: str): value = cls._validate_dict(value, type_, name) elif origin is type: value = cls._validate_type(value, type_, name) + elif issubclass(origin, Config): + # TODO: Validate arguments for config generics. + cls._validate_element_type(value, type_.__origin__, strict=False) + value.validate(_is_validating=True) else: raise FieldTypeError(f"Unsupported __origin__ `{origin}`") elif not isinstance(type_, type): @@ -806,6 +810,8 @@ def _from_dict_nested(cls, value, type_, strict: bool): value = cls._from_dict_array(value, type_, strict) elif issubclass(origin, dict): value = cls._from_dict_dict(value, type_, strict) + elif issubclass(origin, Config): + value = cls._from_dict_config(value, type_, strict) elif origin is type: pass else: @@ -813,10 +819,15 @@ def _from_dict_nested(cls, value, type_, strict: bool): elif not isinstance(type_, type): raise FieldTypeError(f"Not a type: {type_}.") elif issubclass(type_, Config): - if value is MISSING: - value = {} - if isinstance(value, dict): - value = type_._from_dict(value, strict) + value = cls._from_dict_config(value, type_, strict) + return value + + @classmethod + def _from_dict_config(cls, value, type_, strict: bool): + if value is MISSING: + value = {} + if isinstance(value, dict): + value = type_._from_dict(value, strict) return value @classmethod @@ -938,6 +949,7 @@ def __init_subclass__(cls): We need to postpone validation until the class has been processed by the dataclass wrapper. """ Assert.eq(cls.__name__, cls.__qualname__) + super().__init_subclass__() for base_class in cls.__mro__: if issubclass(base_class, Config) and base_class is not cls: assert cls.__class_validated__, ( @@ -1006,6 +1018,7 @@ def __init__(self, config: ConfigType, *args, **kwargs): def __init_subclass__(cls): # Automatically set `config_class` based on the bound type. # Make sure `ConfigType` is bound and respects class hierarchy. + super().__init_subclass__() try: config_class = None for base in types.get_original_bases(cls): diff --git a/fast_llm/data/data/abstract.py b/fast_llm/data/data/abstract.py index 22f5fd194..c67dc0321 100644 --- a/fast_llm/data/data/abstract.py +++ b/fast_llm/data/data/abstract.py @@ -5,7 +5,7 @@ from fast_llm.config import Configurable from fast_llm.data.data.config import DataConfig from fast_llm.data.dataset.config import SamplingParameters -from fast_llm.data.dataset.sample import Batch +from fast_llm.data.sample.abstract import Batch from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.schedule.config import BatchConfig diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index cb8cb38f6..432aa09c3 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -20,7 +20,7 @@ class GPTDataConfig(DataConfig): _abstract = False # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, SampledDatasetConfig] = Field( + datasets: dict[str, SampledDatasetConfig[GPTSample]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 1934b617b..36144105a 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -13,8 +13,8 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.monitor import DatasetMonitor -from fast_llm.data.dataset.sample.language_model import LanguageModelBatch from fast_llm.data.iterator import SampledDatasetIterator +from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index b470c0159..33942708b 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -1,11 +1,13 @@ import abc import typing +from fast_llm.data.sample.abstract import Sample + if typing.TYPE_CHECKING: from fast_llm.data.dataset.config import SamplingData -class Dataset(abc.ABC): +class Dataset[SampleType: Sample](abc.ABC): """ A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature. """ @@ -17,15 +19,23 @@ def name(self) -> str: A name for the dataset to facilitate identification and debugging. """ + def __getstate__(self): + state = super().__getstate__() + # Pickling sometimes fails with bound `SampleType`. + # This is not needed at runtime, so we just drop it. + if "__orig_class__" in state: + del state["__orig_class__"] + return state + -class SampledDataset(Dataset): +class SampledDataset[SampleType: Sample](Dataset[SampleType]): """ A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training. (See the `Sampler` class below.) """ @abc.abstractmethod - def __getitem__(self, index: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: pass @abc.abstractmethod @@ -33,8 +43,8 @@ def __len__(self) -> int: pass -class SamplableDataset(Dataset): +class SamplableDataset[SampleType: Sample](Dataset[SampleType]): @abc.abstractmethod - def sample(self, config: "SamplingData") -> SampledDataset: + def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: pass diff --git a/fast_llm/data/dataset/blended.py b/fast_llm/data/dataset/blended.py index 24b0fa76f..264eb373d 100644 --- a/fast_llm/data/dataset/blended.py +++ b/fast_llm/data/dataset/blended.py @@ -1,16 +1,16 @@ import logging -import typing -import numpy as np +import torch from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingData +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities logger = logging.getLogger(__name__) -class BlendedDataset(SampledDataset): +class BlendedDataset[SampleType: Sample](SampledDataset[SampleType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -21,7 +21,7 @@ class BlendedDataset(SampledDataset): def __init__( self, name: str, - datasets: list[SampledDataset], + datasets: list[SampledDataset[SampleType]], weights: list[float], sampling_config: SamplingData, ): @@ -29,51 +29,52 @@ def __init__( assert len(datasets) > 0 Assert.eq(len(datasets), len(weights)) self._datasets = datasets - self._weights = np.array(normalize_probabilities(weights)) + self._weights = torch.from_numpy(normalize_probabilities(weights, return_array=True)) self._num_samples = sampling_config.parameters.num_samples def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: """ Blending is typically done in one of the following iterative way (ex. in Megatron datasets): ```python dataset_index=np.zeros(num_samples) sample_index=np.zeros(num_samples) sampled=np.zeros(len(weights)) - for idx in range(num_samples): - error = weights * (idx + 1) - sampled + for index in range(num_samples): + error = weights * (index + 1) - sampled dataset_index_ = np.argmax(error) - dataset_index[idx] = dataset_index_ - sample_index[idx] = sampled[dataset_index_] + dataset_index[index] = dataset_index_ + sample_index[index] = sampled[dataset_index_] sampled[dataset_index_] +=1 ``` I.e. it iteratively picks samples to minimize the error `weights * sum(sampled) - sampled`. This implementation computes values on the fly instead of pre-computing them all. """ # We find the number of samples taken from each dataset prior to this point. - sampled = self._get_sampled(idx) + sampled = self._get_sampled(index) # Then get the present sample. - dataset_index = self._get_next_dataset(idx, sampled) - return self._datasets[dataset_index][sampled[dataset_index]] + dataset_index = self._get_next_dataset(index, sampled) + return self._datasets[dataset_index][sampled[dataset_index].item()] - def _get_sampled(self, num_samples: int): + def _get_sampled(self, num_samples: int) -> torch.Tensor: # First we determine a lower bound. # This is indeed a lower bound because a lower value for one dataset would involve more sampling below, # and it would be from that same dataset because it would have the highest error, - sampled = np.floor(self._weights * num_samples).astype(int) + + sampled = (self._weights * num_samples).to(torch.int64) # Then we sample until we reach the target number of samples. # This may not match the actual sampling order, but the final value of `sampled` is correct. - for idx in range(sampled.sum(), num_samples): - dataset_index = self._get_next_dataset(idx, sampled) + for index in range(sampled.sum().item(), num_samples): + dataset_index = self._get_next_dataset(index, sampled) sampled[dataset_index] += 1 return sampled - def _get_next_dataset(self, idx, sampled): + def _get_next_dataset(self, index: int, sampled: torch.Tensor) -> int: # The next sample is the one with the highest error. - return (self._weights * (idx + 1) - sampled).argmax() + return (self._weights * (index + 1) - sampled).argmax().item() @property - def name(self): + def name(self) -> str: return self._name diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 49025a60a..893a02382 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -8,6 +8,7 @@ from fast_llm.config import Config, Field, FieldHint, UpdateType, check_field, config_class from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, normalize_probabilities if typing.TYPE_CHECKING: @@ -97,37 +98,38 @@ def get_next_rank(self) -> int: @config_class() -class DatasetConfig(Config): +class DatasetConfig[SampleType: Sample](Config): _abstract: typing.ClassVar[bool] = True @config_class(registry=True) -class SampledDatasetConfig(DatasetConfig): +class SampledDatasetConfig[SampleType: Sample](DatasetConfig[SampleType]): """ A sampled dataset containing a prepared list of samples to be indexed sequentially (as-is) during training. """ - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: + # TODO: ====== `SamplingData` contains more than needed (ex. `num_samples`) raise NotImplementedError() @config_class() -class SamplableDatasetConfig(SampledDatasetConfig): - def build(self) -> SamplableDataset: +class SamplableDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): + def build(self) -> SamplableDataset[SampleType]: raise NotImplementedError() - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: return self.build().sample(sampling) @config_class() -class IndexedDatasetConfig(SamplableDatasetConfig): - def _build(self) -> "IndexedDataset": +class IndexedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): + def build(self) -> "IndexedDataset[SampleType]": raise NotImplementedError() @config_class(dynamic_type={SampledDatasetConfig: "concatenated"}) -class ConcatenatedDatasetConfig(SamplableDatasetConfig): +class ConcatenatedDatasetConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): """ Concatenate multiple indexed datasets as if they were one. TODO: Make a post-sampling version? (staged training) @@ -139,7 +141,7 @@ class ConcatenatedDatasetConfig(SamplableDatasetConfig): desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[IndexedDatasetConfig] = Field( + datasets: list[IndexedDatasetConfig[SampleType]] = Field( default_factory=list, desc="The datasets to concatenate.", hint=FieldHint.core, @@ -153,7 +155,7 @@ def build(self) -> "ConcatenatedDataset": @config_class(dynamic_type={SampledDatasetConfig: "slice"}) -class DatasetSliceConfig(SamplableDatasetConfig): +class DatasetSliceConfig[SampleType: Sample](SamplableDatasetConfig[SampleType]): """ Use a fraction of an indexed dataset, specified by the range (begin, end). Typically used to subsample a dataset, or to reserve part of the dataset for validation and/or testing. @@ -163,7 +165,7 @@ class DatasetSliceConfig(SamplableDatasetConfig): """ _abstract = False - dataset: IndexedDatasetConfig = Field( + dataset: IndexedDatasetConfig[SampleType] = Field( default=None, desc="The dataset to split.", hint=FieldHint.core, @@ -184,8 +186,7 @@ def build(self) -> "DatasetSlice": dataset = self.dataset.build() size = len(dataset) - - return DatasetSlice( + return DatasetSlice[SampleType]( f"{dataset.name}_{self.begin}_{self.end}", dataset, round(self.begin * size), @@ -194,7 +195,7 @@ def build(self) -> "DatasetSlice": @config_class(dynamic_type={SampledDatasetConfig: "sampled"}) -class SampledDatasetUpdateConfig(SampledDatasetConfig): +class SampledDatasetUpdateConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): """ Wrap a dataset to explicitly sample from it and optionally update its configuration parameters. Only explicitly set parameters (not None) will be updated, other will still be taken from `build_and_sample`'s argument. @@ -205,24 +206,24 @@ class SampledDatasetUpdateConfig(SampledDatasetConfig): desc="Optional override to sampling configuration parameters.", hint=FieldHint.core, ) - dataset: SampledDatasetConfig = Field( + dataset: SampledDatasetConfig[SampleType] = Field( desc="The dataset to sample from.", hint=FieldHint.core, ) - def build_and_sample(self, data: SamplingData) -> SampledDataset: + def build_and_sample(self, data: SamplingData) -> SampledDataset[SampleType]: return self.dataset.build_and_sample(data.update_config(self.sampling)) @config_class(dynamic_type={SampledDatasetConfig: "blended"}) -class BlendedDatasetConfig(SampledDatasetConfig): +class BlendedDatasetConfig[SampleType: Sample](SampledDatasetConfig[SampleType]): _abstract = False name: str = Field( default="blended", desc="The name of the dataset.", hint=FieldHint.core, ) - datasets: list[SampledDatasetConfig] = Field( + datasets: list[SampledDatasetConfig[SampleType]] = Field( default_factory=list, desc="The datasets to blend.", hint=FieldHint.core, @@ -242,7 +243,7 @@ def _validate(self) -> None: def build_and_sample( self, sampling: SamplingData, - ) -> SampledDataset: + ) -> SampledDataset[SampleType]: from fast_llm.data.dataset.blended import BlendedDataset # Build and sample the datasets. @@ -263,7 +264,7 @@ def build_and_sample( for i, (dataset, weight) in enumerate(zip(self.datasets, self.weights, strict=True)) ] # Blend the datasets. - return BlendedDataset( + return BlendedDataset[SampleType]( self.name, sampled_datasets, self.weights, diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 078df37b8..0eef93522 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -39,7 +39,7 @@ class GPTSamplingData(SamplingData): @config_class(dynamic_type={SampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig(SamplableDatasetConfig): +class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -54,7 +54,7 @@ def build(self) -> "GPTRandomDataset": @config_class(dynamic_type={SampledDatasetConfig: "file"}) -class GPTDatasetFromFileConfig(SamplableDatasetConfig): +class GPTDatasetFromFileConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -62,18 +62,18 @@ class GPTDatasetFromFileConfig(SamplableDatasetConfig): hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: config = self._load_config() return config.build_and_sample(sampling) - def build(self) -> SamplableDataset: + def build(self) -> SamplableDataset[SampleType]: config = self._load_config() assert isinstance(config, SamplableDatasetConfig) return config.build() - def _load_config(self): + def _load_config(self) -> SampledDatasetConfig[SampleType]: assert self.path.is_file(), f"File {self.path} does not exist." - return SampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) + return SampledDatasetConfig[SampleType].from_dict(self._convert_paths(yaml.safe_load(self.path.open("r")))) def _convert_paths(self, config): # Recursively convert paths relative to `self.path.parent` to make them relative to cwd. @@ -164,7 +164,7 @@ class FimConfig(Config): @config_class(dynamic_type={SampledDatasetConfig: "fim"}) -class GPTFimSampledDatasetConfig(SampledDatasetConfig, FimConfig): +class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType], FimConfig): """ Configuration for FIM. """ @@ -187,7 +187,7 @@ def build_and_sample( @config_class(dynamic_type={SampledDatasetConfig: "test_slow"}) -class GPTTestSlowDatasetConfig(SampledDatasetConfig): +class GPTTestSlowDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType]): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. """ @@ -200,8 +200,8 @@ class GPTTestSlowDatasetConfig(SampledDatasetConfig): hint=FieldHint.core, ) - def build_and_sample(self, sampling: SamplingData) -> SampledDataset: + def build_and_sample(self, sampling: SamplingData) -> SampledDataset[SampleType]: assert sampling.distributed.config.world_size > 1 if sampling.distributed.config.rank == 0: time.sleep(self.sleep) - return GPTRandomDatasetConfig().build_and_sample(sampling) + return GPTRandomDatasetConfig[SampleType]().build_and_sample(sampling) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 7c86aef81..dc07e844c 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -1,12 +1,12 @@ import numpy as np +import torch from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData -from fast_llm.data.dataset.sampled import GPTSample from fast_llm.engine.distributed.config import MAX_SEED -class GPTFimDataset(SampledDataset): +class GPTFimDataset[SampleType: GPTSample](SampledDataset[SampleType]): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py @@ -15,7 +15,7 @@ class GPTFimDataset(SampledDataset): def __init__( self, config: FimConfig, - dataset: SampledDataset, + dataset: SampledDataset[SampleType], sampling: GPTSamplingData, ): if sampling.parameters.use_loss_masking_spans: @@ -40,11 +40,15 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx: int) -> np.ndarray: - fim_token_ids = self._fim( - self._dataset[idx].token_ids, np.random.RandomState(seed=(self._seed + idx) % MAX_SEED) + def __getitem__(self, index: int) -> SampleType: + # TODO: Use torch methods to avoid back and forth. + return GPTSample( + torch.from_numpy( + self._fim( + self._dataset[index].token_ids.numpy(), np.random.RandomState(seed=(self._seed + index) % MAX_SEED) + ) + ) ) - return GPTSample(fim_token_ids) @property def name(self) -> str: @@ -55,6 +59,7 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: # TODO: permute segments in sample_list, before concatenating. sample_len = sample.shape[0] eod = self._tokenizer.eod + # TODO: Available through `tokens.lengths` segment_breaks = np.argwhere(sample == eod) # split sample by document if segment_breaks.shape != (0, 1): # then there is an EOD token in this example diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 91fbe38c9..0901f5006 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -1,8 +1,8 @@ import numpy as np +import torch from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData -from fast_llm.data.dataset.sample import LanguageModelSample class GPTRandomDataset(SamplableDataset): @@ -21,7 +21,7 @@ def name(self) -> str: return self._name -class GPTRandomSampledDataset(SampledDataset): +class GPTRandomSampledDataset[SampleType: GPTSample](SampledDataset[SampleType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed @@ -32,10 +32,12 @@ def __init__(self, sampling: GPTSamplingData, name: str): def __len__(self) -> int: return self._num_samples - def __getitem__(self, idx) -> np.ndarray: - return LanguageModelSample( - np.random.RandomState(self._seed + 48576439 + 74593 * idx).randint( - 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + def __getitem__(self, index: int) -> SampleType: + return GPTSample( + torch.from_numpy( + np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( + 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 + ) ) ) diff --git a/fast_llm/data/dataset/indexed.py b/fast_llm/data/dataset/indexed.py index 43f675c81..c6eac9e28 100644 --- a/fast_llm/data/dataset/indexed.py +++ b/fast_llm/data/dataset/indexed.py @@ -1,30 +1,19 @@ import abc -import typing -import numpy as np import torch from fast_llm.data.dataset.abstract import SamplableDataset -from fast_llm.data.dataset.config import SamplingData -from fast_llm.data.dataset.sample import Sample +from fast_llm.data.dataset.config import SamplingData, SamplingParameters +from fast_llm.data.sample.abstract import Sample from fast_llm.utils import Assert, padded_cumsum -class IndexedDataset(SamplableDataset): +class IndexedDataset[SampleType: Sample](SamplableDataset[SampleType]): """ - A dataset containing a list of `documents`, to be merged into `samples`. + A dataset containing a list of samples. + TODO: Move sampling responsibility here? """ - @abc.abstractmethod - def get_document(self, index: int, begin: int, end: int) -> Sample: - pass - - @abc.abstractmethod - def __len__(self) -> int: - """ - Number of samples in the dataset. - """ - @abc.abstractmethod def get_document_sizes(self) -> torch.Tensor: """ @@ -39,18 +28,30 @@ def get_document_size(self, index: int) -> int: The size of a document in the dataset. """ - def sample(self, sampling: SamplingData) -> "SampledIndexedDataset": + @abc.abstractmethod + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: + pass + + @abc.abstractmethod + def __len__(self) -> int: + """ + Number of samples in the dataset. + """ + + def sample(self, sampling: SamplingData) -> "GPTSampledIndexedDataset": from fast_llm.data.dataset.sampled import SampledIndexedDataset return SampledIndexedDataset(self, sampling) -class DatasetSlice[IndexedDatasetType: IndexedDataset](IndexedDataset): +class DatasetSlice[SampleType: Sample](IndexedDataset[SampleType]): def __init__( self, name: str, - dataset: IndexedDataset, + dataset: IndexedDataset[SampleType], begin: int | None = None, end: int | None = None, ): @@ -67,59 +68,62 @@ def __init__( except Exception as e: raise AssertionError(f"Invalid document indices for dataset {name} with length {num_samples}") from e - def get( - self, document: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False - ) -> typing.Any: + def get_document_sizes(self) -> torch.Tensor: + # TODO: This can be really big. + return self._dataset.get_document_sizes()[self._begin : self._end] + + def get_document_size(self, index: int) -> int: + return self._dataset.get_document_size(self._begin + index) + + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: """ Get the sample (document) with the given index (in the dataset slice), - optionally sub-sampled to a specific offset (starting point) and maximum length + optionally subsampled to a specific offset (starting point) and maximum length (end = min(offset + length, sample_length). """ - return self._dataset.get(document + self._begin, offset, length, use_loss_masking_spans) + return self._dataset.get_document(index + self._begin, begin, end, parameters) def __len__(self) -> int: return self._end - self._begin - def get_document_sizes(self) -> torch.Tensor: - # TODO: This can be really big. - return self._dataset.get_document_sizes()[self._begin : self._end] - - def get_document_size(self, index: int) -> int: - return self._dataset.get_document_size(self._begin + index) - @property def name(self) -> str: return self._name -class ConcatenatedDataset[IndexedDatasetType: IndexedDataset](IndexedDataset): +class ConcatenatedDataset[SampleType: Sample](IndexedDataset[SampleType]): def __init__( self, name: str, - datasets: list[IndexedDataset], + datasets: list[IndexedDataset[SampleType]], ): self._name = name self._datasets = datasets sizes = [len(dataset) for dataset in self._datasets] - self._dataset_splits = padded_cumsum(sizes) + self._dataset_splits = torch.from_numpy(padded_cumsum(sizes)) def __len__(self) -> int: return self._dataset_splits[-1].item() - def get(self, index: int, *args, **kwargs): - - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") - return self._datasets[dataset].get(index - self._dataset_splits[dataset].item(), *args, **kwargs) - def get_document_sizes(self) -> torch.Tensor: # TODO: This can be really big. return torch.cat([dataset.get_document_sizes() for dataset in self._datasets]) def get_document_size(self, index: int) -> int: - dataset = np.searchsorted(self._dataset_splits[1:], index, side="right") + dataset = torch.searchsorted(self._dataset_splits[1:], index, side="right") return self._datasets[dataset].get_document_size(index - self._dataset_splits[dataset].item()) + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: + dataset = torch.searchsorted(self._dataset_splits[1:], index, side="right") + return self._datasets[dataset].get_document( + index - self._dataset_splits[dataset].item(), begin, end, parameters + ) + @property def name(self) -> str: return self._name diff --git a/fast_llm/data/dataset/memmap.py b/fast_llm/data/dataset/memmap.py index 1228fcaea..82a45ae9e 100644 --- a/fast_llm/data/dataset/memmap.py +++ b/fast_llm/data/dataset/memmap.py @@ -6,9 +6,8 @@ import torch from fast_llm.data.dataset.indexed import IndexedDataset -from fast_llm.data.dataset.sample import Sample -from fast_llm.data.dataset.sample.abstract import MemmapReader -from fast_llm.data.dataset.sample.config import MemmapIndexDatasetReaderConfig +from fast_llm.data.sample.abstract import MemmapReader, Sample +from fast_llm.data.sample.config import MemmapIndexDatasetReaderConfig FILE_HEADER = b"fast_llm_prepared_dataset" diff --git a/fast_llm/data/dataset/monitor.py b/fast_llm/data/dataset/monitor.py index 86bc080fe..01f3195e4 100644 --- a/fast_llm/data/dataset/monitor.py +++ b/fast_llm/data/dataset/monitor.py @@ -1,8 +1,8 @@ import logging import time -import typing from fast_llm.data.dataset.abstract import SampledDataset +from fast_llm.data.sample.abstract import Sample try: from fast_llm.csrc.data import build_blending_indices # noqa @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class DatasetMonitor(SampledDataset): +class DatasetMonitor[SampleType: Sample](SampledDataset[SampleType]): """ A blended sampling of multiple sampled datasets, where each dataset is sampled with the provided probability. The sampling order of each dataset is respected, but there is no strict guarantee @@ -24,7 +24,7 @@ class DatasetMonitor(SampledDataset): def __init__( self, - dataset: SampledDataset, + dataset: SampledDataset[SampleType], data_sample_warn_time_ms: float, ): self._dataset = dataset @@ -33,19 +33,19 @@ def __init__( def __len__(self) -> int: return len(self._dataset) - def __getitem__(self, idx) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: start_time = time.perf_counter() try: - sample = self._dataset[idx] + sample = self._dataset[index] sample_time = (time.perf_counter() - start_time) * 1000 if sample_time > self._data_sample_warn_time_ms: logger.warning( - f"Sample {idx} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" + f"Sample {index} from dataset {self._dataset.name})" f" took {sample_time:,.2f} ms to load" ) return sample except Exception: - logger.error(f"Failed to get sample {idx} from dataset {self._dataset.name}") + logger.error(f"Failed to get sample {index} from dataset {self._dataset.name}") raise @property diff --git a/fast_llm/data/dataset/sample/__init__.py b/fast_llm/data/dataset/sample/__init__.py deleted file mode 100644 index b42cac943..000000000 --- a/fast_llm/data/dataset/sample/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -import abc -import typing - -from fast_llm.data.dataset.sample.abstract import Batch, Sample - - -class LanguageModelSample(Sample): - - @classmethod - @abc.abstractmethod - def merge_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: - pass - - @classmethod - @abc.abstractmethod - def merge_into_batch(cls, samples: typing.Iterable[typing.Self]) -> "LanguageModelBatch": - pass - - @abc.abstractmethod - def crop(self, offset: int = 0, length: int | None = None): - pass - - -class LanguageModelBatch(Batch): - pass diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 493300a6b..42466293a 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -11,6 +11,7 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingData, ShufflingType from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank @@ -24,15 +25,6 @@ logger = logging.getLogger(__name__) -# @dataclasses.dataclass -# class GPTSample: -# token_ids: np.ndarray -# loss_masking_spans: np.ndarray | None = None -# chosen_span: np.ndarray | None = None -# rejected_span: np.ndarray | None = None -# sequence_lengths: np.ndarray | None = None - - class MemmapArray: """ An array with lazy loading in memmap mode. @@ -73,14 +65,14 @@ def _lazy_load(self): TOKEN_CUMSUM_RATE = 10 -class SampledIndexedDataset(SampledDataset): +class SampledIndexedDataset[SampleType: Sample](SampledDataset[SampleType]): """ A sampled GPT dataset. """ def __init__( self, - indexed_dataset: IndexedDataset, + indexed_dataset: IndexedDataset[SampleType], sampling: SamplingData, ): self._indexed_dataset = indexed_dataset @@ -344,7 +336,7 @@ def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) - def __len__(self) -> int: return self._parameters.num_samples - def __getitem__(self, index: int) -> typing.Any: + def __getitem__(self, index: int) -> SampleType: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index a4c554951..317f87dca 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -14,18 +14,17 @@ import transformers import yaml -from fast_llm.data.dataset.gpt.config import ( - GPTBlendedDatasetConfig, - GPTDatasetSliceConfig, - GPTIndexedDatasetConfig, - GPTMemmapDatasetConfig, - GPTSampledDatasetConfig, +from fast_llm.data.dataset.config import ( + BlendedDatasetConfig, + DatasetSliceConfig, + IndexedDatasetConfig, + SampledDatasetConfig, ) from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.dataset.sample.language_model import LanguageModelReader -from fast_llm.data.dataset.sampled import GPTSample from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig +from fast_llm.data.sample.config import MemmapDatasetConfig +from fast_llm.data.sample.language_model import LanguageModelReader from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -38,6 +37,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _data_type: DataType _text_column: str _loss_masking_spans_column: str | None + _sample_type: typing.ClassVar[type[GPTSample]] = GPTSample def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ @@ -136,15 +136,15 @@ def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any "num_tokens": num_tokens, } - def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: + def _save_shard(self, args: tuple[int, datasets.Dataset]) -> MemmapDatasetConfig: shard_idx, shard_dataset = args def _document_generator(): if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( - np.array(item["input_ids"], dtype=self._data_type.numpy), - np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), + torch.tensor(item["input_ids"], dtype=self._data_type.torch), + torch.tensor(item["token_spans"], dtype=torch.int32).reshape(-1, 2), ) elif ( "chosen_token_spans" in shard_dataset.column_names @@ -154,13 +154,13 @@ def _document_generator(): ): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( - token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), + token_ids=torch.tensor(item["input_ids"], dtype=self._data_type.torch), + chosen_span=torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), + rejected_span=torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) + yield GPTSample(torch.tensor(item["input_ids"], dtype=self._data_type.torch)) MemmapDataset.write_dataset( path=self._config.output_path / f"shard_{self._config.distributed.rank}_{shard_idx}.fast_llm_dataset", @@ -168,7 +168,7 @@ def _document_generator(): reader_class=LanguageModelReader, ) - return GPTMemmapDatasetConfig.from_dict( + return MemmapDatasetConfig.from_dict( { "type": "memmap", "path": prefix, @@ -346,7 +346,7 @@ def run(self) -> None: self.generate_config_yaml_for_sharded_dst(dataset_configs) - def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDatasetConfig]) -> None: + def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[MemmapDatasetConfig]) -> None: # Gather dataset_dicts from all ranks to rank 0 if self._config.distributed.world_size > 1: if self._config.distributed.rank == 0: @@ -379,7 +379,9 @@ def generate_config_yaml_for_sharded_dst(self, dataset_configs: list[GPTMemmapDa torch.distributed.destroy_process_group() @classmethod - def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_path: pathlib.Path) -> None: + def _save_dataset_config( + cls, dataset_config: IndexedDatasetConfig[_sample_type], output_path: pathlib.Path + ) -> None: logger.info(f"Saving config to {output_path}") yaml.safe_dump( dataset_config.to_dict(), @@ -387,10 +389,12 @@ def _save_dataset_config(cls, dataset_config: GPTIndexedDatasetConfig, output_pa ) @classmethod - def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) -> GPTIndexedDatasetConfig: + def _blend_dataset_configs( + cls, dataset_configs: list[MemmapDatasetConfig[_sample_type]] + ) -> IndexedDatasetConfig[_sample_type]: if len(dataset_configs) == 1: return dataset_configs[0] - return GPTSampledDatasetConfig.from_dict( + return SampledDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": dataset_configs, @@ -400,8 +404,11 @@ def _blend_dataset_configs(cls, dataset_configs: list[GPTMemmapDatasetConfig]) - @classmethod def _split_and_blend_dataset_configs( - cls, dataset_configs: list[GPTMemmapDatasetConfig], splits: dict[str, int | float], output_path: pathlib.Path - ) -> dict[str, GPTSampledDatasetConfig]: + cls, + dataset_configs: list[MemmapDatasetConfig[_sample_type]], + splits: dict[str, int | float], + output_path: pathlib.Path, + ) -> dict[str, SampledDatasetConfig[_sample_type]]: split_cumsum = padded_cumsum(normalize_probabilities(list(splits.values()), return_array=True)).tolist() dataset_sizes = [dataset_config.num_tokens for dataset_config in dataset_configs] dataset_probabilities = normalize_probabilities(dataset_sizes) @@ -430,13 +437,13 @@ def _split_and_blend_dataset_configs( # Part of the dataset belongs to the split. # TODO: Somehow getting a segfault when merging two lines below (numpy bug?). dataset = dataset_config.to_copy({"path": output_path / dataset_config.path}).build() - sizes_cumsum = dataset.get_document_sizes().cumsum() + sizes_cumsum = dataset.get_document_sizes().numpy().cumsum() Assert.eq(sizes_cumsum[-1], dataset_config.num_tokens) begin_index = _get_nearest_split(sizes_cumsum, split_begin_in_dataset * dataset_config.num_tokens) end_index = _get_nearest_split(sizes_cumsum, split_end_in_dataset * dataset_config.num_tokens) if end_index > begin_index: datasets_in_split.append( - GPTDatasetSliceConfig.from_dict( + DatasetSliceConfig[cls._sample_type].from_dict( { "type": "slice", "dataset": dataset_configs[dataset_index], @@ -458,7 +465,7 @@ def _split_and_blend_dataset_configs( elif len(datasets_in_split) == 1: dataset_splits[split_name] = datasets_in_split[0] else: - dataset_splits[split_name] = GPTBlendedDatasetConfig.from_dict( + dataset_splits[split_name] = BlendedDatasetConfig[cls._sample_type].from_dict( { "type": "blended", "datasets": datasets_in_split, diff --git a/fast_llm/data/sample/__init__.py b/fast_llm/data/sample/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/data/dataset/sample/abstract.py b/fast_llm/data/sample/abstract.py similarity index 93% rename from fast_llm/data/dataset/sample/abstract.py rename to fast_llm/data/sample/abstract.py index cc33bda48..92b24fc33 100644 --- a/fast_llm/data/dataset/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -3,7 +3,7 @@ import typing from fast_llm.config import Configurable -from fast_llm.data.dataset.sample.config import MemmapIndexDatasetReaderConfig, MemmapReaderBaseConfig +from fast_llm.data.sample.config import MemmapIndexDatasetReaderConfig, MemmapReaderBaseConfig if typing.TYPE_CHECKING: import torch diff --git a/fast_llm/data/dataset/sample/config.py b/fast_llm/data/sample/config.py similarity index 89% rename from fast_llm/data/dataset/sample/config.py rename to fast_llm/data/sample/config.py index 61afd627c..512aae50b 100644 --- a/fast_llm/data/dataset/sample/config.py +++ b/fast_llm/data/sample/config.py @@ -9,10 +9,10 @@ if typing.TYPE_CHECKING: from fast_llm.data.dataset.memmap import MemmapDataset - from fast_llm.data.dataset.sample.abstract import MemmapIndexedDatasetReader, MemmapReader - from fast_llm.data.dataset.sample.language_model import LanguageModelReader - from fast_llm.data.dataset.sample.range import RangeReader - from fast_llm.data.dataset.sample.token import TokenReader + from fast_llm.data.sample.abstract import MemmapIndexedDatasetReader, MemmapReader + from fast_llm.data.sample.language_model import LanguageModelReader + from fast_llm.data.sample.range import RangeReader + from fast_llm.data.sample.token import TokenReader @config_class(registry=True) @@ -84,7 +84,7 @@ class RangeReaderConfig(MemmapReaderConfig): @property def reader_class(self) -> "type[RangeReader]": - from fast_llm.data.dataset.sample.range import RangeReader + from fast_llm.data.sample.range import RangeReader return RangeReader @@ -101,7 +101,7 @@ class TokenReaderConfig(MemmapIndexDatasetReaderConfig): @property def reader_class(self) -> "type[TokenReader]": - from fast_llm.data.dataset.sample.token import TokenReader + from fast_llm.data.sample.token import TokenReader return TokenReader @@ -119,7 +119,7 @@ class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): @property def reader_class(self) -> "type[LanguageModelReader]": - from fast_llm.data.dataset.sample.language_model import LanguageModelReader + from fast_llm.data.sample.language_model import LanguageModelReader return LanguageModelReader diff --git a/fast_llm/data/dataset/sample/language_model.py b/fast_llm/data/sample/language_model.py similarity index 92% rename from fast_llm/data/dataset/sample/language_model.py rename to fast_llm/data/sample/language_model.py index 96f27187e..aa53ce854 100644 --- a/fast_llm/data/dataset/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -3,11 +3,11 @@ import numpy as np -from fast_llm.data.dataset.sample import Batch, Sample -from fast_llm.data.dataset.sample.abstract import MemmapIndexedDatasetReader -from fast_llm.data.dataset.sample.config import LanguageModelReaderConfig, NullReaderConfig -from fast_llm.data.dataset.sample.range import RangeBatch, RangeReader, RangeSample -from fast_llm.data.dataset.sample.token import TokenBatch, TokenReader, TokenSample +from fast_llm.data.sample import Batch, Sample +from fast_llm.data.sample.abstract import MemmapIndexedDatasetReader +from fast_llm.data.sample.config import LanguageModelReaderConfig, NullReaderConfig +from fast_llm.data.sample.range import RangeBatch, RangeReader, RangeSample +from fast_llm.data.sample.token import TokenBatch, TokenReader, TokenSample from fast_llm.utils import Assert, get_unique diff --git a/fast_llm/data/dataset/sample/range.py b/fast_llm/data/sample/range.py similarity index 94% rename from fast_llm/data/dataset/sample/range.py rename to fast_llm/data/sample/range.py index def3e7986..68532e5bc 100644 --- a/fast_llm/data/dataset/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -4,9 +4,9 @@ import numpy as np -from fast_llm.data.dataset.sample import Sample -from fast_llm.data.dataset.sample.abstract import MemmapReader -from fast_llm.data.dataset.sample.config import RangeReaderConfig +from fast_llm.data.sample import Sample +from fast_llm.data.sample.abstract import MemmapReader +from fast_llm.data.sample.config import RangeReaderConfig class RangeSample(Sample): diff --git a/fast_llm/data/dataset/sample/token.py b/fast_llm/data/sample/token.py similarity index 94% rename from fast_llm/data/dataset/sample/token.py rename to fast_llm/data/sample/token.py index 649f4e0eb..22ec87a10 100644 --- a/fast_llm/data/dataset/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -4,9 +4,9 @@ import numpy as np import torch -from fast_llm.data.dataset.sample import Batch, Sample -from fast_llm.data.dataset.sample.abstract import MemmapIndexedDatasetReader -from fast_llm.data.dataset.sample.config import TokenReaderConfig +from fast_llm.data.sample import Batch, Sample +from fast_llm.data.sample.abstract import MemmapIndexedDatasetReader +from fast_llm.data.sample.config import TokenReaderConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index a56af3cf8..3a70f308f 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -1,9 +1,11 @@ import torch -def _get_logratios( +def _compute_logprobs_for_preference_spans( logits: torch.Tensor, targets: torch.Tensor, chosen_spans: torch.Tensor, rejected_spans: torch.Tensor ): + assert torch.all(targets < logits.size(-1)), "Target out of vocab range" + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # gather log probabilities corresponding to the target tokens @@ -19,7 +21,23 @@ def _get_logratios( for idx, span in enumerate(rejected_spans): rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - return chosen_logp - rejected_logp + return chosen_logp, rejected_logp, selected_log_probs + + +def _compute_dpo_loss( + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + beta: float, +): + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + diff_logratios = pi_logratios - ref_logratios + + losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) + return losses def compute_dpo_loss( @@ -35,9 +53,21 @@ def compute_dpo_loss( logits_ = logits.float().detach().requires_grad_() reference_model_logits_ = reference_model_logits.float().detach() - pi_logratios = _get_logratios(logits_, targets, chosen_spans, rejected_spans) - ref_logratios = _get_logratios(reference_model_logits_, targets, chosen_spans, rejected_spans) - losses = -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)) + policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( + logits_, targets, chosen_spans, rejected_spans + ) + + reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( + reference_model_logits_, targets, chosen_spans, rejected_spans + ) + + losses = _compute_dpo_loss( + policy_chosen_logps=policy_chosen_logps, + policy_rejected_logps=policy_rejected_logps, + reference_chosen_logps=reference_chosen_logps, + reference_rejected_logps=reference_rejected_logps, + beta=beta, + ) if grad_output is None: loss = None diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index 48efd174e..d6e7dc911 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -2,7 +2,7 @@ import pytest -from fast_llm.data.dataset.sample.config import MemmapDatasetConfig +from fast_llm.data.sample.config import MemmapDatasetConfig from tests.data.common import compare_indexed_dataset, get_dataset_config from tests.utils.dataset import get_test_dataset from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH, DATASET_SAMPLING_CACHE diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index c34463f5d..68b6c7e21 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -5,12 +5,11 @@ import numpy as np import pytest -from fast_llm.data.dataset.gpt.config import GPTIndexedDatasetConfig +from fast_llm.data.dataset.config import IndexedDatasetConfig from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.dataset.sample.language_model import LanguageModelReader -from fast_llm.data.dataset.sampled import GPTSample from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator +from fast_llm.data.sample.language_model import LanguageModelReader from fast_llm.utils import Assert from tests.data.common import MockGPTMemmapDatasetConfig # Noqa @@ -127,7 +126,7 @@ def test_absent_metadata_local(): def test_split_dataset(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + dataset_config_0 = IndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0], {"training": 3, "validation": 1}, @@ -155,8 +154,8 @@ def test_split_dataset(): def test_split_datasets_0(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 1, "validation": 1}, @@ -174,8 +173,8 @@ def test_split_datasets_0(): def test_split_datasets_1(): - dataset_config_0 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = GPTIndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 7abf0fcdd..3e3b19a8d 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -6,8 +6,8 @@ import yaml from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.dataset.sample.language_model import LanguageModelReader, LanguageModelSample -from fast_llm.data.dataset.sample.token import TokenSample +from fast_llm.data.sample.language_model import LanguageModelReader, LanguageModelSample +from fast_llm.data.sample.token import TokenSample from tests.utils.global_variables import ( DATASET_PATH, MODEL_DATASET_PREFIX, From e8333a783ff6c60727d3ef9994f8e2e6968cec49 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 16 Oct 2025 23:57:05 -0400 Subject: [PATCH 6/7] Fix merge --- fast_llm/data/data/gpt/config.py | 10 +- fast_llm/data/data/gpt/data.py | 13 -- fast_llm/data/dataset/config.py | 5 +- fast_llm/data/dataset/gpt/config.py | 18 +-- fast_llm/data/dataset/gpt/fim.py | 35 ++--- fast_llm/data/dataset/gpt/random.py | 32 +++-- fast_llm/data/dataset/sampled.py | 25 ++-- fast_llm/data/sample/abstract.py | 17 ++- fast_llm/data/sample/language_model.py | 75 +++++++++-- fast_llm/data/sample/range.py | 28 ++-- fast_llm/data/sample/token.py | 43 +++++- fast_llm/engine/config_utils/data_type.py | 1 + fast_llm/functional/dpo.py | 71 +++------- fast_llm/layers/attention/attention.py | 96 ++++++------- fast_llm/layers/attention/config.py | 23 +++- fast_llm/layers/language_model/config.py | 7 + fast_llm/layers/language_model/embedding.py | 7 +- fast_llm/models/gpt/config.py | 6 - fast_llm/models/gpt/huggingface.py | 7 +- fast_llm/models/gpt/model.py | 141 ++++++-------------- fast_llm/models/gpt/trainer.py | 1 - 21 files changed, 355 insertions(+), 306 deletions(-) diff --git a/fast_llm/data/data/gpt/config.py b/fast_llm/data/data/gpt/config.py index 432aa09c3..ba5be883a 100644 --- a/fast_llm/data/data/gpt/config.py +++ b/fast_llm/data/data/gpt/config.py @@ -1,11 +1,14 @@ import logging +import typing -from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class +from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.data.config import MultiprocessingContext from fast_llm.data.data.config import DataConfig -from fast_llm.data.dataset.config import SampledDatasetConfig, SamplingConfig +from fast_llm.data.dataset.config import SampledDatasetConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.data.sample.language_model import LanguageModelSample logger = logging.getLogger(__name__) @@ -20,12 +23,11 @@ class GPTDataConfig(DataConfig): _abstract = False # TODO: Review field. Move closer to phase definition in training config? - datasets: dict[str, SampledDatasetConfig[GPTSample]] = Field( + datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field( default_factory=dict, desc="Configuration for the dataset(s).", hint=FieldHint.core, ) - sampling: SamplingConfig = FieldUpdate() data_sample_warn_time_ms: float = Field( default=1000, desc="Warn if a sample takes too long to load.", diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 36144105a..de47ef761 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -1,4 +1,3 @@ -import dataclasses import logging import pathlib import typing @@ -24,20 +23,9 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class GPTBatch: - token_ids: torch.Tensor - loss_masking_spans: list[torch.Tensor] | None = None - sequence_lengths: list[torch.Tensor] | None = None - chosen_spans: list[torch.Tensor] | None = None - rejected_spans: list[torch.Tensor] | None = None - - class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]): """ A global class for all dataset needs, including loading, splitting, sampling and iteration. - Currently hard-coded to a GPT dataset. - TODO: Separate generic and GPT classes. """ _datasets: dict[str, SampledDataset] @@ -137,7 +125,6 @@ def get_iterator( num_workers=num_workers, prefetch_factor=prefetch_factor, pin_memory=True, - # TODO: ====== Make sure the samples are compatible ===== collate_fn=LanguageModelBatch.from_samples, multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 893a02382..8f541c65f 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -59,12 +59,9 @@ class SamplingParameters: Sampling parameters set externally to the dataset and data, ex. determined by the trainer or model. """ + sequence_length: int num_samples: int - # TODO: ====== Always return sequence lengths, let the model decide. ====== - cross_document_attention: bool = True - # TODO: ====== DocumentSamplingParameter? ====== truncate_documents: bool = True - sequence_length: int # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. extra_tokens: int = 1 diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 0eef93522..fdd8ddff4 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -9,9 +9,11 @@ from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.config import SamplableDatasetConfig, SampledDatasetConfig, SamplingData, SamplingParameters +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert if typing.TYPE_CHECKING: + from fast_llm.data.dataset.gpt.fim import GPTFimDataset from fast_llm.data.dataset.gpt.random import GPTRandomDataset @@ -39,7 +41,7 @@ class GPTSamplingData(SamplingData): @config_class(dynamic_type={SampledDatasetConfig: "random"}) -class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): +class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False name: str = Field( default="dummy", @@ -47,14 +49,14 @@ class GPTRandomDatasetConfig[SampleType: GPTSample](SamplableDatasetConfig[Sampl hint=FieldHint.core, ) - def build(self) -> "GPTRandomDataset": + def build(self) -> "GPTRandomDataset[SampleType]": from fast_llm.data.dataset.gpt.random import GPTRandomDataset return GPTRandomDataset(self.name) @config_class(dynamic_type={SampledDatasetConfig: "file"}) -class GPTDatasetFromFileConfig[SampleType: GPTSample](SamplableDatasetConfig[SampleType]): +class GPTDatasetFromFileConfig[SampleType: LanguageModelSample](SamplableDatasetConfig[SampleType]): _abstract: typing.ClassVar[bool] = False path: pathlib.Path = Field( default=None, @@ -164,14 +166,14 @@ class FimConfig(Config): @config_class(dynamic_type={SampledDatasetConfig: "fim"}) -class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType], FimConfig): +class GPTFimSampledDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType], FimConfig): """ Configuration for FIM. """ _abstract: typing.ClassVar[bool] = False - dataset: SampledDatasetConfig = Field( + dataset: SampledDatasetConfig[SampleType] = Field( default=None, desc="The dataset to wrap with fim.", hint=FieldHint.core, @@ -180,14 +182,14 @@ class GPTFimSampledDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[Sam def build_and_sample( self, sampling: GPTSamplingData, - ) -> SampledDataset: + ) -> "GPTFimDataset[SampleType]": from fast_llm.data.dataset.gpt.fim import GPTFimDataset - return GPTFimDataset(self, self.dataset.build_and_sample(sampling), sampling) + return GPTFimDataset[SampleType](self, self.dataset.build_and_sample(sampling), sampling) @config_class(dynamic_type={SampledDatasetConfig: "test_slow"}) -class GPTTestSlowDatasetConfig[SampleType: GPTSample](SampledDatasetConfig[SampleType]): +class GPTTestSlowDatasetConfig[SampleType: LanguageModelSample](SampledDatasetConfig[SampleType]): """ A mock dataset that mimics a slow dataset creation on one rank, which may trigger a timeout. """ diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index dc07e844c..1fde74530 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -3,10 +3,12 @@ from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.gpt.config import FimConfig, GPTSamplingData +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.engine.distributed.config import MAX_SEED -class GPTFimDataset[SampleType: GPTSample](SampledDataset[SampleType]): +class GPTFimDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): """ An implementation of FIM (fill in the middle) post-processing of GPT datasets. Adapted from https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py @@ -42,10 +44,13 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> SampleType: # TODO: Use torch methods to avoid back and forth. - return GPTSample( - torch.from_numpy( - self._fim( - self._dataset[index].token_ids.numpy(), np.random.RandomState(seed=(self._seed + index) % MAX_SEED) + return LanguageModelSample( + TokenSample( + torch.from_numpy( + self._fim( + self._dataset[index].tokens.tokens.numpy(), + np.random.RandomState(seed=(self._seed + index) % MAX_SEED), + ) ) ) ) @@ -78,19 +83,19 @@ def _fim(self, sample: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: permuted = self._fim_split_and_permute_sequence(sample[curr_start_position:], np_rng) new_samples.append(permuted) - sample = np.concatenate(new_samples) + fim_sample = np.concatenate(new_samples) else: - sample = self._fim_split_and_permute_sequence(sample, np_rng) + fim_sample = self._fim_split_and_permute_sequence(sample, np_rng) # Truncate or pad sequence to max-length - diff = sample.shape[0] - sample_len + diff = fim_sample.shape[0] - sample_len if diff > 0: # too long - sample = sample[:sample_len] + fim_sample = fim_sample[:sample_len] elif diff < 0: # too short - sample = np.concatenate([sample, np.full((-1 * diff), self._pad_tok_id)]) + fim_sample = np.concatenate([fim_sample, np.full((-1 * diff), self._pad_tok_id)]) # noqa - assert sample.shape[0] == sample_len - return sample + assert fim_sample.shape[0] == sample_len + return fim_sample.astype(sample.dtype) def _fim_split_and_permute_sequence(self, sequence: np.ndarray, np_rng: np.random.RandomState) -> np.ndarray: """ @@ -163,9 +168,9 @@ def _fim_permute_sequence( middle = contents[boundaries[0] : boundaries[1]] suffix = contents[boundaries[1] :] - prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=np.int64) - middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=np.int64) - suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=np.int64) + prefix = np.array([*self._tokenizer.tokenize(prefix, end=False)], dtype=sequence.dtype) + middle = np.array([*self._tokenizer.tokenize(middle, begin=False, end=False)], dtype=sequence.dtype) + suffix = np.array([*self._tokenizer.tokenize(suffix, begin=False)], dtype=sequence.dtype) # here we truncate each given segment to fit the same length as it was before # A consequence is that we never reach the end of a file? diff --git a/fast_llm/data/dataset/gpt/random.py b/fast_llm/data/dataset/gpt/random.py index 0901f5006..463c5a7d6 100644 --- a/fast_llm/data/dataset/gpt/random.py +++ b/fast_llm/data/dataset/gpt/random.py @@ -3,9 +3,12 @@ from fast_llm.data.dataset.abstract import SamplableDataset, SampledDataset from fast_llm.data.dataset.gpt.config import GPTSamplingData +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample +from fast_llm.engine.config_utils.data_type import get_unsigned_integer_type -class GPTRandomDataset(SamplableDataset): +class GPTRandomDataset[SampleType: LanguageModelSample](SamplableDataset[SampleType]): """ A dummy dataset that always returns the same random sample, for debugging purposes. """ @@ -21,23 +24,30 @@ def name(self) -> str: return self._name -class GPTRandomSampledDataset[SampleType: GPTSample](SampledDataset[SampleType]): +class GPTRandomSampledDataset[SampleType: LanguageModelSample](SampledDataset[SampleType]): def __init__(self, sampling: GPTSamplingData, name: str): self._name = name self._seed = sampling.config.seed - self._sequence_length = sampling.parameters.sequence_length - self._vocab_size = sampling.parameters.vocab_size - self._num_samples = sampling.parameters.num_samples + self._parameters = sampling.parameters + # TODO: Support? + assert not self._parameters.use_loss_masking_spans + assert not self._parameters.use_preference_loss_spans + self._dtype = get_unsigned_integer_type(self._parameters.vocab_size).torch def __len__(self) -> int: - return self._num_samples + return self._parameters.num_samples def __getitem__(self, index: int) -> SampleType: - return GPTSample( - torch.from_numpy( - np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( - 0, self._vocab_size, size=(self._sequence_length + 1,), dtype=np.int64 - ) + # TODO: Sample in self._dtype (breaking) + return LanguageModelSample( + TokenSample( + torch.from_numpy( + np.random.RandomState(self._seed + 48576439 + 74593 * index).randint( + 0, + self._parameters.vocab_size, + size=(self._parameters.sequence_length + self._parameters.extra_tokens,), + ) + ).to(self._dtype), ) ) diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index 42466293a..46a518cd0 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -14,6 +14,7 @@ from fast_llm.data.sample.abstract import Sample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.utils import Assert try: from fast_llm.csrc.data import build_padded_token_cumsum # noqa @@ -67,7 +68,7 @@ def _lazy_load(self): class SampledIndexedDataset[SampleType: Sample](SampledDataset[SampleType]): """ - A sampled GPT dataset. + A sampled dataset. """ def __init__( @@ -108,11 +109,11 @@ def __init__( if sampling.distributed.config.rank == sampling.get_next_rank(): self._sample() # No barrier yet to allow running in parallel. - # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. + # There needs to be one before calling `__getitem__`, normally handled through `Data`. def _sample(self) -> None: """ - Create a `GPTSampledDataset` with the requested parameters. + Create a `SampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. document_sizes = self._indexed_dataset.get_document_sizes().to(self._device) @@ -340,7 +341,7 @@ def __getitem__(self, index: int) -> SampleType: """ Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) with the requested sampling index. - The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). + The returned sample is ready to be concatenated, then fed to a `Model`. """ self._lazy_load() @@ -368,7 +369,7 @@ def __getitem__(self, index: int) -> SampleType: token_count = token_start_array[token_start_cumsum_index] - documents = [] + documents: list[SampleType] = [] while token_count < token_end: # Find the document index in the dataset. if document_sampling_index < self._unshuffled_documents: @@ -388,8 +389,8 @@ def __getitem__(self, index: int) -> SampleType: # Document belongs to the next sample, need to account for padding. padding_size = self._parameters.sequence_length + 1 - tokens_in_sample if token_count > token_start: - # TODO: ====== Handle padding ====== - documents.append(PaddingSample(padding_size)) + documents.append(documents[-1].get_padding(padding_size)) + Assert.eq(token_count + padding_size, token_end) break else: # Move on to the next sample. @@ -401,18 +402,20 @@ def __getitem__(self, index: int) -> SampleType: token_start_index_in_document = max(token_start - token_count, 0) token_end_index_in_document = min(token_end - token_count, document_size) documents.append( - self._indexed_dataset.get( + self._indexed_dataset.get_document( document_index, - offset=token_start_index_in_document, - length=token_end_index_in_document - token_start_index_in_document, + begin=token_start_index_in_document, + end=token_end_index_in_document, + parameters=self._parameters, ) ) + # Go to the next document. document_sampling_index += 1 token_count += document_size # TODO: ====== Better way to get the class method? ====== - return documents[0].merge_documents(documents) + return documents[0].from_documents(documents) @property def name(self) -> str: diff --git a/fast_llm/data/sample/abstract.py b/fast_llm/data/sample/abstract.py index 92b24fc33..a88c3adef 100644 --- a/fast_llm/data/sample/abstract.py +++ b/fast_llm/data/sample/abstract.py @@ -16,20 +16,35 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: pass @abc.abstractmethod - def crop(self, begin: int, end: int): + def crop(self, begin: int, end: int) -> typing.Self: pass @abc.abstractmethod def __len__(self) -> int: pass + @abc.abstractmethod + def get_padding(self, size: int) -> typing.Self: + pass + class Batch(abc.ABC): + # TODO: Relate to `BatchConfig`? @classmethod @abc.abstractmethod def from_samples(cls, samples: typing.Iterable[Sample]) -> typing.Self: pass + @abc.abstractmethod + def to_samples(self) -> list[Sample]: + pass + + def crop(self, begin: int, end: int) -> typing.Self: + return self.from_samples(sample.crop(begin, end) for sample in self.to_samples()) + + def to_device_(self, device: "torch.device | str"): + pass + class MemmapReader[ConfigType: MemmapReaderBaseConfig](Configurable[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview): diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index aa53ce854..c3ffab038 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -3,8 +3,7 @@ import numpy as np -from fast_llm.data.sample import Batch, Sample -from fast_llm.data.sample.abstract import MemmapIndexedDatasetReader +from fast_llm.data.sample.abstract import Batch, MemmapIndexedDatasetReader, Sample from fast_llm.data.sample.config import LanguageModelReaderConfig, NullReaderConfig from fast_llm.data.sample.range import RangeBatch, RangeReader, RangeSample from fast_llm.data.sample.token import TokenBatch, TokenReader, TokenSample @@ -16,55 +15,103 @@ def __init__( self, tokens: TokenSample, loss_masking_spans: RangeSample | None = None, - preference_spans: RangeSample | None = None, + chosen_spans: RangeSample | None = None, + rejected_spans: RangeSample | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans - self.preference_spans = preference_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans @classmethod def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: return cls( - TokenSample.from_documents(document.tokens for document in documents), - _merge_optional(RangeSample.from_documents, (document.loss_masking_spans for document in documents)), - _merge_optional(RangeSample.from_documents, (document.preference_spans for document in documents)), + TokenSample.from_documents([document.tokens for document in documents]), + _merge_optional(RangeSample.from_documents, [document.loss_masking_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), + _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), ) def crop(self, begin: int, end: int) -> typing.Self: return self.__class__( self.tokens.crop(begin, end), - None if self.loss_masking_spans is None else self.loss_masking_spans.crop(begin, end), - None if self.preference_spans is None else self.preference_spans.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), ) def __len__(self) -> int: return len(self.tokens) + def get_padding(self, size: int) -> typing.Self: + return LanguageModelSample( + self.tokens.get_padding(size), + None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), + None if self.chosen_spans is None else self.chosen_spans.get_padding(size), + None if self.rejected_spans is None else self.rejected_spans.get_padding(size), + ) + class LanguageModelBatch(Batch): def __init__( self, tokens: TokenBatch, loss_masking_spans: RangeBatch | None = None, - preference_spans: RangeBatch | None = None, + chosen_spans: RangeBatch | None = None, + rejected_spans: RangeBatch | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans - self.preference_spans = preference_spans + self.chosen_spans = chosen_spans + self.rejected_spans = rejected_spans @classmethod def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: return cls( - TokenBatch.from_samples(sample.tokens for sample in samples), - _merge_optional(RangeBatch.from_samples, (sample.loss_masking_spans for sample in samples)), - _merge_optional(RangeBatch.from_samples, (sample.preference_spans for sample in samples)), + TokenBatch.from_samples([sample.tokens for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.loss_masking_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), + _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), + ) + + def to_samples(self) -> list[LanguageModelSample]: + return [ + LanguageModelSample(tokens, loss_masking_spans, chosen_spans, rejected_spans) + for tokens, loss_masking_spans, chosen_spans, rejected_spans in zip( + self.tokens.to_samples(), + self.loss_masking_spans.to_samples(), + self.chosen_spans.to_samples(), + self.rejected_spans.to_samples(), + strict=True, + ) + ] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens.crop(begin, end), + _crop_optional(self.loss_masking_spans, begin, end), + _crop_optional(self.chosen_spans, begin, end), + _crop_optional(self.rejected_spans, begin, end), ) + def to_device_(self, device: "torch.device | str"): + self.tokens.to_device_(device) + if self.loss_masking_spans is not None: + self.loss_masking_spans.to_device_(device) + if self.chosen_spans is not None: + self.chosen_spans.to_device_(device) + if self.rejected_spans is not None: + self.rejected_spans.to_device_(device) + def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: return None if any(arg is None for arg in args) else fn(args) +def _crop_optional[T: Sample | Batch](sample_or_batch: T, begin: int, end: int) -> T | None: + return None if sample_or_batch is None else sample_or_batch.crop(begin, end) + + class LanguageModelReader[ConfigType: LanguageModelReaderConfig](MemmapIndexedDatasetReader[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview): super().__init__(config, buffer) diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 68532e5bc..224031794 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -1,12 +1,11 @@ -import abc import io import typing import numpy as np -from fast_llm.data.sample import Sample -from fast_llm.data.sample.abstract import MemmapReader +from fast_llm.data.sample.abstract import Batch, MemmapReader, Sample from fast_llm.data.sample.config import RangeReaderConfig +from fast_llm.utils import get_unique class RangeSample(Sample): @@ -14,9 +13,9 @@ class RangeSample(Sample): A reusable component holding a set of ranges in a sample. """ - def __init__(self, sample_size: int, ranges: tuple[tuple[int, int], ...] = ()): - self.sample_size = sample_size + def __init__(self, ranges: list[tuple[int, int]], sample_size: int): self.ranges = ranges + self.sample_size = sample_size @classmethod def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: @@ -27,24 +26,31 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: for begin, end in document.ranges: ranges.extend((begin + sample_size, end + sample_size)) sample_size += document.sample_size - return cls(sample_size, tuple(ranges)) + return cls(ranges, sample_size) def crop(self, begin: int, end: int) -> typing.Self: sample_size = end - begin cropped_ranges = ((max(begin_ - begin, 0), min(end_ - begin, sample_size)) for begin_, end_ in self.ranges) - return self.__class__(sample_size, tuple((begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_)) + return self.__class__([(begin_, end_) for begin_, end_ in cropped_ranges if end_ > begin_], sample_size) def __len__(self) -> int: return self.sample_size + def get_padding(self, size: int) -> typing.Self: + return RangeSample([], size) + -class RangeBatch(abc.ABC): - def __init__(self, ranges: tuple[tuple[tuple[int, int], ...], ...]): - self._ranges = ranges +class RangeBatch(Batch): + def __init__(self, ranges: list[list[tuple[int, int]]], sample_size: int): + self.sample_size = sample_size + self.ranges = ranges @classmethod def from_samples(cls, samples: typing.Iterable[RangeSample]) -> typing.Self: - return cls(tuple(sample.ranges for sample in samples)) + return cls([sample.ranges for sample in samples], get_unique(sample.sample_size for sample in samples)) + + def to_samples(self) -> list[RangeSample]: + return [RangeSample(sample_ranges, self.sample_size) for sample_ranges in self.ranges] class RangeReader[ConfigType: RangeReaderConfig](MemmapReader[ConfigType]): diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index 22ec87a10..d7c5c0920 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -4,8 +4,7 @@ import numpy as np import torch -from fast_llm.data.sample import Batch, Sample -from fast_llm.data.sample.abstract import MemmapIndexedDatasetReader +from fast_llm.data.sample.abstract import Batch, MemmapIndexedDatasetReader, Sample from fast_llm.data.sample.config import TokenReaderConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -17,6 +16,8 @@ def __init__(self, tokens: torch.Tensor, lengths: list[int] | None = None): # Length of each document in the sample. TODO: Use cumsums instead? if lengths is None: lengths = [len(tokens)] + else: + Assert.eq(sum(lengths), len(tokens)) self.lengths = lengths @classmethod @@ -27,17 +28,35 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: ) def crop(self, begin: int, end: int) -> typing.Self: - # We only expect to crop documents, not samples. TODO: Support other cases? - Assert.eq(self.lengths, [len(self.tokens)]) - return self.__class__(self.tokens[begin:end], [end - begin]) + sample_size = end - begin + if self.lengths == [len(self.tokens)]: + # Shortcut for the frequent case of a single document. + lengths = [sample_size] + else: + begin_ = 0 + lengths = [] + for length in self.lengths: + end_ = begin_ + length + cropped_length = min(end_, end) - max(begin_, begin) + if cropped_length > 0: + lengths.append(cropped_length) + if end_ > end: + break + begin_ = end_ + return self.__class__(self.tokens[begin:end], lengths) def __len__(self) -> int: return len(self.tokens) + def get_padding(self, size: int) -> typing.Self: + return TokenSample(torch.full([size], -100, dtype=self.tokens.dtype), [size]) + class TokenBatch(Batch): - def __init__(self, tokens: torch.Tensor, lengths: list[list[int]]) -> None: + def __init__(self, tokens: torch.Tensor, lengths: list[list[int]] | None) -> None: self.tokens = tokens + if lengths is None: + lengths = [[tokens.size(1)]] * tokens.size(0) self.lengths = lengths @classmethod @@ -47,6 +66,18 @@ def from_samples(cls, samples: typing.Iterable[TokenSample]) -> typing.Self: [sample.lengths for sample in samples], ) + def to_samples(self) -> list[TokenSample]: + return [TokenSample(tokens, lengths) for tokens, lengths in zip(self.tokens, self.lengths, strict=True)] + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__( + self.tokens[:, begin:end], [sample.crop(begin, end).lengths for sample in self.to_samples()] + ) + + def to_device_(self, device: "torch.device | str"): + # Also standardize the dtype while we're here. + self.tokens = self.tokens.to(device, dtype=torch.int64, non_blocking=True) + class TokenReader[ConfigType: TokenReaderConfig](MemmapIndexedDatasetReader[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview): diff --git a/fast_llm/engine/config_utils/data_type.py b/fast_llm/engine/config_utils/data_type.py index add121c50..1a0fed91b 100644 --- a/fast_llm/engine/config_utils/data_type.py +++ b/fast_llm/engine/config_utils/data_type.py @@ -168,6 +168,7 @@ def _set_triton_dtype_map() -> None: def get_unsigned_integer_type(max_size: int) -> DataType: + # TODO: Use uint types (recently added for torch, not enough methods supported yet) if max_size < 2**8: return DataType.uint8 elif max_size < 2**15: diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 3a70f308f..7ab0b9ff6 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -1,51 +1,25 @@ import torch -def _compute_logprobs_for_preference_spans( - logits: torch.Tensor, targets: torch.Tensor, chosen_spans: torch.Tensor, rejected_spans: torch.Tensor -): - assert torch.all(targets < logits.size(-1)), "Target out of vocab range" +def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): + # Gather log probabilities corresponding to the target tokens + return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - # gather log probabilities corresponding to the target tokens - selected_log_probs = log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - - # apply chosen mask - chosen_logp = 0 - for idx, span in enumerate(chosen_spans): - chosen_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - - # apply rejected mask - rejected_logp = 0 - for idx, span in enumerate(rejected_spans): - rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() - - return chosen_logp, rejected_logp, selected_log_probs - - -def _compute_dpo_loss( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta: float, -): - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - diff_logratios = pi_logratios - ref_logratios - - losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) - return losses +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans + ) def compute_dpo_loss( logits: torch.Tensor, targets: torch.Tensor, reference_model_logits: torch.Tensor, - chosen_spans: torch.Tensor, - rejected_spans: torch.Tensor, + chosen_spans: list[list[tuple[int, int]]], + rejected_spans: list[list[tuple[int, int]]], beta: float, grad_output: float | None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -53,21 +27,18 @@ def compute_dpo_loss( logits_ = logits.float().detach().requires_grad_() reference_model_logits_ = reference_model_logits.float().detach() - policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( - logits_, targets, chosen_spans, rejected_spans - ) + policy_log_probabilities = _get_target_log_probabilities(logits_, targets) + policy_log_ratios = _get_target_log_probability_for_spans( + policy_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) - reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( - reference_model_logits_, targets, chosen_spans, rejected_spans - ) + reference_log_probabilities = _get_target_log_probabilities(reference_model_logits_, targets) + reference_log_ratios = _get_target_log_probability_for_spans( + reference_log_probabilities, chosen_spans + ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) - losses = _compute_dpo_loss( - policy_chosen_logps=policy_chosen_logps, - policy_rejected_logps=policy_rejected_logps, - reference_chosen_logps=reference_chosen_logps, - reference_rejected_logps=reference_rejected_logps, - beta=beta, - ) + # TODO: ====== Shouldn't the sigmoid be computed independently for each document? + losses = -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)) if grad_output is None: loss = None diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 167184193..ffbe9955e 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -5,11 +5,12 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim from fast_llm.engine.base_model.config import ResourceUsageConfig +from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -79,7 +80,12 @@ def __init__( peft=peft, return_bias=return_bias, ) - self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) + self._implementation = self._config.implementation + if self._implementation == AttentionImplementation.auto: + if _flash_available and self._distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16): + self._implementation = AttentionImplementation.flash + else: + self._implementation = AttentionImplementation.backup self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( @@ -209,8 +215,7 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - with set_generator(self._distributed.tp_generator): - attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) + attn_weights = torch.dropout(attn_weights, self._config.dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value ) @@ -328,29 +333,10 @@ def _forward( query, key = self._rotary(query, key, kwargs) window_size = (-1, -1) if self._config.window_size is None else (self._config.window_size - 1, 0) - - if self._use_flash_attention: - assert _flash_available - with set_generator(self._distributed.tp_generator): - if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: - out_dims = query.size() - query = query.view(-1, query.size(-2), query.size(-1)) - key = key.view(-1, key.size(-2), key.size(-1)) - value = value.view(-1, value.size(-2), value.size(-1)) - input_ = _flash_attn_varlen_func( - query, - key, - value, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), - max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), - max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), - dropout_p=self._config.dropout if self.training else 0.0, - window_size=window_size, - causal=self._config.causal, - softmax_scale=self._softmax_scale, - ).view(*out_dims) - else: + with set_generator(self._distributed.tp_generator): + if self._implementation == AttentionImplementation.flash: + assert _flash_available + if self._config.cross_document_attention: input_ = _flash_attn_func( query, key, @@ -359,17 +345,36 @@ def _forward( dropout_p=self._config.dropout if self.training else 0.0, causal=self._config.causal, softmax_scale=self._softmax_scale, + ).flatten(-2) + else: + input_ = ( + _flash_attn_varlen_func( + query.view(-1, query.size(-2), query.size(-1)), + key.view(-1, key.size(-2), key.size(-1)), + value.view(-1, value.size(-2), value.size(-1)), + cu_seqlens_q=kwargs.get(AttentionKwargs.cu_seqlens_q), + cu_seqlens_k=kwargs.get(AttentionKwargs.cu_seqlens_k), + max_seqlen_q=kwargs.get(AttentionKwargs.max_seqlen_q), + max_seqlen_k=kwargs.get(AttentionKwargs.max_seqlen_k), + dropout_p=self._config.dropout if self.training else 0.0, + window_size=window_size, + causal=self._config.causal, + softmax_scale=self._softmax_scale, + ) + .view(query.size()) + .flatten(-2) ) - input_ = input_.flatten(-2) - else: - # TODO: Avoid the flattens. - input_ = self._attn_fused( - query.flatten(-2), - key.flatten(-2), - value.flatten(-2), - kwargs[AttentionKwargs.attention_mask], - kwargs[AttentionKwargs.attention_mask_value], - ) + elif self._implementation == AttentionImplementation.backup: + # TODO: Avoid the flattens. + input_ = self._attn_fused( + query.flatten(-2), + key.flatten(-2), + value.flatten(-2), + kwargs[AttentionKwargs.attention_mask], + kwargs[AttentionKwargs.attention_mask_value], + ) + else: + raise NotImplementedError(self._implementation) if self._debug.enabled: self._debug(query, "query", self._query_dims, kwargs) @@ -413,8 +418,9 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c attention_compute = sequence_q * sequence_k * attn_compute_base - if (not config.hardware) or self._use_flash_attention: + if (not config.hardware) or self._implementation in AttentionImplementation.flash: # Remove non-causal part. (TODO: Support non-causal) + # TODO: Compute is overestimated without cross-document attention. attention_compute -= (sequence_q * (sequence_q - 1) * attn_compute_base) // 2 if self._config.window_size is not None: @@ -439,10 +445,10 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: self._rotary.preprocess(batch, kwargs) - if not self._use_flash_attention: + if self._implementation == AttentionImplementation.backup: self._preprocess_for_backup_attention(batch, kwargs) - elif AttentionKwargs.sequence_lengths in kwargs: - self._preprocess_for_varlen(batch, kwargs) + elif self._implementation == AttentionImplementation.flash: + self._preprocess_for_flash_attention(batch, kwargs) def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: if ( @@ -471,11 +477,11 @@ def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str kwargs[AttentionKwargs.attention_mask] = self._backup_attention_mask[ None, None, sequence_k - sequence_q : sequence_k, None, :sequence_k ] - if (sequence_lengths := kwargs.get(AttentionKwargs.sequence_lengths, None)) is not None: + if not self._config.cross_document_attention: seq_ids = torch.stack( [ torch.cat([torch.full((x,), i) for i, x in enumerate(sample_lens)]) - for sample_lens in sequence_lengths + for sample_lens in kwargs[AttentionKwargs.sequence_lengths] ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) @@ -485,7 +491,7 @@ def _preprocess_for_backup_attention(self, batch: torch.Tensor, kwargs: dict[str ) kwargs[AttentionKwargs.attention_mask_value] = self._backup_attention_mask_value - def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + def _preprocess_for_flash_attention(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 @@ -495,7 +501,7 @@ def _preprocess_for_varlen(self, batch: torch.Tensor, kwargs: dict[str, typing.A also contain previous tokens from the first document in micro-sequence. We use individual sequence lengths of each document to (optionally) find the micro-sequences in the batch and compute the cumulative lengths. """ - if AttentionKwargs.sequence_lengths not in kwargs: + if self._config.cross_document_attention: return sequence_lengths = kwargs[AttentionKwargs.sequence_lengths] sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 68b6dde91..206fa6e6f 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -1,10 +1,9 @@ +import enum import logging import typing import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockKwargs @@ -32,6 +31,12 @@ class AttentionKwargs(BlockKwargs): past_key_values = "past_key_values" +class AttentionImplementation(enum.StrEnum): + auto = "auto" + flash = "flash" + backup = "backup" + + @config_class(dynamic_type={MixerConfig: "attention"}) class AttentionConfig(MixerConfig): # TODO: Make mixer class dynamic. @@ -107,6 +112,17 @@ class AttentionConfig(MixerConfig): " Under muP (if scaling number of heads instead of head_size): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + implementation: AttentionImplementation = Field( + default=AttentionImplementation.auto, + desc="The implementation to use for the attention layer. Default: `flash` if supported, otherwise `backup`.", + hint=FieldHint.feature, + ) + cross_document_attention: bool = Field( + default=True, + desc="Allow for cross-document attention.", + doc="Disable to prevent attention between tokens belonging to different documents.", + hint=FieldHint.feature, + ) def _validate(self) -> None: super()._validate() @@ -121,6 +137,3 @@ def layer_class(self) -> "type[Attention]": from fast_llm.layers.attention.attention import Attention return Attention - - def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: - return self.use_flash_attention and distributed_config.compute_dtype in (DataType.float16, DataType.bfloat16) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 25fa2d91e..18c64acc4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -53,6 +53,13 @@ class LanguageModelEmbeddingsConfig(BlockConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + cross_document_position_embeddings: bool = Field( + default=True, + desc="Allow for cross-document position embeddings.", + doc="Disable to reset position ids at the beginning of each document.", + hint=FieldHint.feature, + ) + dropout: float = Field( default=0.0, desc="Dropout applied to the embedding layer.", diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 0ad3225c8..61ca1cfc0 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -136,9 +136,12 @@ def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None self._create_position_embeddings(kwargs[LanguageModelKwargs.sequence_length], batch.device) sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: + if not self._config.cross_document_position_embeddings: position_ids = torch.stack( - [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] + [ + torch.cat([torch.arange(x) for x in sample_lens]) + for sample_lens in kwargs[LanguageModelKwargs.sequence_lengths] + ] ).to(batch.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] if kwargs[LanguageModelKwargs.sequence_first]: diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a901a0466..c1ee246f7 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -48,12 +48,6 @@ class GPTBatchConfig(BatchConfig): hint=FieldHint.performance, valid=check_field(Assert.gt, 0), ) - # TODO: Find a better place for these? - cross_document_attention: bool = Field( - default=True, - desc="Applies attention to tokens from other documents in the packed sequence. Set to False for masking attention to other documents.", - hint=FieldHint.feature, - ) use_loss_masking_spans: bool = Field( default=False, desc="Read loss masking spans from the dataset.", diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 9215e6dc7..34e38469a 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -5,7 +5,8 @@ import torch import transformers.modeling_outputs -from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.data.sample.language_model import LanguageModelBatch +from fast_llm.data.sample.token import TokenBatch from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM @@ -80,7 +81,9 @@ def inner_forward( # Iteration serves as a random seed, using random module because it's not seeded by Fast LLM iteration = random.randint(0, 2**32) batch = self.fast_llm_base_model.preprocess_batch( - GPTBatch(input_ids, sequence_lengths=sequence_lenghts), phase=PhaseType.inference, iteration=iteration + LanguageModelBatch(TokenBatch(input_ids, lengths=sequence_lenghts)), + phase=PhaseType.inference, + iteration=iteration, ) ((input_, kwargs),) = batch diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index efa348ecb..3295295f6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -3,7 +3,7 @@ import torch -from fast_llm.data.data.gpt.data import GPTBatch +from fast_llm.data.sample.language_model import LanguageModelBatch from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType @@ -40,7 +40,7 @@ def __init__( param.init_parameter = get_init_megatron(param, self._config.decoder.block, config.hidden_size) # Noqa def preprocess_meta( - self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType + self, batch_meta: GPTBatchConfig | LanguageModelBatch, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: # TODO Remove (Move batch splitting elsewhere) # TODO: Use parallel/sequential dims, distinguish micro and full batch/sequence @@ -51,7 +51,7 @@ def preprocess_meta( micro_sequence_length = batch_meta.micro_sequence_length truncate_documents = batch_meta.truncate_documents else: - micro_batch_size, sequence_length = batch_meta.shape + micro_batch_size, sequence_length = batch_meta.tokens.tokens.shape if phase != PhaseType.inference: sequence_length -= self._config.head.prediction_heads micro_sequence_length = sequence_length @@ -151,7 +151,7 @@ def preprocess_meta( def preprocess_batch( self, - batch: GPTBatch, + batch: LanguageModelBatch, preprocessed_meta: list[tuple[TensorMeta, dict]] | None = None, *, phase: PhaseType, @@ -161,19 +161,10 @@ def preprocess_batch( # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup - if preprocessed_meta is None: - preprocessed_meta = self.preprocess_meta(batch.token_ids, phase) - - _, common_kwargs = preprocessed_meta[0] - sequence_q = common_kwargs[AttentionKwargs.sequence_q_dim].size - sequence_first = common_kwargs[AttentionKwargs.sequence_first] - max_prediction_distance = self._config.head.max_prediction_distance + batch.to_device_(self._distributed.device) - batch.token_ids = batch.token_ids.to( - device=self._distributed.device, - dtype=torch.int64, - non_blocking=True, - ) + if preprocessed_meta is None: + preprocessed_meta = self.preprocess_meta(batch, phase) reference_logits = [{} for _ in preprocessed_meta] for name, reference_model in self._reference_models.items(): @@ -191,103 +182,59 @@ def preprocess_batch( reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] - token_ids = batch.token_ids - if sequence_first: - # Move the sequence dimension first to make sequence parallel ops more efficient. - token_ids = token_ids.transpose(0, 1).contiguous() - preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): - sequence_k = kwargs_meta[AttentionKwargs.sequence_k_dim].size - if sequence_first: - tokens = token_ids[sequence_k - sequence_q : sequence_k] - else: - # TODO: Avoid multiple contiguous calls? - tokens = token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() - if batch.sequence_lengths is not None: - kwargs_meta[AttentionKwargs.sequence_lengths] = batch.sequence_lengths - if batch.chosen_spans is not None: - kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans - if batch.rejected_spans is not None: - kwargs_meta[LanguageModelKwargs.rejected_spans] = batch.rejected_spans + tokens_end = kwargs_meta[AttentionKwargs.sequence_k_dim].size + tokens_begin = tokens_end - kwargs_meta[AttentionKwargs.sequence_q_dim].size + cropped_tokens = batch.tokens.crop(tokens_begin, tokens_end) # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. pasts = presents presents = None if i == len(preprocessed_meta) - 1 else [] - kwargs = { + + kwargs: dict[str, typing.Any] = { **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, + AttentionKwargs.sequence_lengths: batch.tokens.lengths, + **reference_logits[i], } + if phase != PhaseType.inference: - sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels - if sequence_first: - labels = token_ids[sequence_offset : sequence_k + max_prediction_distance] - else: - # TODO: Avoid multiple contiguous calls? - labels = token_ids[:, sequence_offset : sequence_k + max_prediction_distance].contiguous() - # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss - # TODO: take ignore_index from config + labels_begin = tokens_begin + 1 + labels_end = tokens_end + self._config.head.max_prediction_distance + + labels = batch.tokens.crop(labels_begin, labels_end).tokens + if batch.loss_masking_spans is not None: - # avoid changing input tokens - labels = labels.clone() - for idx, spans in enumerate(batch.loss_masking_spans): - if not spans.numel(): - continue - valid_spans = spans[ - (spans[:, 0] <= sequence_k + max_prediction_distance - 1) - & (spans[:, 1] >= sequence_offset) - ] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k + max_prediction_distance - 1) - valid_spans -= sequence_offset - loss_mask = torch.ones_like(labels, dtype=torch.bool) - for start, end in valid_spans: - if sequence_first: - loss_mask[start : end + 1, idx] = False - else: - loss_mask[idx, start : end + 1] = False - if self._config.output_layer.distillation_model is not None: - kwargs[LanguageModelKwargs.loss_mask] = loss_mask - labels = torch.where(loss_mask, labels, -100) - kwargs[LanguageModelKwargs.labels] = labels - kwargs.update(reference_logits[i]) + loss_masking_spans = batch.loss_masking_spans.crop(labels_begin, labels_end) + loss_mask = torch.ones_like(labels, dtype=torch.bool) + for sample_index, loss_masking_spans in enumerate(loss_masking_spans.ranges): + for begin, end in loss_masking_spans: + loss_mask[sample_index, begin:end] = False + if self._config.output_layer.distillation_model is not None: + kwargs[LanguageModelKwargs.loss_mask] = loss_mask + labels = torch.where(loss_mask, labels, -100) + + kwargs[LanguageModelKwargs.labels] = ( + labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels + ).contiguous() if batch.chosen_spans is not None: - chosen_valid_spans = [] - for spans in batch.chosen_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - chosen_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - - rejected_valid_spans = [] - for spans in batch.rejected_spans: - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[0].clamp_(min=sequence_offset) - valid_spans[1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - - rejected_valid_spans.append(valid_spans) - kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans - + kwargs[LanguageModelKwargs.chosen_spans] = batch.chosen_spans.crop(labels_begin, labels_end).ranges + + if batch.rejected_spans is not None: + kwargs[LanguageModelKwargs.rejected_spans] = batch.rejected_spans.crop( + labels_begin, labels_end + ).ranges + + tokens = ( + cropped_tokens.tokens.transpose(0, 1) + if kwargs[AttentionKwargs.sequence_first] + else cropped_tokens.tokens + ).contiguous() self.preprocess(tokens, kwargs) preprocessed.append((tokens, kwargs)) diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 54ea13dc4..b8fb22ebb 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -27,7 +27,6 @@ def _get_sampling_parameters( "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, # OK since DPO is not supported for MTP. "use_preference_loss_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), - "cross_document_attention": self._config.batch.cross_document_attention, "truncate_documents": self._config.batch.truncate_documents, "extra_tokens": self._config.model.base_model.head.max_prediction_distance, } From 158ff64cb6044d47bab464764257316a2082f251 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 17 Oct 2025 00:16:27 -0400 Subject: [PATCH 7/7] Fix merge --- fast_llm/data/dataset/config.py | 1 - fast_llm/data/dataset/gpt/config.py | 2 +- fast_llm/data/preparator/gpt_memmap/config.py | 2 +- .../data/preparator/gpt_memmap/prepare.py | 25 ++- tests/data/common.py | 57 +++-- tests/data/test_blending.py | 13 +- tests/data/test_concatenate.py | 5 +- tests/data/test_fim.py | 6 +- tests/data/test_memmap.py | 10 +- tests/data/test_prepare_gpt_memmap.py | 81 ++++--- tests/data/test_sampling.py | 36 ++-- tests/data/test_slice.py | 5 +- tests/functional/test_functional.py | 197 +++++------------- tests/models/test_match_megatron.py | 39 ++-- tests/test_attention.py | 4 +- tests/test_config.py | 4 +- tests/utils/dataset.py | 44 ++-- tests/utils/global_variables.py | 2 +- tests/utils/model_configs.py | 10 +- 19 files changed, 249 insertions(+), 294 deletions(-) diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index 8f541c65f..20e40b66e 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -33,7 +33,6 @@ class SamplingConfig(Config): A dataset-dependent configuration for sampling. """ - # TODO: ====== DocumentSamplingConfig? ====== seed: int = Field( default=784569, desc="Seed for random sampling.", diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index fdd8ddff4..2124fedd2 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -52,7 +52,7 @@ class GPTRandomDatasetConfig[SampleType: LanguageModelSample](SamplableDatasetCo def build(self) -> "GPTRandomDataset[SampleType]": from fast_llm.data.dataset.gpt.random import GPTRandomDataset - return GPTRandomDataset(self.name) + return GPTRandomDataset[SampleType](self.name) @config_class(dynamic_type={SampledDatasetConfig: "file"}) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index d16181ee0..d2aaee5e2 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -22,7 +22,7 @@ 8: DataType.uint16, } MEMMAP_DTYPES_INV = {y: x for x, y in MEMMAP_DTYPES.items()} -MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x01" +MEMMAP_INDEX_HEADER = b"MMIDIDX\x00\x00" @config_class(registry=True) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 317f87dca..a4200a8f5 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -24,7 +24,7 @@ from fast_llm.data.preparator.config import DatasetPreparator from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig from fast_llm.data.sample.config import MemmapDatasetConfig -from fast_llm.data.sample.language_model import LanguageModelReader +from fast_llm.data.sample.language_model import LanguageModelReader, LanguageModelSample from fast_llm.data.tokenizer import Tokenizer from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum @@ -37,7 +37,7 @@ class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](D _data_type: DataType _text_column: str _loss_masking_spans_column: str | None - _sample_type: typing.ClassVar[type[GPTSample]] = GPTSample + _sample_type: typing.ClassVar[type[LanguageModelSample]] = LanguageModelSample def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: input_ids = [ @@ -140,11 +140,14 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> MemmapDatasetConfig shard_idx, shard_dataset = args def _document_generator(): + # TODO: ====== Yield `LanguageModelSample` ====== if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( + yield ( torch.tensor(item["input_ids"], dtype=self._data_type.torch), torch.tensor(item["token_spans"], dtype=torch.int32).reshape(-1, 2), + None, + None, ) elif ( "chosen_token_spans" in shard_dataset.column_names @@ -153,14 +156,20 @@ def _document_generator(): and self._config.dataset.rejected_text is not None ): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample( - token_ids=torch.tensor(item["input_ids"], dtype=self._data_type.torch), - chosen_span=torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), - rejected_span=torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), + yield ( + torch.tensor(item["input_ids"], dtype=self._data_type.torch), + None, + torch.tensor(item["chosen_token_spans"], dtype=torch.int32).reshape(-1, 2), + torch.tensor(item["rejected_token_spans"], dtype=torch.int32).reshape(-1, 2), ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): - yield GPTSample(torch.tensor(item["input_ids"], dtype=self._data_type.torch)) + yield ( + torch.tensor(item["input_ids"], dtype=self._data_type.torch), + None, + None, + None, + ) MemmapDataset.write_dataset( path=self._config.output_path / f"shard_{self._config.distributed.rank}_{shard_idx}.fast_llm_dataset", diff --git a/tests/data/common.py b/tests/data/common.py index 232ea090a..e6ab8a265 100644 --- a/tests/data/common.py +++ b/tests/data/common.py @@ -8,10 +8,17 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.config import IndexedDatasetConfig, SampledDatasetConfig, SamplingConfig, ShufflingType +from fast_llm.data.dataset.config import ( + IndexedDatasetConfig, + SampledDatasetConfig, + SamplingConfig, + SamplingParameters, + ShufflingType, +) from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters from fast_llm.data.dataset.indexed import IndexedDataset from fast_llm.data.dataset.sampled import SampledIndexedDataset +from fast_llm.data.sample.abstract import Sample from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig @@ -53,7 +60,7 @@ def get_sampling_data( def get_dataset_config[T: SampledDatasetConfig](config: dict[str, typing.Any], cls: type[T]) -> T: dataset_config = SampledDatasetConfig.from_dict(config) - Assert.custom(isinstance, dataset_config, cls) + Assert.custom(isinstance, dataset_config, getattr(cls, "__origin__", cls)) return typing.cast(cls, dataset_config) @@ -96,12 +103,15 @@ def get_test_data_and_compare_samples( batch_config.validate() tokens = { phase: torch.stack( - [batch.token_ids[0] for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0)] + [ + batch.tokens.tokens[0] + for batch in data.get_iterator(batch_config, phase, consumed_samples=0, num_workers=0) + ] ) for phase, samples in samples_per_dataset.items() } for phase, expected_samples_ in expected_samples.items(): - Assert.all_equal(tokens[phase], expected_samples_) + Assert.all_equal(tokens[phase].to(torch.int64), expected_samples_) return data @@ -116,21 +126,30 @@ def compare_indexed_dataset( sizes = dataset.get_document_sizes() # Assert.eq(sizes.sum(), num_tokens) Assert.all_equal( - [len(dataset.get(i).token_ids) for i in range(min(len(dataset), 100))], sizes[: min(len(dataset), 100)] + [len(dataset.get_document(i).tokens.tokens) for i in range(min(len(dataset), 100))], + sizes[: min(len(dataset), 100)], ) for i, expected_sample in expected_samples.items(): - Assert.all_equal(dataset.get(i).token_ids, np.array(expected_sample, dtype=np.uint16)) + Assert.all_equal(dataset.get_document(i).tokens.tokens, np.array(expected_sample, dtype=np.int64)) if loss_masking_spans: for i, loss_masking_span in loss_masking_spans.items(): - Assert.all_equal( - dataset.get(i, use_loss_masking_spans=True).loss_masking_spans, - np.array(loss_masking_spans[i], dtype=np.int32).reshape(-1, 2), + print(i) + Assert.eq( + dataset.get_document( + i, + parameters=GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_loss_masking_spans=True + ), + ).loss_masking_spans.ranges, + loss_masking_spans[i], ) def compare_sampled_dataset(sampled: SampledDataset, expected_samples: list[list[int] | np.ndarray]) -> None: Assert.eq(len(sampled), len(expected_samples)) - Assert.all_equal([sampled[i].token_ids for i in range(len(expected_samples))], expected_samples) + Assert.all_equal( + torch.stack([sampled[i].tokens.tokens for i in range(len(expected_samples))]).to(torch.int64), expected_samples + ) def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_samples: list[list[int]] | None = None): @@ -154,7 +173,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s ) seen_tokens = 0 for document_index in document_sampling: - document = sampled._indexed_dataset.get(document_index).token_ids + document = sampled._indexed_dataset.get_document(document_index).tokens.tokens all_tokens[seen_tokens : seen_tokens + len(document)] = document[: num_tokens - seen_tokens] seen_tokens += len(document) @@ -165,7 +184,7 @@ def validate_indexed_dataset_sampling(sampled: SampledIndexedDataset, expected_s all_tokens[index * sampled._parameters.sequence_length : (index + 1) * sampled._parameters.sequence_length + 1] for index in range(sampled._parameters.num_samples) ] - token_ids = [sampled[i].token_ids for i in range(len(sampled))] + token_ids = torch.stack([sampled[i].tokens.tokens for i in range(len(sampled))]).to(torch.int64) Assert.all_equal(token_ids, validate_samples) if expected_samples is not None: @@ -188,15 +207,15 @@ class MockGPTMemmapDatasetConfig(IndexedDatasetConfig): ) path: pathlib.Path = Field(default=".") - def build(self) -> "GPTIndexedDataset": - return MockGPTMemmapDataset(self) + def build(self) -> "IndexedDataset": + return MockMemmapDataset(self) @property def num_tokens(self) -> int: return self.num_documents * self.num_tokens_per_document -class MockGPTMemmapDataset(IndexedDataset): +class MockMemmapDataset[SampleType: Sample](IndexedDataset[SampleType]): def __init__(self, config: MockGPTMemmapDatasetConfig): self._config = config @@ -207,11 +226,13 @@ def name(self) -> str: def __len__(self) -> int: return self._config.num_documents - def get_document_sizes(self) -> np.ndarray: - return np.full(self._config.num_documents, self._config.num_tokens_per_document, dtype=np.int64) + def get_document_sizes(self) -> torch.Tensor: + return torch.full([self._config.num_documents], self._config.num_tokens_per_document, dtype=torch.int64) def get_document_size(self, index: int) -> int: return self._config.num_tokens_per_document - def get(self, index: int, *args, **kwargs) -> typing.Any: + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: SamplingParameters | None = None + ) -> SampleType: raise NotImplementedError() diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index cf23eede2..a441cc9b7 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -4,6 +4,7 @@ import pytest from fast_llm.data.dataset.config import BlendedDatasetConfig +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert, normalize_probabilities from tests.data.common import ( compare_sampled_dataset, @@ -14,11 +15,11 @@ from tests.utils.dataset import get_test_dataset from tests.utils.global_variables import DATASET_CACHE, DATASET_PATH -_DATASET_PREFIX_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset" +_DATASET_PATH_MIX_1 = DATASET_CACHE / "blended_mix_1" / "dataset.fast_llm_dataset" def _get_test_dataset_mix_1(): - return get_test_dataset(path=_DATASET_PREFIX_MIX_1, seed=2345) + return get_test_dataset(path=_DATASET_PATH_MIX_1, seed=2345) def _get_blending_alt(probs: list[float], num_samples: int) -> tuple[np.ndarray, np.ndarray]: @@ -118,11 +119,11 @@ def test_gpt_blended(): "type": "blended", "datasets": [ {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, + {"type": "memmap", "path": _DATASET_PATH_MIX_1}, ], "weights": [0.75, 0.25], }, - GPTBlendedDatasetConfig, + BlendedDatasetConfig[LanguageModelSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_SAMPLES) @@ -137,7 +138,7 @@ def test_gpt_blended_data(): "type": "blended", "datasets": [ {"type": "memmap", "path": DATASET_PATH}, - {"type": "memmap", "path": _DATASET_PREFIX_MIX_1}, + {"type": "memmap", "path": _DATASET_PATH_MIX_1}, ], "weights": [0.75, 0.25], } @@ -161,7 +162,7 @@ def test_gpt_blended_mixed(): ], "weights": [0.6, 0.4], }, - BlendedDatasetConfig, + BlendedDatasetConfig[LanguageModelSample], ).build_and_sample(get_sampling_data(8, sequence_length=5)) compare_sampled_dataset(sampled, GPT_BLENDED_MIXED_SAMPLES) diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index 0691dde9d..7b009bbf6 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -1,4 +1,5 @@ -from fast_llm.data.dataset.gpt.config import GPTConcatenatedDatasetConfig +from fast_llm.data.dataset.config import ConcatenatedDatasetConfig +from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset, compare_sampled_dataset, @@ -27,7 +28,7 @@ def test_gpt_concatenate(): get_test_dataset() dataset = get_dataset_config( {"type": "concatenated", "datasets": [{"type": "memmap", "path": DATASET_PATH} for _ in range(3)]}, - GPTConcatenatedDatasetConfig, + ConcatenatedDatasetConfig[LanguageModelSample], ).build() compare_indexed_dataset( dataset, diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 1211b92d6..b9dc7fe32 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -1,6 +1,4 @@ -from fast_llm.data.config import TokenizerConfig from fast_llm.data.dataset.gpt.config import GPTFimSampledDatasetConfig -from fast_llm.data.tokenizer import Tokenizer from tests.data.common import ( compare_sampled_dataset, get_dataset_config, @@ -29,13 +27,13 @@ def test_gpt_fim(): sampling_config = get_sampling_data( 8, sequence_length=5, - tokenizer=Tokenizer(TokenizerConfig.from_dict({"path": TOKENIZER_PATH})), vocab_size=49157, ) sampled = get_dataset_config( { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PATH}, + "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", @@ -55,6 +53,7 @@ def test_gpt_fim_data(): "training": { "type": "fim", "dataset": {"type": "memmap", "path": DATASET_PATH}, + "tokenizer": {"path": TOKENIZER_PATH}, "rate": 0.5, "prefix_token": "w", "middle_token": "x", @@ -62,7 +61,6 @@ def test_gpt_fim_data(): "suffix_token": "z", } }, - "tokenizer": {"path": TOKENIZER_PATH}, }, 8, sequence_length=5, diff --git a/tests/data/test_memmap.py b/tests/data/test_memmap.py index d6e7dc911..70874f320 100644 --- a/tests/data/test_memmap.py +++ b/tests/data/test_memmap.py @@ -27,20 +27,20 @@ def test_gpt_memmap(cache_directory): MEMMAP_DATASET_SPANS = { 9: [], - 10: [[0, 4], [6, 8]], - 13: [[1, 2]], + 10: [(0, 2), (2, 7), (7, 10)], + 13: [(0, 2)], 15: [], } -_DATASET_PREFIX_SPANS = DATASET_CACHE / "with_spans" / "dataset" +_DATASET_PATH_SPANS = DATASET_CACHE / "with_spans" / "dataset.fast_llm_dataset" def test_gpt_data_with_spans(): - get_test_dataset(path=_DATASET_PREFIX_SPANS, max_spans=5) + get_test_dataset(path=_DATASET_PATH_SPANS, max_spans=5) dataset = get_dataset_config( { "type": "memmap", - "path": _DATASET_PREFIX_SPANS, + "path": _DATASET_PATH_SPANS, }, MemmapDatasetConfig, ).build() diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 68b6c7e21..ea50cb5aa 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -4,12 +4,14 @@ import numpy as np import pytest +import torch from fast_llm.data.dataset.config import IndexedDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters from fast_llm.data.dataset.memmap import MemmapDataset from fast_llm.data.preparator.gpt_memmap.config import MEMMAP_DTYPES, GPTMemmapDatasetPreparatorConfig from fast_llm.data.preparator.gpt_memmap.prepare import GPTMemmapDatasetPreparator -from fast_llm.data.sample.language_model import LanguageModelReader +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.data.common import MockGPTMemmapDatasetConfig # Noqa @@ -28,52 +30,45 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): - documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] + documents = [ + (torch.from_numpy(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)), None, None, None) + for _ in range(100) + ] with tempfile.TemporaryDirectory() as temp_dir: - path = pathlib.Path(temp_dir) - MemmapDataset.write_dataset(path, documents, LanguageModelReader) - dataset = MemmapDataset(name="foo", path=path) - for i, document in enumerate(documents): - assert np.array_equal( - dataset.get(i).token_ids, document.token_ids, equal_nan=True - ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." + prefix = pathlib.Path(temp_dir) + MemmapDataset.write_dataset(prefix=prefix, documents=documents) + dataset = MemmapDataset(name="foo", prefix=prefix) + for i, (tokens, _, _, _) in enumerate(documents): + Assert.all_equal(dataset.get_document(i).tokens.tokens, tokens.to(torch.int64)) -@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) -def test_write_memmap_preference_dataset(dtype): - def generate_valid_span(max_seq_length): - span = np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False) - return np.sort(span) +def _generate_valid_span(max_seq_length): + return np.sort(np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False)).tolist() - vocab_size = 1000 - max_seq_length = 8192 - num_samples = 100 +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_write_memmap_preference_dataset(dtype): documents = [ - GPTSample( - token_ids=np.random.randint(vocab_size, size=max_seq_length).astype(dtype), - chosen_span=generate_valid_span(max_seq_length=max_seq_length), - rejected_span=generate_valid_span(max_seq_length=max_seq_length), + ( + torch.from_numpy(np.random.randint(1000, size=100).astype(dtype)), + None, + _generate_valid_span(100), + _generate_valid_span(100), ) - for _ in range(num_samples) + for _ in range(50) ] with tempfile.TemporaryDirectory() as temp_dir: - path = pathlib.Path(temp_dir) - MemmapDataset.write_dataset(path, documents, LanguageModelReader) - dataset = MemmapDataset(name="foo", path=path) - for i, document in enumerate(documents): - dataset_item = dataset.get(i, use_preference_loss_spans=True) - assert np.array_equal( - dataset_item.token_ids, document.token_ids, equal_nan=True - ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." - - assert np.array_equal( - dataset_item.chosen_span, document.chosen_span, equal_nan=True - ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get(i).chosen_span}." - - assert np.array_equal( - dataset_item.rejected_span, document.rejected_span, equal_nan=True - ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get(i).rejected_span}." + prefix = pathlib.Path(temp_dir) + MemmapDataset.write_dataset(prefix=prefix, documents=documents) + dataset = MemmapDataset(name="foo", prefix=prefix) + parameters = GPTSamplingParameters( + num_samples=0, sequence_length=0, vocab_size=0, use_preference_loss_spans=True + ) + for i, (token_ids, _, (chosen_begin, chosen_end), (rejected_begin, rejected_end)) in enumerate(documents): + document = dataset.get_document(i, parameters=parameters) + Assert.all_equal(document.tokens.tokens, token_ids.to(torch.int64)) + Assert.eq(document.chosen_spans.ranges, [(chosen_begin, chosen_end + 1)]) + Assert.eq(document.rejected_spans.ranges, [(rejected_begin, rejected_end + 1)]) def test_load_metadata_from_hub(): @@ -126,7 +121,7 @@ def test_absent_metadata_local(): def test_split_dataset(): - dataset_config_0 = IndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0], {"training": 3, "validation": 1}, @@ -154,8 +149,8 @@ def test_split_dataset(): def test_split_datasets_0(): - dataset_config_0 = IndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 1, "validation": 1}, @@ -173,8 +168,8 @@ def test_split_datasets_0(): def test_split_datasets_1(): - dataset_config_0 = IndexedDatasetConfig.from_dict(DATASET_DICT_0.copy()) - dataset_config_1 = IndexedDatasetConfig.from_dict(DATASET_DICT_1.copy()) + dataset_config_0 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_0.copy()) + dataset_config_1 = IndexedDatasetConfig[LanguageModelSample].from_dict(DATASET_DICT_1.copy()) config = GPTMemmapDatasetPreparator._split_and_blend_dataset_configs( [dataset_config_0, dataset_config_1], {"training": 3, "validation": 1}, pathlib.Path(".") ) diff --git a/tests/data/test_sampling.py b/tests/data/test_sampling.py index 8474d8e58..2007eafac 100644 --- a/tests/data/test_sampling.py +++ b/tests/data/test_sampling.py @@ -1,12 +1,13 @@ -import typing - import numpy as np import pytest +import torch from fast_llm.data.dataset.config import ShufflingType -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig -from fast_llm.data.dataset.gpt.indexed import GPTIndexedDataset -from fast_llm.data.dataset.sampled import GPTSample +from fast_llm.data.dataset.gpt.config import GPTSamplingParameters +from fast_llm.data.dataset.indexed import IndexedDataset +from fast_llm.data.sample.config import MemmapDatasetConfig +from fast_llm.data.sample.language_model import LanguageModelSample +from fast_llm.data.sample.token import TokenSample from fast_llm.utils import Assert from tests.data.common import ( get_dataset_config, @@ -40,7 +41,7 @@ def test_gpt_sampled(): # Make sure the memmap dataset works and check for unintended changes in behavior. get_test_dataset() - sampled = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, GPTMemmapDatasetConfig).build_and_sample( + sampled = get_dataset_config({"type": "memmap", "path": DATASET_PATH}, MemmapDatasetConfig).build_and_sample( get_sampling_data(8, sequence_length=5) ) validate_indexed_dataset_sampling(sampled, GPT_MEMMAP_SAMPLES) @@ -53,7 +54,7 @@ def test_gpt_sampled_data(): "datasets": { "training": { "type": "memmap", - "path": DATASET_PATH, + "path": DATASET_PREFIX, } } }, @@ -63,24 +64,23 @@ def test_gpt_sampled_data(): ) -class SimpleGPTIndexedDataset(GPTIndexedDataset): +class SimpleGPTIndexedDataset[SampleType: LanguageModelSample](IndexedDataset[SampleType]): # TODO: worth adding to the main codebase? def __init__(self, samples): self._samples = samples - def get(self, index: int, offset=0, length=None, use_loss_masking_spans: bool = False) -> typing.Any: - if length is None: - length = len(self._samples[index]) - assert not use_loss_masking_spans - return GPTSample( - token_ids=np.array(self._samples[index][offset : offset + length], dtype=np.int64), loss_masking_spans=None - ) + def get_document( + self, index: int, begin: int = 0, end: int | None = None, parameters: GPTSamplingParameters | None = None + ) -> SampleType: + if end is None: + end = len(self._samples[index]) + return LanguageModelSample(TokenSample(torch.tensor(self._samples[index][begin:end], dtype=torch.int64))) def __len__(self) -> int: return len(self._samples) - def get_document_sizes(self) -> np.ndarray: - return np.array([self.get_document_size(index) for index in range(len(self))], dtype=np.int64) + def get_document_sizes(self) -> torch.Tensor: + return torch.tensor([self.get_document_size(index) for index in range(len(self))], dtype=torch.int64) def get_document_size(self, index: int) -> int: return len(self._samples[index]) @@ -181,4 +181,4 @@ def test_gpt_sample_padding(): else: sampled = dataset.sample(sampling) for idx in range(len(expected_samples)): - Assert.all_equal(sampled[idx].token_ids, np.array(expected_samples[idx])) + Assert.all_equal(sampled[idx].tokens.tokens, np.array(expected_samples[idx])) diff --git a/tests/data/test_slice.py b/tests/data/test_slice.py index 07a9bd776..3a6b999cd 100644 --- a/tests/data/test_slice.py +++ b/tests/data/test_slice.py @@ -1,4 +1,5 @@ -from fast_llm.data.dataset.gpt.config import GPTDatasetSliceConfig +from fast_llm.data.dataset.config import DatasetSliceConfig +from fast_llm.data.sample.language_model import LanguageModelSample from tests.data.common import ( compare_indexed_dataset, get_dataset_config, @@ -34,7 +35,7 @@ def test_gpt_slice(): # samples[9:18] dataset = get_dataset_config( {"type": "slice", "dataset": {"type": "memmap", "path": DATASET_PATH}, "begin": 0.0015, "end": 0.003}, - GPTDatasetSliceConfig, + DatasetSliceConfig[LanguageModelSample], ).build() compare_indexed_dataset(dataset, 9, 544, {i - 9: sample for i, sample in MEMMAP_DATASET_SAMPLES.items()}) sampled = dataset.sample(get_sampling_data(8, sequence_length=5)) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 65c7587b2..489f5e1c1 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,167 +1,80 @@ -import random - import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.functional.dpo import _compute_dpo_loss, _get_logratios +from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert +from tests.utils.dataset import get_random_spans from tests.utils.utils import requires_cuda -def ref_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: - if temperature != 1.0: - logits.div_(temperature) - batch_dim = logits.shape[:-1] - last_dim = logits.shape[-1] - - output = torch.nn.functional.cross_entropy(logits.reshape(-1, last_dim), labels.reshape(-1), reduction="none") - log_probs_labels = -output.view(*batch_dim) - - return log_probs_labels - - -def ref_packed_get_batch_logps( - logits: torch.FloatTensor, - labels: torch.LongTensor, - attention_mask, - prompt_id_lens, - packed_seq_lens, -) -> torch.FloatTensor: - labels = labels[:, 1:] - logits = logits[:, :-1, :] - per_token_logps = ref_log_probs_from_logits(logits, labels) - - loss_masks = attention_mask.clone().bool() - - index = 0 - for i, seq_len in enumerate(packed_seq_lens): - loss_masks[0, index : index + prompt_id_lens[i]] = False - index = index + seq_len - - loss_masks = loss_masks[:, 1:] - - logprobs_sums = [] - index = 0 - for i, seq_len in enumerate(packed_seq_lens): - seq = per_token_logps[0, index : index + seq_len - 1] - mask = loss_masks[0, index : index + seq_len - 1] - logprobs_sums.append((seq * mask).sum()) - index = index + seq_len - chosen_logps = logprobs_sums[: len(packed_seq_lens) // 2] - rejected_logps = logprobs_sums[len(packed_seq_lens) // 2 :] - - return torch.tensor(chosen_logps), torch.tensor(rejected_logps) - - -@pytest.mark.slow -@pytest.mark.parametrize( - ("batch_size", "seq_length", "vocab_size"), - ( - (2, 32, 50), - (1, 32, 50), - (2, 100, 50), - (2, 32, 200), - ), -) -def test_preference_logps(batch_size, seq_length, vocab_size): - random.seed(0) - torch.manual_seed(0) - - def random_split(seq_length): - min_val = int(seq_length * 0.3) - max_val = int(seq_length * 0.7) - - if max_val < min_val: - max_val = min_val - - a = random.randint(min_val, max_val) - b = seq_length - a - return [a, b] - - logits = torch.randn(batch_size, seq_length, vocab_size) - targets = torch.randint(0, vocab_size, (batch_size, seq_length)) - packed_seq_lens = random_split(seq_length) # simulate different chosen/rejected lengths - prompt_id_lens = [int(min(packed_seq_lens) * 0.75)] * 2 # sequences are 75% prompt 25% generation - attention_mask = torch.tensor([1] * packed_seq_lens[0] + [2] * packed_seq_lens[1]).unsqueeze(0) - - chosen_span = torch.tensor([[prompt_id_lens[0], packed_seq_lens[0] - 1]]) - 1 # shift by 1 due to label shifting - rejected_span = ( - torch.tensor([[packed_seq_lens[0] + prompt_id_lens[1], packed_seq_lens[0] + packed_seq_lens[1] - 1]]) - 1 - ) # shift by 1 due to label shifting - - ref_chosen_logps, ref_rejected_logps = ref_packed_get_batch_logps( - logits, targets, attention_mask, prompt_id_lens, packed_seq_lens +def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): + return sum( + log_probabilities[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(spans) + for begin, end in sample_spans ) - chosen_logps, rejected_logps, selected_log_probs = _get_logratios( - logits=logits, - targets=targets[:, 1:], - chosen_spans=chosen_span, - rejected_spans=rejected_span, - ) - - ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) - - # check all logps - Assert.custom(torch.allclose, ref_logps, selected_log_probs, rtol=1e-5) - # check chosen and rejected summed logps - Assert.custom(torch.allclose, ref_chosen_logps, chosen_logps, rtol=1e-5) - Assert.custom(torch.allclose, ref_rejected_logps, rejected_logps, rtol=1e-5) - - -def ref_dpo_loss_fcn( - policy_chosen_logps: torch.Tensor, - policy_rejected_logps: torch.Tensor, - reference_chosen_logps: torch.Tensor, - reference_rejected_logps: torch.Tensor, - beta=1, - label_smoothing=0, +def reference_dpo_loss( + logits: torch.Tensor, + targets: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, + beta: float, ) -> torch.Tensor: + # TODO: Too similar to the actual implementation. + policy_log_probs = ( + torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + ) + policy_chosen_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + policy_rejected_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + reference_log_probs = ( + torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) + .gather(dim=-1, index=targets.unsqueeze(-1)) + .squeeze(-1) + ) + reference_chosen_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + reference_rejected_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) pi_logratios = policy_chosen_logps - policy_rejected_logps ref_logratios = reference_chosen_logps - reference_rejected_logps - logits = pi_logratios - ref_logratios - - # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) - losses = ( - -torch.nn.functional.logsigmoid(beta * logits) * (1 - label_smoothing) - - torch.nn.functional.logsigmoid(-beta * logits) * label_smoothing - ) - - loss = losses.mean() - - return loss + return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() def test_dpo_loss(): torch.manual_seed(0) + logits = torch.randn((10, 50, 100), requires_grad=True) + reference_model_logits = torch.randn((10, 50, 100)) + targets = torch.randint(0, 100, (10, 50)) - NUM_SAMPLES = 20 - policy_chosen_logps = torch.rand(NUM_SAMPLES) - policy_rejected_logps = torch.rand(NUM_SAMPLES) - reference_chosen_logps = torch.rand(NUM_SAMPLES) - reference_rejected_logps = torch.rand(NUM_SAMPLES) - betas = torch.rand(NUM_SAMPLES) + spans = get_random_spans(10, 10, 50) - for i in range(NUM_SAMPLES): - fastllm_dpo_loss = _compute_dpo_loss( - policy_chosen_logps=policy_chosen_logps[i], - policy_rejected_logps=policy_rejected_logps[i], - reference_chosen_logps=reference_chosen_logps[i], - reference_rejected_logps=reference_rejected_logps[i], - beta=betas[i].item(), - ) - ref_dpo_loss = ref_dpo_loss_fcn( - policy_chosen_logps=policy_chosen_logps[i].unsqueeze(0), - policy_rejected_logps=policy_rejected_logps[i].unsqueeze(0), - reference_chosen_logps=reference_chosen_logps[i].unsqueeze(0), - reference_rejected_logps=reference_rejected_logps[i].unsqueeze(0), - beta=betas[i].item(), - ) - Assert.rms_close(fastllm_dpo_loss, ref_dpo_loss, 1e-5) + fastllm_loss, fast_llm_grad = compute_dpo_loss( + logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1, grad_output=1 + ) + reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1) + reference_loss.backward() + Assert.rms_close(fastllm_loss, reference_loss, 1e-5) + Assert.rms_close(fast_llm_grad, logits.grad, 1e-5) @requires_cuda diff --git a/tests/models/test_match_megatron.py b/tests/models/test_match_megatron.py index c4146e94b..9c3b15d8e 100644 --- a/tests/models/test_match_megatron.py +++ b/tests/models/test_match_megatron.py @@ -6,14 +6,17 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.data.dataset.abstract import SampledDataset -from fast_llm.data.dataset.gpt.config import GPTMemmapDatasetConfig, GPTSampledDatasetConfig, GPTSamplingData +from fast_llm.data.dataset.config import SampledDatasetConfig +from fast_llm.data.dataset.gpt.config import GPTSamplingData from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.dataset.sampled import GPTSample, logger +from fast_llm.data.dataset.sampled import logger +from fast_llm.data.sample.config import MemmapDatasetConfig +from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.utils import Assert from tests.utils.compare_tensor_logs import CompareConfig from tests.utils.dataset import get_model_test_dataset from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PREFIX +from tests.utils.global_variables import MODEL_DATASET_PATH from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -67,7 +70,7 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare="megatron", config_args=[ "model.distributed.compute_dtype=fp32", - f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PREFIX}}}', + f'data.datasets.training={{"type":"megatron","path":{MODEL_DATASET_PATH}}}', "data.sampling.seed=1234", "model.base_model.use_megatron_initialization=True", ], @@ -79,15 +82,15 @@ def test_match_megatron(run_test_script_for_all_models, model_testing_config, co compare_results_for_all_models(distributed_testing_config) -@config_class(dynamic_type={GPTSampledDatasetConfig: "megatron"}) -class GPTMegatronDatasetConfig(GPTMemmapDatasetConfig): +@config_class(dynamic_type={SampledDatasetConfig: "megatron"}) +class GPTMegatronDatasetConfig(MemmapDatasetConfig): _abstract: typing.ClassVar[bool] = False path: str = Field( desc="Dataset path (prefix).", hint=FieldHint.core, ) - def build(self) -> "MemmapDataset": + def build(self) -> "GPTMemmapDataset": return GPTMegatronMemmapDataset( str(self.path).replace("/", "__"), self.path, self.num_documents, self.num_tokens ) @@ -141,18 +144,16 @@ def __getitem__(self, idx: int) -> typing.Any: shuffled_idx = self._shuffle_idx[idx] doc_f, offset_f = self._sample_idx[shuffled_idx] doc_l, offset_l = self._sample_idx[shuffled_idx + 1] - sample_list = [ - self._indexed_dataset.get( - self._doc_idx[doc].item(), - offset=(doc == doc_f) * offset_f, - length=offset_l + 1 - (doc == doc_f) * offset_f if doc == doc_l else None, - ) - for doc in range(doc_f, doc_l + 1) - ] - token_ids = np.concatenate([sample.token_ids for sample in sample_list], dtype=np.int64) - Assert.eq(len(token_ids), self._sequence_length + 1) - - return GPTSample(token_ids=token_ids) + return LanguageModelSample.from_documents( + [ + self._indexed_dataset.get_document( + self._doc_idx[doc].item(), + begin=(doc == doc_f) * offset_f, + end=offset_l + 1 if doc == doc_l else None, + ) + for doc in range(doc_f, doc_l + 1) + ] + ) @property def name(self) -> str: diff --git a/tests/test_attention.py b/tests/test_attention.py index a19cba8f0..b86cc95fa 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -3,7 +3,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.attention.attention import Attention -from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.block.config import BlockDimNames from fast_llm.utils import Assert @@ -29,7 +29,7 @@ def test_varlen_preprocessing(): micro_sequence_length = 12 sequence_length = 36 attention = Attention( - AttentionConfig(head_size=64), + AttentionConfig(head_size=64, implementation=AttentionImplementation.flash, cross_document_attention=False), DistributedConfig(compute_dtype="bfloat16"), hidden_dim=TensorDim("", 1), lr_scale=None, diff --git a/tests/test_config.py b/tests/test_config.py index 63f2606f1..9a1f542a0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,7 +6,7 @@ import yaml from fast_llm.config import NoAutoValidate -from fast_llm.data.dataset.gpt.config import GPTSamplingConfig +from fast_llm.data.dataset.config import SamplingConfig from fast_llm.engine.checkpoint.config import CheckpointSaveMetadataConfig, ModelConfigType from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig, PretrainedGPTModelConfig @@ -60,7 +60,7 @@ def test_validate_example_config(): GPTTrainerConfig.from_dict(fast_llm_config_dict) -@pytest.mark.parametrize("cls", (GPTSamplingConfig, GPTModelConfig)) +@pytest.mark.parametrize("cls", (SamplingConfig, GPTModelConfig)) def test_serialize_default_config_updates(cls): # Config classes used as config updates should have a default that serializes to an empty dict # so no value is incorrectly overridden. diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 3e3b19a8d..97f7b39fa 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -6,11 +6,10 @@ import yaml from fast_llm.data.dataset.memmap import MemmapDataset -from fast_llm.data.sample.language_model import LanguageModelReader, LanguageModelSample -from fast_llm.data.sample.token import TokenSample +from fast_llm.data.sample.language_model import LanguageModelReader from tests.utils.global_variables import ( DATASET_PATH, - MODEL_DATASET_PREFIX, + MODEL_DATASET_PATH, MODEL_TEST_VOCAB_SIZE, TEST_CHARACTERS, TEST_DATASET_TOKENS, @@ -27,6 +26,15 @@ def download_santacoder_tokenizer(): transformers.AutoTokenizer.from_pretrained("bigcode/santacoder").save_pretrained(TOKENIZER_PATH) +def get_random_spans(num_samples: int, max_spans: int, lengths: np.ndarray | int, seed: int = 0): + spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths, [num_samples, max_spans * 2])) + spans = [np.unique(sample_spans).tolist() for sample_spans in spans] + return [ + [(begin, end) for begin, end in zip(sample_spans[::2], sample_spans[1::2], strict=False)] + for sample_spans in spans + ] + + def get_test_dataset( path: pathlib.Path = DATASET_PATH, seed: int = 1234, @@ -44,26 +52,34 @@ def get_test_dataset( tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH) samples = [ - LanguageModelSample( - TokenSample( - torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size), - ) + ( + torch.from_numpy(np.array(tokenizer(document)["input_ids"], dtype=np.uint16) % vocab_size), + None, + None, + None, ) for document in texts ] if max_spans > 0: - lengths = np.array([max(len(sample), 1) for sample in samples]) - spans = np.sort(np.random.RandomState(seed + 3847).randint(0, lengths[:, None], [len(samples), max_spans])) - for sample, span in zip(samples, spans): - span = np.unique(span) - sample.loss_masking_spans = torch.from_numpy(span[: len(span) // 2 * 2].reshape(-1, 2)) + spans = get_random_spans( + len(samples), max_spans, np.array([[max(len(tokens), 1)] for tokens, _, _, _ in samples]), seed + ) + samples = [ + ( + tokens, + torch.tensor(sample_spans, dtype=torch.int32).reshape(-1, 2), + None, + None, + ) + for (tokens, _, _, _), sample_spans in zip(samples, spans, strict=True) + ] MemmapDataset.write_dataset(path, samples, LanguageModelReader) yaml.safe_dump({"type": "memmap", "path": path.name}, path.parent.joinpath("fast_llm_config.yaml").open("w")) def get_model_test_dataset( - prefix: pathlib.Path = MODEL_DATASET_PREFIX, + path: pathlib.Path = MODEL_DATASET_PATH, vocab_size: int = MODEL_TEST_VOCAB_SIZE, ): - return get_test_dataset(path=prefix, vocab_size=vocab_size) + return get_test_dataset(path=path, vocab_size=vocab_size) diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 8ff6d2a9f..783c62634 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -44,5 +44,5 @@ def set_testing_global_variables(): TEST_CHARACTERS = (string.ascii_lowercase) * 5 + " " * 30 + "\n" TEST_DATASET_TOKENS = 1000000 -MODEL_DATASET_PREFIX = DATASET_CACHE / "model_dataset" +MODEL_DATASET_PATH = DATASET_CACHE / "model_dataset.fast_llm_dataset" MODEL_TEST_VOCAB_SIZE = 384 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index c02521d7b..adcf84b18 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -22,7 +22,7 @@ Qwen2CheckpointFormat, ) from tests.utils.distributed_configs import DistributedTestingConfig -from tests.utils.global_variables import MODEL_DATASET_PREFIX, MODEL_TEST_VOCAB_SIZE +from tests.utils.global_variables import MODEL_DATASET_PATH, MODEL_TEST_VOCAB_SIZE from fast_llm.engine.evaluation.evaluators import ( # isort:skip # needed for dynamic type registration EvaluatorsConfig, @@ -234,18 +234,18 @@ def _update_and_add_testing_config( "data": { "datasets": { "training": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, "type": "slice", "end": 0.969, }, "validation": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, "type": "slice", "begin": 0.969, "end": 0.999, }, "test": { - "dataset": {"type": "memmap", "path": MODEL_DATASET_PREFIX}, + "dataset": {"type": "memmap", "path": MODEL_DATASET_PATH}, "type": "slice", "begin": 0.999, "end": 1, @@ -279,7 +279,7 @@ def _update_and_add_testing_config( "--tokenizer-type=NullTokenizer", # Megatron messes with the vocab size, so we have to subtract 1. f"--vocab-size={MODEL_TEST_VOCAB_SIZE - 1}", - f"--data-path={MODEL_DATASET_PREFIX}", + f"--data-path={MODEL_DATASET_PATH}", "--split=1,0,0", "--lr-decay-style=constant", # Initialization is set up to match MCore models (MCore inverts self-attn qkv and dense layers compared to original Megatron)