Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions src/agent/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2244,15 +2244,22 @@ impl Channel {
success,
..
} => {
let mut workers = self.state.active_workers.write().await;
if workers.remove(worker_id).is_none() {
// Use worker_handles as the source of truth for active workers.
// (active_workers is never populated because Worker is consumed by .run())
if self
.state
.worker_handles
.write()
.await
.remove(worker_id)
.is_none()
{
return Ok(());
}
drop(workers);

run_logger.log_worker_completed(*worker_id, result, *success);

self.state.worker_handles.write().await.remove(worker_id);
self.state.active_workers.write().await.remove(worker_id);
self.state.worker_inputs.write().await.remove(worker_id);
Comment on lines +2247 to 2263

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor robustness tweak: the early return on missing worker_handles skips cleanup of worker_inputs (and any future active_workers usage) if we ever see duplicate/late WorkerComplete events. Consider always cleaning those up, and reword the comment to avoid the absolute “never populated” claim.

Suggested change
// Use worker_handles as the source of truth for active workers.
// (active_workers is never populated because Worker is consumed by .run())
if self.state.worker_handles.write().await.remove(worker_id).is_none() {
return Ok(());
}
drop(workers);
run_logger.log_worker_completed(*worker_id, result, *success);
self.state.worker_handles.write().await.remove(worker_id);
self.state.active_workers.write().await.remove(worker_id);
self.state.worker_inputs.write().await.remove(worker_id);
// `worker_handles` is the source of truth for running workers.
let removed_handle = self.state.worker_handles.write().await.remove(worker_id);
self.state.active_workers.write().await.remove(worker_id);
self.state.worker_inputs.write().await.remove(worker_id);
if removed_handle.is_none() {
return Ok(());
}
run_logger.log_worker_completed(*worker_id, result, *success);


if *notify {
Expand Down Expand Up @@ -2619,4 +2626,49 @@ mod tests {

assert!(should_process_event_for_channel(&event, &channel_id));
}

#[test]
fn worker_complete_event_matches_own_channel() {
let channel_id: ChannelId = Arc::from("channel-a");
let event = ProcessEvent::WorkerComplete {
agent_id: Arc::from("agent"),
worker_id: uuid::Uuid::new_v4(),
channel_id: Some(channel_id.clone()),
result: "done".to_string(),
notify: true,
success: true,
};

assert!(should_process_event_for_channel(&event, &channel_id));
}

#[test]
fn worker_complete_event_ignored_for_other_channel() {
let channel_id: ChannelId = Arc::from("channel-a");
let event = ProcessEvent::WorkerComplete {
agent_id: Arc::from("agent"),
worker_id: uuid::Uuid::new_v4(),
channel_id: Some(Arc::from("channel-b")),
result: "done".to_string(),
notify: true,
success: true,
};

assert!(!should_process_event_for_channel(&event, &channel_id));
}

#[test]
fn worker_complete_event_ignored_when_no_channel() {
let channel_id: ChannelId = Arc::from("channel-a");
let event = ProcessEvent::WorkerComplete {
agent_id: Arc::from("agent"),
worker_id: uuid::Uuid::new_v4(),
channel_id: None,
result: "done".to_string(),
notify: true,
success: true,
};

assert!(!should_process_event_for_channel(&event, &channel_id));
}
}
36 changes: 36 additions & 0 deletions src/agent/channel_dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,4 +739,40 @@ mod tests {
other => panic!("unexpected event: {other:?}"),
}
}

#[tokio::test]
async fn spawn_worker_task_carries_channel_id() {
let (event_tx, mut event_rx) = broadcast::channel(8);
let worker_id: WorkerId = Uuid::new_v4();
let channel_id: crate::ChannelId = Arc::from("test-channel");

let handle = spawn_worker_task(
worker_id,
event_tx,
Arc::<str>::from("agent"),
Some(channel_id.clone()),
None,
async { Ok::<String, crate::Error>("result".to_string()) },
);

let event = tokio::time::timeout(Duration::from_secs(2), event_rx.recv())
.await
.expect("worker completion event should be delivered")
.expect("broadcast receive should succeed");
handle.await.expect("worker task should join cleanly");

match event {
ProcessEvent::WorkerComplete {
channel_id: event_channel_id,
worker_id: completed_worker_id,
success,
..
} => {
assert_eq!(completed_worker_id, worker_id);
assert_eq!(event_channel_id, Some(channel_id));
assert!(success);
}
other => panic!("unexpected event: {other:?}"),
}
}
}