diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index 03b2ed297..23b7cbec6 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -22,7 +22,7 @@ class RunConfig(ConfigBase): shutdown_error_rate: Error rate threshold (0.0-1.0) that triggers early shutdown when early shutdown is enabled. Default is 0.5. shutdown_error_window: Minimum number of completed tasks before error rate - monitoring begins. Must be >= 0. Default is 10. + monitoring begins. Must be >= 1. Default is 10. buffer_size: Number of records to process in each batch during dataset generation. A batch is processed end-to-end (column generation, post-batch processors, and writing the batch to artifact storage) before moving on to the next batch. Must be > 0. Default is 1000. @@ -33,15 +33,18 @@ class RunConfig(ConfigBase): max_conversation_correction_steps: Maximum number of correction rounds permitted within a single conversation when generation tasks call `ModelFacade.generate(...)`. Must be >= 0. Default is 0. + async_trace: If True, collect per-task tracing data when using the async engine + (DATA_DESIGNER_ASYNC_ENGINE=1). Has no effect on the sync path. Default is False. """ disable_early_shutdown: bool = False shutdown_error_rate: float = Field(default=0.5, ge=0.0, le=1.0) - shutdown_error_window: int = Field(default=10, ge=0) + shutdown_error_window: int = Field(default=10, ge=1) buffer_size: int = Field(default=1000, gt=0) non_inference_max_parallel_workers: int = Field(default=4, ge=1) max_conversation_restarts: int = Field(default=5, ge=0) max_conversation_correction_steps: int = Field(default=0, ge=0) + async_trace: bool = False @model_validator(mode="after") def normalize_shutdown_settings(self) -> Self: diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 922952d5b..2ff15a506 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -7,6 +7,9 @@ import contextlib import logging import time +from collections import deque +from dataclasses import dataclass +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable import data_designer.lazy_heavy_imports as lazy @@ -27,6 +30,16 @@ logger = logging.getLogger(__name__) +@dataclass +class _RowGroupState: + """Lifecycle state for a single admitted row group.""" + + size: int + seeds_dispatched: bool = False + pre_batch_done: bool = False + in_flight_count: int = 0 + + class AsyncTaskScheduler: """Dependency-aware async task scheduler for the dataset builder. @@ -46,6 +59,12 @@ def __init__( max_submitted_tasks: int = 256, salvage_max_rounds: int = 2, on_row_group_complete: Callable[[int], None] | None = None, + on_checkpoint_complete: Callable[[Path | str], None] | None = None, + on_seeds_complete: Callable[[int, int], None] | None = None, + on_before_checkpoint: Callable[[int, int], None] | None = None, + shutdown_error_rate: float = 0.5, + shutdown_error_window: int = 10, + disable_early_shutdown: bool = False, trace: bool = False, ) -> None: self._generators = generators @@ -62,6 +81,15 @@ def __init__( self._wake_event = asyncio.Event() self._salvage_max_rounds = salvage_max_rounds self._on_row_group_complete = on_row_group_complete + self._on_checkpoint_complete = on_checkpoint_complete + self._on_seeds_complete = on_seeds_complete + self._on_before_checkpoint = on_before_checkpoint + + # Error rate shutdown (caller passes pre-normalized values via RunConfig) + self._shutdown_error_rate = shutdown_error_rate + self._shutdown_error_window = shutdown_error_window + self._disable_early_shutdown = disable_early_shutdown + self._early_shutdown = False # Multi-column dedup: group output columns by generator identity instance_to_columns: dict[int, list[str]] = {} @@ -75,31 +103,31 @@ def __init__( if gen.is_order_dependent and id(gen) not in self._stateful_locks: self._stateful_locks[id(gen)] = asyncio.Lock() + # Per-RG lifecycle state (admitted but not yet checkpointed) + self._rg_states: dict[int, _RowGroupState] = {} + # Deferred retryable failures (retried in salvage rounds) self._deferred: list[Task] = [] - # Active row groups (admitted but not yet checkpointed) - self._active_rgs: list[tuple[int, int]] = [] - self._admitted_rg_ids: set[int] = set() - # Tracing self._trace = trace self.traces: list[TaskTrace] = [] - # Stats - self._success_count = 0 - self._error_count = 0 + # Sliding window for error rate shutdown + self._recent_outcomes: deque[bool] = deque(maxlen=shutdown_error_window) self._all_rgs_admitted = False # Pre-compute row-group sizes for O(1) lookup self._rg_size_map: dict[int, int] = dict(row_groups) + # Pre-compute seed columns (graph is static) + self._seed_cols: frozenset[str] = frozenset(c for c in graph.columns if not graph.get_upstream_columns(c)) + async def _admit_row_groups(self) -> None: """Admit row groups as semaphore slots become available.""" for rg_id, rg_size in self._row_groups: await self._rg_semaphore.acquire() - self._active_rgs.append((rg_id, rg_size)) - self._admitted_rg_ids.add(rg_id) + self._rg_states[rg_id] = _RowGroupState(size=rg_size) if self._buffer_manager is not None: self._buffer_manager.init_row_group(rg_id, rg_size) @@ -112,25 +140,45 @@ async def _admit_row_groups(self) -> None: async def run(self) -> None: """Main scheduler loop.""" all_columns = self._graph.columns + seed_cols = self._seed_cols + has_pre_batch = self._on_seeds_complete is not None # Launch admission as a background task so it interleaves with dispatch. admission_task = asyncio.create_task(self._admit_row_groups()) # Main dispatch loop while True: + if self._early_shutdown: + logger.warning("Early shutdown triggered - error rate exceeded threshold") + self._checkpoint_completed_row_groups(all_columns) + break + self._wake_event.clear() - ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) + if has_pre_batch: + self._run_seeds_complete_check(seed_cols) + + admitted_ids = set(self._rg_states) + ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) + # Gate non-seed tasks on pre-batch completion when a pre-batch callback is configured + if has_pre_batch: + ready = [ + t + for t in ready + if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols + ] for task in ready: await self._submission_semaphore.acquire() self._dispatched.add(task) self._in_flight.add(task) + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count += 1 asyncio.create_task(self._execute_task(task)) self._checkpoint_completed_row_groups(all_columns) # Are we done? - all_done = self._all_rgs_admitted and not self._active_rgs and not self._in_flight + all_done = self._all_rgs_admitted and not self._rg_states and not self._in_flight if all_done: break @@ -185,29 +233,42 @@ async def run(self) -> None: Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch") ) self._in_flight.add(task) + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count += 1 asyncio.create_task(self._execute_seed_task(task, gid)) else: self._dispatched.discard(task) # Drain: dispatch frontier tasks and any newly-ready downstream tasks # until nothing remains in-flight or in the frontier. - await self._drain_frontier() + await self._drain_frontier(seed_cols, has_pre_batch, all_columns) self._checkpoint_completed_row_groups(all_columns) - if self._active_rgs: - incomplete = [rg_id for rg_id, _ in self._active_rgs] + if self._rg_states: + incomplete = list(self._rg_states) logger.error( - f"Scheduler exited with {len(self._active_rgs)} unfinished row group(s): {incomplete}. " + f"Scheduler exited with {len(self._rg_states)} unfinished row group(s): {incomplete}. " "These row groups were not checkpointed." ) - async def _drain_frontier(self) -> None: + async def _drain_frontier(self, seed_cols: frozenset[str], has_pre_batch: bool, all_columns: list[str]) -> None: """Dispatch all frontier tasks and their downstream until quiescent.""" while True: - ready = self._tracker.get_ready_tasks(self._dispatched, self._admitted_rg_ids) + if has_pre_batch: + self._run_seeds_complete_check(seed_cols) + admitted_ids = set(self._rg_states) + ready = self._tracker.get_ready_tasks(self._dispatched, admitted_ids) + if has_pre_batch: + ready = [ + t + for t in ready + if (s := self._rg_states.get(t.row_group)) is not None and s.pre_batch_done or t.column in seed_cols + ] for task in ready: await self._submission_semaphore.acquire() self._dispatched.add(task) self._in_flight.add(task) + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count += 1 asyncio.create_task(self._execute_task(task)) if not self._in_flight: break @@ -217,25 +278,85 @@ async def _drain_frontier(self) -> None: def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: """Checkpoint any row groups that reached completion.""" completed = [ - (rg_id, rg_size) - for rg_id, rg_size in self._active_rgs - if self._tracker.is_row_group_complete(rg_id, rg_size, all_columns) + (rg_id, state.size) + for rg_id, state in self._rg_states.items() + if self._tracker.is_row_group_complete(rg_id, state.size, all_columns) ] for rg_id, rg_size in completed: - self._active_rgs.remove((rg_id, rg_size)) + dropped = False try: - if self._buffer_manager is not None: - self._buffer_manager.checkpoint_row_group(rg_id) - if self._on_row_group_complete: + del self._rg_states[rg_id] + if self._on_before_checkpoint: + try: + self._on_before_checkpoint(rg_id, rg_size) + except Exception: + # Post-batch is mandatory; drop rather than checkpoint unprocessed data. + logger.error( + f"on_before_checkpoint failed for row group {rg_id}, dropping row group.", + exc_info=True, + ) + for ri in range(rg_size): + self._tracker.drop_row(rg_id, ri) + if self._buffer_manager: + self._buffer_manager.drop_row(rg_id, ri) + dropped = True + if not dropped and self._buffer_manager is not None: + if self._on_checkpoint_complete is not None: + + def on_complete(final_path: Path | str | None) -> None: + if final_path is not None: + self._on_checkpoint_complete(final_path) + + self._buffer_manager.checkpoint_row_group(rg_id, on_complete=on_complete) + else: + self._buffer_manager.checkpoint_row_group(rg_id) + if not dropped and self._on_row_group_complete: self._on_row_group_complete(rg_id) except Exception: logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True) finally: self._rg_semaphore.release() + def _run_seeds_complete_check(self, seed_cols: frozenset[str]) -> None: + """Run pre-batch callbacks for row groups whose seeds just completed.""" + for rg_id, state in list(self._rg_states.items()): + if state.seeds_dispatched and not state.pre_batch_done: + all_seeds_done = all(self._tracker.is_column_complete_for_rg(col, rg_id) for col in seed_cols) + if all_seeds_done and state.in_flight_count == 0: + state.pre_batch_done = True + if self._on_seeds_complete: + try: + self._on_seeds_complete(rg_id, state.size) + except Exception: + logger.warning( + f"Pre-batch processor failed for row group {rg_id}, skipping.", + exc_info=True, + ) + for ri in range(state.size): + self._tracker.drop_row(rg_id, ri) + if self._buffer_manager: + self._buffer_manager.drop_row(rg_id, ri) + + def _in_flight_for_rg(self, rg_id: int) -> bool: + """Check if any tasks are in-flight for a given row group.""" + state = self._rg_states.get(rg_id) + return state is not None and state.in_flight_count > 0 + + def _check_error_rate(self, *, success: bool) -> None: + """Trigger early shutdown if recent error rate exceeds threshold.""" + if self._disable_early_shutdown or self._early_shutdown: + return + self._recent_outcomes.append(success) + if len(self._recent_outcomes) < self._shutdown_error_window: + return + errors = sum(1 for ok in self._recent_outcomes if not ok) + if errors / self._shutdown_error_window >= self._shutdown_error_rate: + self._early_shutdown = True + async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: """Dispatch from_scratch tasks for a row group.""" - seed_cols = [col for col in self._graph.get_topological_order() if not self._graph.get_upstream_columns(col)] + self._rg_states[rg_id].seeds_dispatched = True + seed_cols = self._seed_cols seen_instances: set[int] = set() for col in seed_cols: @@ -268,6 +389,8 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: ) self._dispatched.add(Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="batch")) self._in_flight.add(task) + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count += 1 asyncio.create_task(self._execute_seed_task(task, gid)) async def _execute_seed_task(self, task: Task, generator_id: int) -> None: @@ -300,7 +423,7 @@ async def _execute_task_inner(self, task: Task) -> None: # Skip tasks whose row group was already checkpointed (can happen # when a vacuously-ready downstream is dispatched via create_task # in the same loop iteration that checkpoints the row group). - if not any(rg_id == task.row_group for rg_id, _ in self._active_rgs): + if task.row_group not in self._rg_states: skipped = True return @@ -324,12 +447,12 @@ async def _execute_task_inner(self, task: Task) -> None: else: self._tracker.mark_cell_complete(col, task.row_group, task.row_index) - self._success_count += 1 + self._check_error_rate(success=True) if self._trace and trace: trace.status = "ok" except Exception as exc: - self._error_count += 1 + self._check_error_rate(success=False) if self._trace and trace: trace.status = "error" trace.error = str(exc) @@ -360,6 +483,8 @@ async def _execute_task_inner(self, task: Task) -> None: self.traces.append(trace) self._in_flight.discard(task) + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count = max(0, s.in_flight_count - 1) if not retryable and not skipped: self._dispatched.discard(task) self._submission_semaphore.release() @@ -415,22 +540,21 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: """Execute a full-column/batch task.""" if self._buffer_manager is not None: batch_df = self._buffer_manager.get_dataframe(task.row_group) + # Snapshot dropped rows before the await so the row-count expectation + # is consistent with batch_df (concurrent tasks may drop rows during agenerate). + rg_size = self._get_rg_size(task.row_group) + pre_dropped: set[int] = {ri for ri in range(rg_size) if self._buffer_manager.is_dropped(task.row_group, ri)} else: batch_df = lazy.pd.DataFrame() + rg_size = self._get_rg_size(task.row_group) + pre_dropped = set() result_df = await generator.agenerate(batch_df) # Merge result columns back to buffer if self._buffer_manager is not None: output_cols = self._instance_to_columns.get(id(generator), [task.column]) - rg_size = self._get_rg_size(task.row_group) - dropped = set() - for ri in range(rg_size): - if self._buffer_manager.is_dropped(task.row_group, ri): - dropped.add(ri) - - # Map result rows (which exclude dropped) back to buffer indices - active_rows = rg_size - len(dropped) + active_rows = rg_size - len(pre_dropped) if len(result_df) != active_rows: raise ValueError( f"Batch generator for '{task.column}' returned {len(result_df)} rows " @@ -438,11 +562,13 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: ) result_idx = 0 for ri in range(rg_size): - if ri in dropped: + if ri in pre_dropped: continue - for col in output_cols: - if col in result_df.columns: - self._buffer_manager.update_cell(task.row_group, ri, col, result_df.iloc[result_idx][col]) + # Skip writing to rows dropped by concurrent tasks during the await + if not self._buffer_manager.is_dropped(task.row_group, ri): + for col in output_cols: + if col in result_df.columns: + self._buffer_manager.update_cell(task.row_group, ri, col, result_df.iloc[result_idx][col]) result_idx += 1 return result_df diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py index c76977688..5fd3e609f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/column_wise_builder.py @@ -33,7 +33,7 @@ from data_designer.engine.dataset_builders.utils.concurrency import ConcurrentThreadExecutor from data_designer.engine.dataset_builders.utils.config_compiler import compile_dataset_builder_column_configs from data_designer.engine.dataset_builders.utils.dataset_batch_manager import DatasetBatchManager -from data_designer.engine.dataset_builders.utils.processor_runner import ProcessorRunner +from data_designer.engine.dataset_builders.utils.processor_runner import ProcessorRunner, ProcessorStage from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker from data_designer.engine.models.telemetry import InferenceEvent, NemoSourceEnum, TaskStatusEnum, TelemetryHandler from data_designer.engine.processing.processors.base import Processor @@ -47,6 +47,7 @@ import pandas as pd from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry + from data_designer.engine.dataset_builders.utils.task_model import TaskTrace from data_designer.engine.models.usage import ModelUsageStats logger = logging.getLogger(__name__) @@ -54,6 +55,7 @@ DATA_DESIGNER_ASYNC_ENGINE = os.environ.get("DATA_DESIGNER_ASYNC_ENGINE", "0") == "1" if DATA_DESIGNER_ASYNC_ENGINE: + import asyncio import sys if sys.version_info < (3, 11): @@ -61,9 +63,15 @@ "DATA_DESIGNER_ASYNC_ENGINE requires Python 3.11+ (asyncio.TaskGroup). " f"Current version: {sys.version_info.major}.{sys.version_info.minor}" ) - from data_designer.engine.dataset_builders.utils.async_concurrency import AsyncConcurrentExecutor + from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler + from data_designer.engine.dataset_builders.utils.async_concurrency import ( + AsyncConcurrentExecutor, + ensure_async_engine_loop, + ) + from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker + from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph + from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager - logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled — using async concurrency") _CLIENT_VERSION: str = get_library_version() @@ -80,6 +88,7 @@ def __init__( self._records_to_drop: set[int] = set() self._cell_resize_results: list[dict | list[dict] | None] = [] self._cell_resize_mode = False + self._task_traces: list[TaskTrace] = [] self._registry = registry or DataDesignerRegistry() self._data_designer_config = compile_data_designer_config(data_designer_config, resource_provider) @@ -99,6 +108,10 @@ def artifact_storage(self) -> ArtifactStorage: def processors(self) -> tuple[Processor, ...]: return self._processor_runner.processors + @property + def task_traces(self) -> list[TaskTrace]: + return self._task_traces + def set_processor_runner(self, processors: list[Processor]) -> None: """Replace the processor runner with a new one using the given processors.""" self._processor_runner = ProcessorRunner( @@ -150,22 +163,26 @@ def build( generators = self._initialize_generators() start_time = time.perf_counter() - group_id = uuid.uuid4().hex - buffer_size = self._resource_provider.run_config.buffer_size - self.batch_manager.start(num_records=num_records, buffer_size=buffer_size) - for batch_idx in range(self.batch_manager.num_batches): - logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}") - self._run_batch( - generators, - batch_mode="batch", - group_id=group_id, - current_batch_number=batch_idx, - on_batch_complete=on_batch_complete, - ) - self.batch_manager.finish() - self._processor_runner.run_after_generation(buffer_size) + if DATA_DESIGNER_ASYNC_ENGINE: + self._validate_async_compatibility() + self._build_async(generators, num_records, buffer_size, on_batch_complete) + else: + group_id = uuid.uuid4().hex + self.batch_manager.start(num_records=num_records, buffer_size=buffer_size) + for batch_idx in range(self.batch_manager.num_batches): + logger.info(f"⏳ Processing batch {batch_idx + 1} of {self.batch_manager.num_batches}") + self._run_batch( + generators, + batch_mode="batch", + group_id=group_id, + current_batch_number=batch_idx, + on_batch_complete=on_batch_complete, + ) + self.batch_manager.finish() + + self._processor_runner.run_after_generation(buffer_size) self._resource_provider.model_registry.log_model_usage(time.perf_counter() - start_time) return self.artifact_storage.final_dataset_path @@ -190,6 +207,120 @@ def build_preview(self, *, num_records: int) -> pd.DataFrame: return dataset + def _validate_async_compatibility(self) -> None: + """Raise if any column uses allow_resize=True with the async scheduler.""" + offending = [config.name for config in self.single_column_configs if getattr(config, "allow_resize", False)] + if offending: + raise DatasetGenerationError( + f"allow_resize=True is not supported with DATA_DESIGNER_ASYNC_ENGINE=1. " + f"Offending column(s): {offending}. Either remove allow_resize=True or " + f"disable the async scheduler." + ) + + def _build_async( + self, + generators: list[ColumnGenerator], + num_records: int, + buffer_size: int, + on_batch_complete: Callable[[Path], None] | None = None, + ) -> None: + """Async task-queue builder path — dispatches tasks based on dependency readiness.""" + logger.info("⚡ DATA_DESIGNER_ASYNC_ENGINE is enabled - using async task-queue builder") + + # Build strategy map from generators + strategies: dict[str, GenerationStrategy] = {} + gen_map: dict[str, ColumnGenerator] = {} + for gen in generators: + if isinstance(gen.config, MultiColumnConfig): + for sub in gen.config.columns: + strategies[sub.name] = gen.get_generation_strategy() + gen_map[sub.name] = gen + else: + strategies[gen.config.name] = gen.get_generation_strategy() + gen_map[gen.config.name] = gen + + graph = ExecutionGraph.create(self._column_configs, strategies) + + # Log pre-generation info for all generators + for gen in generators: + gen.log_pre_generation() + + # Partition into row groups + row_groups: list[tuple[int, int]] = [] + remaining = num_records + rg_id = 0 + while remaining > 0: + size = min(buffer_size, remaining) + row_groups.append((rg_id, size)) + remaining -= size + rg_id += 1 + + tracker = CompletionTracker.with_graph(graph, row_groups) + buffer_manager = RowGroupBufferManager(self.artifact_storage) + settings = self._resource_provider.run_config + + # Pre-batch processor callback: runs after seed tasks complete for a row group. + # If it raises, the scheduler drops all rows in the row group (skips it). + def on_seeds_complete(rg_id: int, rg_size: int) -> None: + if not self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH): + return + df = buffer_manager.get_dataframe(rg_id) + df = self._processor_runner.run_pre_batch_on_df(df) + buffer_manager.replace_dataframe(rg_id, df) + # Sync newly-dropped rows to the tracker so downstream cell tasks are skipped + for ri in range(rg_size): + if buffer_manager.is_dropped(rg_id, ri) and not tracker.is_dropped(rg_id, ri): + tracker.drop_row(rg_id, ri) + + # Post-batch processor callback: runs after all columns, before checkpoint. + # rg_id is used as current_batch_number; both are 0-based sequential indices today. + def on_before_checkpoint(rg_id: int, rg_size: int) -> None: + df = buffer_manager.get_dataframe(rg_id) + df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id) + buffer_manager.replace_dataframe(rg_id, df) + + # Telemetry snapshot + group_id = uuid.uuid4().hex + pre_batch_snapshot = self._resource_provider.model_registry.get_model_usage_snapshot() + + trace_enabled = settings.async_trace or os.environ.get("DATA_DESIGNER_ASYNC_TRACE", "0") == "1" + + scheduler = AsyncTaskScheduler( + generators=gen_map, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + on_checkpoint_complete=on_batch_complete, + on_seeds_complete=( + on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None + ), + on_before_checkpoint=( + on_before_checkpoint if self._processor_runner.has_processors_for(ProcessorStage.POST_BATCH) else None + ), + shutdown_error_rate=settings.shutdown_error_rate, + shutdown_error_window=settings.shutdown_error_window, + disable_early_shutdown=settings.disable_early_shutdown, + trace=trace_enabled, + ) + + # Run on background event loop + loop = ensure_async_engine_loop() + future = asyncio.run_coroutine_threadsafe(scheduler.run(), loop) + future.result() + + self._task_traces = scheduler.traces + + # Emit telemetry + try: + usage_deltas = self._resource_provider.model_registry.get_usage_deltas(pre_batch_snapshot) + self._emit_batch_inference_events("batch", usage_deltas, group_id) + except Exception: + logger.debug("Failed to emit batch telemetry for async run", exc_info=True) + + # Write metadata + buffer_manager.write_metadata(target_num_records=num_records, buffer_size=buffer_size) + def process_preview(self, dataset: pd.DataFrame) -> pd.DataFrame: df = self._processor_runner.run_post_batch(dataset.copy(), current_batch_number=None) return self._processor_runner.run_after_generation_on_df(df) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py index f7d18bd26..105278233 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/async_concurrency.py @@ -23,7 +23,7 @@ The background loop is a process-wide singleton. LiteLLM and similar libraries bind internal async state to a specific event loop, so creating per-call or per-instance loops breaks connection reuse and triggers - cross-loop errors. ``_ensure_async_engine_loop()`` creates one daemon + cross-loop errors. ``ensure_async_engine_loop()`` creates one daemon loop thread and reuses it for all executor instances. Startup Handshake: @@ -86,7 +86,7 @@ def _run_loop(loop: asyncio.AbstractEventLoop, ready: threading.Event) -> None: loop.run_forever() -def _ensure_async_engine_loop() -> asyncio.AbstractEventLoop: +def ensure_async_engine_loop() -> asyncio.AbstractEventLoop: """Get or create a persistent event loop for async engine work. A single event loop is shared across all AsyncConcurrentExecutor instances @@ -172,7 +172,7 @@ def run(self, work_items: list[tuple[Coroutine[Any, Any, Any], dict | None]]) -> f"AsyncConcurrentExecutor: launching {len(work_items)} tasks " f"with max_workers={self._max_workers} for column '{self._column_name}'" ) - loop = _ensure_async_engine_loop() + loop = ensure_async_engine_loop() future = asyncio.run_coroutine_threadsafe(self._run_all(work_items), loop) future.result() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py index 80ee1b9de..881bf6c54 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py @@ -78,6 +78,17 @@ def is_all_complete(self, cells: list[SliceRef]) -> bool: return False return True + def is_column_complete_for_rg(self, column: str, row_group_index: int) -> bool: + """Check if *column* has been fully completed for *row_group_index*.""" + if column in self._batch_complete.get(row_group_index, set()): + return True + rg_size = self._row_group_sizes.get(row_group_index, 0) + if rg_size == 0: + return False + completed = self._completed.get(row_group_index, {}).get(column, set()) + dropped = self._dropped.get(row_group_index, set()) + return all(ri in completed or ri in dropped for ri in range(rg_size)) + def drop_row(self, row_group: int, row_index: int) -> None: self._validate_row_group(row_group) self._dropped[row_group].add(row_index) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py index 8163c807a..61bba8919 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/processor_runner.py @@ -72,6 +72,10 @@ def run_pre_batch(self, batch_manager: DatasetBatchManager) -> None: df = self._run_stage(df, ProcessorStage.PRE_BATCH) batch_manager.replace_buffer(df.to_dict(orient="records"), allow_resize=True) + def run_pre_batch_on_df(self, df: pd.DataFrame) -> pd.DataFrame: + """Run PRE_BATCH processors on a DataFrame and return the result.""" + return self._run_stage(df, ProcessorStage.PRE_BATCH) + def run_post_batch(self, df: pd.DataFrame, current_batch_number: int | None) -> pd.DataFrame: """Run process_after_batch() on processors that implement it.""" return self._run_stage(df, ProcessorStage.POST_BATCH, current_batch_number=current_batch_number) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py index a63be31fa..b20ce3aca 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/row_group_buffer.py @@ -68,6 +68,23 @@ def get_dataframe(self, row_group: int) -> pd.DataFrame: rows = [row for i, row in enumerate(self._buffers[row_group]) if i not in dropped] return lazy.pd.DataFrame(rows) + def replace_dataframe(self, row_group: int, df: pd.DataFrame) -> None: + """Replace the buffer for a row group from a DataFrame (non-dropped rows only). + + If *df* has fewer rows than active slots, trailing slots are marked as dropped. + """ + dropped = self._dropped.get(row_group, set()) + records = df.to_dict(orient="records") + buf_idx = 0 + for ri in range(self._row_group_sizes[row_group]): + if ri in dropped: + continue + if buf_idx < len(records): + self._buffers[row_group][ri] = records[buf_idx] + else: + self._dropped.setdefault(row_group, set()).add(ri) + buf_idx += 1 + def drop_row(self, row_group: int, row_index: int) -> None: self._dropped.setdefault(row_group, set()).add(row_index) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py new file mode 100644 index 000000000..1302fa905 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -0,0 +1,259 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math +from unittest.mock import MagicMock, Mock + +import pytest + +import data_designer.lazy_heavy_imports as lazy +from data_designer.config.column_configs import ( + ExpressionColumnConfig, + GenerationStrategy, + LLMTextColumnConfig, + SamplerColumnConfig, +) +from data_designer.config.sampler_params import SamplerType +from data_designer.engine.column_generators.generators.base import ( + ColumnGenerator, + ColumnGeneratorFullColumn, + FromScratchColumnGenerator, +) +from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler +from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder +from data_designer.engine.dataset_builders.errors import DatasetGenerationError +from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker +from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph +from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager +from data_designer.engine.resources.resource_provider import ResourceProvider + +MODEL_ALIAS = "stub" + + +# -- Mock generators for integration tests ------------------------------------ + + +def _expr_config(name: str = "test") -> ExpressionColumnConfig: + return ExpressionColumnConfig(name=name, expr="{{ x }}", dtype="str") + + +def _mock_provider() -> MagicMock: + return MagicMock(spec=ResourceProvider) + + +class MockSeed(FromScratchColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + return data + + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + return lazy.pd.DataFrame({"seed": list(range(num_records))}) + + +class MockCell(ColumnGenerator[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + data["cell_out"] = f"val_{data.get('seed', '?')}" + return data + + +class MockFullCol(ColumnGeneratorFullColumn[ExpressionColumnConfig]): + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + data["expr_out"] = "computed" + return data + + +# -- allow_resize validation test --------------------------------------------- + + +@pytest.mark.parametrize( + "configs,should_raise", + [ + pytest.param( + [Mock(name="col_a", allow_resize=True), Mock(name="col_b", allow_resize=False)], + True, + id="raises_on_allow_resize", + ), + pytest.param( + [Mock(name="col_a", allow_resize=False), Mock(name="col_b", allow_resize=False)], + False, + id="passes_without_allow_resize", + ), + ], +) +def test_validate_async_compatibility(configs: list[Mock], should_raise: bool) -> None: + """Validation rejects allow_resize=True with the async engine.""" + builder = Mock(spec=ColumnWiseDatasetBuilder) + builder.single_column_configs = configs + if should_raise: + with pytest.raises(DatasetGenerationError, match="allow_resize=True"): + ColumnWiseDatasetBuilder._validate_async_compatibility(builder) + else: + ColumnWiseDatasetBuilder._validate_async_compatibility(builder) + + +# -- _build_async integration test with mock generators ----------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_build_async_end_to_end() -> None: + """Test _build_async with mock generators produces correct results in buffer.""" + + provider = _mock_provider() + seed_gen = MockSeed(config=_expr_config("seed"), resource_provider=provider) + cell_gen = MockCell(config=_expr_config("cell_out"), resource_provider=provider) + expr_gen = MockFullCol(config=_expr_config("expr_out"), resource_provider=provider) + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="expr_out", expr="{{ cell_out }}"), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + "expr_out": GenerationStrategy.FULL_COLUMN, + } + gen_map = { + "seed": seed_gen, + "cell_out": cell_gen, + "expr_out": expr_gen, + } + + num_records = 4 + buffer_size = 2 + graph = ExecutionGraph.create(configs, strategies) + + row_groups: list[tuple[int, int]] = [] + remaining = num_records + rg_id = 0 + while remaining > 0: + size = min(buffer_size, remaining) + row_groups.append((rg_id, size)) + remaining -= size + rg_id += 1 + + tracker = CompletionTracker.with_graph(graph, row_groups) + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_manager = RowGroupBufferManager(storage) + + checkpointed: list[int] = [] + + scheduler = AsyncTaskScheduler( + generators=gen_map, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + on_row_group_complete=lambda rg: checkpointed.append(rg), + ) + await scheduler.run() + + # Both row groups should be checkpointed + assert sorted(checkpointed) == [0, 1] + assert buffer_manager.actual_num_records == 4 + + # All columns should be complete + all_cols = ["seed", "cell_out", "expr_out"] + assert tracker.is_row_group_complete(0, 2, all_cols) + assert tracker.is_row_group_complete(1, 2, all_cols) + + +# -- Test that existing sync path is unaffected -------------------------------- + + +def test_sync_path_unaffected_by_async_engine_flag() -> None: + """DATA_DESIGNER_ASYNC_ENGINE=0 keeps the sync path unchanged.""" + import data_designer.engine.dataset_builders.column_wise_builder as builder_mod + + assert hasattr(builder_mod, "DATA_DESIGNER_ASYNC_ENGINE") + assert isinstance(builder_mod.DATA_DESIGNER_ASYNC_ENGINE, bool) + + +# -- Test execution graph integration with real column configs ----------------- + + +def test_execution_graph_from_real_configs() -> None: + """Build execution graph from real column config objects.""" + configs = [ + SamplerColumnConfig(name="id", sampler_type=SamplerType.UUID, params={}), + LLMTextColumnConfig(name="question", prompt="{{ id }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="answer", prompt="{{ question }}", model_alias=MODEL_ALIAS), + ExpressionColumnConfig(name="combined", expr="{{ question }} {{ answer }}"), + ] + strategies = { + "id": GenerationStrategy.FULL_COLUMN, + "question": GenerationStrategy.CELL_BY_CELL, + "answer": GenerationStrategy.CELL_BY_CELL, + "combined": GenerationStrategy.FULL_COLUMN, + } + + graph = ExecutionGraph.create(configs, strategies) + + order = graph.get_topological_order() + idx = {col: i for i, col in enumerate(order)} + + assert idx["id"] < idx["question"] + assert idx["question"] < idx["answer"] + assert idx["answer"] < idx["combined"] + + # Task counts + counts = graph.compute_task_count(num_records=10, buffer_size=3) + assert counts["id"] == math.ceil(10 / 3) + assert counts["question"] == 10 + assert counts["answer"] == 10 + assert counts["combined"] == math.ceil(10 / 3) + + +# -- Test checkpoint correctness ----------------------------------------------- + + +@pytest.mark.asyncio(loop_scope="session") +async def test_checkpoint_produces_correct_parquet_calls() -> None: + """Verify checkpoint writes parquet for each row group.""" + + provider = _mock_provider() + seed_gen = MockSeed(config=_expr_config("seed"), resource_provider=provider) + + configs = [SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["X"]})] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + gen_map = {"seed": seed_gen} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3), (1, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_manager = RowGroupBufferManager(storage) + + scheduler = AsyncTaskScheduler( + generators=gen_map, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_manager, + ) + await scheduler.run() + + # Two row groups → two write_batch_to_parquet_file calls + assert storage.write_batch_to_parquet_file.call_count == 2 + assert storage.move_partial_result_to_final_file_path.call_count == 2 + assert buffer_manager.actual_num_records == 5 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index dc91e2ee7..1027bf872 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -3,6 +3,7 @@ from __future__ import annotations +import asyncio from typing import Any from unittest.mock import MagicMock @@ -463,7 +464,7 @@ async def test_scheduler_eager_row_drop_skips_downstream_of_failed_column() -> N """When fail_col drops a row, a downstream column never processes it.""" provider = _mock_provider() - # Pipeline: seed → fail_col (cell, permanent failure) → downstream (cell) + # Pipeline: seed -> fail_col (cell, permanent failure) -> downstream (cell) # downstream depends on fail_col, so its tasks only enter the frontier # after fail_col completes for each row. Since fail_col always fails, # the row is dropped before downstream is ever enqueued. @@ -502,7 +503,7 @@ async def test_scheduler_eager_row_drop_skips_downstream_of_failed_column() -> N async def test_scheduler_non_retryable_seed_failure_no_keyerror_on_downstream() -> None: """Non-retryable seed failure does not cause KeyError on vacuously-ready downstream. - Pipeline: seed (full_column) → cell_out (cell_by_cell) → full_out (full_column). + Pipeline: seed (full_column) -> cell_out (cell_by_cell) -> full_out (full_column). When seed fails non-retryably, all rows are dropped. cell_out's cell tasks become vacuously complete (all rows dropped), which makes full_out ready. full_out must not crash with a KeyError when its row group buffer has been @@ -557,3 +558,365 @@ async def test_scheduler_non_retryable_seed_failure_no_keyerror_on_downstream() # full_out was either never dispatched or silently skipped (no KeyError) full_out_errors = [t for t in scheduler.traces if t.column == "full_out" and t.status == "error"] assert len(full_out_errors) == 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_error_rate_shutdown() -> None: + """Early shutdown triggers when error rate exceeds threshold.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="fail_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "fail_col": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "fail_col": MockFailingGenerator(config=_expr_config("fail_col"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 10)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + buffer_mgr = RowGroupBufferManager(storage) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + shutdown_error_rate=0.5, + shutdown_error_window=2, + ) + await scheduler.run() + + # Early shutdown: not all rows should be checkpointed (some row groups incomplete) + assert buffer_mgr.actual_num_records < 10 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_early_shutdown_disabled() -> None: + """shutdown_error_rate=1.0 prevents shutdown even at 100% error rate.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="fail_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "fail_col": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "fail_col": MockFailingGenerator(config=_expr_config("fail_col"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 5)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + buffer_mgr = RowGroupBufferManager(storage) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + shutdown_error_rate=1.0, + ) + await scheduler.run() + + # All rows dropped (all fail) but no early shutdown - all row groups processed + assert all(tracker.is_dropped(0, ri) for ri in range(5)) + assert tracker.is_row_group_complete(0, 5, ["seed", "fail_col"]) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_sliding_window_error_rate_recovers() -> None: + """Transient errors diluted by successes do not trigger early shutdown.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "col": GenerationStrategy.CELL_BY_CELL, + } + # First 2 calls fail (retryable 503), rest succeed. + # With window=10 and 10 cell tasks, at most 2/10 = 20% error rate + # when the window first fills - well below the 0.4 threshold. + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "col": MockFailingGenerator(config=_expr_config("col"), resource_provider=provider, transient_failures=2), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 10)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + buffer_mgr = RowGroupBufferManager(storage) + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + shutdown_error_rate=0.4, + shutdown_error_window=10, + ) + await scheduler.run() + + # No early shutdown - transient errors recovered in salvage + assert not scheduler._early_shutdown + assert tracker.is_row_group_complete(0, 10, ["seed", "col"]) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_on_before_checkpoint_callback() -> None: + """on_before_checkpoint is called before each row group is checkpointed.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + ] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + generators = {"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider)} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3), (1, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_mgr = RowGroupBufferManager(storage) + callback_log: list[tuple[int, int]] = [] + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_before_checkpoint=lambda rg, sz: callback_log.append((rg, sz)), + ) + await scheduler.run() + + assert sorted(callback_log) == [(0, 3), (1, 2)] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_on_checkpoint_complete_callback_receives_final_path() -> None: + """on_checkpoint_complete is called with the written parquet file path.""" + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + ] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + generators = {"seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider)} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_mgr = RowGroupBufferManager(storage) + callback_log: list[str] = [] + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_checkpoint_complete=lambda path: callback_log.append(path), + ) + await scheduler.run() + + assert callback_log == ["/fake_final.parquet"] + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_on_checkpoint_complete_skips_empty_row_group() -> None: + """on_checkpoint_complete is not called when a row group writes no file.""" + provider = _mock_provider() + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + ] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + generators = { + "seed": MockFailingSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + } + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + buffer_mgr = RowGroupBufferManager(storage) + callback = MagicMock() + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_checkpoint_complete=callback, + ) + await scheduler.run() + + callback.assert_not_called() + storage.write_batch_to_parquet_file.assert_not_called() + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_pre_batch_failure_skips_row_group() -> None: + """Pre-batch processor failure drops all rows in the row group; other row groups continue.""" + provider = _mock_provider() + seed_gen = MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider) + cell_gen = MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider) + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = {"seed": seed_gen, "cell_out": cell_gen} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3), (1, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + + buffer_mgr = RowGroupBufferManager(storage) + + def failing_pre_batch(rg_id: int, rg_size: int) -> None: + if rg_id == 0: + raise RuntimeError("pre-batch processor failed") + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + on_seeds_complete=failing_pre_batch, + ) + await scheduler.run() + + # Row group 0: all rows dropped due to pre-batch failure + assert all(tracker.is_dropped(0, ri) for ri in range(3)) + # Row group 1: completed normally + assert tracker.is_row_group_complete(1, 2, ["seed", "cell_out"]) + + +class _SlowSeedGenerator(FromScratchColumnGenerator[ExpressionColumnConfig]): + """Seed generator whose async cost scales with row count. + + Both RGs' seed tasks run concurrently. The task with fewer rows sleeps for + less real time, causing its downstream to be dispatched and completed first. + """ + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + return data + + def generate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + return lazy.pd.DataFrame({self.config.name: list(range(num_records))}) + + async def agenerate_from_scratch(self, num_records: int) -> lazy.pd.DataFrame: + await asyncio.sleep(num_records * 0.02) + return self.generate_from_scratch(num_records) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_out_of_order_row_group_completion() -> None: + """Row groups may complete out of order; both are checkpointed correctly.""" + provider = _mock_provider() + # Slow seed generator: RG 0 (5 rows) sleeps 100ms, RG 1 (1 row) sleeps 20ms. + # RG 1 finishes seeds first, its downstream is dispatched and completes before RG 0. + slow_seed = _SlowSeedGenerator(config=_expr_config("seed"), resource_provider=provider) + cell_gen = MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider) + + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = {"seed": slow_seed, "cell_out": cell_gen} + + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 5), (1, 1)] + tracker = CompletionTracker.with_graph(graph, row_groups) + + storage = MagicMock() + storage.dataset_name = "test" + storage.get_file_paths.return_value = {} + storage.write_batch_to_parquet_file.return_value = "/fake.parquet" + storage.move_partial_result_to_final_file_path.return_value = "/fake_final.parquet" + buffer_mgr = RowGroupBufferManager(storage) + + checkpoint_order: list[int] = [] + + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=buffer_mgr, + max_concurrent_row_groups=2, + on_row_group_complete=lambda rg_id: checkpoint_order.append(rg_id), + ) + await scheduler.run() + + # Both row groups completed + assert tracker.is_row_group_complete(0, 5, ["seed", "cell_out"]) + assert tracker.is_row_group_complete(1, 1, ["seed", "cell_out"]) + # Both were checkpointed + assert set(checkpoint_order) == {0, 1} + # RG 1 (fewer rows, fewer seed yields) checkpoints before RG 0 + assert checkpoint_order.index(1) < checkpoint_order.index(0) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py index 50819ee64..189ddcac7 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_row_group_buffer.py @@ -7,6 +7,7 @@ import pytest +import data_designer.lazy_heavy_imports as lazy from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager @@ -124,6 +125,57 @@ def test_checkpoint_calls_on_complete() -> None: callback.assert_called_once_with("/fake/final.parquet") +def test_replace_dataframe_same_size() -> None: + """replace_dataframe with same number of rows replaces data in-place.""" + mgr = RowGroupBufferManager(_mock_artifact_storage()) + mgr.init_row_group(0, 3) + mgr.update_batch(0, "col", ["a", "b", "c"]) + + df = lazy.pd.DataFrame({"col": ["x", "y", "z"]}) + mgr.replace_dataframe(0, df) + + assert mgr.get_row(0, 0) == {"col": "x"} + assert mgr.get_row(0, 1) == {"col": "y"} + assert mgr.get_row(0, 2) == {"col": "z"} + + +def test_replace_dataframe_with_dropped_rows() -> None: + """replace_dataframe skips dropped rows and replaces only active slots.""" + mgr = RowGroupBufferManager(_mock_artifact_storage()) + mgr.init_row_group(0, 4) + mgr.update_batch(0, "col", ["a", "b", "c", "d"]) + mgr.drop_row(0, 1) # drop row 1 + + # 3 active rows: indices 0, 2, 3 + df = lazy.pd.DataFrame({"col": ["x", "y", "z"]}) + mgr.replace_dataframe(0, df) + + assert mgr.get_row(0, 0) == {"col": "x"} + assert mgr.is_dropped(0, 1) + assert mgr.get_row(0, 2) == {"col": "y"} + assert mgr.get_row(0, 3) == {"col": "z"} + + +def test_replace_dataframe_fewer_rows_marks_trailing_dropped() -> None: + """replace_dataframe with fewer rows marks trailing active slots as dropped.""" + mgr = RowGroupBufferManager(_mock_artifact_storage()) + mgr.init_row_group(0, 4) + mgr.update_batch(0, "col", ["a", "b", "c", "d"]) + + # Only 2 rows - should drop indices 2 and 3 + df = lazy.pd.DataFrame({"col": ["x", "y"]}) + mgr.replace_dataframe(0, df) + + assert mgr.get_row(0, 0) == {"col": "x"} + assert mgr.get_row(0, 1) == {"col": "y"} + assert mgr.is_dropped(0, 2) + assert mgr.is_dropped(0, 3) + + # get_dataframe should only return the 2 active rows + result_df = mgr.get_dataframe(0) + assert len(result_df) == 2 + + def test_checkpoint_calls_on_complete_when_all_rows_dropped() -> None: storage = _mock_artifact_storage() callback = Mock() diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index d5df0ff5b..c195a174d 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -226,6 +226,8 @@ def create( except Exception as e: raise DataDesignerGenerationError(f"🛑 Error generating dataset: {e}") from e + task_traces = builder.task_traces + try: dataset_for_profiler = builder.artifact_storage.load_dataset_with_dropped_columns() except Exception as e: @@ -262,6 +264,7 @@ def create( analysis=analysis, config_builder=config_builder, dataset_metadata=dataset_metadata, + task_traces=task_traces, ) def preview( diff --git a/packages/data-designer/src/data_designer/interface/results.py b/packages/data-designer/src/data_designer/interface/results.py index 4a4006a95..07692ff00 100644 --- a/packages/data-designer/src/data_designer/interface/results.py +++ b/packages/data-designer/src/data_designer/interface/results.py @@ -17,6 +17,8 @@ if TYPE_CHECKING: import pandas as pd + from data_designer.engine.dataset_builders.utils.task_model import TaskTrace + class DatasetCreationResults(WithRecordSamplerMixin): """Results container for a Data Designer dataset creation run. @@ -33,6 +35,7 @@ def __init__( analysis: DatasetProfilerResults, config_builder: DataDesignerConfigBuilder, dataset_metadata: DatasetMetadata, + task_traces: list[TaskTrace] | None = None, ): """Creates a new instance with results based on a dataset creation run. @@ -41,11 +44,13 @@ def __init__( analysis: Profiling results for the generated dataset. config_builder: Configuration builder used to create the dataset. dataset_metadata: Metadata about the generated dataset (e.g., seed column names). + task_traces: Optional list of TaskTrace objects from the async scheduler. """ self.artifact_storage = artifact_storage self._analysis = analysis self._config_builder = config_builder self.dataset_metadata = dataset_metadata + self.task_traces: list[TaskTrace] = task_traces or [] def load_analysis(self) -> DatasetProfilerResults: """Load the profiling analysis results for the generated dataset. diff --git a/tests_e2e/tests/test_async_engine.py b/tests_e2e/tests/test_async_engine.py new file mode 100644 index 000000000..a15f4d10e --- /dev/null +++ b/tests_e2e/tests/test_async_engine.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from pathlib import Path + +import pytest + +NUM_RECORDS = 2 +PARALLEL_COLUMNS = ("summary", "analysis") + + +def _run_async_engine_concurrency_case(tmp_path: Path) -> dict[str, object]: + repo_root = Path(__file__).resolve().parents[2] + script = f""" +from __future__ import annotations + +import json +from collections import defaultdict +from pathlib import Path + +import data_designer.config as dd +from data_designer.config.run_config import RunConfig +from data_designer.interface import DataDesigner + +NUM_RECORDS = {NUM_RECORDS} +PARALLEL_COLUMNS = {PARALLEL_COLUMNS!r} +tmp_path = Path({str(tmp_path)!r}) + +dd_instance = DataDesigner(artifact_path=str(tmp_path)) +dd_instance.set_run_config(RunConfig(buffer_size=NUM_RECORDS, async_trace=True)) + +config = dd.DataDesignerConfigBuilder() +config.add_column( + dd.SamplerColumnConfig( + name="topic", + sampler_type=dd.SamplerType.CATEGORY, + params=dd.CategorySamplerParams(values=["science", "history", "art"]), + ) +) +for col in PARALLEL_COLUMNS: + config.add_column( + dd.LLMTextColumnConfig( + name=col, + model_alias="nvidia-text", + prompt="Write one sentence about {{{{ topic }}}} (" + col + ").", + ) + ) + +result = dd_instance.create(config, num_records=NUM_RECORDS, dataset_name="async_e2e") +df = result.load_dataset() +traces = result.task_traces + +by_col: dict[str, list[tuple[float, float]]] = defaultdict(list) +for trace in traces: + if trace.task_type == "cell" and trace.status == "ok" and trace.slot_acquired_at and trace.completed_at: + by_col[trace.column].append((trace.slot_acquired_at, trace.completed_at)) + +overlap_found = False +cols = [col for col in PARALLEL_COLUMNS if by_col[col]] +for i, col_a in enumerate(cols): + for col_b in cols[i + 1 :]: + for start_a, end_a in by_col[col_a]: + for start_b, end_b in by_col[col_b]: + if start_a < end_b and start_b < end_a: + overlap_found = True + break + if overlap_found: + break + if overlap_found: + break + if overlap_found: + break + +payload = {{ + "rows": len(df), + "columns": list(df.columns), + "non_null": {{col: bool(df[col].notna().all()) for col in ("topic", *PARALLEL_COLUMNS)}}, + "trace_count": len(traces), + "overlap_found": overlap_found, +}} +print("RESULT_JSON=" + json.dumps(payload)) +""" + env = os.environ.copy() + env["DATA_DESIGNER_ASYNC_ENGINE"] = "1" + completed = subprocess.run( + [sys.executable, "-c", script], + check=True, + capture_output=True, + text=True, + cwd=repo_root, + env=env, + ) + for line in completed.stdout.splitlines(): + if line.startswith("RESULT_JSON="): + return json.loads(line.removeprefix("RESULT_JSON=")) + raise AssertionError(f"Missing RESULT_JSON marker in subprocess output:\n{completed.stdout}") + + +def test_async_engine_concurrent_columns(tmp_path: Path) -> None: + """Verify the async engine runs independent LLM columns concurrently.""" + if os.environ.get("NVIDIA_API_KEY") is None: + pytest.skip("NVIDIA_API_KEY must be set") + + payload = _run_async_engine_concurrency_case(tmp_path) + + assert payload["rows"] == NUM_RECORDS + for col in ("topic", *PARALLEL_COLUMNS): + assert col in payload["columns"] + assert payload["non_null"][col] + + assert payload["trace_count"] > 0 + assert payload["overlap_found"], ( + "No overlapping execution found between parallel columns - async scheduler may not be dispatching concurrently" + )