Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions fast_llm/data/preparation/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import datasets
import huggingface_hub
import huggingface_hub.utils
import numpy as np
import requests
import torch.distributed
import tqdm
import transformers
Expand Down Expand Up @@ -81,27 +81,26 @@ def _load_dataset(self) -> datasets.Dataset:
return dataset

def _get_croissant_metadata(self):
url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant"
token = huggingface_hub.get_token()
try:
# Retrieve the dataset metadata in croissant format
url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant"
if token is None:
response = requests.get(url)
else:
response = requests.get(url, headers={"Authorization": f"Bearer {token}"})
headers = None if token is None else {"Authorization": f"Bearer {token}"}
response = huggingface_hub.utils.http_backoff("GET", url, headers=headers, timeout=10)

if response.status_code != 200:
logger.warning(
f"Failed to get croissant metadata, status_code: {response.status_code}, body: {response.text}"
f"Failed to get croissant metadata from {url}, "
f"status_code: {response.status_code}, body: {response.text}"
)
return None

data = response.json()
except Exception as e:
logger.warning(f"Failed to get croissant metadata, {e}")
logger.warning(f"Failed to get croissant metadata from {url} after retries, {e}")
return None
if "error" in data:
logger.warning(f"Failed to get croissant metadata, error: {data['error']}")
logger.warning(f"Failed to get croissant metadata from {url}, error: {data['error']}")
return None

return data
Expand Down
14 changes: 12 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,16 @@ def pytest_addoption(parser):
class WorkerResources:
torchrun_port: int
rendezvous_port: int
data_streaming_port: int
model_streaming_port: int


MAX_TEST_MEMORY = 5e9
CUDA_CONTEXT_SIZE = 7e8
TORCHRUN_DEFAULT_PORT = 25900
PORTS_PER_WORKER = 32
DATA_STREAMING_PORT_OFFSET = 2
MODEL_STREAMING_PORT_OFFSET = 8


def pytest_configure(config):
Expand Down Expand Up @@ -128,10 +133,15 @@ def pytest_configure(config):
f"Please reduce the number of workers to {int(gpu_memory/(MAX_TEST_MEMORY + CUDA_CONTEXT_SIZE))*num_gpus} or less."
)

worker_port_base = TORCHRUN_DEFAULT_PORT + PORTS_PER_WORKER * worker_id
config.worker_resources = WorkerResources(
# Each worker needs its own set of ports for safe distributed run. Hopefully these are free.
torchrun_port=TORCHRUN_DEFAULT_PORT + 2 * worker_id,
rendezvous_port=TORCHRUN_DEFAULT_PORT + 2 * worker_id + 1,
torchrun_port=worker_port_base,
rendezvous_port=worker_port_base + 1,
data_streaming_port=worker_port_base + DATA_STREAMING_PORT_OFFSET,
# Model streaming uses a contiguous range starting here: one port per redis
# producer plus one port per weights broadcast rendezvous.
model_streaming_port=worker_port_base + MODEL_STREAMING_PORT_OFFSET,
)

