diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index 02c1b6c01..c6fece9d7 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -32,18 +32,29 @@ class GPTBatch: token_ids: torch.Tensor loss_masking_spans: list[torch.Tensor] | None = None sequence_lengths: list[torch.Tensor] | None = None + chosen_spans: list[torch.Tensor] | None = None + rejected_spans: list[torch.Tensor] | None = None def gpt_data_collate_fn(batch: list[GPTSample], sampling_parameters: GPTSamplingParameters) -> GPTBatch: stacked_ids = np.stack([sample.token_ids for sample in batch]) stacked_spans = None sequence_lengths = None + stacked_chosen_spans = None + stacked_rejected_spans = None if sampling_parameters.use_loss_masking_spans: stacked_spans = [torch.from_numpy(sample.loss_masking_spans) for sample in batch] + if sampling_parameters.use_preference_loss_spans: + stacked_chosen_spans = [torch.from_numpy(sample.chosen_span) for sample in batch] + stacked_rejected_spans = [torch.from_numpy(sample.rejected_span) for sample in batch] if not sampling_parameters.cross_document_attention: sequence_lengths = [torch.tensor(sample.sequence_lengths) for sample in batch] return GPTBatch( - token_ids=torch.from_numpy(stacked_ids), loss_masking_spans=stacked_spans, sequence_lengths=sequence_lengths + token_ids=torch.from_numpy(stacked_ids), + loss_masking_spans=stacked_spans, + sequence_lengths=sequence_lengths, + chosen_spans=stacked_chosen_spans, + rejected_spans=stacked_rejected_spans, ) @@ -149,6 +160,7 @@ def get_iterator( sampling_parameters = self._sampling_parameters[dataset_name] Assert.in_range_incl(batch_config.sequence_length, 1, sampling_parameters.sequence_length) log_main_rank(f"Initializing {dataset_name} dataset iterator from sample {consumed_samples}...") + return iter( torch.utils.data.DataLoader( self._datasets[dataset_name], # noqa diff --git a/fast_llm/data/dataset/gpt/config.py b/fast_llm/data/dataset/gpt/config.py index ed9128c6e..acb8dfd7b 100644 --- a/fast_llm/data/dataset/gpt/config.py +++ b/fast_llm/data/dataset/gpt/config.py @@ -73,6 +73,7 @@ class GPTSamplingParameters(SamplingParameters): sequence_length: int vocab_size: int use_loss_masking_spans: bool = False + use_preference_loss_spans: bool = False cross_document_attention: bool = True # How many extra tokens to add to the sequence length. # This is used to provide labels even for the last tokens in the sequence. diff --git a/fast_llm/data/dataset/gpt/fim.py b/fast_llm/data/dataset/gpt/fim.py index 63b7f4378..2b2c8b3be 100644 --- a/fast_llm/data/dataset/gpt/fim.py +++ b/fast_llm/data/dataset/gpt/fim.py @@ -20,8 +20,11 @@ def __init__( ): if sampling.parameters.use_loss_masking_spans: raise NotImplementedError("FIM is currently not compatible with loss masking.") + if sampling.parameters.use_preference_loss_spans: + raise NotImplementedError("FIM is currently not compatible with preference loss masking.") self._config = config self._dataset = dataset + self._seed = sampling.config.seed self._tokenizer = sampling.tokenizer if self._tokenizer is None: diff --git a/fast_llm/data/dataset/gpt/memmap.py b/fast_llm/data/dataset/gpt/memmap.py index ef060b008..f39fd56f4 100644 --- a/fast_llm/data/dataset/gpt/memmap.py +++ b/fast_llm/data/dataset/gpt/memmap.py @@ -34,13 +34,16 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None self._name = name self._prefix = pathlib.Path(prefix) self._has_spans = 0 + self._has_preference_spans = False with self._prefix.with_suffix(".idx").open("rb") as stream: Assert.eq(stream.read(9), MEMMAP_INDEX_HEADER, msg=f"File: {stream.name}") self._version = struct.unpack("= 2: self._has_spans = struct.unpack("= 3: + self._has_preference_spans = struct.unpack("= 2: self._spans = [] self._num_spans = np.frombuffer( self._index_bin_buffer, @@ -83,6 +91,36 @@ def _init(self, name: str, prefix: pathlib.Path | str, num_documents: int | None ).reshape(-1, 2) ) + # read preference spans + self._chosen_spans = None + self._rejected_spans = None + if self._has_preference_spans and self._version >= 3: + self._chosen_spans = [] + self._rejected_spans = [] + chosen_span_offset = offset + self._document_sizes.nbytes + self._pointers.nbytes + for idx in range(self._num_documents): + self._chosen_spans.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=chosen_span_offset + idx * 2 * np.dtype(np.int32).itemsize, + ) + ) + + rejected_span_offset = ( + offset + self._document_sizes.nbytes + self._pointers.nbytes + np.array(self._chosen_spans).nbytes + ) + for idx in range(self._num_documents): + self._rejected_spans.append( + np.frombuffer( + self._index_bin_buffer, + dtype=np.int32, + count=2, + offset=rejected_span_offset + idx * 2 * np.dtype(np.int32).itemsize, + ) + ) + self._bin_buffer_mmap = np.memmap(self._prefix.with_suffix(".bin"), mode="r", order="C") self._bin_buffer = memoryview(self._bin_buffer_mmap) @@ -105,7 +143,12 @@ def __del__(self): del self._index_bin_buffer_mmap def get( - self, idx: int, offset: int = 0, length: int | None = None, use_loss_masking_spans: bool = False + self, + idx: int, + offset: int = 0, + length: int | None = None, + use_loss_masking_spans: bool = False, + use_preference_loss_spans: bool = False, ) -> GPTSample: token_ids = np.frombuffer( self._bin_buffer, @@ -116,13 +159,53 @@ def get( sample_spans = None if use_loss_masking_spans and self._spans is not None: sample_spans = self._spans[idx] - # adjust the spans for the offset and length + + # filter spans that are outside the range of the selected tokens in the document sample_spans = sample_spans[ (sample_spans[:, 0] < offset + len(token_ids)) & (sample_spans[:, 1] >= offset) ] - sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset + + # subtract by offset to normalize span boundaries + sample_spans[:, 0] = np.maximum(sample_spans[:, 0], offset) - offset # offset sample_spans[:, 1] = np.minimum(sample_spans[:, 1], offset + len(token_ids) - 1) - offset - return GPTSample(token_ids=token_ids, loss_masking_spans=sample_spans) + + chosen_span = None + rejected_span = None + + if use_preference_loss_spans: + if not self._has_preference_spans: + raise ValueError("No preference spans found in memmap dataset.") + elif self._has_preference_spans and self._chosen_spans is None: + raise ValueError("Failed to read chosen spans from memmap dataset.") + elif self._has_preference_spans and self._rejected_spans is None: + raise ValueError("Failed to read rejected spans from memmap dataset.") + else: + chosen_span = self._chosen_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + chosen_span = chosen_span[(chosen_span[0] < offset + len(token_ids)) & (chosen_span[1] >= offset)][0] + + # subtract by offset to normalize span boundaries + chosen_span[0] = np.maximum(chosen_span[0], offset) - offset # offset + chosen_span[1] = np.minimum(chosen_span[1], offset + len(token_ids) - 1) - offset + + rejected_span = self._rejected_spans[idx] + + # filter spans that are outside the range of the selected tokens in the document + rejected_span = rejected_span[ + (rejected_span[0] < offset + len(token_ids)) & (rejected_span[1] >= offset) + ][0] + + # subtract by offset to normalize span boundaries + rejected_span[0] = np.maximum(rejected_span[0], offset) - offset # offset + rejected_span[1] = np.minimum(rejected_span[1], offset + len(token_ids) - 1) - offset + + return GPTSample( + token_ids=token_ids, + loss_masking_spans=sample_spans, + chosen_span=chosen_span, + rejected_span=rejected_span, + ) @property def name(self) -> str: @@ -157,6 +240,8 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP # number of spans for each document num_spans = [] spans = [] + chosen_spans = [] + rejected_spans = [] prefix = pathlib.Path(prefix) prefix.parent.mkdir(parents=True, exist_ok=True) @@ -182,6 +267,10 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP if document.loss_masking_spans is not None: num_spans.append(len(document.loss_masking_spans)) spans.append(document.loss_masking_spans) + if document.chosen_span is not None: + chosen_spans.append(document.chosen_span) + if document.rejected_span is not None: + rejected_spans.append(document.rejected_span) offset += doc_length * np.dtype(dtype).itemsize num_documents += 1 @@ -193,15 +282,20 @@ def write_dataset(cls, prefix: pathlib.Path | str, documents: typing.Iterable[GP spans = np.vstack(spans, dtype=np.int32) else: spans = np.array(spans, dtype=np.int32) + chosen_spans = np.array(chosen_spans, dtype=np.int32).reshape(-1, 2) + rejected_spans = np.array(rejected_spans, dtype=np.int32).reshape(-1, 2) # Write the index file (.idx) with prefix.with_suffix(".idx").open("wb") as idx_stream: idx_stream.write(MEMMAP_INDEX_HEADER) # Indicates the version # Version 2 optionally adds loss-masking spans - idx_stream.write(struct.pack(" 0 else 0)) + # Flag to indicate whether preference loss-masking spans are present + idx_stream.write(struct.pack(" 0 and rejected_spans.size > 0 else 0)) # Data type idx_stream.write(struct.pack(" None: raise RuntimeError( f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}." ) + # We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads, # but in case of truncations we also include those last labels in the following sample, # so we need `sequence_length * num_samples + extra_tokens` tokens in total. - if self._truncate_documents: + if self._parameters.use_preference_loss_spans: + documents_per_epoch = (~long_docs_filter).sum().item() + num_epochs = math.ceil(self._parameters.num_samples / documents_per_epoch) + elif self._truncate_documents: num_epochs = math.ceil( (self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens) / tokens_per_epoch @@ -187,8 +201,8 @@ def _sample(self) -> None: if self._yaml_path is not None and self._yaml_path.is_file(): loaded_yaml_data = yaml.safe_load(self._yaml_path.open("r")) - self._load_yaml_data(loaded_yaml_data) - if not self._truncate_documents: + self._load_yaml_data(yaml_data) + if not self._truncate_documents and not self._parameters.use_preference_loss_spans: del loaded_yaml_data["unshuffled_tokens"] if loaded_yaml_data != yaml_data: @@ -251,6 +265,24 @@ def _sample(self) -> None: else: raise NotImplementedError(f"Unknown shuffling type: {self._config.shuffle}") + if self._parameters.use_preference_loss_spans: + yaml_data["unshuffled_tokens"] = 0 # not used, ignore + + # index of all documents less than seq length long + doc_length_filtered_indicies = torch.nonzero(~long_docs_filter, as_tuple=True)[0] + self._doc_length_filtered_indicies.save(doc_length_filtered_indicies.numpy(force=self._config.gpu)) + + # apply shuffling on doc_length_filtered_indicies + if shuffled_epochs > 0: + self._document_shuffling.save( + document_shuffling[: self._parameters.num_samples].numpy(force=self._config.gpu) + ) + self._document_sizes.save(document_sizes.numpy(force=self._config.gpu)) + if self._yaml_path is not None: + self._yaml_path.parent.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(yaml_data, self._yaml_path.open("w")) + return + # To get a sample on the fly we need to know where it begins, # and this is a non-trivial information because the documents have variable length. # The starting point `(document[idx], token[idx])` corresponds to the `(idx * sequence_length)` th token, i.e. @@ -349,6 +381,40 @@ def __getitem__(self, index: int) -> typing.Any: The returned sample is ready to be concatenated, then fed to a `GPTModel` (see `GPTModel.preprocess`). """ self._lazy_load() + + if self._parameters.use_preference_loss_spans: + if index < self._unshuffled_documents: + document_index = self._doc_length_filtered_indicies[index % self._documents_per_epoch] + else: + document_index = self._doc_length_filtered_indicies[ + self._document_shuffling[index - self._unshuffled_documents].item() + ] + + sample = self._indexed_dataset.get( + document_index, + offset=0, + length=self._document_sizes[document_index], + use_loss_masking_spans=self._parameters.use_loss_masking_spans, + use_preference_loss_spans=self._parameters.use_preference_loss_spans, + ) + + chosen_span_end = sample.chosen_span[1] + 1 + sequence_lengths = [ + chosen_span_end, + len(sample.token_ids) - chosen_span_end, + ] + + # compute padding size + padding = np.full((self._parameters.sequence_length + 1,), 0) + padding[: len(sample.token_ids)] = sample.token_ids + sequence_lengths.append(self._parameters.sequence_length - len(sample.token_ids)) + sample.token_ids = padding + + if not self._parameters.cross_document_attention: + sample.sequence_lengths = np.array(sequence_lengths) + + return sample + # tokens at the boundary are included in only one sample when we pack without truncations # in case of packing with truncations, the last token from the previous sample is also the first token of the next sample sample_length = ( @@ -454,7 +520,9 @@ def _lazy_load(self): def _load_yaml_data(self, data: dict[str, typing.Any]) -> None: self._documents_per_epoch = data["dataset"]["documents_per_epoch"] - if "unshuffled_tokens" not in data: + if self._parameters.use_preference_loss_spans: + data["unshuffled_tokens"] = 0 # not used, ignore + elif "unshuffled_tokens" not in data: # Backward compatibility # TODO v0.x: Remove assert self._truncate_documents @@ -485,6 +553,8 @@ def __init__( ) self._config = sampling.config self._parameters = sampling.parameters + if self._parameters.use_preference_loss_spans: + raise NotImplementedError("Legacy sampling does not support preference loss masking.") if sampling.cache_directory is None: log_main_rank( diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index 2c4311c37..ce60f00e0 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -59,6 +59,12 @@ class GPTHuggingfaceDatasetConfig(Config): loss_masking_spans: None | str = Field( default=None, desc="Field containing character spans to mask for loss computation", hint=FieldHint.optional ) + chosen_text: None | str = Field( + default=None, desc="Field containing chosen text for preference optimization", hint=FieldHint.optional + ) + rejected_text: None | str = Field( + default=None, desc="Field containing rejected text for preference optimization", hint=FieldHint.optional + ) data_type: DataType | None = Field( default=None, desc="Data type of the dataset field." diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 23e497bf8..0cba3aa1c 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -74,6 +74,70 @@ def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict "num_tokens": num_tokens, } + def _tokenize_preference_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]: + packed_texts = [] + chosen_spans = [] + rejected_spans = [] + + for conv_history, chosen_text, rejected_text in zip( + batch[self._config.dataset.field], + batch[self._config.dataset.chosen_text], + batch[self._config.dataset.rejected_text], + ): + # compute chosen span + full_chosen_text = conv_history + chosen_text + self._tokenizer.tokenizer.eos_token + chosen_span = [len(conv_history), len(full_chosen_text) - 1] + offset = len(full_chosen_text) + chosen_spans.append(chosen_span) + + # compute rejected span + full_rejected_text = self._tokenizer.tokenizer.bos_token + conv_history + rejected_text + rejected_span = [ + offset + len(self._tokenizer.tokenizer.bos_token + conv_history), + offset + len(full_rejected_text) - 1, + ] + rejected_spans.append(rejected_span) + + # pack texts + packed_text = full_chosen_text + full_rejected_text + + assert ( + packed_text[chosen_span[0] : chosen_span[1] + 1] == chosen_text + self._tokenizer.tokenizer.eos_token + ), f"{packed_text[chosen_span[0]: chosen_span[1] + 1]} does not match {chosen_text}" + + assert ( + packed_text[rejected_span[0] : rejected_span[1] + 1] == rejected_text + ), f"{packed_text[rejected_span[0]: rejected_span[1] + 1]} does not match {rejected_text}" + packed_texts.append(packed_text) + + # tokenize with spans + input_ids, chosen_token_spans, rejected_token_spans = map( + list, + zip( + *[ + ( + np.array(input_ids, dtype=self._data_type.numpy), + np.array(token_spans[0], dtype=np.int32), + np.array( + [token_spans[1][0], token_spans[1][1] + 1], dtype=np.int32 + ), # adding 1 to end for eos token + ) + for input_ids, token_spans in [ + self._tokenizer.tokenize_with_spans(text, [chosen_span, rejected_span]) + for text, chosen_span, rejected_span in zip(packed_texts, chosen_spans, rejected_spans) + ] + ] + ), + ) + + num_tokens = [len(x) for x in input_ids] + return { + "input_ids": input_ids, + "chosen_token_spans": chosen_token_spans, + "rejected_token_spans": rejected_token_spans, + "num_tokens": num_tokens, + } + def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetConfig: shard_idx, shard_dataset = args prefix = f"shard_{self._config.distributed.rank}_{shard_idx}" @@ -86,6 +150,18 @@ def _document_generator(): np.array(item["input_ids"], dtype=self._data_type.numpy), np.array(item["token_spans"], dtype=np.int32).reshape(-1, 2), ) + elif ( + "chosen_token_spans" in shard_dataset.column_names + and "rejected_token_spans" in shard_dataset.column_names + and self._config.dataset.chosen_text is not None + and self._config.dataset.rejected_text is not None + ): + for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): + yield GPTSample( + token_ids=np.array(item["input_ids"], dtype=self._data_type.numpy), + chosen_span=np.array(item["chosen_token_spans"], dtype=np.int32).reshape(-1, 2), + rejected_span=np.array(item["rejected_token_spans"], dtype=np.int32).reshape(-1, 2), + ) else: for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"): yield GPTSample(np.array(item["input_ids"], dtype=self._data_type.numpy)) @@ -214,10 +290,24 @@ def run(self) -> None: ) if self._config.dataset.field not in dataset.column_names: raise ValueError(f"Dataset does not have field '{self._config.dataset.field}'.") + if self._config.dataset.loss_masking_spans is not None and ( + self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None + ): + raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.") + if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None): + raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.") + + # route tokenize function if self._config.dataset.loss_masking_spans is not None: if self._config.dataset.loss_masking_spans not in dataset.column_names: raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.") tokenize_fn = self._tokenize_batch_with_spans + elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None: + if self._config.dataset.chosen_text not in dataset.column_names: + raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.") + if self._config.dataset.rejected_text not in dataset.column_names: + raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.") + tokenize_fn = self._tokenize_preference_batch_with_spans else: tokenize_fn = self._tokenize_batch diff --git a/fast_llm/data/tokenizer.py b/fast_llm/data/tokenizer.py index 28e105ee8..988e23e76 100644 --- a/fast_llm/data/tokenizer.py +++ b/fast_llm/data/tokenizer.py @@ -52,6 +52,7 @@ def tokenize_with_spans( token_spans = [] char_pos = 0 beginning_of_text = True + for start, end in char_spans: if char_pos < start: curr_text = text[char_pos:start] diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 44a5f677f..91ce0d892 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -203,6 +203,7 @@ def _create_index(self) -> None: Assert.incl(step.type_, (StepType.forward, StepType.backward)) step.global_index = i # TODO: More configurable placement? + step.pipeline_rank = step.stage % self._distributed.pipeline_parallel step.local_index = len(self._device_steps[step.pipeline_rank]) self._device_steps[step.pipeline_rank].append(step) @@ -222,6 +223,7 @@ def _create_index(self) -> None: Assert.empty(step_map) # Related steps + for i, step in enumerate(self._steps): if self._is_training: if step.type_ == StepType.forward: @@ -229,6 +231,7 @@ def _create_index(self) -> None: step.backward_step = self.get_step(StepType.backward, *step.map_index[1:]) else: step.forward_step = self.get_step(StepType.forward, *step.map_index[1:]) + if step.type_ == StepType.forward and step.stage == 0: step.prev_step = None elif step.type_ == StepType.backward and step.stage == self._num_stages - 1: diff --git a/fast_llm/functional/dpo.py b/fast_llm/functional/dpo.py new file mode 100644 index 000000000..3a70f308f --- /dev/null +++ b/fast_llm/functional/dpo.py @@ -0,0 +1,78 @@ +import torch + + +def _compute_logprobs_for_preference_spans( + logits: torch.Tensor, targets: torch.Tensor, chosen_spans: torch.Tensor, rejected_spans: torch.Tensor +): + assert torch.all(targets < logits.size(-1)), "Target out of vocab range" + + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + + # gather log probabilities corresponding to the target tokens + selected_log_probs = log_probs.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) + + # apply chosen mask + chosen_logp = 0 + for idx, span in enumerate(chosen_spans): + chosen_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() + + # apply rejected mask + rejected_logp = 0 + for idx, span in enumerate(rejected_spans): + rejected_logp += selected_log_probs[idx][span[0].item() : span[1].item() + 1].sum() + + return chosen_logp, rejected_logp, selected_log_probs + + +def _compute_dpo_loss( + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + beta: float, +): + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + diff_logratios = pi_logratios - ref_logratios + + losses = -torch.nn.functional.logsigmoid(beta * diff_logratios) + return losses + + +def compute_dpo_loss( + logits: torch.Tensor, + targets: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, + beta: float, + grad_output: float | None, +) -> tuple[torch.Tensor, torch.Tensor]: + with torch.enable_grad(): + logits_ = logits.float().detach().requires_grad_() + reference_model_logits_ = reference_model_logits.float().detach() + + policy_chosen_logps, policy_rejected_logps, _ = _compute_logprobs_for_preference_spans( + logits_, targets, chosen_spans, rejected_spans + ) + + reference_chosen_logps, reference_rejected_logps, _ = _compute_logprobs_for_preference_spans( + reference_model_logits_, targets, chosen_spans, rejected_spans + ) + + losses = _compute_dpo_loss( + policy_chosen_logps=policy_chosen_logps, + policy_rejected_logps=policy_rejected_logps, + reference_chosen_logps=reference_chosen_logps, + reference_rejected_logps=reference_rejected_logps, + beta=beta, + ) + + if grad_output is None: + loss = None + else: + loss = losses.mean() + loss.backward(torch.full_like(loss, grad_output)) + loss.detach() + return loss.detach(), logits_.grad.detach().to(logits.dtype) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index d0f03ccf2..89625b552 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -34,6 +34,8 @@ class LanguageModelKwargs: # TODO: These are generic labels = "labels" phase = "phase" + chosen_spans = "chosen_spans" + rejected_spans = "rejected_spans" loss_mask = "loss_mask" @@ -88,6 +90,21 @@ class LanguageModelBaseConfig(BaseModelConfig): desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", hint=FieldHint.feature, ) + enable_dpo: bool | None = Field( + default=False, + desc="Whether to enable DPO loss", + hint=FieldHint.feature, + ) + dpo_beta: float | None = Field( + default=1.0, + desc="Beta value for DPO loss.", + hint=FieldHint.feature, + ) + dpo_reference_model: str | None = Field( + default=None, + desc="Name of the reference model to use for dpo.", + hint=FieldHint.feature, + ) cross_entropy_impl: CrossEntropyImpl = Field( default=CrossEntropyImpl.auto, desc="Implementation for the cross-entropy computation.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 813dcc076..20f43dd30 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -12,6 +12,7 @@ from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward +from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( @@ -73,14 +74,18 @@ def __init__( self._init_output_weights(hidden_dim, config) - self._cross_entropy_impl = config.cross_entropy_impl - if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._parallel_embeddings: - self._cross_entropy_impl = CrossEntropyImpl.fused - elif TritonConfig.TRITON_ENABLED: - self._cross_entropy_impl = CrossEntropyImpl.triton - else: - self._cross_entropy_impl = CrossEntropyImpl.fused + self._use_dpo_loss = config.enable_dpo + if self._use_dpo_loss: + self.dpo_beta = config.dpo_beta + else: + self._cross_entropy_impl = config.cross_entropy_impl + if self._cross_entropy_impl == CrossEntropyImpl.auto: + if self._parallel_embeddings: + self._cross_entropy_impl = CrossEntropyImpl.fused + elif TritonConfig.TRITON_ENABLED: + self._cross_entropy_impl = CrossEntropyImpl.triton + else: + self._cross_entropy_impl = CrossEntropyImpl.fused self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) @@ -143,12 +148,12 @@ def _forward_backward( ) -> tuple[torch.Tensor, torch.Tensor | None]: target = kwargs.get( LanguageModelKwargs.labels - if self._config.distillation_model is None + if self._use_dpo_loss or self._config.distillation_model is None else f"{self._config.distillation_model}_logits" ) # Loss mask for distillation. (Labels are already masked.) loss_mask = None - if target is not None: + if target is not None and not self._use_dpo_loss: if self._config.distillation_model is None: # MTP: Shift the labels target_sequence_length = ( @@ -309,16 +314,28 @@ def _logits_cross_entropy_forward_backward( if target is None: return logits * self._logits_scale_factor, None - loss, grad = cross_entropy_forward_backward( - logits.flatten(0, -2), - target, - loss_mask, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output, - implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, - target_format=TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits, - ) + if self._use_dpo_loss: + loss, grad = compute_dpo_loss( + logits, + target, + kwargs.get(f"{self._config.dpo_reference_model}_logits"), + kwargs[LanguageModelKwargs.chosen_spans], + kwargs[LanguageModelKwargs.rejected_spans], + self.dpo_beta, + grad_output, + ) + else: + loss, grad = cross_entropy_forward_backward( + logits.flatten(0, -2), + target, + loss_mask, + group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, + grad_output=grad_output, + implementation=self._cross_entropy_impl, + logits_scale_factor=self._logits_scale_factor, + target_format=TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits, + ) + # TODO: de-allocate earlier. del logits return loss, output_parallel_linear_backward(grad, context) diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index 7e95bb5cc..d719bef3d 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -69,3 +69,54 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: tensor_name=LanguageModelKwargs.position_ids, dtype=torch.int64, ) + + +class PreferenceSpanPreprocessor(Preprocessor): + def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + self._config = config + self._tensor_space = tensor_space + self._distributed_config = self._tensor_space.distributed_config + self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + + def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: + return + + def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + sequence_q = kwargs[TransformerKwargs.sequence_q_dim].size + sequence_k = kwargs[TransformerKwargs.sequence_k_dim].size + sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels + + if LanguageModelKwargs.chosen_spans not in kwargs or LanguageModelKwargs.rejected_spans not in kwargs: + raise ValueError("Expected chosen spans or rejected spans to be found within the batch.") + + chosen_spans = kwargs[LanguageModelKwargs.chosen_spans] + chosen_valid_spans = [] + for spans in chosen_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + chosen_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.chosen_spans] = chosen_valid_spans + + rejected_spans = kwargs[LanguageModelKwargs.rejected_spans] + rejected_valid_spans = [] + for spans in rejected_spans: + if not spans.numel(): + continue + # only keep spans within the sequence or partially within the sequence + valid_spans = spans[(spans[0] <= sequence_k) & (spans[1] >= sequence_offset)][0] + if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence + valid_spans[0].clamp_(min=sequence_offset) + valid_spans[1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + + rejected_valid_spans.append(valid_spans) + kwargs[LanguageModelKwargs.rejected_spans] = rejected_valid_spans diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 418f948e3..df4222772 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -173,14 +173,29 @@ def _validate(self) -> None: if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() - if (name := self.model.base_model.distillation_model) is None: - Assert.empty(self.reference_models) - else: - Assert.eq(self.reference_models.keys(), {name}) + if self.model.base_model.use_absolute_position_embeddings: Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + + distillation_model = self.model.base_model.distillation_model + dpo_reference_model = self.model.base_model.dpo_reference_model + + if self.model.base_model.enable_dpo: + assert dpo_reference_model is not None + Assert.none(distillation_model) + else: + Assert.none(dpo_reference_model) + + if distillation_model is None and dpo_reference_model is None: + Assert.empty(self.reference_models) + else: + assert distillation_model is None or dpo_reference_model is None # currently don't support both + expected_names = {name for name in (distillation_model, dpo_reference_model) if name is not None} + Assert.eq(self.reference_models.keys(), expected_names) + for reference_model in self.reference_models.values(): Assert.none(reference_model.model.base_model.distillation_model) + Assert.none(reference_model.model.base_model.dpo_reference_model) # TODO: Support more LM head features. Assert.none(reference_model.model.base_model.cross_entropy_splits) Assert.eq(reference_model.model.base_model.parallel_embeddings, self.model.base_model.parallel_embeddings) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index d177a41d2..b548ab525 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -13,7 +13,7 @@ from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor +from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.config import ( RoutingType, TransformerDimNames, @@ -70,6 +70,9 @@ def __init__( else: self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + if self._config.enable_dpo: # TODO better way to pass in? + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + def get_output_layers(self) -> list[Layer]: layers = [] for i in range(self._config.prediction_heads): @@ -283,6 +286,10 @@ def preprocess( tokens = batch.token_ids[:, sequence_k - sequence_q : sequence_k].contiguous() if batch.sequence_lengths is not None: kwargs_meta[TransformerKwargs.sequence_lengths] = batch.sequence_lengths + if batch.chosen_spans is not None: + kwargs_meta[LanguageModelKwargs.chosen_spans] = batch.chosen_spans + if batch.rejected_spans is not None: + kwargs_meta[LanguageModelKwargs.rejected_spans] = batch.rejected_spans # TODO: Add pasts/presents to meta input? # Use lists as pointers so `past_key_values` is populated during the previous micro_sequence. @@ -294,7 +301,7 @@ def preprocess( TransformerKwargs.presents: presents, } if phase != PhaseType.inference: - sequence_offset = sequence_k - sequence_q + 1 + sequence_offset = sequence_k - sequence_q + 1 # +1 for shift in labels if sequence_first: labels = batch.token_ids[sequence_offset : sequence_k + prediction_heads] else: @@ -312,6 +319,7 @@ def preprocess( (spans[:, 0] <= sequence_k + prediction_heads - 1) & (spans[:, 1] >= sequence_offset) ] if valid_spans.numel(): + # if span is partially within the sequence, truncate parts of spans that are outside of the sequence valid_spans[:, 0].clamp_(min=sequence_offset) valid_spans[:, 1].clamp_(max=sequence_k + prediction_heads - 1) valid_spans -= sequence_offset diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 3bdb05c3a..cc39d7f70 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -28,6 +28,7 @@ def _get_sampling_parameters( "vocab_size": self._config.model.base_model.vocab_size, "sequence_length": self._config.batch.sequence_length, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, + "use_preference_loss_spans": self._config.model.base_model.enable_dpo, "cross_document_attention": self._config.batch.cross_document_attention, "extra_tokens": self._config.model.base_model.prediction_heads, } diff --git a/tests/data/test_prepare_gpt_memmap.py b/tests/data/test_prepare_gpt_memmap.py index 9dd7975c2..17ba5de01 100644 --- a/tests/data/test_prepare_gpt_memmap.py +++ b/tests/data/test_prepare_gpt_memmap.py @@ -39,6 +39,43 @@ def test_write_memmap_dataset(dtype): ), f"Mismatch for document {i}: {document} != {dataset.get(i)}." +@pytest.mark.parametrize("dtype", MEMMAP_DTYPES.values()) +def test_write_memmap_preference_dataset(dtype): + def generate_valid_span(max_seq_length): + span = np.random.choice(np.arange(0, max_seq_length - 1), size=2, replace=False) + return np.sort(span) + + vocab_size = 1000 + max_seq_length = 8192 + num_samples = 100 + + documents = [ + GPTSample( + token_ids=np.random.randint(vocab_size, size=max_seq_length).astype(dtype), + chosen_span=generate_valid_span(max_seq_length=max_seq_length), + rejected_span=generate_valid_span(max_seq_length=max_seq_length), + ) + for _ in range(num_samples) + ] + with tempfile.TemporaryDirectory() as temp_dir: + prefix = pathlib.Path(temp_dir) + GPTMemmapDataset.write_dataset(prefix=prefix, documents=documents) + dataset = GPTMemmapDataset(name="foo", prefix=prefix) + for i, document in enumerate(documents): + dataset_item = dataset.get(i, use_preference_loss_spans=True) + assert np.array_equal( + dataset_item.token_ids, document.token_ids, equal_nan=True + ), f"Token ids mismatch for document {i}: {document} != {dataset.get(i)}." + + assert np.array_equal( + dataset_item.chosen_span, document.chosen_span, equal_nan=True + ), f"Chosen loss masking spans mismatch for document {i}: {document.chosen_span} != {dataset.get(i).chosen_span}." + + assert np.array_equal( + dataset_item.rejected_span, document.rejected_span, equal_nan=True + ), f"Rejected loss masking spans mismatch for document {i}: {document.rejected_span} != {dataset.get(i).rejected_span}." + + def test_load_metadata_from_hub(): with tempfile.TemporaryDirectory(suffix="test") as local_folder: get_preparator(local_folder, "lhoestq/demo1")._save_croissant_metadata() diff --git a/tests/test_functional.py b/tests/test_functional.py index 3e5c7f873..34a7a77fc 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,13 +1,162 @@ +import random + import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.functional.dpo import _compute_dpo_loss, _compute_logprobs_for_preference_spans from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert from tests.common import requires_cuda +def ref_log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: + if temperature != 1.0: + logits.div_(temperature) + batch_dim = logits.shape[:-1] + last_dim = logits.shape[-1] + + output = torch.nn.functional.cross_entropy(logits.reshape(-1, last_dim), labels.reshape(-1), reduction="none") + log_probs_labels = -output.view(*batch_dim) + + return log_probs_labels + + +def ref_packed_get_batch_logps( + logits: torch.FloatTensor, + labels: torch.LongTensor, + attention_mask, + prompt_id_lens, + packed_seq_lens, +) -> torch.FloatTensor: + labels = labels[:, 1:] + logits = logits[:, :-1, :] + per_token_logps = ref_log_probs_from_logits(logits, labels) + + loss_masks = attention_mask.clone().bool() + + index = 0 + for i, seq_len in enumerate(packed_seq_lens): + loss_masks[0, index : index + prompt_id_lens[i]] = False + index = index + seq_len + + loss_masks = loss_masks[:, 1:] + + logprobs_sums = [] + index = 0 + for i, seq_len in enumerate(packed_seq_lens): + seq = per_token_logps[0, index : index + seq_len - 1] + mask = loss_masks[0, index : index + seq_len - 1] + logprobs_sums.append((seq * mask).sum()) + index = index + seq_len + chosen_logps = logprobs_sums[: len(packed_seq_lens) // 2] + rejected_logps = logprobs_sums[len(packed_seq_lens) // 2 :] + + return torch.tensor(chosen_logps), torch.tensor(rejected_logps) + + +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("seq_length", [1024, 4096, 8192]) +@pytest.mark.parametrize("vocab_size", [1000, 2000, 8000]) +def test_preference_logps(batch_size, seq_length, vocab_size): + random.seed(0) + torch.manual_seed(0) + + def random_split(seq_length): + min_val = int(seq_length * 0.3) + max_val = int(seq_length * 0.7) + + if max_val < min_val: + max_val = min_val + + a = random.randint(min_val, max_val) + b = seq_length - a + return [a, b] + + logits = torch.randn(batch_size, seq_length, vocab_size) + targets = torch.randint(0, vocab_size, (batch_size, seq_length)) + packed_seq_lens = random_split(seq_length) # simulate different chosen/rejected lengths + prompt_id_lens = [int(min(packed_seq_lens) * 0.75)] * 2 # sequences are 75% prompt 25% generation + attention_mask = torch.tensor([1] * packed_seq_lens[0] + [2] * packed_seq_lens[1]).unsqueeze(0) + + chosen_span = torch.tensor([[prompt_id_lens[0], packed_seq_lens[0] - 1]]) - 1 # shift by 1 due to label shifting + rejected_span = ( + torch.tensor([[packed_seq_lens[0] + prompt_id_lens[1], packed_seq_lens[0] + packed_seq_lens[1] - 1]]) - 1 + ) # shift by 1 due to label shifting + + ref_chosen_logps, ref_rejected_logps = ref_packed_get_batch_logps( + logits, targets, attention_mask, prompt_id_lens, packed_seq_lens + ) + + chosen_logps, rejected_logps, selected_log_probs = _compute_logprobs_for_preference_spans( + logits=logits, + targets=targets[:, 1:], + chosen_spans=chosen_span, + rejected_spans=rejected_span, + ) + + ref_logps = ref_log_probs_from_logits(logits[:, :-1, :], targets[:, 1:]) + + # check all logps + Assert.custom(torch.allclose, ref_logps, selected_log_probs, rtol=1e-5) + + # check chosen and rejected summed logps + Assert.custom(torch.allclose, ref_chosen_logps, chosen_logps, rtol=1e-5) + Assert.custom(torch.allclose, ref_rejected_logps, rejected_logps, rtol=1e-5) + + +def ref_dpo_loss_fcn( + policy_chosen_logps: torch.Tensor, + policy_rejected_logps: torch.Tensor, + reference_chosen_logps: torch.Tensor, + reference_rejected_logps: torch.Tensor, + beta=1, + label_smoothing=0, +) -> torch.Tensor: + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + logits = pi_logratios - ref_logratios + + # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) + losses = ( + -torch.nn.functional.logsigmoid(beta * logits) * (1 - label_smoothing) + - torch.nn.functional.logsigmoid(-beta * logits) * label_smoothing + ) + + loss = losses.mean() + + return loss + + +def test_dpo_loss(): + torch.manual_seed(0) + + NUM_SAMPLES = 20 + policy_chosen_logps = torch.rand(NUM_SAMPLES) + policy_rejected_logps = torch.rand(NUM_SAMPLES) + reference_chosen_logps = torch.rand(NUM_SAMPLES) + reference_rejected_logps = torch.rand(NUM_SAMPLES) + betas = torch.rand(NUM_SAMPLES) + + for i in range(NUM_SAMPLES): + fastllm_dpo_loss = _compute_dpo_loss( + policy_chosen_logps=policy_chosen_logps[i], + policy_rejected_logps=policy_rejected_logps[i], + reference_chosen_logps=reference_chosen_logps[i], + reference_rejected_logps=reference_rejected_logps[i], + beta=betas[i].item(), + ) + ref_dpo_loss = ref_dpo_loss_fcn( + policy_chosen_logps=policy_chosen_logps[i].unsqueeze(0), + policy_rejected_logps=policy_rejected_logps[i].unsqueeze(0), + reference_chosen_logps=reference_chosen_logps[i].unsqueeze(0), + reference_rejected_logps=reference_rejected_logps[i].unsqueeze(0), + beta=betas[i].item(), + ) + Assert.rms_close(fastllm_dpo_loss, ref_dpo_loss, 1e-5) + + @requires_cuda @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize(