Skip to content

Commit f7ec222

Browse files
Add throttled progress queue for reduced cpu load
1 parent badba2d commit f7ec222

File tree

2 files changed

+144
-7
lines changed

2 files changed

+144
-7
lines changed

backend/src/events.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import threading
5+
import time
46
from abc import ABC, abstractmethod
5-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
68
from typing import Literal, TypedDict
79

810
import navi
@@ -160,3 +162,119 @@ async def wait_until_empty(self, timeout: float) -> None:
160162
async def put_and_wait(self, event: Event, timeout: float = float("inf")) -> None:
161163
await self.queue.put(event)
162164
await self.wait_until_empty(timeout)
165+
166+
167+
@dataclass
168+
class ThrottledProgressQueue(EventConsumer):
169+
"""
170+
A smart event queue that batches and throttles progress events.
171+
172+
Key features:
173+
- Progress events are deduplicated per nodeId (only latest is kept)
174+
- Batched events are flushed at regular intervals (default 50ms = 20 updates/sec)
175+
- Non-progress events pass through immediately
176+
- Thread-safe for use from executor threads
177+
178+
This significantly reduces CPU overhead when processing many iterations,
179+
as we avoid sending redundant intermediate progress updates.
180+
"""
181+
182+
queue: EventConsumer
183+
flush_interval: float = 0.05 # 50ms = 20 updates/sec max
184+
185+
# Internal state (initialized in __post_init__)
186+
_pending_progress: dict[NodeId, NodeProgressUpdateEvent] = field(
187+
default_factory=dict, init=False
188+
)
189+
_lock: threading.Lock = field(default_factory=threading.Lock, init=False)
190+
_last_flush_time: float = field(default_factory=time.monotonic, init=False)
191+
_flush_scheduled: bool = field(default=False, init=False)
192+
_loop: asyncio.AbstractEventLoop | None = field(default=None, init=False)
193+
194+
def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
195+
"""Set the event loop for scheduling async flushes."""
196+
self._loop = loop
197+
198+
def put(self, event: Event) -> None:
199+
event_type = event["event"]
200+
201+
# Progress events get batched and deduplicated
202+
if event_type == "node-progress":
203+
self._buffer_progress(event) # type: ignore
204+
return
205+
206+
# Flush any pending progress before important state changes
207+
# This ensures progress is up-to-date before node-finish
208+
if event_type in ("node-finish", "chain-start", "execution-error"):
209+
self._flush_now()
210+
211+
# All other events pass through immediately
212+
self.queue.put(event)
213+
214+
def _buffer_progress(self, event: NodeProgressUpdateEvent) -> None:
215+
"""Buffer a progress event, replacing any existing one for the same node."""
216+
node_id = event["data"]["nodeId"]
217+
218+
with self._lock:
219+
# Always keep the latest progress for each node
220+
self._pending_progress[node_id] = event
221+
222+
# Check if we should flush based on time
223+
now = time.monotonic()
224+
elapsed = now - self._last_flush_time
225+
226+
if elapsed >= self.flush_interval:
227+
# Time to flush immediately
228+
self._flush_locked()
229+
elif not self._flush_scheduled and self._loop is not None:
230+
# Schedule a flush for later
231+
# Use call_soon_threadsafe since this may be called from worker threads
232+
self._flush_scheduled = True
233+
delay = self.flush_interval - elapsed
234+
self._loop.call_soon_threadsafe(
235+
lambda d=delay: self._schedule_delayed_flush(d)
236+
)
237+
238+
def _schedule_delayed_flush(self, delay: float) -> None:
239+
"""Schedule a delayed flush. Must be called from the event loop thread."""
240+
if self._loop is not None:
241+
self._loop.call_later(delay, self._scheduled_flush)
242+
243+
def _scheduled_flush(self) -> None:
244+
"""Called by the event loop after the flush delay."""
245+
with self._lock:
246+
self._flush_scheduled = False
247+
if self._pending_progress:
248+
self._flush_locked()
249+
250+
def _flush_locked(self) -> None:
251+
"""Flush pending progress events. Must be called with lock held."""
252+
if not self._pending_progress:
253+
return
254+
255+
# Send all pending progress events
256+
for event in self._pending_progress.values():
257+
self.queue.put(event)
258+
259+
self._pending_progress.clear()
260+
self._last_flush_time = time.monotonic()
261+
262+
def _flush_now(self) -> None:
263+
"""Immediately flush any pending progress events."""
264+
with self._lock:
265+
self._flush_locked()
266+
267+
def flush(self) -> None:
268+
"""
269+
Public method to flush pending events.
270+
Call this when execution completes to ensure all progress is sent.
271+
"""
272+
self._flush_now()
273+
274+
def clear_node(self, node_id: NodeId) -> None:
275+
"""
276+
Clear any pending progress for a specific node.
277+
Useful when a node is aborted or restarted.
278+
"""
279+
with self._lock:
280+
self._pending_progress.pop(node_id, None)

