Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ecd1918
clean history
jlamypoirier Sep 22, 2025
9114ce2
Vision multimodal
jlamypoirier Sep 26, 2025
a44642c
Drop varlen mamba
jlamypoirier Sep 26, 2025
ddf2143
cleanup
jlamypoirier Sep 26, 2025
8ee7d5e
cleanup
jlamypoirier Sep 26, 2025
43ca913
cleanup
jlamypoirier Sep 26, 2025
15405a1
cleanup
jlamypoirier Sep 26, 2025
a3dc89d
stuff
jlamypoirier Sep 26, 2025
414f87e
stuff
jlamypoirier Sep 26, 2025
4a21360
stuff
jlamypoirier Sep 26, 2025
2180ea5
Merge branch 'jlp/mlp_block' into jlp/vision_multimodal
jlamypoirier Sep 26, 2025
bb7c62d
Embeddings
jlamypoirier Sep 26, 2025
47b9a44
Model interface
jlamypoirier Oct 1, 2025
09b0215
Merge branch 'main' into jlp/vision_multimodal
jlamypoirier Oct 3, 2025
f31a313
Fix merge
jlamypoirier Oct 3, 2025
3d84972
model
jlamypoirier Oct 4, 2025
8f8ef19
cleanup
jlamypoirier Oct 4, 2025
6084122
language_model
jlamypoirier Oct 4, 2025
4a96980
fixes
jlamypoirier Oct 6, 2025
7854138
fixes
jlamypoirier Oct 6, 2025
0350e17
Merge branch 'jlp/language_model_block' into jlp/vision_multimodal
jlamypoirier Oct 6, 2025
1a18929
Dataset interface
jlamypoirier Oct 15, 2025
fd63846
misc
jlamypoirier Oct 15, 2025
2486caf
fix
jlamypoirier Oct 15, 2025
92e93e8
Language model sample
jlamypoirier Oct 16, 2025
d6f6944
fix
jlamypoirier Oct 16, 2025
5c802fa
fixes
jlamypoirier Oct 16, 2025
95d1840
test
jlamypoirier Oct 16, 2025
eafd9cb
fixes
jlamypoirier Oct 17, 2025
c56df69
cleanup
jlamypoirier Oct 17, 2025
7f437e1
misc
jlamypoirier Oct 17, 2025
dfd27f5
misc
jlamypoirier Oct 17, 2025
d937f58
Merge branch 'main' into jlp/vision_multimodal
jlamypoirier Oct 17, 2025
11bbee2
Merge branch 'jlp/lm_sample' into jlp/vision_multimodal
jlamypoirier Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ COPY --chmod=777 ./fast_llm/__init__.py fast_llm/
COPY --chmod=777 ./fast_llm/csrc/ fast_llm/csrc/

# Install dependencies within the virtual environment.
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,GENERATION,DEV]" triton==3.1.0
RUN pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,DEV]" triton==3.1.0

# Copy the remaining source code with universal write permissions.
COPY --chmod=777 ./Megatron-LM Megatron-LM
Expand Down
21 changes: 17 additions & 4 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,10 @@ def _validate_element(cls, value, type_, name: str):
value = cls._validate_dict(value, type_, name)
elif origin is type:
value = cls._validate_type(value, type_, name)
elif issubclass(origin, Config):
# TODO: Validate arguments for config generics.
cls._validate_element_type(value, type_.__origin__, strict=False)
value.validate(_is_validating=True)
else:
raise FieldTypeError(f"Unsupported __origin__ `{origin}`")
elif not isinstance(type_, type):
Expand Down Expand Up @@ -806,17 +810,24 @@ def _from_dict_nested(cls, value, type_, strict: bool):
value = cls._from_dict_array(value, type_, strict)
elif issubclass(origin, dict):
value = cls._from_dict_dict(value, type_, strict)
elif issubclass(origin, Config):
value = cls._from_dict_config(value, type_, strict)
elif origin is type:
pass
else:
raise FieldTypeError(f"Unsupported __origin__ `{origin}`")
elif not isinstance(type_, type):
raise FieldTypeError(f"Not a type: {type_}.")
elif issubclass(type_, Config):
if value is MISSING:
value = {}
if isinstance(value, dict):
value = type_._from_dict(value, strict)
value = cls._from_dict_config(value, type_, strict)
return value

@classmethod
def _from_dict_config(cls, value, type_, strict: bool):
if value is MISSING:
value = {}
if isinstance(value, dict):
value = type_._from_dict(value, strict)
return value

