Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fast_llm/engine/checkpoint/safe_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __enter__(self) -> "SafeLoad":
triton_fill(self_shard, math.nan)
# Reset and count shard pads
for _, fsdp, fsdp_shards in self._model.split_shards_by_fsdp(self._self_shards):
for fsdp_shard in fsdp_shards.values():
self._loaded += fsdp.reset_shard_pad(fsdp_shard)
for shard_name, fsdp_shard in fsdp_shards.items():
self._loaded += fsdp.reset_shard_pad(fsdp_shard, shard_name)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
Expand Down
5 changes: 5 additions & 0 deletions fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
logger = logging.getLogger(__name__)


class ShardName:
weights = "weights"
grads = "grads"


class StageMode(str, enum.Enum):
# Allow forward and backward passes and optimizer.
# TODO: Add mode for forward and backward but not optimizer?
Expand Down
14 changes: 11 additions & 3 deletions fast_llm/engine/multi_stage/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fast_llm.engine.config_utils.tensor_space import TensorDim
from fast_llm.engine.distributed.config import DistributedDim
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, StageMode
from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode
from fast_llm.functional.triton.pointwise import triton_add, triton_copy
from fast_llm.logging import log_distributed_tensor
from fast_llm.tensor import ParameterMeta, SafeTensorSlice, TensorMeta
Expand Down Expand Up @@ -246,13 +246,14 @@ def setup(
)
self._parameter_buffers[parameter_name] = parameter_buffer

def reset_shard_pad(self, shard: torch.Tensor) -> int:
def reset_shard_pad(self, shard: torch.Tensor, shard_name: str) -> int:
assert self._is_setup
assert self._mode.on_device
# TODO: Needed?
# Prevent nans with the padded values
# Also ensures a correct parameter count in loading context.
self._weight_shard_meta.validate(shard)
shard_meta = self._weight_shard_meta if shard_name == ShardName.weights else self._grad_shard_meta
shard_meta.validate(shard)
if self._shard_pad > 0:
shard[-self._shard_pad :].zero_()
return self._shard_pad
Expand Down Expand Up @@ -452,5 +453,12 @@ def copy_shard_overlaps(
begin, end = self._parameter_range_in_shard(name)

for shard_name, shard in shards.items():
# Shards can be empty (frozen weights)
if shard.numel() == 0:
continue
if loaded_shards[shard_name].numel() == 0:
shard[begin:end][overlap_mask] = 0
counter += overlap_count
continue
shard[begin:end][overlap_mask] = loaded_shards[shard_name][overlap_index_map_masked]
counter += overlap_count
7 changes: 1 addition & 6 deletions fast_llm/engine/multi_stage/multi_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from fast_llm.engine.config_utils.tensor_space import TensorDim
from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode
from fast_llm.engine.multi_stage.fsdp import FSDP
from fast_llm.engine.multi_stage.stage import Stage
from fast_llm.engine.optimizer.config import ParamGroup
Expand All @@ -24,11 +24,6 @@
logger = logging.getLogger(__name__)


class ShardName:
weights = "weights"
grads = "grads"


class MultiStageModel[ConfigType: FastLLMModelConfig](Configurable[ConfigType]):
config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig
base_model_class: typing.ClassVar[type[BaseModel]] = BaseModel
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/multi_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def reduce_gradients(self, accumulate=False) -> None:
level=self._config.debug_param_gradients,
global_=False,
)
if self._config.debug_all_param_gradients:
if self._config.debug_all_param_gradients and fsdp.requires_grad:
fsdp.log_shard(
name="gradient",
shard=fsdp.grad_shard,
Expand Down
4 changes: 2 additions & 2 deletions fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.engine.multi_stage.config import StageConfig, StageMode
from fast_llm.engine.multi_stage.config import ShardName, StageConfig, StageMode
from fast_llm.engine.multi_stage.fsdp import FSDP
from fast_llm.engine.optimizer.config import ParamGroup
from fast_llm.logging import log_generator
Expand Down Expand Up @@ -209,7 +209,7 @@ def initialize_weights(self) -> None:
meta.init_parameter(parameter, self._distributed)

if self.mode.on_device:
fsdp.reset_shard_pad(fsdp.weight_shard)
fsdp.reset_shard_pad(fsdp.weight_shard, ShardName.weights)

if self._config.debug_param_init:
log_generator("CPU generator after reset", torch.random.default_generator)
Expand Down
3 changes: 2 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def run_test_script(
config: CompareConfig | None = None,
prepare_fn=None,
compare_fn=None,
do_compare: bool = True,
):
if torch.cuda.device_count() < num_gpus:
pytest.skip(f"Not enough GPUs to run test ({torch.cuda.device_count()}<{num_gpus})")
Expand Down Expand Up @@ -413,7 +414,7 @@ def run_test_script(
completed_proc = subprocess.run(command, env=env, timeout=60)
if completed_proc.returncode:
raise RuntimeError(f"Process failed with return code {completed_proc.returncode}")
if compare:
if compare and do_compare:
if compare_fn is not None:
compare_fn(TEST_RESULTS_PATH / name, TEST_RESULTS_PATH / compare)
compare_tensor_logs(
Expand Down
22 changes: 20 additions & 2 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
FastLLMCheckpointFormat,
ModelConfigType,
)
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageMode
from fast_llm.engine.multi_stage.multi_stage import ShardName
from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode
from fast_llm.models.auto import model_registry
from fast_llm.tools.convert import ConversionConfig
from tests.common import (
Expand Down Expand Up @@ -76,6 +75,7 @@ def _compare_resume_fn(test_path: pathlib.Path, compare_path: pathlib.Path):

@pytest.mark.depends(on=["test_checkpoint_and_eval"])
def test_resume():
# Resume from iteration=1 and compare outputs with the baseline run.
run_test_script(
f"test_{TEST_MODEL}_resume",
CONFIG_COMMON
Expand All @@ -90,6 +90,24 @@ def test_resume():
)


@pytest.mark.depends(on=["test_checkpoint_and_eval"])
def test_resume_frozen():
Comment thread
RaymondLi0 marked this conversation as resolved.
# Resume with frozen mlp. No comparison.
run_test_script(
f"test_{TEST_MODEL}_resume_frozen",
CONFIG_COMMON
+ [
"training.checkpoint.interval=1",
"training.evaluations.validation.interval=2",
"training.evaluations.validation.iterations=1",
"model.base_model.transformer.mlp_lr_scale=0.",
],
compare=f"test_{TEST_MODEL}_checkpoint_and_eval",
prepare_fn=_prepare_resume_fn,
do_compare=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just removing compare?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the prepare_resume_fn uses compare to copy the checkpoints

)


def _run_conversion(config: ConversionConfig):
if config.output.path.is_dir() and not REUSE_RESULTS:
shutil.rmtree(config.output.path)
Expand Down