Skip to content
4 changes: 2 additions & 2 deletions fast_llm/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class Dataset(abc.ABC):

@property
@abc.abstractmethod
def name(self):
def name(self) -> str:
"""
A name for the dataset to facilitate identification and debugging.
"""
Expand All @@ -169,7 +169,7 @@ class SamplingConfig(Config):


class SamplableDataset(Dataset):
def sample(self, config: SamplingConfig, data: Data):
def sample(self, config: SamplingConfig, data: Data) -> "SampledDataset":
pass


Expand Down
13 changes: 5 additions & 8 deletions fast_llm/data/gpt/concatenated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

from fast_llm.data.gpt.config import GPTConcatenatedDatasetConfig
from fast_llm.data.gpt.dataset import GPTIndexedDataset
from fast_llm.utils import padded_cumsum

Expand All @@ -8,10 +9,10 @@ class GPTConcatenatedDataset(GPTIndexedDataset):

def __init__(
self,
name: str,
config: GPTConcatenatedDatasetConfig,
datasets: list[GPTIndexedDataset],
):
self._name = name
self._config = config
self._datasets = datasets
sizes = [dataset.num_documents for dataset in self._datasets]
self._dataset_splits = padded_cumsum(sizes)
Expand All @@ -22,11 +23,7 @@ def num_tokens(self):
return sum(dataset.num_tokens for dataset in self._datasets)

def num_documents(self):
return sum(dataset.num_documents for dataset in self._datasets)

def get_document_sizes(self) -> "np.ndarray":
# TODO: This can be really big.
return np.concatenate([dataset.get_document_sizes() for dataset in self._datasets])
return self._num_documents

def get(self, document: int, offset: int = 0, length: int | None = None):
"""
Expand All @@ -39,4 +36,4 @@ def get(self, document: int, offset: int = 0, length: int | None = None):

@property
def name(self):
return self._name
return self._config.name
Loading