From 18f807aa5a77cf4729e00aa6ccce56e39580c925 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Tue, 17 Mar 2026 21:14:11 +0000 Subject: [PATCH 01/23] feat: wire async task-queue scheduler into ColumnWiseDatasetBuilder --- .../src/data_designer/config/run_config.py | 1 + .../dataset_builders/async_scheduler.py | 52 ++++ .../dataset_builders/column_wise_builder.py | 155 +++++++++-- .../utils/row_group_buffer.py | 12 + .../test_async_builder_integration.py | 257 ++++++++++++++++++ .../dataset_builders/test_async_scheduler.py | 156 ++++++++++- .../data_designer/interface/data_designer.py | 3 + .../src/data_designer/interface/results.py | 3 + 8 files changed, 620 insertions(+), 19 deletions(-) create mode 100644 packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index 03b2ed297..37799fae6 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -42,6 +42,7 @@ class RunConfig(ConfigBase): non_inference_max_parallel_workers: int = Field(default=4, ge=1) max_conversation_restarts: int = Field(default=5, ge=0) max_conversation_correction_steps: int = Field(default=0, ge=0) + async_trace: bool = False @model_validator(mode="after") def normalize_shutdown_settings(self) -> Self: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 922952d5b..ea55ca601 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -46,6 +46,11 @@ def __init__( max_submitted_tasks: int = 256, salvage_max_rounds: int = 2, on_row_group_complete: Callable[[int], None] | None = None, + on_seeds_complete: Callable[[int, int], None] | None = None, + on_before_checkpoint: Callable[[int, int], None] | None = None, + shutdown_error_rate: float = 0.5, + shutdown_error_window: int = 10, + disable_early_shutdown: bool = False, trace: bool = False, ) -> None: self._generators = generators @@ -62,6 +67,13 @@ def __init__( self._wake_event = asyncio.Event() self._salvage_max_rounds = salvage_max_rounds self._on_row_group_complete = on_row_group_complete + self._on_seeds_complete = on_seeds_complete + self._on_before_checkpoint = on_before_checkpoint + + # Error rate shutdown + self._shutdown_error_rate = 1.0 if disable_early_shutdown else shutdown_error_rate + self._shutdown_error_window = shutdown_error_window + self._early_shutdown = False # Multi-column dedup: group output columns by generator identity instance_to_columns: dict[int, list[str]] = {} @@ -81,6 +93,8 @@ def __init__( # Active row groups (admitted but not yet checkpointed) self._active_rgs: list[tuple[int, int]] = [] self._admitted_rg_ids: set[int] = set() + self._seeds_dispatched_rgs: set[int] = set() + self._pre_batch_done_rgs: set[int] = set() # Tracing self._trace = trace @@ -118,8 +132,29 @@ async def run(self) -> None: # Main dispatch loop while True: + if self._early_shutdown: + logger.warning("Early shutdown triggered — error rate exceeded threshold") + break + self._wake_event.clear() + # Run pre-batch callbacks for row groups whose seeds just completed + for rg_id, rg_size in self._active_rgs: + if rg_id in self._seeds_dispatched_rgs and rg_id not in self._pre_batch_done_rgs: + seed_cols = {c for c in all_columns if not self._graph.get_upstream_columns(c)} + all_seeds_done = all(col in self._tracker._completed.get(rg_id, {}) for col in seed_cols) + if all_seeds_done and not self._in_flight_for_rg(rg_id): + self._pre_batch_done_rgs.add(rg_id) + if self._on_seeds_complete: + try: + self._on_seeds_complete(rg_id, rg_size) + except Exception as exc: + logger.warning(f"Pre-batch processor failed for row group {rg_id}, skipping: {exc}") + for ri in range(rg_size): + self._tracker.drop_row(rg_id, ri) + if self._buffer_manager: + self._buffer_manager.drop_row(rg_id, ri) + ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) for task in ready: await self._submission_semaphore.acquire() @@ -224,6 +259,8 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: for rg_id, rg_size in completed: self._active_rgs.remove((rg_id, rg_size)) try: + if self._on_before_checkpoint: + self._on_before_checkpoint(rg_id, rg_size) if self._buffer_manager is not None: self._buffer_manager.checkpoint_row_group(rg_id) if self._on_row_group_complete: @@ -233,8 +270,22 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: finally: self._rg_semaphore.release() + def _in_flight_for_rg(self, rg_id: int) -> bool: + """Check if any tasks are in-flight for a given row group.""" + return any(t.row_group == rg_id for t in self._in_flight) + + def _check_error_rate(self) -> None: + """Trigger early shutdown if error rate exceeds threshold.""" + completed = self._success_count + self._error_count + if completed < self._shutdown_error_window: + return + error_rate = self._error_count / max(1, completed) + if error_rate >= self._shutdown_error_rate: + self._early_shutdown = True + async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: """Dispatch from_scratch tasks for a row group.""" + self._seeds_dispatched_rgs.add(rg_id) seed_cols = [col for col in self._graph.get_topological_order() if not self._graph.get_upstream_columns(col)] seen_instances: set[int] = set() @@ -330,6 +381,7 @@ async def _execute_task_inner(self, task: Task) -> None: except Exception as exc: self._error_count += 1 + self._check_error_rate() if self._trace and trace: trace.status = "error" trace.error = str(exc) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index c76977688..2eac85278 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -33,7 +33,7 @@ from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager -from data_designer.engine.dataset_builders.utils.processor_runner import ProcessorRunner +from data_designer.engine.dataset_builders.utils.processor_runner import ProcessorRunner, ProcessorStage from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler from data_designer.engine.processing.processors.base import Processor @@ -47,6 +47,7 @@ import pandas as pd from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry + from data_designer.engine.dataset_builders.utils.task_model import TaskTrace from data_designer.engine.models.usage import ModelUsageStats logger = logging.getLogger(__name__) @@ -54,6 +55,7 @@ DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "0") == "1" if DATA_DESIGNER_ASYNC_ENGINE: + import asyncio import sys if sys.version_info < (3, 11): @@ -61,9 +63,16 @@ "DATA_DESIGNER_ASYNC_ENGINE requires Python 3.11+ (asyncio.TaskGroup). " f"Current version: {sys.version_info.major}.{sys.version_info.minor}" ) - from data_designer.engine.dataset_builders.utils.async_concurrency import AsyncConcurrentExecutor + from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler + from data_designer.engine.dataset_builders.utils.async_concurrency import ( + AsyncConcurrentExecutor, + _ensure_async_engine_loop, + ) + from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker + from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph + from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager - logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async concurrency") + logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async task-queue builder") _CLIENT_VERSION: str = get_library_version() @@ -80,6 +89,7 @@ def __init__( self._records_to_drop: set[int] = set() self._cell_resize_results: list[dict | list[dict] | None] = [] self._cell_resize_mode = False + self._task_traces: list[TaskTrace] = [] self._registry = registry or DataDesignerRegistry() self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider) @@ -99,6 +109,10 @@ def artifact_storage(self) -> ArtifactStorage: def processors(self) -> tuple[Processor, ...]: return self._processor_runner.processors + @property + def task_traces(self) -> list[TaskTrace]: + return self._task_traces + def set_processor_runner(self, processors: list[Processor]) -> None: """Replace the processor runner with a new one using the given processors.""" self._processor_runner = ProcessorRunner( @@ -150,22 +164,27 @@ def build( generators = self._initialize_generators() start_time = time.perf_counter() - group_id = uuid.uuid4().hex - buffer_size = self._resource_provider.run_config.buffer_size - self.batch_manager.start(num_records=num_records, buffer_size=buffer_size) - for batch_idx in range(self.batch_manager.num_batches): - logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}") - self._run_batch( - generators, - batch_mode="batch", - group_id=group_id, - current_batch_number=batch_idx, - on_batch_complete=on_batch_complete, - ) - self.batch_manager.finish() - self._processor_runner.run_after_generation(buffer_size) + if DATA_DESIGNER_ASYNC_ENGINE: + self._validate_async_compatibility() + logger.info("⚡ Using async task-queue builder") + self._build_async(generators, num_records, buffer_size, on_batch_complete) + else: + group_id = uuid.uuid4().hex + self.batch_manager.start(num_records=num_records, buffer_size=buffer_size) + for batch_idx in range(self.batch_manager.num_batches): + logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}") + self._run_batch( + generators, + batch_mode="batch", + group_id=group_id, + current_batch_number=batch_idx, + on_batch_complete=on_batch_complete, + ) + self.batch_manager.finish() + + self._processor_runner.run_after_generation(buffer_size) self._resource_provider.model_registry.log_model_usage(time.perf_counter() - start_time) return self.artifact_storage.final_dataset_path @@ -190,6 +209,108 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: return dataset + def _validate_async_compatibility(self) -> None: + """Raise if any column uses allow_resize=True with the async scheduler.""" + offending = [config.name for config in self.single_column_configs if getattr(config, "allow_resize", False)] + if offending: + raise DatasetGenerationError( + f"allow_resize=True is not supported with DATA_DESIGNER_ASYNC_ENGINE=1. " + f"Offending column(s): {offending}. Either remove allow_resize=True or " + f"disable the async scheduler." + ) + + def _build_async( + self, + generators: list[ColumnGenerator], + num_records: int, + buffer_size: int, + on_batch_complete: Callable[[Path], None] | None = None, + ) -> None: + """Async task-queue builder path — dispatches tasks based on dependency readiness.""" + # Build strategy map from generators + strategies: dict[str, GenerationStrategy] = {} + gen_map: dict[str, ColumnGenerator] = {} + for gen in generators: + if isinstance(gen.config, MultiColumnConfig): + for sub in gen.config.columns: + strategies[sub.name] = gen.get_generation_strategy() + gen_map[sub.name] = gen + else: + strategies[gen.config.name] = gen.get_generation_strategy() + gen_map[gen.config.name] = gen + + graph = ExecutionGraph.create(self._column_configs, strategies) + + # Log pre-generation info for all generators + for gen in generators: + gen.log_pre_generation() + + # Partition into row groups + row_groups: list[tuple[int, int]] = [] + remaining = num_records + rg_id = 0 + while remaining > 0: + size = min(buffer_size, remaining) + row_groups.append((rg_id, size)) + remaining -= size + rg_id += 1 + + tracker = CompletionTracker.with_graph(graph, row_groups) + buffer_manager = RowGroupBufferManager(self.artifact_storage) + settings = self._resource_provider.run_config + + # Pre-batch processor callback: runs after seed tasks complete for a row group. + # If it raises, the scheduler drops all rows in the row group (skips it). + def on_seeds_complete(rg_id: int, rg_size: int) -> None: + if not self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH): + return + df = buffer_manager.get_dataframe(rg_id) + df = self._processor_runner._run_stage(df, ProcessorStage.PRE_BATCH) + buffer_manager.replace_dataframe(rg_id, df) + + # Post-batch processor callback: runs after all columns, before checkpoint. + def on_before_checkpoint(rg_id: int, rg_size: int) -> None: + df = buffer_manager.get_dataframe(rg_id) + df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id) + buffer_manager.replace_dataframe(rg_id, df) + + # Telemetry snapshot + group_id = uuid.uuid4().hex + pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot() + + trace_enabled = settings.async_trace or os.environ.get("DATA_DESIGNER_ASYNC_TRACE", "0") == "1" + + scheduler = AsyncTaskScheduler( + generators=gen_map, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + on_seeds_complete=on_seeds_complete, + on_before_checkpoint=on_before_checkpoint, + shutdown_error_rate=settings.shutdown_error_rate, + shutdown_error_window=settings.shutdown_error_window, + disable_early_shutdown=settings.disable_early_shutdown, + trace=trace_enabled, + ) + + # Run on background event loop + loop = _ensure_async_engine_loop() + future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) + future.result() + + self._task_traces = scheduler.traces + + # Emit telemetry + try: + usage_deltas = self._resource_provider.model_registry.get_usage_deltas(pre_batch_snapshot) + self._emit_batch_inference_events("batch", usage_deltas, group_id) + except Exception: + pass + + # Write metadata + buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size) + def process_preview(self, dataset: pd.DataFrame) -> pd.DataFrame: df = self._processor_runner.run_post_batch(dataset.copy(), current_batch_number=None) return self._processor_runner.run_after_generation_on_df(df) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py index a63be31fa..3a6c86165 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py @@ -68,6 +68,18 @@ def get_dataframe(self, row_group: int) -> pd.DataFrame: rows = [row for i, row in enumerate(self._buffers[row_group]) if i not in dropped] return lazy.pd.DataFrame(rows) + def replace_dataframe(self, row_group: int, df: pd.DataFrame) -> None: + """Replace the buffer for a row group from a DataFrame (non-dropped rows only).""" + dropped = self._dropped.get(row_group, set()) + records = df.to_dict(orient="records") + buf_idx = 0 + for ri in range(self._row_group_sizes[row_group]): + if ri in dropped: + continue + if buf_idx < len(records): + self._buffers[row_group][ri] = records[buf_idx] + buf_idx += 1 + def drop_row(self, row_group: int, row_index: int) -> None: self._dropped.setdefault(row_group, set()).add(row_index) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py new file mode 100644 index 000000000..211c6c76c --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -0,0 +1,257 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math +from unittest.mock import MagicMock, Mock + +import pytest + +import data_designer.lazy_heavy_imports as lazy +from data_designer.config.column_configs import ( + ExpressionColumnConfig, + GenerationStrategy, + LLMTextColumnConfig, + SamplerColumnConfig, +) +from data_designer.config.sampler_params import SamplerType +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + ColumnGeneratorFullColumn, + FromScratchColumnGenerator, +) +from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder +from data_designer.engine.dataset_builders.errors import DatasetGenerationError +from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker +from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph +from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager +from data_designer.engine.resources.resource_provider import ResourceProvider + +MODEL_ALIAS = "stub" + + +# -- Mock generators for integration tests ------------------------------------ + + +def _expr_config(name: str = "test") -> ExpressionColumnConfig: + return ExpressionColumnConfig(name=name, expr="{{ x }}", dtype="str") + + +def _mock_provider() -> MagicMock: + return MagicMock(spec=ResourceProvider) + + +class MockSeed(FromScratchColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + return data + + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + return lazy.pd.DataFrame({"seed": list(range(num_records))}) + + +class MockCell(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + data["cell_out"] = f"val_{data.get('seed', '?')}" + return data + + +class MockFullCol(ColumnGeneratorFullColumn[ExpressionColumnConfig]): + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + data["expr_out"] = "computed" + return data + + +# -- allow_resize validation test --------------------------------------------- + + +def test_validate_async_compatibility_raises_on_allow_resize() -> None: + """allow_resize=True with async scheduler raises DatasetGenerationError.""" + builder = Mock(spec=ColumnWiseDatasetBuilder) + builder.single_column_configs = [ + Mock(name="col_a", allow_resize=True), + Mock(name="col_b", allow_resize=False), + ] + # Call the unbound method on the mock + with pytest.raises(DatasetGenerationError, match="allow_resize=True"): + ColumnWiseDatasetBuilder._validate_async_compatibility(builder) + + +def test_validate_async_compatibility_passes_without_allow_resize() -> None: + """No allow_resize=True passes validation.""" + builder = Mock(spec=ColumnWiseDatasetBuilder) + builder.single_column_configs = [ + Mock(name="col_a", allow_resize=False), + Mock(name="col_b", allow_resize=False), + ] + # Should not raise + ColumnWiseDatasetBuilder._validate_async_compatibility(builder) + + +# -- _build_async integration test with mock generators ----------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_build_async_end_to_end() -> None: + """Test _build_async with mock generators produces correct results in buffer.""" + from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler + + provider = _mock_provider() + seed_gen = MockSeed(config=_expr_config("seed"), resource_provider=provider) + cell_gen = MockCell(config=_expr_config("cell_out"), resource_provider=provider) + expr_gen = MockFullCol(config=_expr_config("expr_out"), resource_provider=provider) + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="expr_out", expr="{{ cell_out }}"), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + "expr_out": GenerationStrategy.FULL_COLUMN, + } + gen_map = { + "seed": seed_gen, + "cell_out": cell_gen, + "expr_out": expr_gen, + } + + num_records = 4 + buffer_size = 2 + graph = ExecutionGraph.create(configs, strategies) + + row_groups: list[tuple[int, int]] = [] + remaining = num_records + rg_id = 0 + while remaining > 0: + size = min(buffer_size, remaining) + row_groups.append((rg_id, size)) + remaining -= size + rg_id += 1 + + tracker = CompletionTracker.with_graph(graph, row_groups) + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_manager = RowGroupBufferManager(storage) + + checkpointed: list[int] = [] + + scheduler = AsyncTaskScheduler( + generators=gen_map, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + on_row_group_complete=lambda rg: checkpointed.append(rg), + ) + await scheduler.run() + + # Both row groups should be checkpointed + assert sorted(checkpointed) == [0, 1] + assert buffer_manager.actual_num_records == 4 + + # All columns should be complete + all_cols = ["seed", "cell_out", "expr_out"] + assert tracker.is_row_group_complete(0, 2, all_cols) + assert tracker.is_row_group_complete(1, 2, all_cols) + + +# -- Test that existing sync path is unaffected -------------------------------- + + +def test_sync_path_unaffected_by_async_engine_flag() -> None: + """DATA_DESIGNER_ASYNC_ENGINE=0 keeps the sync path unchanged.""" + import data_designer.engine.dataset_builders.column_wise_builder as builder_mod + + assert hasattr(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE") + assert isinstance(builder_mod.DATA_DESIGNER_ASYNC_ENGINE, bool) + + +# -- Test execution graph integration with real column configs ----------------- + + +def test_execution_graph_from_real_configs() -> None: + """Build execution graph from real column config objects.""" + configs = [ + SamplerColumnConfig(name="id", sampler_type=SamplerType.UUID, params={}), + LLMTextColumnConfig(name="question", prompt="{{ id }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="answer", prompt="{{ question }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="combined", expr="{{ question }} {{ answer }}"), + ] + strategies = { + "id": GenerationStrategy.FULL_COLUMN, + "question": GenerationStrategy.CELL_BY_CELL, + "answer": GenerationStrategy.CELL_BY_CELL, + "combined": GenerationStrategy.FULL_COLUMN, + } + + graph = ExecutionGraph.create(configs, strategies) + + order = graph.get_topological_order() + idx = {col: i for i, col in enumerate(order)} + + assert idx["id"] < idx["question"] + assert idx["question"] < idx["answer"] + assert idx["answer"] < idx["combined"] + + # Task counts + counts = graph.compute_task_count(num_records=10, buffer_size=3) + assert counts["id"] == math.ceil(10 / 3) + assert counts["question"] == 10 + assert counts["answer"] == 10 + assert counts["combined"] == math.ceil(10 / 3) + + +# -- Test checkpoint correctness ----------------------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_checkpoint_produces_correct_parquet_calls() -> None: + """Verify checkpoint writes parquet for each row group.""" + from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler + + provider = _mock_provider() + seed_gen = MockSeed(config=_expr_config("seed"), resource_provider=provider) + + configs = [SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["X"]})] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + gen_map = {"seed": seed_gen} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3), (1, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_manager = RowGroupBufferManager(storage) + + scheduler = AsyncTaskScheduler( + generators=gen_map, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + ) + await scheduler.run() + + # Two row groups → two write_batch_to_parquet_file calls + assert storage.write_batch_to_parquet_file.call_count == 2 + assert storage.move_partial_result_to_final_file_path.call_count == 2 + assert buffer_manager.actual_num_records == 5 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index dc91e2ee7..bffb4f453 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -463,7 +463,7 @@ async def test_scheduler_eager_row_drop_skips_downstream_of_failed_column() -> N """When fail_col drops a row, a downstream column never processes it.""" provider = _mock_provider() - # Pipeline: seed → fail_col (cell, permanent failure) → downstream (cell) + # Pipeline: seed -> fail_col (cell, permanent failure) -> downstream (cell) # downstream depends on fail_col, so its tasks only enter the frontier # after fail_col completes for each row. Since fail_col always fails, # the row is dropped before downstream is ever enqueued. @@ -502,7 +502,7 @@ async def test_scheduler_eager_row_drop_skips_downstream_of_failed_column() -> N async def test_scheduler_non_retryable_seed_failure_no_keyerror_on_downstream() -> None: """Non-retryable seed failure does not cause KeyError on vacuously-ready downstream. - Pipeline: seed (full_column) → cell_out (cell_by_cell) → full_out (full_column). + Pipeline: seed (full_column) -> cell_out (cell_by_cell) -> full_out (full_column). When seed fails non-retryably, all rows are dropped. cell_out's cell tasks become vacuously complete (all rows dropped), which makes full_out ready. full_out must not crash with a KeyError when its row group buffer has been @@ -557,3 +557,155 @@ async def test_scheduler_non_retryable_seed_failure_no_keyerror_on_downstream() # full_out was either never dispatched or silently skipped (no KeyError) full_out_errors = [t for t in scheduler.traces if t.column == "full_out" and t.status == "error"] assert len(full_out_errors) == 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_error_rate_shutdown() -> None: + """Early shutdown triggers when error rate exceeds threshold.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="fail_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "fail_col": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "fail_col": MockFailingGenerator(config=_expr_config("fail_col"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 10)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + shutdown_error_rate=0.5, + shutdown_error_window=2, + ) + await scheduler.run() + + assert scheduler._early_shutdown is True + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_early_shutdown_disabled() -> None: + """disable_early_shutdown=True prevents shutdown even at 100% error rate.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="fail_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "fail_col": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "fail_col": MockFailingGenerator(config=_expr_config("fail_col"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 5)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + disable_early_shutdown=True, + ) + await scheduler.run() + + assert scheduler._early_shutdown is False + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_on_before_checkpoint_callback() -> None: + """on_before_checkpoint is called before each row group is checkpointed.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + ] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + generators = {"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider)} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3), (1, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_mgr = RowGroupBufferManager(storage) + callback_log: list[tuple[int, int]] = [] + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_before_checkpoint=lambda rg, sz: callback_log.append((rg, sz)), + ) + await scheduler.run() + + assert sorted(callback_log) == [(0, 3), (1, 2)] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_pre_batch_failure_skips_row_group() -> None: + """Pre-batch processor failure drops all rows in the row group; other row groups continue.""" + provider = _mock_provider() + seed_gen = MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider) + cell_gen = MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider) + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = {"seed": seed_gen, "cell_out": cell_gen} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3), (1, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_mgr = RowGroupBufferManager(storage) + + def failing_pre_batch(rg_id: int, rg_size: int) -> None: + if rg_id == 0: + raise RuntimeError("pre-batch processor failed") + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_seeds_complete=failing_pre_batch, + ) + await scheduler.run() + + # Row group 0: all rows dropped due to pre-batch failure + assert all(tracker.is_dropped(0, ri) for ri in range(3)) + # Row group 1: completed normally + assert tracker.is_row_group_complete(1, 2, ["seed", "cell_out"]) diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index 57374b9ef..66e87cbd4 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -215,6 +215,8 @@ def create( except Exception as e: raise DataDesignerGenerationError(f"🛑 Error generating dataset: {e}") from e + task_traces = builder.task_traces + try: dataset_for_profiler = builder.artifact_storage.load_dataset_with_dropped_columns() except Exception as e: @@ -251,6 +253,7 @@ def create( analysis=analysis, config_builder=config_builder, dataset_metadata=dataset_metadata, + task_traces=task_traces, ) def preview( diff --git a/packages/data-designer/src/data_designer/interface/results.py b/packages/data-designer/src/data_designer/interface/results.py index 4a4006a95..5c829e16d 100644 --- a/packages/data-designer/src/data_designer/interface/results.py +++ b/packages/data-designer/src/data_designer/interface/results.py @@ -33,6 +33,7 @@ def __init__( analysis: DatasetProfilerResults, config_builder: DataDesignerConfigBuilder, dataset_metadata: DatasetMetadata, + task_traces: list | None = None, ): """Creates a new instance with results based on a dataset creation run. @@ -41,11 +42,13 @@ def __init__( analysis: Profiling results for the generated dataset. config_builder: Configuration builder used to create the dataset. dataset_metadata: Metadata about the generated dataset (e.g., seed column names). + task_traces: Optional list of TaskTrace objects from the async scheduler. """ self.artifact_storage = artifact_storage self._analysis = analysis self._config_builder = config_builder self.dataset_metadata = dataset_metadata + self.task_traces: list = task_traces or [] def load_analysis(self) -> DatasetProfilerResults: """Load the profiling analysis results for the generated dataset. From 5474d4eea100d84c3751bdb777e5473bc8fd96a3 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Tue, 17 Mar 2026 21:14:24 +0000 Subject: [PATCH 02/23] chore: add async benchmark notebook and demo scripts --- demo/async/bench_sync_vs_async.py | 281 + demo/async/helpers.py | 213 + demo/async/notebook_async_benchmark.html | 9380 ++++++++++++++++++++++ demo/async/notebook_async_benchmark.py | 704 ++ demo/async/test_edge_cases.py | 236 + demo/async/test_execution_order.py | 150 + demo/async/test_sync_vs_async_parity.py | 158 + 7 files changed, 11122 insertions(+) create mode 100644 demo/async/bench_sync_vs_async.py create mode 100644 demo/async/helpers.py create mode 100644 demo/async/notebook_async_benchmark.html create mode 100644 demo/async/notebook_async_benchmark.py create mode 100644 demo/async/test_edge_cases.py create mode 100644 demo/async/test_execution_order.py create mode 100644 demo/async/test_sync_vs_async_parity.py diff --git a/demo/async/bench_sync_vs_async.py b/demo/async/bench_sync_vs_async.py new file mode 100644 index 000000000..6e019da81 --- /dev/null +++ b/demo/async/bench_sync_vs_async.py @@ -0,0 +1,281 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Benchmark: sync vs async builder with repeated trials. + +Runs each engine N times (default 5), drops the first run of each as warmup, +and reports min / median / mean / max / stdev for the remaining runs. + +Uses real LLMs (openai-text). Set OPENAI_API_KEY in environment. + +Usage: + cd async_test && uv run python bench_sync_vs_async.py + cd async_test && uv run python bench_sync_vs_async.py --trials 7 --num-records 10 +""" + +from __future__ import annotations + +import logging +import math +import os +import sys +import tempfile +import time +import warnings +from argparse import ArgumentParser + +os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" + +warnings.filterwarnings("ignore", message=".*urllib3.*") +warnings.filterwarnings("ignore", message=".*Unclosed.*") + +import data_designer.engine.dataset_builders.column_wise_builder as cwb +from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig +from data_designer.config.run_config import RunConfig +from data_designer.config.sampler_params import CategorySamplerParams, SamplerType +from data_designer.interface import DataDesigner + +NUM_RECORDS_DEFAULT = 5 +TRIALS_DEFAULT = 5 +MAX_PARALLEL_DEFAULT = 4 + +# DAG shapes: +# narrow: topic → summary → followup (sequential, 2 LLM cols) +# wide: topic → summary ─┐ +# topic → analysis ├→ synthesis (3 parallel + 1 merge, 4 LLM cols) +# topic → trivia ──┘ +DAG_DEFAULT = "wide" + + +def _build_config(dag: str, max_parallel: int) -> DataDesignerConfigBuilder: + config = DataDesignerConfigBuilder( + model_configs=[ + ModelConfig( + alias="openai-text", + model="gpt-4.1", + provider="openai", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=max_parallel), + ) + ] + ) + config.add_column( + SamplerColumnConfig( + name="topic", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["science", "history", "art"]), + ) + ) + + if dag == "narrow": + config.add_column( + LLMTextColumnConfig( + name="summary", + model_alias="openai-text", + prompt="Write a one-sentence summary about {{ topic }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="followup", + model_alias="openai-text", + prompt="Given this summary: '{{ summary }}', write a follow-up question.", + ) + ) + elif dag == "wide": + config.add_column( + LLMTextColumnConfig( + name="summary", + model_alias="openai-text", + prompt="Write a one-sentence summary about {{ topic }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="analysis", + model_alias="openai-text", + prompt="Write a one-sentence analysis of why {{ topic }} matters.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="trivia", + model_alias="openai-text", + prompt="State one surprising fact about {{ topic }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="synthesis", + model_alias="openai-text", + prompt=( + "Combine these into one sentence: " + "summary='{{ summary }}', analysis='{{ analysis }}', trivia='{{ trivia }}'." + ), + ) + ) + else: + raise ValueError(f"Unknown dag shape: {dag!r}. Use 'narrow' or 'wide'.") + + return config + + +def _suppress_logs() -> None: + """Mute all data_designer / litellm / httpx loggers.""" + for name in ("data_designer", "LiteLLM", "litellm", "httpx", "asyncio"): + logging.getLogger(name).setLevel(logging.CRITICAL) + # Also prevent the root logger from propagating data_designer messages. + logging.getLogger().setLevel(logging.CRITICAL) + + +def _run_once(*, async_mode: bool, num_records: int, dag: str, max_parallel: int) -> float: + """Run a single creation and return wall-clock seconds.""" + _suppress_logs() + cwb.DATA_DESIGNER_ASYNC_ENGINE = async_mode + + with tempfile.TemporaryDirectory() as artifact_dir: + dd = DataDesigner(artifact_path=artifact_dir) + dd.set_run_config(RunConfig(buffer_size=1000, disable_early_shutdown=True, async_trace=False)) + config = _build_config(dag, max_parallel) + + t0 = time.perf_counter() + result = dd.create(config, num_records=num_records, dataset_name="bench") + elapsed = time.perf_counter() - t0 + + df = result.load_dataset() + assert len(df) == num_records, f"Expected {num_records} rows, got {len(df)}" + + return elapsed + + +def _stats(times: list[float]) -> dict[str, float]: + """Compute summary statistics for a list of times.""" + n = len(times) + s = sorted(times) + mean = sum(s) / n + if n % 2 == 1: + median = s[n // 2] + else: + median = (s[n // 2 - 1] + s[n // 2]) / 2 + variance = sum((t - mean) ** 2 for t in s) / n + return { + "min": s[0], + "median": median, + "mean": mean, + "max": s[-1], + "stdev": math.sqrt(variance), + } + + +def _print_row(label: str, st: dict[str, float]) -> None: + print( + f" {label:<7} " + f"{st['min']:>7.3f}s " + f"{st['median']:>7.3f}s " + f"{st['mean']:>7.3f}s " + f"{st['max']:>7.3f}s " + f"{st['stdev']:>7.3f}s" + ) + + +def main() -> None: + parser = ArgumentParser(description="Benchmark sync vs async builder") + parser.add_argument( + "--trials", type=int, default=TRIALS_DEFAULT, help="Total trials per engine (including 1 warmup)" + ) + parser.add_argument("--num-records", type=int, default=NUM_RECORDS_DEFAULT, help="Records per trial") + parser.add_argument( + "--dag", + type=str, + default=DAG_DEFAULT, + choices=["narrow", "wide"], + help="DAG shape: narrow (sequential) or wide (parallel branches)", + ) + parser.add_argument( + "--max-parallel", type=int, default=MAX_PARALLEL_DEFAULT, help="max_parallel_requests for the LLM model" + ) + args = parser.parse_args() + + trials: int = args.trials + num_records: int = args.num_records + dag: str = args.dag + max_parallel: int = args.max_parallel + + if trials < 2: + print("Need at least 2 trials (1 warmup + 1 measured). Setting trials=2.") + trials = 2 + + measured = trials - 1 + + print("=" * 70) + print(f"Benchmark: Sync vs Async ({trials} trials, first dropped as warmup)") + print(f"Records per trial: {num_records}, DAG: {dag}, max_parallel: {max_parallel}") + print("=" * 70) + + # --- Warmup: one sync run to prime health checks, caches, etc. --- + # First DataDesigner() call configures logging, so suppress via devnull. + print("\nWarmup (sync)...", end=" ", flush=True) + devnull = open(os.devnull, "w") # noqa: SIM115 + old_stderr = sys.stderr + sys.stderr = devnull + warmup_time = _run_once(async_mode=False, num_records=num_records, dag=dag, max_parallel=max_parallel) + sys.stderr = old_stderr + devnull.close() + print(f"{warmup_time:.3f}s") + + # --- Interleaved trials: ABABABAB to reduce temporal bias --- + sync_times: list[float] = [] + async_times: list[float] = [] + + for i in range(measured): + trial_num = i + 1 + + # Sync + print(f"\n Trial {trial_num}/{measured} — sync ...", end=" ", flush=True) + t_sync = _run_once(async_mode=False, num_records=num_records, dag=dag, max_parallel=max_parallel) + sync_times.append(t_sync) + print(f"{t_sync:.3f}s") + + # Async + print(f" Trial {trial_num}/{measured} — async...", end=" ", flush=True) + t_async = _run_once(async_mode=True, num_records=num_records, dag=dag, max_parallel=max_parallel) + async_times.append(t_async) + print(f"{t_async:.3f}s") + + # --- Stats --- + sync_st = _stats(sync_times) + async_st = _stats(async_times) + speedup_median = sync_st["median"] / async_st["median"] if async_st["median"] > 0 else float("inf") + speedup_mean = sync_st["mean"] / async_st["mean"] if async_st["mean"] > 0 else float("inf") + + print("\n" + "=" * 70) + print(f"Results ({measured} measured trials, {num_records} records each, dag={dag})") + print("=" * 70) + + header = f" {'engine':<7} {'min':>8} {'median':>8} {'mean':>8} {'max':>8} {'stdev':>8}" + print(header) + print(f" {'-' * (len(header.strip()) - 2)}") + _print_row("sync", sync_st) + _print_row("async", async_st) + + print(f"\n Speedup (median): {speedup_median:.2f}x") + print(f" Speedup (mean): {speedup_mean:.2f}x") + + # Per-trial detail + print("\n Per-trial times:") + print(f" {'trial':>5} {'sync':>8} {'async':>8} {'ratio':>7}") + print(f" {'-' * 33}") + for i, (s, a) in enumerate(zip(sync_times, async_times)): + ratio = s / a if a > 0 else float("inf") + print(f" {i + 1:>5} {s:>7.3f}s {a:>7.3f}s {ratio:>6.2f}x") + + print("\nDone.") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nInterrupted.") + sys.exit(130) diff --git a/demo/async/helpers.py b/demo/async/helpers.py new file mode 100644 index 000000000..d47865da7 --- /dev/null +++ b/demo/async/helpers.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared utilities for async engine test scripts.""" + +from __future__ import annotations + +import contextlib +import hashlib +import json +import random +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Any + +# --------------------------------------------------------------------------- +# Assertion helper +# --------------------------------------------------------------------------- + + +def check(condition: bool, message: str) -> None: + status = "PASS" if condition else "FAIL" + print(f" [{status}] {message}") + if not condition: + raise AssertionError(message) + + +# --------------------------------------------------------------------------- +# Mock LLM responses +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class FakeMessage: + content: str + tool_calls: list[dict[str, Any]] | None = None + reasoning_content: str | None = None + + +@dataclass(frozen=True) +class FakeChoice: + message: FakeMessage + + +@dataclass(frozen=True) +class FakeResponse: + choices: list[FakeChoice] + usage: Any | None = None + model: str | None = None + + +def _stable_seed(model: str, messages: list[dict[str, Any]]) -> int: + payload = json.dumps( + {"model": model, "messages": messages}, + sort_keys=True, + separators=(",", ":"), + ensure_ascii=True, + default=str, + ) + digest = hashlib.sha256(payload.encode()).digest() + return int.from_bytes(digest[:8], "big") + + +def _mock_response_text(model: str, messages: list[dict[str, Any]]) -> str: + rng = random.Random(_stable_seed(model, messages)) + score = rng.uniform(0, 10) + category = rng.choice(["low", "mid", "high"]) + return f"mock:{model}|cat={category}|score={score:.3f}" + + +def _fake_response(model: str, messages: list[dict[str, Any]], **_kwargs: Any) -> FakeResponse: + text = _mock_response_text(model, messages) + return FakeResponse(choices=[FakeChoice(message=FakeMessage(content=text))], model=model) + + +@contextlib.contextmanager +def patch_llm_responses(*, fail_pattern: str | None = None, fail_rate: float = 1.0) -> Iterator[None]: + """Patch LLM completion to return deterministic mock responses. + + Args: + fail_pattern: If set, raise RuntimeError when this string appears in the + serialized messages. Used to simulate LLM failures for specific columns. + fail_rate: Probability of failure when fail_pattern matches (0.0-1.0). + """ + from data_designer.engine.models.litellm_overrides import CustomRouter + + original_completion = CustomRouter.completion + original_acompletion = getattr(CustomRouter, "acompletion", None) + + def _should_fail(messages: list[dict[str, Any]]) -> bool: + if fail_pattern is None: + return False + serialized = json.dumps(messages, default=str) + if fail_pattern not in serialized: + return False + return random.random() < fail_rate + + def fake_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse: + if _should_fail(messages): + raise RuntimeError(f"Simulated LLM failure for {model}") + return _fake_response(model, messages, **kwargs) + + async def fake_acompletion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse: + if _should_fail(messages): + raise RuntimeError(f"Simulated LLM failure for {model}") + return _fake_response(model, messages, **kwargs) + + CustomRouter.completion = fake_completion + CustomRouter.acompletion = fake_acompletion + try: + yield + finally: + CustomRouter.completion = original_completion + if original_acompletion is not None: + CustomRouter.acompletion = original_acompletion + else: + with contextlib.suppress(AttributeError): + delattr(CustomRouter, "acompletion") + + +# --------------------------------------------------------------------------- +# Dataset fingerprinting +# --------------------------------------------------------------------------- + + +def dataset_fingerprint(df: Any) -> str: + import numpy as np + import pandas as pd + + def _default(v: Any) -> Any: + if isinstance(v, np.generic): + return v.item() + if isinstance(v, np.ndarray): + return v.tolist() + if isinstance(v, (pd.Timestamp, pd.Timedelta)): + return v.isoformat() + if isinstance(v, set): + return sorted(v) + if isinstance(v, bytes): + return v.decode("utf-8", errors="replace") + return str(v) + + normalized = df.reset_index(drop=True) + normalized = normalized.reindex(sorted(normalized.columns), axis=1) + records = normalized.to_dict(orient="records") + payload = json.dumps(records, sort_keys=True, separators=(",", ":"), ensure_ascii=True, default=_default) + return hashlib.sha256(payload.encode()).hexdigest() + + +# --------------------------------------------------------------------------- +# Model / DataDesigner factory helpers +# --------------------------------------------------------------------------- + + +def create_mock_model_config() -> tuple: + """Return (ModelConfig, ModelProvider) for openai-text with mock backend.""" + from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider + + provider = ModelProvider( + name="mock-provider", + endpoint="https://mock.local", + provider_type="openai", + api_key="mock-key", + ) + model_config = ModelConfig( + alias="openai-text", + model="gpt-4.1", + provider="mock-provider", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=8), + skip_health_check=True, + ) + return model_config, provider + + +def create_data_designer(*, artifact_path: str, async_trace: bool = True) -> Any: + from data_designer.config.run_config import RunConfig + from data_designer.interface import DataDesigner + + _, provider = create_mock_model_config() + dd = DataDesigner(artifact_path=artifact_path, model_providers=[provider]) + dd.set_run_config( + RunConfig( + buffer_size=1000, + disable_early_shutdown=True, + async_trace=async_trace, + ) + ) + return dd + + +def create_base_config() -> Any: + """Create a DataDesignerConfigBuilder with openai-text model and a seed column.""" + from data_designer.config.column_configs import SamplerColumnConfig + from data_designer.config.config_builder import DataDesignerConfigBuilder + from data_designer.config.sampler_params import SamplerType, UniformSamplerParams + + model_config, _ = create_mock_model_config() + config = DataDesignerConfigBuilder(model_configs=[model_config]) + config.add_column( + SamplerColumnConfig( + name="seed_value", + sampler_type=SamplerType.UNIFORM, + params=UniformSamplerParams(low=0.0, high=100.0, decimal_places=2), + ) + ) + return config + + +def seed_rng(seed: int = 42) -> None: + import numpy as np + + random.seed(seed) + np.random.seed(seed) diff --git a/demo/async/notebook_async_benchmark.html b/demo/async/notebook_async_benchmark.html new file mode 100644 index 000000000..a6651a7e9 --- /dev/null +++ b/demo/async/notebook_async_benchmark.html @@ -0,0 +1,9380 @@ + + + + + +notebook_async_benchmark + + + + + + + + + + + + +
+ + + + + + + + + + + + + +
+ + diff --git a/demo/async/notebook_async_benchmark.py b/demo/async/notebook_async_benchmark.py new file mode 100644 index 000000000..917b7a013 --- /dev/null +++ b/demo/async/notebook_async_benchmark.py @@ -0,0 +1,704 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +# %% [markdown] +# # Async Scheduler Benchmark +# +# This notebook compares the **sync** (column-by-column) builder with the +# **async** (task-queue) scheduler and shows when the async path delivers +# real speedups. +# +# **Key finding:** the async scheduler needs *both* a wide DAG (independent +# columns) *and* enough LLM concurrency (`max_parallel_requests`) to +# outperform the sync engine. With a narrow/sequential DAG or low +# concurrency, both engines hit the same API throughput ceiling. +# +# **Requirements:** `OPENAI_API_KEY` in environment. Uses `gpt-4.1` via +# `openai-text` alias. + +# %% Setup +from __future__ import annotations + +import logging +import math +import os +import sys +import tempfile +import time +import warnings + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import pandas as pd + +os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" + +warnings.filterwarnings("ignore", message=".*urllib3.*") +warnings.filterwarnings("ignore", message=".*Unclosed.*") + +from IPython.display import HTML, display + +import data_designer.engine.dataset_builders.column_wise_builder as cwb +from data_designer.config.column_configs import ( + GenerationStrategy, + LLMTextColumnConfig, + SamplerColumnConfig, +) +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig +from data_designer.config.run_config import RunConfig +from data_designer.config.sampler_params import CategorySamplerParams, SamplerType +from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph +from data_designer.interface import DataDesigner + +NUM_RECORDS = 10 +TRIALS = 5 + +_STRATEGY_MAP: dict[type, GenerationStrategy] = { + SamplerColumnConfig: GenerationStrategy.FULL_COLUMN, + LLMTextColumnConfig: GenerationStrategy.CELL_BY_CELL, +} + + +def suppress_logs() -> None: + for name in ("data_designer", "LiteLLM", "litellm", "httpx", "asyncio"): + logging.getLogger(name).setLevel(logging.CRITICAL) + logging.getLogger().setLevel(logging.CRITICAL) + + +def render_mermaid(code: str) -> None: + """Render a Mermaid diagram inline via the Mermaid JS CDN.""" + display( + HTML( + f"""
{code}
+""" + ) + ) + + +def _execution_graph_from_builder(config: DataDesignerConfigBuilder): # noqa: ANN201 + """Build an ExecutionGraph directly from a config builder.""" + cols = list(config._column_configs.values()) + strategies = {c.name: _STRATEGY_MAP[type(c)] for c in cols} + return ExecutionGraph.create(cols, strategies) + + +# %% [markdown] +# ## DAG Shapes +# +# We test two DAG shapes: +# +# **Narrow** (sequential): `topic → summary → followup` +# - 2 LLM columns, fully sequential — no cross-column parallelism possible. +# +# **Wide** (parallel branches): +# ``` +# topic → summary ─┐ +# topic → analysis ─┼→ synthesis +# topic → trivia ─┘ +# ``` +# - 3 independent LLM columns off the sampler, then 1 merge column. +# - The async scheduler can overlap summary/analysis/trivia. + +# %% Config builders + + +def _model_configs(max_parallel: int) -> list[ModelConfig]: + return [ + ModelConfig( + alias="openai-text", + model="gpt-4.1", + provider="openai", + inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=max_parallel), + ) + ] + + +def build_narrow(max_parallel: int) -> DataDesignerConfigBuilder: + config = DataDesignerConfigBuilder(model_configs=_model_configs(max_parallel)) + config.add_column( + SamplerColumnConfig( + name="topic", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["science", "history", "art"]), + ) + ) + config.add_column( + LLMTextColumnConfig( + name="summary", + model_alias="openai-text", + prompt="Write a one-sentence summary about {{ topic }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="followup", + model_alias="openai-text", + prompt="Given this summary: '{{ summary }}', write a follow-up question.", + ) + ) + return config + + +def build_wide(max_parallel: int) -> DataDesignerConfigBuilder: + config = DataDesignerConfigBuilder(model_configs=_model_configs(max_parallel)) + config.add_column( + SamplerColumnConfig( + name="topic", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["science", "history", "art"]), + ) + ) + config.add_column( + LLMTextColumnConfig( + name="summary", + model_alias="openai-text", + prompt="Write a one-sentence summary about {{ topic }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="analysis", + model_alias="openai-text", + prompt="Write a one-sentence analysis of why {{ topic }} matters.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="trivia", + model_alias="openai-text", + prompt="State one surprising fact about {{ topic }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="synthesis", + model_alias="openai-text", + prompt=( + "Combine these into one sentence: " + "summary='{{ summary }}', analysis='{{ analysis }}', trivia='{{ trivia }}'." + ), + ) + ) + return config + + +# %% [markdown] +# ## DAG Visualization +# +# The `ExecutionGraph` class models the column dependency DAG and annotates +# each node with its generation strategy (`full_column` for samplers, +# `cell_by_cell` for LLM columns). The `to_mermaid()` method renders this +# as a Mermaid flowchart. + +# %% +narrow_graph = _execution_graph_from_builder(build_narrow(4)) +wide_graph = _execution_graph_from_builder(build_wide(4)) + +print("Narrow DAG") +print(f" Critical path: {' → '.join(narrow_graph.get_longest_dependency_chain())}") +print(f" Task count (10 records, buffer=1000): {narrow_graph.compute_task_count(NUM_RECORDS, 1000)}") +render_mermaid(narrow_graph.to_mermaid()) + +# %% +print("Wide DAG") +print(f" Critical path: {' → '.join(wide_graph.get_longest_dependency_chain())}") +print(f" Task count (10 records, buffer=1000): {wide_graph.compute_task_count(NUM_RECORDS, 1000)}") +render_mermaid(wide_graph.to_mermaid()) + + +# %% Runner + + +def run_once( + *, + async_mode: bool, + config: DataDesignerConfigBuilder, + trace: bool = False, +) -> tuple[float, list]: + """Run one creation. Returns (elapsed_seconds, traces).""" + suppress_logs() + cwb.DATA_DESIGNER_ASYNC_ENGINE = async_mode + + with tempfile.TemporaryDirectory() as artifact_dir: + dd = DataDesigner(artifact_path=artifact_dir) + dd.set_run_config( + RunConfig( + buffer_size=1000, + disable_early_shutdown=True, + async_trace=trace and async_mode, + ) + ) + + t0 = time.perf_counter() + result = dd.create(config, num_records=NUM_RECORDS, dataset_name="bench") + elapsed = time.perf_counter() - t0 + + df = result.load_dataset() + assert len(df) == NUM_RECORDS, f"Expected {NUM_RECORDS} rows, got {len(df)}" + + traces = result.task_traces if hasattr(result, "task_traces") else [] + + return elapsed, traces + + +def benchmark( + config_fn, + max_parallel: int, + label: str, + trace_last: bool = False, +) -> dict: + """Run TRIALS interleaved sync/async and return stats + optional traces.""" + config_s = config_fn(max_parallel) + config_a = config_fn(max_parallel) + + # Warmup + devnull = open(os.devnull, "w") # noqa: SIM115 + old_stderr = sys.stderr + sys.stderr = devnull + run_once(async_mode=False, config=config_s) + sys.stderr = old_stderr + devnull.close() + + sync_times: list[float] = [] + async_times: list[float] = [] + last_traces: list = [] + + for i in range(TRIALS): + is_last = i == TRIALS - 1 + + config_s = config_fn(max_parallel) + t_sync, _ = run_once(async_mode=False, config=config_s) + sync_times.append(t_sync) + + config_a = config_fn(max_parallel) + t_async, traces = run_once( + async_mode=True, + config=config_a, + trace=trace_last and is_last, + ) + async_times.append(t_async) + if is_last: + last_traces = traces + + return { + "label": label, + "max_parallel": max_parallel, + "sync": stats(sync_times), + "async": stats(async_times), + "sync_times": sync_times, + "async_times": async_times, + "traces": last_traces, + } + + +def stats(times: list[float]) -> dict[str, float]: + n = len(times) + s = sorted(times) + mean = sum(s) / n + median = s[n // 2] if n % 2 == 1 else (s[n // 2 - 1] + s[n // 2]) / 2 + variance = sum((t - mean) ** 2 for t in s) / n + return {"min": s[0], "median": median, "mean": mean, "max": s[-1], "stdev": math.sqrt(variance)} + + +# %% [markdown] +# ## Experiment 1: Narrow DAG, default concurrency (4) +# +# With a sequential DAG and low concurrency, the async scheduler has no room +# to parallelize. Both engines process one column at a time with 4 workers. + +# %% +print("Running: narrow DAG, max_parallel=4 ...") +r_narrow_4 = benchmark(build_narrow, max_parallel=4, label="narrow, par=4") + +# %% [markdown] +# ## Experiment 2: Wide DAG, default concurrency (4) +# +# Even with 3 independent branches, `max_parallel_requests=4` means only 4 +# LLM calls can be in flight at once. The sync engine already saturates +# these 4 slots column-by-column. The async scheduler's cross-column +# dispatch hits the same ceiling. + +# %% +print("Running: wide DAG, max_parallel=4 ...") +r_wide_4 = benchmark(build_wide, max_parallel=4, label="wide, par=4") + +# %% [markdown] +# ## Experiment 3: Wide DAG, high concurrency (16) +# +# Now we raise `max_parallel_requests=16`. The sync engine still processes +# columns one at a time (16 workers per column). The async scheduler can +# dispatch tasks from summary, analysis, and trivia *simultaneously*, +# keeping 16 slots busy across columns. + +# %% +print("Running: wide DAG, max_parallel=16 ...") +r_wide_16 = benchmark(build_wide, max_parallel=16, label="wide, par=16", trace_last=True) + +# %% [markdown] +# ## Results Summary + +# %% +rows = [] +for r in [r_narrow_4, r_wide_4, r_wide_16]: + s, a = r["sync"], r["async"] + speedup = s["median"] / a["median"] if a["median"] > 0 else float("inf") + rows.append( + { + "Configuration": r["label"], + "Sync median (s)": round(s["median"], 1), + "Async median (s)": round(a["median"], 1), + "Speedup": f"{speedup:.2f}x", + } + ) + +summary_df = pd.DataFrame(rows) +display(summary_df.style.hide(axis="index").set_caption(f"Benchmark results ({TRIALS} trials, {NUM_RECORDS} records)")) + +# Per-trial detail +for r in [r_narrow_4, r_wide_4, r_wide_16]: + trial_rows = [] + for i, (st, at) in enumerate(zip(r["sync_times"], r["async_times"])): + ratio = st / at if at > 0 else float("inf") + trial_rows.append( + {"Trial": i + 1, "Sync (s)": round(st, 1), "Async (s)": round(at, 1), "Ratio": f"{ratio:.2f}x"} + ) + trial_df = pd.DataFrame(trial_rows) + display(trial_df.style.hide(axis="index").set_caption(r["label"])) + +# %% [markdown] +# ## Using Traces to Optimize +# +# The async scheduler records a `TaskTrace` for every dispatched task when +# `async_trace=True`. Each trace captures: +# +# | Field | Meaning | +# |---|---| +# | `column` | Which column this task generates | +# | `row_group` / `row_index` | Position in the dataset | +# | `task_type` | `from_scratch`, `batch`, or `cell` | +# | `dispatched_at` | When the scheduler queued the task | +# | `slot_acquired_at` | When it started executing (after waiting for a slot) | +# | `completed_at` | When it finished | +# | `status` | `ok` or `error` | +# +# From these we can compute **wait time** (queued → slot acquired) and +# **run time** (slot acquired → completed) for every task, revealing where +# the pipeline spends its time. + +# %% + +traces = r_wide_16["traces"] + +if traces: + t0 = min(t.dispatched_at for t in traces if t.dispatched_at) + + by_column: dict[str, list] = {} + for t in traces: + by_column.setdefault(t.column, []).append(t) + + # --- Per-column summary --- + col_rows = [] + for col in sorted(by_column): + col_traces = by_column[col] + waits = [ + (t.slot_acquired_at - t.dispatched_at) * 1000 for t in col_traces if t.slot_acquired_at and t.dispatched_at + ] + runs = [ + (t.completed_at - t.slot_acquired_at) * 1000 for t in col_traces if t.completed_at and t.slot_acquired_at + ] + errors = sum(1 for t in col_traces if t.status == "error") + col_rows.append( + { + "Column": col, + "Tasks": len(col_traces), + "Avg Wait (ms)": round(sum(waits) / len(waits), 0) if waits else 0, + "Avg Run (ms)": round(sum(runs) / len(runs), 0) if runs else 0, + "Total Run (ms)": round(sum(runs), 0), + "Errors": errors, + } + ) + + col_df = pd.DataFrame(col_rows) + display( + col_df.style.hide(axis="index") + .set_caption("Per-column timing (wide DAG, max_parallel=16, async)") + .background_gradient(subset=["Total Run (ms)"], cmap="YlOrRd") + ) + + # --- Column timeline --- + timeline_rows = [] + for col in sorted(by_column): + col_traces = by_column[col] + first = min(t.dispatched_at for t in col_traces if t.dispatched_at) - t0 + last = max(t.completed_at for t in col_traces if t.completed_at) - t0 + timeline_rows.append( + { + "Column": col, + "First Dispatch (s)": round(first, 3), + "Last Complete (s)": round(last, 3), + "Span (s)": round(last - first, 3), + } + ) + + timeline_df = pd.DataFrame(timeline_rows) + display( + timeline_df.style.hide(axis="index") + .set_caption("Column timeline (relative to first dispatch)") + .background_gradient(subset=["Span (s)"], cmap="Blues") + ) + +# %% [markdown] +# ### Per-row-group breakdown +# +# The async scheduler pipelines row groups: it can start independent columns +# for row group 1 while row group 0 is still generating dependent columns +# like `synthesis`. This table shows when each (row group, column) pair +# starts and finishes. + +# %% +if traces: + rg_rows = [] + for t in traces: + if t.dispatched_at and t.completed_at: + rg_rows.append( + { + "Row Group": t.row_group, + "Column": t.column, + "Task Type": t.task_type, + "Row Index": t.row_index if t.row_index is not None else "—", + "Start (s)": round(t.slot_acquired_at - t0, 3) if t.slot_acquired_at else None, + "End (s)": round(t.completed_at - t0, 3), + "Run (ms)": ( + round((t.completed_at - t.slot_acquired_at) * 1000, 0) if t.slot_acquired_at else None + ), + "Status": t.status, + } + ) + + rg_df = pd.DataFrame(rg_rows).sort_values(["Row Group", "Start (s)"], na_position="last") + + # Summary per (row_group, column) + rg_summary = ( + rg_df.groupby(["Row Group", "Column"]) + .agg( + Tasks=("Run (ms)", "count"), + First_Start=("Start (s)", "min"), + Last_End=("End (s)", "max"), + Avg_Run_ms=("Run (ms)", "mean"), + ) + .reset_index() + ) + rg_summary["Avg_Run_ms"] = rg_summary["Avg_Run_ms"].round(0) + rg_summary.columns = ["Row Group", "Column", "Tasks", "First Start (s)", "Last End (s)", "Avg Run (ms)"] + + display( + rg_summary.style.hide(axis="index") + .set_caption("Per-row-group column breakdown") + .background_gradient(subset=["Avg Run (ms)"], cmap="YlOrRd") + ) +else: + print("(no traces captured — run with trace=True)") + +# %% [markdown] +# ### Task Gantt Chart +# +# Each bar represents a task executing on the async scheduler. Bars are +# color-coded by column — overlapping bars across columns show the async +# scheduler exploiting cross-column parallelism. +# +# Notice that `synthesis` bars start *before* `summary`/`analysis`/`trivia` +# have finished all their rows. This is because `synthesis` is a +# **cell-by-cell** column: it only needs `summary[i]`, `analysis[i]`, and +# `trivia[i]` to be complete before generating `synthesis[i]`. The scheduler +# dispatches each synthesis cell as soon as its per-row dependencies are met, +# without waiting for the full column to finish. + +# %% +if traces: + end_time = max(t.completed_at for t in traces if t.completed_at) + total_secs = end_time - t0 + + # Use topological order so the chart reads top-to-bottom like the DAG + columns_ordered = wide_graph.get_topological_order() + cmap = plt.get_cmap("tab10") + col_colors = {col: cmap(i) for i, col in enumerate(columns_ordered)} + + # Build bars: one per task, y-position = column index + fig, ax = plt.subplots(figsize=(12, max(3, len(columns_ordered) * 0.8))) + + for col_idx, col in enumerate(columns_ordered): + for t in by_column[col]: + if t.slot_acquired_at and t.completed_at: + start = t.slot_acquired_at - t0 + duration = t.completed_at - t.slot_acquired_at + alpha = 0.4 if t.status == "error" else 0.85 + ax.barh( + col_idx, + duration, + left=start, + height=0.6, + color=col_colors[col], + alpha=alpha, + edgecolor="white", + linewidth=0.5, + ) + + ax.set_yticks(range(len(columns_ordered))) + ax.set_yticklabels(columns_ordered) + ax.set_xlabel("Time (s)") + ax.set_title("Task Gantt Chart — wide DAG, max_parallel=16, async") + ax.set_xlim(0, total_secs * 1.02) + ax.invert_yaxis() + + patches = [mpatches.Patch(color=col_colors[c], label=c) for c in columns_ordered] + ax.legend(handles=patches, loc="upper right", fontsize="small", framealpha=0.8) + plt.tight_layout() + plt.show() + +# %% [markdown] +# ### Cell-level Timeline +# +# This chart shows every individual cell task, with the y-axis representing +# row indices. Each colored bar is one (column, row) task. You can see how +# `synthesis` cells start filling in from the top (row 0) while the +# independent columns are still processing later rows — this is the +# cell-level pipelining that the async scheduler enables. + +# %% +if traces: + # Collect cell-by-cell tasks (row_index is not None) + cell_tasks = [t for t in traces if t.row_index is not None and t.slot_acquired_at and t.completed_at] + # Also include full_column tasks, drawn as spanning all rows + batch_tasks = [t for t in traces if t.row_index is None and t.slot_acquired_at and t.completed_at] + + if cell_tasks: + row_indices = sorted({t.row_index for t in cell_tasks}) + columns_ordered = wide_graph.get_topological_order() + cmap_cell = plt.get_cmap("tab10") + col_colors = {col: cmap_cell(i) for i, col in enumerate(columns_ordered)} + n_cols = len(columns_ordered) + + bar_height = 0.8 / n_cols # subdivide each row's vertical space + + fig, ax = plt.subplots(figsize=(12, max(3, len(row_indices) * 0.6))) + + for t in cell_tasks: + col_offset = columns_ordered.index(t.column) + y = t.row_index + (col_offset - n_cols / 2 + 0.5) * bar_height + start = t.slot_acquired_at - t0 + duration = t.completed_at - t.slot_acquired_at + alpha = 0.4 if t.status == "error" else 0.85 + ax.barh( + y, + duration, + left=start, + height=bar_height * 0.9, + color=col_colors[t.column], + alpha=alpha, + edgecolor="white", + linewidth=0.3, + ) + + # Show full_column tasks as a thin bar spanning all rows + for t in batch_tasks: + start = t.slot_acquired_at - t0 + duration = t.completed_at - t.slot_acquired_at + ax.barh( + len(row_indices) / 2 - 0.5, + duration, + left=start, + height=len(row_indices) * 0.05, + color=col_colors.get(t.column, "gray"), + alpha=0.3, + edgecolor="none", + ) + + ax.set_yticks(row_indices) + ax.set_yticklabels([f"row {r}" for r in row_indices]) + ax.set_xlabel("Time (s)") + ax.set_title("Cell-level Timeline — each bar is one (column, row) task") + ax.set_xlim(0, total_secs * 1.02) + ax.invert_yaxis() + + patches = [mpatches.Patch(color=col_colors[c], label=c) for c in columns_ordered] + ax.legend(handles=patches, loc="upper right", fontsize="small", framealpha=0.8) + plt.tight_layout() + plt.show() + +# %% [markdown] +# ### Optimization Hints + +# %% +if traces: + print("Optimization hints:") + + # 1. Slowest column by total run time + max_col = max( + by_column, + key=lambda c: sum( + (t.completed_at - t.slot_acquired_at) for t in by_column[c] if t.completed_at and t.slot_acquired_at + ), + ) + max_total = sum( + (t.completed_at - t.slot_acquired_at) for t in by_column[max_col] if t.completed_at and t.slot_acquired_at + ) + print(f" - Slowest column: '{max_col}' ({max_total:.1f}s total LLM time)") + + # 2. Average wait time + all_waits = [(t.slot_acquired_at - t.dispatched_at) for t in traces if t.slot_acquired_at and t.dispatched_at] + avg_wait = sum(all_waits) / len(all_waits) if all_waits else 0 + if avg_wait > 0.5: + print(f" - High avg wait: {avg_wait:.3f}s → consider increasing max_parallel_requests") + else: + print(f" - Low avg wait: {avg_wait:.3f}s → concurrency slots are well-utilized") + + # 3. Column overlap detection + col_spans = {} + for col in by_column: + first = min(t.dispatched_at for t in by_column[col] if t.dispatched_at) + last = max(t.completed_at for t in by_column[col] if t.completed_at) + col_spans[col] = (first, last) + + overlapping_pairs = [] + cols = sorted(col_spans) + for i, c1 in enumerate(cols): + for c2 in cols[i + 1 :]: + s1, e1 = col_spans[c1] + s2, e2 = col_spans[c2] + overlap = max(0, min(e1, e2) - max(s1, s2)) + if overlap > 0: + overlapping_pairs.append((c1, c2, overlap)) + + if overlapping_pairs: + print(f" - Column overlap detected ({len(overlapping_pairs)} pairs):") + for c1, c2, overlap in sorted(overlapping_pairs, key=lambda x: -x[2]): + print(f" {c1} ↔ {c2}: {overlap:.1f}s overlap") + else: + print(" - No column overlap — async scheduler isn't adding parallelism") + print(" → check that columns have independent dependencies") +else: + print("(no traces captured — run with trace=True)") + +# %% [markdown] +# ## Takeaways +# +# 1. **DAG shape matters.** Sequential pipelines (A → B → C) give the async +# scheduler nothing to overlap. Design pipelines with independent branches +# where possible. +# +# 2. **`max_parallel_requests` is the lever.** The default of 4 is +# conservative. If your provider supports higher concurrency, increase it +# to let the scheduler fill slots across columns. +# +# 3. **Traces reveal bottlenecks.** Use `RunConfig(async_trace=True)` and +# inspect `result.task_traces` to see: +# - Which columns are slowest (optimize prompts or use faster models) +# - Whether tasks are waiting for slots (increase `max_parallel_requests`) +# - Whether columns overlap in time (confirms async is helping) +# +# 4. **No regression.** Even when the async scheduler can't parallelize, +# it matches sync performance — the overhead is negligible. diff --git a/demo/async/test_edge_cases.py b/demo/async/test_edge_cases.py new file mode 100644 index 000000000..b200bad85 --- /dev/null +++ b/demo/async/test_edge_cases.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Test: edge cases for the async scheduler. + +Sub-tests: +1. Minimum generation (num_records=1) — produces correct dataset and traces. +2. Tiny buffer_size (buffer_size=1) — multiple row groups schedule correctly. +3. Early shutdown — failing column triggers error-rate shutdown; traces capture errors. +""" + +from __future__ import annotations + +import os +import sys +import tempfile + +os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" + +from helpers import ( + check, + create_base_config, + create_data_designer, + create_mock_model_config, + patch_llm_responses, + seed_rng, +) + +from data_designer.config.column_configs import LLMTextColumnConfig +from data_designer.config.run_config import RunConfig + + +def _simple_config(): + config = create_base_config() + config.add_column( + LLMTextColumnConfig( + name="output", + model_alias="openai-text", + prompt="Process seed {{ seed_value }}.", + ) + ) + return config + + +def _config_with_fragile_column(): + """Config where 'fragile' column prompt triggers mock failures.""" + config = create_base_config() + config.add_column( + LLMTextColumnConfig( + name="stable", + model_alias="openai-text", + prompt="Process seed {{ seed_value }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="fragile", + model_alias="openai-text", + prompt="FRAGILE_MARKER: analyze {{ seed_value }}.", + ) + ) + return config + + +# --------------------------------------------------------------------------- +# Sub-test 1: minimum generation +# --------------------------------------------------------------------------- + + +def test_single_record() -> None: + print("\n--- Sub-test 1: Single record (num_records=1) ---") + + with tempfile.TemporaryDirectory() as artifact_dir: + seed_rng(42) + dd = create_data_designer(artifact_path=artifact_dir, async_trace=True) + config = _simple_config() + + with patch_llm_responses(): + result = dd.create(config, num_records=1, dataset_name="single") + + df = result.load_dataset() + traces = result.task_traces + + print(f" rows={len(df)}, cols={len(df.columns)}, traces={len(traces)}") + + check(len(df) == 1, "Single row produced") + check("seed_value" in df.columns, "seed_value column present") + check("output" in df.columns, "output column present") + check(len(traces) > 0, f"Traces captured ({len(traces)})") + check( + all(t.status == "ok" for t in traces), + "All traces succeeded", + ) + + +# --------------------------------------------------------------------------- +# Sub-test 2: tiny buffer_size +# --------------------------------------------------------------------------- + + +def test_tiny_buffer() -> None: + # buffer_size=1 with 8 records creates 8 row groups, well above + # the default max_concurrent_row_groups=3. This exercises the + # scheduler's streaming admission (row groups admitted as slots free). + num_records = 8 + print(f"\n--- Sub-test 2: Tiny buffer_size (buffer_size=1, num_records={num_records}) ---") + + with tempfile.TemporaryDirectory() as artifact_dir: + seed_rng(42) + _, provider = create_mock_model_config() + + from data_designer.interface import DataDesigner + + dd = DataDesigner(artifact_path=artifact_dir, model_providers=[provider]) + dd.set_run_config( + RunConfig( + buffer_size=1, + disable_early_shutdown=True, + async_trace=True, + ) + ) + + config = _simple_config() + + with patch_llm_responses(): + result = dd.create(config, num_records=num_records, dataset_name="tiny-buffer") + + df = result.load_dataset() + traces = result.task_traces + + row_groups_seen = {t.row_group for t in traces} + + print(f" rows={len(df)}, traces={len(traces)}, row_groups={sorted(row_groups_seen)}") + + check(len(df) == num_records, f"All {num_records} rows produced") + check(len(row_groups_seen) == num_records, f"Got {num_records} distinct row groups") + check( + all(t.status == "ok" for t in traces), + "All traces succeeded", + ) + + +# --------------------------------------------------------------------------- +# Sub-test 3: early shutdown via error rate +# --------------------------------------------------------------------------- + + +def test_early_shutdown() -> None: + print("\n--- Sub-test 3: Early shutdown from error rate ---") + + with tempfile.TemporaryDirectory() as artifact_dir: + seed_rng(42) + _, provider = create_mock_model_config() + + from data_designer.interface import DataDesigner + from data_designer.interface.errors import DataDesignerGenerationError, DataDesignerProfilingError + + dd = DataDesigner(artifact_path=artifact_dir, model_providers=[provider]) + dd.set_run_config( + RunConfig( + buffer_size=1000, + disable_early_shutdown=False, + shutdown_error_rate=0.3, + shutdown_error_window=5, + async_trace=True, + ) + ) + + # Capture the builder so we can read traces even if profiling fails + # (early shutdown may produce no data → no parquet → profiling error). + captured_builder = None + original_create = dd._create_dataset_builder + + def _capture_builder(config, rp): + nonlocal captured_builder + captured_builder = original_create(config, rp) + return captured_builder + + dd._create_dataset_builder = _capture_builder + + config = _config_with_fragile_column() + num_records = 16 + + try: + with patch_llm_responses(fail_pattern="FRAGILE_MARKER", fail_rate=1.0): + result = dd.create(config, num_records=num_records, dataset_name="shutdown") + traces = result.task_traces + except (DataDesignerGenerationError, DataDesignerProfilingError): + # Early shutdown with 100% failures leaves no data — profiling fails. + traces = captured_builder.task_traces if captured_builder else [] + finally: + dd._create_dataset_builder = original_create + + error_traces = [t for t in traces if t.status == "error"] + ok_traces = [t for t in traces if t.status == "ok"] + + print(f" traces={len(traces)}, ok={len(ok_traces)}, errors={len(error_traces)}") + + check(len(traces) > 0, "Traces were captured") + check(len(error_traces) > 0, f"Error traces present ({len(error_traces)})") + check( + any(t.column == "fragile" for t in error_traces), + "Errors are on the 'fragile' column", + ) + + # With 100% fail rate on fragile column and low error window, + # the scheduler should have stopped before completing everything. + total_expected = num_records * 3 # seed_value + stable + fragile, each cell-by-cell + check( + len(traces) < total_expected, + f"Early shutdown: {len(traces)} traces < {total_expected} expected", + ) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main() -> None: + print("=" * 60) + print("Test: Edge Cases") + print("=" * 60) + + test_single_record() + test_tiny_buffer() + test_early_shutdown() + + print("\nDone.") + + +if __name__ == "__main__": + try: + main() + except AssertionError: + sys.exit(1) diff --git a/demo/async/test_execution_order.py b/demo/async/test_execution_order.py new file mode 100644 index 000000000..d09aa5c09 --- /dev/null +++ b/demo/async/test_execution_order.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Test: async scheduler respects DAG execution order. + +Builds a 4-column DAG: + + seed_value (sampler, no deps) + ├── stage1 (LLM, depends on seed_value) + │ └── stage2 (LLM, depends on stage1) + └── independent (LLM, depends on seed_value only) + +Verifies via TaskTraces that: +1. All seed tasks finish before any downstream task dispatches. +2. All stage1 tasks finish before any stage2 task dispatches. +3. "independent" tasks can overlap with stage1 (no dependency between them). +""" + +from __future__ import annotations + +import os +import sys +import tempfile + +os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" + +from helpers import ( + check, + create_base_config, + create_data_designer, + patch_llm_responses, + seed_rng, +) + +from data_designer.config.column_configs import LLMTextColumnConfig + +NUM_RECORDS = 8 + + +def _build_config(): + config = create_base_config() + config.add_column( + LLMTextColumnConfig( + name="stage1", + model_alias="openai-text", + prompt="Summarize seed {{ seed_value }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="stage2", + model_alias="openai-text", + prompt="Analyze {{ stage1 }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="independent", + model_alias="openai-text", + prompt="Independent thought on {{ seed_value }}.", + ) + ) + return config + + +def _traces_for(traces: list, column: str) -> list: + return [t for t in traces if t.column == column] + + +def main() -> None: + print("=" * 60) + print("Test: Execution Order Respects DAG") + print("=" * 60) + + with tempfile.TemporaryDirectory() as artifact_dir: + seed_rng(42) + dd = create_data_designer(artifact_path=artifact_dir, async_trace=True) + config = _build_config() + + with patch_llm_responses(): + result = dd.create(config, num_records=NUM_RECORDS, dataset_name="order") + + df = result.load_dataset() + traces = result.task_traces + + print(f"\n Dataset: {len(df)} rows, {len(df.columns)} cols") + print(f" Traces: {len(traces)} total") + + # Group traces by column + seed_traces = _traces_for(traces, "seed_value") + stage1_traces = _traces_for(traces, "stage1") + stage2_traces = _traces_for(traces, "stage2") + indep_traces = _traces_for(traces, "independent") + + print( + f" Per column: seed_value={len(seed_traces)}, stage1={len(stage1_traces)}, " + f"stage2={len(stage2_traces)}, independent={len(indep_traces)}" + ) + + # --- Assertion 1: seeds finish before downstream dispatches --- + print("\nResults:") + if seed_traces: + latest_seed_completion = max(t.completed_at for t in seed_traces) + downstream = stage1_traces + stage2_traces + indep_traces + if downstream: + earliest_downstream_dispatch = min(t.dispatched_at for t in downstream) + check( + latest_seed_completion <= earliest_downstream_dispatch, + "All seed tasks completed before first downstream dispatch", + ) + else: + check(False, "Expected downstream traces but found none") + else: + check(False, "Expected seed traces but found none") + + # --- Assertion 2: stage1 finishes before stage2 dispatches --- + if stage1_traces and stage2_traces: + latest_stage1 = max(t.completed_at for t in stage1_traces) + earliest_stage2 = min(t.dispatched_at for t in stage2_traces) + check( + latest_stage1 <= earliest_stage2, + "All stage1 tasks completed before first stage2 dispatch", + ) + + # --- Assertion 3: independent dispatched before stage2 --- + # With zero-latency mocks there's no temporal overlap — tasks complete + # instantly when the event loop schedules them. Instead we verify + # that independent tasks were dispatched in the same "wave" as stage1 + # (between seed completion and stage2 dispatch), proving the scheduler + # treats them as independent of stage1. + if indep_traces and stage2_traces: + latest_indep_dispatch = max(t.dispatched_at for t in indep_traces) + earliest_stage2_dispatch = min(t.dispatched_at for t in stage2_traces) + check( + latest_indep_dispatch <= earliest_stage2_dispatch, + "All independent tasks dispatched before first stage2 dispatch", + ) + + # --- Assertion 4: all traces succeeded --- + error_traces = [t for t in traces if t.status == "error"] + check(len(error_traces) == 0, f"No error traces (found {len(error_traces)})") + + print("\nDone.") + + +if __name__ == "__main__": + try: + main() + except AssertionError: + sys.exit(1) diff --git a/demo/async/test_sync_vs_async_parity.py b/demo/async/test_sync_vs_async_parity.py new file mode 100644 index 000000000..820794633 --- /dev/null +++ b/demo/async/test_sync_vs_async_parity.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Test: sync builder vs async scheduler produce structurally identical datasets. + +Uses real LLMs (openai-text). Compares: +- Row / column counts +- Non-null values in every cell +- TaskTrace availability (async only) +- Wall-clock timing +""" + +from __future__ import annotations + +import os +import sys +import tempfile +import time + +# Enable async engine before any data_designer imports. +os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" + +from helpers import check, seed_rng + +from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig +from data_designer.config.config_builder import DataDesignerConfigBuilder +from data_designer.config.run_config import RunConfig +from data_designer.config.sampler_params import CategorySamplerParams, SamplerType +from data_designer.interface import DataDesigner + +NUM_RECORDS = 5 + + +def _build_config() -> DataDesignerConfigBuilder: + config = DataDesignerConfigBuilder() + config.add_column( + SamplerColumnConfig( + name="topic", + sampler_type=SamplerType.CATEGORY, + params=CategorySamplerParams(values=["science", "history", "art"]), + ) + ) + config.add_column( + LLMTextColumnConfig( + name="summary", + model_alias="openai-text", + prompt="Write a one-sentence summary about {{ topic }}.", + ) + ) + config.add_column( + LLMTextColumnConfig( + name="followup", + model_alias="openai-text", + prompt="Given this summary: '{{ summary }}', write a follow-up question.", + ) + ) + return config + + +def _print_trace_table(traces: list) -> None: + if not traces: + print(" (no traces)") + return + + header = f" {'column':<12} {'rg':>3} {'row':>4} {'type':<13} {'wait':>7} {'run':>7} {'status':<6} {'error'}" + print(header) + print(f" {'-' * len(header.strip())}") + + for t in sorted(traces, key=lambda t: t.dispatched_at): + wait_ms = (t.slot_acquired_at - t.dispatched_at) * 1000 if t.slot_acquired_at and t.dispatched_at else 0 + run_ms = (t.completed_at - t.slot_acquired_at) * 1000 if t.completed_at and t.slot_acquired_at else 0 + row_str = str(t.row_index) if t.row_index is not None else "-" + err_str = (t.error or "")[:40] + print( + f" {t.column:<12} {t.row_group:>3} {row_str:>4} {t.task_type:<13} " + f"{wait_ms:>6.0f}ms {run_ms:>6.0f}ms {t.status:<6} {err_str}" + ) + + +def main() -> None: + print("=" * 60) + print("Test: Sync vs Async Parity (real LLMs)") + print("=" * 60) + + with tempfile.TemporaryDirectory() as async_dir, tempfile.TemporaryDirectory() as sync_dir: + # --- Sync run first (to warm up health checks, caches, etc.) --- + import data_designer.engine.dataset_builders.column_wise_builder as cwb + + cwb.DATA_DESIGNER_ASYNC_ENGINE = False + + print(f"\n[1/2] Running sync builder ({NUM_RECORDS} records)...") + seed_rng(42) + dd_sync = DataDesigner(artifact_path=sync_dir) + dd_sync.set_run_config(RunConfig(buffer_size=1000, disable_early_shutdown=True, async_trace=False)) + config_sync = _build_config() + + t0 = time.perf_counter() + result_sync = dd_sync.create(config_sync, num_records=NUM_RECORDS, dataset_name="sync") + sync_time = time.perf_counter() - t0 + + df_sync = result_sync.load_dataset() + + print(f" rows={len(df_sync)}, cols={len(df_sync.columns)}, time={sync_time:.3f}s") + + # --- Async run second --- + cwb.DATA_DESIGNER_ASYNC_ENGINE = True + + print(f"\n[2/2] Running async scheduler ({NUM_RECORDS} records)...") + seed_rng(42) + dd_async = DataDesigner(artifact_path=async_dir) + dd_async.set_run_config(RunConfig(buffer_size=1000, disable_early_shutdown=True, async_trace=True)) + config_async = _build_config() + + t0 = time.perf_counter() + result_async = dd_async.create(config_async, num_records=NUM_RECORDS, dataset_name="async") + async_time = time.perf_counter() - t0 + + df_async = result_async.load_dataset() + traces_async = result_async.task_traces + + print( + f" rows={len(df_async)}, cols={len(df_async.columns)}, time={async_time:.3f}s, traces={len(traces_async)}" + ) + + print("\n Async traces:") + _print_trace_table(traces_async) + + # --- Assertions --- + print("\nResults:") + check(len(df_async) == NUM_RECORDS, f"Async produced {NUM_RECORDS} rows") + check(len(df_sync) == NUM_RECORDS, f"Sync produced {NUM_RECORDS} rows") + check(len(df_async) == len(df_sync), f"Row counts match ({len(df_async)})") + check( + set(df_async.columns) == set(df_sync.columns), + f"Column sets match ({sorted(df_async.columns)})", + ) + check( + df_async.notna().all().all(), + "Async: no null values", + ) + check( + df_sync.notna().all().all(), + "Sync: no null values", + ) + check(len(traces_async) > 0, f"Async produced {len(traces_async)} TaskTraces") + check(len(result_sync.task_traces) == 0, "Sync produced no TaskTraces") + + speedup = sync_time / async_time if async_time > 0 else float("inf") + print(f"\n Timing: sync={sync_time:.3f}s async={async_time:.3f}s ratio={speedup:.2f}x") + + print("\nDone.") + + +if __name__ == "__main__": + try: + main() + except AssertionError: + sys.exit(1) From cbbf16255a6137c3d5f08480e6ccbc8d438aee7a Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 14:49:08 +0000 Subject: [PATCH 03/23] fix: address all PR review comments on async builder integration - Wire on_batch_complete through on_row_group_complete callback - Mark trailing slots as dropped in replace_dataframe when processor filters rows - Ensure checkpoint still runs when on_before_checkpoint raises - Gate non-seed task dispatch on pre-batch completion - Add public run_pre_batch_on_df to ProcessorRunner (replaces private _run_stage call) - Add public is_column_complete_for_rg to CompletionTracker (replaces private _completed access) - Type task_traces as list[TaskTrace] in results.py - Add async_trace docstring to RunConfig - Move module-level log into _build_async - Add replace_dataframe unit tests (same-size, dropped rows, fewer rows) - Assert on public outcomes in scheduler tests instead of private attributes - Parametrize allow_resize validation tests - Cache seed_cols before main loop - Remove redundant disable_early_shutdown from AsyncTaskScheduler --- .../src/data_designer/config/run_config.py | 2 + .../dataset_builders/async_scheduler.py | 23 +++++--- .../dataset_builders/column_wise_builder.py | 9 ++-- .../utils/completion_tracker.py | 4 ++ .../utils/processor_runner.py | 4 ++ .../utils/row_group_buffer.py | 7 ++- .../test_async_builder_integration.py | 41 ++++++++------- .../dataset_builders/test_async_scheduler.py | 25 +++++++-- .../utils/test_row_group_buffer.py | 52 +++++++++++++++++++ .../src/data_designer/interface/results.py | 6 ++- 10 files changed, 137 insertions(+), 36 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index 37799fae6..d1e378134 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -33,6 +33,8 @@ class RunConfig(ConfigBase): max_conversation_correction_steps: Maximum number of correction rounds permitted within a single conversation when generation tasks call `ModelFacade.generate(...)`. Must be >= 0. Default is 0. + async_trace: If True, collect per-task tracing data when using the async engine + (DATA_DESIGNER_ASYNC_ENGINE=1). Has no effect on the sync path. Default is False. """ disable_early_shutdown: bool = False diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index ea55ca601..8b7499c0f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -50,7 +50,6 @@ def __init__( on_before_checkpoint: Callable[[int, int], None] | None = None, shutdown_error_rate: float = 0.5, shutdown_error_window: int = 10, - disable_early_shutdown: bool = False, trace: bool = False, ) -> None: self._generators = generators @@ -70,8 +69,8 @@ def __init__( self._on_seeds_complete = on_seeds_complete self._on_before_checkpoint = on_before_checkpoint - # Error rate shutdown - self._shutdown_error_rate = 1.0 if disable_early_shutdown else shutdown_error_rate + # Error rate shutdown (caller passes pre-normalized values via RunConfig) + self._shutdown_error_rate = shutdown_error_rate self._shutdown_error_window = shutdown_error_window self._early_shutdown = False @@ -126,6 +125,8 @@ async def _admit_row_groups(self) -> None: async def run(self) -> None: """Main scheduler loop.""" all_columns = self._graph.columns + seed_cols = frozenset(c for c in all_columns if not self._graph.get_upstream_columns(c)) + has_pre_batch = self._on_seeds_complete is not None # Launch admission as a background task so it interleaves with dispatch. admission_task = asyncio.create_task(self._admit_row_groups()) @@ -133,7 +134,7 @@ async def run(self) -> None: # Main dispatch loop while True: if self._early_shutdown: - logger.warning("Early shutdown triggered — error rate exceeded threshold") + logger.warning("Early shutdown triggered - error rate exceeded threshold") break self._wake_event.clear() @@ -141,8 +142,7 @@ async def run(self) -> None: # Run pre-batch callbacks for row groups whose seeds just completed for rg_id, rg_size in self._active_rgs: if rg_id in self._seeds_dispatched_rgs and rg_id not in self._pre_batch_done_rgs: - seed_cols = {c for c in all_columns if not self._graph.get_upstream_columns(c)} - all_seeds_done = all(col in self._tracker._completed.get(rg_id, {}) for col in seed_cols) + all_seeds_done = all(self._tracker.is_column_complete_for_rg(col, rg_id) for col in seed_cols) if all_seeds_done and not self._in_flight_for_rg(rg_id): self._pre_batch_done_rgs.add(rg_id) if self._on_seeds_complete: @@ -156,6 +156,9 @@ async def run(self) -> None: self._buffer_manager.drop_row(rg_id, ri) ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) + # Gate non-seed tasks on pre-batch completion when a pre-batch callback is configured + if has_pre_batch: + ready = [t for t in ready if t.row_group in self._pre_batch_done_rgs or t.column in seed_cols] for task in ready: await self._submission_semaphore.acquire() self._dispatched.add(task) @@ -260,7 +263,13 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: self._active_rgs.remove((rg_id, rg_size)) try: if self._on_before_checkpoint: - self._on_before_checkpoint(rg_id, rg_size) + try: + self._on_before_checkpoint(rg_id, rg_size) + except Exception: + logger.error( + f"on_before_checkpoint failed for row group {rg_id}, checkpointing un-processed data.", + exc_info=True, + ) if self._buffer_manager is not None: self._buffer_manager.checkpoint_row_group(rg_id) if self._on_row_group_complete: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 2eac85278..bc03effbb 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -72,7 +72,6 @@ from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager - logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async task-queue builder") _CLIENT_VERSION: str = get_library_version() @@ -227,6 +226,8 @@ def _build_async( on_batch_complete: Callable[[Path], None] | None = None, ) -> None: """Async task-queue builder path — dispatches tasks based on dependency readiness.""" + logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled - using async task-queue builder") + # Build strategy map from generators strategies: dict[str, GenerationStrategy] = {} gen_map: dict[str, ColumnGenerator] = {} @@ -265,7 +266,7 @@ def on_seeds_complete(rg_id: int, rg_size: int) -> None: if not self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH): return df = buffer_manager.get_dataframe(rg_id) - df = self._processor_runner._run_stage(df, ProcessorStage.PRE_BATCH) + df = self._processor_runner.run_pre_batch_on_df(df) buffer_manager.replace_dataframe(rg_id, df) # Post-batch processor callback: runs after all columns, before checkpoint. @@ -288,9 +289,11 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: buffer_manager=buffer_manager, on_seeds_complete=on_seeds_complete, on_before_checkpoint=on_before_checkpoint, + on_row_group_complete=lambda rg_id: on_batch_complete(self.artifact_storage.final_dataset_path) + if on_batch_complete + else None, shutdown_error_rate=settings.shutdown_error_rate, shutdown_error_window=settings.shutdown_error_window, - disable_early_shutdown=settings.disable_early_shutdown, trace=trace_enabled, ) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 80ee1b9de..01631dfd4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -78,6 +78,10 @@ def is_all_complete(self, cells: list[SliceRef]) -> bool: return False return True + def is_column_complete_for_rg(self, column: str, row_group: int) -> bool: + """Check if *column* has been completed (batch or all cells) for *row_group*.""" + return column in self._completed.get(row_group, {}) + def drop_row(self, row_group: int, row_index: int) -> None: self._validate_row_group(row_group) self._dropped[row_group].add(row_index) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py index 8163c807a..61bba8919 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py @@ -72,6 +72,10 @@ def run_pre_batch(self, batch_manager: DatasetBatchManager) -> None: df = self._run_stage(df, ProcessorStage.PRE_BATCH) batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=True) + def run_pre_batch_on_df(self, df: pd.DataFrame) -> pd.DataFrame: + """Run PRE_BATCH processors on a DataFrame and return the result.""" + return self._run_stage(df, ProcessorStage.PRE_BATCH) + def run_post_batch(self, df: pd.DataFrame, current_batch_number: int | None) -> pd.DataFrame: """Run process_after_batch() on processors that implement it.""" return self._run_stage(df, ProcessorStage.POST_BATCH, current_batch_number=current_batch_number) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py index 3a6c86165..b20ce3aca 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py @@ -69,7 +69,10 @@ def get_dataframe(self, row_group: int) -> pd.DataFrame: return lazy.pd.DataFrame(rows) def replace_dataframe(self, row_group: int, df: pd.DataFrame) -> None: - """Replace the buffer for a row group from a DataFrame (non-dropped rows only).""" + """Replace the buffer for a row group from a DataFrame (non-dropped rows only). + + If *df* has fewer rows than active slots, trailing slots are marked as dropped. + """ dropped = self._dropped.get(row_group, set()) records = df.to_dict(orient="records") buf_idx = 0 @@ -78,6 +81,8 @@ def replace_dataframe(self, row_group: int, df: pd.DataFrame) -> None: continue if buf_idx < len(records): self._buffers[row_group][ri] = records[buf_idx] + else: + self._dropped.setdefault(row_group, set()).add(ri) buf_idx += 1 def drop_row(self, row_group: int, row_index: int) -> None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index 211c6c76c..0ae343b2e 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -73,29 +73,32 @@ def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: # -- allow_resize validation test --------------------------------------------- -def test_validate_async_compatibility_raises_on_allow_resize() -> None: - """allow_resize=True with async scheduler raises DatasetGenerationError.""" +@pytest.mark.parametrize( + "configs,should_raise", + [ + pytest.param( + [Mock(name="col_a", allow_resize=True), Mock(name="col_b", allow_resize=False)], + True, + id="raises_on_allow_resize", + ), + pytest.param( + [Mock(name="col_a", allow_resize=False), Mock(name="col_b", allow_resize=False)], + False, + id="passes_without_allow_resize", + ), + ], +) +def test_validate_async_compatibility(configs: list[Mock], should_raise: bool) -> None: + """Validation rejects allow_resize=True with the async engine.""" builder = Mock(spec=ColumnWiseDatasetBuilder) - builder.single_column_configs = [ - Mock(name="col_a", allow_resize=True), - Mock(name="col_b", allow_resize=False), - ] - # Call the unbound method on the mock - with pytest.raises(DatasetGenerationError, match="allow_resize=True"): + builder.single_column_configs = configs + if should_raise: + with pytest.raises(DatasetGenerationError, match="allow_resize=True"): + ColumnWiseDatasetBuilder._validate_async_compatibility(builder) + else: ColumnWiseDatasetBuilder._validate_async_compatibility(builder) -def test_validate_async_compatibility_passes_without_allow_resize() -> None: - """No allow_resize=True passes validation.""" - builder = Mock(spec=ColumnWiseDatasetBuilder) - builder.single_column_configs = [ - Mock(name="col_a", allow_resize=False), - Mock(name="col_b", allow_resize=False), - ] - # Should not raise - ColumnWiseDatasetBuilder._validate_async_compatibility(builder) - - # -- _build_async integration test with mock generators ----------------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index bffb4f453..0cabcffda 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -580,22 +580,31 @@ async def test_scheduler_error_rate_shutdown() -> None: row_groups = [(0, 10)] tracker = CompletionTracker.with_graph(graph, row_groups) + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + buffer_mgr = RowGroupBufferManager(storage) + scheduler = AsyncTaskScheduler( generators=generators, graph=graph, tracker=tracker, row_groups=row_groups, + buffer_manager=buffer_mgr, shutdown_error_rate=0.5, shutdown_error_window=2, ) await scheduler.run() - assert scheduler._early_shutdown is True + # Early shutdown: not all rows should be checkpointed (some row groups incomplete) + assert buffer_mgr.actual_num_records < 10 @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_early_shutdown_disabled() -> None: - """disable_early_shutdown=True prevents shutdown even at 100% error rate.""" + """shutdown_error_rate=1.0 prevents shutdown even at 100% error rate.""" provider = _mock_provider() configs = [ SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), @@ -614,16 +623,24 @@ async def test_scheduler_early_shutdown_disabled() -> None: row_groups = [(0, 5)] tracker = CompletionTracker.with_graph(graph, row_groups) + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + buffer_mgr = RowGroupBufferManager(storage) + scheduler = AsyncTaskScheduler( generators=generators, graph=graph, tracker=tracker, row_groups=row_groups, - disable_early_shutdown=True, + buffer_manager=buffer_mgr, + shutdown_error_rate=1.0, ) await scheduler.run() - assert scheduler._early_shutdown is False + # All rows dropped (all fail) but no early shutdown - all row groups processed + assert all(tracker.is_dropped(0, ri) for ri in range(5)) + assert tracker.is_row_group_complete(0, 5, ["seed", "fail_col"]) @pytest.mark.asyncio(loop_scope="session") diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py index 50819ee64..189ddcac7 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py @@ -7,6 +7,7 @@ import pytest +import data_designer.lazy_heavy_imports as lazy from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager @@ -124,6 +125,57 @@ def test_checkpoint_calls_on_complete() -> None: callback.assert_called_once_with("/fake/final.parquet") +def test_replace_dataframe_same_size() -> None: + """replace_dataframe with same number of rows replaces data in-place.""" + mgr = RowGroupBufferManager(_mock_artifact_storage()) + mgr.init_row_group(0, 3) + mgr.update_batch(0, "col", ["a", "b", "c"]) + + df = lazy.pd.DataFrame({"col": ["x", "y", "z"]}) + mgr.replace_dataframe(0, df) + + assert mgr.get_row(0, 0) == {"col": "x"} + assert mgr.get_row(0, 1) == {"col": "y"} + assert mgr.get_row(0, 2) == {"col": "z"} + + +def test_replace_dataframe_with_dropped_rows() -> None: + """replace_dataframe skips dropped rows and replaces only active slots.""" + mgr = RowGroupBufferManager(_mock_artifact_storage()) + mgr.init_row_group(0, 4) + mgr.update_batch(0, "col", ["a", "b", "c", "d"]) + mgr.drop_row(0, 1) # drop row 1 + + # 3 active rows: indices 0, 2, 3 + df = lazy.pd.DataFrame({"col": ["x", "y", "z"]}) + mgr.replace_dataframe(0, df) + + assert mgr.get_row(0, 0) == {"col": "x"} + assert mgr.is_dropped(0, 1) + assert mgr.get_row(0, 2) == {"col": "y"} + assert mgr.get_row(0, 3) == {"col": "z"} + + +def test_replace_dataframe_fewer_rows_marks_trailing_dropped() -> None: + """replace_dataframe with fewer rows marks trailing active slots as dropped.""" + mgr = RowGroupBufferManager(_mock_artifact_storage()) + mgr.init_row_group(0, 4) + mgr.update_batch(0, "col", ["a", "b", "c", "d"]) + + # Only 2 rows - should drop indices 2 and 3 + df = lazy.pd.DataFrame({"col": ["x", "y"]}) + mgr.replace_dataframe(0, df) + + assert mgr.get_row(0, 0) == {"col": "x"} + assert mgr.get_row(0, 1) == {"col": "y"} + assert mgr.is_dropped(0, 2) + assert mgr.is_dropped(0, 3) + + # get_dataframe should only return the 2 active rows + result_df = mgr.get_dataframe(0) + assert len(result_df) == 2 + + def test_checkpoint_calls_on_complete_when_all_rows_dropped() -> None: storage = _mock_artifact_storage() callback = Mock() diff --git a/packages/data-designer/src/data_designer/interface/results.py b/packages/data-designer/src/data_designer/interface/results.py index 5c829e16d..07692ff00 100644 --- a/packages/data-designer/src/data_designer/interface/results.py +++ b/packages/data-designer/src/data_designer/interface/results.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: import pandas as pd + from data_designer.engine.dataset_builders.utils.task_model import TaskTrace + class DatasetCreationResults(WithRecordSamplerMixin): """Results container for a Data Designer dataset creation run. @@ -33,7 +35,7 @@ def __init__( analysis: DatasetProfilerResults, config_builder: DataDesignerConfigBuilder, dataset_metadata: DatasetMetadata, - task_traces: list | None = None, + task_traces: list[TaskTrace] | None = None, ): """Creates a new instance with results based on a dataset creation run. @@ -48,7 +50,7 @@ def __init__( self._analysis = analysis self._config_builder = config_builder self.dataset_metadata = dataset_metadata - self.task_traces: list = task_traces or [] + self.task_traces: list[TaskTrace] = task_traces or [] def load_analysis(self) -> DatasetProfilerResults: """Load the profiling analysis results for the generated dataset. From f06b90dbe525bc2ed0540acfd3cdccf9f73364b8 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 14:52:45 +0000 Subject: [PATCH 04/23] style: fix ruff format for lambda expression --- .../engine/dataset_builders/column_wise_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index bc03effbb..e1881362a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -289,9 +289,9 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: buffer_manager=buffer_manager, on_seeds_complete=on_seeds_complete, on_before_checkpoint=on_before_checkpoint, - on_row_group_complete=lambda rg_id: on_batch_complete(self.artifact_storage.final_dataset_path) - if on_batch_complete - else None, + on_row_group_complete=lambda rg_id: ( + on_batch_complete(self.artifact_storage.final_dataset_path) if on_batch_complete else None + ), shutdown_error_rate=settings.shutdown_error_rate, shutdown_error_window=settings.shutdown_error_window, trace=trace_enabled, From 3facfefd49bc0f46339a69f66e604f622012b1fe Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 16:04:09 +0000 Subject: [PATCH 05/23] fix: address open review issues on async scheduler - Flush completed row groups before breaking on early shutdown (data loss) - Change error rate check from >= to > so disable_early_shutdown sentinel (1.0) never triggers at 100% failure rate - Extract seeds-complete check into helper and call it in salvage rounds via _drain_frontier, with pre-batch gating, so pre-batch processor runs even when seed tasks succeed only after retry - Fix is_column_complete_for_rg to check _batch_complete first, then verify all non-dropped rows for CELL_BY_CELL columns - Replace O(|in-flight|) scan in _in_flight_for_rg with per-RG counter --- .../dataset_builders/async_scheduler.py | 53 ++++++++++++------- .../utils/completion_tracker.py | 11 +++- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 8b7499c0f..921cbc60d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -86,6 +86,9 @@ def __init__( if gen.is_order_dependent and id(gen) not in self._stateful_locks: self._stateful_locks[id(gen)] = asyncio.Lock() + # Per-RG in-flight counters for O(1) lookup + self._in_flight_counts: dict[int, int] = {} + # Deferred retryable failures (retried in salvage rounds) self._deferred: list[Task] = [] @@ -135,25 +138,12 @@ async def run(self) -> None: while True: if self._early_shutdown: logger.warning("Early shutdown triggered - error rate exceeded threshold") + self._checkpoint_completed_row_groups(all_columns) break self._wake_event.clear() - # Run pre-batch callbacks for row groups whose seeds just completed - for rg_id, rg_size in self._active_rgs: - if rg_id in self._seeds_dispatched_rgs and rg_id not in self._pre_batch_done_rgs: - all_seeds_done = all(self._tracker.is_column_complete_for_rg(col, rg_id) for col in seed_cols) - if all_seeds_done and not self._in_flight_for_rg(rg_id): - self._pre_batch_done_rgs.add(rg_id) - if self._on_seeds_complete: - try: - self._on_seeds_complete(rg_id, rg_size) - except Exception as exc: - logger.warning(f"Pre-batch processor failed for row group {rg_id}, skipping: {exc}") - for ri in range(rg_size): - self._tracker.drop_row(rg_id, ri) - if self._buffer_manager: - self._buffer_manager.drop_row(rg_id, ri) + self._run_seeds_complete_check(seed_cols) ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) # Gate non-seed tasks on pre-batch completion when a pre-batch callback is configured @@ -163,6 +153,7 @@ async def run(self) -> None: await self._submission_semaphore.acquire() self._dispatched.add(task) self._in_flight.add(task) + self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) + 1 asyncio.create_task(self._execute_task(task)) self._checkpoint_completed_row_groups(all_columns) @@ -223,12 +214,13 @@ async def run(self) -> None: Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch") ) self._in_flight.add(task) + self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) + 1 asyncio.create_task(self._execute_seed_task(task, gid)) else: self._dispatched.discard(task) # Drain: dispatch frontier tasks and any newly-ready downstream tasks # until nothing remains in-flight or in the frontier. - await self._drain_frontier() + await self._drain_frontier(seed_cols, has_pre_batch, all_columns) self._checkpoint_completed_row_groups(all_columns) if self._active_rgs: @@ -238,14 +230,18 @@ async def run(self) -> None: "These row groups were not checkpointed." ) - async def _drain_frontier(self) -> None: + async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, all_columns: list[str]) -> None: """Dispatch all frontier tasks and their downstream until quiescent.""" while True: + self._run_seeds_complete_check(seed_cols) ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) + if has_pre_batch: + ready = [t for t in ready if t.row_group in self._pre_batch_done_rgs or t.column in seed_cols] for task in ready: await self._submission_semaphore.acquire() self._dispatched.add(task) self._in_flight.add(task) + self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) + 1 asyncio.create_task(self._execute_task(task)) if not self._in_flight: break @@ -279,9 +275,26 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: finally: self._rg_semaphore.release() + def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: + """Run pre-batch callbacks for row groups whose seeds just completed.""" + for rg_id, rg_size in self._active_rgs: + if rg_id in self._seeds_dispatched_rgs and rg_id not in self._pre_batch_done_rgs: + all_seeds_done = all(self._tracker.is_column_complete_for_rg(col, rg_id) for col in seed_cols) + if all_seeds_done and not self._in_flight_for_rg(rg_id): + self._pre_batch_done_rgs.add(rg_id) + if self._on_seeds_complete: + try: + self._on_seeds_complete(rg_id, rg_size) + except Exception as exc: + logger.warning(f"Pre-batch processor failed for row group {rg_id}, skipping: {exc}") + for ri in range(rg_size): + self._tracker.drop_row(rg_id, ri) + if self._buffer_manager: + self._buffer_manager.drop_row(rg_id, ri) + def _in_flight_for_rg(self, rg_id: int) -> bool: """Check if any tasks are in-flight for a given row group.""" - return any(t.row_group == rg_id for t in self._in_flight) + return self._in_flight_counts.get(rg_id, 0) > 0 def _check_error_rate(self) -> None: """Trigger early shutdown if error rate exceeds threshold.""" @@ -289,7 +302,7 @@ def _check_error_rate(self) -> None: if completed < self._shutdown_error_window: return error_rate = self._error_count / max(1, completed) - if error_rate >= self._shutdown_error_rate: + if error_rate > self._shutdown_error_rate: self._early_shutdown = True async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: @@ -328,6 +341,7 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: ) self._dispatched.add(Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="batch")) self._in_flight.add(task) + self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) + 1 asyncio.create_task(self._execute_seed_task(task, gid)) async def _execute_seed_task(self, task: Task, generator_id: int) -> None: @@ -421,6 +435,7 @@ async def _execute_task_inner(self, task: Task) -> None: self.traces.append(trace) self._in_flight.discard(task) + self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) - 1 if not retryable and not skipped: self._dispatched.discard(task) self._submission_semaphore.release() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 01631dfd4..f8c974f4a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -79,8 +79,15 @@ def is_all_complete(self, cells: list[SliceRef]) -> bool: return True def is_column_complete_for_rg(self, column: str, row_group: int) -> bool: - """Check if *column* has been completed (batch or all cells) for *row_group*.""" - return column in self._completed.get(row_group, {}) + """Check if *column* has been fully completed for *row_group*.""" + if column in self._batch_complete.get(row_group, set()): + return True + rg_size = self._row_group_sizes.get(row_group, 0) + if rg_size == 0: + return False + completed = self._completed.get(row_group, {}).get(column, set()) + dropped = self._dropped.get(row_group, set()) + return all(ri in completed or ri in dropped for ri in range(rg_size)) def drop_row(self, row_group: int, row_index: int) -> None: self._validate_row_group(row_group) From c650a2f7194bd877eac60b3535d3233552d977fe Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 19:09:54 +0000 Subject: [PATCH 06/23] fix: sync pre-batch row drops to CompletionTracker and restore stderr safely Pre-batch processors that filter rows marked them as dropped in RowGroupBufferManager but not in CompletionTracker, causing unnecessary LLM calls for rows that would be discarded at checkpoint time. Also wrap the benchmark warmup stderr redirect in try/finally so stderr is restored if _run_once raises. --- demo/async/bench_sync_vs_async.py | 13 +++++++------ .../engine/dataset_builders/column_wise_builder.py | 4 ++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/demo/async/bench_sync_vs_async.py b/demo/async/bench_sync_vs_async.py index 6e019da81..89a9a9ca8 100644 --- a/demo/async/bench_sync_vs_async.py +++ b/demo/async/bench_sync_vs_async.py @@ -216,12 +216,13 @@ def main() -> None: # --- Warmup: one sync run to prime health checks, caches, etc. --- # First DataDesigner() call configures logging, so suppress via devnull. print("\nWarmup (sync)...", end=" ", flush=True) - devnull = open(os.devnull, "w") # noqa: SIM115 - old_stderr = sys.stderr - sys.stderr = devnull - warmup_time = _run_once(async_mode=False, num_records=num_records, dag=dag, max_parallel=max_parallel) - sys.stderr = old_stderr - devnull.close() + with open(os.devnull, "w") as devnull: + old_stderr = sys.stderr + sys.stderr = devnull + try: + warmup_time = _run_once(async_mode=False, num_records=num_records, dag=dag, max_parallel=max_parallel) + finally: + sys.stderr = old_stderr print(f"{warmup_time:.3f}s") # --- Interleaved trials: ABABABAB to reduce temporal bias --- diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index e1881362a..0ece93b0f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -268,6 +268,10 @@ def on_seeds_complete(rg_id: int, rg_size: int) -> None: df = buffer_manager.get_dataframe(rg_id) df = self._processor_runner.run_pre_batch_on_df(df) buffer_manager.replace_dataframe(rg_id, df) + # Sync newly-dropped rows to the tracker so downstream cell tasks are skipped + for ri in range(rg_size): + if buffer_manager.is_dropped(rg_id, ri) and not tracker.is_dropped(rg_id, ri): + tracker.drop_row(rg_id, ri) # Post-batch processor callback: runs after all columns, before checkpoint. def on_before_checkpoint(rg_id: int, rg_size: int) -> None: From 71e741279ba8b1aff73dac42f3e5808dd4c77533 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 19:24:29 +0000 Subject: [PATCH 07/23] fix: prune _admitted_rg_ids on row group checkpoint Prevents unbounded growth of the admission set across large runs. --- .../src/data_designer/engine/dataset_builders/async_scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 921cbc60d..ce523db38 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -257,6 +257,7 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: ] for rg_id, rg_size in completed: self._active_rgs.remove((rg_id, rg_size)) + self._admitted_rg_ids.discard(rg_id) try: if self._on_before_checkpoint: try: From d43fc41223935972fa00bb6dcebda58d85405e0e Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 19:28:12 +0000 Subject: [PATCH 08/23] chore: remove demo/async files from PR Dev-time benchmarks and manual test scripts - kept locally, not needed in the PR. --- demo/async/bench_sync_vs_async.py | 282 - demo/async/helpers.py | 213 - demo/async/notebook_async_benchmark.html | 9380 ---------------------- demo/async/notebook_async_benchmark.py | 704 -- demo/async/test_edge_cases.py | 236 - demo/async/test_execution_order.py | 150 - demo/async/test_sync_vs_async_parity.py | 158 - 7 files changed, 11123 deletions(-) delete mode 100644 demo/async/bench_sync_vs_async.py delete mode 100644 demo/async/helpers.py delete mode 100644 demo/async/notebook_async_benchmark.html delete mode 100644 demo/async/notebook_async_benchmark.py delete mode 100644 demo/async/test_edge_cases.py delete mode 100644 demo/async/test_execution_order.py delete mode 100644 demo/async/test_sync_vs_async_parity.py diff --git a/demo/async/bench_sync_vs_async.py b/demo/async/bench_sync_vs_async.py deleted file mode 100644 index 89a9a9ca8..000000000 --- a/demo/async/bench_sync_vs_async.py +++ /dev/null @@ -1,282 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Benchmark: sync vs async builder with repeated trials. - -Runs each engine N times (default 5), drops the first run of each as warmup, -and reports min / median / mean / max / stdev for the remaining runs. - -Uses real LLMs (openai-text). Set OPENAI_API_KEY in environment. - -Usage: - cd async_test && uv run python bench_sync_vs_async.py - cd async_test && uv run python bench_sync_vs_async.py --trials 7 --num-records 10 -""" - -from __future__ import annotations - -import logging -import math -import os -import sys -import tempfile -import time -import warnings -from argparse import ArgumentParser - -os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" - -warnings.filterwarnings("ignore", message=".*urllib3.*") -warnings.filterwarnings("ignore", message=".*Unclosed.*") - -import data_designer.engine.dataset_builders.column_wise_builder as cwb -from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig -from data_designer.config.config_builder import DataDesignerConfigBuilder -from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig -from data_designer.config.run_config import RunConfig -from data_designer.config.sampler_params import CategorySamplerParams, SamplerType -from data_designer.interface import DataDesigner - -NUM_RECORDS_DEFAULT = 5 -TRIALS_DEFAULT = 5 -MAX_PARALLEL_DEFAULT = 4 - -# DAG shapes: -# narrow: topic → summary → followup (sequential, 2 LLM cols) -# wide: topic → summary ─┐ -# topic → analysis ├→ synthesis (3 parallel + 1 merge, 4 LLM cols) -# topic → trivia ──┘ -DAG_DEFAULT = "wide" - - -def _build_config(dag: str, max_parallel: int) -> DataDesignerConfigBuilder: - config = DataDesignerConfigBuilder( - model_configs=[ - ModelConfig( - alias="openai-text", - model="gpt-4.1", - provider="openai", - inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=max_parallel), - ) - ] - ) - config.add_column( - SamplerColumnConfig( - name="topic", - sampler_type=SamplerType.CATEGORY, - params=CategorySamplerParams(values=["science", "history", "art"]), - ) - ) - - if dag == "narrow": - config.add_column( - LLMTextColumnConfig( - name="summary", - model_alias="openai-text", - prompt="Write a one-sentence summary about {{ topic }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="followup", - model_alias="openai-text", - prompt="Given this summary: '{{ summary }}', write a follow-up question.", - ) - ) - elif dag == "wide": - config.add_column( - LLMTextColumnConfig( - name="summary", - model_alias="openai-text", - prompt="Write a one-sentence summary about {{ topic }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="analysis", - model_alias="openai-text", - prompt="Write a one-sentence analysis of why {{ topic }} matters.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="trivia", - model_alias="openai-text", - prompt="State one surprising fact about {{ topic }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="synthesis", - model_alias="openai-text", - prompt=( - "Combine these into one sentence: " - "summary='{{ summary }}', analysis='{{ analysis }}', trivia='{{ trivia }}'." - ), - ) - ) - else: - raise ValueError(f"Unknown dag shape: {dag!r}. Use 'narrow' or 'wide'.") - - return config - - -def _suppress_logs() -> None: - """Mute all data_designer / litellm / httpx loggers.""" - for name in ("data_designer", "LiteLLM", "litellm", "httpx", "asyncio"): - logging.getLogger(name).setLevel(logging.CRITICAL) - # Also prevent the root logger from propagating data_designer messages. - logging.getLogger().setLevel(logging.CRITICAL) - - -def _run_once(*, async_mode: bool, num_records: int, dag: str, max_parallel: int) -> float: - """Run a single creation and return wall-clock seconds.""" - _suppress_logs() - cwb.DATA_DESIGNER_ASYNC_ENGINE = async_mode - - with tempfile.TemporaryDirectory() as artifact_dir: - dd = DataDesigner(artifact_path=artifact_dir) - dd.set_run_config(RunConfig(buffer_size=1000, disable_early_shutdown=True, async_trace=False)) - config = _build_config(dag, max_parallel) - - t0 = time.perf_counter() - result = dd.create(config, num_records=num_records, dataset_name="bench") - elapsed = time.perf_counter() - t0 - - df = result.load_dataset() - assert len(df) == num_records, f"Expected {num_records} rows, got {len(df)}" - - return elapsed - - -def _stats(times: list[float]) -> dict[str, float]: - """Compute summary statistics for a list of times.""" - n = len(times) - s = sorted(times) - mean = sum(s) / n - if n % 2 == 1: - median = s[n // 2] - else: - median = (s[n // 2 - 1] + s[n // 2]) / 2 - variance = sum((t - mean) ** 2 for t in s) / n - return { - "min": s[0], - "median": median, - "mean": mean, - "max": s[-1], - "stdev": math.sqrt(variance), - } - - -def _print_row(label: str, st: dict[str, float]) -> None: - print( - f" {label:<7} " - f"{st['min']:>7.3f}s " - f"{st['median']:>7.3f}s " - f"{st['mean']:>7.3f}s " - f"{st['max']:>7.3f}s " - f"{st['stdev']:>7.3f}s" - ) - - -def main() -> None: - parser = ArgumentParser(description="Benchmark sync vs async builder") - parser.add_argument( - "--trials", type=int, default=TRIALS_DEFAULT, help="Total trials per engine (including 1 warmup)" - ) - parser.add_argument("--num-records", type=int, default=NUM_RECORDS_DEFAULT, help="Records per trial") - parser.add_argument( - "--dag", - type=str, - default=DAG_DEFAULT, - choices=["narrow", "wide"], - help="DAG shape: narrow (sequential) or wide (parallel branches)", - ) - parser.add_argument( - "--max-parallel", type=int, default=MAX_PARALLEL_DEFAULT, help="max_parallel_requests for the LLM model" - ) - args = parser.parse_args() - - trials: int = args.trials - num_records: int = args.num_records - dag: str = args.dag - max_parallel: int = args.max_parallel - - if trials < 2: - print("Need at least 2 trials (1 warmup + 1 measured). Setting trials=2.") - trials = 2 - - measured = trials - 1 - - print("=" * 70) - print(f"Benchmark: Sync vs Async ({trials} trials, first dropped as warmup)") - print(f"Records per trial: {num_records}, DAG: {dag}, max_parallel: {max_parallel}") - print("=" * 70) - - # --- Warmup: one sync run to prime health checks, caches, etc. --- - # First DataDesigner() call configures logging, so suppress via devnull. - print("\nWarmup (sync)...", end=" ", flush=True) - with open(os.devnull, "w") as devnull: - old_stderr = sys.stderr - sys.stderr = devnull - try: - warmup_time = _run_once(async_mode=False, num_records=num_records, dag=dag, max_parallel=max_parallel) - finally: - sys.stderr = old_stderr - print(f"{warmup_time:.3f}s") - - # --- Interleaved trials: ABABABAB to reduce temporal bias --- - sync_times: list[float] = [] - async_times: list[float] = [] - - for i in range(measured): - trial_num = i + 1 - - # Sync - print(f"\n Trial {trial_num}/{measured} — sync ...", end=" ", flush=True) - t_sync = _run_once(async_mode=False, num_records=num_records, dag=dag, max_parallel=max_parallel) - sync_times.append(t_sync) - print(f"{t_sync:.3f}s") - - # Async - print(f" Trial {trial_num}/{measured} — async...", end=" ", flush=True) - t_async = _run_once(async_mode=True, num_records=num_records, dag=dag, max_parallel=max_parallel) - async_times.append(t_async) - print(f"{t_async:.3f}s") - - # --- Stats --- - sync_st = _stats(sync_times) - async_st = _stats(async_times) - speedup_median = sync_st["median"] / async_st["median"] if async_st["median"] > 0 else float("inf") - speedup_mean = sync_st["mean"] / async_st["mean"] if async_st["mean"] > 0 else float("inf") - - print("\n" + "=" * 70) - print(f"Results ({measured} measured trials, {num_records} records each, dag={dag})") - print("=" * 70) - - header = f" {'engine':<7} {'min':>8} {'median':>8} {'mean':>8} {'max':>8} {'stdev':>8}" - print(header) - print(f" {'-' * (len(header.strip()) - 2)}") - _print_row("sync", sync_st) - _print_row("async", async_st) - - print(f"\n Speedup (median): {speedup_median:.2f}x") - print(f" Speedup (mean): {speedup_mean:.2f}x") - - # Per-trial detail - print("\n Per-trial times:") - print(f" {'trial':>5} {'sync':>8} {'async':>8} {'ratio':>7}") - print(f" {'-' * 33}") - for i, (s, a) in enumerate(zip(sync_times, async_times)): - ratio = s / a if a > 0 else float("inf") - print(f" {i + 1:>5} {s:>7.3f}s {a:>7.3f}s {ratio:>6.2f}x") - - print("\nDone.") - - -if __name__ == "__main__": - try: - main() - except KeyboardInterrupt: - print("\nInterrupted.") - sys.exit(130) diff --git a/demo/async/helpers.py b/demo/async/helpers.py deleted file mode 100644 index d47865da7..000000000 --- a/demo/async/helpers.py +++ /dev/null @@ -1,213 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Shared utilities for async engine test scripts.""" - -from __future__ import annotations - -import contextlib -import hashlib -import json -import random -from collections.abc import Iterator -from dataclasses import dataclass -from typing import Any - -# --------------------------------------------------------------------------- -# Assertion helper -# --------------------------------------------------------------------------- - - -def check(condition: bool, message: str) -> None: - status = "PASS" if condition else "FAIL" - print(f" [{status}] {message}") - if not condition: - raise AssertionError(message) - - -# --------------------------------------------------------------------------- -# Mock LLM responses -# --------------------------------------------------------------------------- - - -@dataclass(frozen=True) -class FakeMessage: - content: str - tool_calls: list[dict[str, Any]] | None = None - reasoning_content: str | None = None - - -@dataclass(frozen=True) -class FakeChoice: - message: FakeMessage - - -@dataclass(frozen=True) -class FakeResponse: - choices: list[FakeChoice] - usage: Any | None = None - model: str | None = None - - -def _stable_seed(model: str, messages: list[dict[str, Any]]) -> int: - payload = json.dumps( - {"model": model, "messages": messages}, - sort_keys=True, - separators=(",", ":"), - ensure_ascii=True, - default=str, - ) - digest = hashlib.sha256(payload.encode()).digest() - return int.from_bytes(digest[:8], "big") - - -def _mock_response_text(model: str, messages: list[dict[str, Any]]) -> str: - rng = random.Random(_stable_seed(model, messages)) - score = rng.uniform(0, 10) - category = rng.choice(["low", "mid", "high"]) - return f"mock:{model}|cat={category}|score={score:.3f}" - - -def _fake_response(model: str, messages: list[dict[str, Any]], **_kwargs: Any) -> FakeResponse: - text = _mock_response_text(model, messages) - return FakeResponse(choices=[FakeChoice(message=FakeMessage(content=text))], model=model) - - -@contextlib.contextmanager -def patch_llm_responses(*, fail_pattern: str | None = None, fail_rate: float = 1.0) -> Iterator[None]: - """Patch LLM completion to return deterministic mock responses. - - Args: - fail_pattern: If set, raise RuntimeError when this string appears in the - serialized messages. Used to simulate LLM failures for specific columns. - fail_rate: Probability of failure when fail_pattern matches (0.0-1.0). - """ - from data_designer.engine.models.litellm_overrides import CustomRouter - - original_completion = CustomRouter.completion - original_acompletion = getattr(CustomRouter, "acompletion", None) - - def _should_fail(messages: list[dict[str, Any]]) -> bool: - if fail_pattern is None: - return False - serialized = json.dumps(messages, default=str) - if fail_pattern not in serialized: - return False - return random.random() < fail_rate - - def fake_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse: - if _should_fail(messages): - raise RuntimeError(f"Simulated LLM failure for {model}") - return _fake_response(model, messages, **kwargs) - - async def fake_acompletion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> FakeResponse: - if _should_fail(messages): - raise RuntimeError(f"Simulated LLM failure for {model}") - return _fake_response(model, messages, **kwargs) - - CustomRouter.completion = fake_completion - CustomRouter.acompletion = fake_acompletion - try: - yield - finally: - CustomRouter.completion = original_completion - if original_acompletion is not None: - CustomRouter.acompletion = original_acompletion - else: - with contextlib.suppress(AttributeError): - delattr(CustomRouter, "acompletion") - - -# --------------------------------------------------------------------------- -# Dataset fingerprinting -# --------------------------------------------------------------------------- - - -def dataset_fingerprint(df: Any) -> str: - import numpy as np - import pandas as pd - - def _default(v: Any) -> Any: - if isinstance(v, np.generic): - return v.item() - if isinstance(v, np.ndarray): - return v.tolist() - if isinstance(v, (pd.Timestamp, pd.Timedelta)): - return v.isoformat() - if isinstance(v, set): - return sorted(v) - if isinstance(v, bytes): - return v.decode("utf-8", errors="replace") - return str(v) - - normalized = df.reset_index(drop=True) - normalized = normalized.reindex(sorted(normalized.columns), axis=1) - records = normalized.to_dict(orient="records") - payload = json.dumps(records, sort_keys=True, separators=(",", ":"), ensure_ascii=True, default=_default) - return hashlib.sha256(payload.encode()).hexdigest() - - -# --------------------------------------------------------------------------- -# Model / DataDesigner factory helpers -# --------------------------------------------------------------------------- - - -def create_mock_model_config() -> tuple: - """Return (ModelConfig, ModelProvider) for openai-text with mock backend.""" - from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider - - provider = ModelProvider( - name="mock-provider", - endpoint="https://mock.local", - provider_type="openai", - api_key="mock-key", - ) - model_config = ModelConfig( - alias="openai-text", - model="gpt-4.1", - provider="mock-provider", - inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=8), - skip_health_check=True, - ) - return model_config, provider - - -def create_data_designer(*, artifact_path: str, async_trace: bool = True) -> Any: - from data_designer.config.run_config import RunConfig - from data_designer.interface import DataDesigner - - _, provider = create_mock_model_config() - dd = DataDesigner(artifact_path=artifact_path, model_providers=[provider]) - dd.set_run_config( - RunConfig( - buffer_size=1000, - disable_early_shutdown=True, - async_trace=async_trace, - ) - ) - return dd - - -def create_base_config() -> Any: - """Create a DataDesignerConfigBuilder with openai-text model and a seed column.""" - from data_designer.config.column_configs import SamplerColumnConfig - from data_designer.config.config_builder import DataDesignerConfigBuilder - from data_designer.config.sampler_params import SamplerType, UniformSamplerParams - - model_config, _ = create_mock_model_config() - config = DataDesignerConfigBuilder(model_configs=[model_config]) - config.add_column( - SamplerColumnConfig( - name="seed_value", - sampler_type=SamplerType.UNIFORM, - params=UniformSamplerParams(low=0.0, high=100.0, decimal_places=2), - ) - ) - return config - - -def seed_rng(seed: int = 42) -> None: - import numpy as np - - random.seed(seed) - np.random.seed(seed) diff --git a/demo/async/notebook_async_benchmark.html b/demo/async/notebook_async_benchmark.html deleted file mode 100644 index a6651a7e9..000000000 --- a/demo/async/notebook_async_benchmark.html +++ /dev/null @@ -1,9380 +0,0 @@ - - - - - -notebook_async_benchmark - - - - - - - - - - - - -
- - - - - - - - - - - - - -
- - diff --git a/demo/async/notebook_async_benchmark.py b/demo/async/notebook_async_benchmark.py deleted file mode 100644 index 917b7a013..000000000 --- a/demo/async/notebook_async_benchmark.py +++ /dev/null @@ -1,704 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -# %% [markdown] -# # Async Scheduler Benchmark -# -# This notebook compares the **sync** (column-by-column) builder with the -# **async** (task-queue) scheduler and shows when the async path delivers -# real speedups. -# -# **Key finding:** the async scheduler needs *both* a wide DAG (independent -# columns) *and* enough LLM concurrency (`max_parallel_requests`) to -# outperform the sync engine. With a narrow/sequential DAG or low -# concurrency, both engines hit the same API throughput ceiling. -# -# **Requirements:** `OPENAI_API_KEY` in environment. Uses `gpt-4.1` via -# `openai-text` alias. - -# %% Setup -from __future__ import annotations - -import logging -import math -import os -import sys -import tempfile -import time -import warnings - -import matplotlib.patches as mpatches -import matplotlib.pyplot as plt -import pandas as pd - -os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" - -warnings.filterwarnings("ignore", message=".*urllib3.*") -warnings.filterwarnings("ignore", message=".*Unclosed.*") - -from IPython.display import HTML, display - -import data_designer.engine.dataset_builders.column_wise_builder as cwb -from data_designer.config.column_configs import ( - GenerationStrategy, - LLMTextColumnConfig, - SamplerColumnConfig, -) -from data_designer.config.config_builder import DataDesignerConfigBuilder -from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig -from data_designer.config.run_config import RunConfig -from data_designer.config.sampler_params import CategorySamplerParams, SamplerType -from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph -from data_designer.interface import DataDesigner - -NUM_RECORDS = 10 -TRIALS = 5 - -_STRATEGY_MAP: dict[type, GenerationStrategy] = { - SamplerColumnConfig: GenerationStrategy.FULL_COLUMN, - LLMTextColumnConfig: GenerationStrategy.CELL_BY_CELL, -} - - -def suppress_logs() -> None: - for name in ("data_designer", "LiteLLM", "litellm", "httpx", "asyncio"): - logging.getLogger(name).setLevel(logging.CRITICAL) - logging.getLogger().setLevel(logging.CRITICAL) - - -def render_mermaid(code: str) -> None: - """Render a Mermaid diagram inline via the Mermaid JS CDN.""" - display( - HTML( - f"""
{code}
-""" - ) - ) - - -def _execution_graph_from_builder(config: DataDesignerConfigBuilder): # noqa: ANN201 - """Build an ExecutionGraph directly from a config builder.""" - cols = list(config._column_configs.values()) - strategies = {c.name: _STRATEGY_MAP[type(c)] for c in cols} - return ExecutionGraph.create(cols, strategies) - - -# %% [markdown] -# ## DAG Shapes -# -# We test two DAG shapes: -# -# **Narrow** (sequential): `topic → summary → followup` -# - 2 LLM columns, fully sequential — no cross-column parallelism possible. -# -# **Wide** (parallel branches): -# ``` -# topic → summary ─┐ -# topic → analysis ─┼→ synthesis -# topic → trivia ─┘ -# ``` -# - 3 independent LLM columns off the sampler, then 1 merge column. -# - The async scheduler can overlap summary/analysis/trivia. - -# %% Config builders - - -def _model_configs(max_parallel: int) -> list[ModelConfig]: - return [ - ModelConfig( - alias="openai-text", - model="gpt-4.1", - provider="openai", - inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=max_parallel), - ) - ] - - -def build_narrow(max_parallel: int) -> DataDesignerConfigBuilder: - config = DataDesignerConfigBuilder(model_configs=_model_configs(max_parallel)) - config.add_column( - SamplerColumnConfig( - name="topic", - sampler_type=SamplerType.CATEGORY, - params=CategorySamplerParams(values=["science", "history", "art"]), - ) - ) - config.add_column( - LLMTextColumnConfig( - name="summary", - model_alias="openai-text", - prompt="Write a one-sentence summary about {{ topic }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="followup", - model_alias="openai-text", - prompt="Given this summary: '{{ summary }}', write a follow-up question.", - ) - ) - return config - - -def build_wide(max_parallel: int) -> DataDesignerConfigBuilder: - config = DataDesignerConfigBuilder(model_configs=_model_configs(max_parallel)) - config.add_column( - SamplerColumnConfig( - name="topic", - sampler_type=SamplerType.CATEGORY, - params=CategorySamplerParams(values=["science", "history", "art"]), - ) - ) - config.add_column( - LLMTextColumnConfig( - name="summary", - model_alias="openai-text", - prompt="Write a one-sentence summary about {{ topic }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="analysis", - model_alias="openai-text", - prompt="Write a one-sentence analysis of why {{ topic }} matters.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="trivia", - model_alias="openai-text", - prompt="State one surprising fact about {{ topic }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="synthesis", - model_alias="openai-text", - prompt=( - "Combine these into one sentence: " - "summary='{{ summary }}', analysis='{{ analysis }}', trivia='{{ trivia }}'." - ), - ) - ) - return config - - -# %% [markdown] -# ## DAG Visualization -# -# The `ExecutionGraph` class models the column dependency DAG and annotates -# each node with its generation strategy (`full_column` for samplers, -# `cell_by_cell` for LLM columns). The `to_mermaid()` method renders this -# as a Mermaid flowchart. - -# %% -narrow_graph = _execution_graph_from_builder(build_narrow(4)) -wide_graph = _execution_graph_from_builder(build_wide(4)) - -print("Narrow DAG") -print(f" Critical path: {' → '.join(narrow_graph.get_longest_dependency_chain())}") -print(f" Task count (10 records, buffer=1000): {narrow_graph.compute_task_count(NUM_RECORDS, 1000)}") -render_mermaid(narrow_graph.to_mermaid()) - -# %% -print("Wide DAG") -print(f" Critical path: {' → '.join(wide_graph.get_longest_dependency_chain())}") -print(f" Task count (10 records, buffer=1000): {wide_graph.compute_task_count(NUM_RECORDS, 1000)}") -render_mermaid(wide_graph.to_mermaid()) - - -# %% Runner - - -def run_once( - *, - async_mode: bool, - config: DataDesignerConfigBuilder, - trace: bool = False, -) -> tuple[float, list]: - """Run one creation. Returns (elapsed_seconds, traces).""" - suppress_logs() - cwb.DATA_DESIGNER_ASYNC_ENGINE = async_mode - - with tempfile.TemporaryDirectory() as artifact_dir: - dd = DataDesigner(artifact_path=artifact_dir) - dd.set_run_config( - RunConfig( - buffer_size=1000, - disable_early_shutdown=True, - async_trace=trace and async_mode, - ) - ) - - t0 = time.perf_counter() - result = dd.create(config, num_records=NUM_RECORDS, dataset_name="bench") - elapsed = time.perf_counter() - t0 - - df = result.load_dataset() - assert len(df) == NUM_RECORDS, f"Expected {NUM_RECORDS} rows, got {len(df)}" - - traces = result.task_traces if hasattr(result, "task_traces") else [] - - return elapsed, traces - - -def benchmark( - config_fn, - max_parallel: int, - label: str, - trace_last: bool = False, -) -> dict: - """Run TRIALS interleaved sync/async and return stats + optional traces.""" - config_s = config_fn(max_parallel) - config_a = config_fn(max_parallel) - - # Warmup - devnull = open(os.devnull, "w") # noqa: SIM115 - old_stderr = sys.stderr - sys.stderr = devnull - run_once(async_mode=False, config=config_s) - sys.stderr = old_stderr - devnull.close() - - sync_times: list[float] = [] - async_times: list[float] = [] - last_traces: list = [] - - for i in range(TRIALS): - is_last = i == TRIALS - 1 - - config_s = config_fn(max_parallel) - t_sync, _ = run_once(async_mode=False, config=config_s) - sync_times.append(t_sync) - - config_a = config_fn(max_parallel) - t_async, traces = run_once( - async_mode=True, - config=config_a, - trace=trace_last and is_last, - ) - async_times.append(t_async) - if is_last: - last_traces = traces - - return { - "label": label, - "max_parallel": max_parallel, - "sync": stats(sync_times), - "async": stats(async_times), - "sync_times": sync_times, - "async_times": async_times, - "traces": last_traces, - } - - -def stats(times: list[float]) -> dict[str, float]: - n = len(times) - s = sorted(times) - mean = sum(s) / n - median = s[n // 2] if n % 2 == 1 else (s[n // 2 - 1] + s[n // 2]) / 2 - variance = sum((t - mean) ** 2 for t in s) / n - return {"min": s[0], "median": median, "mean": mean, "max": s[-1], "stdev": math.sqrt(variance)} - - -# %% [markdown] -# ## Experiment 1: Narrow DAG, default concurrency (4) -# -# With a sequential DAG and low concurrency, the async scheduler has no room -# to parallelize. Both engines process one column at a time with 4 workers. - -# %% -print("Running: narrow DAG, max_parallel=4 ...") -r_narrow_4 = benchmark(build_narrow, max_parallel=4, label="narrow, par=4") - -# %% [markdown] -# ## Experiment 2: Wide DAG, default concurrency (4) -# -# Even with 3 independent branches, `max_parallel_requests=4` means only 4 -# LLM calls can be in flight at once. The sync engine already saturates -# these 4 slots column-by-column. The async scheduler's cross-column -# dispatch hits the same ceiling. - -# %% -print("Running: wide DAG, max_parallel=4 ...") -r_wide_4 = benchmark(build_wide, max_parallel=4, label="wide, par=4") - -# %% [markdown] -# ## Experiment 3: Wide DAG, high concurrency (16) -# -# Now we raise `max_parallel_requests=16`. The sync engine still processes -# columns one at a time (16 workers per column). The async scheduler can -# dispatch tasks from summary, analysis, and trivia *simultaneously*, -# keeping 16 slots busy across columns. - -# %% -print("Running: wide DAG, max_parallel=16 ...") -r_wide_16 = benchmark(build_wide, max_parallel=16, label="wide, par=16", trace_last=True) - -# %% [markdown] -# ## Results Summary - -# %% -rows = [] -for r in [r_narrow_4, r_wide_4, r_wide_16]: - s, a = r["sync"], r["async"] - speedup = s["median"] / a["median"] if a["median"] > 0 else float("inf") - rows.append( - { - "Configuration": r["label"], - "Sync median (s)": round(s["median"], 1), - "Async median (s)": round(a["median"], 1), - "Speedup": f"{speedup:.2f}x", - } - ) - -summary_df = pd.DataFrame(rows) -display(summary_df.style.hide(axis="index").set_caption(f"Benchmark results ({TRIALS} trials, {NUM_RECORDS} records)")) - -# Per-trial detail -for r in [r_narrow_4, r_wide_4, r_wide_16]: - trial_rows = [] - for i, (st, at) in enumerate(zip(r["sync_times"], r["async_times"])): - ratio = st / at if at > 0 else float("inf") - trial_rows.append( - {"Trial": i + 1, "Sync (s)": round(st, 1), "Async (s)": round(at, 1), "Ratio": f"{ratio:.2f}x"} - ) - trial_df = pd.DataFrame(trial_rows) - display(trial_df.style.hide(axis="index").set_caption(r["label"])) - -# %% [markdown] -# ## Using Traces to Optimize -# -# The async scheduler records a `TaskTrace` for every dispatched task when -# `async_trace=True`. Each trace captures: -# -# | Field | Meaning | -# |---|---| -# | `column` | Which column this task generates | -# | `row_group` / `row_index` | Position in the dataset | -# | `task_type` | `from_scratch`, `batch`, or `cell` | -# | `dispatched_at` | When the scheduler queued the task | -# | `slot_acquired_at` | When it started executing (after waiting for a slot) | -# | `completed_at` | When it finished | -# | `status` | `ok` or `error` | -# -# From these we can compute **wait time** (queued → slot acquired) and -# **run time** (slot acquired → completed) for every task, revealing where -# the pipeline spends its time. - -# %% - -traces = r_wide_16["traces"] - -if traces: - t0 = min(t.dispatched_at for t in traces if t.dispatched_at) - - by_column: dict[str, list] = {} - for t in traces: - by_column.setdefault(t.column, []).append(t) - - # --- Per-column summary --- - col_rows = [] - for col in sorted(by_column): - col_traces = by_column[col] - waits = [ - (t.slot_acquired_at - t.dispatched_at) * 1000 for t in col_traces if t.slot_acquired_at and t.dispatched_at - ] - runs = [ - (t.completed_at - t.slot_acquired_at) * 1000 for t in col_traces if t.completed_at and t.slot_acquired_at - ] - errors = sum(1 for t in col_traces if t.status == "error") - col_rows.append( - { - "Column": col, - "Tasks": len(col_traces), - "Avg Wait (ms)": round(sum(waits) / len(waits), 0) if waits else 0, - "Avg Run (ms)": round(sum(runs) / len(runs), 0) if runs else 0, - "Total Run (ms)": round(sum(runs), 0), - "Errors": errors, - } - ) - - col_df = pd.DataFrame(col_rows) - display( - col_df.style.hide(axis="index") - .set_caption("Per-column timing (wide DAG, max_parallel=16, async)") - .background_gradient(subset=["Total Run (ms)"], cmap="YlOrRd") - ) - - # --- Column timeline --- - timeline_rows = [] - for col in sorted(by_column): - col_traces = by_column[col] - first = min(t.dispatched_at for t in col_traces if t.dispatched_at) - t0 - last = max(t.completed_at for t in col_traces if t.completed_at) - t0 - timeline_rows.append( - { - "Column": col, - "First Dispatch (s)": round(first, 3), - "Last Complete (s)": round(last, 3), - "Span (s)": round(last - first, 3), - } - ) - - timeline_df = pd.DataFrame(timeline_rows) - display( - timeline_df.style.hide(axis="index") - .set_caption("Column timeline (relative to first dispatch)") - .background_gradient(subset=["Span (s)"], cmap="Blues") - ) - -# %% [markdown] -# ### Per-row-group breakdown -# -# The async scheduler pipelines row groups: it can start independent columns -# for row group 1 while row group 0 is still generating dependent columns -# like `synthesis`. This table shows when each (row group, column) pair -# starts and finishes. - -# %% -if traces: - rg_rows = [] - for t in traces: - if t.dispatched_at and t.completed_at: - rg_rows.append( - { - "Row Group": t.row_group, - "Column": t.column, - "Task Type": t.task_type, - "Row Index": t.row_index if t.row_index is not None else "—", - "Start (s)": round(t.slot_acquired_at - t0, 3) if t.slot_acquired_at else None, - "End (s)": round(t.completed_at - t0, 3), - "Run (ms)": ( - round((t.completed_at - t.slot_acquired_at) * 1000, 0) if t.slot_acquired_at else None - ), - "Status": t.status, - } - ) - - rg_df = pd.DataFrame(rg_rows).sort_values(["Row Group", "Start (s)"], na_position="last") - - # Summary per (row_group, column) - rg_summary = ( - rg_df.groupby(["Row Group", "Column"]) - .agg( - Tasks=("Run (ms)", "count"), - First_Start=("Start (s)", "min"), - Last_End=("End (s)", "max"), - Avg_Run_ms=("Run (ms)", "mean"), - ) - .reset_index() - ) - rg_summary["Avg_Run_ms"] = rg_summary["Avg_Run_ms"].round(0) - rg_summary.columns = ["Row Group", "Column", "Tasks", "First Start (s)", "Last End (s)", "Avg Run (ms)"] - - display( - rg_summary.style.hide(axis="index") - .set_caption("Per-row-group column breakdown") - .background_gradient(subset=["Avg Run (ms)"], cmap="YlOrRd") - ) -else: - print("(no traces captured — run with trace=True)") - -# %% [markdown] -# ### Task Gantt Chart -# -# Each bar represents a task executing on the async scheduler. Bars are -# color-coded by column — overlapping bars across columns show the async -# scheduler exploiting cross-column parallelism. -# -# Notice that `synthesis` bars start *before* `summary`/`analysis`/`trivia` -# have finished all their rows. This is because `synthesis` is a -# **cell-by-cell** column: it only needs `summary[i]`, `analysis[i]`, and -# `trivia[i]` to be complete before generating `synthesis[i]`. The scheduler -# dispatches each synthesis cell as soon as its per-row dependencies are met, -# without waiting for the full column to finish. - -# %% -if traces: - end_time = max(t.completed_at for t in traces if t.completed_at) - total_secs = end_time - t0 - - # Use topological order so the chart reads top-to-bottom like the DAG - columns_ordered = wide_graph.get_topological_order() - cmap = plt.get_cmap("tab10") - col_colors = {col: cmap(i) for i, col in enumerate(columns_ordered)} - - # Build bars: one per task, y-position = column index - fig, ax = plt.subplots(figsize=(12, max(3, len(columns_ordered) * 0.8))) - - for col_idx, col in enumerate(columns_ordered): - for t in by_column[col]: - if t.slot_acquired_at and t.completed_at: - start = t.slot_acquired_at - t0 - duration = t.completed_at - t.slot_acquired_at - alpha = 0.4 if t.status == "error" else 0.85 - ax.barh( - col_idx, - duration, - left=start, - height=0.6, - color=col_colors[col], - alpha=alpha, - edgecolor="white", - linewidth=0.5, - ) - - ax.set_yticks(range(len(columns_ordered))) - ax.set_yticklabels(columns_ordered) - ax.set_xlabel("Time (s)") - ax.set_title("Task Gantt Chart — wide DAG, max_parallel=16, async") - ax.set_xlim(0, total_secs * 1.02) - ax.invert_yaxis() - - patches = [mpatches.Patch(color=col_colors[c], label=c) for c in columns_ordered] - ax.legend(handles=patches, loc="upper right", fontsize="small", framealpha=0.8) - plt.tight_layout() - plt.show() - -# %% [markdown] -# ### Cell-level Timeline -# -# This chart shows every individual cell task, with the y-axis representing -# row indices. Each colored bar is one (column, row) task. You can see how -# `synthesis` cells start filling in from the top (row 0) while the -# independent columns are still processing later rows — this is the -# cell-level pipelining that the async scheduler enables. - -# %% -if traces: - # Collect cell-by-cell tasks (row_index is not None) - cell_tasks = [t for t in traces if t.row_index is not None and t.slot_acquired_at and t.completed_at] - # Also include full_column tasks, drawn as spanning all rows - batch_tasks = [t for t in traces if t.row_index is None and t.slot_acquired_at and t.completed_at] - - if cell_tasks: - row_indices = sorted({t.row_index for t in cell_tasks}) - columns_ordered = wide_graph.get_topological_order() - cmap_cell = plt.get_cmap("tab10") - col_colors = {col: cmap_cell(i) for i, col in enumerate(columns_ordered)} - n_cols = len(columns_ordered) - - bar_height = 0.8 / n_cols # subdivide each row's vertical space - - fig, ax = plt.subplots(figsize=(12, max(3, len(row_indices) * 0.6))) - - for t in cell_tasks: - col_offset = columns_ordered.index(t.column) - y = t.row_index + (col_offset - n_cols / 2 + 0.5) * bar_height - start = t.slot_acquired_at - t0 - duration = t.completed_at - t.slot_acquired_at - alpha = 0.4 if t.status == "error" else 0.85 - ax.barh( - y, - duration, - left=start, - height=bar_height * 0.9, - color=col_colors[t.column], - alpha=alpha, - edgecolor="white", - linewidth=0.3, - ) - - # Show full_column tasks as a thin bar spanning all rows - for t in batch_tasks: - start = t.slot_acquired_at - t0 - duration = t.completed_at - t.slot_acquired_at - ax.barh( - len(row_indices) / 2 - 0.5, - duration, - left=start, - height=len(row_indices) * 0.05, - color=col_colors.get(t.column, "gray"), - alpha=0.3, - edgecolor="none", - ) - - ax.set_yticks(row_indices) - ax.set_yticklabels([f"row {r}" for r in row_indices]) - ax.set_xlabel("Time (s)") - ax.set_title("Cell-level Timeline — each bar is one (column, row) task") - ax.set_xlim(0, total_secs * 1.02) - ax.invert_yaxis() - - patches = [mpatches.Patch(color=col_colors[c], label=c) for c in columns_ordered] - ax.legend(handles=patches, loc="upper right", fontsize="small", framealpha=0.8) - plt.tight_layout() - plt.show() - -# %% [markdown] -# ### Optimization Hints - -# %% -if traces: - print("Optimization hints:") - - # 1. Slowest column by total run time - max_col = max( - by_column, - key=lambda c: sum( - (t.completed_at - t.slot_acquired_at) for t in by_column[c] if t.completed_at and t.slot_acquired_at - ), - ) - max_total = sum( - (t.completed_at - t.slot_acquired_at) for t in by_column[max_col] if t.completed_at and t.slot_acquired_at - ) - print(f" - Slowest column: '{max_col}' ({max_total:.1f}s total LLM time)") - - # 2. Average wait time - all_waits = [(t.slot_acquired_at - t.dispatched_at) for t in traces if t.slot_acquired_at and t.dispatched_at] - avg_wait = sum(all_waits) / len(all_waits) if all_waits else 0 - if avg_wait > 0.5: - print(f" - High avg wait: {avg_wait:.3f}s → consider increasing max_parallel_requests") - else: - print(f" - Low avg wait: {avg_wait:.3f}s → concurrency slots are well-utilized") - - # 3. Column overlap detection - col_spans = {} - for col in by_column: - first = min(t.dispatched_at for t in by_column[col] if t.dispatched_at) - last = max(t.completed_at for t in by_column[col] if t.completed_at) - col_spans[col] = (first, last) - - overlapping_pairs = [] - cols = sorted(col_spans) - for i, c1 in enumerate(cols): - for c2 in cols[i + 1 :]: - s1, e1 = col_spans[c1] - s2, e2 = col_spans[c2] - overlap = max(0, min(e1, e2) - max(s1, s2)) - if overlap > 0: - overlapping_pairs.append((c1, c2, overlap)) - - if overlapping_pairs: - print(f" - Column overlap detected ({len(overlapping_pairs)} pairs):") - for c1, c2, overlap in sorted(overlapping_pairs, key=lambda x: -x[2]): - print(f" {c1} ↔ {c2}: {overlap:.1f}s overlap") - else: - print(" - No column overlap — async scheduler isn't adding parallelism") - print(" → check that columns have independent dependencies") -else: - print("(no traces captured — run with trace=True)") - -# %% [markdown] -# ## Takeaways -# -# 1. **DAG shape matters.** Sequential pipelines (A → B → C) give the async -# scheduler nothing to overlap. Design pipelines with independent branches -# where possible. -# -# 2. **`max_parallel_requests` is the lever.** The default of 4 is -# conservative. If your provider supports higher concurrency, increase it -# to let the scheduler fill slots across columns. -# -# 3. **Traces reveal bottlenecks.** Use `RunConfig(async_trace=True)` and -# inspect `result.task_traces` to see: -# - Which columns are slowest (optimize prompts or use faster models) -# - Whether tasks are waiting for slots (increase `max_parallel_requests`) -# - Whether columns overlap in time (confirms async is helping) -# -# 4. **No regression.** Even when the async scheduler can't parallelize, -# it matches sync performance — the overhead is negligible. diff --git a/demo/async/test_edge_cases.py b/demo/async/test_edge_cases.py deleted file mode 100644 index b200bad85..000000000 --- a/demo/async/test_edge_cases.py +++ /dev/null @@ -1,236 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Test: edge cases for the async scheduler. - -Sub-tests: -1. Minimum generation (num_records=1) — produces correct dataset and traces. -2. Tiny buffer_size (buffer_size=1) — multiple row groups schedule correctly. -3. Early shutdown — failing column triggers error-rate shutdown; traces capture errors. -""" - -from __future__ import annotations - -import os -import sys -import tempfile - -os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" - -from helpers import ( - check, - create_base_config, - create_data_designer, - create_mock_model_config, - patch_llm_responses, - seed_rng, -) - -from data_designer.config.column_configs import LLMTextColumnConfig -from data_designer.config.run_config import RunConfig - - -def _simple_config(): - config = create_base_config() - config.add_column( - LLMTextColumnConfig( - name="output", - model_alias="openai-text", - prompt="Process seed {{ seed_value }}.", - ) - ) - return config - - -def _config_with_fragile_column(): - """Config where 'fragile' column prompt triggers mock failures.""" - config = create_base_config() - config.add_column( - LLMTextColumnConfig( - name="stable", - model_alias="openai-text", - prompt="Process seed {{ seed_value }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="fragile", - model_alias="openai-text", - prompt="FRAGILE_MARKER: analyze {{ seed_value }}.", - ) - ) - return config - - -# --------------------------------------------------------------------------- -# Sub-test 1: minimum generation -# --------------------------------------------------------------------------- - - -def test_single_record() -> None: - print("\n--- Sub-test 1: Single record (num_records=1) ---") - - with tempfile.TemporaryDirectory() as artifact_dir: - seed_rng(42) - dd = create_data_designer(artifact_path=artifact_dir, async_trace=True) - config = _simple_config() - - with patch_llm_responses(): - result = dd.create(config, num_records=1, dataset_name="single") - - df = result.load_dataset() - traces = result.task_traces - - print(f" rows={len(df)}, cols={len(df.columns)}, traces={len(traces)}") - - check(len(df) == 1, "Single row produced") - check("seed_value" in df.columns, "seed_value column present") - check("output" in df.columns, "output column present") - check(len(traces) > 0, f"Traces captured ({len(traces)})") - check( - all(t.status == "ok" for t in traces), - "All traces succeeded", - ) - - -# --------------------------------------------------------------------------- -# Sub-test 2: tiny buffer_size -# --------------------------------------------------------------------------- - - -def test_tiny_buffer() -> None: - # buffer_size=1 with 8 records creates 8 row groups, well above - # the default max_concurrent_row_groups=3. This exercises the - # scheduler's streaming admission (row groups admitted as slots free). - num_records = 8 - print(f"\n--- Sub-test 2: Tiny buffer_size (buffer_size=1, num_records={num_records}) ---") - - with tempfile.TemporaryDirectory() as artifact_dir: - seed_rng(42) - _, provider = create_mock_model_config() - - from data_designer.interface import DataDesigner - - dd = DataDesigner(artifact_path=artifact_dir, model_providers=[provider]) - dd.set_run_config( - RunConfig( - buffer_size=1, - disable_early_shutdown=True, - async_trace=True, - ) - ) - - config = _simple_config() - - with patch_llm_responses(): - result = dd.create(config, num_records=num_records, dataset_name="tiny-buffer") - - df = result.load_dataset() - traces = result.task_traces - - row_groups_seen = {t.row_group for t in traces} - - print(f" rows={len(df)}, traces={len(traces)}, row_groups={sorted(row_groups_seen)}") - - check(len(df) == num_records, f"All {num_records} rows produced") - check(len(row_groups_seen) == num_records, f"Got {num_records} distinct row groups") - check( - all(t.status == "ok" for t in traces), - "All traces succeeded", - ) - - -# --------------------------------------------------------------------------- -# Sub-test 3: early shutdown via error rate -# --------------------------------------------------------------------------- - - -def test_early_shutdown() -> None: - print("\n--- Sub-test 3: Early shutdown from error rate ---") - - with tempfile.TemporaryDirectory() as artifact_dir: - seed_rng(42) - _, provider = create_mock_model_config() - - from data_designer.interface import DataDesigner - from data_designer.interface.errors import DataDesignerGenerationError, DataDesignerProfilingError - - dd = DataDesigner(artifact_path=artifact_dir, model_providers=[provider]) - dd.set_run_config( - RunConfig( - buffer_size=1000, - disable_early_shutdown=False, - shutdown_error_rate=0.3, - shutdown_error_window=5, - async_trace=True, - ) - ) - - # Capture the builder so we can read traces even if profiling fails - # (early shutdown may produce no data → no parquet → profiling error). - captured_builder = None - original_create = dd._create_dataset_builder - - def _capture_builder(config, rp): - nonlocal captured_builder - captured_builder = original_create(config, rp) - return captured_builder - - dd._create_dataset_builder = _capture_builder - - config = _config_with_fragile_column() - num_records = 16 - - try: - with patch_llm_responses(fail_pattern="FRAGILE_MARKER", fail_rate=1.0): - result = dd.create(config, num_records=num_records, dataset_name="shutdown") - traces = result.task_traces - except (DataDesignerGenerationError, DataDesignerProfilingError): - # Early shutdown with 100% failures leaves no data — profiling fails. - traces = captured_builder.task_traces if captured_builder else [] - finally: - dd._create_dataset_builder = original_create - - error_traces = [t for t in traces if t.status == "error"] - ok_traces = [t for t in traces if t.status == "ok"] - - print(f" traces={len(traces)}, ok={len(ok_traces)}, errors={len(error_traces)}") - - check(len(traces) > 0, "Traces were captured") - check(len(error_traces) > 0, f"Error traces present ({len(error_traces)})") - check( - any(t.column == "fragile" for t in error_traces), - "Errors are on the 'fragile' column", - ) - - # With 100% fail rate on fragile column and low error window, - # the scheduler should have stopped before completing everything. - total_expected = num_records * 3 # seed_value + stable + fragile, each cell-by-cell - check( - len(traces) < total_expected, - f"Early shutdown: {len(traces)} traces < {total_expected} expected", - ) - - -# --------------------------------------------------------------------------- -# Main -# --------------------------------------------------------------------------- - - -def main() -> None: - print("=" * 60) - print("Test: Edge Cases") - print("=" * 60) - - test_single_record() - test_tiny_buffer() - test_early_shutdown() - - print("\nDone.") - - -if __name__ == "__main__": - try: - main() - except AssertionError: - sys.exit(1) diff --git a/demo/async/test_execution_order.py b/demo/async/test_execution_order.py deleted file mode 100644 index d09aa5c09..000000000 --- a/demo/async/test_execution_order.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Test: async scheduler respects DAG execution order. - -Builds a 4-column DAG: - - seed_value (sampler, no deps) - ├── stage1 (LLM, depends on seed_value) - │ └── stage2 (LLM, depends on stage1) - └── independent (LLM, depends on seed_value only) - -Verifies via TaskTraces that: -1. All seed tasks finish before any downstream task dispatches. -2. All stage1 tasks finish before any stage2 task dispatches. -3. "independent" tasks can overlap with stage1 (no dependency between them). -""" - -from __future__ import annotations - -import os -import sys -import tempfile - -os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" - -from helpers import ( - check, - create_base_config, - create_data_designer, - patch_llm_responses, - seed_rng, -) - -from data_designer.config.column_configs import LLMTextColumnConfig - -NUM_RECORDS = 8 - - -def _build_config(): - config = create_base_config() - config.add_column( - LLMTextColumnConfig( - name="stage1", - model_alias="openai-text", - prompt="Summarize seed {{ seed_value }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="stage2", - model_alias="openai-text", - prompt="Analyze {{ stage1 }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="independent", - model_alias="openai-text", - prompt="Independent thought on {{ seed_value }}.", - ) - ) - return config - - -def _traces_for(traces: list, column: str) -> list: - return [t for t in traces if t.column == column] - - -def main() -> None: - print("=" * 60) - print("Test: Execution Order Respects DAG") - print("=" * 60) - - with tempfile.TemporaryDirectory() as artifact_dir: - seed_rng(42) - dd = create_data_designer(artifact_path=artifact_dir, async_trace=True) - config = _build_config() - - with patch_llm_responses(): - result = dd.create(config, num_records=NUM_RECORDS, dataset_name="order") - - df = result.load_dataset() - traces = result.task_traces - - print(f"\n Dataset: {len(df)} rows, {len(df.columns)} cols") - print(f" Traces: {len(traces)} total") - - # Group traces by column - seed_traces = _traces_for(traces, "seed_value") - stage1_traces = _traces_for(traces, "stage1") - stage2_traces = _traces_for(traces, "stage2") - indep_traces = _traces_for(traces, "independent") - - print( - f" Per column: seed_value={len(seed_traces)}, stage1={len(stage1_traces)}, " - f"stage2={len(stage2_traces)}, independent={len(indep_traces)}" - ) - - # --- Assertion 1: seeds finish before downstream dispatches --- - print("\nResults:") - if seed_traces: - latest_seed_completion = max(t.completed_at for t in seed_traces) - downstream = stage1_traces + stage2_traces + indep_traces - if downstream: - earliest_downstream_dispatch = min(t.dispatched_at for t in downstream) - check( - latest_seed_completion <= earliest_downstream_dispatch, - "All seed tasks completed before first downstream dispatch", - ) - else: - check(False, "Expected downstream traces but found none") - else: - check(False, "Expected seed traces but found none") - - # --- Assertion 2: stage1 finishes before stage2 dispatches --- - if stage1_traces and stage2_traces: - latest_stage1 = max(t.completed_at for t in stage1_traces) - earliest_stage2 = min(t.dispatched_at for t in stage2_traces) - check( - latest_stage1 <= earliest_stage2, - "All stage1 tasks completed before first stage2 dispatch", - ) - - # --- Assertion 3: independent dispatched before stage2 --- - # With zero-latency mocks there's no temporal overlap — tasks complete - # instantly when the event loop schedules them. Instead we verify - # that independent tasks were dispatched in the same "wave" as stage1 - # (between seed completion and stage2 dispatch), proving the scheduler - # treats them as independent of stage1. - if indep_traces and stage2_traces: - latest_indep_dispatch = max(t.dispatched_at for t in indep_traces) - earliest_stage2_dispatch = min(t.dispatched_at for t in stage2_traces) - check( - latest_indep_dispatch <= earliest_stage2_dispatch, - "All independent tasks dispatched before first stage2 dispatch", - ) - - # --- Assertion 4: all traces succeeded --- - error_traces = [t for t in traces if t.status == "error"] - check(len(error_traces) == 0, f"No error traces (found {len(error_traces)})") - - print("\nDone.") - - -if __name__ == "__main__": - try: - main() - except AssertionError: - sys.exit(1) diff --git a/demo/async/test_sync_vs_async_parity.py b/demo/async/test_sync_vs_async_parity.py deleted file mode 100644 index 820794633..000000000 --- a/demo/async/test_sync_vs_async_parity.py +++ /dev/null @@ -1,158 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Test: sync builder vs async scheduler produce structurally identical datasets. - -Uses real LLMs (openai-text). Compares: -- Row / column counts -- Non-null values in every cell -- TaskTrace availability (async only) -- Wall-clock timing -""" - -from __future__ import annotations - -import os -import sys -import tempfile -import time - -# Enable async engine before any data_designer imports. -os.environ["DATA_DESIGNER_ASYNC_ENGINE"] = "1" - -from helpers import check, seed_rng - -from data_designer.config.column_configs import LLMTextColumnConfig, SamplerColumnConfig -from data_designer.config.config_builder import DataDesignerConfigBuilder -from data_designer.config.run_config import RunConfig -from data_designer.config.sampler_params import CategorySamplerParams, SamplerType -from data_designer.interface import DataDesigner - -NUM_RECORDS = 5 - - -def _build_config() -> DataDesignerConfigBuilder: - config = DataDesignerConfigBuilder() - config.add_column( - SamplerColumnConfig( - name="topic", - sampler_type=SamplerType.CATEGORY, - params=CategorySamplerParams(values=["science", "history", "art"]), - ) - ) - config.add_column( - LLMTextColumnConfig( - name="summary", - model_alias="openai-text", - prompt="Write a one-sentence summary about {{ topic }}.", - ) - ) - config.add_column( - LLMTextColumnConfig( - name="followup", - model_alias="openai-text", - prompt="Given this summary: '{{ summary }}', write a follow-up question.", - ) - ) - return config - - -def _print_trace_table(traces: list) -> None: - if not traces: - print(" (no traces)") - return - - header = f" {'column':<12} {'rg':>3} {'row':>4} {'type':<13} {'wait':>7} {'run':>7} {'status':<6} {'error'}" - print(header) - print(f" {'-' * len(header.strip())}") - - for t in sorted(traces, key=lambda t: t.dispatched_at): - wait_ms = (t.slot_acquired_at - t.dispatched_at) * 1000 if t.slot_acquired_at and t.dispatched_at else 0 - run_ms = (t.completed_at - t.slot_acquired_at) * 1000 if t.completed_at and t.slot_acquired_at else 0 - row_str = str(t.row_index) if t.row_index is not None else "-" - err_str = (t.error or "")[:40] - print( - f" {t.column:<12} {t.row_group:>3} {row_str:>4} {t.task_type:<13} " - f"{wait_ms:>6.0f}ms {run_ms:>6.0f}ms {t.status:<6} {err_str}" - ) - - -def main() -> None: - print("=" * 60) - print("Test: Sync vs Async Parity (real LLMs)") - print("=" * 60) - - with tempfile.TemporaryDirectory() as async_dir, tempfile.TemporaryDirectory() as sync_dir: - # --- Sync run first (to warm up health checks, caches, etc.) --- - import data_designer.engine.dataset_builders.column_wise_builder as cwb - - cwb.DATA_DESIGNER_ASYNC_ENGINE = False - - print(f"\n[1/2] Running sync builder ({NUM_RECORDS} records)...") - seed_rng(42) - dd_sync = DataDesigner(artifact_path=sync_dir) - dd_sync.set_run_config(RunConfig(buffer_size=1000, disable_early_shutdown=True, async_trace=False)) - config_sync = _build_config() - - t0 = time.perf_counter() - result_sync = dd_sync.create(config_sync, num_records=NUM_RECORDS, dataset_name="sync") - sync_time = time.perf_counter() - t0 - - df_sync = result_sync.load_dataset() - - print(f" rows={len(df_sync)}, cols={len(df_sync.columns)}, time={sync_time:.3f}s") - - # --- Async run second --- - cwb.DATA_DESIGNER_ASYNC_ENGINE = True - - print(f"\n[2/2] Running async scheduler ({NUM_RECORDS} records)...") - seed_rng(42) - dd_async = DataDesigner(artifact_path=async_dir) - dd_async.set_run_config(RunConfig(buffer_size=1000, disable_early_shutdown=True, async_trace=True)) - config_async = _build_config() - - t0 = time.perf_counter() - result_async = dd_async.create(config_async, num_records=NUM_RECORDS, dataset_name="async") - async_time = time.perf_counter() - t0 - - df_async = result_async.load_dataset() - traces_async = result_async.task_traces - - print( - f" rows={len(df_async)}, cols={len(df_async.columns)}, time={async_time:.3f}s, traces={len(traces_async)}" - ) - - print("\n Async traces:") - _print_trace_table(traces_async) - - # --- Assertions --- - print("\nResults:") - check(len(df_async) == NUM_RECORDS, f"Async produced {NUM_RECORDS} rows") - check(len(df_sync) == NUM_RECORDS, f"Sync produced {NUM_RECORDS} rows") - check(len(df_async) == len(df_sync), f"Row counts match ({len(df_async)})") - check( - set(df_async.columns) == set(df_sync.columns), - f"Column sets match ({sorted(df_async.columns)})", - ) - check( - df_async.notna().all().all(), - "Async: no null values", - ) - check( - df_sync.notna().all().all(), - "Sync: no null values", - ) - check(len(traces_async) > 0, f"Async produced {len(traces_async)} TaskTraces") - check(len(result_sync.task_traces) == 0, "Sync produced no TaskTraces") - - speedup = sync_time / async_time if async_time > 0 else float("inf") - print(f"\n Timing: sync={sync_time:.3f}s async={async_time:.3f}s ratio={speedup:.2f}x") - - print("\nDone.") - - -if __name__ == "__main__": - try: - main() - except AssertionError: - sys.exit(1) From 259828d1e5d3c3fb0de1b7eabfa21b0129426177 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 19:44:59 +0000 Subject: [PATCH 09/23] fix: wire disable_early_shutdown into AsyncTaskScheduler RunConfig.disable_early_shutdown was forwarded to the sync executor but silently ignored in the async path. Now passed through to the scheduler's _check_error_rate. --- .../data_designer/engine/dataset_builders/async_scheduler.py | 4 ++++ .../engine/dataset_builders/column_wise_builder.py | 1 + 2 files changed, 5 insertions(+) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index ce523db38..bdb32e383 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -50,6 +50,7 @@ def __init__( on_before_checkpoint: Callable[[int, int], None] | None = None, shutdown_error_rate: float = 0.5, shutdown_error_window: int = 10, + disable_early_shutdown: bool = False, trace: bool = False, ) -> None: self._generators = generators @@ -72,6 +73,7 @@ def __init__( # Error rate shutdown (caller passes pre-normalized values via RunConfig) self._shutdown_error_rate = shutdown_error_rate self._shutdown_error_window = shutdown_error_window + self._disable_early_shutdown = disable_early_shutdown self._early_shutdown = False # Multi-column dedup: group output columns by generator identity @@ -299,6 +301,8 @@ def _in_flight_for_rg(self, rg_id: int) -> bool: def _check_error_rate(self) -> None: """Trigger early shutdown if error rate exceeds threshold.""" + if self._disable_early_shutdown: + return completed = self._success_count + self._error_count if completed < self._shutdown_error_window: return diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 0ece93b0f..dbe7c361d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -298,6 +298,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: ), shutdown_error_rate=settings.shutdown_error_rate, shutdown_error_window=settings.shutdown_error_window, + disable_early_shutdown=settings.disable_early_shutdown, trace=trace_enabled, ) From 40e16fe23ceffcb334ea09c8075cbe4adbcb3dae Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 20:07:21 +0000 Subject: [PATCH 10/23] test: add e2e test for async engine concurrency Verifies the async scheduler dispatches independent LLM columns concurrently by checking for overlapping task trace intervals. Uses a wide DAG (sampler -> 2 parallel LLM columns) with 2 records. Requires NVIDIA_API_KEY. --- tests_e2e/tests/test_async_engine.py | 87 ++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 tests_e2e/tests/test_async_engine.py diff --git a/tests_e2e/tests/test_async_engine.py b/tests_e2e/tests/test_async_engine.py new file mode 100644 index 000000000..0ea139ed3 --- /dev/null +++ b/tests_e2e/tests/test_async_engine.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +from collections import defaultdict +from pathlib import Path + +import pytest + +import data_designer.config as dd +import data_designer.engine.dataset_builders.column_wise_builder as cwb +from data_designer.config.run_config import RunConfig +from data_designer.interface import DataDesigner + +NUM_RECORDS = 2 +PARALLEL_COLUMNS = ("summary", "analysis") + + +def test_async_engine_concurrent_columns(tmp_path: Path) -> None: + """Verify the async engine runs independent LLM columns concurrently.""" + if os.environ.get("NVIDIA_API_KEY") is None: + pytest.skip("NVIDIA_API_KEY must be set") + + original = cwb.DATA_DESIGNER_ASYNC_ENGINE + cwb.DATA_DESIGNER_ASYNC_ENGINE = True + try: + dd_instance = DataDesigner(artifact_path=str(tmp_path)) + dd_instance.set_run_config(RunConfig(buffer_size=NUM_RECORDS, async_trace=True)) + + config = dd.DataDesignerConfigBuilder() + config.add_column( + dd.SamplerColumnConfig( + name="topic", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams(values=["science", "history", "art"]), + ) + ) + for col in PARALLEL_COLUMNS: + config.add_column( + dd.LLMTextColumnConfig( + name=col, + model_alias="nvidia-text", + prompt=f"Write one sentence about {{{{ topic }}}} ({col}).", + ) + ) + + result = dd_instance.create(config, num_records=NUM_RECORDS, dataset_name="async_e2e") + df = result.load_dataset() + finally: + cwb.DATA_DESIGNER_ASYNC_ENGINE = original + + # Dataset correctness + assert len(df) == NUM_RECORDS + for col in ("topic", *PARALLEL_COLUMNS): + assert col in df.columns + assert df[col].notna().all() + + # Concurrency: check that cell tasks from different columns overlapped + traces = result.task_traces + assert traces, "No task traces recorded - async_trace may not be enabled" + + by_col: dict[str, list] = defaultdict(list) + for t in traces: + if t.task_type == "cell" and t.status == "ok" and t.slot_acquired_at and t.completed_at: + by_col[t.column].append(t) + + overlap_found = False + cols = [c for c in PARALLEL_COLUMNS if by_col[c]] + for i, col_a in enumerate(cols): + for col_b in cols[i + 1 :]: + for ta in by_col[col_a]: + for tb in by_col[col_b]: + if ta.slot_acquired_at < tb.completed_at and tb.slot_acquired_at < ta.completed_at: + overlap_found = True + break + if overlap_found: + break + if overlap_found: + break + if overlap_found: + break + + assert overlap_found, ( + "No overlapping execution found between parallel columns - async scheduler may not be dispatching concurrently" + ) From 37e3b6271ba5f2720b83e1e324443914a8be773e Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 20:13:05 +0000 Subject: [PATCH 11/23] fix: drop row group on on_before_checkpoint failure instead of writing unprocessed data Matches on_seeds_complete failure behavior and avoids silently checkpointing unfiltered rows when a post-batch processor fails. --- .../data_designer/engine/dataset_builders/async_scheduler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index bdb32e383..a1caa3bd6 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -266,9 +266,12 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: self._on_before_checkpoint(rg_id, rg_size) except Exception: logger.error( - f"on_before_checkpoint failed for row group {rg_id}, checkpointing un-processed data.", + f"on_before_checkpoint failed for row group {rg_id}, dropping row group.", exc_info=True, ) + for ri in range(rg_size): + if self._buffer_manager: + self._buffer_manager.drop_row(rg_id, ri) if self._buffer_manager is not None: self._buffer_manager.checkpoint_row_group(rg_id) if self._on_row_group_complete: From 954117bbca5ef77bdf86e64b02b6fd6c644a8b4e Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Wed, 18 Mar 2026 20:31:48 +0000 Subject: [PATCH 12/23] fix: skip on_before_checkpoint when no POST_BATCH processors configured Avoids unnecessary DataFrame round-trip for every row group in the common case where no post-batch processors exist. --- .../engine/dataset_builders/column_wise_builder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index dbe7c361d..05aa205c5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -292,7 +292,9 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: row_groups=row_groups, buffer_manager=buffer_manager, on_seeds_complete=on_seeds_complete, - on_before_checkpoint=on_before_checkpoint, + on_before_checkpoint=( + on_before_checkpoint if self._processor_runner.has_processors_for(ProcessorStage.POST_BATCH) else None + ), on_row_group_complete=lambda rg_id: ( on_batch_complete(self.artifact_storage.final_dataset_path) if on_batch_complete else None ), From d7dd2ee452c53b5696ff453d1526adbde327dd68 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 19 Mar 2026 17:51:45 +0000 Subject: [PATCH 13/23] fix: address remaining review nits from nabinchha and greptile summary - Gate on_seeds_complete on PRE_BATCH processors (matches on_before_checkpoint pattern) - Cache seed_cols as instance attr instead of recomputing in _dispatch_seeds - Iterate list(self._active_rgs) snapshot in _run_seeds_complete_check - Add logger.debug to telemetry except block - Add design comment on on_before_checkpoint failure drop behavior - Rename row_group param to row_group_index in is_column_complete_for_rg - Document rg_id as current_batch_number equivalence - Use mock.patch.object in e2e test instead of direct mutation - Add max(0, ...) floor guard on _in_flight_counts decrement - Rename _ensure_async_engine_loop to public ensure_async_engine_loop - Move AsyncTaskScheduler import to module level in integration tests --- .../engine/dataset_builders/async_scheduler.py | 12 ++++++++---- .../engine/dataset_builders/column_wise_builder.py | 11 +++++++---- .../dataset_builders/utils/async_concurrency.py | 6 +++--- .../dataset_builders/utils/completion_tracker.py | 12 ++++++------ .../test_async_builder_integration.py | 3 +-- tests_e2e/tests/test_async_engine.py | 7 ++----- 6 files changed, 27 insertions(+), 24 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index a1caa3bd6..12d5ef5cc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -112,6 +112,9 @@ def __init__( # Pre-compute row-group sizes for O(1) lookup self._rg_size_map: dict[int, int] = dict(row_groups) + # Pre-compute seed columns (graph is static) + self._seed_cols: frozenset[str] = frozenset(c for c in graph.columns if not graph.get_upstream_columns(c)) + async def _admit_row_groups(self) -> None: """Admit row groups as semaphore slots become available.""" for rg_id, rg_size in self._row_groups: @@ -130,7 +133,7 @@ async def _admit_row_groups(self) -> None: async def run(self) -> None: """Main scheduler loop.""" all_columns = self._graph.columns - seed_cols = frozenset(c for c in all_columns if not self._graph.get_upstream_columns(c)) + seed_cols = self._seed_cols has_pre_batch = self._on_seeds_complete is not None # Launch admission as a background task so it interleaves with dispatch. @@ -265,6 +268,7 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: try: self._on_before_checkpoint(rg_id, rg_size) except Exception: + # Post-batch is mandatory; drop rather than checkpoint unprocessed data. logger.error( f"on_before_checkpoint failed for row group {rg_id}, dropping row group.", exc_info=True, @@ -283,7 +287,7 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: """Run pre-batch callbacks for row groups whose seeds just completed.""" - for rg_id, rg_size in self._active_rgs: + for rg_id, rg_size in list(self._active_rgs): if rg_id in self._seeds_dispatched_rgs and rg_id not in self._pre_batch_done_rgs: all_seeds_done = all(self._tracker.is_column_complete_for_rg(col, rg_id) for col in seed_cols) if all_seeds_done and not self._in_flight_for_rg(rg_id): @@ -316,7 +320,7 @@ def _check_error_rate(self) -> None: async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: """Dispatch from_scratch tasks for a row group.""" self._seeds_dispatched_rgs.add(rg_id) - seed_cols = [col for col in self._graph.get_topological_order() if not self._graph.get_upstream_columns(col)] + seed_cols = self._seed_cols seen_instances: set[int] = set() for col in seed_cols: @@ -443,7 +447,7 @@ async def _execute_task_inner(self, task: Task) -> None: self.traces.append(trace) self._in_flight.discard(task) - self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) - 1 + self._in_flight_counts[task.row_group] = max(0, self._in_flight_counts.get(task.row_group, 0) - 1) if not retryable and not skipped: self._dispatched.discard(task) self._submission_semaphore.release() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index 05aa205c5..c8d148e7c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -66,7 +66,7 @@ from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.utils.async_concurrency import ( AsyncConcurrentExecutor, - _ensure_async_engine_loop, + ensure_async_engine_loop, ) from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -274,6 +274,7 @@ def on_seeds_complete(rg_id: int, rg_size: int) -> None: tracker.drop_row(rg_id, ri) # Post-batch processor callback: runs after all columns, before checkpoint. + # rg_id is used as current_batch_number; both are 0-based sequential indices today. def on_before_checkpoint(rg_id: int, rg_size: int) -> None: df = buffer_manager.get_dataframe(rg_id) df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id) @@ -291,7 +292,9 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: tracker=tracker, row_groups=row_groups, buffer_manager=buffer_manager, - on_seeds_complete=on_seeds_complete, + on_seeds_complete=( + on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None + ), on_before_checkpoint=( on_before_checkpoint if self._processor_runner.has_processors_for(ProcessorStage.POST_BATCH) else None ), @@ -305,7 +308,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: ) # Run on background event loop - loop = _ensure_async_engine_loop() + loop = ensure_async_engine_loop() future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) future.result() @@ -316,7 +319,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: usage_deltas = self._resource_provider.model_registry.get_usage_deltas(pre_batch_snapshot) self._emit_batch_inference_events("batch", usage_deltas, group_id) except Exception: - pass + logger.debug("Failed to emit batch telemetry for async run", exc_info=True) # Write metadata buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py index f7d18bd26..105278233 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py @@ -23,7 +23,7 @@ The background loop is a process-wide singleton. LiteLLM and similar libraries bind internal async state to a specific event loop, so creating per-call or per-instance loops breaks connection reuse and triggers - cross-loop errors. ``_ensure_async_engine_loop()`` creates one daemon + cross-loop errors. ``ensure_async_engine_loop()`` creates one daemon loop thread and reuses it for all executor instances. Startup Handshake: @@ -86,7 +86,7 @@ def _run_loop(loop: asyncio.AbstractEventLoop, ready: threading.Event) -> None: loop.run_forever() -def _ensure_async_engine_loop() -> asyncio.AbstractEventLoop: +def ensure_async_engine_loop() -> asyncio.AbstractEventLoop: """Get or create a persistent event loop for async engine work. A single event loop is shared across all AsyncConcurrentExecutor instances @@ -172,7 +172,7 @@ def run(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> f"AsyncConcurrentExecutor: launching {len(work_items)} tasks " f"with max_workers={self._max_workers} for column '{self._column_name}'" ) - loop = _ensure_async_engine_loop() + loop = ensure_async_engine_loop() future = asyncio.run_coroutine_threadsafe(self._run_all(work_items), loop) future.result() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index f8c974f4a..881bf6c54 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -78,15 +78,15 @@ def is_all_complete(self, cells: list[SliceRef]) -> bool: return False return True - def is_column_complete_for_rg(self, column: str, row_group: int) -> bool: - """Check if *column* has been fully completed for *row_group*.""" - if column in self._batch_complete.get(row_group, set()): + def is_column_complete_for_rg(self, column: str, row_group_index: int) -> bool: + """Check if *column* has been fully completed for *row_group_index*.""" + if column in self._batch_complete.get(row_group_index, set()): return True - rg_size = self._row_group_sizes.get(row_group, 0) + rg_size = self._row_group_sizes.get(row_group_index, 0) if rg_size == 0: return False - completed = self._completed.get(row_group, {}).get(column, set()) - dropped = self._dropped.get(row_group, set()) + completed = self._completed.get(row_group_index, {}).get(column, set()) + dropped = self._dropped.get(row_group_index, set()) return all(ri in completed or ri in dropped for ri in range(rg_size)) def drop_row(self, row_group: int, row_index: int) -> None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index 0ae343b2e..1302fa905 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -21,6 +21,7 @@ ColumnGeneratorFullColumn, FromScratchColumnGenerator, ) +from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker @@ -105,7 +106,6 @@ def test_validate_async_compatibility(configs: list[Mock], should_raise: bool) - @pytest.mark.asyncio(loop_scope="session") async def test_build_async_end_to_end() -> None: """Test _build_async with mock generators produces correct results in buffer.""" - from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler provider = _mock_provider() seed_gen = MockSeed(config=_expr_config("seed"), resource_provider=provider) @@ -224,7 +224,6 @@ def test_execution_graph_from_real_configs() -> None: @pytest.mark.asyncio(loop_scope="session") async def test_checkpoint_produces_correct_parquet_calls() -> None: """Verify checkpoint writes parquet for each row group.""" - from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler provider = _mock_provider() seed_gen = MockSeed(config=_expr_config("seed"), resource_provider=provider) diff --git a/tests_e2e/tests/test_async_engine.py b/tests_e2e/tests/test_async_engine.py index 0ea139ed3..9eaea23db 100644 --- a/tests_e2e/tests/test_async_engine.py +++ b/tests_e2e/tests/test_async_engine.py @@ -6,6 +6,7 @@ import os from collections import defaultdict from pathlib import Path +from unittest.mock import patch import pytest @@ -23,9 +24,7 @@ def test_async_engine_concurrent_columns(tmp_path: Path) -> None: if os.environ.get("NVIDIA_API_KEY") is None: pytest.skip("NVIDIA_API_KEY must be set") - original = cwb.DATA_DESIGNER_ASYNC_ENGINE - cwb.DATA_DESIGNER_ASYNC_ENGINE = True - try: + with patch.object(cwb, "DATA_DESIGNER_ASYNC_ENGINE", True): dd_instance = DataDesigner(artifact_path=str(tmp_path)) dd_instance.set_run_config(RunConfig(buffer_size=NUM_RECORDS, async_trace=True)) @@ -48,8 +47,6 @@ def test_async_engine_concurrent_columns(tmp_path: Path) -> None: result = dd_instance.create(config, num_records=NUM_RECORDS, dataset_name="async_e2e") df = result.load_dataset() - finally: - cwb.DATA_DESIGNER_ASYNC_ENGINE = original # Dataset correctness assert len(df) == NUM_RECORDS From b1b67417b71b88af1f608cb3cbd9bf4a644aaccc Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 19 Mar 2026 17:46:43 -0300 Subject: [PATCH 14/23] fix: preserve async callback contract and e2e setup --- .../dataset_builders/async_scheduler.py | 13 +- .../dataset_builders/column_wise_builder.py | 4 +- .../dataset_builders/test_async_scheduler.py | 72 ++++++++ tests_e2e/tests/test_async_engine.py | 154 +++++++++++------- 4 files changed, 180 insertions(+), 63 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 12d5ef5cc..2f6fda2fd 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -7,6 +7,7 @@ import contextlib import logging import time +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable import data_designer.lazy_heavy_imports as lazy @@ -46,6 +47,7 @@ def __init__( max_submitted_tasks: int = 256, salvage_max_rounds: int = 2, on_row_group_complete: Callable[[int], None] | None = None, + on_checkpoint_complete: Callable[[Path | str], None] | None = None, on_seeds_complete: Callable[[int, int], None] | None = None, on_before_checkpoint: Callable[[int, int], None] | None = None, shutdown_error_rate: float = 0.5, @@ -67,6 +69,7 @@ def __init__( self._wake_event = asyncio.Event() self._salvage_max_rounds = salvage_max_rounds self._on_row_group_complete = on_row_group_complete + self._on_checkpoint_complete = on_checkpoint_complete self._on_seeds_complete = on_seeds_complete self._on_before_checkpoint = on_before_checkpoint @@ -277,7 +280,15 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: if self._buffer_manager: self._buffer_manager.drop_row(rg_id, ri) if self._buffer_manager is not None: - self._buffer_manager.checkpoint_row_group(rg_id) + if self._on_checkpoint_complete is not None: + + def on_complete(final_path: Path | str | None) -> None: + if final_path is not None: + self._on_checkpoint_complete(final_path) + + self._buffer_manager.checkpoint_row_group(rg_id, on_complete=on_complete) + else: + self._buffer_manager.checkpoint_row_group(rg_id) if self._on_row_group_complete: self._on_row_group_complete(rg_id) except Exception: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index c8d148e7c..df40e9e93 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -292,15 +292,13 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: tracker=tracker, row_groups=row_groups, buffer_manager=buffer_manager, + on_checkpoint_complete=on_batch_complete, on_seeds_complete=( on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None ), on_before_checkpoint=( on_before_checkpoint if self._processor_runner.has_processors_for(ProcessorStage.POST_BATCH) else None ), - on_row_group_complete=lambda rg_id: ( - on_batch_complete(self.artifact_storage.final_dataset_path) if on_batch_complete else None - ), shutdown_error_rate=settings.shutdown_error_rate, shutdown_error_window=settings.shutdown_error_window, disable_early_shutdown=settings.disable_early_shutdown, diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index 0cabcffda..5b9c5042d 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -679,6 +679,78 @@ async def test_scheduler_on_before_checkpoint_callback() -> None: assert sorted(callback_log) == [(0, 3), (1, 2)] +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_on_checkpoint_complete_callback_receives_final_path() -> None: + """on_checkpoint_complete is called with the written parquet file path.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + ] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + generators = {"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider)} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_mgr = RowGroupBufferManager(storage) + callback_log: list[str] = [] + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_checkpoint_complete=lambda path: callback_log.append(path), + ) + await scheduler.run() + + assert callback_log == ["/fake_final.parquet"] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_on_checkpoint_complete_skips_empty_row_group() -> None: + """on_checkpoint_complete is not called when a row group writes no file.""" + provider = _mock_provider() + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + ] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + generators = { + "seed": MockFailingSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + buffer_mgr = RowGroupBufferManager(storage) + callback = MagicMock() + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_checkpoint_complete=callback, + ) + await scheduler.run() + + callback.assert_not_called() + storage.write_batch_to_parquet_file.assert_not_called() + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_pre_batch_failure_skips_row_group() -> None: """Pre-batch processor failure drops all rows in the row group; other row groups continue.""" diff --git a/tests_e2e/tests/test_async_engine.py b/tests_e2e/tests/test_async_engine.py index 9eaea23db..a15f4d10e 100644 --- a/tests_e2e/tests/test_async_engine.py +++ b/tests_e2e/tests/test_async_engine.py @@ -3,82 +3,118 @@ from __future__ import annotations +import json import os -from collections import defaultdict +import subprocess +import sys from pathlib import Path -from unittest.mock import patch import pytest +NUM_RECORDS = 2 +PARALLEL_COLUMNS = ("summary", "analysis") + + +def _run_async_engine_concurrency_case(tmp_path: Path) -> dict[str, object]: + repo_root = Path(__file__).resolve().parents[2] + script = f""" +from __future__ import annotations + +import json +from collections import defaultdict +from pathlib import Path + import data_designer.config as dd -import data_designer.engine.dataset_builders.column_wise_builder as cwb from data_designer.config.run_config import RunConfig from data_designer.interface import DataDesigner -NUM_RECORDS = 2 -PARALLEL_COLUMNS = ("summary", "analysis") - +NUM_RECORDS = {NUM_RECORDS} +PARALLEL_COLUMNS = {PARALLEL_COLUMNS!r} +tmp_path = Path({str(tmp_path)!r}) -def test_async_engine_concurrent_columns(tmp_path: Path) -> None: - """Verify the async engine runs independent LLM columns concurrently.""" - if os.environ.get("NVIDIA_API_KEY") is None: - pytest.skip("NVIDIA_API_KEY must be set") +dd_instance = DataDesigner(artifact_path=str(tmp_path)) +dd_instance.set_run_config(RunConfig(buffer_size=NUM_RECORDS, async_trace=True)) - with patch.object(cwb, "DATA_DESIGNER_ASYNC_ENGINE", True): - dd_instance = DataDesigner(artifact_path=str(tmp_path)) - dd_instance.set_run_config(RunConfig(buffer_size=NUM_RECORDS, async_trace=True)) - - config = dd.DataDesignerConfigBuilder() - config.add_column( - dd.SamplerColumnConfig( - name="topic", - sampler_type=dd.SamplerType.CATEGORY, - params=dd.CategorySamplerParams(values=["science", "history", "art"]), - ) +config = dd.DataDesignerConfigBuilder() +config.add_column( + dd.SamplerColumnConfig( + name="topic", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams(values=["science", "history", "art"]), + ) +) +for col in PARALLEL_COLUMNS: + config.add_column( + dd.LLMTextColumnConfig( + name=col, + model_alias="nvidia-text", + prompt="Write one sentence about {{{{ topic }}}} (" + col + ").", ) - for col in PARALLEL_COLUMNS: - config.add_column( - dd.LLMTextColumnConfig( - name=col, - model_alias="nvidia-text", - prompt=f"Write one sentence about {{{{ topic }}}} ({col}).", - ) - ) - - result = dd_instance.create(config, num_records=NUM_RECORDS, dataset_name="async_e2e") - df = result.load_dataset() - - # Dataset correctness - assert len(df) == NUM_RECORDS - for col in ("topic", *PARALLEL_COLUMNS): - assert col in df.columns - assert df[col].notna().all() - - # Concurrency: check that cell tasks from different columns overlapped - traces = result.task_traces - assert traces, "No task traces recorded - async_trace may not be enabled" - - by_col: dict[str, list] = defaultdict(list) - for t in traces: - if t.task_type == "cell" and t.status == "ok" and t.slot_acquired_at and t.completed_at: - by_col[t.column].append(t) - - overlap_found = False - cols = [c for c in PARALLEL_COLUMNS if by_col[c]] - for i, col_a in enumerate(cols): - for col_b in cols[i + 1 :]: - for ta in by_col[col_a]: - for tb in by_col[col_b]: - if ta.slot_acquired_at < tb.completed_at and tb.slot_acquired_at < ta.completed_at: - overlap_found = True - break - if overlap_found: + ) + +result = dd_instance.create(config, num_records=NUM_RECORDS, dataset_name="async_e2e") +df = result.load_dataset() +traces = result.task_traces + +by_col: dict[str, list[tuple[float, float]]] = defaultdict(list) +for trace in traces: + if trace.task_type == "cell" and trace.status == "ok" and trace.slot_acquired_at and trace.completed_at: + by_col[trace.column].append((trace.slot_acquired_at, trace.completed_at)) + +overlap_found = False +cols = [col for col in PARALLEL_COLUMNS if by_col[col]] +for i, col_a in enumerate(cols): + for col_b in cols[i + 1 :]: + for start_a, end_a in by_col[col_a]: + for start_b, end_b in by_col[col_b]: + if start_a < end_b and start_b < end_a: + overlap_found = True break if overlap_found: break if overlap_found: break + if overlap_found: + break + +payload = {{ + "rows": len(df), + "columns": list(df.columns), + "non_null": {{col: bool(df[col].notna().all()) for col in ("topic", *PARALLEL_COLUMNS)}}, + "trace_count": len(traces), + "overlap_found": overlap_found, +}} +print("RESULT_JSON=" + json.dumps(payload)) +""" + env = os.environ.copy() + env["DATA_DESIGNER_ASYNC_ENGINE"] = "1" + completed = subprocess.run( + [sys.executable, "-c", script], + check=True, + capture_output=True, + text=True, + cwd=repo_root, + env=env, + ) + for line in completed.stdout.splitlines(): + if line.startswith("RESULT_JSON="): + return json.loads(line.removeprefix("RESULT_JSON=")) + raise AssertionError(f"Missing RESULT_JSON marker in subprocess output:\n{completed.stdout}") + + +def test_async_engine_concurrent_columns(tmp_path: Path) -> None: + """Verify the async engine runs independent LLM columns concurrently.""" + if os.environ.get("NVIDIA_API_KEY") is None: + pytest.skip("NVIDIA_API_KEY must be set") + + payload = _run_async_engine_concurrency_case(tmp_path) + + assert payload["rows"] == NUM_RECORDS + for col in ("topic", *PARALLEL_COLUMNS): + assert col in payload["columns"] + assert payload["non_null"][col] - assert overlap_found, ( + assert payload["trace_count"] > 0 + assert payload["overlap_found"], ( "No overlapping execution found between parallel columns - async scheduler may not be dispatching concurrently" ) From 372e274c1d3d2ac85a4a6f02e768d7ab50b89865 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 19 Mar 2026 20:51:53 +0000 Subject: [PATCH 15/23] fix: prune _seeds_dispatched_rgs and _pre_batch_done_rgs on checkpoint --- .../data_designer/engine/dataset_builders/async_scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 2f6fda2fd..d798f2d25 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -266,6 +266,8 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: for rg_id, rg_size in completed: self._active_rgs.remove((rg_id, rg_size)) self._admitted_rg_ids.discard(rg_id) + self._seeds_dispatched_rgs.discard(rg_id) + self._pre_batch_done_rgs.discard(rg_id) try: if self._on_before_checkpoint: try: From 2114d3b8af9f1088ee8fec01a005c2cd13f138b2 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 19 Mar 2026 21:11:40 +0000 Subject: [PATCH 16/23] refactor: consolidate per-RG state into _RowGroupState dataclass Replace 5 independent collections (_active_rgs, _admitted_rg_ids, _seeds_dispatched_rgs, _pre_batch_done_rgs, _in_flight_counts) with a single _rg_states dict keyed by row group ID. Cleanup is now a single `del` instead of N separate discards, eliminating the class of bugs where one collection is missed during row group teardown. --- .../dataset_builders/async_scheduler.py | 95 +++++++++++-------- 1 file changed, 56 insertions(+), 39 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index d798f2d25..0335177d3 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -7,6 +7,7 @@ import contextlib import logging import time +from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Callable @@ -28,6 +29,16 @@ logger = logging.getLogger(__name__) +@dataclass +class _RowGroupState: + """Lifecycle state for a single admitted row group.""" + + size: int + seeds_dispatched: bool = False + pre_batch_done: bool = False + in_flight_count: int = 0 + + class AsyncTaskScheduler: """Dependency-aware async task scheduler for the dataset builder. @@ -91,18 +102,12 @@ def __init__( if gen.is_order_dependent and id(gen) not in self._stateful_locks: self._stateful_locks[id(gen)] = asyncio.Lock() - # Per-RG in-flight counters for O(1) lookup - self._in_flight_counts: dict[int, int] = {} + # Per-RG lifecycle state (admitted but not yet checkpointed) + self._rg_states: dict[int, _RowGroupState] = {} # Deferred retryable failures (retried in salvage rounds) self._deferred: list[Task] = [] - # Active row groups (admitted but not yet checkpointed) - self._active_rgs: list[tuple[int, int]] = [] - self._admitted_rg_ids: set[int] = set() - self._seeds_dispatched_rgs: set[int] = set() - self._pre_batch_done_rgs: set[int] = set() - # Tracing self._trace = trace self.traces: list[TaskTrace] = [] @@ -122,8 +127,7 @@ async def _admit_row_groups(self) -> None: """Admit row groups as semaphore slots become available.""" for rg_id, rg_size in self._row_groups: await self._rg_semaphore.acquire() - self._active_rgs.append((rg_id, rg_size)) - self._admitted_rg_ids.add(rg_id) + self._rg_states[rg_id] = _RowGroupState(size=rg_size) if self._buffer_manager is not None: self._buffer_manager.init_row_group(rg_id, rg_size) @@ -153,21 +157,27 @@ async def run(self) -> None: self._run_seeds_complete_check(seed_cols) - ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) + admitted_ids = set(self._rg_states) + ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) # Gate non-seed tasks on pre-batch completion when a pre-batch callback is configured if has_pre_batch: - ready = [t for t in ready if t.row_group in self._pre_batch_done_rgs or t.column in seed_cols] + ready = [ + t + for t in ready + if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols + ] for task in ready: await self._submission_semaphore.acquire() self._dispatched.add(task) self._in_flight.add(task) - self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) + 1 + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count += 1 asyncio.create_task(self._execute_task(task)) self._checkpoint_completed_row_groups(all_columns) # Are we done? - all_done = self._all_rgs_admitted and not self._active_rgs and not self._in_flight + all_done = self._all_rgs_admitted and not self._rg_states and not self._in_flight if all_done: break @@ -222,7 +232,8 @@ async def run(self) -> None: Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch") ) self._in_flight.add(task) - self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) + 1 + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count += 1 asyncio.create_task(self._execute_seed_task(task, gid)) else: self._dispatched.discard(task) @@ -231,10 +242,10 @@ async def run(self) -> None: await self._drain_frontier(seed_cols, has_pre_batch, all_columns) self._checkpoint_completed_row_groups(all_columns) - if self._active_rgs: - incomplete = [rg_id for rg_id, _ in self._active_rgs] + if self._rg_states: + incomplete = list(self._rg_states) logger.error( - f"Scheduler exited with {len(self._active_rgs)} unfinished row group(s): {incomplete}. " + f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. " "These row groups were not checkpointed." ) @@ -242,14 +253,20 @@ async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, """Dispatch all frontier tasks and their downstream until quiescent.""" while True: self._run_seeds_complete_check(seed_cols) - ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) + admitted_ids = set(self._rg_states) + ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) if has_pre_batch: - ready = [t for t in ready if t.row_group in self._pre_batch_done_rgs or t.column in seed_cols] + ready = [ + t + for t in ready + if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols + ] for task in ready: await self._submission_semaphore.acquire() self._dispatched.add(task) self._in_flight.add(task) - self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) + 1 + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count += 1 asyncio.create_task(self._execute_task(task)) if not self._in_flight: break @@ -259,15 +276,12 @@ async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: """Checkpoint any row groups that reached completion.""" completed = [ - (rg_id, rg_size) - for rg_id, rg_size in self._active_rgs - if self._tracker.is_row_group_complete(rg_id, rg_size, all_columns) + (rg_id, state.size) + for rg_id, state in self._rg_states.items() + if self._tracker.is_row_group_complete(rg_id, state.size, all_columns) ] for rg_id, rg_size in completed: - self._active_rgs.remove((rg_id, rg_size)) - self._admitted_rg_ids.discard(rg_id) - self._seeds_dispatched_rgs.discard(rg_id) - self._pre_batch_done_rgs.discard(rg_id) + del self._rg_states[rg_id] try: if self._on_before_checkpoint: try: @@ -300,24 +314,25 @@ def on_complete(final_path: Path | str | None) -> None: def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: """Run pre-batch callbacks for row groups whose seeds just completed.""" - for rg_id, rg_size in list(self._active_rgs): - if rg_id in self._seeds_dispatched_rgs and rg_id not in self._pre_batch_done_rgs: + for rg_id, state in list(self._rg_states.items()): + if state.seeds_dispatched and not state.pre_batch_done: all_seeds_done = all(self._tracker.is_column_complete_for_rg(col, rg_id) for col in seed_cols) - if all_seeds_done and not self._in_flight_for_rg(rg_id): - self._pre_batch_done_rgs.add(rg_id) + if all_seeds_done and state.in_flight_count == 0: + state.pre_batch_done = True if self._on_seeds_complete: try: - self._on_seeds_complete(rg_id, rg_size) + self._on_seeds_complete(rg_id, state.size) except Exception as exc: logger.warning(f"Pre-batch processor failed for row group {rg_id}, skipping: {exc}") - for ri in range(rg_size): + for ri in range(state.size): self._tracker.drop_row(rg_id, ri) if self._buffer_manager: self._buffer_manager.drop_row(rg_id, ri) def _in_flight_for_rg(self, rg_id: int) -> bool: """Check if any tasks are in-flight for a given row group.""" - return self._in_flight_counts.get(rg_id, 0) > 0 + state = self._rg_states.get(rg_id) + return state is not None and state.in_flight_count > 0 def _check_error_rate(self) -> None: """Trigger early shutdown if error rate exceeds threshold.""" @@ -332,7 +347,7 @@ def _check_error_rate(self) -> None: async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: """Dispatch from_scratch tasks for a row group.""" - self._seeds_dispatched_rgs.add(rg_id) + self._rg_states[rg_id].seeds_dispatched = True seed_cols = self._seed_cols seen_instances: set[int] = set() @@ -366,7 +381,8 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: ) self._dispatched.add(Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="batch")) self._in_flight.add(task) - self._in_flight_counts[task.row_group] = self._in_flight_counts.get(task.row_group, 0) + 1 + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count += 1 asyncio.create_task(self._execute_seed_task(task, gid)) async def _execute_seed_task(self, task: Task, generator_id: int) -> None: @@ -399,7 +415,7 @@ async def _execute_task_inner(self, task: Task) -> None: # Skip tasks whose row group was already checkpointed (can happen # when a vacuously-ready downstream is dispatched via create_task # in the same loop iteration that checkpoints the row group). - if not any(rg_id == task.row_group for rg_id, _ in self._active_rgs): + if task.row_group not in self._rg_states: skipped = True return @@ -460,7 +476,8 @@ async def _execute_task_inner(self, task: Task) -> None: self.traces.append(trace) self._in_flight.discard(task) - self._in_flight_counts[task.row_group] = max(0, self._in_flight_counts.get(task.row_group, 0) - 1) + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count = max(0, s.in_flight_count - 1) if not retryable and not skipped: self._dispatched.discard(task) self._submission_semaphore.release() From 166cfff690f0403e4354f15689a6dcdd62d82ea1 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Thu, 19 Mar 2026 21:21:13 +0000 Subject: [PATCH 17/23] fix: skip checkpoint and callbacks when on_before_checkpoint fails When on_before_checkpoint raises and all rows are dropped, the code previously fell through to checkpoint_row_group and on_row_group_complete, sending a spurious progress notification for a batch with zero records. Now gates both on a `dropped` flag so they are skipped after failure. --- .../engine/dataset_builders/async_scheduler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 0335177d3..49f9f53b1 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -282,6 +282,7 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: ] for rg_id, rg_size in completed: del self._rg_states[rg_id] + dropped = False try: if self._on_before_checkpoint: try: @@ -295,7 +296,8 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: for ri in range(rg_size): if self._buffer_manager: self._buffer_manager.drop_row(rg_id, ri) - if self._buffer_manager is not None: + dropped = True + if not dropped and self._buffer_manager is not None: if self._on_checkpoint_complete is not None: def on_complete(final_path: Path | str | None) -> None: @@ -305,7 +307,7 @@ def on_complete(final_path: Path | str | None) -> None: self._buffer_manager.checkpoint_row_group(rg_id, on_complete=on_complete) else: self._buffer_manager.checkpoint_row_group(rg_id) - if self._on_row_group_complete: + if not dropped and self._on_row_group_complete: self._on_row_group_complete(rg_id) except Exception: logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True) From c8a823cdd798424b3b78dc54b38ecb5926e5e63d Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 20 Mar 2026 03:07:26 +0000 Subject: [PATCH 18/23] fix: snapshot dropped rows before await in _run_batch and sync tracker on checkpoint failure Two fixes: - _run_batch: snapshot dropped rows before `await agenerate` so the row-count expectation matches batch_df. Concurrent tasks can drop rows during the await, causing a spurious ValueError that would drop the entire row group. Write-back now re-checks is_dropped to skip rows dropped mid-flight. - _checkpoint_completed_row_groups: add tracker.drop_row alongside buffer_manager.drop_row when on_before_checkpoint fails, keeping both in sync. --- .../dataset_builders/async_scheduler.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 49f9f53b1..069d32f3c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -294,6 +294,7 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: exc_info=True, ) for ri in range(rg_size): + self._tracker.drop_row(rg_id, ri) if self._buffer_manager: self._buffer_manager.drop_row(rg_id, ri) dropped = True @@ -535,22 +536,21 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: """Execute a full-column/batch task.""" if self._buffer_manager is not None: batch_df = self._buffer_manager.get_dataframe(task.row_group) + # Snapshot dropped rows before the await so the row-count expectation + # is consistent with batch_df (concurrent tasks may drop rows during agenerate). + rg_size = self._get_rg_size(task.row_group) + pre_dropped: set[int] = {ri for ri in range(rg_size) if self._buffer_manager.is_dropped(task.row_group, ri)} else: batch_df = lazy.pd.DataFrame() + rg_size = self._get_rg_size(task.row_group) + pre_dropped = set() result_df = await generator.agenerate(batch_df) # Merge result columns back to buffer if self._buffer_manager is not None: output_cols = self._instance_to_columns.get(id(generator), [task.column]) - rg_size = self._get_rg_size(task.row_group) - dropped = set() - for ri in range(rg_size): - if self._buffer_manager.is_dropped(task.row_group, ri): - dropped.add(ri) - - # Map result rows (which exclude dropped) back to buffer indices - active_rows = rg_size - len(dropped) + active_rows = rg_size - len(pre_dropped) if len(result_df) != active_rows: raise ValueError( f"Batch generator for '{task.column}' returned {len(result_df)} rows " @@ -558,11 +558,13 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: ) result_idx = 0 for ri in range(rg_size): - if ri in dropped: + if ri in pre_dropped: continue - for col in output_cols: - if col in result_df.columns: - self._buffer_manager.update_cell(task.row_group, ri, col, result_df.iloc[result_idx][col]) + # Skip writing to rows dropped by concurrent tasks during the await + if not self._buffer_manager.is_dropped(task.row_group, ri): + for col in output_cols: + if col in result_df.columns: + self._buffer_manager.update_cell(task.row_group, ri, col, result_df.iloc[result_idx][col]) result_idx += 1 return result_df From d3c27126ab56b8c9462e7a22b1c638dc82c8d35d Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 20 Mar 2026 03:40:01 +0000 Subject: [PATCH 19/23] feat: sliding window error rate and out-of-order row group completion test Replace cumulative error counters with a deque-based sliding window so that early transient failures do not permanently inflate the error rate in long-running jobs. Add tests for the sliding window recovery path and for deterministic out-of-order row group checkpoint ordering. --- .../dataset_builders/async_scheduler.py | 23 ++-- .../dataset_builders/test_async_scheduler.py | 125 ++++++++++++++++++ 2 files changed, 136 insertions(+), 12 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 069d32f3c..9bbfe7740 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -7,6 +7,7 @@ import contextlib import logging import time +from collections import deque from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING, Any, Callable @@ -112,9 +113,8 @@ def __init__( self._trace = trace self.traces: list[TaskTrace] = [] - # Stats - self._success_count = 0 - self._error_count = 0 + # Sliding window for error rate shutdown + self._recent_outcomes: deque[bool] = deque(maxlen=shutdown_error_window) self._all_rgs_admitted = False # Pre-compute row-group sizes for O(1) lookup @@ -337,15 +337,15 @@ def _in_flight_for_rg(self, rg_id: int) -> bool: state = self._rg_states.get(rg_id) return state is not None and state.in_flight_count > 0 - def _check_error_rate(self) -> None: - """Trigger early shutdown if error rate exceeds threshold.""" + def _check_error_rate(self, *, success: bool) -> None: + """Trigger early shutdown if recent error rate exceeds threshold.""" if self._disable_early_shutdown: return - completed = self._success_count + self._error_count - if completed < self._shutdown_error_window: + self._recent_outcomes.append(success) + if len(self._recent_outcomes) < self._shutdown_error_window: return - error_rate = self._error_count / max(1, completed) - if error_rate > self._shutdown_error_rate: + errors = sum(1 for ok in self._recent_outcomes if not ok) + if errors / self._shutdown_error_window > self._shutdown_error_rate: self._early_shutdown = True async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: @@ -442,13 +442,12 @@ async def _execute_task_inner(self, task: Task) -> None: else: self._tracker.mark_cell_complete(col, task.row_group, task.row_index) - self._success_count += 1 + self._check_error_rate(success=True) if self._trace and trace: trace.status = "ok" except Exception as exc: - self._error_count += 1 - self._check_error_rate() + self._check_error_rate(success=False) if self._trace and trace: trace.status = "error" trace.error = str(exc) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index 5b9c5042d..a8e08b06d 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -3,6 +3,7 @@ from __future__ import annotations +import asyncio from typing import Any from unittest.mock import MagicMock @@ -643,6 +644,53 @@ async def test_scheduler_early_shutdown_disabled() -> None: assert tracker.is_row_group_complete(0, 5, ["seed", "fail_col"]) +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_sliding_window_error_rate_recovers() -> None: + """Transient errors diluted by successes do not trigger early shutdown.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "col": GenerationStrategy.CELL_BY_CELL, + } + # First 2 calls fail (retryable 503), rest succeed. + # With window=10 and 10 cell tasks, at most 2/10 = 20% error rate + # when the window first fills - well below the 0.4 threshold. + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "col": MockFailingGenerator(config=_expr_config("col"), resource_provider=provider, transient_failures=2), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 10)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + buffer_mgr = RowGroupBufferManager(storage) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + shutdown_error_rate=0.4, + shutdown_error_window=10, + ) + await scheduler.run() + + # No early shutdown - transient errors recovered in salvage + assert not scheduler._early_shutdown + assert tracker.is_row_group_complete(0, 10, ["seed", "col"]) + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_on_before_checkpoint_callback() -> None: """on_before_checkpoint is called before each row group is checkpointed.""" @@ -798,3 +846,80 @@ def failing_pre_batch(rg_id: int, rg_size: int) -> None: assert all(tracker.is_dropped(0, ri) for ri in range(3)) # Row group 1: completed normally assert tracker.is_row_group_complete(1, 2, ["seed", "cell_out"]) + + +class _SlowSeedGenerator(FromScratchColumnGenerator[ExpressionColumnConfig]): + """Seed generator whose async cost scales with row count via event-loop yields. + + Both RGs' seed tasks run as concurrent asyncio tasks. Because they interleave + via ``asyncio.sleep(0)``, the task with fewer yields (smaller RG) finishes first, + causing its downstream to be dispatched and completed before the larger RG. + """ + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + return data + + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + return lazy.pd.DataFrame({self.config.name: list(range(num_records))}) + + async def agenerate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + for _ in range(num_records * 10): + await asyncio.sleep(0) + return self.generate_from_scratch(num_records) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_out_of_order_row_group_completion() -> None: + """Row groups may complete out of order; both are checkpointed correctly.""" + provider = _mock_provider() + # Slow seed generator: RG 0 (5 rows) does 50 yields, RG 1 (1 row) does 10 yields. + # Because seed tasks interleave in the event loop, RG 1 finishes seeds first, + # its downstream cell_out is dispatched and completes before RG 0. + slow_seed = _SlowSeedGenerator(config=_expr_config("seed"), resource_provider=provider) + cell_gen = MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider) + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = {"seed": slow_seed, "cell_out": cell_gen} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 5), (1, 1)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + buffer_mgr = RowGroupBufferManager(storage) + + checkpoint_order: list[int] = [] + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + max_concurrent_row_groups=2, + on_row_group_complete=lambda rg_id: checkpoint_order.append(rg_id), + ) + await scheduler.run() + + # Both row groups completed + assert tracker.is_row_group_complete(0, 5, ["seed", "cell_out"]) + assert tracker.is_row_group_complete(1, 1, ["seed", "cell_out"]) + # Both were checkpointed + assert set(checkpoint_order) == {0, 1} + # RG 1 (fewer rows, fewer seed yields) checkpoints before RG 0 + assert checkpoint_order.index(1) < checkpoint_order.index(0) From 986aff6010f5ddfd43ed56b51463528e7ff0b429 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 20 Mar 2026 03:45:01 +0000 Subject: [PATCH 20/23] fix: use real time delays in out-of-order completion test asyncio.sleep(0) interleaving is not deterministic across Python versions. Switch to asyncio.sleep(num_records * 0.02) so the smaller row group genuinely finishes seeds first regardless of event loop scheduling. --- .../dataset_builders/test_async_scheduler.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index a8e08b06d..1027bf872 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -849,11 +849,10 @@ def failing_pre_batch(rg_id: int, rg_size: int) -> None: class _SlowSeedGenerator(FromScratchColumnGenerator[ExpressionColumnConfig]): - """Seed generator whose async cost scales with row count via event-loop yields. + """Seed generator whose async cost scales with row count. - Both RGs' seed tasks run as concurrent asyncio tasks. Because they interleave - via ``asyncio.sleep(0)``, the task with fewer yields (smaller RG) finishes first, - causing its downstream to be dispatched and completed before the larger RG. + Both RGs' seed tasks run concurrently. The task with fewer rows sleeps for + less real time, causing its downstream to be dispatched and completed first. """ @staticmethod @@ -867,8 +866,7 @@ def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: return lazy.pd.DataFrame({self.config.name: list(range(num_records))}) async def agenerate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: - for _ in range(num_records * 10): - await asyncio.sleep(0) + await asyncio.sleep(num_records * 0.02) return self.generate_from_scratch(num_records) @@ -876,9 +874,8 @@ async def agenerate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: async def test_scheduler_out_of_order_row_group_completion() -> None: """Row groups may complete out of order; both are checkpointed correctly.""" provider = _mock_provider() - # Slow seed generator: RG 0 (5 rows) does 50 yields, RG 1 (1 row) does 10 yields. - # Because seed tasks interleave in the event loop, RG 1 finishes seeds first, - # its downstream cell_out is dispatched and completes before RG 0. + # Slow seed generator: RG 0 (5 rows) sleeps 100ms, RG 1 (1 row) sleeps 20ms. + # RG 1 finishes seeds first, its downstream is dispatched and completes before RG 0. slow_seed = _SlowSeedGenerator(config=_expr_config("seed"), resource_provider=provider) cell_gen = MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider) From f9215215265d0e1e620da70857452786b44ef2d0 Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 20 Mar 2026 11:12:41 +0000 Subject: [PATCH 21/23] fix: prevent ZeroDivisionError when shutdown_error_window is 0 Change RunConfig.shutdown_error_window constraint from ge=0 to ge=1 so the sliding window denominator is never zero. --- .../src/data_designer/config/run_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index d1e378134..23b7cbec6 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -22,7 +22,7 @@ class RunConfig(ConfigBase): shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early shutdown when early shutdown is enabled. Default is 0.5. shutdown_error_window: Minimum number of completed tasks before error rate - monitoring begins. Must be >= 0. Default is 10. + monitoring begins. Must be >= 1. Default is 10. buffer_size: Number of records to process in each batch during dataset generation. A batch is processed end-to-end (column generation, post-batch processors, and writing the batch to artifact storage) before moving on to the next batch. Must be > 0. Default is 1000. @@ -39,7 +39,7 @@ class RunConfig(ConfigBase): disable_early_shutdown: bool = False shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0) - shutdown_error_window: int = Field(default=10, ge=0) + shutdown_error_window: int = Field(default=10, ge=1) buffer_size: int = Field(default=1000, gt=0) non_inference_max_parallel_workers: int = Field(default=4, ge=1) max_conversation_restarts: int = Field(default=5, ge=0) From 2a8698e7c8a032c24e1ee8ef3257662b01753a3d Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 20 Mar 2026 14:07:00 +0000 Subject: [PATCH 22/23] fix: address Greptile review nits in async_scheduler - Move del _rg_states inside try/finally so semaphore is always released - Add exc_info=True to pre-batch failure log for consistent tracebacks - Short-circuit _check_error_rate when _early_shutdown already set --- .../engine/dataset_builders/async_scheduler.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 9bbfe7740..533a3911c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -281,9 +281,9 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: if self._tracker.is_row_group_complete(rg_id, state.size, all_columns) ] for rg_id, rg_size in completed: - del self._rg_states[rg_id] dropped = False try: + del self._rg_states[rg_id] if self._on_before_checkpoint: try: self._on_before_checkpoint(rg_id, rg_size) @@ -325,8 +325,11 @@ def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: if self._on_seeds_complete: try: self._on_seeds_complete(rg_id, state.size) - except Exception as exc: - logger.warning(f"Pre-batch processor failed for row group {rg_id}, skipping: {exc}") + except Exception: + logger.warning( + f"Pre-batch processor failed for row group {rg_id}, skipping.", + exc_info=True, + ) for ri in range(state.size): self._tracker.drop_row(rg_id, ri) if self._buffer_manager: @@ -339,7 +342,7 @@ def _in_flight_for_rg(self, rg_id: int) -> bool: def _check_error_rate(self, *, success: bool) -> None: """Trigger early shutdown if recent error rate exceeds threshold.""" - if self._disable_early_shutdown: + if self._disable_early_shutdown or self._early_shutdown: return self._recent_outcomes.append(success) if len(self._recent_outcomes) < self._shutdown_error_window: From 58875fad9384c23187bc371ab9e0381bcc367cab Mon Sep 17 00:00:00 2001 From: Andre Manoel Date: Fri, 20 Mar 2026 15:05:33 +0000 Subject: [PATCH 23/23] fix: address Greptile summary findings - Remove duplicate async engine log in build() (kept in _build_async) - Guard _run_seeds_complete_check with has_pre_batch at both call sites - Change error rate comparison from > to >= to match sync path semantics --- .../engine/dataset_builders/async_scheduler.py | 8 +++++--- .../engine/dataset_builders/column_wise_builder.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 533a3911c..2ff15a506 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -155,7 +155,8 @@ async def run(self) -> None: self._wake_event.clear() - self._run_seeds_complete_check(seed_cols) + if has_pre_batch: + self._run_seeds_complete_check(seed_cols) admitted_ids = set(self._rg_states) ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) @@ -252,7 +253,8 @@ async def run(self) -> None: async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, all_columns: list[str]) -> None: """Dispatch all frontier tasks and their downstream until quiescent.""" while True: - self._run_seeds_complete_check(seed_cols) + if has_pre_batch: + self._run_seeds_complete_check(seed_cols) admitted_ids = set(self._rg_states) ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) if has_pre_batch: @@ -348,7 +350,7 @@ def _check_error_rate(self, *, success: bool) -> None: if len(self._recent_outcomes) < self._shutdown_error_window: return errors = sum(1 for ok in self._recent_outcomes if not ok) - if errors / self._shutdown_error_window > self._shutdown_error_rate: + if errors / self._shutdown_error_window >= self._shutdown_error_rate: self._early_shutdown = True async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index df40e9e93..5fd3e609f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -167,7 +167,6 @@ def build( if DATA_DESIGNER_ASYNC_ENGINE: self._validate_async_compatibility() - logger.info("⚡ Using async task-queue builder") self._build_async(generators, num_records, buffer_size, on_batch_complete) else: group_id = uuid.uuid4().hex