backend/src/server.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@
2727
from chain.json import JsonNode, parse_json
2828
from chain.optimize import optimize
2929
from dependencies.store import installed_packages
30-
from events import EventConsumer, EventQueue, ExecutionErrorData
30+
from events import (
31+
EventConsumer,
32+
EventQueue,
33+
ExecutionErrorData,
34+
ThrottledProgressQueue,
35+
)
3136

3237
# Logger will be initialized when AppContext is created
3338
# For now, use a fallback logger
@@ -239,6 +244,11 @@ async def run(request: Request):
239244

240245
logger.info("Running new executor...")
241246

247+
# Wrap queue with throttling for progress events
248+
# This batches and deduplicates progress updates to reduce CPU overhead
249+
throttled_queue = ThrottledProgressQueue(queue=ctx.queue)
250+
throttled_queue.set_loop(app.loop)
251+
242252
use_new_executor = full_data.get("useExperimentalFeatures", False)
243253
if use_new_executor:
244254
executor = NewExecutor(
@@ -247,7 +257,7 @@ async def run(request: Request):
247257
send_broadcast_data=full_data["sendBroadcastData"],
248258
options=ExecutionOptions.parse(full_data["options"]),
249259
loop=app.loop,
250-
queue=ctx.queue,
260+
queue=throttled_queue,
251261
pool=ctx.pool,
252262
storage_dir=ctx.storage_dir,
253263
parent_cache=OutputCache(static_data=ctx.cache.copy()),
@@ -259,7 +269,7 @@ async def run(request: Request):
259269
send_broadcast_data=full_data["sendBroadcastData"],
260270
options=ExecutionOptions.parse(full_data["options"]),
261271
loop=app.loop,
262-
queue=ctx.queue,
272+
queue=throttled_queue,
263273
pool=ctx.pool,
264274
storage_dir=ctx.storage_dir,
265275
parent_cache=OutputCache(static_data=ctx.cache.copy()),
@@ -270,6 +280,8 @@ async def run(request: Request):
270280
except Aborted:
271281
pass
272282
finally:
283+
# Ensure any pending progress events are flushed
284+
throttled_queue.flush()
273285
ctx.executor = None
274286
gc.collect()
275287

@@ -367,10 +379,15 @@ async def run_individual(request: Request):
367379
chain.inputs.set(node_id, node.data.inputs[index].id, i)
368380

369381
# only yield certain types of events
370-
queue = EventConsumer.filter(
382+
filtered_queue = EventConsumer.filter(
371383
ctx.queue, {"node-finish", "node-broadcast", "execution-error"}
372384
)
373385

386+
# Wrap with throttling (though individual runs typically don't have many
387+
# progress events, this keeps the pattern consistent)
388+
throttled_queue = ThrottledProgressQueue(queue=filtered_queue)
389+
throttled_queue.set_loop(app.loop)
390+
374391
execution_id = ExecutionId("individual-executor " + node_id)
375392

376393
use_new_executor = full_data.get("useExperimentalFeatures", False)
@@ -381,7 +398,7 @@ async def run_individual(request: Request):
381398
send_broadcast_data=True,
382399
options=ExecutionOptions.parse(full_data["options"]),
383400
loop=app.loop,
384-
queue=queue,
401+
queue=throttled_queue,
385402
storage_dir=ctx.storage_dir,
386403
pool=ctx.pool,
387404
)
@@ -392,7 +409,7 @@ async def run_individual(request: Request):
392409
send_broadcast_data=True,
393410
options=ExecutionOptions.parse(full_data["options"]),
394411
loop=app.loop,
395-
queue=queue,
412+
queue=throttled_queue,
396413
storage_dir=ctx.storage_dir,
397414
pool=ctx.pool,
398415
)
@@ -432,6 +449,8 @@ async def run_individual(request: Request):
432449
except Aborted:
433450
pass
434451
finally:
452+
# Ensure any pending progress events are flushed
453+
throttled_queue.flush()
435454
if ctx.individual_executors.get(execution_id, None) == executor:
436455
ctx.individual_executors.pop(execution_id, None)
437456
gc.collect()

0 commit comments

Comments
 (0)