Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 119 additions & 1 deletion backend/src/events.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import asyncio
import threading
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Literal, TypedDict

import navi
Expand Down Expand Up @@ -160,3 +162,119 @@ async def wait_until_empty(self, timeout: float) -> None:
async def put_and_wait(self, event: Event, timeout: float = float("inf")) -> None:
await self.queue.put(event)
await self.wait_until_empty(timeout)


@dataclass
class ThrottledProgressQueue(EventConsumer):
"""
A smart event queue that batches and throttles progress events.

Key features:
- Progress events are deduplicated per nodeId (only latest is kept)
- Batched events are flushed at regular intervals (default 50ms = 20 updates/sec)
- Non-progress events pass through immediately
- Thread-safe for use from executor threads

This significantly reduces CPU overhead when processing many iterations,
as we avoid sending redundant intermediate progress updates.
"""

queue: EventConsumer
flush_interval: float = 0.05 # 50ms = 20 updates/sec max

# Internal state (initialized in __post_init__)
_pending_progress: dict[NodeId, NodeProgressUpdateEvent] = field(
default_factory=dict, init=False
)
_lock: threading.Lock = field(default_factory=threading.Lock, init=False)
_last_flush_time: float = field(default_factory=time.monotonic, init=False)
_flush_scheduled: bool = field(default=False, init=False)
_loop: asyncio.AbstractEventLoop | None = field(default=None, init=False)

def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
"""Set the event loop for scheduling async flushes."""
self._loop = loop

def put(self, event: Event) -> None:
event_type = event["event"]

# Progress events get batched and deduplicated
if event_type == "node-progress":
self._buffer_progress(event) # type: ignore
return

# Flush any pending progress before important state changes
# This ensures progress is up-to-date before node-finish
if event_type in ("node-finish", "chain-start", "execution-error"):
self._flush_now()

# All other events pass through immediately
self.queue.put(event)

def _buffer_progress(self, event: NodeProgressUpdateEvent) -> None:
"""Buffer a progress event, replacing any existing one for the same node."""
node_id = event["data"]["nodeId"]

with self._lock:
# Always keep the latest progress for each node
self._pending_progress[node_id] = event

# Check if we should flush based on time
now = time.monotonic()
elapsed = now - self._last_flush_time

if elapsed >= self.flush_interval:
# Time to flush immediately
self._flush_locked()
elif not self._flush_scheduled and self._loop is not None:
# Schedule a flush for later
# Use call_soon_threadsafe since this may be called from worker threads
self._flush_scheduled = True
delay = self.flush_interval - elapsed
self._loop.call_soon_threadsafe(
lambda d=delay: self._schedule_delayed_flush(d)
)

def _schedule_delayed_flush(self, delay: float) -> None:
"""Schedule a delayed flush. Must be called from the event loop thread."""
if self._loop is not None:
self._loop.call_later(delay, self._scheduled_flush)

def _scheduled_flush(self) -> None:
"""Called by the event loop after the flush delay."""
with self._lock:
self._flush_scheduled = False
if self._pending_progress:
self._flush_locked()

def _flush_locked(self) -> None:
"""Flush pending progress events. Must be called with lock held."""
if not self._pending_progress:
return

# Send all pending progress events
for event in self._pending_progress.values():
self.queue.put(event)

self._pending_progress.clear()
self._last_flush_time = time.monotonic()

def _flush_now(self) -> None:
"""Immediately flush any pending progress events."""
with self._lock:
self._flush_locked()

def flush(self) -> None:
"""
Public method to flush pending events.
Call this when execution completes to ensure all progress is sent.
"""
self._flush_now()

def clear_node(self, node_id: NodeId) -> None:
"""
Clear any pending progress for a specific node.
Useful when a node is aborted or restarted.
"""
with self._lock:
self._pending_progress.pop(node_id, None)
31 changes: 25 additions & 6 deletions backend/src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
from chain.json import JsonNode, parse_json
from chain.optimize import optimize
from dependencies.store import installed_packages
from events import EventConsumer, EventQueue, ExecutionErrorData
from events import (
EventConsumer,
EventQueue,
ExecutionErrorData,
ThrottledProgressQueue,
)

# Logger will be initialized when AppContext is created
# For now, use a fallback logger
Expand Down Expand Up @@ -239,6 +244,11 @@ async def run(request: Request):

logger.info("Running new executor...")

# Wrap queue with throttling for progress events
# This batches and deduplicates progress updates to reduce CPU overhead
throttled_queue = ThrottledProgressQueue(queue=ctx.queue)
throttled_queue.set_loop(app.loop)

use_new_executor = full_data.get("useExperimentalFeatures", False)
if use_new_executor:
executor = NewExecutor(
Expand All @@ -247,7 +257,7 @@ async def run(request: Request):
send_broadcast_data=full_data["sendBroadcastData"],
options=ExecutionOptions.parse(full_data["options"]),
loop=app.loop,
queue=ctx.queue,
queue=throttled_queue,
pool=ctx.pool,
storage_dir=ctx.storage_dir,
parent_cache=OutputCache(static_data=ctx.cache.copy()),
Expand All @@ -259,7 +269,7 @@ async def run(request: Request):
send_broadcast_data=full_data["sendBroadcastData"],
options=ExecutionOptions.parse(full_data["options"]),
loop=app.loop,
queue=ctx.queue,
queue=throttled_queue,
pool=ctx.pool,
storage_dir=ctx.storage_dir,
parent_cache=OutputCache(static_data=ctx.cache.copy()),
Expand All @@ -270,6 +280,8 @@ async def run(request: Request):
except Aborted:
pass
finally:
# Ensure any pending progress events are flushed
throttled_queue.flush()
ctx.executor = None
gc.collect()

Expand Down Expand Up @@ -367,10 +379,15 @@ async def run_individual(request: Request):
chain.inputs.set(node_id, node.data.inputs[index].id, i)

# only yield certain types of events
queue = EventConsumer.filter(
filtered_queue = EventConsumer.filter(
ctx.queue, {"node-finish", "node-broadcast", "execution-error"}
)

# Wrap with throttling (though individual runs typically don't have many
# progress events, this keeps the pattern consistent)
throttled_queue = ThrottledProgressQueue(queue=filtered_queue)
throttled_queue.set_loop(app.loop)

execution_id = ExecutionId("individual-executor " + node_id)

use_new_executor = full_data.get("useExperimentalFeatures", False)
Expand All @@ -381,7 +398,7 @@ async def run_individual(request: Request):
send_broadcast_data=True,
options=ExecutionOptions.parse(full_data["options"]),
loop=app.loop,
queue=queue,
queue=throttled_queue,
storage_dir=ctx.storage_dir,
pool=ctx.pool,
)
Expand All @@ -392,7 +409,7 @@ async def run_individual(request: Request):
send_broadcast_data=True,
options=ExecutionOptions.parse(full_data["options"]),
loop=app.loop,
queue=queue,
queue=throttled_queue,
storage_dir=ctx.storage_dir,
pool=ctx.pool,
)
Expand Down Expand Up @@ -432,6 +449,8 @@ async def run_individual(request: Request):
except Aborted:
pass
finally:
# Ensure any pending progress events are flushed
throttled_queue.flush()
if ctx.individual_executors.get(execution_id, None) == executor:
ctx.individual_executors.pop(execution_id, None)
gc.collect()
Expand Down