# Skip slow autotune for tests. The default config has the highest block size, so this shouldn't hide any bug.
Expand Down
6 changes: 4 additions & 2 deletions tests/data/test_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def test_preparator_split_sharded():
def test_dataset_preparator_from_hub():
# TODO: Find or make a smaller dataset to speed things up.
output_path = DATASET_CACHE / "preparator_from_hub"
expected_url = "https://huggingface.co/datasets/openai/gsm8k"
preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict(
{
"dataset": {
Expand All @@ -181,8 +182,9 @@ def test_dataset_preparator_from_hub():
)
preparator_config.run()

assert (croissant_path := output_path / "croissant.json").is_file()
Assert.eq(json.load(croissant_path.open("r"))["url"], "https://huggingface.co/datasets/openai/gsm8k")
croissant_path = output_path / "croissant.json"
assert croissant_path.is_file(), f"Croissant metadata not fetched from {expected_url}"
Assert.eq(json.load(croissant_path.open("r"))["url"], expected_url)

dataset = GPTDatasetFromFileConfig(path=output_path / "fast_llm_config.yaml").build()
Assert.custom(isinstance, dataset, MemmapDataset)
Expand Down
18 changes: 9 additions & 9 deletions tests/data/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_streaming_dataset(
worker_resources: WorkerResources,
):
"""StreamingDataset should read a message and convert it into LanguageModelSample."""
stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port, timeout=1)
stream_config = StreamingDatasetConfig(port=worker_resources.data_streaming_port, timeout=1)
dataset_iterator = RedisStreamingDataset(stream_config).iterate(SamplingConfig(), len(documents), 0)
documents = [document if isinstance(document, dict) else {"tokens": list(document)} for document in documents]
for document in documents:
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_streaming_sampled_dataset(
):
"""StreamingDataset should read a message and convert it into LanguageModelSample."""
dataset_iterator = iter(
StreamingDatasetConfig(port=worker_resources.torchrun_port, timeout=1).build_and_sample(
StreamingDatasetConfig(port=worker_resources.data_streaming_port, timeout=1).build_and_sample(
SamplingConfig(truncate_documents=False, micro_batch_size=5, predicted_tokens=0), 1, 0
)
)
Expand Down Expand Up @@ -161,13 +161,13 @@ def _get_distributed_config(distributed_config_dict: dict[str, typing.Any], worl


def _run_test_data_streaming(
path: pathlib.Path, distributed_config: DistributedConfig, port: int, num_workers: int = 1
path: pathlib.Path, distributed_config: DistributedConfig, redis_port: int, num_workers: int = 1
):
redis_config = RedisConfig(port=port + 100, timeout=1)
redis_config = RedisConfig(port=redis_port, timeout=1)

data = GPTData(
GPTDataConfig(
datasets={"train": {"type": "streaming", "port": port + 100}},
datasets={"train": {"type": "streaming", "port": redis_port}},
micro_batch_size=_SEQUENCE_LENGTH,
truncate_documents=False,
),
Expand Down Expand Up @@ -218,7 +218,7 @@ def check_data_streaming_results(path: pathlib.Path, distributed_config: Distrib


def _run_test_data_streaming_distributed(
test_context: DistributedTestContext, base_path: pathlib.Path, port: int
test_context: DistributedTestContext, base_path: pathlib.Path, redis_port: int
) -> None:
# Import all dynamic classes. TODO: needed?
import fast_llm.cli # noqa
Expand All @@ -228,14 +228,14 @@ def _run_test_data_streaming_distributed(
logger.info(name, subtest.do_run)
if subtest.do_run:
distributed_config = _get_distributed_config(distributed_config_dict, num_gpus)
_run_test_data_streaming(base_path / name, distributed_config, port)
_run_test_data_streaming(base_path / name, distributed_config, redis_port)


@pytest.mark.parametrize("num_workers", (0, 1))
def test_data_streaming(data_result_path, worker_resources, num_workers):
distributed_config = _get_distributed_config({})
path = data_result_path / f"data_streaming/single_gpu_workers_{num_workers}"
_run_test_data_streaming(path, distributed_config, worker_resources.torchrun_port, num_workers)
_run_test_data_streaming(path, distributed_config, worker_resources.data_streaming_port, num_workers)
check_data_streaming_results(path, distributed_config)


Expand All @@ -258,7 +258,7 @@ def test_data_streaming(data_result_path, worker_resources, num_workers):
def test_run_data_streaming_distributed(run_parallel_script, data_result_path, worker_resources):
run_parallel_script(
_run_test_data_streaming_distributed,
(data_result_path / "data_streaming", worker_resources.torchrun_port),
(data_result_path / "data_streaming", worker_resources.data_streaming_port),
world_size=4,
backend=DistributedBackend.gloo,
use_cuda=False, # Disable device count check.
Expand Down
19 changes: 12 additions & 7 deletions tests/models/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from fast_llm.engine.training.config import StreamingTrainerCallbackConfig
from fast_llm.engine.training.streaming import REDIS_TRAINING_FIELD, REDIS_TRAINING_STREAM
from fast_llm.utils import Assert
from tests.conftest import WorkerResources
from tests.conftest import MODEL_STREAMING_PORT_OFFSET, PORTS_PER_WORKER, WorkerResources
from tests.models.test_checkpoint import compare_safetensor_files
from tests.utils.distributed_configs import DistributedTestingConfig
from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup, update_and_add_testing_config
Expand Down Expand Up @@ -59,6 +59,9 @@ def total_gpus(self) -> int:
),
]

# Each config consumes one redis producer port plus one weights-broadcast rendezvous port.
Assert.leq(MODEL_STREAMING_PORT_OFFSET + 2 * len(_DISTRIBUTED_STREAMING_CONFIGS), PORTS_PER_WORKER)


def _run_event_consumer(
streaming_config: StreamingTrainerCallbackConfig, consumer_index: int, base_path: pathlib.Path
Expand Down Expand Up @@ -131,13 +134,17 @@ def _run_event_consumer(


def _run_model_streaming_configs(
test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig, port: int
test_context: DistributedTestContext,
base_path: pathlib.Path,
model_testing_config: ModelTestingConfig,
streaming_port: int,
) -> None:
# Import all dynamic classes.
import fast_llm.cli # noqa

for config_index, config in enumerate(_DISTRIBUTED_STREAMING_CONFIGS):
config_port = port + config_index
config_port = streaming_port + config_index
broadcast_port = streaming_port + len(_DISTRIBUTED_STREAMING_CONFIGS) + config_index
model_testing_config = update_and_add_testing_config(
model_testing_config,
None,
Expand All @@ -149,7 +156,7 @@ def _run_model_streaming_configs(
"type": "streaming",
"port": config_port,
"broadcast": {
"port": config_port + 1000,
"port": broadcast_port,
"external_world_size": config.consumer_count,
},
"export": {"format": model_testing_config.checkpoint_format.name},
Expand Down Expand Up @@ -203,11 +210,9 @@ def test_run_model_distributed_streaming(
if torch.cuda.device_count() < 2:
pytest.skip(f"Not enough GPUs")
model_testing_config.get_dataset()
# Use a fixed shift to avoid port conflicts with other distributed tests.
port = worker_resources.torchrun_port + 4321
run_parallel_script(
_run_model_streaming_configs,
(run_test_script_base_path, model_testing_config, port),
(run_test_script_base_path, model_testing_config, worker_resources.model_streaming_port),
world_size=torch.cuda.device_count(),
backend=model_testing_config.distributed_backend,
)
Expand Down
Loading