|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import asyncio |
| 4 | +import threading |
| 5 | +import time |
4 | 6 | from abc import ABC, abstractmethod |
5 | | -from dataclasses import dataclass |
| 7 | +from dataclasses import dataclass, field |
6 | 8 | from typing import Literal, TypedDict |
7 | 9 |
|
8 | 10 | import navi |
@@ -160,3 +162,119 @@ async def wait_until_empty(self, timeout: float) -> None: |
160 | 162 | async def put_and_wait(self, event: Event, timeout: float = float("inf")) -> None: |
161 | 163 | await self.queue.put(event) |
162 | 164 | await self.wait_until_empty(timeout) |
| 165 | + |
| 166 | + |
| 167 | +@dataclass |
| 168 | +class ThrottledProgressQueue(EventConsumer): |
| 169 | + """ |
| 170 | + A smart event queue that batches and throttles progress events. |
| 171 | +
|
| 172 | + Key features: |
| 173 | + - Progress events are deduplicated per nodeId (only latest is kept) |
| 174 | + - Batched events are flushed at regular intervals (default 50ms = 20 updates/sec) |
| 175 | + - Non-progress events pass through immediately |
| 176 | + - Thread-safe for use from executor threads |
| 177 | +
|
| 178 | + This significantly reduces CPU overhead when processing many iterations, |
| 179 | + as we avoid sending redundant intermediate progress updates. |
| 180 | + """ |
| 181 | + |
| 182 | + queue: EventConsumer |
| 183 | + flush_interval: float = 0.05 # 50ms = 20 updates/sec max |
| 184 | + |
| 185 | + # Internal state (initialized in __post_init__) |
| 186 | + _pending_progress: dict[NodeId, NodeProgressUpdateEvent] = field( |
| 187 | + default_factory=dict, init=False |
| 188 | + ) |
| 189 | + _lock: threading.Lock = field(default_factory=threading.Lock, init=False) |
| 190 | + _last_flush_time: float = field(default_factory=time.monotonic, init=False) |
| 191 | + _flush_scheduled: bool = field(default=False, init=False) |
| 192 | + _loop: asyncio.AbstractEventLoop | None = field(default=None, init=False) |
| 193 | + |
| 194 | + def set_loop(self, loop: asyncio.AbstractEventLoop) -> None: |
| 195 | + """Set the event loop for scheduling async flushes.""" |
| 196 | + self._loop = loop |
| 197 | + |
| 198 | + def put(self, event: Event) -> None: |
| 199 | + event_type = event["event"] |
| 200 | + |
| 201 | + # Progress events get batched and deduplicated |
| 202 | + if event_type == "node-progress": |
| 203 | + self._buffer_progress(event) # type: ignore |
| 204 | + return |
| 205 | + |
| 206 | + # Flush any pending progress before important state changes |
| 207 | + # This ensures progress is up-to-date before node-finish |
| 208 | + if event_type in ("node-finish", "chain-start", "execution-error"): |
| 209 | + self._flush_now() |
| 210 | + |
| 211 | + # All other events pass through immediately |
| 212 | + self.queue.put(event) |
| 213 | + |
| 214 | + def _buffer_progress(self, event: NodeProgressUpdateEvent) -> None: |
| 215 | + """Buffer a progress event, replacing any existing one for the same node.""" |
| 216 | + node_id = event["data"]["nodeId"] |
| 217 | + |
| 218 | + with self._lock: |
| 219 | + # Always keep the latest progress for each node |
| 220 | + self._pending_progress[node_id] = event |
| 221 | + |
| 222 | + # Check if we should flush based on time |
| 223 | + now = time.monotonic() |
| 224 | + elapsed = now - self._last_flush_time |
| 225 | + |
| 226 | + if elapsed >= self.flush_interval: |
| 227 | + # Time to flush immediately |
| 228 | + self._flush_locked() |
| 229 | + elif not self._flush_scheduled and self._loop is not None: |
| 230 | + # Schedule a flush for later |
| 231 | + # Use call_soon_threadsafe since this may be called from worker threads |
| 232 | + self._flush_scheduled = True |
| 233 | + delay = self.flush_interval - elapsed |
| 234 | + self._loop.call_soon_threadsafe( |
| 235 | + lambda d=delay: self._schedule_delayed_flush(d) |
| 236 | + ) |
| 237 | + |
| 238 | + def _schedule_delayed_flush(self, delay: float) -> None: |
| 239 | + """Schedule a delayed flush. Must be called from the event loop thread.""" |
| 240 | + if self._loop is not None: |
| 241 | + self._loop.call_later(delay, self._scheduled_flush) |
| 242 | + |
| 243 | + def _scheduled_flush(self) -> None: |
| 244 | + """Called by the event loop after the flush delay.""" |
| 245 | + with self._lock: |
| 246 | + self._flush_scheduled = False |
| 247 | + if self._pending_progress: |
| 248 | + self._flush_locked() |
| 249 | + |
| 250 | + def _flush_locked(self) -> None: |
| 251 | + """Flush pending progress events. Must be called with lock held.""" |
| 252 | + if not self._pending_progress: |
| 253 | + return |
| 254 | + |
| 255 | + # Send all pending progress events |
| 256 | + for event in self._pending_progress.values(): |
| 257 | + self.queue.put(event) |
| 258 | + |
| 259 | + self._pending_progress.clear() |
| 260 | + self._last_flush_time = time.monotonic() |
| 261 | + |
| 262 | + def _flush_now(self) -> None: |
| 263 | + """Immediately flush any pending progress events.""" |
| 264 | + with self._lock: |
| 265 | + self._flush_locked() |
| 266 | + |
| 267 | + def flush(self) -> None: |
| 268 | + """ |
| 269 | + Public method to flush pending events. |
| 270 | + Call this when execution completes to ensure all progress is sent. |
| 271 | + """ |
| 272 | + self._flush_now() |
| 273 | + |
| 274 | + def clear_node(self, node_id: NodeId) -> None: |
| 275 | + """ |
| 276 | + Clear any pending progress for a specific node. |
| 277 | + Useful when a node is aborted or restarted. |
| 278 | + """ |
| 279 | + with self._lock: |
| 280 | + self._pending_progress.pop(node_id, None) |
0 commit comments