From 8847be8edd717796af6c4ee3adc3821587d003e1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 16:39:01 -0400 Subject: [PATCH 1/2] Fix parallel test flakes --- fast_llm/data/preparation/gpt_memmap/prepare.py | 17 ++++++++--------- tests/conftest.py | 14 ++++++++++++-- tests/data/test_preparator.py | 10 ++++++++-- tests/data/test_streaming.py | 14 +++++++------- tests/models/test_streaming.py | 14 ++++++++------ 5 files changed, 43 insertions(+), 26 deletions(-) diff --git a/fast_llm/data/preparation/gpt_memmap/prepare.py b/fast_llm/data/preparation/gpt_memmap/prepare.py index 10cc15e06..410758147 100644 --- a/fast_llm/data/preparation/gpt_memmap/prepare.py +++ b/fast_llm/data/preparation/gpt_memmap/prepare.py @@ -11,8 +11,8 @@ import datasets import huggingface_hub +import huggingface_hub.utils import numpy as np -import requests import torch.distributed import tqdm import transformers @@ -81,27 +81,26 @@ def _load_dataset(self) -> datasets.Dataset: return dataset def _get_croissant_metadata(self): + url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant" token = huggingface_hub.get_token() try: # Retrieve the dataset metadata in croissant format - url = f"https://huggingface.co/api/datasets/{self._config.dataset.path}/croissant" - if token is None: - response = requests.get(url) - else: - response = requests.get(url, headers={"Authorization": f"Bearer {token}"}) + headers = None if token is None else {"Authorization": f"Bearer {token}"} + response = huggingface_hub.utils.http_backoff("GET", url, headers=headers, timeout=10) if response.status_code != 200: logger.warning( - f"Failed to get croissant metadata, status_code: {response.status_code}, body: {response.text}" + f"Failed to get croissant metadata from {url}, " + f"status_code: {response.status_code}, body: {response.text}" ) return None data = response.json() except Exception as e: - logger.warning(f"Failed to get croissant metadata, {e}") + logger.warning(f"Failed to get croissant metadata from {url} after retries, {e}") return None if "error" in data: - logger.warning(f"Failed to get croissant metadata, error: {data['error']}") + logger.warning(f"Failed to get croissant metadata from {url}, error: {data['error']}") return None return data diff --git a/tests/conftest.py b/tests/conftest.py index 43f1fc65f..73b49b6cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,11 +81,16 @@ def pytest_addoption(parser): class WorkerResources: torchrun_port: int rendezvous_port: int + data_streaming_port: int + model_streaming_port: int MAX_TEST_MEMORY = 5e9 CUDA_CONTEXT_SIZE = 7e8 TORCHRUN_DEFAULT_PORT = 25900 +PORTS_PER_WORKER = 32 +DATA_STREAMING_PORT_OFFSET = 2 +MODEL_STREAMING_PORT_OFFSET = 8 def pytest_configure(config): @@ -128,10 +133,15 @@ def pytest_configure(config): f"Please reduce the number of workers to {int(gpu_memory/(MAX_TEST_MEMORY + CUDA_CONTEXT_SIZE))*num_gpus} or less." ) + worker_port_base = TORCHRUN_DEFAULT_PORT + PORTS_PER_WORKER * worker_id config.worker_resources = WorkerResources( # Each worker needs its own set of ports for safe distributed run. Hopefully these are free. - torchrun_port=TORCHRUN_DEFAULT_PORT + 2 * worker_id, - rendezvous_port=TORCHRUN_DEFAULT_PORT + 2 * worker_id + 1, + torchrun_port=worker_port_base, + rendezvous_port=worker_port_base + 1, + data_streaming_port=worker_port_base + DATA_STREAMING_PORT_OFFSET, + # Model streaming uses a contiguous range starting here: one port per redis + # producer plus one port per weights broadcast rendezvous. + model_streaming_port=worker_port_base + MODEL_STREAMING_PORT_OFFSET, ) # Skip slow autotune for tests. The default config has the highest block size, so this shouldn't hide any bug. diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index 79db01b55..444d9a11a 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -167,6 +167,7 @@ def test_preparator_split_sharded(): def test_dataset_preparator_from_hub(): # TODO: Find or make a smaller dataset to speed things up. output_path = DATASET_CACHE / "preparator_from_hub" + expected_url = "https://huggingface.co/datasets/openai/gsm8k" preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( { "dataset": { @@ -181,8 +182,13 @@ def test_dataset_preparator_from_hub(): ) preparator_config.run() - assert (croissant_path := output_path / "croissant.json").is_file() - Assert.eq(json.load(croissant_path.open("r"))["url"], "https://huggingface.co/datasets/openai/gsm8k") + croissant_path = output_path / "croissant.json" + assert croissant_path.is_file(), ( + "Expected the preparator to fetch Croissant metadata from the Hugging Face Hub " + f"and save it to {croissant_path}. If this fails intermittently, check network/DNS " + f"and the availability of {expected_url}/tree/main or the Croissant API endpoint." + ) + Assert.eq(json.load(croissant_path.open("r"))["url"], expected_url) dataset = GPTDatasetFromFileConfig(path=output_path / "fast_llm_config.yaml").build() Assert.custom(isinstance, dataset, MemmapDataset) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index d26ef4053..2aad6664c 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -56,7 +56,7 @@ def test_streaming_dataset( worker_resources: WorkerResources, ): """StreamingDataset should read a message and convert it into LanguageModelSample.""" - stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port, timeout=1) + stream_config = StreamingDatasetConfig(port=worker_resources.data_streaming_port, timeout=1) dataset_iterator = RedisStreamingDataset(stream_config).iterate(SamplingConfig(), len(documents), 0) documents = [document if isinstance(document, dict) else {"tokens": list(document)} for document in documents] for document in documents: @@ -126,7 +126,7 @@ def test_streaming_sampled_dataset( ): """StreamingDataset should read a message and convert it into LanguageModelSample.""" dataset_iterator = iter( - StreamingDatasetConfig(port=worker_resources.torchrun_port, timeout=1).build_and_sample( + StreamingDatasetConfig(port=worker_resources.data_streaming_port, timeout=1).build_and_sample( SamplingConfig(truncate_documents=False, micro_batch_size=5, predicted_tokens=0), 1, 0 ) ) @@ -161,13 +161,13 @@ def _get_distributed_config(distributed_config_dict: dict[str, typing.Any], worl def _run_test_data_streaming( - path: pathlib.Path, distributed_config: DistributedConfig, port: int, num_workers: int = 1 + path: pathlib.Path, distributed_config: DistributedConfig, redis_port: int, num_workers: int = 1 ): - redis_config = RedisConfig(port=port + 100, timeout=1) + redis_config = RedisConfig(port=redis_port, timeout=1) data = GPTData( GPTDataConfig( - datasets={"train": {"type": "streaming", "port": port + 100}}, + datasets={"train": {"type": "streaming", "port": redis_port}}, micro_batch_size=_SEQUENCE_LENGTH, truncate_documents=False, ), @@ -235,7 +235,7 @@ def _run_test_data_streaming_distributed( def test_data_streaming(data_result_path, worker_resources, num_workers): distributed_config = _get_distributed_config({}) path = data_result_path / f"data_streaming/single_gpu_workers_{num_workers}" - _run_test_data_streaming(path, distributed_config, worker_resources.torchrun_port, num_workers) + _run_test_data_streaming(path, distributed_config, worker_resources.data_streaming_port, num_workers) check_data_streaming_results(path, distributed_config) @@ -258,7 +258,7 @@ def test_data_streaming(data_result_path, worker_resources, num_workers): def test_run_data_streaming_distributed(run_parallel_script, data_result_path, worker_resources): run_parallel_script( _run_test_data_streaming_distributed, - (data_result_path / "data_streaming", worker_resources.torchrun_port), + (data_result_path / "data_streaming", worker_resources.data_streaming_port), world_size=4, backend=DistributedBackend.gloo, use_cuda=False, # Disable device count check. diff --git a/tests/models/test_streaming.py b/tests/models/test_streaming.py index e65c128f6..91dfcc432 100644 --- a/tests/models/test_streaming.py +++ b/tests/models/test_streaming.py @@ -131,13 +131,17 @@ def _run_event_consumer( def _run_model_streaming_configs( - test_context: DistributedTestContext, base_path: pathlib.Path, model_testing_config: ModelTestingConfig, port: int + test_context: DistributedTestContext, + base_path: pathlib.Path, + model_testing_config: ModelTestingConfig, + streaming_port: int, ) -> None: # Import all dynamic classes. import fast_llm.cli # noqa for config_index, config in enumerate(_DISTRIBUTED_STREAMING_CONFIGS): - config_port = port + config_index + config_port = streaming_port + config_index + broadcast_port = streaming_port + len(_DISTRIBUTED_STREAMING_CONFIGS) + config_index model_testing_config = update_and_add_testing_config( model_testing_config, None, @@ -149,7 +153,7 @@ def _run_model_streaming_configs( "type": "streaming", "port": config_port, "broadcast": { - "port": config_port + 1000, + "port": broadcast_port, "external_world_size": config.consumer_count, }, "export": {"format": model_testing_config.checkpoint_format.name}, @@ -203,11 +207,9 @@ def test_run_model_distributed_streaming( if torch.cuda.device_count() < 2: pytest.skip(f"Not enough GPUs") model_testing_config.get_dataset() - # Use a fixed shift to avoid port conflicts with other distributed tests. - port = worker_resources.torchrun_port + 4321 run_parallel_script( _run_model_streaming_configs, - (run_test_script_base_path, model_testing_config, port), + (run_test_script_base_path, model_testing_config, worker_resources.model_streaming_port), world_size=torch.cuda.device_count(), backend=model_testing_config.distributed_backend, ) From 9e96c153954740580940e796fc5c7f0da34f5a07 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 May 2026 17:17:48 -0400 Subject: [PATCH 2/2] Address self-review nits - Assert the model-streaming port range stays within PORTS_PER_WORKER so adding entries to _DISTRIBUTED_STREAMING_CONFIGS fails loudly instead of silently colliding with the next worker's range. - Collapse the Croissant fetch assertion message to a single line. - Rename the lingering `port` parameter on `_run_test_data_streaming_distributed` to `redis_port` for consistency with the helper it delegates to. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/data/test_preparator.py | 6 +----- tests/data/test_streaming.py | 4 ++-- tests/models/test_streaming.py | 5 ++++- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/data/test_preparator.py b/tests/data/test_preparator.py index 444d9a11a..415d5f1af 100644 --- a/tests/data/test_preparator.py +++ b/tests/data/test_preparator.py @@ -183,11 +183,7 @@ def test_dataset_preparator_from_hub(): preparator_config.run() croissant_path = output_path / "croissant.json" - assert croissant_path.is_file(), ( - "Expected the preparator to fetch Croissant metadata from the Hugging Face Hub " - f"and save it to {croissant_path}. If this fails intermittently, check network/DNS " - f"and the availability of {expected_url}/tree/main or the Croissant API endpoint." - ) + assert croissant_path.is_file(), f"Croissant metadata not fetched from {expected_url}" Assert.eq(json.load(croissant_path.open("r"))["url"], expected_url) dataset = GPTDatasetFromFileConfig(path=output_path / "fast_llm_config.yaml").build() diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 2aad6664c..e938de7d0 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -218,7 +218,7 @@ def check_data_streaming_results(path: pathlib.Path, distributed_config: Distrib def _run_test_data_streaming_distributed( - test_context: DistributedTestContext, base_path: pathlib.Path, port: int + test_context: DistributedTestContext, base_path: pathlib.Path, redis_port: int ) -> None: # Import all dynamic classes. TODO: needed? import fast_llm.cli # noqa @@ -228,7 +228,7 @@ def _run_test_data_streaming_distributed( logger.info(name, subtest.do_run) if subtest.do_run: distributed_config = _get_distributed_config(distributed_config_dict, num_gpus) - _run_test_data_streaming(base_path / name, distributed_config, port) + _run_test_data_streaming(base_path / name, distributed_config, redis_port) @pytest.mark.parametrize("num_workers", (0, 1)) diff --git a/tests/models/test_streaming.py b/tests/models/test_streaming.py index 91dfcc432..3a7e27f4e 100644 --- a/tests/models/test_streaming.py +++ b/tests/models/test_streaming.py @@ -16,7 +16,7 @@ from fast_llm.engine.training.config import StreamingTrainerCallbackConfig from fast_llm.engine.training.streaming import REDIS_TRAINING_FIELD, REDIS_TRAINING_STREAM from fast_llm.utils import Assert -from tests.conftest import WorkerResources +from tests.conftest import MODEL_STREAMING_PORT_OFFSET, PORTS_PER_WORKER, WorkerResources from tests.models.test_checkpoint import compare_safetensor_files from tests.utils.distributed_configs import DistributedTestingConfig from tests.utils.model_configs import ModelTestingConfig, ModelTestingGroup, update_and_add_testing_config @@ -59,6 +59,9 @@ def total_gpus(self) -> int: ), ] +# Each config consumes one redis producer port plus one weights-broadcast rendezvous port. +Assert.leq(MODEL_STREAMING_PORT_OFFSET + 2 * len(_DISTRIBUTED_STREAMING_CONFIGS), PORTS_PER_WORKER) + def _run_event_consumer( streaming_config: StreamingTrainerCallbackConfig, consumer_index: int, base_path: pathlib.Path