Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
5137757
stuff
jlamypoirier Mar 26, 2025
f0cb32a
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Mar 26, 2025
f26010e
Update pretrained config
jlamypoirier Mar 27, 2025
b930a39
stuff
jlamypoirier Mar 27, 2025
918a7a8
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Mar 27, 2025
8117c47
fixes
jlamypoirier Mar 27, 2025
1c995d3
fix
jlamypoirier Mar 27, 2025
3f90475
Merge branch 'main' into config_updates
jlamypoirier Mar 27, 2025
e389058
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Mar 27, 2025
506fe92
fixes
jlamypoirier Mar 27, 2025
971d3ef
fixes
jlamypoirier Mar 27, 2025
6bf20cb
Tests wip
jlamypoirier Mar 28, 2025
c13fb19
misc
jlamypoirier Mar 29, 2025
a20fcec
tests
jlamypoirier Apr 1, 2025
9af26a7
Merge branch 'main' into config_updates
jlamypoirier Apr 1, 2025
9af372d
Tests, fixes, remove tuple format
jlamypoirier Apr 1, 2025
dded00a
fix
jlamypoirier Apr 2, 2025
42d5ca4
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Apr 2, 2025
986f9f3
fix
jlamypoirier Apr 2, 2025
5abc087
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 2, 2025
8e3e795
fixes
jlamypoirier Apr 2, 2025
da6eb7b
fixes
jlamypoirier Apr 3, 2025
67e08aa
Merge branch 'main' into config_updates
jlamypoirier Apr 3, 2025
a09e6f3
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 3, 2025
baad705
fix
jlamypoirier Apr 3, 2025
b702837
Test, fixes
jlamypoirier Apr 5, 2025
a8684f8
Knowledge distillation, fix cross-entropy
jlamypoirier Apr 11, 2025
b781729
Fixes, distillation
jlamypoirier Apr 13, 2025
db6504b
fixes
jlamypoirier Apr 14, 2025
7c2933a
Merge remote-tracking branch 'origin/main' into config_updates
jlamypoirier Apr 14, 2025
a017c11
Merge branch 'config_updates' into update_pretrained_config
jlamypoirier Apr 14, 2025
368a6bf
Merge remote-tracking branch 'origin/main' into update_pretrained_config
jlamypoirier Apr 14, 2025
e0c82a0
Merge remote-tracking branch 'origin/main' into distillation
jlamypoirier Apr 14, 2025
16a3dd7
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 14, 2025
cff9892
fixes
jlamypoirier Apr 14, 2025
793ecde
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 14, 2025
b67006a
fixes
jlamypoirier Apr 15, 2025
2014108
Add constraints
jlamypoirier Apr 16, 2025
4fb78e4
Merge remote-tracking branch 'origin/main' into distillation
jlamypoirier Apr 16, 2025
fa3d556
Add constraints
jlamypoirier Apr 16, 2025
48141e5
Merge remote-tracking branch 'origin/main' into update_pretrained_config
jlamypoirier Apr 17, 2025
e6e5a32
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 17, 2025
537deca
fix
jlamypoirier Apr 17, 2025
3d5dc94
Merge commit '6ad0a96c9328234b907d01a82c4c52bd48752b2f' into update_p…
jlamypoirier Apr 18, 2025
2bb0c08
Merge branch 'update_pretrained_config' into distillation
jlamypoirier Apr 18, 2025
067ba97
Merge remote-tracking branch 'origin/main' into distillation
jlamypoirier Apr 21, 2025
d2b3154
misc
jlamypoirier Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __enter__(self):
global _AUTO_VALIDATE
self._old_value = _AUTO_VALIDATE
_AUTO_VALIDATE = False
return _AUTO_VALIDATE

def __exit__(self, exc_type, exc_val, exc_tb):
global _AUTO_VALIDATE
Expand Down
3 changes: 2 additions & 1 deletion fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ def _validate(self) -> None:
if self.reference_config.reference_config is not None:
self.reference_config = self.reference_config.reference_config
assert self.reference_config.reference_config is None
self.compare(self.reference_config, ValueError)
self.distributed_dims = self.reference_config.distributed_dims
else:
self.distributed_dims = {}
Expand Down Expand Up @@ -368,6 +367,8 @@ def _validate(self) -> None:

super()._validate()

if self.reference_config is not None:
self.compare(self.reference_config, ValueError)
Assert.in_range(self.rank, 0, self.world_size)
Assert.in_range(self.local_rank, 0, self.local_world_size)

Expand Down
3 changes: 2 additions & 1 deletion fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def _validate(self) -> None:
self.pretrained.setup(self.model)
self.pretrained.validate()
if self.pretrained.path is not None:
self.model = self.model.from_pretrained(self.pretrained, self.model)
with NoAutoValidate():
self.model = self.model.from_pretrained(self.pretrained, self.model)
self._setup()
super()._validate()

Expand Down
15 changes: 6 additions & 9 deletions fast_llm/engine/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,15 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig):

def _validate(self) -> None:
self.training.export.setup(self.model)
self.model.validate()
for reference_model in self.reference_models.values():
_add_reference_distributed_to_pretrained(reference_model, self.model.distributed)
super()._validate()
if self.reference_models:
# TODO: Add support.
Assert.eq(self.model.distributed.pipeline_parallel, 1)
# TODO: Check if these work.
Assert.eq(self.model.distributed.tensor_parallel, 1)
Assert.eq(self.model.distributed.sequence_data_parallel, 1)

for reference_model in self.reference_models.values():
_add_reference_distributed_to_pretrained(reference_model, self.model.distributed)
super()._validate()
if self.run.experiment_dir is None:
assert not self.training.checkpoint.enabled()

Expand Down Expand Up @@ -431,13 +429,12 @@ def _add_reference_distributed_to_pretrained(pretrained: PretrainedFastLLMModelC

def new_setup():
# Make sure the distributed config isn't set
# TODO!!!!!!!!!!!!!: Uncomment after #205
# pretrained.model.distributed.validate()
# Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"})
pretrained.model.distributed.validate()
Assert.leq(pretrained.model.distributed.to_dict().keys(), {"world_size", "rank", "local_world_size"})
with NoAutoValidate():
pretrained.model.distributed = distributed.to_copy()
# Allow sharing the `Distributed` instance.
pretrained.model.distributed.reference_config = distributed
old_setup()

pretrained._setup = new_setup
object.__setattr__(pretrained, "_setup", new_setup)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit of a hack, no?

Copy link
Copy Markdown
Collaborator Author

@jlamypoirier jlamypoirier Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of, adapting the current setattr to deal with method override would be very difficult so I'm bypassing it. (And that's how frozen dataclasses do it)

6 changes: 6 additions & 0 deletions fast_llm/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,9 @@ class CrossEntropyImpl(str, enum.Enum):
torch = "torch"
fused = "fused"
triton = "triton"


class TargetFormat(enum.StrEnum):
labels = "labels"
logits = "logits"
probabilities = "probabilities"
191 changes: 108 additions & 83 deletions fast_llm/functional/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.autograd

from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce
from fast_llm.functional.config import CrossEntropyImpl
from fast_llm.functional.config import CrossEntropyImpl, TargetFormat
from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward
from fast_llm.utils import Assert

Expand All @@ -12,34 +12,67 @@ def torch_cross_entropy_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
grad_output: float | None,
logits_scale_factor: float = 1.0,
logits_scale_factor: float,
target_format: TargetFormat,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
A wrapper for the pytorch implementation of cross-entropy.
The cross-entropy kernels themselves are well-optimized, but the need for explicit casting
and separate forward and backward kernels lead to poor performance.
TODO: loss masking only works for this method if the masking index is set to -100.
TODO: loss masking only works for with labels format and if the masking index is set to -100.
"""
# Torch compile doesn't understand this.
with torch.enable_grad():
logits_ = logits.float().detach().requires_grad_()
if logits_scale_factor != 1.0:
logits_ *= logits_scale_factor
with torch.set_grad_enabled(grad_output is not None):
logits_ = logits.float().detach().requires_grad_(grad_output is not None)
if target_format == TargetFormat.logits:
if logits_scale_factor != 1.0:
target = target * logits_scale_factor
target = torch.softmax(target, dim=-1)
loss = torch.nn.functional.cross_entropy(
logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target
).mean()
if grad_output is None:
loss = None
grad = None
else:
loss = torch.nn.functional.cross_entropy(logits_, target).mean()
loss.backward(torch.full_like(loss, grad_output))
loss.detach_()
return loss.detach(), logits_.grad.detach().to(logits.dtype)
grad = logits_.grad.detach().to(logits.dtype)
return loss.detach_(), grad


# @torch.compile
def _fused_softmax_base(
logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
logits = logits.float()
if logits_scale_factor != 1.0:
logits *= logits_scale_factor
logits_max = torch.max(logits, dim=dim, keepdim=True)[0]
if group is not None:
all_reduce(logits_max, op=ReduceOp.MAX, group=group)
logits_norm = (logits - logits_max).float()
exp_logits = logits_norm.exp()
sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True)
if group is not None:
all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group)
return logits_norm, exp_logits, sum_exp_logits


# @torch.compile
def fused_softmax(
logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup = None, dim: int = -1
) -> torch.Tensor:
_, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group, dim)
return exp_logits / sum_exp_logits


@torch.compile
def fused_cross_entropy_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
grad_output: float | None,
logits_scale_factor: float = 1.0,
logits_scale_factor: float,
target_format: TargetFormat,
group: ProcessGroup | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
A fused implementation of cross-entropy with torch compile.
Expand All @@ -48,82 +81,67 @@ def fused_cross_entropy_forward_backward(
"""
# Do the forward and backward passes all at once, and fused with dtype conversion.
# Way faster and more memory-efficient than the pytorch version.
loss_mask = target >= 0
# Ignore_index can go out of bounds, so set masked values to zero.
target = (target * loss_mask).unsqueeze(1)
logits_norm = logits.sub(torch.max(logits, dim=-1)[0].unsqueeze(dim=-1)).float()
if logits_scale_factor != 1.0:
logits_norm *= logits_scale_factor
exp_logits = logits_norm.exp()
sum_exp_logits = exp_logits.sum(dim=-1)

if grad_output is None:
grad = None
else:
exp_logits = exp_logits.scatter(1, target, exp_logits.gather(1, target) - sum_exp_logits.unsqueeze(dim=-1))
# exp_logits[torch.arange(0, logits.size(0), device=logits.device), target.squeeze(dim=-1)]-=sum_exp_logits
exp_logits = exp_logits.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1))

if logits_scale_factor != 1.0:
exp_logits *= logits_scale_factor

grad = torch.where(loss_mask.unsqueeze(1), exp_logits.to(logits.dtype), 0)

per_sample_loss = sum_exp_logits.log().sub(logits_norm.gather(1, target).squeeze(1)) * loss_mask

return per_sample_loss.mean(), grad

logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group)

@torch.compile
def parallel_cross_entropy_forward_backward(
logits: torch.Tensor,
target: torch.Tensor,
grad_output: float | None,
group: ProcessGroup,
logits_scale_factor: float = 1.0,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
A fused implementation of cross-entropy with torch compile, with support for tensor parallelism.
Comes with a noticeable overhead, but reduces memory usage.
"""
# TODO: Compiled version incorrect for some inputs (32 bit indexing issue?).
# TODO: Optimize, overlap/combine reductions
loss_mask = target >= 0
target = target.unsqueeze(1)

logits_max = torch.max(logits, dim=-1)[0]
all_reduce(logits_max, op=ReduceOp.MAX, group=group)
logits_norm = logits.sub(logits_max.unsqueeze(dim=-1)).float()
if logits_scale_factor != 1.0:
logits_norm *= logits_scale_factor

exp_logits = logits_norm.exp()
sum_exp_logits = exp_logits.sum(dim=-1)
all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group)
if target_format == TargetFormat.logits:
target = fused_softmax(target, logits_scale_factor, group)

# Mask the target (fused)
# TODO: Could mask earlier on cpu or overlap with reduce?
vocab_start_index = logits.size(-1) * group.rank()
target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1))
target = (target - vocab_start_index) * target_mask
if target_format == TargetFormat.labels:
target = target.unsqueeze(-1)
loss_mask = target >= 0
if group is None:
# Keep values within range for scatter and gather ops to work.
target = target * loss_mask
target_mask = None
else:
# Mask the target (fused)
# TODO: Could mask earlier on cpu or overlap with reduce?
vocab_start_index = logits.size(-1) * group.rank()
target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1))
target = (target - vocab_start_index) * target_mask
else:
# TODO: Support masking
loss_mask = None
# Target should be tensor-parallel already, no further manipulation needed.
target_mask = None

if grad_output is None:
grad = None
else:
exp_logits1 = exp_logits.scatter(
1, target, exp_logits.gather(1, target) - target_mask * sum_exp_logits.unsqueeze(dim=-1)
)
exp_logits2 = exp_logits1.mul((grad_output / logits.size(0)) / sum_exp_logits.unsqueeze(dim=-1))
# grad / grad_output = exp_logits / sum_exp_logits - target_probabilities.
if target_format == TargetFormat.labels:
grad_base = exp_logits.scatter_add(
1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits)
)
else:
grad_base = exp_logits - sum_exp_logits * target

grad = grad_base.mul((grad_output / logits.size(0)) / sum_exp_logits)
if logits_scale_factor != 1.0:
exp_logits2 *= logits_scale_factor
grad *= logits_scale_factor
grad = grad.to(logits.dtype)
if loss_mask is not None:
grad = torch.where(loss_mask, grad.to(logits.dtype), 0)

# loss = mean(log(sum_exp_logits) - sum(probabilities * logits))
if target_format == TargetFormat.labels:
predicted_logits = logits_norm.gather(1, target)
if group is not None:
predicted_logits = target_mask * predicted_logits
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
else:
predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True)

grad = torch.where(loss_mask.unsqueeze(1), exp_logits2.to(logits.dtype), 0)
per_sample_loss = sum_exp_logits.log() - predicted_logits
if loss_mask is not None:
per_sample_loss = per_sample_loss * loss_mask

predicted_logits = (target_mask * logits_norm.gather(1, target)).squeeze(1)
all_reduce(predicted_logits, op=ReduceOp.SUM, group=group)
per_sample_loss = sum_exp_logits.log().sub(predicted_logits) * loss_mask
loss = per_sample_loss.mean()
if target_format != TargetFormat.labels and group is not None:
all_reduce(loss, op=ReduceOp.MEAN, group=group)

return per_sample_loss.mean(), grad
return loss, grad


_CROSS_ENTROPY_IMPLEMENTATIONS = {
Expand All @@ -134,25 +152,32 @@ def parallel_cross_entropy_forward_backward(


def cross_entropy_forward_backward(
logits,
target,
logits: torch.Tensor,
target: torch.Tensor,
grad_output: float | None,
group: ProcessGroup | None,
group: ProcessGroup | None = None,
implementation: CrossEntropyImpl = CrossEntropyImpl.fused,
logits_scale_factor: float = 1.0,
target_format: TargetFormat = TargetFormat.labels,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
Select the appropriate implementation of cross-entropy.
The triton implementation from the triton submodule is the fastest and recommended one.
It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way,
which is faster and has a relatively small memory overhead.
"""
if target_format == TargetFormat.labels:
Assert.eq(target.shape, logits.shape[:-1])
Assert.eq(target.dtype, torch.int64)
else:
Assert.eq(target.shape, logits.shape)
assert target.dtype.is_floating_point, target.dtype
if group:
Assert.eq(implementation, CrossEntropyImpl.fused)
return parallel_cross_entropy_forward_backward(
logits, target, grad_output, group, logits_scale_factor=logits_scale_factor
return fused_cross_entropy_forward_backward(
logits, target, grad_output, logits_scale_factor, target_format, group
)
else:
return _CROSS_ENTROPY_IMPLEMENTATIONS[implementation](
logits, target, grad_output, logits_scale_factor=logits_scale_factor
logits, target, grad_output, logits_scale_factor, target_format
)
Loading