@classmethod
Expand Down Expand Up @@ -938,6 +949,7 @@ def __init_subclass__(cls):
We need to postpone validation until the class has been processed by the dataclass wrapper.
"""
Assert.eq(cls.__name__, cls.__qualname__)
super().__init_subclass__()
for base_class in cls.__mro__:
if issubclass(base_class, Config) and base_class is not cls:
assert cls.__class_validated__, (
Expand Down Expand Up @@ -1006,6 +1018,7 @@ def __init__(self, config: ConfigType, *args, **kwargs):
def __init_subclass__(cls):
# Automatically set `config_class` based on the bound type.
# Make sure `ConfigType` is bound and respects class hierarchy.
super().__init_subclass__()
try:
config_class = None
for base in types.get_original_bases(cls):
Expand Down
11 changes: 10 additions & 1 deletion fast_llm/data/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import enum
import pathlib
import typing

from fast_llm.config import Config, Field, FieldHint, check_field, config_class
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.tokenizer import Tokenizer


class MultiprocessingContext(str, enum.Enum):
# Fast but risk of segfaults due to interactions with triton
Expand All @@ -29,7 +33,7 @@ class TokenizerConfig(Config):
hint=FieldHint.deprecated,
valid=check_field(Assert.eq, TokenizerFromFile),
)
path: pathlib.Path | None = Field(
path: pathlib.Path = Field(
default=None,
desc="Path to the tokenizer file.",
hint=FieldHint.core,
Expand All @@ -39,3 +43,8 @@ class TokenizerConfig(Config):
desc="BOS token to use if the tokenizer doesn't define one; must be an existing token.",
hint=FieldHint.core,
)

def get_tokenizer(self) -> "Tokenizer":
from fast_llm.data.tokenizer import Tokenizer

return Tokenizer(self)
3 changes: 2 additions & 1 deletion fast_llm/data/data/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from fast_llm.config import Configurable
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.sample.abstract import Batch
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.schedule.config import BatchConfig

Expand Down Expand Up @@ -47,5 +48,5 @@ def get_iterator(
num_workers: int,
prefetch_factor: int | None = None,
timeout: float = 60,
) -> typing.Iterator[typing.Any]:
) -> typing.Iterator[Batch]:
pass
16 changes: 7 additions & 9 deletions fast_llm/data/data/gpt/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import logging
import typing

from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.data.config import MultiprocessingContext, TokenizerConfig
from fast_llm.config import Field, FieldHint, check_field, config_class
from fast_llm.data.config import MultiprocessingContext
from fast_llm.data.data.config import DataConfig
from fast_llm.data.dataset.gpt.config import GPTSampledDatasetConfig, GPTSamplingConfig
from fast_llm.data.dataset.config import SampledDatasetConfig
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.data.sample.language_model import LanguageModelSample
logger = logging.getLogger(__name__)


Expand All @@ -19,17 +22,12 @@ class GPTDataConfig(DataConfig):

_abstract = False

tokenizer: TokenizerConfig = Field(
desc="Configuration for the tokenizer (for FIM).",
hint=FieldHint.feature,
)
# TODO: Review field. Move closer to phase definition in training config?
datasets: dict[str, GPTSampledDatasetConfig] = Field(
datasets: dict[str, SampledDatasetConfig["LanguageModelSample"]] = Field(
default_factory=dict,
desc="Configuration for the dataset(s).",
hint=FieldHint.core,
)
sampling: GPTSamplingConfig = FieldUpdate()
data_sample_warn_time_ms: float = Field(
default=1000,
desc="Warn if a sample takes too long to load.",
Expand Down
63 changes: 5 additions & 58 deletions fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import dataclasses
import logging
import pathlib
import typing
import warnings
from functools import partial

import numpy as np
import torch
import torch.utils.data

Expand All @@ -14,60 +11,25 @@
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.gpt.config import GPTSamplingData, GPTSamplingParameters
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.dataset.monitor import DatasetMonitor
from fast_llm.data.iterator import SampledDatasetIterator
from fast_llm.data.tokenizer import Tokenizer
from fast_llm.data.sample.language_model import LanguageModelBatch
from fast_llm.engine.config_utils.run import log_main_rank
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.schedule.config import BatchConfig
from fast_llm.models.gpt.config import GPTBatchConfig
from fast_llm.utils import Assert

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class GPTBatch:
token_ids: torch.Tensor
loss_masking_spans: list[torch.Tensor] | None = None
sequence_lengths: list[torch.Tensor] | None = None
chosen_spans: list[torch.Tensor] | None = None
rejected_spans: list[torch.Tensor] | None = None


def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch:
stacked_ids = np.stack([sample.token_ids for sample in batch])
stacked_spans = None
sequence_lengths = None
stacked_chosen_spans = None
stacked_rejected_spans = None
if sampling_parameters.use_loss_masking_spans:
stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch]
if sampling_parameters.use_preference_loss_spans:
stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch]
stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch]
if not sampling_parameters.cross_document_attention:
sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch]
return GPTBatch(
token_ids=torch.from_numpy(stacked_ids),
loss_masking_spans=stacked_spans,
sequence_lengths=sequence_lengths,
chosen_spans=stacked_chosen_spans,
rejected_spans=stacked_rejected_spans,
)


class GPTData[ConfigType: GPTDataConfig](Data[ConfigType]):
"""
A global class for all dataset needs, including loading, splitting, sampling and iteration.
Currently hard-coded to a GPT dataset.
TODO: Separate generic and GPT classes.
"""

_datasets: dict[str, SampledDataset]
_sampling_parameters: dict[str, GPTSamplingParameters]
_tokenizer: Tokenizer | None
_is_setup: bool = False

def __init__(
Expand Down Expand Up @@ -108,49 +70,37 @@ def setup(
)

log_main_rank(f"Preparing dataset. This may take several minutes.")
self._tokenizer = None if self._config.tokenizer.path is None else Tokenizer(self._config.tokenizer)

if self._cache_directory is None:
# TODO: Avoid this
warnings.warn(f"Using the dataset directory for the index cache.")

self._datasets = {}
for dataset_name, sampling_parameters in self._sampling_parameters.items():
if self._tokenizer is not None:
# NOTE: Some models like Qwen2-1.5B-Instruct
# have vocab_size bigger in model config than in tokenizer
# TODO: Still, is it too constraining?
Assert.geq(sampling_parameters.vocab_size, self._tokenizer.vocab_size)
if sampling_parameters.num_samples > 0:
sampling = GPTSamplingData(
config=self._config.sampling,
parameters=sampling_parameters,
cache_directory=self._cache_directory,
distributed=distributed,
dataset_name=dataset_name,
tokenizer=self._tokenizer,
)
dataset = self._config.datasets[dataset_name].build_and_sample(sampling)
self._datasets[dataset_name] = DatasetMonitor(dataset, self._config.data_sample_warn_time_ms)

safe_barrier(self._distributed.world_group, "data_preparation", timeout)
self._is_setup = True

@property
def tokenizer(self) -> Tokenizer:
assert self._is_setup
return self._tokenizer

def get_iterator(
self,
batch_config: BatchConfig,
batch_config: GPTBatchConfig,
dataset_name: str,
*,
consumed_samples: int,
num_workers: int,
prefetch_factor: int | None = None,
timeout: float = 60,
) -> typing.Iterator[typing.Any]:
) -> typing.Iterator[LanguageModelBatch]:
assert self._is_setup

# Some dataset names may come from phases and are capitalized,
Expand All @@ -175,10 +125,7 @@ def get_iterator(
num_workers=num_workers,
prefetch_factor=prefetch_factor,
pin_memory=True,
collate_fn=partial(
gpt_data_collate_fn,
sampling_parameters=sampling_parameters,
),
collate_fn=LanguageModelBatch.from_samples,
multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None,
)
)
20 changes: 15 additions & 5 deletions fast_llm/data/dataset/abstract.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import abc
import typing

from fast_llm.data.sample.abstract import Sample

if typing.TYPE_CHECKING:
from fast_llm.data.dataset.config import SamplingData


class Dataset(abc.ABC):
class Dataset[SampleType: Sample](abc.ABC):
"""
A generic dataset class compatible with torch.utils.data.Dataset but with a slightly different signature.
"""
Expand All @@ -17,24 +19,32 @@ def name(self) -> str:
A name for the dataset to facilitate identification and debugging.
"""

def __getstate__(self):
state = super().__getstate__()
# Pickling sometimes fails with bound `SampleType`.
# This is not needed at runtime, so we just drop it.
if "__orig_class__" in state:
del state["__orig_class__"]
return state


class SampledDataset(Dataset):
class SampledDataset[SampleType: Sample](Dataset[SampleType]):
"""
A sampled dataset class containing a prepared list of samples to be indexed sequentially (as-is) during training.
(See the `Sampler` class below.)
"""

@abc.abstractmethod
def __getitem__(self, index: int) -> typing.Any:
def __getitem__(self, index: int) -> SampleType:
pass

@abc.abstractmethod
def __len__(self) -> int:
pass


class SamplableDataset(Dataset):
class SamplableDataset[SampleType: Sample](Dataset[SampleType]):

@abc.abstractmethod
def sample(self, config: "SamplingData") -> SampledDataset:
def sample(self, config: "SamplingData") -> SampledDataset[SampleType]:
pass
Loading
Loading