diff --git a/backend/src/events.py b/backend/src/events.py index 74a486519..4b0401be1 100644 --- a/backend/src/events.py +++ b/backend/src/events.py @@ -1,8 +1,10 @@ from __future__ import annotations import asyncio +import threading +import time from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Literal, TypedDict import navi @@ -160,3 +162,119 @@ async def wait_until_empty(self, timeout: float) -> None: async def put_and_wait(self, event: Event, timeout: float = float("inf")) -> None: await self.queue.put(event) await self.wait_until_empty(timeout) + + +@dataclass +class ThrottledProgressQueue(EventConsumer): + """ + A smart event queue that batches and throttles progress events. + + Key features: + - Progress events are deduplicated per nodeId (only latest is kept) + - Batched events are flushed at regular intervals (default 50ms = 20 updates/sec) + - Non-progress events pass through immediately + - Thread-safe for use from executor threads + + This significantly reduces CPU overhead when processing many iterations, + as we avoid sending redundant intermediate progress updates. + """ + + queue: EventConsumer + flush_interval: float = 0.05 # 50ms = 20 updates/sec max + + # Internal state (initialized in __post_init__) + _pending_progress: dict[NodeId, NodeProgressUpdateEvent] = field( + default_factory=dict, init=False + ) + _lock: threading.Lock = field(default_factory=threading.Lock, init=False) + _last_flush_time: float = field(default_factory=time.monotonic, init=False) + _flush_scheduled: bool = field(default=False, init=False) + _loop: asyncio.AbstractEventLoop | None = field(default=None, init=False) + + def set_loop(self, loop: asyncio.AbstractEventLoop) -> None: + """Set the event loop for scheduling async flushes.""" + self._loop = loop + + def put(self, event: Event) -> None: + event_type = event["event"] + + # Progress events get batched and deduplicated + if event_type == "node-progress": + self._buffer_progress(event) # type: ignore + return + + # Flush any pending progress before important state changes + # This ensures progress is up-to-date before node-finish + if event_type in ("node-finish", "chain-start", "execution-error"): + self._flush_now() + + # All other events pass through immediately + self.queue.put(event) + + def _buffer_progress(self, event: NodeProgressUpdateEvent) -> None: + """Buffer a progress event, replacing any existing one for the same node.""" + node_id = event["data"]["nodeId"] + + with self._lock: + # Always keep the latest progress for each node + self._pending_progress[node_id] = event + + # Check if we should flush based on time + now = time.monotonic() + elapsed = now - self._last_flush_time + + if elapsed >= self.flush_interval: + # Time to flush immediately + self._flush_locked() + elif not self._flush_scheduled and self._loop is not None: + # Schedule a flush for later + # Use call_soon_threadsafe since this may be called from worker threads + self._flush_scheduled = True + delay = self.flush_interval - elapsed + self._loop.call_soon_threadsafe( + lambda d=delay: self._schedule_delayed_flush(d) + ) + + def _schedule_delayed_flush(self, delay: float) -> None: + """Schedule a delayed flush. Must be called from the event loop thread.""" + if self._loop is not None: + self._loop.call_later(delay, self._scheduled_flush) + + def _scheduled_flush(self) -> None: + """Called by the event loop after the flush delay.""" + with self._lock: + self._flush_scheduled = False + if self._pending_progress: + self._flush_locked() + + def _flush_locked(self) -> None: + """Flush pending progress events. Must be called with lock held.""" + if not self._pending_progress: + return + + # Send all pending progress events + for event in self._pending_progress.values(): + self.queue.put(event) + + self._pending_progress.clear() + self._last_flush_time = time.monotonic() + + def _flush_now(self) -> None: + """Immediately flush any pending progress events.""" + with self._lock: + self._flush_locked() + + def flush(self) -> None: + """ + Public method to flush pending events. + Call this when execution completes to ensure all progress is sent. + """ + self._flush_now() + + def clear_node(self, node_id: NodeId) -> None: + """ + Clear any pending progress for a specific node. + Useful when a node is aborted or restarted. + """ + with self._lock: + self._pending_progress.pop(node_id, None) diff --git a/backend/src/server.py b/backend/src/server.py index d90489768..87e552ef8 100644 --- a/backend/src/server.py +++ b/backend/src/server.py @@ -27,7 +27,12 @@ from chain.json import JsonNode, parse_json from chain.optimize import optimize from dependencies.store import installed_packages -from events import EventConsumer, EventQueue, ExecutionErrorData +from events import ( + EventConsumer, + EventQueue, + ExecutionErrorData, + ThrottledProgressQueue, +) # Logger will be initialized when AppContext is created # For now, use a fallback logger @@ -239,6 +244,11 @@ async def run(request: Request): logger.info("Running new executor...") + # Wrap queue with throttling for progress events + # This batches and deduplicates progress updates to reduce CPU overhead + throttled_queue = ThrottledProgressQueue(queue=ctx.queue) + throttled_queue.set_loop(app.loop) + use_new_executor = full_data.get("useExperimentalFeatures", False) if use_new_executor: executor = NewExecutor( @@ -247,7 +257,7 @@ async def run(request: Request): send_broadcast_data=full_data["sendBroadcastData"], options=ExecutionOptions.parse(full_data["options"]), loop=app.loop, - queue=ctx.queue, + queue=throttled_queue, pool=ctx.pool, storage_dir=ctx.storage_dir, parent_cache=OutputCache(static_data=ctx.cache.copy()), @@ -259,7 +269,7 @@ async def run(request: Request): send_broadcast_data=full_data["sendBroadcastData"], options=ExecutionOptions.parse(full_data["options"]), loop=app.loop, - queue=ctx.queue, + queue=throttled_queue, pool=ctx.pool, storage_dir=ctx.storage_dir, parent_cache=OutputCache(static_data=ctx.cache.copy()), @@ -270,6 +280,8 @@ async def run(request: Request): except Aborted: pass finally: + # Ensure any pending progress events are flushed + throttled_queue.flush() ctx.executor = None gc.collect() @@ -367,10 +379,15 @@ async def run_individual(request: Request): chain.inputs.set(node_id, node.data.inputs[index].id, i) # only yield certain types of events - queue = EventConsumer.filter( + filtered_queue = EventConsumer.filter( ctx.queue, {"node-finish", "node-broadcast", "execution-error"} ) + # Wrap with throttling (though individual runs typically don't have many + # progress events, this keeps the pattern consistent) + throttled_queue = ThrottledProgressQueue(queue=filtered_queue) + throttled_queue.set_loop(app.loop) + execution_id = ExecutionId("individual-executor " + node_id) use_new_executor = full_data.get("useExperimentalFeatures", False) @@ -381,7 +398,7 @@ async def run_individual(request: Request): send_broadcast_data=True, options=ExecutionOptions.parse(full_data["options"]), loop=app.loop, - queue=queue, + queue=throttled_queue, storage_dir=ctx.storage_dir, pool=ctx.pool, ) @@ -392,7 +409,7 @@ async def run_individual(request: Request): send_broadcast_data=True, options=ExecutionOptions.parse(full_data["options"]), loop=app.loop, - queue=queue, + queue=throttled_queue, storage_dir=ctx.storage_dir, pool=ctx.pool, ) @@ -432,6 +449,8 @@ async def run_individual(request: Request): except Aborted: pass finally: + # Ensure any pending progress events are flushed + throttled_queue.flush() if ctx.individual_executors.get(execution_id, None) == executor: ctx.individual_executors.pop(execution_id, None) gc.collect()