From 818a16271d087a8dee541e0ea01161995a56995b Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 28 Mar 2025 22:35:16 +0000 Subject: [PATCH 01/47] initial dpo updates --- fast_llm/data/dataset/gpt/config.py | 10 + fast_llm/data/dataset/gpt/fim.py | 2 + fast_llm/data/dataset/gpt/memmap.py | 104 ++++++++- fast_llm/data/dataset/gpt/sampled.py | 208 +++++++++++------- fast_llm/data/preparator/gpt_memmap/config.py | 6 + .../data/preparator/gpt_memmap/prepare.py | 51 +++++ fast_llm/engine/schedule/schedule.py | 14 +- fast_llm/models/gpt/model.py | 5 + 8 files changed, 309 insertions(+), 91 deletions(-) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 74d8a0c35..a37d32b4e 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -57,6 +57,16 @@ class GPTSamplingConfig(SamplingConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + use_preference_loss_masking_spans: bool | None = Field( + default=None, + desc="Read preference loss masking spans from the dataset.", + hint=FieldHint.feature, + ) + enable_packing: bool | None = Field( + default=True, + desc="Whether to enable packing or not.", + hint=FieldHint.feature, + ) shuffle: ShufflingType | None = Field( default=None, desc="Shuffling strategy.", diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 275505ba3..192e31315 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -20,6 +20,8 @@ def __init__( ): if sampling.config.use_loss_masking_spans: raise NotImplementedError("FIM is currently not compatible with loss masking.") + if sampling.config.use_preference_loss_masking_spans: + raise NotImplementedError("FIM is currently not compatible with preference loss masking.") self._config = config self._dataset = dataset self._seed = sampling.config.seed diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index c95b3705e..d560660a7 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -34,13 +34,17 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._name = name self._prefix = pathlib.Path(prefix) self._has_spans = 0 + self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER) self._version = struct.unpack(" GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -116,13 +153,47 @@ def get( sample_spans = None if use_loss_masking_spans and self._spans is not None: sample_spans = self._spans[idx] - # adjust the spans for the offset and length + + # filter spans that are outside the range of the selected tokens in the document sample_spans = sample_spans[ (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) ] - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset + + # subtract by offset to normalize span boundaries + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) + + chosen_spans = None + rejected_spans = None + if use_preference_loss_masking_spans and self._chosen_spans is not None and self._rejected_spans is not None: + chosen_spans = self._chosen_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + chosen_sample_spans = chosen_spans[ + (chosen_spans[:, 0] < offset + len(token_ids)) & (chosen_spans[:, 1] >= offset) + ] + + # subtract by offset to normalize span boundaries + chosen_spans[:, 0] = np.maximum(chosen_spans[:, 0], offset) - offset # offset + chosen_spans[:, 1] = np.minimum(chosen_spans[:, 1], offset + len(token_ids) - 1) - offset + + rejected_spans = self._rejected_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + rejected_sample_spans = rejected_spans[ + (rejected_spans[:, 0] < offset + len(token_ids)) & (rejected_spans[:, 1] >= offset) + ] + + # subtract by offset to normalize span boundaries + rejected_spans[:, 0] = np.maximum(rejected_spans[:, 0], offset) - offset # offset + rejected_spans[:, 1] = np.minimum(rejected_spans[:, 1], offset + len(token_ids) - 1) - offset + + return GPTSample( + token_ids=token_ids, + loss_masking_spans=sample_spans, + chosen_loss_masking_spans=chosen_sample_spans, + rejected_loss_masking_spans=rejected_sample_spans + ) @property def name(self) -> str: @@ -157,6 +228,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # number of spans for each document num_spans = [] spans = [] + chosen_spans = [] + rejected_spans = [] prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) @@ -182,6 +255,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) + if document.chosen_loss_masking_spans is not None: + chosen_spans.append(document.chosen_loss_masking_spans) + if document.rejected_loss_masking_spans is not None: + rejected_spans.append(document.rejected_loss_masking_spans) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 @@ -193,15 +270,26 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP spans = np.vstack(spans, dtype=np.int32) else: spans = np.array(spans, dtype=np.int32) + # if len(chosen_spans) > 0: + # chosen_spans = np.vstack(chosen_spans, dtype=np.int32) + # else: + chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) + # if len(rejected_spans) > 0: + # rejected_spans = np.vstack(rejected_spans, dtype=np.int32) + # else: + rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version # Version 2 optionally adds loss-masking spans - idx_stream.write(struct.pack(" 0 else 0)) + # Flag to indicate whether preference loss-masking spans are present + idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) # Data type idx_stream.write(struct.pack(" None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + self.document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) # Calculate basic stats. - documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() + documents_per_epoch = self.document_sizes.numel() + tokens_per_epoch = self.document_sizes.sum().item() # We produce sequences of length `self._sequence_length + 1` so the last token has a label, # but we also include that last label in the following sample, # so we need `sequence_length * num_samples + 1` tokens in total. - num_epochs = math.ceil((self._sequence_length * self._num_samples + 1) / tokens_per_epoch) + if self._config.enable_packing: + num_epochs = math.ceil((self._sequence_length * self._num_samples + 1) / tokens_per_epoch) + else: + num_epochs = math.ceil(self._num_samples / documents_per_epoch) # Prepare for shuffling. generator = torch.Generator(device=self._device) @@ -231,37 +243,45 @@ def _sample(self) -> None: # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` - if shuffled_epochs > 0: - token_cumsum_shuffled = self._get_token_cumsum( - document_sizes[ - # Torch indexing only works with int32 or int64 - document_shuffling.to( - dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + + if self._config.enable_packing: + if shuffled_epochs > 0: + token_cumsum_shuffled = self._get_token_cumsum( + self.document_sizes[ + # Torch indexing only works with int32 or int64 + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ], + offset=unshuffled_epochs * tokens_per_epoch, + dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch, + ) + self._token_cumsum_shuffled.save(token_cumsum_shuffled.numpy(force=self._config.gpu)) + self._document_shuffling.save( + document_shuffling[: (token_cumsum_shuffled.numel() + 1) * TOKEN_CUMSUM_RATE].numpy( + force=self._config.gpu ) - ], - offset=unshuffled_epochs * tokens_per_epoch, - dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch, - ) - self._token_cumsum_shuffled.save(token_cumsum_shuffled.numpy(force=self._config.gpu)) + ) + # Free memory + del token_cumsum_shuffled + del document_shuffling + + if unshuffled_epochs > 0: + token_cumsum_unshuffled = self._get_token_cumsum( + self.document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch + ) + self._token_cumsum_unshuffled.save(token_cumsum_unshuffled.numpy(force=self._config.gpu)) + else: self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.numel() + 1) * TOKEN_CUMSUM_RATE].numpy( + document_shuffling[:self._num_samples].numpy( force=self._config.gpu ) ) - # Free memory - del token_cumsum_shuffled - del document_shuffling - - if unshuffled_epochs > 0: - token_cumsum_unshuffled = self._get_token_cumsum( - document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch - ) - self._token_cumsum_unshuffled.save(token_cumsum_unshuffled.numpy(force=self._config.gpu)) def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: torch.dtype) -> torch.Tensor: # Create the output tensor. out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype) - # Get partial sums for regular intervals, excluding the last incomplete interval. + # Get partial sums for regular intervals, excluding the last incomplete interval. (sum #tokens in groups of 10 documents) torch.sum( sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), dim=1, out=out[1:] ) @@ -287,66 +307,90 @@ def __getitem__(self, index: int) -> typing.Any: The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). """ self._lazy_load() - token_start = index * self._sequence_length - token_end = token_start + self._sequence_length + 1 + if self._config.enable_packing: + # which global token idx to start and end at + token_start = index * self._sequence_length + token_end = token_start + self._sequence_length + 1 + + if token_start < self._unshuffled_tokens: + token_start_array = self._token_cumsum_unshuffled.array + token_start_array_document_offset = 0 + else: + # cumulative sum array + token_start_array = self._token_cumsum_shuffled.array + token_start_array_document_offset = self._unshuffled_documents + + # Find the rightmost location `token_start_cumsum_index` in `token_cumsum` with `token_cumsum[token_start_cumsum_index] <= token_start` + token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1 + + # which document to index from after shuffling + document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset + + # the current token pointer (initialized at the start of document_sampling_index) + token_count = token_start_array[token_start_cumsum_index] + + token_ids = [] + loss_masking_spans = [] + while token_count < token_end: + # Find the document index in the dataset. + if document_sampling_index < self._unshuffled_documents: + document_index = document_sampling_index % self._documents_per_epoch + else: + # which document to index from + document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() + + document_size = self._indexed_dataset.get_document_size(document_index) + # Determine if the document belongs to the requested sample. + if token_count + document_size >= token_start: + # Determine which part of the document belong to the sample, and add it to the list. + token_start_index_in_document = max(token_start - token_count, 0) + token_end_index_in_document = min(token_end - token_count, document_size) + sample = self._indexed_dataset.get( + document_index, + offset=token_start_index_in_document, + length=token_end_index_in_document - token_start_index_in_document, + use_loss_masking_spans=self._config.use_loss_masking_spans, + use_preference_loss_masking_spans=self._config.use_preference_loss_masking_spans + ) + token_ids.append(sample.token_ids) + if self._config.use_loss_masking_spans: + for loss_masking_span in sample.loss_masking_spans: + # offset span by token_count - token_start for packing sequences + span = np.clip(loss_masking_span + token_count - token_start, 0, self._sequence_length + 1) + if span[1] > span[0]: + loss_masking_spans.append(span) + + # Go to the next document. + document_sampling_index += 1 + token_count += document_size + + sequence_lengths = ( + np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) + if not self._cross_document_attention + else None + ) + token_ids = np.concatenate(token_ids, dtype=np.int64) + loss_masking_spans = ( + np.stack(loss_masking_spans, dtype=np.int32) if self._config.use_loss_masking_spans else None + ) + Assert.eq(len(token_ids), self._sequence_length + 1) - if token_start < self._unshuffled_tokens: - token_start_array = self._token_cumsum_unshuffled.array - token_start_array_document_offset = 0 + return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) else: - token_start_array = self._token_cumsum_shuffled.array - token_start_array_document_offset = self._unshuffled_documents - - # Find the rightmost location `token_start_cumsum_index` in `token_cumsum` with `token_cumsum[token_start_cumsum_index] <= token_start` - token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1 - - document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset - token_count = token_start_array[token_start_cumsum_index] - - token_ids = [] - loss_masking_spans = [] - while token_count < token_end: - # Find the document index in the dataset. - if document_sampling_index < self._unshuffled_documents: - document_index = document_sampling_index % self._documents_per_epoch + if index < self._unshuffled_documents: + document_index = index % self._documents_per_epoch else: - document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() - - document_size = self._indexed_dataset.get_document_size(document_index) - # Determine if the document belongs to the requested sample. - if token_count + document_size >= token_start: - # Determine which part of the document belong to the sample, and add it to the list. - token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, document_size) - sample = self._indexed_dataset.get( - document_index, - offset=token_start_index_in_document, - length=token_end_index_in_document - token_start_index_in_document, - use_loss_masking_spans=self._config.use_loss_masking_spans, - ) - token_ids.append(sample.token_ids) - if self._config.use_loss_masking_spans: - for loss_masking_span in sample.loss_masking_spans: - span = np.clip(loss_masking_span + token_count - token_start, 0, self._sequence_length + 1) - if span[1] > span[0]: - loss_masking_spans.append(span) - - # Go to the next document. - document_sampling_index += 1 - token_count += document_size - - sequence_lengths = ( - np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - if not self._cross_document_attention - else None - ) - token_ids = np.concatenate(token_ids, dtype=np.int64) - loss_masking_spans = ( - np.stack(loss_masking_spans, dtype=np.int32) if self._config.use_loss_masking_spans else None - ) - Assert.eq(len(token_ids), self._sequence_length + 1) + document_index = self._document_shuffling[index - self._unshuffled_documents].item() + + sample = self._indexed_dataset.get( + document_index, + offset=0, + length=self.document_sizes[document_index], + use_loss_masking_spans=self._config.use_loss_masking_spans, + use_preference_loss_masking_spans=self._config.use_preference_loss_masking_spans + ) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + return sample @property def name(self) -> str: diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c37..f5708e9b3 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -59,6 +59,12 @@ class GPTHuggingfaceDatasetConfig(Config): loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) + chosen_loss_masking_spans: None | str = Field( + default=None, desc="Field containing character chosen spans to mask for loss computation", hint=FieldHint.optional + ) + rejected_loss_masking_spans: None | str = Field( + default=None, desc="Field containing character rejected spans to mask for loss computation", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index b3dae1df1..ba0be311c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -73,6 +73,37 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict "token_spans": token_spans, "num_tokens": num_tokens, } + + def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + input_ids, chosen_token_spans, rejected_token_spans = map( + list, + zip( + *[ + ( + np.array(input_ids, dtype=self._data_type.numpy), + np.array(token_spans[0], dtype=np.int32), + np.array(token_spans[1], dtype=np.int32) + ) + for input_ids, token_spans in [ + self._tokenizer.tokenize_with_spans(text, [chosen_span, rejected_span]) + for text, chosen_span, rejected_span in zip( + batch[self._config.dataset.field], + batch[self._config.dataset.chosen_loss_masking_spans], + batch[self._config.dataset.rejected_loss_masking_spans] + ) + ] + ] + ), + ) + + num_tokens = [len(x) for x in input_ids] + + return { + "input_ids": input_ids, + "chosen_token_spans": chosen_token_spans, + "rejected_token_spans": rejected_token_spans, + "num_tokens": num_tokens + } def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: shard_idx, shard_dataset = args @@ -86,6 +117,13 @@ def _document_generator(): np.array(item["input_ids"], dtype=self._data_type.numpy), np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), ) + elif "chosen_token_spans" in shard_dataset.column_names and "rejected_token_spans" in shard_dataset.column_names and self._config.dataset.chosen_loss_masking_spans is not None and self._config.dataset.rejected_loss_masking_spans is not None: + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), + chosen_loss_masking_spans=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), + rejected_loss_masking_spans=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2) + ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) @@ -214,10 +252,23 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") + if self._config.dataset.loss_masking_spans is not None and \ + (self._config.dataset.chosen_loss_masking_spans is not None or self._config.dataset.rejected_loss_masking_spans is not None): + raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if (self._config.dataset.chosen_loss_masking_spans is None) != (self._config.dataset.rejected_loss_masking_spans is None): + raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") + + # route tokenize function if self._config.dataset.loss_masking_spans is not None: if self._config.dataset.loss_masking_spans not in dataset.column_names: raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.") tokenize_fn = self._tokenize_batch_with_spans + elif self._config.dataset.chosen_loss_masking_spans is not None and self._config.dataset.rejected_loss_masking_spans is not None: + if self._config.dataset.chosen_loss_masking_spans not in dataset.column_names: + raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_loss_masking_spans}'.") + if self._config.dataset.rejected_loss_masking_spans not in dataset.column_names: + raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_loss_masking_spans}'.") + tokenize_fn = self._tokenize_preference_batch_with_spans else: tokenize_fn = self._tokenize_batch diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 87a12bfe9..255e2f51b 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -135,7 +135,7 @@ def __init__( if self._batch_config.num_inputs < self._distributed.pipeline_parallel: warnings.warn("Not enough input to achieve true pipeline parallelism.") - # Setup the activation metas. + # Setup the activation metas. (metadata for sequence parallel) self._preprocessed_meta = self._multi_stage.base_model.preprocess_meta( self._batch_config, phase=self._phase, @@ -191,8 +191,8 @@ def get_step( return self._step_map[(type_, stage, data_index)] def _create_index(self) -> None: - self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed.pipeline_parallel)] - self._step_map = {} + self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed.pipeline_parallel)] # steps for each device + self._step_map = {} # map index (type, stage, data index) => step for i, step in enumerate(self._steps): Assert.in_range(step.stage, 0, self._num_stages) Assert.in_range( @@ -203,6 +203,8 @@ def _create_index(self) -> None: Assert.incl(step.type_, (StepType.forward, StepType.backward)) step.global_index = i # TODO: More configurable placement? + + # perform looping here step.pipeline_rank = step.stage % self._distributed.pipeline_parallel step.local_index = len(self._device_steps[step.pipeline_rank]) self._device_steps[step.pipeline_rank].append(step) @@ -222,12 +224,16 @@ def _create_index(self) -> None: Assert.empty(step_map) # Related steps + for i, step in enumerate(self._steps): + # link forward and backward steps together if self._is_training: if step.type_ == StepType.forward: step.backward_step = self.get_step(StepType.backward, *step.map_index[1:]) else: step.forward_step = self.get_step(StepType.forward, *step.map_index[1:]) + + # link the previous step if step.type_ == StepType.forward and step.stage == 0: step.prev_step = None elif step.type_ == StepType.backward and step.stage == self._num_stages - 1: @@ -236,6 +242,8 @@ def _create_index(self) -> None: step.prev_step = self.get_step( step.type_, step.stage + (1 if step.type_ == StepType.backward else -1), *step.map_index[2:] ) + + # link the next step if step.type_ == StepType.backward and step.stage == 0: step.next_step = None elif step.type_ == StepType.forward and step.stage == self._num_stages - 1: diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bb2f4ed5f..7aa5ed3b0 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -91,6 +91,8 @@ def setup(self, distributed: Distributed) -> None: self._tensor_space.setup(distributed) self._is_setup = True + + # perform preprocessing for sequence parallel def preprocess_meta( self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -136,6 +138,7 @@ def preprocess_meta( else sequence_q_dim ) + # determins if batch dim or sequence dim is first need_sequence_first = hidden_sequence_q_dim.size != sequence_length if self._config.sequence_first is None: sequence_first = need_sequence_first @@ -143,6 +146,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) + # hidden dim is model hidden size hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) @@ -167,6 +171,7 @@ def preprocess_meta( sequence_k = sequence_k_past + sequence_q_dim.size sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) + # sequence_k_past is start and sequence_k is end of sequence tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 ) From 40c96c8a913844de2bd355924a05be41cd8133d9 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 3 Apr 2025 20:42:49 +0000 Subject: [PATCH 02/47] dataset changes for dpo --- fast_llm/data/data/gpt/data.py | 19 +++++++++++++++-- fast_llm/data/dataset/gpt/memmap.py | 28 ++++++++++++------------- fast_llm/data/dataset/gpt/sampled.py | 31 ++++++++++++++++++---------- fast_llm/data/tokenizer.py | 17 +++++++++++++-- 4 files changed, 66 insertions(+), 29 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 8fc333765..a1134c507 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,20 +32,34 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + chosen_loss_masking_spans: list[torch.Tensor] | None = None + rejected_loss_masking_spans: list[torch.Tensor] | None = None def gpt_data_collate_fn( - batch: list[GPTSample], use_loss_masking_spans: bool, cross_document_attention: bool + batch: list[GPTSample], + use_loss_masking_spans: bool, + cross_document_attention: bool, + use_preference_loss_masking_spans: bool ) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None + stacked_chosen_spans = None + stacked_rejected_spans = None if use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + if use_preference_loss_masking_spans: + stacked_chosen_spans = [torch.from_numpy(sample.chosen_loss_masking_spans) for sample in batch] + stacked_rejected_spans= [torch.from_numpy(sample.rejected_loss_masking_spans) for sample in batch] if not cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths + token_ids=torch.from_numpy(stacked_ids), + loss_masking_spans=stacked_spans, + sequence_lengths=sequence_lengths, + chosen_loss_masking_spans=stacked_chosen_spans, + rejected_loss_masking_spans=stacked_rejected_spans ) @@ -169,6 +183,7 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, + use_preference_loss_masking_spans=self._config.sampling.use_preference_loss_masking_spans ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 33c818abb..9576b965a 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -106,7 +106,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None dtype=np.int32, count=2, offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) + ) ) rejected_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes @@ -117,7 +117,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None dtype=np.int32, count=2, offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, - ).reshape(-1, 2) + ) ) self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") @@ -169,30 +169,30 @@ def get( chosen_spans = self._chosen_spans[idx] # filter spans that are outside the range of the selected tokens in the document - chosen_sample_spans = chosen_spans[ - (chosen_spans[:, 0] < offset + len(token_ids)) & (chosen_spans[:, 1] >= offset) - ] + chosen_spans = chosen_spans[ + (chosen_spans[0] < offset + len(token_ids)) & (chosen_spans[1] >= offset) + ][0] # subtract by offset to normalize span boundaries - chosen_spans[:, 0] = np.maximum(chosen_spans[:, 0], offset) - offset # offset - chosen_spans[:, 1] = np.minimum(chosen_spans[:, 1], offset + len(token_ids) - 1) - offset + chosen_spans[0] = np.maximum(chosen_spans[0], offset) - offset # offset + chosen_spans[1] = np.minimum(chosen_spans[1], offset + len(token_ids) - 1) - offset rejected_spans = self._rejected_spans[idx] # filter spans that are outside the range of the selected tokens in the document - rejected_sample_spans = rejected_spans[ - (rejected_spans[:, 0] < offset + len(token_ids)) & (rejected_spans[:, 1] >= offset) - ] + rejected_spans = rejected_spans[ + (rejected_spans[0] < offset + len(token_ids)) & (rejected_spans[1] >= offset) + ][0] # subtract by offset to normalize span boundaries - rejected_spans[:, 0] = np.maximum(rejected_spans[:, 0], offset) - offset # offset - rejected_spans[:, 1] = np.minimum(rejected_spans[:, 1], offset + len(token_ids) - 1) - offset + rejected_spans[0] = np.maximum(rejected_spans[0], offset) - offset # offset + rejected_spans[1] = np.minimum(rejected_spans[1], offset + len(token_ids) - 1) - offset return GPTSample( token_ids=token_ids, loss_masking_spans=sample_spans, - chosen_loss_masking_spans=chosen_sample_spans, - rejected_loss_masking_spans=rejected_sample_spans + chosen_loss_masking_spans=chosen_spans, + rejected_loss_masking_spans=rejected_spans ) @property diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 4d1c02daf..e355ee5a1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -120,6 +120,9 @@ def __init__( # contains cumulative sum of document sizes grouped by TOKEN_CUMSUM_RATE in shuffled order self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) + + self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) + self._yaml_path = base_path.with_suffix(".yaml") # Sample or validate the dataset of a given rank. if sampling.distributed.config.rank == sampling.get_next_rank(): @@ -132,11 +135,11 @@ def _sample(self) -> None: Create a `GPTSampledDataset` with the requested parameters. """ # Get the document sizes, the main information needed for sampling. - self.document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) # Calculate basic stats. - documents_per_epoch = self.document_sizes.numel() - tokens_per_epoch = self.document_sizes.sum().item() + documents_per_epoch = document_sizes.numel() + tokens_per_epoch = document_sizes.sum().item() # We produce sequences of length `self._sequence_length + 1` so the last token has a label, # but we also include that last label in the following sample, # so we need `sequence_length * num_samples + 1` tokens in total. @@ -160,7 +163,7 @@ def _sample(self) -> None: "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, - "tokens_per_epoch": tokens_per_epoch, + "tokens_per_epoch": tokens_per_epoch }, "num_samples": self._num_samples, "unshuffled_epochs": unshuffled_epochs, @@ -247,7 +250,7 @@ def _sample(self) -> None: if self._config.enable_packing: if shuffled_epochs > 0: token_cumsum_shuffled = self._get_token_cumsum( - self.document_sizes[ + document_sizes[ # Torch indexing only works with int32 or int64 document_shuffling.to( dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 @@ -268,15 +271,17 @@ def _sample(self) -> None: if unshuffled_epochs > 0: token_cumsum_unshuffled = self._get_token_cumsum( - self.document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch + document_sizes, offset=0, dtype=get_unsigned_integer_type(tokens_per_epoch * num_epochs).torch ) self._token_cumsum_unshuffled.save(token_cumsum_unshuffled.numpy(force=self._config.gpu)) else: - self._document_shuffling.save( - document_shuffling[:self._num_samples].numpy( - force=self._config.gpu + if shuffled_epochs > 0: + self._document_shuffling.save( + document_shuffling[:self._num_samples].numpy( + force=self._config.gpu + ) ) - ) + self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: torch.dtype) -> torch.Tensor: # Create the output tensor. @@ -385,11 +390,15 @@ def __getitem__(self, index: int) -> typing.Any: sample = self._indexed_dataset.get( document_index, offset=0, - length=self.document_sizes[document_index], + length=self._document_sizes[document_index], use_loss_masking_spans=self._config.use_loss_masking_spans, use_preference_loss_masking_spans=self._config.use_preference_loss_masking_spans ) + chosen_loss_masking_span_end = sample.chosen_loss_masking_spans[1] + 1 + sequence_lengths = np.array([chosen_loss_masking_span_end, len(sample.token_ids) - chosen_loss_masking_span_end]) + sample.sequence_lengths = sequence_lengths + return sample @property diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 28e105ee8..77d2d315b 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -22,6 +22,8 @@ def __init__(self, config: TokenizerConfig): raise ValueError("Tokenizer does not have an BOS token.") self.eod_id = self.tokenizer.eos_token_id self.bod_id = self.tokenizer.bos_token_id + self.eod_token = self.tokenizer.eos_token + self.bod_token = self.tokenizer.bos_token @property def vocab_size(self) -> int: @@ -52,6 +54,9 @@ def tokenize_with_spans( token_spans = [] char_pos = 0 beginning_of_text = True + if text.startswith(self.bod_token): + beginning_of_text = False + for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] @@ -60,7 +65,11 @@ def tokenize_with_spans( input_ids.extend(tokenized_text) curr_text = text[start : end + 1] if end >= len(text) - 1: - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) + tokenized_text = self.tokenize( + curr_text, + begin=beginning_of_text, + end=True if not curr_text.endswith(self.eod_token) else False + ) else: tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) beginning_of_text = False @@ -69,7 +78,11 @@ def tokenize_with_spans( char_pos = end + 1 if char_pos < len(text): curr_text = text[char_pos:] - tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) + tokenized_text = self.tokenize( + curr_text, + begin=beginning_of_text, + end=True if not curr_text.endswith(self.eod_token) else False + ) input_ids.extend(tokenized_text) return input_ids, token_spans From f7796d4b41336c9738f861dc5133a4185a43c7a9 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 3 Apr 2025 20:45:28 +0000 Subject: [PATCH 03/47] adding dpo loss --- fast_llm/functional/config.py | 4 ++ fast_llm/functional/dpo.py | 55 ++++++++++++++++++++++ fast_llm/layers/language_model/config.py | 14 +++++- fast_llm/layers/language_model/head.py | 59 ++++++++++++++++++++---- fast_llm/models/gpt/model.py | 28 ++++++++++- 5 files changed, 148 insertions(+), 12 deletions(-) create mode 100644 fast_llm/functional/dpo.py diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 9f1fe005e..b45bc117e 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -91,3 +91,7 @@ class CrossEntropyImpl(str, enum.Enum): torch = "torch" fused = "fused" triton = "triton" + +class LossFunctionType(str, enum.Enum): + cross_entropy = "cross_entropy" + dpo = "dpo" \ No newline at end of file diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py new file mode 100644 index 000000000..384f0e881 --- /dev/null +++ b/fast_llm/functional/dpo.py @@ -0,0 +1,55 @@ +import torch +import torch.nn.functional as F +from typing import Tuple + + +def compute_logps_for_spans( + logits: torch.Tensor, + targets: torch.Tensor, + chosen_span: torch.Tensor, + rejected_span: torch.Tensor + ): + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + # gather log probabilities corresponding to the target tokens + # selected_log_probs = log_probs[torch.arange(logits.shape[0] - 1), targets] + selected_log_probs = log_probs[:-1].gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + + # apply chosen mask + chosen_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool) + chosen_mask[chosen_span[:, 0]: chosen_span[:, 1] + 1] = 1 + chosen_logp = (selected_log_probs * chosen_mask).sum() + + # apply rejected mask + rejected_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool) + rejected_mask[rejected_span[:, 0]: rejected_span[:, 1] + 1] = 1 + rejected_logp = (selected_log_probs * rejected_mask).sum() + + # chosen_logp = selected_log_probs[chosen_span[:, 0]: chosen_span[:, 1] + 1].sum() + # rejected_logp = selected_log_probs[rejected_span[:, 0]: rejected_span[:, 1] + 1].sum() + + return chosen_logp, rejected_logp + +def compute_simplified_dpo_loss( + logits: torch.Tensor, + targets: torch.Tensor, + chosen_span: torch.Tensor, + rejected_span: torch.Tensor, + beta: float, + grad_output: float | None +) -> Tuple[torch.Tensor, torch.Tensor]: + with torch.enable_grad(): + logits_ = logits.float().detach().requires_grad_() + + policy_chosen_logps, policy_rejected_logps = compute_logps_for_spans(logits_, targets, chosen_span, rejected_span) + + pi_logratios = policy_chosen_logps - policy_rejected_logps + + losses = -F.logsigmoid(beta * pi_logratios) + if grad_output is None: + loss = None + else: + loss = losses.mean() + loss.backward(torch.full_like(loss, grad_output)) + loss.detach() + return loss.detach(), logits_.grad.detach().to(logits.dtype) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 8e3a467cc..864cc1577 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames -from fast_llm.functional.config import CrossEntropyImpl +from fast_llm.functional.config import CrossEntropyImpl, LossFunctionType from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig from fast_llm.utils import Assert @@ -28,6 +28,8 @@ class LanguageModelKwargs: # TODO: These are generic labels = "labels" phase = "phase" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" @config_class() @@ -128,6 +130,16 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", hint=FieldHint.feature, ) + loss_function_type: LossFunctionType = Field( + default=LossFunctionType.cross_entropy, + desc="Type of loss function to use", + hint=FieldHint.feature, + ) + beta: float | None = Field( + default=1.0, + desc="Beta value for DPO loss.", + hint=FieldHint.feature, + ) cross_entropy_impl: CrossEntropyImpl = Field( default=CrossEntropyImpl.auto, desc="Implementation for the cross-entropy computation.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index efca95b41..5aed40820 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -10,9 +10,10 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, TritonConfig +from fast_llm.functional.config import CrossEntropyImpl, TritonConfig, LossFunctionType from fast_llm.functional.cross_entropy import cross_entropy_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward +from fast_llm.functional.dpo import compute_simplified_dpo_loss from fast_llm.layers.common.auxiliary_loss import z_loss from fast_llm.layers.language_model.config import ( LanguageModelBaseConfig, @@ -74,14 +75,20 @@ def __init__( ), ) - self._cross_entropy_impl = config.cross_entropy_impl - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._parallel_embeddings: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused + self._loss_function_type = config.loss_function_type + if self._loss_function_type == LossFunctionType.cross_entropy: + self._cross_entropy_impl = config.cross_entropy_impl + if self._cross_entropy_impl == CrossEntropyImpl.auto: + if self._parallel_embeddings: + self._cross_entropy_impl = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + self._cross_entropy_impl = CrossEntropyImpl.triton + else: + self._cross_entropy_impl = CrossEntropyImpl.fused + self._loss_fcn = self._logits_cross_entropy_forward_backward_split + else: + self._loss_fcn = self._logits_dpo + self.dpo_beta = config.beta self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) @@ -127,7 +134,7 @@ def _forward_backward( ) output_weights = kwargs[WORD_EMBEDDINGS_WEIGHT] if self._tie_word_embeddings else self.output_weights - loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( + loss, ln_output_grad = self._loss_fcn( ln_output.detach(), labels, output_weights, grad_output, kwargs, losses ) @@ -136,6 +143,38 @@ def _forward_backward( return loss, input_.grad else: return loss, None + + def _logits_dpo( + self, + input_: torch.Tensor, + labels: torch.Tensor | None, + weight: torch.Tensor, + grad_output: float, + kwargs: dict, + losses: dict | None = None + ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + logits, context = output_parallel_linear_forward( + input_=input_, + weight=weight, + bias=None, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + sequence_parallel=self._sequence_parallel and self._parallel_embeddings, + ) + + loss, grad = compute_simplified_dpo_loss( + logits.flatten(0, -2), + labels, + kwargs[LanguageModelKwargs.chosen_spans], + kwargs[LanguageModelKwargs.rejected_spans], + self.dpo_beta, + grad_output + ) + + # TODO: de-allocate earlier. + del logits + return loss, output_parallel_linear_backward(grad, context).view_as(input_) + + def _logits_cross_entropy_forward_backward_split( self, diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9a26f58d9..344818995 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -254,7 +254,7 @@ def preprocess( TransformerKwargs.presents: presents, } if phase != PhaseType.inference: - sequence_offset = sequence_k - sequence_q + 1 + sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: labels = batch.token_ids[sequence_offset : sequence_k + 1] else: @@ -266,8 +266,10 @@ def preprocess( for i, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue + # filter spans within the sequence or partially within the sequence valid_spans = spans[(spans[:, 0] <= sequence_k) & (spans[:, 1] >= sequence_offset)] if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k) valid_spans -= sequence_offset @@ -276,6 +278,30 @@ def preprocess( labels[start : end + 1, i] = -100 else: labels[i, start : end + 1] = -100 + if batch.chosen_loss_masking_spans is not None: + for i, spans in enumerate(batch.chosen_loss_masking_spans): + if not spans.numel(): + continue + # filter spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[:, 0].clamp_(min=sequence_offset) + valid_spans[:, 1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + kwargs[LanguageModelKwargs.chosen_spans] = valid_spans + if batch.rejected_loss_masking_spans is not None: + for i, spans in enumerate(batch.rejected_loss_masking_spans): + if not spans.numel(): + continue + # filter spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[:, 0].clamp_(min=sequence_offset) + valid_spans[:, 1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + kwargs[LanguageModelKwargs.rejected_spans] = valid_spans kwargs[LanguageModelKwargs.labels] = labels if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs) From 3c0199f633bc36ab68a349093cb7a4bb5eb1f8cf Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 4 Apr 2025 21:15:07 +0000 Subject: [PATCH 04/47] packing disabled filter sequennces longer than seq length --- fast_llm/data/dataset/gpt/sampled.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 2499fda29..bfef9fb35 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -96,6 +96,8 @@ def __init__( if self._config.enable_packing and self._config.use_preference_loss_masking_spans: raise NotImplementedError("Packing currently not implemented with preference loss masking.") + if not self._config.enable_packing and self._truncate_documents: + raise NotImplementedError("If packing is disabled, document truncation must also be disabled.") if sampling.cache_directory is None: self._document_shuffling = MemmapArray() @@ -122,7 +124,9 @@ def __init__( self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) - self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) + if not self._config.enable_packing: + self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) + self._doc_length_filtered_indicies = MemmapArray(base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy")) self._yaml_path = base_path.with_suffix(".yaml") # Sample or validate the dataset of a given rank. @@ -168,6 +172,7 @@ def _sample(self) -> None: / tokens_per_epoch ) else: + documents_per_epoch = (~long_docs_filter).sum().item() num_epochs = math.ceil(self._num_samples / documents_per_epoch) # Prepare for shuffling. @@ -310,6 +315,18 @@ def _sample(self) -> None: # Free memory del document_shuffling else: + # index of all documents less than seq length long + doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] + self._doc_length_filtered_indicies.save( + doc_length_filtered_indicies.numpy( + force=self._config.gpu + ) + ) + + # # apply shuffling on doc_length_filtered_indicies + # document_shuffling_length_filtered_indices = torch.gather( + # doc_length_filtered_indicies, dim=0, index=document_shuffling.to(torch.int64) + # ) if shuffled_epochs > 0: self._document_shuffling.save( document_shuffling[:self._num_samples].numpy( @@ -321,8 +338,6 @@ def _sample(self) -> None: # yaml_data["unshuffled_tokens"] = num_tokens_unshuffled self._yaml_path.parent.mkdir(parents=True, exist_ok=True) yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - - def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: if self._truncate_documents: @@ -454,9 +469,9 @@ def __getitem__(self, index: int) -> typing.Any: return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) else: if index < self._unshuffled_documents: - document_index = index % self._documents_per_epoch + document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] else: - document_index = self._document_shuffling[index - self._unshuffled_documents].item() + document_index = self._doc_length_filtered_indicies[self._document_shuffling[index - self._unshuffled_documents].item()] sample = self._indexed_dataset.get( document_index, From 0e1335bb3eeee30b83ad9eadffe6a24d131014b5 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 4 Apr 2025 21:52:11 +0000 Subject: [PATCH 05/47] disable no packing for legacy sampling --- fast_llm/data/dataset/gpt/memmap.py | 7 +------ fast_llm/data/dataset/gpt/sampled.py | 6 +++++- fast_llm/engine/schedule/schedule.py | 7 +++---- fast_llm/functional/dpo.py | 4 ---- fast_llm/models/gpt/model.py | 10 +++------- 5 files changed, 12 insertions(+), 22 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 9576b965a..3e4490f92 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -270,13 +270,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP spans = np.vstack(spans, dtype=np.int32) else: spans = np.array(spans, dtype=np.int32) - # if len(chosen_spans) > 0: - # chosen_spans = np.vstack(chosen_spans, dtype=np.int32) - # else: + chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) - # if len(rejected_spans) > 0: - # rejected_spans = np.vstack(rejected_spans, dtype=np.int32) - # else: rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) # Write the index file (.idx) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index bfef9fb35..5120f950c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -117,7 +117,7 @@ def __init__( ) # TODO: Names are confusing - # contains document indexes/pointers in order of traversal (shuffled) + # contains shuffled document indicies self._document_shuffling = MemmapArray(base_path.with_name(base_path.name + "_shuffling.npy")) # contains cumulative sum of document sizes grouped by TOKEN_CUMSUM_RATE in shuffled order @@ -521,6 +521,10 @@ def __init__( self._indexed_dataset = indexed_dataset self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length + if not sampling.config.enable_packing: + raise NotImplementedError( + "Legacy sampling only supports document packing. Please use the latest dataset format." + ) if not sampling.truncate_documents: raise NotImplementedError( "Legacy sampling only supports document truncation. Please use the latest dataset format." diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 0bc01c782..92a4b31c6 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -135,7 +135,7 @@ def __init__( if self._batch_config.num_inputs < self._distributed.pipeline_parallel: warnings.warn("Not enough input to achieve true pipeline parallelism.") - # Setup the activation metas. (metadata for sequence parallel) + # Setup the activation metas. self._preprocessed_meta = self._multi_stage.base_model.preprocess_meta( self._batch_config, phase=self._phase, @@ -191,8 +191,8 @@ def get_step( return self._step_map[(type_, stage, data_index)] def _create_index(self) -> None: - self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed.pipeline_parallel)] # steps for each device - self._step_map = {} # map index (type, stage, data index) => step + self._device_steps: list[list[Step]] = [[] for _ in range(self._distributed.pipeline_parallel)] + self._step_map = {} for i, step in enumerate(self._steps): Assert.in_range(step.stage, 0, self._num_stages) Assert.in_range( @@ -204,7 +204,6 @@ def _create_index(self) -> None: step.global_index = i # TODO: More configurable placement? - # perform looping here step.pipeline_rank = step.stage % self._distributed.pipeline_parallel step.local_index = len(self._device_steps[step.pipeline_rank]) self._device_steps[step.pipeline_rank].append(step) diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 384f0e881..b7ef8ccd3 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -12,7 +12,6 @@ def compute_logps_for_spans( log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # gather log probabilities corresponding to the target tokens - # selected_log_probs = log_probs[torch.arange(logits.shape[0] - 1), targets] selected_log_probs = log_probs[:-1].gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) # apply chosen mask @@ -25,9 +24,6 @@ def compute_logps_for_spans( rejected_mask[rejected_span[:, 0]: rejected_span[:, 1] + 1] = 1 rejected_logp = (selected_log_probs * rejected_mask).sum() - # chosen_logp = selected_log_probs[chosen_span[:, 0]: chosen_span[:, 1] + 1].sum() - # rejected_logp = selected_log_probs[rejected_span[:, 0]: rejected_span[:, 1] + 1].sum() - return chosen_logp, rejected_logp def compute_simplified_dpo_loss( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 16072e1d4..01745e031 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -120,7 +120,6 @@ def setup(self, distributed: Distributed) -> None: self._is_setup = True - # perform preprocessing for sequence parallel def preprocess_meta( self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -166,7 +165,6 @@ def preprocess_meta( else sequence_q_dim ) - # determins if batch dim or sequence dim is first need_sequence_first = hidden_sequence_q_dim.size != sequence_length if self._config.sequence_first is None: sequence_first = need_sequence_first @@ -174,7 +172,6 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - # hidden dim is model hidden size hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) @@ -199,7 +196,6 @@ def preprocess_meta( sequence_k = sequence_k_past + sequence_q_dim.size sequence_k_dim = TensorDim(TransformerDimNames.sequence_k, sequence_k) - # sequence_k_past is start and sequence_k is end of sequence tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 ) @@ -294,7 +290,7 @@ def preprocess( for i, spans in enumerate(batch.loss_masking_spans): if not spans.numel(): continue - # filter spans within the sequence or partially within the sequence + # only keep spans within the sequence or partially within the sequence valid_spans = spans[(spans[:, 0] <= sequence_k) & (spans[:, 1] >= sequence_offset)] if valid_spans.numel(): # if span is partially within the sequence, truncate parts of spans that are outside of the sequence @@ -310,7 +306,7 @@ def preprocess( for i, spans in enumerate(batch.chosen_loss_masking_spans): if not spans.numel(): continue - # filter spans within the sequence or partially within the sequence + # only keep spans within the sequence or partially within the sequence valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] if valid_spans.numel(): # if span is partially within the sequence, truncate parts of spans that are outside of the sequence @@ -322,7 +318,7 @@ def preprocess( for i, spans in enumerate(batch.rejected_loss_masking_spans): if not spans.numel(): continue - # filter spans within the sequence or partially within the sequence + # only keep spans within the sequence or partially within the sequence valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] if valid_spans.numel(): # if span is partially within the sequence, truncate parts of spans that are outside of the sequence From 0e0909850c0e152394eee077c1294ec94ad7d46a Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 9 Apr 2025 00:24:01 +0000 Subject: [PATCH 06/47] adding dpo tests --- tests/data/test_prepare_gpt_memmap.py | 37 ++++++++++++++++++++++++++- tests/test_functional.py | 27 +++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 9a15a051b..086e8b56c 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -25,7 +25,6 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset ) return config.get_dataset_preparator_class()(config=config) - @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] @@ -38,6 +37,42 @@ def test_write_memmap_dataset(dtype): dataset.get(i).token_ids, document.token_ids, equal_nan=True ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_write_memmap_preference_dataset(dtype): + def generate_valid_span(max_seq_length): + span = np.random.choice(np.arange(0, max_seq_length-1), size=2, replace=False) + return np.sort(span) + + vocab_size = 1000 + max_seq_length = 8192 + num_samples = 100 + + documents = [ + GPTSample( + token_ids=np.random.randint(vocab_size, size=max_seq_length).astype(dtype), + chosen_loss_masking_spans=generate_valid_span(max_seq_length=max_seq_length), + rejected_loss_masking_spans=generate_valid_span(max_seq_length=max_seq_length) + ) + for _ in range(num_samples) + ] + with tempfile.TemporaryDirectory() as temp_dir: + prefix = pathlib.Path(temp_dir) + GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) + dataset = GPTMemmapDataset(name="foo", prefix=prefix) + for i, document in enumerate(documents): + dataset_item = dataset.get(i, use_preference_loss_masking_spans=True) + assert np.array_equal( + dataset_item.token_ids, document.token_ids, equal_nan=True + ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." + + assert np.array_equal( + dataset_item.chosen_loss_masking_spans, document.chosen_loss_masking_spans, equal_nan=True + ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_loss_masking_spans} != {dataset.get(i).chosen_loss_masking_spans}." + + assert np.array_equal( + dataset_item.rejected_loss_masking_spans, document.rejected_loss_masking_spans, equal_nan=True + ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_loss_masking_spans} != {dataset.get(i).rejected_loss_masking_spans}." + def test_load_metadata_from_hub(): with tempfile.TemporaryDirectory(suffix="test") as local_folder: diff --git a/tests/test_functional.py b/tests/test_functional.py index 3e5c7f873..bee568949 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -4,9 +4,36 @@ from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map +from fast_llm.functional.dpo import compute_simplified_dpo_loss from fast_llm.utils import Assert from tests.common import requires_cuda +def test_simplified_dpo_loss(): + torch.manual_seed(0) + vocab_size = 10 + seq_length = 10 + logits = torch.randn((seq_length, vocab_size)) + targets = torch.randint(vocab_size, size=(seq_length-1, )) + + dpo_loss, _ = compute_simplified_dpo_loss( + logits=logits, + targets=targets, + chosen_span=torch.tensor([[1, 2]]), + rejected_span=torch.tensor([[4, 5]]), + beta=0.1, + grad_output=0.25 + ) + Assert.rms_close(dpo_loss, torch.tensor(0.71527), 1e-5) + + dpo_loss, _ = compute_simplified_dpo_loss( + logits=logits, + targets=targets, + chosen_span=torch.tensor([[2, 3]]), + rejected_span=torch.tensor([[5, 7]]), + beta=0.3, + grad_output=0.25 + ) + Assert.rms_close(dpo_loss, torch.tensor(0.30449), 1e-5) @requires_cuda @pytest.mark.parametrize("gated", [True, False]) From 1075176a152efcba9b677721ccc78fb905ef82c7 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 10 Apr 2025 00:14:01 +0000 Subject: [PATCH 07/47] small fix --- fast_llm/data/dataset/gpt/memmap.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 3e4490f92..3270454df 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -72,7 +72,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None # read spans self._spans = None - if self._has_spans and self._version == 2: + if self._has_spans and self._version in {2, 3}: self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -270,7 +270,6 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP spans = np.vstack(spans, dtype=np.int32) else: spans = np.array(spans, dtype=np.int32) - chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) From 415634956d3272012d011bc9335cc54617fe53e9 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 10 Apr 2025 01:44:32 +0000 Subject: [PATCH 08/47] span tokenization updates --- fast_llm/data/tokenizer.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 77d2d315b..988e23e76 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -22,8 +22,6 @@ def __init__(self, config: TokenizerConfig): raise ValueError("Tokenizer does not have an BOS token.") self.eod_id = self.tokenizer.eos_token_id self.bod_id = self.tokenizer.bos_token_id - self.eod_token = self.tokenizer.eos_token - self.bod_token = self.tokenizer.bos_token @property def vocab_size(self) -> int: @@ -54,8 +52,6 @@ def tokenize_with_spans( token_spans = [] char_pos = 0 beginning_of_text = True - if text.startswith(self.bod_token): - beginning_of_text = False for start, end in char_spans: if char_pos < start: @@ -65,11 +61,7 @@ def tokenize_with_spans( input_ids.extend(tokenized_text) curr_text = text[start : end + 1] if end >= len(text) - 1: - tokenized_text = self.tokenize( - curr_text, - begin=beginning_of_text, - end=True if not curr_text.endswith(self.eod_token) else False - ) + tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) else: tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=False) beginning_of_text = False @@ -78,11 +70,7 @@ def tokenize_with_spans( char_pos = end + 1 if char_pos < len(text): curr_text = text[char_pos:] - tokenized_text = self.tokenize( - curr_text, - begin=beginning_of_text, - end=True if not curr_text.endswith(self.eod_token) else False - ) + tokenized_text = self.tokenize(curr_text, begin=beginning_of_text, end=True) input_ids.extend(tokenized_text) return input_ids, token_spans From 9669211383b8233bd90981962269bdc4137e803d Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 10 Apr 2025 03:34:09 +0000 Subject: [PATCH 09/47] enable chosen/rejected text for preparator --- fast_llm/data/preparator/gpt_memmap/config.py | 8 +-- .../data/preparator/gpt_memmap/prepare.py | 65 +++++++++++++++---- 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index f5708e9b3..ce60f00e0 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -59,11 +59,11 @@ class GPTHuggingfaceDatasetConfig(Config): loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) - chosen_loss_masking_spans: None | str = Field( - default=None, desc="Field containing character chosen spans to mask for loss computation", hint=FieldHint.optional + chosen_text: None | str = Field( + default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional ) - rejected_loss_masking_spans: None | str = Field( - default=None, desc="Field containing character rejected spans to mask for loss computation", hint=FieldHint.optional + rejected_text: None | str = Field( + default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional ) data_type: DataType | None = Field( default=None, diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index ba0be311c..3ef8b9721 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -75,6 +75,35 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict } def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + packed_texts = [] + chosen_spans = [] + rejected_spans = [] + + for conv_history, chosen_text, rejected_text in zip( + batch[self._config.dataset.field], + batch[self._config.dataset.chosen_text], + batch[self._config.dataset.rejected_text], + ): + # compute chosen span + full_chosen_text = conv_history + chosen_text + self._tokenizer.tokenizer.eos_token + chosen_span = [len(conv_history), len(full_chosen_text) - 1] + offset = len(full_chosen_text) + chosen_spans.append(chosen_span) + + # compute rejected span + full_rejected_text = self._tokenizer.tokenizer.bos_token + conv_history + rejected_text + rejected_span = [offset + len(self._tokenizer.tokenizer.bos_token + conv_history), offset + len(full_rejected_text) - 1] + rejected_spans.append(rejected_span) + + # pack texts + packed_text = full_chosen_text + full_rejected_text + + assert packed_text[chosen_span[0]: chosen_span[1] + 1] == chosen_text + self._tokenizer.tokenizer.eos_token, f"{packed_text[chosen_span[0]: chosen_span[1] + 1]} does not match {chosen_text}" + + assert packed_text[rejected_span[0]: rejected_span[1] + 1] == rejected_text, f"{packed_text[rejected_span[0]: rejected_span[1] + 1]} does not match {rejected_text}" + packed_texts.append(packed_text) + + # tokenize with spans input_ids, chosen_token_spans, rejected_token_spans = map( list, zip( @@ -82,20 +111,32 @@ def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans[0], dtype=np.int32), - np.array(token_spans[1], dtype=np.int32) + np.array([token_spans[1][0], token_spans[1][1] + 1], dtype=np.int32) # adding 1 to end for eos token ) for input_ids, token_spans in [ self._tokenizer.tokenize_with_spans(text, [chosen_span, rejected_span]) for text, chosen_span, rejected_span in zip( - batch[self._config.dataset.field], - batch[self._config.dataset.chosen_loss_masking_spans], - batch[self._config.dataset.rejected_loss_masking_spans] + packed_texts, + chosen_spans, + rejected_spans ) ] ] ), ) + # verify span tokenization + for input_ids_arr, chosen_token_span, rejected_token_span, chosen_text, rejected_text in zip( + input_ids, + chosen_token_spans, + rejected_token_spans, + batch[self._config.dataset.chosen_text], + batch[self._config.dataset.rejected_text] + ): + assert self._tokenizer.tokenizer.decode(input_ids_arr[chosen_token_span[0]: chosen_token_span[1] + 1]) == chosen_text + self._tokenizer.tokenizer.eos_token + + assert self._tokenizer.tokenizer.decode(input_ids_arr[rejected_token_span[0]: rejected_token_span[1] + 1]) == rejected_text + self._tokenizer.tokenizer.eos_token + num_tokens = [len(x) for x in input_ids] return { @@ -117,7 +158,7 @@ def _document_generator(): np.array(item["input_ids"], dtype=self._data_type.numpy), np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), ) - elif "chosen_token_spans" in shard_dataset.column_names and "rejected_token_spans" in shard_dataset.column_names and self._config.dataset.chosen_loss_masking_spans is not None and self._config.dataset.rejected_loss_masking_spans is not None: + elif "chosen_token_spans" in shard_dataset.column_names and "rejected_token_spans" in shard_dataset.column_names and self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), @@ -253,9 +294,9 @@ def run(self) -> None: if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") if self._config.dataset.loss_masking_spans is not None and \ - (self._config.dataset.chosen_loss_masking_spans is not None or self._config.dataset.rejected_loss_masking_spans is not None): + (self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None): raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") - if (self._config.dataset.chosen_loss_masking_spans is None) != (self._config.dataset.rejected_loss_masking_spans is None): + if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") # route tokenize function @@ -263,11 +304,11 @@ def run(self) -> None: if self._config.dataset.loss_masking_spans not in dataset.column_names: raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.") tokenize_fn = self._tokenize_batch_with_spans - elif self._config.dataset.chosen_loss_masking_spans is not None and self._config.dataset.rejected_loss_masking_spans is not None: - if self._config.dataset.chosen_loss_masking_spans not in dataset.column_names: - raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_loss_masking_spans}'.") - if self._config.dataset.rejected_loss_masking_spans not in dataset.column_names: - raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_loss_masking_spans}'.") + elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + if self._config.dataset.chosen_text not in dataset.column_names: + raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") + if self._config.dataset.rejected_text not in dataset.column_names: + raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.") tokenize_fn = self._tokenize_preference_batch_with_spans else: tokenize_fn = self._tokenize_batch From 257d236ce06805f3116cfadf0377a9a6e22987ab Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 10 Apr 2025 20:27:47 +0000 Subject: [PATCH 10/47] removing assert --- .../data/preparator/gpt_memmap/prepare.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 3ef8b9721..f7d0c1b77 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -73,12 +73,12 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict "token_spans": token_spans, "num_tokens": num_tokens, } - + def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: packed_texts = [] chosen_spans = [] rejected_spans = [] - + for conv_history, chosen_text, rejected_text in zip( batch[self._config.dataset.field], batch[self._config.dataset.chosen_text], @@ -92,15 +92,22 @@ def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any # compute rejected span full_rejected_text = self._tokenizer.tokenizer.bos_token + conv_history + rejected_text - rejected_span = [offset + len(self._tokenizer.tokenizer.bos_token + conv_history), offset + len(full_rejected_text) - 1] + rejected_span = [ + offset + len(self._tokenizer.tokenizer.bos_token + conv_history), + offset + len(full_rejected_text) - 1, + ] rejected_spans.append(rejected_span) # pack texts packed_text = full_chosen_text + full_rejected_text - assert packed_text[chosen_span[0]: chosen_span[1] + 1] == chosen_text + self._tokenizer.tokenizer.eos_token, f"{packed_text[chosen_span[0]: chosen_span[1] + 1]} does not match {chosen_text}" + assert ( + packed_text[chosen_span[0] : chosen_span[1] + 1] == chosen_text + self._tokenizer.tokenizer.eos_token + ), f"{packed_text[chosen_span[0]: chosen_span[1] + 1]} does not match {chosen_text}" - assert packed_text[rejected_span[0]: rejected_span[1] + 1] == rejected_text, f"{packed_text[rejected_span[0]: rejected_span[1] + 1]} does not match {rejected_text}" + assert ( + packed_text[rejected_span[0] : rejected_span[1] + 1] == rejected_text + ), f"{packed_text[rejected_span[0]: rejected_span[1] + 1]} does not match {rejected_text}" packed_texts.append(packed_text) # tokenize with spans @@ -111,39 +118,24 @@ def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any ( np.array(input_ids, dtype=self._data_type.numpy), np.array(token_spans[0], dtype=np.int32), - np.array([token_spans[1][0], token_spans[1][1] + 1], dtype=np.int32) # adding 1 to end for eos token + np.array( + [token_spans[1][0], token_spans[1][1] + 1], dtype=np.int32 + ), # adding 1 to end for eos token ) for input_ids, token_spans in [ self._tokenizer.tokenize_with_spans(text, [chosen_span, rejected_span]) - for text, chosen_span, rejected_span in zip( - packed_texts, - chosen_spans, - rejected_spans - ) + for text, chosen_span, rejected_span in zip(packed_texts, chosen_spans, rejected_spans) ] ] ), ) - # verify span tokenization - for input_ids_arr, chosen_token_span, rejected_token_span, chosen_text, rejected_text in zip( - input_ids, - chosen_token_spans, - rejected_token_spans, - batch[self._config.dataset.chosen_text], - batch[self._config.dataset.rejected_text] - ): - assert self._tokenizer.tokenizer.decode(input_ids_arr[chosen_token_span[0]: chosen_token_span[1] + 1]) == chosen_text + self._tokenizer.tokenizer.eos_token - - assert self._tokenizer.tokenizer.decode(input_ids_arr[rejected_token_span[0]: rejected_token_span[1] + 1]) == rejected_text + self._tokenizer.tokenizer.eos_token - num_tokens = [len(x) for x in input_ids] - return { "input_ids": input_ids, "chosen_token_spans": chosen_token_spans, "rejected_token_spans": rejected_token_spans, - "num_tokens": num_tokens + "num_tokens": num_tokens, } def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: @@ -158,12 +150,19 @@ def _document_generator(): np.array(item["input_ids"], dtype=self._data_type.numpy), np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), ) - elif "chosen_token_spans" in shard_dataset.column_names and "rejected_token_spans" in shard_dataset.column_names and self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + elif ( + "chosen_token_spans" in shard_dataset.column_names + and "rejected_token_spans" in shard_dataset.column_names + and self._config.dataset.chosen_text is not None + and self._config.dataset.rejected_text is not None + ): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), chosen_loss_masking_spans=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_loss_masking_spans=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2) + rejected_loss_masking_spans=np.array(item["rejected_token_spans"], dtype=np.int32).reshape( + -1, 2 + ), ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): @@ -293,12 +292,13 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") - if self._config.dataset.loss_masking_spans is not None and \ - (self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None): + if self._config.dataset.loss_masking_spans is not None and ( + self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None + ): raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") - + # route tokenize function if self._config.dataset.loss_masking_spans is not None: if self._config.dataset.loss_masking_spans not in dataset.column_names: From aa8a87183ca38904b730702d74ce9c28fd575ca3 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 10 Apr 2025 20:33:32 +0000 Subject: [PATCH 11/47] moving dpo loss call --- fast_llm/layers/language_model/head.py | 71 +++++++++----------------- 1 file changed, 25 insertions(+), 46 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 80cd259a7..8a49be0a9 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -10,10 +10,10 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, TritonConfig, LossFunctionType +from fast_llm.functional.config import CrossEntropyImpl, LossFunctionType, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward -from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.functional.dpo import compute_simplified_dpo_loss +from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( LanguageModelBaseConfig, @@ -89,10 +89,10 @@ def __init__( self._cross_entropy_impl = CrossEntropyImpl.triton else: self._cross_entropy_impl = CrossEntropyImpl.fused - self._loss_fcn = self._logits_cross_entropy_forward_backward_split - else: - self._loss_fcn = self._logits_dpo + elif self._loss_function_type == LossFunctionType.dpo: self.dpo_beta = config.beta + else: + raise NotImplementedError(f"Loss function type {self._loss_function_type} not supported.") self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) @@ -172,7 +172,7 @@ def _forward_backward( ) output_weights = self._get_output_weights(kwargs) - loss, ln_output_grad = self._loss_fcn( + loss, ln_output_grad = self._logits_cross_entropy_forward_backward_split( ln_output.detach(), labels, output_weights, grad_output, kwargs, losses ) @@ -181,38 +181,6 @@ def _forward_backward( return loss, input_.grad else: return loss, None - - def _logits_dpo( - self, - input_: torch.Tensor, - labels: torch.Tensor | None, - weight: torch.Tensor, - grad_output: float, - kwargs: dict, - losses: dict | None = None - ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - logits, context = output_parallel_linear_forward( - input_=input_, - weight=weight, - bias=None, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - sequence_parallel=self._sequence_parallel and self._parallel_embeddings, - ) - - loss, grad = compute_simplified_dpo_loss( - logits.flatten(0, -2), - labels, - kwargs[LanguageModelKwargs.chosen_spans], - kwargs[LanguageModelKwargs.rejected_spans], - self.dpo_beta, - grad_output - ) - - # TODO: de-allocate earlier. - del logits - return loss, output_parallel_linear_backward(grad, context).view_as(input_) - - def _get_output_weights(self, kwargs: dict) -> torch.Tensor: if self._tie_word_embeddings: @@ -326,14 +294,25 @@ def _logits_cross_entropy_forward_backward( if labels is None: return logits * self._logits_scale_factor, None - loss, grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - labels, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, - ) + if self._loss_function_type == LossFunctionType.cross_entropy: + loss, grad = cross_entropy_forward_backward( + logits.flatten(0, -2), + labels, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + grad_output=grad_output, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + ) + elif self._loss_function_type == LossFunctionType.dpo: + loss, grad = compute_simplified_dpo_loss( + logits.flatten(0, -2), + labels, + kwargs[LanguageModelKwargs.chosen_spans], + kwargs[LanguageModelKwargs.rejected_spans], + self.dpo_beta, + grad_output, + ) + # TODO: de-allocate earlier. del logits return loss, output_parallel_linear_backward(grad, context) From d08bf4d9130c34610a64b5c47601fc0cadf05f68 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 10 Apr 2025 21:20:49 +0000 Subject: [PATCH 12/47] renaming --- fast_llm/data/data/gpt/data.py | 18 +++--- fast_llm/data/dataset/gpt/memmap.py | 59 ++++++++++--------- fast_llm/data/dataset/gpt/sampled.py | 49 ++++++++------- .../data/preparator/gpt_memmap/prepare.py | 4 +- tests/data/test_prepare_gpt_memmap.py | 16 ++--- 5 files changed, 79 insertions(+), 67 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 0cd4ff87a..cab925ea2 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -37,10 +37,10 @@ class GPTBatch: def gpt_data_collate_fn( - batch: list[GPTSample], - use_loss_masking_spans: bool, + batch: list[GPTSample], + use_loss_masking_spans: bool, cross_document_attention: bool, - use_preference_loss_masking_spans: bool + use_preference_loss_masking_spans: bool, ) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None @@ -50,16 +50,16 @@ def gpt_data_collate_fn( if use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] if use_preference_loss_masking_spans: - stacked_chosen_spans = [torch.from_numpy(sample.chosen_loss_masking_spans) for sample in batch] - stacked_rejected_spans= [torch.from_numpy(sample.rejected_loss_masking_spans) for sample in batch] + stacked_chosen_spans = [torch.from_numpy(sample.chosen_loss_masking_span) for sample in batch] + stacked_rejected_spans = [torch.from_numpy(sample.rejected_loss_masking_span) for sample in batch] if not cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), - loss_masking_spans=stacked_spans, + token_ids=torch.from_numpy(stacked_ids), + loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, chosen_loss_masking_spans=stacked_chosen_spans, - rejected_loss_masking_spans=stacked_rejected_spans + rejected_loss_masking_spans=stacked_rejected_spans, ) @@ -185,7 +185,7 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, - use_preference_loss_masking_spans=self._config.sampling.use_preference_loss_masking_spans + use_preference_loss_masking_spans=self._config.sampling.use_preference_loss_masking_spans, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 3270454df..36701a8e5 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -108,8 +108,10 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, ) ) - - rejected_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes + + rejected_span_offset = ( + offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes + ) for idx in range(self._num_documents): self._rejected_spans.append( np.frombuffer( @@ -142,7 +144,12 @@ def __del__(self): del self._index_bin_buffer_mmap def get( - self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, use_preference_loss_masking_spans: bool = False + self, + idx: int, + offset: int = 0, + length: int | None = None, + use_loss_masking_spans: bool = False, + use_preference_loss_masking_spans: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -160,39 +167,37 @@ def get( ] # subtract by offset to normalize span boundaries - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - - chosen_spans = None - rejected_spans = None + + chosen_span = None + rejected_span = None if use_preference_loss_masking_spans and self._chosen_spans is not None and self._rejected_spans is not None: - chosen_spans = self._chosen_spans[idx] + chosen_span = self._chosen_spans[idx] # filter spans that are outside the range of the selected tokens in the document - chosen_spans = chosen_spans[ - (chosen_spans[0] < offset + len(token_ids)) & (chosen_spans[1] >= offset) - ][0] + chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] # subtract by offset to normalize span boundaries - chosen_spans[0] = np.maximum(chosen_spans[0], offset) - offset # offset - chosen_spans[1] = np.minimum(chosen_spans[1], offset + len(token_ids) - 1) - offset + chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset + chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset - rejected_spans = self._rejected_spans[idx] + rejected_span = self._rejected_spans[idx] # filter spans that are outside the range of the selected tokens in the document - rejected_spans = rejected_spans[ - (rejected_spans[0] < offset + len(token_ids)) & (rejected_spans[1] >= offset) - ][0] + rejected_span = rejected_span[(rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset)][ + 0 + ] # subtract by offset to normalize span boundaries - rejected_spans[0] = np.maximum(rejected_spans[0], offset) - offset # offset - rejected_spans[1] = np.minimum(rejected_spans[1], offset + len(token_ids) - 1) - offset + rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset + rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset return GPTSample( - token_ids=token_ids, - loss_masking_spans=sample_spans, - chosen_loss_masking_spans=chosen_spans, - rejected_loss_masking_spans=rejected_spans + token_ids=token_ids, + loss_masking_spans=sample_spans, + chosen_loss_masking_span=chosen_span, + rejected_loss_masking_span=rejected_span, ) @property @@ -255,10 +260,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - if document.chosen_loss_masking_spans is not None: - chosen_spans.append(document.chosen_loss_masking_spans) - if document.rejected_loss_masking_spans is not None: - rejected_spans.append(document.rejected_loss_masking_spans) + if document.chosen_loss_masking_span is not None: + chosen_spans.append(document.chosen_loss_masking_span) + if document.rejected_loss_masking_span is not None: + rejected_spans.append(document.rejected_loss_masking_span) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 854ec9cb1..7927a859c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -30,8 +30,8 @@ class GPTSample: token_ids: np.ndarray loss_masking_spans: np.ndarray | None = None - chosen_loss_masking_spans: np.ndarray | None = None - rejected_loss_masking_spans: np.ndarray | None = None + chosen_loss_masking_span: np.ndarray | None = None + rejected_loss_masking_span: np.ndarray | None = None sequence_lengths: np.ndarray | None = None @@ -126,7 +126,9 @@ def __init__( if not self._config.enable_packing: self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) - self._doc_length_filtered_indicies = MemmapArray(base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy")) + self._doc_length_filtered_indicies = MemmapArray( + base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") + ) self._yaml_path = base_path.with_suffix(".yaml") # Sample or validate the dataset of a given rank. @@ -168,7 +170,10 @@ def _sample(self) -> None: # so we need `sequence_length * num_samples + 1` tokens in total. if self._config.enable_packing: num_epochs = math.ceil( - ((self._sequence_length + 1 - self._truncate_documents) * self._num_samples + 1 * self._truncate_documents) + ( + (self._sequence_length + 1 - self._truncate_documents) * self._num_samples + + 1 * self._truncate_documents + ) / tokens_per_epoch ) else: @@ -190,7 +195,7 @@ def _sample(self) -> None: "dataset": { "name": self._indexed_dataset.name, "documents_per_epoch": documents_per_epoch, - "tokens_per_epoch": tokens_per_epoch + "tokens_per_epoch": tokens_per_epoch, }, "num_samples": self._num_samples, "unshuffled_epochs": unshuffled_epochs, @@ -317,22 +322,14 @@ def _sample(self) -> None: else: # index of all documents less than seq length long doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save( - doc_length_filtered_indicies.numpy( - force=self._config.gpu - ) - ) + self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) # # apply shuffling on doc_length_filtered_indicies # document_shuffling_length_filtered_indices = torch.gather( # doc_length_filtered_indicies, dim=0, index=document_shuffling.to(torch.int64) # ) if shuffled_epochs > 0: - self._document_shuffling.save( - document_shuffling[:self._num_samples].numpy( - force=self._config.gpu - ) - ) + self._document_shuffling.save(document_shuffling[: self._num_samples].numpy(force=self._config.gpu)) self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) if self._yaml_path is not None: # yaml_data["unshuffled_tokens"] = num_tokens_unshuffled @@ -409,7 +406,9 @@ def __getitem__(self, index: int) -> typing.Any: if document_sampling_index < self._unshuffled_documents: document_index = document_sampling_index % self._documents_per_epoch else: - document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() + document_index = self._document_shuffling[ + document_sampling_index - self._unshuffled_documents + ].item() document_size = self._indexed_dataset.get_document_size(document_index) @@ -466,23 +465,29 @@ def __getitem__(self, index: int) -> typing.Any: ) Assert.eq(len(token_ids), self._sequence_length + 1) - return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + return GPTSample( + token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths + ) else: if index < self._unshuffled_documents: document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] else: - document_index = self._doc_length_filtered_indicies[self._document_shuffling[index - self._unshuffled_documents].item()] - + document_index = self._doc_length_filtered_indicies[ + self._document_shuffling[index - self._unshuffled_documents].item() + ] + sample = self._indexed_dataset.get( document_index, offset=0, length=self._document_sizes[document_index], use_loss_masking_spans=self._config.use_loss_masking_spans, - use_preference_loss_masking_spans=self._config.use_preference_loss_masking_spans + use_preference_loss_masking_spans=self._config.use_preference_loss_masking_spans, ) - chosen_loss_masking_span_end = sample.chosen_loss_masking_spans[1] + 1 - sequence_lengths = np.array([chosen_loss_masking_span_end, len(sample.token_ids) - chosen_loss_masking_span_end]) + chosen_loss_masking_span_end = sample.chosen_loss_masking_span[1] + 1 + sequence_lengths = np.array( + [chosen_loss_masking_span_end, len(sample.token_ids) - chosen_loss_masking_span_end] + ) sample.sequence_lengths = sequence_lengths return sample diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index f7d0c1b77..c2e4f018b 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -159,8 +159,8 @@ def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_loss_masking_spans=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_loss_masking_spans=np.array(item["rejected_token_spans"], dtype=np.int32).reshape( + chosen_loss_masking_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), + rejected_loss_masking_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape( -1, 2 ), ) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 086e8b56c..b1a488f4f 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -25,6 +25,7 @@ def get_preparator(output_path: str, dataset_path_name: str) -> GPTMemmapDataset ) return config.get_dataset_preparator_class()(config=config) + @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_dataset(dtype): documents = [GPTSample(np.random.randint(1000, size=np.random.randint(1, 100)).astype(dtype)) for _ in range(100)] @@ -37,10 +38,11 @@ def test_write_memmap_dataset(dtype): dataset.get(i).token_ids, document.token_ids, equal_nan=True ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." + @pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) def test_write_memmap_preference_dataset(dtype): def generate_valid_span(max_seq_length): - span = np.random.choice(np.arange(0, max_seq_length-1), size=2, replace=False) + span = np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False) return np.sort(span) vocab_size = 1000 @@ -50,8 +52,8 @@ def generate_valid_span(max_seq_length): documents = [ GPTSample( token_ids=np.random.randint(vocab_size, size=max_seq_length).astype(dtype), - chosen_loss_masking_spans=generate_valid_span(max_seq_length=max_seq_length), - rejected_loss_masking_spans=generate_valid_span(max_seq_length=max_seq_length) + chosen_loss_masking_span=generate_valid_span(max_seq_length=max_seq_length), + rejected_loss_masking_span=generate_valid_span(max_seq_length=max_seq_length), ) for _ in range(num_samples) ] @@ -66,12 +68,12 @@ def generate_valid_span(max_seq_length): ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." assert np.array_equal( - dataset_item.chosen_loss_masking_spans, document.chosen_loss_masking_spans, equal_nan=True - ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_loss_masking_spans} != {dataset.get(i).chosen_loss_masking_spans}." + dataset_item.chosen_loss_masking_span, document.chosen_loss_masking_span, equal_nan=True + ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_loss_masking_span} != {dataset.get(i).chosen_loss_masking_span}." assert np.array_equal( - dataset_item.rejected_loss_masking_spans, document.rejected_loss_masking_spans, equal_nan=True - ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_loss_masking_spans} != {dataset.get(i).rejected_loss_masking_spans}." + dataset_item.rejected_loss_masking_span, document.rejected_loss_masking_span, equal_nan=True + ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_loss_masking_span} != {dataset.get(i).rejected_loss_masking_span}." def test_load_metadata_from_hub(): From b41021063382b428c05bd16ce413d339a7162e5f Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Sun, 13 Apr 2025 19:45:31 +0000 Subject: [PATCH 13/47] padding fix --- fast_llm/data/dataset/gpt/sampled.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 7927a859c..328e2f66c 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -320,6 +320,9 @@ def _sample(self) -> None: # Free memory del document_shuffling else: + if not self._truncate_documents: + yaml_data["unshuffled_tokens"] = None # not used with packing disabled + # index of all documents less than seq length long doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) @@ -485,10 +488,18 @@ def __getitem__(self, index: int) -> typing.Any: ) chosen_loss_masking_span_end = sample.chosen_loss_masking_span[1] + 1 - sequence_lengths = np.array( - [chosen_loss_masking_span_end, len(sample.token_ids) - chosen_loss_masking_span_end] - ) - sample.sequence_lengths = sequence_lengths + sequence_lengths = [ + chosen_loss_masking_span_end, + len(sample.token_ids) - chosen_loss_masking_span_end, + ] + + # compute padding size + padding = np.full((self._sequence_length,), 0) + padding[: len(sample.token_ids)] = sample.token_ids + sequence_lengths.append(self._sequence_length - len(sample.token_ids)) + sample.token_ids = padding + + sample.sequence_lengths = np.array(sequence_lengths) return sample From 366a20b55652e93e0a888890bf4e5bf5733e0fb3 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Sun, 13 Apr 2025 21:08:37 +0000 Subject: [PATCH 14/47] dpo config changes --- fast_llm/functional/config.py | 5 +++-- fast_llm/layers/language_model/config.py | 8 ++++---- fast_llm/layers/language_model/head.py | 16 ++++++++-------- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index b45bc117e..df1f5bf09 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -92,6 +92,7 @@ class CrossEntropyImpl(str, enum.Enum): fused = "fused" triton = "triton" -class LossFunctionType(str, enum.Enum): + +class LossFunction(str, enum.Enum): cross_entropy = "cross_entropy" - dpo = "dpo" \ No newline at end of file + dpo = "dpo" diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 23ce17619..68f9d8696 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames -from fast_llm.functional.config import CrossEntropyImpl, LossFunctionType +from fast_llm.functional.config import CrossEntropyImpl, LossFunction from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig from fast_llm.utils import Assert @@ -142,12 +142,12 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", hint=FieldHint.feature, ) - loss_function_type: LossFunctionType = Field( - default=LossFunctionType.cross_entropy, + loss_function: LossFunction = Field( + default=LossFunction.cross_entropy, desc="Type of loss function to use", hint=FieldHint.feature, ) - beta: float | None = Field( + dpo_beta: float | None = Field( default=1.0, desc="Beta value for DPO loss.", hint=FieldHint.feature, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8a49be0a9..ffc803ab2 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, LossFunctionType, TritonConfig +from fast_llm.functional.config import CrossEntropyImpl, LossFunction, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward from fast_llm.functional.dpo import compute_simplified_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward @@ -79,8 +79,8 @@ def __init__( self._init_output_weights(hidden_dim, config) - self._loss_function_type = config.loss_function_type - if self._loss_function_type == LossFunctionType.cross_entropy: + self._loss_function = config.loss_function + if self._loss_function == LossFunction.cross_entropy: self._cross_entropy_impl = config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: @@ -89,10 +89,10 @@ def __init__( self._cross_entropy_impl = CrossEntropyImpl.triton else: self._cross_entropy_impl = CrossEntropyImpl.fused - elif self._loss_function_type == LossFunctionType.dpo: - self.dpo_beta = config.beta + elif self._loss_function == LossFunction.dpo: + self.dpo_beta = config.dpo_beta else: - raise NotImplementedError(f"Loss function type {self._loss_function_type} not supported.") + raise NotImplementedError(f"Loss function type {self._loss_function} not supported.") self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) @@ -294,7 +294,7 @@ def _logits_cross_entropy_forward_backward( if labels is None: return logits * self._logits_scale_factor, None - if self._loss_function_type == LossFunctionType.cross_entropy: + if self._loss_function == LossFunction.cross_entropy: loss, grad = cross_entropy_forward_backward( logits.flatten(0, -2), labels, @@ -303,7 +303,7 @@ def _logits_cross_entropy_forward_backward( implementation=self._cross_entropy_impl, logits_scale_factor=self._logits_scale_factor, ) - elif self._loss_function_type == LossFunctionType.dpo: + elif self._loss_function == LossFunction.dpo: loss, grad = compute_simplified_dpo_loss( logits.flatten(0, -2), labels, From dca842efa1a5969e663ed52bab037209f9faf788 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Sun, 13 Apr 2025 21:08:56 +0000 Subject: [PATCH 15/47] memmap version fixes --- fast_llm/data/dataset/gpt/memmap.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 36701a8e5..3b2405711 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -40,10 +40,9 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: self._has_preference_spans = struct.unpack("= 2: self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -95,7 +94,7 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None # read preference spans self._chosen_spans = None self._rejected_spans = None - if self._has_preference_spans: + if self._has_preference_spans and self._version >= 3: self._chosen_spans = [] self._rejected_spans = [] chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes From ca86694db57805b4b31f3f3483524e6c199c5ee0 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 14 Apr 2025 23:27:23 +0000 Subject: [PATCH 16/47] removing dpo flags and new sampling class --- fast_llm/data/data/gpt/data.py | 6 +- fast_llm/data/dataset/gpt/config.py | 10 - fast_llm/data/dataset/gpt/indexed.py | 17 +- fast_llm/data/dataset/gpt/memmap.py | 3 +- fast_llm/data/dataset/gpt/sampled.py | 619 ++++++++++++++++++++++++++- 5 files changed, 632 insertions(+), 23 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index cab925ea2..76d90a7e7 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -168,9 +168,11 @@ def get_iterator( Assert.incl(dataset_name, self._datasets) Assert.in_range_incl(batch_config.sequence_length, 1, self._max_sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") + + dataset = self._datasets[dataset_name] # noqa return iter( torch.utils.data.DataLoader( - self._datasets[dataset_name], # noqa + dataset, batch_sampler=SampledDatasetIterator( total_samples=len(self._datasets[dataset_name]), begin_index=consumed_samples, @@ -185,7 +187,7 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, - use_preference_loss_masking_spans=self._config.sampling.use_preference_loss_masking_spans, + use_preference_loss_masking_spans=dataset._dataset._indexed_dataset._has_preference_spans, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index fbe3b7d35..0f04884b6 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -57,16 +57,6 @@ class GPTSamplingConfig(SamplingConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) - use_preference_loss_masking_spans: bool | None = Field( - default=None, - desc="Read preference loss masking spans from the dataset.", - hint=FieldHint.feature, - ) - enable_packing: bool | None = Field( - default=True, - desc="Whether to enable packing or not.", - hint=FieldHint.feature, - ) shuffle: ShufflingType | None = Field( default=None, desc="Shuffling strategy.", diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 688ea6a70..2596a5af6 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -26,14 +26,19 @@ def get_document_size(self, index: int) -> int: """ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset, LegacyGPTSampledIndexedDataset - - return ( - LegacyGPTSampledIndexedDataset(self, sampling) - if sampling.config.shuffle == ShufflingType.legacy - else GPTSampledIndexedDataset(self, sampling) + from fast_llm.data.dataset.gpt.sampled import ( + GPTSampledIndexedDataset, + GPTSampledPreferenceIndexedDataset, + LegacyGPTSampledIndexedDataset, ) + if sampling.config.shuffle == ShufflingType.legacy: + return LegacyGPTSampledIndexedDataset(self, sampling) + elif self._has_preference_spans: + return GPTSampledPreferenceIndexedDataset(self, sampling) + else: + return GPTSampledIndexedDataset(self, sampling) + class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): """ diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 3b2405711..1725d6fac 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -148,7 +148,6 @@ def get( offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - use_preference_loss_masking_spans: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -171,7 +170,7 @@ def get( chosen_span = None rejected_span = None - if use_preference_loss_masking_spans and self._chosen_spans is not None and self._rejected_spans is not None: + if self._has_preference_spans and self._chosen_spans is not None and self._rejected_spans is not None: chosen_span = self._chosen_spans[idx] # filter spans that are outside the range of the selected tokens in the document diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 328e2f66c..13888fb4e 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -75,7 +75,251 @@ def _lazy_load(self): TOKEN_CUMSUM_RATE = 10 -class GPTSampledIndexedDataset(SampledDataset): +class GPTSampledPreferenceIndexedDataset(SampledDataset): + """ + A sampled GPT dataset for preference data. + """ + + def __init__( + self, + indexed_dataset: GPTIndexedDataset, + sampling: GPTSamplingData, + ): + assert isinstance(sampling, GPTSamplingData) + self._indexed_dataset = indexed_dataset + self._num_samples = sampling.num_samples + self._sequence_length = sampling.sequence_length + self._cross_document_attention = sampling.cross_document_attention + self._config = sampling.config + self._truncate_documents = sampling.truncate_documents + self._device = torch.device("cuda" if self._config.gpu else "cpu") + + if self._truncate_documents: + raise NotImplementedError("Sampled preference indexed dataset does not support document trunctation.") + + if sampling.cache_directory is None: + self._document_shuffling = MemmapArray() + self._token_cumsum_shuffled = MemmapArray() + self._token_cumsum_unshuffled = MemmapArray() + self._yaml_path = None + log_main_rank( + " > No dataset cache directory provided, building the index map on all ranks." + "This may be very inefficient...", + log_fn=logger.warning, + ) + self._sample() + else: + base_path = ( + sampling.cache_directory / f"{self.name}_ns_{self._num_samples}_sl_{self._sequence_length}" + f"_s_{self._config.seed}" + ) + # TODO: Names are confusing + + # contains shuffled document indicies + self._document_shuffling = MemmapArray(base_path.with_name(base_path.name + "_shuffling.npy")) + + self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) + self._doc_length_filtered_indicies = MemmapArray( + base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") + ) + + self._yaml_path = base_path.with_suffix(".yaml") + # Sample or validate the dataset of a given rank. + if sampling.distributed.config.rank == sampling.get_next_rank(): + self._sample() + # No barrier yet to allow running in parallel. + # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. + + def _sample(self) -> None: + """ + Create a `GPTSampledDataset` with the requested parameters. + """ + # compute document sizes, documents per epoch, and tok per epoch + document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + documents_per_epoch = document_sizes.numel() + tokens_per_epoch = document_sizes.sum().item() + + # filter documents past the sequence length + long_docs_filter = document_sizes > self._sequence_length + 1 # TODO: do we need the +1 here if no truncation? + ignored_documents = sum(long_docs_filter) + if ignored_documents: + log_main_rank( + f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", + log_fn=logger.warning, + ) + tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + if tokens_per_epoch == 0: + raise RuntimeError( + f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + ) + + # update documents per epochs/num epochs based on filter + documents_per_epoch = (~long_docs_filter).sum().item() + num_epochs = math.ceil(self._num_samples / documents_per_epoch) + + # compute shuffled/unshuffled epochs + generator = torch.Generator(device=self._device) + if self._config.shuffle == ShufflingType.skip_first_epoch: + shuffled_epochs = num_epochs - 1 + elif self._config.shuffle == ShufflingType.disabled: + shuffled_epochs = 0 + else: + shuffled_epochs = num_epochs + shuffled_documents = documents_per_epoch * shuffled_epochs + unshuffled_epochs = num_epochs - shuffled_epochs + + yaml_data = { + "dataset": { + "name": self._indexed_dataset.name, + "documents_per_epoch": documents_per_epoch, + "tokens_per_epoch": tokens_per_epoch, + }, + "num_samples": self._num_samples, + "unshuffled_epochs": unshuffled_epochs, + "sequence_length": self._sequence_length, + "truncate_documents": self._truncate_documents, + "config": self._config.to_serialized(), + } + + if self._yaml_path is not None and self._yaml_path.is_file(): + loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) + self._load_yaml_data(yaml_data) + del loaded_yaml_data["unshuffled_tokens"] + + if loaded_yaml_data != yaml_data: + raise RuntimeError( + f"Invalid dataset cache for dataset {self.name}." + " If this is due to an intended configuration change," + " please delete the cache before continuing." + f"\nCurrent config:\n{yaml.safe_dump(yaml_data)}" + f"\nCached config:\n{yaml.safe_dump(loaded_yaml_data)}" + ) + # Dataset is already sampled, skip. + logger.info(f"Using existing sampling for dataset {self.name}") + return + + if shuffled_documents > 1e8: + warnings.warn( + f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." + f" This may take a while and/or use an excessive amount of memory." + ) + elif documents_per_epoch > 1e8: + # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? + warnings.warn( + f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." + f" Sampling may take a while and/or use an excessive amount of memory." + ) + + # Use the smallest possible data type to save memory and disk usage. + document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch + # Shuffle the dataset (documents) + # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial + # so we only evaluate and store the shuffled part `document_shuffling`. + if self._config.shuffle == ShufflingType.full: + generator.manual_seed(self._config.seed) + # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` + document_shuffling = ( + torch.randperm( + shuffled_documents, + generator=generator, + dtype=get_unsigned_integer_type(shuffled_documents).torch, + device=self._device, + ) + .remainder_(documents_per_epoch) + .to(dtype=document_shuffling_dtype) + ) + elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): + document_shuffling = torch.empty( + shuffled_documents, + dtype=document_shuffling_dtype, + device=self._device, + ) + for i in range(shuffled_epochs): + generator.manual_seed(self._config.seed + i * 571) + torch.randperm( + documents_per_epoch, + generator=generator, + out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], + ) + elif self._config.shuffle == ShufflingType.disabled: + document_shuffling = None + else: + raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + + yaml_data["unshuffled_tokens"] = None # not used with packing disabled + + # index of all documents less than seq length long + doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] + self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) + + # save document shuffling and document sizes + if shuffled_epochs > 0: + self._document_shuffling.save(document_shuffling[: self._num_samples].numpy(force=self._config.gpu)) + self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) + + # save yaml + if self._yaml_path is not None: + # yaml_data["unshuffled_tokens"] = num_tokens_unshuffled + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + + def __len__(self) -> int: + return self._num_samples + + def __getitem__(self, index: int) -> typing.Any: + """ + Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) + with the requested sampling index. + The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). + """ + self._lazy_load() + + # get the document index to read + if index < self._unshuffled_documents: + document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] + else: + document_index = self._doc_length_filtered_indicies[ + self._document_shuffling[index - self._unshuffled_documents].item() + ] + + # read document from memmap + sample = self._indexed_dataset.get( + document_index, + offset=0, + length=self._document_sizes[document_index], + use_loss_masking_spans=self._config.use_loss_masking_spans, + ) + + # compute sequence lengths + chosen_loss_masking_span_end = sample.chosen_loss_masking_span[1] + 1 + sequence_lengths = [ + chosen_loss_masking_span_end, + len(sample.token_ids) - chosen_loss_masking_span_end, + ] + + # compute padding size + padding = np.full((self._sequence_length,), 0) + padding[: len(sample.token_ids)] = sample.token_ids + sequence_lengths.append(self._sequence_length - len(sample.token_ids)) + sample.token_ids = padding + + if not self._cross_document_attention: # only add sequence lengths if no cross doc attention + sample.sequence_lengths = np.array(sequence_lengths) + + return sample + + @property + def name(self) -> str: + return self._indexed_dataset.name + + def _lazy_load(self): + if not hasattr(self, "_documents_per_epoch"): + self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) + + def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: + self._documents_per_epoch = data["dataset"]["documents_per_epoch"] + self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch + """ A sampled GPT dataset. """ @@ -152,7 +396,7 @@ def _sample(self) -> None: "The C++ extension for dataset sampling is missing." " Please make sure Fast-LLM is installed correctly." ) - long_docs_filter = document_sizes > self._sequence_length + 1 + long_docs_filter = document_sizes > self._sequence_length ignored_documents = sum(long_docs_filter) if ignored_documents: log_main_rank( @@ -484,7 +728,6 @@ def __getitem__(self, index: int) -> typing.Any: offset=0, length=self._document_sizes[document_index], use_loss_masking_spans=self._config.use_loss_masking_spans, - use_preference_loss_masking_spans=self._config.use_preference_loss_masking_spans, ) chosen_loss_masking_span_end = sample.chosen_loss_masking_span[1] + 1 @@ -511,6 +754,376 @@ def _lazy_load(self): if not hasattr(self, "_documents_per_epoch"): self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) + def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: + self._documents_per_epoch = data["dataset"]["documents_per_epoch"] + + # if "unshuffled_tokens" not in data: + # # Backward compatibility + # # TODO v0.x: Remove + # assert self._truncate_documents + # data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] + + # self._unshuffled_tokens = data["unshuffled_tokens"] + self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch + + +class GPTSampledIndexedDataset(SampledDataset): + """ + A sampled GPT dataset. + """ + + def __init__( + self, + indexed_dataset: GPTIndexedDataset, + sampling: GPTSamplingData, + ): + assert isinstance(sampling, GPTSamplingData) + self._indexed_dataset = indexed_dataset + self._num_samples = sampling.num_samples + self._sequence_length = sampling.sequence_length + self._cross_document_attention = sampling.cross_document_attention + self._config = sampling.config + self._truncate_documents = sampling.truncate_documents + self._device = torch.device("cuda" if self._config.gpu else "cpu") + + if sampling.cache_directory is None: + self._document_shuffling = MemmapArray() + self._token_cumsum_shuffled = MemmapArray() + self._token_cumsum_unshuffled = MemmapArray() + self._yaml_path = None + log_main_rank( + " > No dataset cache directory provided, building the index map on all ranks." + "This may be very inefficient...", + log_fn=logger.warning, + ) + self._sample() + else: + base_path = ( + sampling.cache_directory / f"{self.name}_ns_{self._num_samples}_sl_{self._sequence_length}" + f"_s_{self._config.seed}" + ) + # TODO: Names are confusing + self._document_shuffling = MemmapArray(base_path.with_name(base_path.name + "_shuffling.npy")) + self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) + self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) + self._yaml_path = base_path.with_suffix(".yaml") + # Sample or validate the dataset of a given rank. + if sampling.distributed.config.rank == sampling.get_next_rank(): + self._sample() + # No barrier yet to allow running in parallel. + # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. + + def _sample(self) -> None: + """ + Create a `GPTSampledDataset` with the requested parameters. + """ + # Get the document sizes, the main information needed for sampling. + document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) + documents_per_epoch = document_sizes.numel() + tokens_per_epoch = document_sizes.sum().item() + + # Calculate basic stats. + if not self._truncate_documents: + assert _extension_available, ( + "The C++ extension for dataset sampling is missing." + " Please make sure Fast-LLM is installed correctly." + ) + long_docs_filter = document_sizes > self._sequence_length + 1 + ignored_documents = sum(long_docs_filter) + if ignored_documents: + log_main_rank( + f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", + log_fn=logger.warning, + ) + tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() + if tokens_per_epoch == 0: + raise RuntimeError( + f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." + ) + # TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? + # We produce sequences of length `self._sequence_length + 1` so the last token has a label, + # but in case of truncations we also include that last label in the following sample, + # so we need `sequence_length * num_samples + 1` tokens in total. + num_epochs = math.ceil( + ((self._sequence_length + 1 - self._truncate_documents) * self._num_samples + 1 * self._truncate_documents) + / tokens_per_epoch + ) + + # Prepare for shuffling. + generator = torch.Generator(device=self._device) + if self._config.shuffle == ShufflingType.skip_first_epoch: + shuffled_epochs = num_epochs - 1 + elif self._config.shuffle == ShufflingType.disabled: + shuffled_epochs = 0 + else: + shuffled_epochs = num_epochs + shuffled_documents = documents_per_epoch * shuffled_epochs + unshuffled_epochs = num_epochs - shuffled_epochs + + yaml_data = { + "dataset": { + "name": self._indexed_dataset.name, + "documents_per_epoch": documents_per_epoch, + "tokens_per_epoch": tokens_per_epoch, + }, + "num_samples": self._num_samples, + "unshuffled_epochs": unshuffled_epochs, + "sequence_length": self._sequence_length, + "truncate_documents": self._truncate_documents, + "config": self._config.to_serialized(), + } + if self._truncate_documents: + yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs + + if self._yaml_path is not None and self._yaml_path.is_file(): + loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) + self._load_yaml_data(yaml_data) + if not self._truncate_documents: + del loaded_yaml_data["unshuffled_tokens"] + + if loaded_yaml_data != yaml_data: + raise RuntimeError( + f"Invalid dataset cache for dataset {self.name}." + " If this is due to an intended configuration change," + " please delete the cache before continuing." + f"\nCurrent config:\n{yaml.safe_dump(yaml_data)}" + f"\nCached config:\n{yaml.safe_dump(loaded_yaml_data)}" + ) + # Dataset is already sampled, skip. + logger.info(f"Using existing sampling for dataset {self.name}") + return + + if shuffled_documents > 1e8: + warnings.warn( + f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." + f" This may take a while and/or use an excessive amount of memory." + ) + elif documents_per_epoch > 1e8: + # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? + warnings.warn( + f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." + f" Sampling may take a while and/or use an excessive amount of memory." + ) + + # Use the smallest possible data type to save memory and disk usage. + document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch + # Shuffle the dataset (documents) + # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial + # so we only evaluate and store the shuffled part `document_shuffling`. + if self._config.shuffle == ShufflingType.full: + generator.manual_seed(self._config.seed) + # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` + document_shuffling = ( + torch.randperm( + shuffled_documents, + generator=generator, + dtype=get_unsigned_integer_type(shuffled_documents).torch, + device=self._device, + ) + .remainder_(documents_per_epoch) + .to(dtype=document_shuffling_dtype) + ) + elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): + document_shuffling = torch.empty( + shuffled_documents, + dtype=document_shuffling_dtype, + device=self._device, + ) + for i in range(shuffled_epochs): + generator.manual_seed(self._config.seed + i * 571) + torch.randperm( + documents_per_epoch, + generator=generator, + out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], + ) + elif self._config.shuffle == ShufflingType.disabled: + document_shuffling = None + else: + raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + + # To get a sample on the fly we need to know where it begins, + # and this is a non-trivial information because the documents have variable length. + # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. + # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. + # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. + # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. + # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. + # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` + if unshuffled_epochs > 0: + token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( + document_sizes, + offset=0, + # TODO: Allowing for max 100% extra tokens for padding, is that enough? + dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), + ) + self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) + else: + unshuffled_tokens = 0 + + if not self._truncate_documents: + yaml_data["unshuffled_tokens"] = unshuffled_tokens + self._load_yaml_data(yaml_data) + if self._yaml_path is not None: + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + + if shuffled_epochs > 0: + token_cumsum_shuffled, _ = self._get_token_cumsum( + document_sizes[ + # Torch indexing only works with int32 or int64 + document_shuffling.to( + dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 + ) + ], + offset=self._unshuffled_tokens, + # TODO: Allowing for max 100% extra tokens for padding, is that enough? + dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), + ) + self._token_cumsum_shuffled.save(token_cumsum_shuffled) + self._document_shuffling.save( + document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( + force=self._config.gpu + ) + ) + # Free memory + del document_shuffling + + def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: + if self._truncate_documents: + # Create the output tensor. + out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype.torch) + # Get partial sums for regular intervals, excluding the last incomplete interval. + torch.sum( + sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), + dim=1, + out=out[1:], + ) + # Pad with the begin offset + out[0] = offset + # Calculate the cumsum. + out.cumsum_(0) + # Crop unnecessary entries. + out = out[ + : torch.clamp_min_( + torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"), + 0, + ) + ] + return out.numpy(force=self._config.gpu), None + else: + # TODO: dynamically handle int64 or int32 in CPP + out = build_padded_token_cumsum( + sizes.cpu().numpy(), (self._sequence_length + 1), TOKEN_CUMSUM_RATE, offset + ) + num_tokens = out[-1] + out = out[:-1][ + : np.clip(np.searchsorted(out, self._num_samples * (self._sequence_length + 1), side="right"), 0, None) + ] + return out, num_tokens + + def __len__(self) -> int: + return self._num_samples + + def __getitem__(self, index: int) -> typing.Any: + """ + Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) + with the requested sampling index. + The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). + """ + self._lazy_load() + # tokens at the boundary are included in only one sample when we pack without truncations + # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample + token_start = index * (self._sequence_length + 1 - self._truncate_documents) + token_end = token_start + self._sequence_length + 1 + + if token_start < self._unshuffled_tokens: + token_start_array = self._token_cumsum_unshuffled.array + token_start_array_document_offset = 0 + else: + token_start_array = self._token_cumsum_shuffled.array + token_start_array_document_offset = self._unshuffled_documents + + # Find the rightmost location `token_start_cumsum_index` in `token_cumsum` with `token_cumsum[token_start_cumsum_index] <= token_start` + token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1 + + document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset + + token_count = token_start_array[token_start_cumsum_index] + + token_ids = [] + loss_masking_spans = [] + while token_count < token_end: + # Find the document index in the dataset. + if document_sampling_index < self._unshuffled_documents: + document_index = document_sampling_index % self._documents_per_epoch + else: + document_index = self._document_shuffling[document_sampling_index - self._unshuffled_documents].item() + + document_size = self._indexed_dataset.get_document_size(document_index) + + if not self._truncate_documents: + if document_size > self._sequence_length + 1: + # Document too long, ignore + document_sampling_index += 1 + continue + tokens_in_sample = token_count % (self._sequence_length + 1) + if document_size + tokens_in_sample > self._sequence_length + 1: + # Document belongs to the next sample, need to account for padding. + padding_size = self._sequence_length + 1 - tokens_in_sample + if token_count > token_start: + # Add padding tokens to current sample + token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) + Assert.eq(token_count + padding_size, token_end) + break + else: + # Move on to the next sample. + token_count += padding_size + + # Determine if the document belongs to the requested sample. + if token_count + document_size >= token_start: + # Determine which part of the document belong to the sample, and add it to the list. + token_start_index_in_document = max(token_start - token_count, 0) + token_end_index_in_document = min(token_end - token_count, document_size) + sample = self._indexed_dataset.get( + document_index, + offset=token_start_index_in_document, + length=token_end_index_in_document - token_start_index_in_document, + use_loss_masking_spans=self._config.use_loss_masking_spans, + ) + token_ids.append(sample.token_ids) + if self._config.use_loss_masking_spans: + for loss_masking_span in sample.loss_masking_spans: + span = np.clip(loss_masking_span + token_count - token_start, 0, self._sequence_length + 1) + if span[1] > span[0]: + loss_masking_spans.append(span) + + # Go to the next document. + document_sampling_index += 1 + token_count += document_size + + sequence_lengths = ( + np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) + if not self._cross_document_attention + else None + ) + token_ids = np.concatenate(token_ids, dtype=np.int64) + loss_masking_spans = ( + (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) + if self._config.use_loss_masking_spans + else None + ) + Assert.eq(len(token_ids), self._sequence_length + 1) + + return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths) + + @property + def name(self) -> str: + return self._indexed_dataset.name + + def _lazy_load(self): + if not hasattr(self, "_documents_per_epoch"): + self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) + def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] From aa94f9a83b4c5e25d56924d19e8f18fe4fef429c Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 14 Apr 2025 23:30:31 +0000 Subject: [PATCH 17/47] removing extra lines --- fast_llm/data/dataset/gpt/sampled.py | 446 --------------------------- 1 file changed, 446 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 13888fb4e..95b60e99f 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -320,452 +320,6 @@ def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch - """ - A sampled GPT dataset. - """ - - def __init__( - self, - indexed_dataset: GPTIndexedDataset, - sampling: GPTSamplingData, - ): - assert isinstance(sampling, GPTSamplingData) - self._indexed_dataset = indexed_dataset - self._num_samples = sampling.num_samples - self._sequence_length = sampling.sequence_length - self._cross_document_attention = sampling.cross_document_attention - self._config = sampling.config - self._truncate_documents = sampling.truncate_documents - self._device = torch.device("cuda" if self._config.gpu else "cpu") - - if self._config.enable_packing and self._config.use_preference_loss_masking_spans: - raise NotImplementedError("Packing currently not implemented with preference loss masking.") - if not self._config.enable_packing and self._truncate_documents: - raise NotImplementedError("If packing is disabled, document truncation must also be disabled.") - - if sampling.cache_directory is None: - self._document_shuffling = MemmapArray() - self._token_cumsum_shuffled = MemmapArray() - self._token_cumsum_unshuffled = MemmapArray() - self._yaml_path = None - log_main_rank( - " > No dataset cache directory provided, building the index map on all ranks." - "This may be very inefficient...", - log_fn=logger.warning, - ) - self._sample() - else: - base_path = ( - sampling.cache_directory / f"{self.name}_ns_{self._num_samples}_sl_{self._sequence_length}" - f"_s_{self._config.seed}" - ) - # TODO: Names are confusing - - # contains shuffled document indicies - self._document_shuffling = MemmapArray(base_path.with_name(base_path.name + "_shuffling.npy")) - - # contains cumulative sum of document sizes grouped by TOKEN_CUMSUM_RATE in shuffled order - self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) - self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) - - if not self._config.enable_packing: - self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) - self._doc_length_filtered_indicies = MemmapArray( - base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") - ) - - self._yaml_path = base_path.with_suffix(".yaml") - # Sample or validate the dataset of a given rank. - if sampling.distributed.config.rank == sampling.get_next_rank(): - self._sample() - # No barrier yet to allow running in parallel. - # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. - - def _sample(self) -> None: - """ - Create a `GPTSampledDataset` with the requested parameters. - """ - # Get the document sizes, the main information needed for sampling. - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) - documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() - - # Calculate basic stats. - if not self._truncate_documents: - assert _extension_available, ( - "The C++ extension for dataset sampling is missing." - " Please make sure Fast-LLM is installed correctly." - ) - long_docs_filter = document_sizes > self._sequence_length - ignored_documents = sum(long_docs_filter) - if ignored_documents: - log_main_rank( - f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", - log_fn=logger.warning, - ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() - if tokens_per_epoch == 0: - raise RuntimeError( - f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." - ) - # TODO MTP: Produce more labels to provide labels for the multi-token prediction heads? - # We produce sequences of length `self._sequence_length + 1` so the last token has a label, - # but in case of truncations we also include that last label in the following sample, - # so we need `sequence_length * num_samples + 1` tokens in total. - if self._config.enable_packing: - num_epochs = math.ceil( - ( - (self._sequence_length + 1 - self._truncate_documents) * self._num_samples - + 1 * self._truncate_documents - ) - / tokens_per_epoch - ) - else: - documents_per_epoch = (~long_docs_filter).sum().item() - num_epochs = math.ceil(self._num_samples / documents_per_epoch) - - # Prepare for shuffling. - generator = torch.Generator(device=self._device) - if self._config.shuffle == ShufflingType.skip_first_epoch: - shuffled_epochs = num_epochs - 1 - elif self._config.shuffle == ShufflingType.disabled: - shuffled_epochs = 0 - else: - shuffled_epochs = num_epochs - shuffled_documents = documents_per_epoch * shuffled_epochs - unshuffled_epochs = num_epochs - shuffled_epochs - - yaml_data = { - "dataset": { - "name": self._indexed_dataset.name, - "documents_per_epoch": documents_per_epoch, - "tokens_per_epoch": tokens_per_epoch, - }, - "num_samples": self._num_samples, - "unshuffled_epochs": unshuffled_epochs, - "sequence_length": self._sequence_length, - "truncate_documents": self._truncate_documents, - "config": self._config.to_serialized(), - } - if self._truncate_documents: - yaml_data["unshuffled_tokens"] = tokens_per_epoch * unshuffled_epochs - - if self._yaml_path is not None and self._yaml_path.is_file(): - loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(yaml_data) - if not self._truncate_documents: - del loaded_yaml_data["unshuffled_tokens"] - - if loaded_yaml_data != yaml_data: - raise RuntimeError( - f"Invalid dataset cache for dataset {self.name}." - " If this is due to an intended configuration change," - " please delete the cache before continuing." - f"\nCurrent config:\n{yaml.safe_dump(yaml_data)}" - f"\nCached config:\n{yaml.safe_dump(loaded_yaml_data)}" - ) - # Dataset is already sampled, skip. - logger.info(f"Using existing sampling for dataset {self.name}") - return - - if shuffled_documents > 1e8: - warnings.warn( - f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." - f" This may take a while and/or use an excessive amount of memory." - ) - elif documents_per_epoch > 1e8: - # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? - warnings.warn( - f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." - f" Sampling may take a while and/or use an excessive amount of memory." - ) - - # Use the smallest possible data type to save memory and disk usage. - document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch - # Shuffle the dataset (documents) - # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial - # so we only evaluate and store the shuffled part `document_shuffling`. - if self._config.shuffle == ShufflingType.full: - generator.manual_seed(self._config.seed) - # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` - document_shuffling = ( - torch.randperm( - shuffled_documents, - generator=generator, - dtype=get_unsigned_integer_type(shuffled_documents).torch, - device=self._device, - ) - .remainder_(documents_per_epoch) - .to(dtype=document_shuffling_dtype) - ) - elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): - document_shuffling = torch.empty( - shuffled_documents, - dtype=document_shuffling_dtype, - device=self._device, - ) - for i in range(shuffled_epochs): - generator.manual_seed(self._config.seed + i * 571) - torch.randperm( - documents_per_epoch, - generator=generator, - out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], - ) - elif self._config.shuffle == ShufflingType.disabled: - document_shuffling = None - else: - raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - - # To get a sample on the fly we need to know where it begins, - # and this is a non-trivial information because the documents have variable length. - # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. - # `document_sizes[all_document_index][:document[idx]].sum() + token[idx] == idx * sequence_length`. - # This can be computed quickly provided we know a (partial) sum close to `(idx * sequence_length)`. - # So it is enough to pre-compute the (zero-padded) token cumsum at regular intervals `TOKEN_CUMSUM_RATE`. - # Using `TOKEN_CUMSUM_RATE > 1` reduces pre-computation overhead at the cost of runtime computation. - # Equivalent to `torch.hstack((0, document_sizes[all_document_index].cumsum()[::TOKEN_CUMSUM_RATE]))` - if self._config.enable_packing: - if unshuffled_epochs > 0: - token_cumsum_unshuffled, unshuffled_tokens = self._get_token_cumsum( - document_sizes, - offset=0, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_unshuffled.save(token_cumsum_unshuffled) - else: - unshuffled_tokens = 0 - - if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = unshuffled_tokens - self._load_yaml_data(yaml_data) - if self._yaml_path is not None: - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - - if shuffled_epochs > 0: - token_cumsum_shuffled, _ = self._get_token_cumsum( - document_sizes[ - # Torch indexing only works with int32 or int64 - document_shuffling.to( - dtype=torch.int64 if document_shuffling.dtype == torch.int64 else torch.int32 - ) - ], - offset=self._unshuffled_tokens, - # TODO: Allowing for max 100% extra tokens for padding, is that enough? - dtype=get_unsigned_integer_type((2 - self._truncate_documents) * tokens_per_epoch * num_epochs), - ) - self._token_cumsum_shuffled.save(token_cumsum_shuffled) - self._document_shuffling.save( - document_shuffling[: (token_cumsum_shuffled.size + 1) * TOKEN_CUMSUM_RATE].numpy( - force=self._config.gpu - ) - ) - # Free memory - del document_shuffling - else: - if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = None # not used with packing disabled - - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # # apply shuffling on doc_length_filtered_indicies - # document_shuffling_length_filtered_indices = torch.gather( - # doc_length_filtered_indicies, dim=0, index=document_shuffling.to(torch.int64) - # ) - if shuffled_epochs > 0: - self._document_shuffling.save(document_shuffling[: self._num_samples].numpy(force=self._config.gpu)) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - if self._yaml_path is not None: - # yaml_data["unshuffled_tokens"] = num_tokens_unshuffled - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - - def _get_token_cumsum(self, sizes: torch.Tensor, offset: int, dtype: DataType) -> tuple[np.ndarray, int | None]: - if self._truncate_documents: - # Create the output tensor. - out = sizes.new_empty(sizes.numel() // TOKEN_CUMSUM_RATE + 1, dtype=dtype.torch) - # Get partial sums for regular intervals, excluding the last incomplete interval. - torch.sum( - sizes[: sizes.numel() - sizes.numel() % TOKEN_CUMSUM_RATE].view(-1, TOKEN_CUMSUM_RATE), - dim=1, - out=out[1:], - ) - # Pad with the begin offset - out[0] = offset - # Calculate the cumsum. - out.cumsum_(0) - # Crop unnecessary entries. - out = out[ - : torch.clamp_min_( - torch.searchsorted(out, self._num_samples * self._sequence_length, side="right"), - 0, - ) - ] - return out.numpy(force=self._config.gpu), None - else: - # TODO: dynamically handle int64 or int32 in CPP - out = build_padded_token_cumsum( - sizes.cpu().numpy(), (self._sequence_length + 1), TOKEN_CUMSUM_RATE, offset - ) - num_tokens = out[-1] - out = out[:-1][ - : np.clip(np.searchsorted(out, self._num_samples * (self._sequence_length + 1), side="right"), 0, None) - ] - return out, num_tokens - - def __len__(self) -> int: - return self._num_samples - - def __getitem__(self, index: int) -> typing.Any: - """ - Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) - with the requested sampling index. - The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). - """ - self._lazy_load() - if self._config.enable_packing: - # tokens at the boundary are included in only one sample when we pack without truncations - # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample - token_start = index * (self._sequence_length + 1 - self._truncate_documents) - token_end = token_start + self._sequence_length + 1 - - if token_start < self._unshuffled_tokens: - token_start_array = self._token_cumsum_unshuffled.array - token_start_array_document_offset = 0 - else: - token_start_array = self._token_cumsum_shuffled.array - token_start_array_document_offset = self._unshuffled_documents - - # Find the rightmost location `token_start_cumsum_index` in `token_cumsum` with `token_cumsum[token_start_cumsum_index] <= token_start` - token_start_cumsum_index = np.searchsorted(token_start_array, token_start, side="right").item() - 1 - - document_sampling_index = token_start_cumsum_index * TOKEN_CUMSUM_RATE + token_start_array_document_offset - - token_count = token_start_array[token_start_cumsum_index] - - token_ids = [] - loss_masking_spans = [] - while token_count < token_end: - # Find the document index in the dataset. - if document_sampling_index < self._unshuffled_documents: - document_index = document_sampling_index % self._documents_per_epoch - else: - document_index = self._document_shuffling[ - document_sampling_index - self._unshuffled_documents - ].item() - - document_size = self._indexed_dataset.get_document_size(document_index) - - if not self._truncate_documents: - if document_size > self._sequence_length + 1: - # Document too long, ignore - document_sampling_index += 1 - continue - tokens_in_sample = token_count % (self._sequence_length + 1) - if document_size + tokens_in_sample > self._sequence_length + 1: - # Document belongs to the next sample, need to account for padding. - padding_size = self._sequence_length + 1 - tokens_in_sample - if token_count > token_start: - # Add padding tokens to current sample - token_ids.append(np.full((padding_size,), -100, dtype=np.int64)) - Assert.eq(token_count + padding_size, token_end) - break - else: - # Move on to the next sample. - token_count += padding_size - - # Determine if the document belongs to the requested sample. - if token_count + document_size >= token_start: - # Determine which part of the document belong to the sample, and add it to the list. - token_start_index_in_document = max(token_start - token_count, 0) - token_end_index_in_document = min(token_end - token_count, document_size) - sample = self._indexed_dataset.get( - document_index, - offset=token_start_index_in_document, - length=token_end_index_in_document - token_start_index_in_document, - use_loss_masking_spans=self._config.use_loss_masking_spans, - ) - token_ids.append(sample.token_ids) - if self._config.use_loss_masking_spans: - for loss_masking_span in sample.loss_masking_spans: - span = np.clip(loss_masking_span + token_count - token_start, 0, self._sequence_length + 1) - if span[1] > span[0]: - loss_masking_spans.append(span) - - # Go to the next document. - document_sampling_index += 1 - token_count += document_size - - sequence_lengths = ( - np.array([ids.size - (idx == len(token_ids) - 1) for idx, ids in enumerate(token_ids)], dtype=np.int32) - if not self._cross_document_attention - else None - ) - token_ids = np.concatenate(token_ids, dtype=np.int64) - loss_masking_spans = ( - (np.stack(loss_masking_spans, dtype=np.int32) if loss_masking_spans else np.array([])) - if self._config.use_loss_masking_spans - else None - ) - Assert.eq(len(token_ids), self._sequence_length + 1) - - return GPTSample( - token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths - ) - else: - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._config.use_loss_masking_spans, - ) - - chosen_loss_masking_span_end = sample.chosen_loss_masking_span[1] + 1 - sequence_lengths = [ - chosen_loss_masking_span_end, - len(sample.token_ids) - chosen_loss_masking_span_end, - ] - - # compute padding size - padding = np.full((self._sequence_length,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - sample.sequence_lengths = np.array(sequence_lengths) - - return sample - - @property - def name(self) -> str: - return self._indexed_dataset.name - - def _lazy_load(self): - if not hasattr(self, "_documents_per_epoch"): - self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) - - def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: - self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - - # if "unshuffled_tokens" not in data: - # # Backward compatibility - # # TODO v0.x: Remove - # assert self._truncate_documents - # data["unshuffled_tokens"] = data["tokens_per_epoch"] * data["unshuffled_epochs"] - - # self._unshuffled_tokens = data["unshuffled_tokens"] - self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch - class GPTSampledIndexedDataset(SampledDataset): """ From 7f37038101bf4b2030c7850f206d9783738f8cb6 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 15 Apr 2025 00:21:23 +0000 Subject: [PATCH 18/47] small data configuration updates --- fast_llm/data/data/gpt/data.py | 8 +++++++- fast_llm/data/dataset/gpt/fim.py | 6 ++++-- fast_llm/data/dataset/gpt/sampled.py | 4 ---- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 76d90a7e7..0e39f8c6a 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -170,6 +170,12 @@ def get_iterator( log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") dataset = self._datasets[dataset_name] # noqa + + if hasattr(dataset._dataset, "_indexed_dataset"): + use_preference_loss_masking_spans = dataset._dataset._indexed_dataset._has_preference_spans + else: + use_preference_loss_masking_spans = False + return iter( torch.utils.data.DataLoader( dataset, @@ -187,7 +193,7 @@ def get_iterator( gpt_data_collate_fn, use_loss_masking_spans=self._config.sampling.use_loss_masking_spans, cross_document_attention=self._cross_document_attention, - use_preference_loss_masking_spans=dataset._dataset._indexed_dataset._has_preference_spans, + use_preference_loss_masking_spans=use_preference_loss_masking_spans, ), multiprocessing_context=self._config.multiprocessing_context.value if num_workers > 0 else None, ) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 192e31315..580421fa4 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -20,10 +20,12 @@ def __init__( ): if sampling.config.use_loss_masking_spans: raise NotImplementedError("FIM is currently not compatible with loss masking.") - if sampling.config.use_preference_loss_masking_spans: - raise NotImplementedError("FIM is currently not compatible with preference loss masking.") self._config = config self._dataset = dataset + + if hasattr(self._dataset, "_indexed_dataset") and dataset._dataset._indexed_dataset._has_preference_spans: + raise NotImplementedError("FIM is currently not compatible with preference loss masking.") + self._seed = sampling.config.seed self._tokenizer = sampling.tokenizer if self._tokenizer is None: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 95b60e99f..9eed88b93 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -708,10 +708,6 @@ def __init__( self._indexed_dataset = indexed_dataset self._num_samples = sampling.num_samples self._sequence_length = sampling.sequence_length - if not sampling.config.enable_packing: - raise NotImplementedError( - "Legacy sampling only supports document packing. Please use the latest dataset format." - ) if not sampling.truncate_documents: raise NotImplementedError( "Legacy sampling only supports document truncation. Please use the latest dataset format." From 0d7ccbde47a9b2860265153444936188157aefbd Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 15 Apr 2025 00:24:50 +0000 Subject: [PATCH 19/47] update test case --- tests/data/test_prepare_gpt_memmap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index b1a488f4f..cdd9606f1 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -62,7 +62,7 @@ def generate_valid_span(max_seq_length): GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): - dataset_item = dataset.get(i, use_preference_loss_masking_spans=True) + dataset_item = dataset.get(i) assert np.array_equal( dataset_item.token_ids, document.token_ids, equal_nan=True ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." From 5fd1c865561eeac622daec537f602bdc354b891b Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 15 Apr 2025 00:38:40 +0000 Subject: [PATCH 20/47] logp span using index instead --- fast_llm/functional/dpo.py | 76 ++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 23 deletions(-) diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index b7ef8ccd3..7140e3706 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -1,47 +1,77 @@ import torch -import torch.nn.functional as F -from typing import Tuple -def compute_logps_for_spans( - logits: torch.Tensor, - targets: torch.Tensor, - chosen_span: torch.Tensor, - rejected_span: torch.Tensor - ): +def compute_logprobs_for_spans( + logits: torch.Tensor, targets: torch.Tensor, chosen_span: torch.Tensor, rejected_span: torch.Tensor +): + assert torch.all(targets < logits.size(-1)), "Target out of vocab range" + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - + # gather log probabilities corresponding to the target tokens selected_log_probs = log_probs[:-1].gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - + # apply chosen mask - chosen_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool) - chosen_mask[chosen_span[:, 0]: chosen_span[:, 1] + 1] = 1 - chosen_logp = (selected_log_probs * chosen_mask).sum() + # chosen_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool) + # chosen_mask[chosen_span[0][0].item(): chosen_span[0][1].item() + 1] = 1 + # chosen_logp = (selected_log_probs * chosen_mask).sum() + chosen_logp = selected_log_probs[chosen_span[0][0].item() : chosen_span[0][1].item() + 1].sum() # apply rejected mask - rejected_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool) - rejected_mask[rejected_span[:, 0]: rejected_span[:, 1] + 1] = 1 - rejected_logp = (selected_log_probs * rejected_mask).sum() + # rejected_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool) + # rejected_mask[rejected_span[0][0].item(): rejected_span[0][1].item() + 1] = 1 + # rejected_logp = (selected_log_probs * rejected_mask).sum() + rejected_logp = selected_log_probs[rejected_span[0][0].item() : rejected_span[0][1].item() + 1].sum() return chosen_logp, rejected_logp + +# def compute_simpo_loss( +# logits: torch.Tensor, +# targets: torch.Tensor, +# chosen_span: torch.Tensor, +# rejected_span: torch.Tensor, +# beta: float, +# grad_output: float | None +# ) -> tuple[torch.Tensor, torch.Tensor]: +# with torch.enable_grad(): +# logits_ = logits.float().detach().requires_grad_() +# policy_chosen_logps, policy_rejected_logps = compute_logprobs_for_spans(logits_, targets, chosen_span, rejected_span) + +# len_chosen_span = (chosen_span[0][1] - chosen_span[0][0]).item() +# len_rej_span = (rejected_span[0][1] - rejected_span[0][0]).item() + +# pi_logratios = (beta / (len_chosen_span + 1e-8)) * policy_chosen_logps - (beta / (len_rej_span + 1e-8)) * policy_rejected_logps + +# losses = -torch.nn.functional.logsigmoid(pi_logratios) + +# if grad_output is None: +# loss = None +# else: +# loss = losses.mean() +# loss.backward(torch.full_like(loss, grad_output)) +# loss.detach() +# return loss.detach(), logits_.grad.detach().to(logits.dtype) + + def compute_simplified_dpo_loss( - logits: torch.Tensor, - targets: torch.Tensor, - chosen_span: torch.Tensor, + logits: torch.Tensor, + targets: torch.Tensor, + chosen_span: torch.Tensor, rejected_span: torch.Tensor, beta: float, - grad_output: float | None -) -> Tuple[torch.Tensor, torch.Tensor]: + grad_output: float | None, +) -> tuple[torch.Tensor, torch.Tensor]: with torch.enable_grad(): logits_ = logits.float().detach().requires_grad_() - policy_chosen_logps, policy_rejected_logps = compute_logps_for_spans(logits_, targets, chosen_span, rejected_span) + policy_chosen_logps, policy_rejected_logps = compute_logprobs_for_spans( + logits_, targets, chosen_span, rejected_span + ) pi_logratios = policy_chosen_logps - policy_rejected_logps - losses = -F.logsigmoid(beta * pi_logratios) + losses = -torch.nn.functional.logsigmoid(beta * pi_logratios) if grad_output is None: loss = None else: From 63db041a1ae7cd4e858dfebde8d929aac3e3098e Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 15 Apr 2025 00:39:39 +0000 Subject: [PATCH 21/47] small updates --- fast_llm/data/data/gpt/data.py | 4 +++- fast_llm/data/dataset/gpt/fim.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 0e39f8c6a..54fc39e75 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -171,7 +171,9 @@ def get_iterator( dataset = self._datasets[dataset_name] # noqa - if hasattr(dataset._dataset, "_indexed_dataset"): + if hasattr(dataset._dataset, "_indexed_dataset") and hasattr( + dataset._dataset._indexed_dataset, "_has_preference_spans" + ): use_preference_loss_masking_spans = dataset._dataset._indexed_dataset._has_preference_spans else: use_preference_loss_masking_spans = False diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 580421fa4..94d68c3c1 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -23,7 +23,11 @@ def __init__( self._config = config self._dataset = dataset - if hasattr(self._dataset, "_indexed_dataset") and dataset._dataset._indexed_dataset._has_preference_spans: + if ( + hasattr(self._dataset, "_indexed_dataset") + and hasattr(self._dataset._indexed_dataset, "_has_preference_spans") + and dataset._dataset._indexed_dataset._has_preference_spans + ): raise NotImplementedError("FIM is currently not compatible with preference loss masking.") self._seed = sampling.config.seed From dab6dabf59f1b12301b4e0f9dd738830a966b59c Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 15 Apr 2025 00:47:28 +0000 Subject: [PATCH 22/47] small fix --- fast_llm/data/dataset/gpt/indexed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index 2596a5af6..e869827f6 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -34,7 +34,7 @@ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": if sampling.config.shuffle == ShufflingType.legacy: return LegacyGPTSampledIndexedDataset(self, sampling) - elif self._has_preference_spans: + elif hasattr(self, "_has_preference_spans") and self._has_preference_spans: return GPTSampledPreferenceIndexedDataset(self, sampling) else: return GPTSampledIndexedDataset(self, sampling) From 41fb3e3309904ccf342540eb400fa0f14e598556 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 15 Apr 2025 00:59:46 +0000 Subject: [PATCH 23/47] fixing fim --- fast_llm/data/dataset/gpt/fim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 94d68c3c1..acd3d1dd3 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -26,7 +26,7 @@ def __init__( if ( hasattr(self._dataset, "_indexed_dataset") and hasattr(self._dataset._indexed_dataset, "_has_preference_spans") - and dataset._dataset._indexed_dataset._has_preference_spans + and self._dataset._indexed_dataset._has_preference_spans ): raise NotImplementedError("FIM is currently not compatible with preference loss masking.") From 3d77986d5624eb0f4ec29b46e2ddd5f518955717 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 15 Apr 2025 01:17:51 +0000 Subject: [PATCH 24/47] adding checks for chosen/rej spans in memmap dataset --- fast_llm/data/dataset/gpt/memmap.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index 1725d6fac..d719ec53d 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -190,6 +190,11 @@ def get( # subtract by offset to normalize span boundaries rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + elif self._has_preference_spans: + if self._chosen_spans is None: + raise ValueError("Failed to read chosen spans from memmap dataset.") + if self._rejected_spans is None: + raise ValueError("Failed to read rejected spans from memmap dataset.") return GPTSample( token_ids=token_ids, From 905bc00fc15bfacbb86584ecada88b4b27e06ea2 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 16 Apr 2025 21:45:37 +0000 Subject: [PATCH 25/47] refractor to preprocessor --- fast_llm/data/dataset/gpt/memmap.py | 12 ++--- .../layers/language_model/preprocessing.py | 49 +++++++++++++++++++ fast_llm/models/gpt/model.py | 34 +++---------- 3 files changed, 63 insertions(+), 32 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index d719ec53d..f32abe1d2 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -170,7 +170,12 @@ def get( chosen_span = None rejected_span = None - if self._has_preference_spans and self._chosen_spans is not None and self._rejected_spans is not None: + + if self._has_preference_spans and self._chosen_spans is None: + raise ValueError("Failed to read chosen spans from memmap dataset.") + elif self._has_preference_spans and self._rejected_spans is None: + raise ValueError("Failed to read rejected spans from memmap dataset.") + elif self._has_preference_spans: chosen_span = self._chosen_spans[idx] # filter spans that are outside the range of the selected tokens in the document @@ -190,11 +195,6 @@ def get( # subtract by offset to normalize span boundaries rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset - elif self._has_preference_spans: - if self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - if self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") return GPTSample( token_ids=token_ids, diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index 7e95bb5cc..56f36d137 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -69,3 +69,52 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: tensor_name=LanguageModelKwargs.position_ids, dtype=torch.int64, ) + + +class PreferenceSpanPreprocessor(Preprocessor): + def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + return + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels + + if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: + raise ValueError("Expected chosen spans or rejected spans to be found within the batch.") + + chosen_loss_masking_spans = kwargs[LanguageModelKwargs.chosen_spans] + for spans in chosen_loss_masking_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[:, 0].clamp_(min=sequence_offset) + valid_spans[:, 1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + # TODO: check higher batch size + kwargs[LanguageModelKwargs.chosen_spans] = valid_spans + + rejected_loss_masking_spans = kwargs[LanguageModelKwargs.rejected_spans] + for spans in rejected_loss_masking_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[:, 0].clamp_(min=sequence_offset) + valid_spans[:, 1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + # TODO: check higher batch size + kwargs[LanguageModelKwargs.rejected_spans] = valid_spans diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 699ba93ff..3e44a76e0 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -15,7 +15,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor +from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, @@ -72,6 +72,9 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + if self._config.loss_function == "dpo": # TODO better way to pass in? + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + def get_output_layers(self) -> list[Layer]: return [ layer @@ -120,7 +123,6 @@ def setup(self, distributed: Distributed) -> None: self._tensor_space.setup(distributed) self._is_setup = True - def preprocess_meta( self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: @@ -254,6 +256,10 @@ def preprocess( tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths + if batch.chosen_loss_masking_spans is not None: + kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_loss_masking_spans + if batch.rejected_loss_masking_spans is not None: + kwargs_meta[LanguageModelKwargs.rejected_spans] = batch.rejected_loss_masking_spans # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. @@ -289,30 +295,6 @@ def preprocess( labels[start : end + 1, i] = -100 else: labels[i, start : end + 1] = -100 - if batch.chosen_loss_masking_spans is not None: - for i, spans in enumerate(batch.chosen_loss_masking_spans): - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - kwargs[LanguageModelKwargs.chosen_spans] = valid_spans - if batch.rejected_loss_masking_spans is not None: - for i, spans in enumerate(batch.rejected_loss_masking_spans): - if not spans.numel(): - continue - # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] - if valid_spans.numel(): - # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k) - valid_spans -= sequence_offset - kwargs[LanguageModelKwargs.rejected_spans] = valid_spans kwargs[LanguageModelKwargs.labels] = labels for preprocessor in self._preprocessors: preprocessor.preprocess(tokens, kwargs) From 52c8f9f3d360975bf6eb507e4cbdb9c40442f092 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 18 Apr 2025 18:21:43 +0000 Subject: [PATCH 26/47] moving puse_pref_loss_spans to sampling parameters and combining sampling classes --- fast_llm/data/data/gpt/data.py | 7 - fast_llm/data/dataset/gpt/config.py | 1 + fast_llm/data/dataset/gpt/fim.py | 9 +- fast_llm/data/dataset/gpt/indexed.py | 17 +- fast_llm/data/dataset/gpt/memmap.py | 52 +++-- fast_llm/data/dataset/gpt/sampled.py | 323 ++++++--------------------- fast_llm/models/gpt/config.py | 5 + fast_llm/models/gpt/trainer.py | 1 + 8 files changed, 114 insertions(+), 301 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index dd9826e01..6c568cab0 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -161,13 +161,6 @@ def get_iterator( dataset = self._datasets[dataset_name] # noqa - if hasattr(dataset._dataset, "_indexed_dataset") and hasattr( - dataset._dataset._indexed_dataset, "_has_preference_spans" - ): - dataset._dataset._indexed_dataset._has_preference_spans - else: - pass - return iter( torch.utils.data.DataLoader( dataset, diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index c347a5c70..942c38524 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -73,6 +73,7 @@ class GPTSamplingParameters(SamplingParameters): sequence_length: int vocab_size: int use_loss_masking_spans: bool = False + use_preference_loss_masking_spans: bool = False cross_document_attention: bool = True diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 9c77a4f50..5eeba59b5 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -20,16 +20,11 @@ def __init__( ): if sampling.parameters.use_loss_masking_spans: raise NotImplementedError("FIM is currently not compatible with loss masking.") + if sampling.parameters.use_preference_loss_masking_spans: + raise NotImplementedError("FIM is currently not compatible with preference loss masking.") self._config = config self._dataset = dataset - if ( - hasattr(self._dataset, "_indexed_dataset") - and hasattr(self._dataset._indexed_dataset, "_has_preference_spans") - and self._dataset._indexed_dataset._has_preference_spans - ): - raise NotImplementedError("FIM is currently not compatible with preference loss masking.") - self._seed = sampling.config.seed self._tokenizer = sampling.tokenizer if self._tokenizer is None: diff --git a/fast_llm/data/dataset/gpt/indexed.py b/fast_llm/data/dataset/gpt/indexed.py index e869827f6..688ea6a70 100644 --- a/fast_llm/data/dataset/gpt/indexed.py +++ b/fast_llm/data/dataset/gpt/indexed.py @@ -26,18 +26,13 @@ def get_document_size(self, index: int) -> int: """ def sample(self, sampling: GPTSamplingData) -> "GPTSampledIndexedDataset": - from fast_llm.data.dataset.gpt.sampled import ( - GPTSampledIndexedDataset, - GPTSampledPreferenceIndexedDataset, - LegacyGPTSampledIndexedDataset, - ) + from fast_llm.data.dataset.gpt.sampled import GPTSampledIndexedDataset, LegacyGPTSampledIndexedDataset - if sampling.config.shuffle == ShufflingType.legacy: - return LegacyGPTSampledIndexedDataset(self, sampling) - elif hasattr(self, "_has_preference_spans") and self._has_preference_spans: - return GPTSampledPreferenceIndexedDataset(self, sampling) - else: - return GPTSampledIndexedDataset(self, sampling) + return ( + LegacyGPTSampledIndexedDataset(self, sampling) + if sampling.config.shuffle == ShufflingType.legacy + else GPTSampledIndexedDataset(self, sampling) + ) class GPTDatasetSlice[IndexedDatasetType: GPTIndexedDataset](DatasetSlice[IndexedDatasetType], GPTIndexedDataset): diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index f32abe1d2..a73f4a37e 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -148,6 +148,7 @@ def get( offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, + use_preference_loss_masking_spans: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -171,30 +172,33 @@ def get( chosen_span = None rejected_span = None - if self._has_preference_spans and self._chosen_spans is None: - raise ValueError("Failed to read chosen spans from memmap dataset.") - elif self._has_preference_spans and self._rejected_spans is None: - raise ValueError("Failed to read rejected spans from memmap dataset.") - elif self._has_preference_spans: - chosen_span = self._chosen_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] - - # subtract by offset to normalize span boundaries - chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset - chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset - - rejected_span = self._rejected_spans[idx] - - # filter spans that are outside the range of the selected tokens in the document - rejected_span = rejected_span[(rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset)][ - 0 - ] - - # subtract by offset to normalize span boundaries - rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset - rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + if use_preference_loss_masking_spans: + if self._has_preference_spans and self._chosen_spans is None: + raise ValueError("Failed to read chosen spans from memmap dataset.") + elif self._has_preference_spans and self._rejected_spans is None: + raise ValueError("Failed to read rejected spans from memmap dataset.") + elif self._has_preference_spans: + chosen_span = self._chosen_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] + + # subtract by offset to normalize span boundaries + chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset + chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset + + rejected_span = self._rejected_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + rejected_span = rejected_span[ + (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) + ][0] + + # subtract by offset to normalize span boundaries + rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset + rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + else: + raise ValueError("No preference spans found in memmap dataset.") return GPTSample( token_ids=token_ids, diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 9b8de18f2..48892fe65 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -75,252 +75,6 @@ def _lazy_load(self): TOKEN_CUMSUM_RATE = 10 -class GPTSampledPreferenceIndexedDataset(SampledDataset): - """ - A sampled GPT dataset for preference data. - """ - - def __init__( - self, - indexed_dataset: GPTIndexedDataset, - sampling: GPTSamplingData, - ): - assert isinstance(sampling, GPTSamplingData) - self._indexed_dataset = indexed_dataset - self._num_samples = sampling.num_samples - self._sequence_length = sampling.sequence_length - self._cross_document_attention = sampling.cross_document_attention - self._config = sampling.config - self._truncate_documents = sampling.truncate_documents - self._device = torch.device("cuda" if self._config.gpu else "cpu") - - if self._truncate_documents: - raise NotImplementedError("Sampled preference indexed dataset does not support document trunctation.") - - if sampling.cache_directory is None: - self._document_shuffling = MemmapArray() - self._token_cumsum_shuffled = MemmapArray() - self._token_cumsum_unshuffled = MemmapArray() - self._yaml_path = None - log_main_rank( - " > No dataset cache directory provided, building the index map on all ranks." - "This may be very inefficient...", - log_fn=logger.warning, - ) - self._sample() - else: - base_path = ( - sampling.cache_directory / f"{self.name}_ns_{self._num_samples}_sl_{self._sequence_length}" - f"_s_{self._config.seed}" - ) - # TODO: Names are confusing - - # contains shuffled document indicies - self._document_shuffling = MemmapArray(base_path.with_name(base_path.name + "_shuffling.npy")) - - self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) - self._doc_length_filtered_indicies = MemmapArray( - base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") - ) - - self._yaml_path = base_path.with_suffix(".yaml") - # Sample or validate the dataset of a given rank. - if sampling.distributed.config.rank == sampling.get_next_rank(): - self._sample() - # No barrier yet to allow running in parallel. - # There needs to be one before calling `__getitem__`, normally handled through `GPTData`. - - def _sample(self) -> None: - """ - Create a `GPTSampledDataset` with the requested parameters. - """ - # compute document sizes, documents per epoch, and tok per epoch - document_sizes = torch.from_numpy(self._indexed_dataset.get_document_sizes()).to(self._device) - documents_per_epoch = document_sizes.numel() - tokens_per_epoch = document_sizes.sum().item() - - # filter documents past the sequence length - long_docs_filter = document_sizes > self._sequence_length + 1 # TODO: do we need the +1 here if no truncation? - ignored_documents = sum(long_docs_filter) - if ignored_documents: - log_main_rank( - f" > {ignored_documents}/{documents_per_epoch} documents are longer than {self._sequence_length+1} tokens and will be ignored.", - log_fn=logger.warning, - ) - tokens_per_epoch = document_sizes[~long_docs_filter].sum().item() - if tokens_per_epoch == 0: - raise RuntimeError( - f" > No documents shorter than {self._sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." - ) - - # update documents per epochs/num epochs based on filter - documents_per_epoch = (~long_docs_filter).sum().item() - num_epochs = math.ceil(self._num_samples / documents_per_epoch) - - # compute shuffled/unshuffled epochs - generator = torch.Generator(device=self._device) - if self._config.shuffle == ShufflingType.skip_first_epoch: - shuffled_epochs = num_epochs - 1 - elif self._config.shuffle == ShufflingType.disabled: - shuffled_epochs = 0 - else: - shuffled_epochs = num_epochs - shuffled_documents = documents_per_epoch * shuffled_epochs - unshuffled_epochs = num_epochs - shuffled_epochs - - yaml_data = { - "dataset": { - "name": self._indexed_dataset.name, - "documents_per_epoch": documents_per_epoch, - "tokens_per_epoch": tokens_per_epoch, - }, - "num_samples": self._num_samples, - "unshuffled_epochs": unshuffled_epochs, - "sequence_length": self._sequence_length, - "truncate_documents": self._truncate_documents, - "config": self._config.to_serialized(), - } - - if self._yaml_path is not None and self._yaml_path.is_file(): - loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(yaml_data) - del loaded_yaml_data["unshuffled_tokens"] - - if loaded_yaml_data != yaml_data: - raise RuntimeError( - f"Invalid dataset cache for dataset {self.name}." - " If this is due to an intended configuration change," - " please delete the cache before continuing." - f"\nCurrent config:\n{yaml.safe_dump(yaml_data)}" - f"\nCached config:\n{yaml.safe_dump(loaded_yaml_data)}" - ) - # Dataset is already sampled, skip. - logger.info(f"Using existing sampling for dataset {self.name}") - return - - if shuffled_documents > 1e8: - warnings.warn( - f"Shuffling {shuffled_documents:.2e} documents for dataset {self._indexed_dataset.name}." - f" This may take a while and/or use an excessive amount of memory." - ) - elif documents_per_epoch > 1e8: - # TODO: Most of the damage is already done in `get_document_sizes`. Find a way to warn earlier? - warnings.warn( - f"The dataset {self._indexed_dataset.name} contains {documents_per_epoch:.2e} documents." - f" Sampling may take a while and/or use an excessive amount of memory." - ) - - # Use the smallest possible data type to save memory and disk usage. - document_shuffling_dtype = get_unsigned_integer_type(documents_per_epoch).torch - # Shuffle the dataset (documents) - # This generates a document shuffling index `all_document_index`, the unshuffled part is trivial - # so we only evaluate and store the shuffled part `document_shuffling`. - if self._config.shuffle == ShufflingType.full: - generator.manual_seed(self._config.seed) - # Equivalent to `shuffle(range(documents_per_epoch * num_epochs)) % documents_per_epoch` - document_shuffling = ( - torch.randperm( - shuffled_documents, - generator=generator, - dtype=get_unsigned_integer_type(shuffled_documents).torch, - device=self._device, - ) - .remainder_(documents_per_epoch) - .to(dtype=document_shuffling_dtype) - ) - elif self._config.shuffle in (ShufflingType.skip_first_epoch, ShufflingType.epoch): - document_shuffling = torch.empty( - shuffled_documents, - dtype=document_shuffling_dtype, - device=self._device, - ) - for i in range(shuffled_epochs): - generator.manual_seed(self._config.seed + i * 571) - torch.randperm( - documents_per_epoch, - generator=generator, - out=document_shuffling[i * documents_per_epoch : (i + 1) * documents_per_epoch], - ) - elif self._config.shuffle == ShufflingType.disabled: - document_shuffling = None - else: - raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - - yaml_data["unshuffled_tokens"] = None # not used with packing disabled - - # index of all documents less than seq length long - doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] - self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) - - # save document shuffling and document sizes - if shuffled_epochs > 0: - self._document_shuffling.save(document_shuffling[: self._num_samples].numpy(force=self._config.gpu)) - self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) - - # save yaml - if self._yaml_path is not None: - # yaml_data["unshuffled_tokens"] = num_tokens_unshuffled - self._yaml_path.parent.mkdir(parents=True, exist_ok=True) - yaml.safe_dump(yaml_data, self._yaml_path.open("w")) - - def __len__(self) -> int: - return self._num_samples - - def __getitem__(self, index: int) -> typing.Any: - """ - Get the sample, (fixed-length sequence of tokens holding one or more complete or partial documents) - with the requested sampling index. - The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). - """ - self._lazy_load() - - # get the document index to read - if index < self._unshuffled_documents: - document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] - else: - document_index = self._doc_length_filtered_indicies[ - self._document_shuffling[index - self._unshuffled_documents].item() - ] - - # read document from memmap - sample = self._indexed_dataset.get( - document_index, - offset=0, - length=self._document_sizes[document_index], - use_loss_masking_spans=self._config.use_loss_masking_spans, - ) - - # compute sequence lengths - chosen_loss_masking_span_end = sample.chosen_loss_masking_span[1] + 1 - sequence_lengths = [ - chosen_loss_masking_span_end, - len(sample.token_ids) - chosen_loss_masking_span_end, - ] - - # compute padding size - padding = np.full((self._sequence_length,), 0) - padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._sequence_length - len(sample.token_ids)) - sample.token_ids = padding - - if not self._cross_document_attention: # only add sequence lengths if no cross doc attention - sample.sequence_lengths = np.array(sequence_lengths) - - return sample - - @property - def name(self) -> str: - return self._indexed_dataset.name - - def _lazy_load(self): - if not hasattr(self, "_documents_per_epoch"): - self._load_yaml_data(yaml.safe_load(self._yaml_path.open("r"))) - - def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: - self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - self._unshuffled_documents = data["unshuffled_epochs"] * self._documents_per_epoch - - class GPTSampledIndexedDataset(SampledDataset): """ A sampled GPT dataset. @@ -360,6 +114,14 @@ def __init__( self._token_cumsum_shuffled = MemmapArray(base_path.with_name(base_path.name + "_shuffled_cumsum.npy")) self._token_cumsum_unshuffled = MemmapArray(base_path.with_name(base_path.name + "_unshuffled_cumsum.npy")) self._yaml_path = base_path.with_suffix(".yaml") + + # keep document sizes and len filtered docs for preference loss masking + if self._parameters.use_preference_loss_masking_spans: + self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) + self._doc_length_filtered_indicies = MemmapArray( + base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") + ) + # Sample or validate the dataset of a given rank. if sampling.distributed.config.rank == sampling.get_next_rank(): self._sample() @@ -397,13 +159,18 @@ def _sample(self) -> None: # We produce sequences of length `self._sequence_length + 1` so the last token has a label, # but in case of truncations we also include that last label in the following sample, # so we need `sequence_length * num_samples + 1` tokens in total. - num_epochs = math.ceil( - ( - (self._parameters.sequence_length + 1 - self._truncate_documents) * self._parameters.num_samples - + 1 * self._truncate_documents + + if self._parameters.use_preference_loss_masking_spans: + documents_per_epoch = (~long_docs_filter).sum().item() + num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) + else: + num_epochs = math.ceil( + ( + (self._parameters.sequence_length + 1 - self._truncate_documents) * self._parameters.num_samples + + 1 * self._truncate_documents + ) + / tokens_per_epoch ) - / tokens_per_epoch - ) # Prepare for shuffling. generator = torch.Generator(device=self._device) @@ -497,6 +264,23 @@ def _sample(self) -> None: else: raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + if self._parameters.use_preference_loss_masking_spans: + if not self._truncate_documents: + yaml_data["unshuffled_tokens"] = None + + # index of all documents less than seq length long + doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] + self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) + + # apply shuffling on doc_length_filtered_indicies + if shuffled_epochs > 0: + self._document_shuffling.save(document_shuffling[: self._num_samples].numpy(force=self._config.gpu)) + self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) + if self._yaml_path is not None: + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + return + # To get a sample on the fly we need to know where it begins, # and this is a non-trivial information because the documents have variable length. # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. @@ -595,6 +379,39 @@ def __getitem__(self, index: int) -> typing.Any: The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). """ self._lazy_load() + + if self._parameters.use_preference_loss_masking_spans: + if index < self._unshuffled_documents: + document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] + else: + document_index = self._doc_length_filtered_indicies[ + self._document_shuffling[index - self._unshuffled_documents].item() + ] + + sample = self._indexed_dataset.get( + document_index, + offset=0, + length=self._document_sizes[document_index], + use_loss_masking_spans=self._parameters.use_loss_masking_spans, + use_preference_loss_masking_spans=self._parameters.use_preference_loss_masking_spans, + ) + + chosen_loss_masking_span_end = sample.chosen_loss_masking_span[1] + 1 + sequence_lengths = [ + chosen_loss_masking_span_end, + len(sample.token_ids) - chosen_loss_masking_span_end, + ] + + # compute padding size + padding = np.full((self._parameters.sequence_length,), 0) + padding[: len(sample.token_ids)] = sample.token_ids + sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) + sample.token_ids = padding + + sample.sequence_lengths = np.array(sequence_lengths) + + return sample + # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample token_start = index * (self._parameters.sequence_length + 1 - self._truncate_documents) @@ -722,6 +539,8 @@ def __init__( raise NotImplementedError( "Legacy sampling only supports document truncation. Please use the latest dataset format." ) + if sampling.use_preference_loss_masking_spans: + raise NotImplementedError("Legacy sampling does not support preference loss masking.") self._config = sampling.config self._parameters = sampling.parameters diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index b78c3311b..3c01b7dbc 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -93,6 +93,11 @@ class GPTBatchConfig(BatchConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + use_preference_loss_masking_spans: bool = Field( + default=False, + desc="Read loss masking spans from the dataset.", + hint=FieldHint.feature, + ) def _validate(self) -> None: if self.micro_sequence_length is None: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index a269f5a63..9b24c972c 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -48,6 +48,7 @@ def _get_sampling_parameters( "vocab_size": self._config.model.base_model.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, + "use_preference_loss_masking_spans": self._config.batch.use_preference_loss_masking_spans, "cross_document_attention": self._config.batch.cross_document_attention, } ) From 9067b6ad3e30a0d3d171f4944b0b6b0da116f2d8 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 23 Apr 2025 19:19:02 +0000 Subject: [PATCH 27/47] dpo loss enabling flag --- fast_llm/functional/config.py | 5 --- fast_llm/layers/language_model/config.py | 8 ++--- fast_llm/layers/language_model/head.py | 32 +++++++++---------- .../layers/language_model/preprocessing.py | 2 -- fast_llm/models/gpt/model.py | 2 +- 5 files changed, 20 insertions(+), 29 deletions(-) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index b44939497..7284ca071 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -93,11 +93,6 @@ class CrossEntropyImpl(str, enum.Enum): triton = "triton" -class LossFunction(str, enum.Enum): - cross_entropy = "cross_entropy" - dpo = "dpo" - - class TargetFormat(enum.StrEnum): labels = "labels" logits = "logits" diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index cc16dead4..040594cc4 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -4,7 +4,7 @@ from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames -from fast_llm.functional.config import CrossEntropyImpl, LossFunction +from fast_llm.functional.config import CrossEntropyImpl from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig from fast_llm.utils import Assert @@ -143,9 +143,9 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", hint=FieldHint.feature, ) - loss_function: LossFunction = Field( - default=LossFunction.cross_entropy, - desc="Type of loss function to use", + use_dpo_loss: bool | None = Field( + default=False, + desc="Whether to enable DPO loss", hint=FieldHint.feature, ) dpo_beta: float | None = Field( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 62260063d..99d04fa0d 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -10,7 +10,7 @@ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward -from fast_llm.functional.config import CrossEntropyImpl, LossFunction, TargetFormat, TritonConfig +from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward from fast_llm.functional.dpo import compute_simplified_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward @@ -72,8 +72,10 @@ def __init__( self._init_output_weights(hidden_dim, config) - self._loss_function = config.loss_function - if self._loss_function == LossFunction.cross_entropy: + self._use_dpo_loss = config.use_dpo_loss + if self._use_dpo_loss: + self.dpo_beta = config.dpo_beta + else: self._cross_entropy_impl = config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: if self._parallel_embeddings: @@ -82,10 +84,6 @@ def __init__( self._cross_entropy_impl = CrossEntropyImpl.triton else: self._cross_entropy_impl = CrossEntropyImpl.fused - elif self._loss_function == LossFunction.dpo: - self.dpo_beta = config.dpo_beta - else: - raise NotImplementedError(f"Loss function type {self._loss_function} not supported.") self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) @@ -301,7 +299,16 @@ def _logits_cross_entropy_forward_backward( if target is None: return logits * self._logits_scale_factor, None - if self._loss_function == LossFunction.cross_entropy: + if self._use_dpo_loss: + loss, grad = compute_simplified_dpo_loss( + logits.flatten(0, -2), + target, + kwargs[LanguageModelKwargs.chosen_spans], + kwargs[LanguageModelKwargs.rejected_spans], + self.dpo_beta, + grad_output, + ) + else: loss, grad = cross_entropy_forward_backward( logits.flatten(0, -2), target, @@ -311,15 +318,6 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits, ) - elif self._loss_function == LossFunction.dpo: - loss, grad = compute_simplified_dpo_loss( - logits.flatten(0, -2), - target, - kwargs[LanguageModelKwargs.chosen_spans], - kwargs[LanguageModelKwargs.rejected_spans], - self.dpo_beta, - grad_output, - ) # TODO: de-allocate earlier. del logits return loss, output_parallel_linear_backward(grad, context) diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index 56f36d137..bda4d02d7 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -101,7 +101,6 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: valid_spans[:, 1].clamp_(max=sequence_k) valid_spans -= sequence_offset - # TODO: check higher batch size kwargs[LanguageModelKwargs.chosen_spans] = valid_spans rejected_loss_masking_spans = kwargs[LanguageModelKwargs.rejected_spans] @@ -116,5 +115,4 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: valid_spans[:, 1].clamp_(max=sequence_k) valid_spans -= sequence_offset - # TODO: check higher batch size kwargs[LanguageModelKwargs.rejected_spans] = valid_spans diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 32273e1f6..40a5b2bb8 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -71,7 +71,7 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - if self._config.loss_function == "dpo": # TODO better way to pass in? + if self._config.use_dpo_loss: # TODO better way to pass in? self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) def get_output_layers(self) -> list[Layer]: From 062ce883b54ad369018e0b10ecdd5e93c4ce2936 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 24 Apr 2025 00:50:13 +0000 Subject: [PATCH 28/47] check for config compatibility --- fast_llm/models/gpt/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 7877afcfc..22ba9eb7e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -201,6 +201,7 @@ def _validate(self) -> None: if self.model.base_model.distillation_model is not None: # TODO: Support loss masking for distillation? assert not self.batch.use_loss_masking_spans + assert self.model.base_model.use_dpo_loss == self.batch.use_preference_loss_masking_spans for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. From f53ac56d179ab1d4c4668c50c74af167db844ab5 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 28 Apr 2025 00:22:43 +0000 Subject: [PATCH 29/47] full dpo changes --- fast_llm/functional/dpo.py | 42 ++++++++------------------ fast_llm/layers/language_model/head.py | 20 +++++++----- 2 files changed, 24 insertions(+), 38 deletions(-) diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 7140e3706..cb08eaa60 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -26,37 +26,10 @@ def compute_logprobs_for_spans( return chosen_logp, rejected_logp -# def compute_simpo_loss( -# logits: torch.Tensor, -# targets: torch.Tensor, -# chosen_span: torch.Tensor, -# rejected_span: torch.Tensor, -# beta: float, -# grad_output: float | None -# ) -> tuple[torch.Tensor, torch.Tensor]: -# with torch.enable_grad(): -# logits_ = logits.float().detach().requires_grad_() -# policy_chosen_logps, policy_rejected_logps = compute_logprobs_for_spans(logits_, targets, chosen_span, rejected_span) - -# len_chosen_span = (chosen_span[0][1] - chosen_span[0][0]).item() -# len_rej_span = (rejected_span[0][1] - rejected_span[0][0]).item() - -# pi_logratios = (beta / (len_chosen_span + 1e-8)) * policy_chosen_logps - (beta / (len_rej_span + 1e-8)) * policy_rejected_logps - -# losses = -torch.nn.functional.logsigmoid(pi_logratios) - -# if grad_output is None: -# loss = None -# else: -# loss = losses.mean() -# loss.backward(torch.full_like(loss, grad_output)) -# loss.detach() -# return loss.detach(), logits_.grad.detach().to(logits.dtype) - - -def compute_simplified_dpo_loss( +def compute_dpo_loss( logits: torch.Tensor, targets: torch.Tensor, + reference_model_logits: torch.Tensor, chosen_span: torch.Tensor, rejected_span: torch.Tensor, beta: float, @@ -64,14 +37,23 @@ def compute_simplified_dpo_loss( ) -> tuple[torch.Tensor, torch.Tensor]: with torch.enable_grad(): logits_ = logits.float().detach().requires_grad_() + reference_model_logits_ = reference_model_logits.float().detach() policy_chosen_logps, policy_rejected_logps = compute_logprobs_for_spans( logits_, targets, chosen_span, rejected_span ) + reference_chosen_logps, reference_rejected_logps = compute_logprobs_for_spans( + reference_model_logits_, targets, chosen_span, rejected_span + ) + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + diff_logratios = pi_logratios - ref_logratios + + losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) - losses = -torch.nn.functional.logsigmoid(beta * pi_logratios) if grad_output is None: loss = None else: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 99d04fa0d..ef247b855 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -12,7 +12,7 @@ from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward -from fast_llm.functional.dpo import compute_simplified_dpo_loss +from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( @@ -144,13 +144,16 @@ def forward( def _forward_backward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None ) -> tuple[torch.Tensor, torch.Tensor | None]: - target = kwargs.get( - LanguageModelKwargs.labels - if self._config.distillation_model is None - else f"{self._config.distillation_model}_logits" - ) + if kwargs.get(LanguageModelKwargs.labels) is None: + target = None + else: + target = kwargs.get( + LanguageModelKwargs.labels + if self._use_dpo_loss or self._config.distillation_model is None + else f"{self._config.distillation_model}_logits" + ) if target is not None: - if self._config.distillation_model is None: + if self._config.distillation_model is None or self._use_dpo_loss: # MTP: Shift the labels target = ( target[self._prediction_distance : self._prediction_distance + input_.size(0),] @@ -300,9 +303,10 @@ def _logits_cross_entropy_forward_backward( if target is None: return logits * self._logits_scale_factor, None if self._use_dpo_loss: - loss, grad = compute_simplified_dpo_loss( + loss, grad = compute_dpo_loss( logits.flatten(0, -2), target, + kwargs.get(f"{self._config.distillation_model}_logits").flatten(0, -2), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], self.dpo_beta, From 92f28ee09eb825529340eecfddc38dee83a89765 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 28 Apr 2025 00:32:47 +0000 Subject: [PATCH 30/47] adding distillation model check --- fast_llm/models/gpt/config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 22ba9eb7e..d567d2879 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -202,6 +202,8 @@ def _validate(self) -> None: # TODO: Support loss masking for distillation? assert not self.batch.use_loss_masking_spans assert self.model.base_model.use_dpo_loss == self.batch.use_preference_loss_masking_spans + if self.model.base_model.use_dpo_loss: + assert self.model.base_model.distillation_model is not None for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) # TODO: Support more LM head features. From 2b2515f41c1e41966b5d3862964c88a3940db051 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 28 Apr 2025 18:24:34 +0000 Subject: [PATCH 31/47] update dpo test cases --- fast_llm/functional/dpo.py | 29 +++++++++++--- tests/test_functional.py | 77 +++++++++++++++++++++++++------------- 2 files changed, 74 insertions(+), 32 deletions(-) diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index cb08eaa60..7130aa53c 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -26,6 +26,22 @@ def compute_logprobs_for_spans( return chosen_logp, rejected_logp +def _compute_dpo_loss( + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + beta: float, +): + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + diff_logratios = pi_logratios - ref_logratios + + losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) + return losses + + def compute_dpo_loss( logits: torch.Tensor, targets: torch.Tensor, @@ -47,12 +63,13 @@ def compute_dpo_loss( reference_model_logits_, targets, chosen_span, rejected_span ) - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - - diff_logratios = pi_logratios - ref_logratios - - losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) + losses = _compute_dpo_loss( + policy_chosen_logps=policy_chosen_logps, + policy_rejected_logps=policy_rejected_logps, + reference_chosen_logps=reference_chosen_logps, + reference_rejected_logps=reference_rejected_logps, + beta=beta, + ) if grad_output is None: loss = None diff --git a/tests/test_functional.py b/tests/test_functional.py index bee568949..61c913a2b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -2,38 +2,63 @@ import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.functional.dpo import _compute_dpo_loss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.functional.dpo import compute_simplified_dpo_loss from fast_llm.utils import Assert from tests.common import requires_cuda -def test_simplified_dpo_loss(): - torch.manual_seed(0) - vocab_size = 10 - seq_length = 10 - logits = torch.randn((seq_length, vocab_size)) - targets = torch.randint(vocab_size, size=(seq_length-1, )) - - dpo_loss, _ = compute_simplified_dpo_loss( - logits=logits, - targets=targets, - chosen_span=torch.tensor([[1, 2]]), - rejected_span=torch.tensor([[4, 5]]), - beta=0.1, - grad_output=0.25 - ) - Assert.rms_close(dpo_loss, torch.tensor(0.71527), 1e-5) - - dpo_loss, _ = compute_simplified_dpo_loss( - logits=logits, - targets=targets, - chosen_span=torch.tensor([[2, 3]]), - rejected_span=torch.tensor([[5, 7]]), - beta=0.3, - grad_output=0.25 + +def openrlhf_dpo_loss_fcn( + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + beta=1, + label_smoothing=0, +) -> torch.Tensor: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + logits = pi_logratios - ref_logratios + + # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) + losses = ( + -torch.nn.functional.logsigmoid(beta * logits) * (1 - label_smoothing) + - torch.nn.functional.logsigmoid(-beta * logits) * label_smoothing ) - Assert.rms_close(dpo_loss, torch.tensor(0.30449), 1e-5) + + loss = losses.mean() + + return loss + + +def test_dpo_loss(): + torch.manual_seed(0) + + NUM_SAMPLES = 20 + policy_chosen_logps = torch.rand(NUM_SAMPLES) + policy_rejected_logps = torch.rand(NUM_SAMPLES) + reference_chosen_logps = torch.rand(NUM_SAMPLES) + reference_rejected_logps = torch.rand(NUM_SAMPLES) + betas = torch.rand(NUM_SAMPLES) + + for i in range(NUM_SAMPLES): + fastllm_dpo_loss = _compute_dpo_loss( + policy_chosen_logps=policy_chosen_logps[i], + policy_rejected_logps=policy_rejected_logps[i], + reference_chosen_logps=reference_chosen_logps[i], + reference_rejected_logps=reference_rejected_logps[i], + beta=betas[i].item(), + ) + openrlhf_dpo_loss = openrlhf_dpo_loss_fcn( + policy_chosen_logps=policy_chosen_logps[i].unsqueeze(0), + policy_rejected_logps=policy_rejected_logps[i].unsqueeze(0), + reference_chosen_logps=reference_chosen_logps[i].unsqueeze(0), + reference_rejected_logps=reference_rejected_logps[i].unsqueeze(0), + beta=betas[i].item(), + ) + Assert.rms_close(fastllm_dpo_loss, openrlhf_dpo_loss, 1e-5) + @requires_cuda @pytest.mark.parametrize("gated", [True, False]) From bd9142fdffda7867406a3c80e6c4571b5030a023 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 28 Apr 2025 18:26:43 +0000 Subject: [PATCH 32/47] FFixing sampled for dpo --- fast_llm/data/data/gpt/data.py | 4 +--- fast_llm/data/dataset/gpt/sampled.py | 9 +++++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 6c568cab0..e7e7327a2 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -159,11 +159,9 @@ def get_iterator( Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") - dataset = self._datasets[dataset_name] # noqa - return iter( torch.utils.data.DataLoader( - dataset, + self._datasets[dataset_name], # noqa batch_sampler=SampledDatasetIterator( total_samples=len(self._datasets[dataset_name]), begin_index=consumed_samples, diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 40ccb4f63..cc482ccb1 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -202,7 +202,7 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) self._load_yaml_data(yaml_data) - if not self._truncate_documents: + if not self._truncate_documents and not self._parameters.use_preference_loss_masking_spans: del loaded_yaml_data["unshuffled_tokens"] if loaded_yaml_data != yaml_data: @@ -266,8 +266,7 @@ def _sample(self) -> None: raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") if self._parameters.use_preference_loss_masking_spans: - if not self._truncate_documents: - yaml_data["unshuffled_tokens"] = None + yaml_data["unshuffled_tokens"] = 0 # not used, ignore # index of all documents less than seq length long doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] @@ -518,7 +517,9 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if "unshuffled_tokens" not in data: + if self._parameters.use_preference_loss_masking_spans: + data["unshuffled_tokens"] = 0 # not used, ignore + elif "unshuffled_tokens" not in data: # Backward compatibility # TODO v0.x: Remove assert self._truncate_documents From 4f261009189945b9358bcad4302c79e3fa4a7237 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Mon, 28 Apr 2025 18:43:18 +0000 Subject: [PATCH 33/47] test case fixes --- fast_llm/data/data/gpt/data.py | 2 ++ fast_llm/data/dataset/gpt/sampled.py | 4 ++-- tests/data/test_prepare_gpt_memmap.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index e7e7327a2..c9e1a5e1d 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -40,6 +40,8 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None + stacked_chosen_spans = None + stacked_rejected_spans = None if sampling_parameters.use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] if sampling_parameters.use_preference_loss_masking_spans: diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index cc482ccb1..78bf254e7 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -548,10 +548,10 @@ def __init__( raise NotImplementedError( "Legacy sampling only supports document truncation. Please use the latest dataset format." ) - if sampling.use_preference_loss_masking_spans: - raise NotImplementedError("Legacy sampling does not support preference loss masking.") self._config = sampling.config self._parameters = sampling.parameters + if self._parameters.use_preference_loss_masking_spans: + raise NotImplementedError("Legacy sampling does not support preference loss masking.") if sampling.cache_directory is None: log_main_rank( diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index b85e396f3..4cc263e99 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -62,7 +62,7 @@ def generate_valid_span(max_seq_length): GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): - dataset_item = dataset.get(i) + dataset_item = dataset.get(i, use_preference_loss_masking_spans=True) assert np.array_equal( dataset_item.token_ids, document.token_ids, equal_nan=True ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." From 41cc7fe61514eb103011621234fead048af75021 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 29 Apr 2025 00:41:03 +0000 Subject: [PATCH 34/47] adding preference logps test case --- fast_llm/functional/dpo.py | 8 +-- tests/test_functional.py | 109 +++++++++++++++++++++++++++++++++++-- 2 files changed, 109 insertions(+), 8 deletions(-) diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 7130aa53c..783a24cf8 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -1,7 +1,7 @@ import torch -def compute_logprobs_for_spans( +def _compute_logprobs_for_preference_spans( logits: torch.Tensor, targets: torch.Tensor, chosen_span: torch.Tensor, rejected_span: torch.Tensor ): assert torch.all(targets < logits.size(-1)), "Target out of vocab range" @@ -23,7 +23,7 @@ def compute_logprobs_for_spans( # rejected_logp = (selected_log_probs * rejected_mask).sum() rejected_logp = selected_log_probs[rejected_span[0][0].item() : rejected_span[0][1].item() + 1].sum() - return chosen_logp, rejected_logp + return chosen_logp, rejected_logp, selected_log_probs def _compute_dpo_loss( @@ -55,11 +55,11 @@ def compute_dpo_loss( logits_ = logits.float().detach().requires_grad_() reference_model_logits_ = reference_model_logits.float().detach() - policy_chosen_logps, policy_rejected_logps = compute_logprobs_for_spans( + policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( logits_, targets, chosen_span, rejected_span ) - reference_chosen_logps, reference_rejected_logps = compute_logprobs_for_spans( + reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( reference_model_logits_, targets, chosen_span, rejected_span ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 61c913a2b..a90037c76 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,15 +1,116 @@ +import random + import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.functional.dpo import _compute_dpo_loss +from fast_llm.functional.dpo import _compute_dpo_loss, _compute_logprobs_for_preference_spans from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert from tests.common import requires_cuda -def openrlhf_dpo_loss_fcn( +def ref_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: + if temperature != 1.0: + logits.div_(temperature) + batch_dim = logits.shape[:-1] + last_dim = logits.shape[-1] + + output = torch.nn.functional.cross_entropy(logits.reshape(-1, last_dim), labels.reshape(-1), reduction="none") + log_probs_labels = -output.view(*batch_dim) + + return log_probs_labels + + +def ref_packed_get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_id_lens, + packed_seq_lens, +) -> torch.FloatTensor: + labels = labels[:, 1:] + logits = logits[:, :-1, :] + per_token_logps = ref_log_probs_from_logits(logits, labels) + + loss_masks = attention_mask.clone().bool() + + index = 0 + for i, seq_len in enumerate(packed_seq_lens): + loss_masks[0, index : index + prompt_id_lens[i]] = False + index = index + seq_len + + loss_masks = loss_masks[:, 1:] + + logprobs_sums = [] + index = 0 + for i, seq_len in enumerate(packed_seq_lens): + seq = per_token_logps[0, index : index + seq_len - 1] + mask = loss_masks[0, index : index + seq_len - 1] + logprobs_sums.append((seq * mask).sum()) + index = index + seq_len + chosen_logps = logprobs_sums[: len(packed_seq_lens) // 2] + rejected_logps = logprobs_sums[len(packed_seq_lens) // 2 :] + + return torch.tensor(chosen_logps), torch.tensor(rejected_logps) + + +def test_preference_logps(): + random.seed(0) + torch.manual_seed(0) + num_iters = 20 + + def random_split(seq_length): + min_val = int(seq_length * 0.3) + max_val = int(seq_length * 0.7) + + if max_val < min_val: + max_val = min_val + + a = random.randint(min_val, max_val) + b = seq_length - a + return [a, b] + + for _ in range(num_iters): + seq_length = random.choice([1024, 4096, 8192]) + vocab_size = random.choice([1000, 2000, 8000]) + + logits = torch.randn(1, seq_length, vocab_size) + targets = torch.randint(0, vocab_size, (1, seq_length)) + packed_seq_lens = random_split(seq_length) + prompt_id_lens = [int(min(packed_seq_lens) * 0.75)] * 2 + attention_mask = torch.tensor([1] * packed_seq_lens[0] + [2] * packed_seq_lens[1]).unsqueeze(0) + + chosen_span = ( + torch.tensor([[prompt_id_lens[0], packed_seq_lens[0] - 1]]) - 1 + ) # shift by 1 due to label shifting + rejected_span = ( + torch.tensor([[packed_seq_lens[0] + prompt_id_lens[1], packed_seq_lens[0] + packed_seq_lens[1] - 1]]) - 1 + ) # shift by 1 due to label shifting + + ref_chosen_logps, ref_rejected_logps = ref_packed_get_batch_logps( + logits, targets, attention_mask, prompt_id_lens, packed_seq_lens + ) + + chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( + logits=logits.flatten(0, -2), + targets=targets.flatten()[1:], + chosen_span=chosen_span, + rejected_span=rejected_span, + ) + + ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) + + # check all logps + Assert.custom(torch.allclose, ref_logps, selected_log_probs, rtol=1e-5) + + # check chosen and rejected summed logps + Assert.custom(torch.allclose, ref_chosen_logps, chosen_logps, rtol=1e-5) + Assert.custom(torch.allclose, ref_rejected_logps, rejected_logps, rtol=1e-5) + + +def ref_dpo_loss_fcn( policy_chosen_logps: torch.Tensor, policy_rejected_logps: torch.Tensor, reference_chosen_logps: torch.Tensor, @@ -50,14 +151,14 @@ def test_dpo_loss(): reference_rejected_logps=reference_rejected_logps[i], beta=betas[i].item(), ) - openrlhf_dpo_loss = openrlhf_dpo_loss_fcn( + ref_dpo_loss = ref_dpo_loss_fcn( policy_chosen_logps=policy_chosen_logps[i].unsqueeze(0), policy_rejected_logps=policy_rejected_logps[i].unsqueeze(0), reference_chosen_logps=reference_chosen_logps[i].unsqueeze(0), reference_rejected_logps=reference_rejected_logps[i].unsqueeze(0), beta=betas[i].item(), ) - Assert.rms_close(fastllm_dpo_loss, openrlhf_dpo_loss, 1e-5) + Assert.rms_close(fastllm_dpo_loss, ref_dpo_loss, 1e-5) @requires_cuda From a6950f13e64ac79608447a66e0cb6b2661962365 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Tue, 29 Apr 2025 17:13:54 +0000 Subject: [PATCH 35/47] small fix --- fast_llm/data/dataset/gpt/sampled.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 78bf254e7..211447dd5 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -274,7 +274,9 @@ def _sample(self) -> None: # apply shuffling on doc_length_filtered_indicies if shuffled_epochs > 0: - self._document_shuffling.save(document_shuffling[: self._num_samples].numpy(force=self._config.gpu)) + self._document_shuffling.save( + document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) + ) self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) if self._yaml_path is not None: self._yaml_path.parent.mkdir(parents=True, exist_ok=True) From fb9803dd5f5b7f674147c9c92def9d7c29a29c53 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 30 Apr 2025 00:12:23 +0000 Subject: [PATCH 36/47] higher mbs fixes --- fast_llm/functional/dpo.py | 16 +++++++-------- fast_llm/layers/language_model/head.py | 7 ++++--- .../layers/language_model/preprocessing.py | 20 +++++++++++-------- 3 files changed, 23 insertions(+), 20 deletions(-) diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 783a24cf8..f10e771c0 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -9,19 +9,17 @@ def _compute_logprobs_for_preference_spans( log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # gather log probabilities corresponding to the target tokens - selected_log_probs = log_probs[:-1].gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + selected_log_probs = log_probs[:, :-1, :].gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) # apply chosen mask - # chosen_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool) - # chosen_mask[chosen_span[0][0].item(): chosen_span[0][1].item() + 1] = 1 - # chosen_logp = (selected_log_probs * chosen_mask).sum() - chosen_logp = selected_log_probs[chosen_span[0][0].item() : chosen_span[0][1].item() + 1].sum() + chosen_logp = 0 + for idx, span in enumerate(chosen_span): + chosen_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() # apply rejected mask - # rejected_mask = torch.zeros_like(selected_log_probs, dtype=torch.bool) - # rejected_mask[rejected_span[0][0].item(): rejected_span[0][1].item() + 1] = 1 - # rejected_logp = (selected_log_probs * rejected_mask).sum() - rejected_logp = selected_log_probs[rejected_span[0][0].item() : rejected_span[0][1].item() + 1].sum() + rejected_logp = 0 + for idx, span in enumerate(rejected_span): + rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() return chosen_logp, rejected_logp, selected_log_probs diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index ef247b855..b50605eee 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -163,7 +163,8 @@ def _forward_backward( self._prediction_distance : self._prediction_distance + input_.size(1), ] ) - target = target.flatten() + if not self._use_dpo_loss: + target = target.flatten() else: # Target is reference model logits. target = target.flatten(0, -2) @@ -304,9 +305,9 @@ def _logits_cross_entropy_forward_backward( return logits * self._logits_scale_factor, None if self._use_dpo_loss: loss, grad = compute_dpo_loss( - logits.flatten(0, -2), + logits, target, - kwargs.get(f"{self._config.distillation_model}_logits").flatten(0, -2), + kwargs.get(f"{self._config.distillation_model}_logits"), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], self.dpo_beta, diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index bda4d02d7..dbe9c6040 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -90,29 +90,33 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: raise ValueError("Expected chosen spans or rejected spans to be found within the batch.") chosen_loss_masking_spans = kwargs[LanguageModelKwargs.chosen_spans] + chosen_valid_spans = [] for spans in chosen_loss_masking_spans: if not spans.numel(): continue # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] if valid_spans.numel(): # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k) + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) valid_spans -= sequence_offset - kwargs[LanguageModelKwargs.chosen_spans] = valid_spans + chosen_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans rejected_loss_masking_spans = kwargs[LanguageModelKwargs.rejected_spans] + rejected_valid_spans = [] for spans in rejected_loss_masking_spans: if not spans.numel(): continue # only keep spans within the sequence or partially within the sequence - valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)] + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] if valid_spans.numel(): # if span is partially within the sequence, truncate parts of spans that are outside of the sequence - valid_spans[:, 0].clamp_(min=sequence_offset) - valid_spans[:, 1].clamp_(max=sequence_k) + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) valid_spans -= sequence_offset - kwargs[LanguageModelKwargs.rejected_spans] = valid_spans + rejected_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans From 723f30edb03402de678eb3838d064c0bd668ccd6 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 30 Apr 2025 00:12:54 +0000 Subject: [PATCH 37/47] test higher mbs --- tests/test_functional.py | 64 +++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index a90037c76..ffab0c85b 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -56,10 +56,12 @@ def ref_packed_get_batch_logps( return torch.tensor(chosen_logps), torch.tensor(rejected_logps) -def test_preference_logps(): +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("seq_length", [1024, 4096, 8192]) +@pytest.mark.parametrize("vocab_size", [1000, 2000, 8000]) +def test_preference_logps(batch_size, seq_length, vocab_size): random.seed(0) torch.manual_seed(0) - num_iters = 20 def random_split(seq_length): min_val = int(seq_length * 0.3) @@ -72,42 +74,36 @@ def random_split(seq_length): b = seq_length - a return [a, b] - for _ in range(num_iters): - seq_length = random.choice([1024, 4096, 8192]) - vocab_size = random.choice([1000, 2000, 8000]) - - logits = torch.randn(1, seq_length, vocab_size) - targets = torch.randint(0, vocab_size, (1, seq_length)) - packed_seq_lens = random_split(seq_length) - prompt_id_lens = [int(min(packed_seq_lens) * 0.75)] * 2 - attention_mask = torch.tensor([1] * packed_seq_lens[0] + [2] * packed_seq_lens[1]).unsqueeze(0) - - chosen_span = ( - torch.tensor([[prompt_id_lens[0], packed_seq_lens[0] - 1]]) - 1 - ) # shift by 1 due to label shifting - rejected_span = ( - torch.tensor([[packed_seq_lens[0] + prompt_id_lens[1], packed_seq_lens[0] + packed_seq_lens[1] - 1]]) - 1 - ) # shift by 1 due to label shifting - - ref_chosen_logps, ref_rejected_logps = ref_packed_get_batch_logps( - logits, targets, attention_mask, prompt_id_lens, packed_seq_lens - ) + logits = torch.randn(batch_size, seq_length, vocab_size) + targets = torch.randint(0, vocab_size, (batch_size, seq_length)) + packed_seq_lens = random_split(seq_length) # simulate different chosen/rejected lengths + prompt_id_lens = [int(min(packed_seq_lens) * 0.75)] * 2 # sequences are 75% prompt 25% generation + attention_mask = torch.tensor([1] * packed_seq_lens[0] + [2] * packed_seq_lens[1]).unsqueeze(0) - chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( - logits=logits.flatten(0, -2), - targets=targets.flatten()[1:], - chosen_span=chosen_span, - rejected_span=rejected_span, - ) + chosen_span = torch.tensor([[prompt_id_lens[0], packed_seq_lens[0] - 1]]) - 1 # shift by 1 due to label shifting + rejected_span = ( + torch.tensor([[packed_seq_lens[0] + prompt_id_lens[1], packed_seq_lens[0] + packed_seq_lens[1] - 1]]) - 1 + ) # shift by 1 due to label shifting + + ref_chosen_logps, ref_rejected_logps = ref_packed_get_batch_logps( + logits, targets, attention_mask, prompt_id_lens, packed_seq_lens + ) + + chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( + logits=logits, + targets=targets[:, 1:], + chosen_span=chosen_span, + rejected_span=rejected_span, + ) - ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) + ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) - # check all logps - Assert.custom(torch.allclose, ref_logps, selected_log_probs, rtol=1e-5) + # check all logps + Assert.custom(torch.allclose, ref_logps, selected_log_probs, rtol=1e-5) - # check chosen and rejected summed logps - Assert.custom(torch.allclose, ref_chosen_logps, chosen_logps, rtol=1e-5) - Assert.custom(torch.allclose, ref_rejected_logps, rejected_logps, rtol=1e-5) + # check chosen and rejected summed logps + Assert.custom(torch.allclose, ref_chosen_logps, chosen_logps, rtol=1e-5) + Assert.custom(torch.allclose, ref_rejected_logps, rejected_logps, rtol=1e-5) def ref_dpo_loss_fcn( From 8063a21629d4778a6bfb18327b28976c778e72f4 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 30 Apr 2025 18:05:40 +0000 Subject: [PATCH 38/47] small change --- fast_llm/data/dataset/gpt/memmap.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index a73f4a37e..a7585c37c 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -173,11 +173,13 @@ def get( rejected_span = None if use_preference_loss_masking_spans: - if self._has_preference_spans and self._chosen_spans is None: + if not self._has_preference_spans: + raise ValueError("No preference spans found in memmap dataset.") + elif self._has_preference_spans and self._chosen_spans is None: raise ValueError("Failed to read chosen spans from memmap dataset.") elif self._has_preference_spans and self._rejected_spans is None: raise ValueError("Failed to read rejected spans from memmap dataset.") - elif self._has_preference_spans: + else: chosen_span = self._chosen_spans[idx] # filter spans that are outside the range of the selected tokens in the document @@ -197,8 +199,6 @@ def get( # subtract by offset to normalize span boundaries rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset - else: - raise ValueError("No preference spans found in memmap dataset.") return GPTSample( token_ids=token_ids, From c3a8ebb3c93c10b4bfc1091f9ec1d4932f6a1776 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 30 Apr 2025 21:07:55 +0000 Subject: [PATCH 39/47] updates --- fast_llm/layers/language_model/config.py | 7 ++++++- fast_llm/layers/language_model/head.py | 11 +++++------ fast_llm/models/gpt/config.py | 24 +++++++++++++----------- fast_llm/models/gpt/model.py | 2 +- fast_llm/models/gpt/trainer.py | 2 +- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 040594cc4..db0016ffd 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -143,7 +143,7 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", hint=FieldHint.feature, ) - use_dpo_loss: bool | None = Field( + enable_dpo: bool | None = Field( default=False, desc="Whether to enable DPO loss", hint=FieldHint.feature, @@ -153,6 +153,11 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig): desc="Beta value for DPO loss.", hint=FieldHint.feature, ) + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) cross_entropy_impl: CrossEntropyImpl = Field( default=CrossEntropyImpl.auto, desc="Implementation for the cross-entropy computation.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index b50605eee..5d064f0e4 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -72,7 +72,7 @@ def __init__( self._init_output_weights(hidden_dim, config) - self._use_dpo_loss = config.use_dpo_loss + self._use_dpo_loss = config.enable_dpo if self._use_dpo_loss: self.dpo_beta = config.dpo_beta else: @@ -152,8 +152,8 @@ def _forward_backward( if self._use_dpo_loss or self._config.distillation_model is None else f"{self._config.distillation_model}_logits" ) - if target is not None: - if self._config.distillation_model is None or self._use_dpo_loss: + if target is not None and not self._use_dpo_loss: + if self._config.distillation_model is None: # MTP: Shift the labels target = ( target[self._prediction_distance : self._prediction_distance + input_.size(0),] @@ -163,8 +163,7 @@ def _forward_backward( self._prediction_distance : self._prediction_distance + input_.size(1), ] ) - if not self._use_dpo_loss: - target = target.flatten() + target = target.flatten() else: # Target is reference model logits. target = target.flatten(0, -2) @@ -307,7 +306,7 @@ def _logits_cross_entropy_forward_backward( loss, grad = compute_dpo_loss( logits, target, - kwargs.get(f"{self._config.distillation_model}_logits"), + kwargs.get(f"{self._config.dpo_reference_model}_logits"), kwargs[LanguageModelKwargs.chosen_spans], kwargs[LanguageModelKwargs.rejected_spans], self.dpo_beta, diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index d567d2879..e6b271fe3 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -99,11 +99,6 @@ class GPTBatchConfig(BatchConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) - use_preference_loss_masking_spans: bool = Field( - default=False, - desc="Read loss masking spans from the dataset.", - hint=FieldHint.feature, - ) def _validate(self) -> None: if self.micro_sequence_length is None: @@ -192,20 +187,27 @@ def _validate(self) -> None: if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() - if (name := self.model.base_model.distillation_model) is None: - Assert.empty(self.reference_models) - else: + + if (name := self.model.base_model.distillation_model) is not None: Assert.eq(self.reference_models.keys(), {name}) + elif (name := self.model.base_model.dpo_reference_model) is not None: + Assert.eq(self.reference_models.keys(), {name}) + else: + Assert.empty(self.reference_models) + if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) if self.model.base_model.distillation_model is not None: # TODO: Support loss masking for distillation? assert not self.batch.use_loss_masking_spans - assert self.model.base_model.use_dpo_loss == self.batch.use_preference_loss_masking_spans - if self.model.base_model.use_dpo_loss: - assert self.model.base_model.distillation_model is not None + + if self.model.base_model.enable_dpo: + assert self.model.base_model.dpo_reference_model is not None + Assert.none(self.model.base_model.distillation_model) + for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) + Assert.none(reference_model.model.base_model.dpo_reference_model) # TODO: Support more LM head features. Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 40a5b2bb8..dc933b228 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -71,7 +71,7 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) - if self._config.use_dpo_loss: # TODO better way to pass in? + if self._config.enable_dpo: # TODO better way to pass in? self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) def get_output_layers(self) -> list[Layer]: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 8092dc62d..87f80a368 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -48,7 +48,7 @@ def _get_sampling_parameters( "vocab_size": self._config.model.base_model.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - "use_preference_loss_masking_spans": self._config.batch.use_preference_loss_masking_spans, + "use_preference_loss_masking_spans": self._config.model.base_model.enable_dpo, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, } From db5242fd9efcfba46590d9643b83a09ba70d3426 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 30 Apr 2025 21:29:01 +0000 Subject: [PATCH 40/47] small changes --- fast_llm/models/gpt/config.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index e6b271fe3..93a659f70 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -188,12 +188,13 @@ def _validate(self) -> None: set_megatron_distributed_seeds(self.model.distributed) super()._validate() - if (name := self.model.base_model.distillation_model) is not None: - Assert.eq(self.reference_models.keys(), {name}) - elif (name := self.model.base_model.dpo_reference_model) is not None: - Assert.eq(self.reference_models.keys(), {name}) - else: - Assert.empty(self.reference_models) + # if self.model.base_model.distillation_model is None and self.model.base_model.dpo_reference_model is None: + # Assert.empty(self.reference_models) + # else: + # if (name := self.model.base_model.distillation_model) is not None: + # Assert.eq(self.reference_models.keys(), {name}) + # if (name := self.model.base_model.dpo_reference_model) is not None: + # Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) @@ -204,6 +205,18 @@ def _validate(self) -> None: if self.model.base_model.enable_dpo: assert self.model.base_model.dpo_reference_model is not None Assert.none(self.model.base_model.distillation_model) + else: + Assert.none(self.model.base_model.dpo_reference_model) + + distillation_model = self.model.base_model.distillation_model + dpo_reference_model = self.model.base_model.dpo_reference_model + + if distillation_model is None and dpo_reference_model is None: + Assert.empty(self.reference_models) + else: + assert distillation_model is None or dpo_reference_model is None # currently don't support both + expected_names = {name for name in (distillation_model, dpo_reference_model) if name is not None} + Assert.eq(self.reference_models.keys(), expected_names) for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) From ab139ca3f35a9a1cf3e11d71da5a8cce820d9d05 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 30 Apr 2025 21:31:23 +0000 Subject: [PATCH 41/47] small changes --- fast_llm/models/gpt/config.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 93a659f70..05a51b81e 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -198,18 +198,19 @@ def _validate(self) -> None: if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) - if self.model.base_model.distillation_model is not None: + + distillation_model = self.model.base_model.distillation_model + dpo_reference_model = self.model.base_model.dpo_reference_model + + if distillation_model is not None: # TODO: Support loss masking for distillation? assert not self.batch.use_loss_masking_spans if self.model.base_model.enable_dpo: - assert self.model.base_model.dpo_reference_model is not None - Assert.none(self.model.base_model.distillation_model) + assert dpo_reference_model is not None + Assert.none(distillation_model) else: - Assert.none(self.model.base_model.dpo_reference_model) - - distillation_model = self.model.base_model.distillation_model - dpo_reference_model = self.model.base_model.dpo_reference_model + Assert.none(dpo_reference_model) if distillation_model is None and dpo_reference_model is None: Assert.empty(self.reference_models) From 63041aa2df59fd162ed133a7ad0977e35afa62dd Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Wed, 30 Apr 2025 21:32:38 +0000 Subject: [PATCH 42/47] remove comments --- fast_llm/models/gpt/config.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 05a51b81e..4c560d0f7 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -188,14 +188,6 @@ def _validate(self) -> None: set_megatron_distributed_seeds(self.model.distributed) super()._validate() - # if self.model.base_model.distillation_model is None and self.model.base_model.dpo_reference_model is None: - # Assert.empty(self.reference_models) - # else: - # if (name := self.model.base_model.distillation_model) is not None: - # Assert.eq(self.reference_models.keys(), {name}) - # if (name := self.model.base_model.dpo_reference_model) is not None: - # Assert.eq(self.reference_models.keys(), {name}) - if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) From e60ad62935bac00457936a7e9f06dd1ca771b7dd Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 1 May 2025 23:20:27 +0000 Subject: [PATCH 43/47] maxlen consistency --- fast_llm/data/dataset/gpt/sampled.py | 4 ++-- fast_llm/functional/dpo.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 211447dd5..f87fb63d4 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -405,9 +405,9 @@ def __getitem__(self, index: int) -> typing.Any: ] # compute padding size - padding = np.full((self._parameters.sequence_length,), 0) + padding = np.full((self._parameters.sequence_length + 1,), 0) padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) + sequence_lengths.append(self._parameters.sequence_length + 1 - len(sample.token_ids)) sample.token_ids = padding sample.sequence_lengths = np.array(sequence_lengths) diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index f10e771c0..3d66f91a3 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -9,7 +9,7 @@ def _compute_logprobs_for_preference_spans( log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # gather log probabilities corresponding to the target tokens - selected_log_probs = log_probs[:, :-1, :].gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + selected_log_probs = log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) # apply chosen mask chosen_logp = 0 From 85613f78397d94f59fb2afb7877809eb0240fc60 Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 1 May 2025 23:39:20 +0000 Subject: [PATCH 44/47] remove comments --- fast_llm/engine/schedule/schedule.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 5f29da23f..91ce0d892 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -225,7 +225,6 @@ def _create_index(self) -> None: # Related steps for i, step in enumerate(self._steps): - # link forward and backward steps together if self._is_training: if step.type_ == StepType.forward: if step.stage >= self._first_grad_stage: @@ -233,7 +232,6 @@ def _create_index(self) -> None: else: step.forward_step = self.get_step(StepType.forward, *step.map_index[1:]) - # link the previous step if step.type_ == StepType.forward and step.stage == 0: step.prev_step = None elif step.type_ == StepType.backward and step.stage == self._num_stages - 1: From 274269208bd3593a32fcdee9979e60fe3e1f98aa Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 2 May 2025 19:06:33 +0000 Subject: [PATCH 45/47] refractoring --- fast_llm/data/data/gpt/data.py | 14 ++++----- fast_llm/data/dataset/gpt/config.py | 2 +- fast_llm/data/dataset/gpt/fim.py | 2 +- fast_llm/data/dataset/gpt/memmap.py | 16 +++++----- fast_llm/data/dataset/gpt/sampled.py | 29 ++++++++++--------- .../data/preparator/gpt_memmap/prepare.py | 6 ++-- fast_llm/functional/dpo.py | 14 ++++----- .../layers/language_model/preprocessing.py | 8 ++--- fast_llm/models/gpt/model.py | 8 ++--- fast_llm/models/gpt/trainer.py | 2 +- tests/data/test_prepare_gpt_memmap.py | 14 ++++----- 11 files changed, 57 insertions(+), 58 deletions(-) diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index c9e1a5e1d..c6fece9d7 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,8 +32,8 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None - chosen_loss_masking_spans: list[torch.Tensor] | None = None - rejected_loss_masking_spans: list[torch.Tensor] | None = None + chosen_spans: list[torch.Tensor] | None = None + rejected_spans: list[torch.Tensor] | None = None def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: @@ -44,17 +44,17 @@ def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSampling stacked_rejected_spans = None if sampling_parameters.use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] - if sampling_parameters.use_preference_loss_masking_spans: - stacked_chosen_spans = [torch.from_numpy(sample.chosen_loss_masking_span) for sample in batch] - stacked_rejected_spans = [torch.from_numpy(sample.rejected_loss_masking_span) for sample in batch] + if sampling_parameters.use_preference_loss_spans: + stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch] + stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] return GPTBatch( token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths, - chosen_loss_masking_spans=stacked_chosen_spans, - rejected_loss_masking_spans=stacked_rejected_spans, + chosen_spans=stacked_chosen_spans, + rejected_spans=stacked_rejected_spans, ) diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index 0678eb2cd..acb8dfd7b 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -73,7 +73,7 @@ class GPTSamplingParameters(SamplingParameters): sequence_length: int vocab_size: int use_loss_masking_spans: bool = False - use_preference_loss_masking_spans: bool = False + use_preference_loss_spans: bool = False cross_document_attention: bool = True # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 5eeba59b5..2b2c8b3be 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -20,7 +20,7 @@ def __init__( ): if sampling.parameters.use_loss_masking_spans: raise NotImplementedError("FIM is currently not compatible with loss masking.") - if sampling.parameters.use_preference_loss_masking_spans: + if sampling.parameters.use_preference_loss_spans: raise NotImplementedError("FIM is currently not compatible with preference loss masking.") self._config = config self._dataset = dataset diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index a7585c37c..f39fd56f4 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -148,7 +148,7 @@ def get( offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False, - use_preference_loss_masking_spans: bool = False, + use_preference_loss_spans: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -172,7 +172,7 @@ def get( chosen_span = None rejected_span = None - if use_preference_loss_masking_spans: + if use_preference_loss_spans: if not self._has_preference_spans: raise ValueError("No preference spans found in memmap dataset.") elif self._has_preference_spans and self._chosen_spans is None: @@ -203,8 +203,8 @@ def get( return GPTSample( token_ids=token_ids, loss_masking_spans=sample_spans, - chosen_loss_masking_span=chosen_span, - rejected_loss_masking_span=rejected_span, + chosen_span=chosen_span, + rejected_span=rejected_span, ) @property @@ -267,10 +267,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) - if document.chosen_loss_masking_span is not None: - chosen_spans.append(document.chosen_loss_masking_span) - if document.rejected_loss_masking_span is not None: - rejected_spans.append(document.rejected_loss_masking_span) + if document.chosen_span is not None: + chosen_spans.append(document.chosen_span) + if document.rejected_span is not None: + rejected_spans.append(document.rejected_span) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index f87fb63d4..5679736a2 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -30,8 +30,8 @@ class GPTSample: token_ids: np.ndarray loss_masking_spans: np.ndarray | None = None - chosen_loss_masking_span: np.ndarray | None = None - rejected_loss_masking_span: np.ndarray | None = None + chosen_span: np.ndarray | None = None + rejected_span: np.ndarray | None = None sequence_lengths: np.ndarray | None = None @@ -116,7 +116,7 @@ def __init__( self._yaml_path = base_path.with_suffix(".yaml") # keep document sizes and len filtered docs for preference loss masking - if self._parameters.use_preference_loss_masking_spans: + if self._parameters.use_preference_loss_spans: self._document_sizes = MemmapArray(base_path.with_name(base_path.name + "_doc_sizes.npy")) self._doc_length_filtered_indicies = MemmapArray( base_path.with_name(base_path.name + "_doc_length_filtered_indices.npy") @@ -159,7 +159,7 @@ def _sample(self) -> None: # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. - if self._parameters.use_preference_loss_masking_spans: + if self._parameters.use_preference_loss_spans: documents_per_epoch = (~long_docs_filter).sum().item() num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) elif self._truncate_documents: @@ -202,7 +202,7 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) self._load_yaml_data(yaml_data) - if not self._truncate_documents and not self._parameters.use_preference_loss_masking_spans: + if not self._truncate_documents and not self._parameters.use_preference_loss_spans: del loaded_yaml_data["unshuffled_tokens"] if loaded_yaml_data != yaml_data: @@ -265,7 +265,7 @@ def _sample(self) -> None: else: raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") - if self._parameters.use_preference_loss_masking_spans: + if self._parameters.use_preference_loss_spans: yaml_data["unshuffled_tokens"] = 0 # not used, ignore # index of all documents less than seq length long @@ -382,7 +382,7 @@ def __getitem__(self, index: int) -> typing.Any: """ self._lazy_load() - if self._parameters.use_preference_loss_masking_spans: + if self._parameters.use_preference_loss_spans: if index < self._unshuffled_documents: document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] else: @@ -395,13 +395,13 @@ def __getitem__(self, index: int) -> typing.Any: offset=0, length=self._document_sizes[document_index], use_loss_masking_spans=self._parameters.use_loss_masking_spans, - use_preference_loss_masking_spans=self._parameters.use_preference_loss_masking_spans, + use_preference_loss_spans=self._parameters.use_preference_loss_spans, ) - chosen_loss_masking_span_end = sample.chosen_loss_masking_span[1] + 1 + chosen_span_end = sample.chosen_span[1] + 1 sequence_lengths = [ - chosen_loss_masking_span_end, - len(sample.token_ids) - chosen_loss_masking_span_end, + chosen_span_end, + len(sample.token_ids) - chosen_span_end, ] # compute padding size @@ -410,7 +410,8 @@ def __getitem__(self, index: int) -> typing.Any: sequence_lengths.append(self._parameters.sequence_length + 1 - len(sample.token_ids)) sample.token_ids = padding - sample.sequence_lengths = np.array(sequence_lengths) + if not self._parameters.cross_document_attention: + sample.sequence_lengths = np.array(sequence_lengths) return sample @@ -519,7 +520,7 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if self._parameters.use_preference_loss_masking_spans: + if self._parameters.use_preference_loss_spans: data["unshuffled_tokens"] = 0 # not used, ignore elif "unshuffled_tokens" not in data: # Backward compatibility @@ -552,7 +553,7 @@ def __init__( ) self._config = sampling.config self._parameters = sampling.parameters - if self._parameters.use_preference_loss_masking_spans: + if self._parameters.use_preference_loss_spans: raise NotImplementedError("Legacy sampling does not support preference loss masking.") if sampling.cache_directory is None: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 4c1d1d5fa..0cba3aa1c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -159,10 +159,8 @@ def _document_generator(): for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample( token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), - chosen_loss_masking_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), - rejected_loss_masking_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape( - -1, 2 - ), + chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), + rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py index 3d66f91a3..3a70f308f 100644 --- a/fast_llm/functional/dpo.py +++ b/fast_llm/functional/dpo.py @@ -2,7 +2,7 @@ def _compute_logprobs_for_preference_spans( - logits: torch.Tensor, targets: torch.Tensor, chosen_span: torch.Tensor, rejected_span: torch.Tensor + logits: torch.Tensor, targets: torch.Tensor, chosen_spans: torch.Tensor, rejected_spans: torch.Tensor ): assert torch.all(targets < logits.size(-1)), "Target out of vocab range" @@ -13,12 +13,12 @@ def _compute_logprobs_for_preference_spans( # apply chosen mask chosen_logp = 0 - for idx, span in enumerate(chosen_span): + for idx, span in enumerate(chosen_spans): chosen_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() # apply rejected mask rejected_logp = 0 - for idx, span in enumerate(rejected_span): + for idx, span in enumerate(rejected_spans): rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() return chosen_logp, rejected_logp, selected_log_probs @@ -44,8 +44,8 @@ def compute_dpo_loss( logits: torch.Tensor, targets: torch.Tensor, reference_model_logits: torch.Tensor, - chosen_span: torch.Tensor, - rejected_span: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, beta: float, grad_output: float | None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -54,11 +54,11 @@ def compute_dpo_loss( reference_model_logits_ = reference_model_logits.float().detach() policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( - logits_, targets, chosen_span, rejected_span + logits_, targets, chosen_spans, rejected_spans ) reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( - reference_model_logits_, targets, chosen_span, rejected_span + reference_model_logits_, targets, chosen_spans, rejected_spans ) losses = _compute_dpo_loss( diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index dbe9c6040..d719bef3d 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -89,9 +89,9 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: raise ValueError("Expected chosen spans or rejected spans to be found within the batch.") - chosen_loss_masking_spans = kwargs[LanguageModelKwargs.chosen_spans] + chosen_spans = kwargs[LanguageModelKwargs.chosen_spans] chosen_valid_spans = [] - for spans in chosen_loss_masking_spans: + for spans in chosen_spans: if not spans.numel(): continue # only keep spans within the sequence or partially within the sequence @@ -105,9 +105,9 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: chosen_valid_spans.append(valid_spans) kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans - rejected_loss_masking_spans = kwargs[LanguageModelKwargs.rejected_spans] + rejected_spans = kwargs[LanguageModelKwargs.rejected_spans] rejected_valid_spans = [] - for spans in rejected_loss_masking_spans: + for spans in rejected_spans: if not spans.numel(): continue # only keep spans within the sequence or partially within the sequence diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 61344490e..7e0993239 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -252,10 +252,10 @@ def preprocess( tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths - if batch.chosen_loss_masking_spans is not None: - kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_loss_masking_spans - if batch.rejected_loss_masking_spans is not None: - kwargs_meta[LanguageModelKwargs.rejected_spans] = batch.rejected_loss_masking_spans + if batch.chosen_spans is not None: + kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans + if batch.rejected_spans is not None: + kwargs_meta[LanguageModelKwargs.rejected_spans] = batch.rejected_spans # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 87f80a368..0d5542ccb 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -48,7 +48,7 @@ def _get_sampling_parameters( "vocab_size": self._config.model.base_model.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - "use_preference_loss_masking_spans": self._config.model.base_model.enable_dpo, + "use_preference_loss_spans": self._config.model.base_model.enable_dpo, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, } diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 4cc263e99..17ba5de01 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -52,8 +52,8 @@ def generate_valid_span(max_seq_length): documents = [ GPTSample( token_ids=np.random.randint(vocab_size, size=max_seq_length).astype(dtype), - chosen_loss_masking_span=generate_valid_span(max_seq_length=max_seq_length), - rejected_loss_masking_span=generate_valid_span(max_seq_length=max_seq_length), + chosen_span=generate_valid_span(max_seq_length=max_seq_length), + rejected_span=generate_valid_span(max_seq_length=max_seq_length), ) for _ in range(num_samples) ] @@ -62,18 +62,18 @@ def generate_valid_span(max_seq_length): GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) dataset = GPTMemmapDataset(name="foo", prefix=prefix) for i, document in enumerate(documents): - dataset_item = dataset.get(i, use_preference_loss_masking_spans=True) + dataset_item = dataset.get(i, use_preference_loss_spans=True) assert np.array_equal( dataset_item.token_ids, document.token_ids, equal_nan=True ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." assert np.array_equal( - dataset_item.chosen_loss_masking_span, document.chosen_loss_masking_span, equal_nan=True - ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_loss_masking_span} != {dataset.get(i).chosen_loss_masking_span}." + dataset_item.chosen_span, document.chosen_span, equal_nan=True + ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get(i).chosen_span}." assert np.array_equal( - dataset_item.rejected_loss_masking_span, document.rejected_loss_masking_span, equal_nan=True - ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_loss_masking_span} != {dataset.get(i).rejected_loss_masking_span}." + dataset_item.rejected_span, document.rejected_span, equal_nan=True + ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get(i).rejected_span}." def test_load_metadata_from_hub(): From 8b837c02606b59ce381d49e75ceaa78e4a151f3f Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Fri, 2 May 2025 19:33:43 +0000 Subject: [PATCH 46/47] fix --- tests/test_functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_functional.py b/tests/test_functional.py index ffab0c85b..34a7a77fc 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -92,8 +92,8 @@ def random_split(seq_length): chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( logits=logits, targets=targets[:, 1:], - chosen_span=chosen_span, - rejected_span=rejected_span, + chosen_spans=chosen_span, + rejected_spans=rejected_span, ) ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) From 16136ac6405ab4a5a35db3c133b444dd167aeedf Mon Sep 17 00:00:00 2001 From: Toby Liang Date: Thu, 8 May 2025 18:12:29 +0000 Subject: [PATCH 47/47] fix --- fast_llm/data/dataset/gpt/sampled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/data/dataset/gpt/sampled.py b/fast_llm/data/dataset/gpt/sampled.py index 5679736a2..8bb5f7370 100644 --- a/fast_llm/data/dataset/gpt/sampled.py +++ b/fast_llm/data/dataset/gpt/sampled.py @@ -407,7 +407,7 @@ def __getitem__(self, index: int) -> typing.Any: # compute padding size padding = np.full((self._parameters.sequence_length + 1,), 0) padding[: len(sample.token_ids)] = sample.token_ids - sequence_lengths.append(self._parameters.sequence_length + 1 - len(sample.token_ids)) + sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) sample.token_ids = padding if not self._parameters.cross_document_attention: