Skip to content
3 changes: 2 additions & 1 deletion fast_llm/engine/config_utils/tensor_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def __init__(self, distributed_config: DistributedConfig):
self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1))

def setup(self, distributed: "Distributed") -> None:
assert distributed.config is self._distributed_config
assert not self._is_setup
if distributed.config is not self._distributed_config:
distributed.config.compare(self._distributed_config, ValueError)
self._is_setup = True
self._distributed = distributed

Expand Down
134 changes: 77 additions & 57 deletions fast_llm/engine/distributed/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ class DistributedConfig(Config):
desc="Ensure the initialization is the same for any distributed configuration.",
hint=FieldHint.testing,
)
reference_config: "DistributedConfig|None" = Field(
default=None,
init=False,
desc="Pointer to the distributed config this one is an identical copy of.",
hint=FieldHint.derived,
)

def _validate(self) -> None:
if self.world_size is None:
Expand Down Expand Up @@ -281,76 +287,90 @@ def _validate(self) -> None:
if self.tensor_parallel == 1:
self.sequence_tensor_parallel = False

self.distributed_dims = {}
if self.reference_config is not None:
self.reference_config.validate()
if self.reference_config.reference_config is not None:
self.reference_config = self.reference_config.reference_config
assert self.reference_config.reference_config is None
self.compare(self.reference_config, ValueError)
self.distributed_dims = self.reference_config.distributed_dims
else:
self.distributed_dims = {}

self.add_distributed_dim(
DistributedDim(name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.data,
size=self.data_parallel,
rank=self.data_rank,
id_=f"x_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.world,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.world, size=self.world_size, rank=self.rank, id_=None, parent=None
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.pipeline,
size=self.pipeline_parallel,
rank=self.pipeline_rank,
id_=f"x_{self.data_rank}_{self.tensor_rank}",
parent=DistributedDimNames.world,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.data,
size=self.data_parallel,
rank=self.data_rank,
id_=f"x_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.world,
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor,
size=self.tensor_parallel,
rank=self.tensor_rank,
id_=f"x_{self.data_rank}_{self.pipeline_rank}",
parent=DistributedDimNames.world,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.pipeline,
size=self.pipeline_parallel,
rank=self.pipeline_rank,
id_=f"x_{self.data_rank}_{self.tensor_rank}",
parent=DistributedDimNames.world,
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.sequence_data,
size=self.sequence_data_parallel,
rank=self.sequence_data_rank,
id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.data,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor,
size=self.tensor_parallel,
rank=self.tensor_rank,
id_=f"x_{self.data_rank}_{self.pipeline_rank}",
parent=DistributedDimNames.world,
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.batch_data,
size=self.batch_data_parallel,
rank=self.batch_data_rank,
id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.data,
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.sequence_data,
size=self.sequence_data_parallel,
rank=self.sequence_data_rank,
id_=f"{self.batch_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.data,
)
)
)
self.add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_sequence_data,
size=self.sequence_data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
id_=f"{self.batch_data_rank}_{self.pipeline_rank}",
parent=(
DistributedDimNames.tensor
if self.sequence_data_parallel == 1
else DistributedDimNames.sequence_data if self.tensor_parallel == 1 else DistributedDimNames.world
),
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.batch_data,
size=self.batch_data_parallel,
rank=self.batch_data_rank,
id_=f"{self.sequence_data_rank}_{self.pipeline_rank}_{self.tensor_rank}",
parent=DistributedDimNames.data,
)
)
self._add_distributed_dim(
DistributedDim(
name=DistributedDimNames.tensor_and_sequence_data,
size=self.sequence_data_parallel * self.tensor_parallel,
rank=self.tensor_rank + self.sequence_data_rank * self.tensor_parallel,
id_=f"{self.batch_data_rank}_{self.pipeline_rank}",
parent=(
DistributedDimNames.tensor
if self.sequence_data_parallel == 1
else (
DistributedDimNames.sequence_data
if self.tensor_parallel == 1
else DistributedDimNames.world
)
),
)
)
)

super()._validate()

Assert.in_range(self.rank, 0, self.world_size)
Assert.in_range(self.local_rank, 0, self.local_world_size)

def add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
def _add_distributed_dim(self, distributed_dim: DistributedDim) -> None:
if distributed_dim.name in self.distributed_dims:
Assert.eq(distributed_dim, self.distributed_dims[distributed_dim.name])
else:
Expand Down
8 changes: 8 additions & 0 deletions fast_llm/engine/distributed/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Distributed[ConfigType: DistributedConfig](Configurable[ConfigType]):

def __init__(self, config: DistributedConfig, use_cpu: bool = False):
super().__init__(config)
assert self._config.reference_config is None
self._use_cpu = use_cpu

if self._use_cpu:
Expand Down Expand Up @@ -148,6 +149,13 @@ def add_group(self, distributed_dim: DistributedDim) -> ProcessGroup | None:
distributed_dim.setup(group)
return group

def check_config(self, config: DistributedConfig) -> None:
# Allows using this `Distributed` on a model with a distributed config that is a copy of `self._config`
if config.reference_config is None:
Assert.is_(config, self._config)
else:
Assert.is_(config.reference_config, self._config)

def set_step(self, step: int, phase: PhaseType) -> None:
"""
Reseed pytorch for a given training step.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,38 @@

import transformers.modeling_outputs

from fast_llm.config import NoAutoValidate
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.engine.huggingface.config import HuggingfaceModelConfig
from fast_llm.engine.inference.config import HuggingfaceModelConfig
from fast_llm.engine.inference.runner import InferenceRunner
from fast_llm.engine.multi_stage.config import StageMode
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig
from fast_llm.engine.schedule.runner import ScheduleRunner
from fast_llm.engine.schedule.schedule import Schedule


class HuggingfacePreTrainedModel(transformers.PreTrainedModel):
config_class: typing.ClassVar[type[HuggingfaceModelConfig]] = HuggingfaceModelConfig
model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel
runner_class: typing.ClassVar[type[InferenceRunner]] = InferenceRunner
config: HuggingfaceModelConfig
# base_model_prefix = ""
# _no_split_modules = None
# _supports_cache_class = False
# _tied_weights_keys = []

def __init__(self, config: HuggingfaceModelConfig, fast_llm_model: FastLLMModel, **kwargs):
assert self.model_class.config_class is config.model_config_class
assert self.runner_class.model_class.config_class is config.model_config_class
assert config.fast_llm_config is fast_llm_model.config
assert isinstance(config, self.config_class)

super().__init__(config, **kwargs)
self._fast_llm_config = config.fast_llm_config
self._fast_llm_model = fast_llm_model

self._inference_runner = self.runner_class(fast_llm_model)
if not fast_llm_model.is_setup:
fast_llm_model.setup(mode=StageMode.inference)
self._inference_runner.setup()
# Transformers needs to be able to inspect the base model.
self.fast_llm_base_model = self._fast_llm_model.base_model
self._distributed_config = self._fast_llm_config.distributed
self.fast_llm_base_model = fast_llm_model.base_model
# TODO: Support distributed models?
assert self._distributed_config.world_size == 1
self._schedule_config = ScheduleConfig()
# We only need a basic schedule and don't care about dimensions.
# TODO: Sort things out.
with NoAutoValidate():
self._batch_config = BatchConfig()
self._batch_config.setup(self._distributed_config)
self._batch_config.validate()
self._runner = ScheduleRunner(
config=self._schedule_config, multi_stage=self._fast_llm_model, distributed_config=self._distributed_config
)
self._runner.setup(self._fast_llm_model.distributed)
# TODO: Random state? (Distributed.set_step)
self._schedule = Schedule(
multi_stage=self._fast_llm_model,
batch_config=self._batch_config,
schedule_config=self._schedule_config,
distributed_config=self._distributed_config,
phase=PhaseType.inference,
)
assert fast_llm_model.config.distributed.world_size == 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like us to work on this in a separate PR soon, because this is needed for running generative benchmarks during training.


with transformers.modeling_utils.no_init_weights():
self.post_init()

Expand All @@ -79,7 +60,7 @@ def from_pretrained(
config_updates[("distributed", "training_dtype")] = torch_dtype

# Create the model
fast_llm_model = cls.model_class.from_pretrained(
fast_llm_model = cls.runner_class.model_class.from_pretrained(
pretrained_model_name_or_path, config_updates=config_updates, mode=mode
)
config = cls.config_class(fast_llm_model.config)
Expand Down
58 changes: 58 additions & 0 deletions fast_llm/engine/inference/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import abc
import typing

from fast_llm.config import NoAutoValidate
from fast_llm.engine.distributed.config import PhaseType
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel
from fast_llm.engine.schedule.config import BatchConfig, ScheduleConfig
from fast_llm.engine.schedule.runner import ScheduleRunner
from fast_llm.engine.schedule.schedule import Schedule


class InferenceRunner(abc.ABC):
model_class: typing.ClassVar[type[FastLLMModel]] = FastLLMModel

def __init__(self, fast_llm_model: FastLLMModel):
assert isinstance(fast_llm_model, self.model_class)
self._fast_llm_model = fast_llm_model
# We only need a basic schedule and don't care about dimensions.
self._schedule_config = ScheduleConfig()
# TODO: Sort things out.
with NoAutoValidate():
self._batch_config = BatchConfig()
self._batch_config.setup(self._fast_llm_model.config.distributed)
self._batch_config.validate()
self._runner = ScheduleRunner(
config=self._schedule_config,
multi_stage=self._fast_llm_model,
distributed_config=self._fast_llm_model.config.distributed,
)
# TODO: Random state? (Distributed.set_step)
self._schedule = Schedule(
multi_stage=self._fast_llm_model,
batch_config=self._batch_config,
schedule_config=self._schedule_config,
distributed_config=self._fast_llm_model.config.distributed,
phase=PhaseType.inference,
)

@property
def fast_llm_model(self) -> FastLLMModel:
return self._fast_llm_model

def setup(self):
self._runner.setup(self._fast_llm_model.distributed)

def forward(
self, input_, kwargs: dict, *, iteration: int = 1, return_metrics: bool = False
) -> tuple[dict[str, float | int], dict[str, typing.Any] | None]:
# TODO: Return an actual model output.
reduced_losses, update_successful, metrics = self._runner.run_step(
iter((((input_, kwargs),),)),
self._schedule,
iteration=iteration,
return_metrics=return_metrics,
preprocessed=True,
)
assert update_successful
return reduced_losses, metrics
2 changes: 1 addition & 1 deletion fast_llm/engine/multi_stage/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from fast_llm.utils import Assert

if typing.TYPE_CHECKING:
from fast_llm.engine.huggingface.model import HuggingfacePreTrainedModel
from fast_llm.engine.inference.model import HuggingfacePreTrainedModel
from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel

logger = logging.getLogger(__name__)
Expand Down
12 changes: 10 additions & 2 deletions fast_llm/engine/multi_stage/multi_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,15 @@ def __init__(
"Bfloat16 gradient accumulation and reduction is not recommended. (use --full_precision_gradients=1)"
)

def setup(self, distributed: Distributed, mode: StageMode = StageMode.training) -> None:
def setup(self, distributed: Distributed | None = None, mode: StageMode = StageMode.training) -> None:
# TODO: More checks?
stage: Stage
assert distributed.config is self._config.distributed
assert not self._is_setup
self._is_setup = True
if distributed is None:
distributed = Distributed(self._config.distributed)
else:
distributed.check_config(self._config.distributed)
self._distributed = distributed
self._mode = mode
self._base_model.setup(distributed)
Expand Down Expand Up @@ -381,6 +384,10 @@ def get_shard(self, name: str) -> torch.Tensor:
raise KeyError(f"Unknown shard name {name}")
return self._shards[name]

@property
def is_setup(self) -> bool:
return self._is_setup

@property
def support_forward(self) -> bool:
assert self._is_setup
Expand Down Expand Up @@ -442,6 +449,7 @@ def is_parameter_on_device(self, parameter_name: str) -> bool:

@property
def distributed(self) -> Distributed:
assert self._is_setup
return self._distributed

def invalidate_buffers(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/multi_stage/stage_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def setup(
mode: StageMode = StageMode.training,
) -> None:
assert not self._is_setup
assert distributed.config is self._distributed_config
distributed.check_config(self._distributed_config)
self._mode = mode
self._is_setup = True
self._distributed = distributed
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(

def setup(self, distributed: Distributed, optimizer: Optimizer | None = None) -> None:
assert not self._is_setup
assert distributed.config is self._distributed_config
distributed.check_config(self._distributed_config)
self._is_setup = True
self._optimizer = optimizer
assert self._multi_stage.support_forward
Expand Down
Loading