From fff0bae85e7fce90bdd62ae104b4127534db1478 Mon Sep 17 00:00:00 2001 From: cjchanh Date: Fri, 10 Apr 2026 18:09:17 -0600 Subject: [PATCH] fix: keep worker accept loop alive on transient TCP errors The worker accept loop (`Worker::run`) previously used `while let Ok(...) = self.listener.accept()` which silently exited on any accept error, permanently killing the TCP listener. This caused a "one connection works, then dead worker" failure mode, particularly severe on iOS where the first master disconnect left the worker unreachable until a full app restart. Changes: - Hold TcpListener in Option for in-place rebinding - Classify transient errors (ECONNABORTED, EINTR, TimedOut, WouldBlock) and retry immediately instead of exiting - On fatal accept errors, drop and rebind the listener on the same port without reloading model weights - Log listener fd and local address before each accept cycle for diagnostics - Add describe_listener() helper for cross-platform fd reporting Tested: 20 sequential inference requests over 6 minutes on an iOS worker (iPad M3) with zero drops or reconnections needed. Fixes #79 --- cake-core/src/cake/sharding/worker.rs | 173 +++++++++++++++++++++----- 1 file changed, 139 insertions(+), 34 deletions(-) diff --git a/cake-core/src/cake/sharding/worker.rs b/cake-core/src/cake/sharding/worker.rs index a0efe2c..7b2dfe0 100644 --- a/cake-core/src/cake/sharding/worker.rs +++ b/cake-core/src/cake/sharding/worker.rs @@ -1,12 +1,15 @@ +#[cfg(unix)] +use std::os::fd::AsRawFd; use std::{ collections::HashMap, + io::ErrorKind, net::SocketAddr, sync::Arc, time::{Duration, Instant}, }; -use crate::cake::{Context, Forwarder}; use super::{Message, WorkerInfo}; +use crate::cake::{Context, Forwarder}; use crate::models::Generator; use anyhow::Result; @@ -77,11 +80,45 @@ impl WorkerContext { /// Cake worker node. pub struct Worker { - listener: TcpListener, + listener: Option, context: WorkerContext, } impl Worker { + fn describe_listener(listener: &TcpListener) -> String { + let addr = listener + .local_addr() + .map(|addr| addr.to_string()) + .unwrap_or_else(|_| "".to_string()); + #[cfg(unix)] + { + format!("{addr} fd={}", listener.as_raw_fd()) + } + #[cfg(not(unix))] + { + addr + } + } + + fn accept_error_is_transient(kind: ErrorKind) -> bool { + matches!( + kind, + ErrorKind::ConnectionAborted + | ErrorKind::Interrupted + | ErrorKind::TimedOut + | ErrorKind::WouldBlock + ) + } + + async fn bind_listener(bind_address: &str) -> Result { + let listener = TcpListener::bind(bind_address).await?; + log::info!( + "worker listener bound on {}", + Self::describe_listener(&listener) + ); + Ok(listener) + } + /// Detect how many CUDA devices are available. fn detect_cuda_device_count() -> usize { #[cfg(feature = "cuda")] @@ -199,17 +236,14 @@ impl Worker { .context() .bind_to_thread() .map_err(|e| { - anyhow!( - "failed to bind CUDA context for GPU {gpu_idx}: {e:?}" - ) + anyhow!("failed to bind CUDA context for GPU {gpu_idx}: {e:?}") })?; } let mut results = Vec::new(); for layer_name in layers { log::info!("loading {} on cuda:{} ...", &layer_name, gpu_idx); - let block = - G::Shardable::load(layer_name.clone(), &thread_ctx)?; + let block = G::Shardable::load(layer_name.clone(), &thread_ctx)?; results.push((layer_name, dev.clone(), block)); } Ok(results) @@ -243,16 +277,24 @@ impl Worker { let listener = { let taken = ctx.listener_override.lock().unwrap().take(); if let Some(existing) = taken { + log::info!( + "using pre-bound worker listener {}", + Self::describe_listener(&existing) + ); existing } else { - TcpListener::bind(&ctx.args.address).await? + Self::bind_listener(&ctx.args.address).await? } }; log::info!( "listening on {} (mem:{}) ...", &ctx.args.address, - human_bytes::human_bytes(memory_stats::memory_stats().map(|m| m.physical_mem).unwrap_or(0) as f64) + human_bytes::human_bytes( + memory_stats::memory_stats() + .map(|m| m.physical_mem) + .unwrap_or(0) as f64 + ) ); let device = ctx.device.clone(); @@ -268,7 +310,10 @@ impl Worker { context: ctx.clone(), }; - Ok(Self { listener, context }) + Ok(Self { + listener: Some(listener), + context, + }) } /// Read a message from the socket and return elapsed time, message size and message. @@ -355,12 +400,19 @@ impl Worker { let mut write_buf = Vec::with_capacity(64 * 1024); // keep reading messages - while let Ok((read_time, read_size, op_message)) = { - let start = Instant::now(); - Message::from_reader_buf(&mut socket, &mut read_buf) - .await - .map(|(size, msg)| (start.elapsed(), size, msg)) - } { + loop { + let (read_time, read_size, op_message) = match { + let start = Instant::now(); + Message::from_reader_buf(&mut socket, &mut read_buf) + .await + .map(|(size, msg)| (start.elapsed(), size, msg)) + } { + Ok(result) => result, + Err(e) => { + log::info!("[{}] connection loop ended: {}", &client, e); + break; + } + }; if matches!(op_message, Message::Goodbye) { log::debug!("[{}] goodbye", &client); context @@ -574,24 +626,67 @@ impl Worker { msg_idx += 1; } + log::info!("[{}] handler exiting", &client); Ok(()) } /// Run the worker server accept loop. pub async fn run(&mut self) -> Result<()> { - while let Ok((socket, client)) = self.listener.accept().await { - let _ = socket.set_nodelay(true); - log::debug!("{} connected", &client); - - let context = self.context.get_client_context(); - tokio::spawn(async move { - if let Err(e) = Self::handle_master_client(socket, client, context).await { - log::error!("{}", e); + loop { + let listener_desc = self + .listener + .as_ref() + .map(Self::describe_listener) + .unwrap_or_else(|| "".to_string()); + log::info!("worker accept loop awaiting master on {}", listener_desc); + + let accept_result = match self.listener.as_mut() { + Some(listener) => listener.accept().await, + None => { + let bind_address = self.context.context.args.address.clone(); + log::warn!( + "worker listener missing before accept; rebinding on {}", + bind_address + ); + self.listener = Some(Self::bind_listener(&bind_address).await?); + continue; } - }); - } + }; - Ok(()) + match accept_result { + Ok((socket, client)) => { + let _ = socket.set_nodelay(true); + log::info!("[{}] accepted on {}", &client, listener_desc); + + let context = self.context.get_client_context(); + tokio::spawn(async move { + if let Err(e) = Self::handle_master_client(socket, client, context).await { + log::error!("{}", e); + } + }); + } + Err(e) if Self::accept_error_is_transient(e.kind()) => { + log::warn!( + "transient accept error on {}: {} ({:?})", + listener_desc, + e, + e.kind() + ); + } + Err(e) => { + let bind_address = self.context.context.args.address.clone(); + log::error!( + "accept failed on {}: {} ({:?}); dropping listener and rebinding {}", + listener_desc, + e, + e.kind(), + bind_address + ); + self.listener.take(); + self.listener = Some(Self::bind_listener(&bind_address).await?); + } + } + } } } @@ -742,7 +837,10 @@ mod tests { // New context should have a fresh cache (as_new clears KV entries) assert!(client_ctx.context.cache.is_some()); // Device and dtype should be copied - assert_eq!(format!("{:?}", client_ctx.device), format!("{:?}", Device::Cpu)); + assert_eq!( + format!("{:?}", client_ctx.device), + format!("{:?}", Device::Cpu) + ); assert_eq!(client_ctx.dtype, DType::F32); } @@ -760,8 +858,9 @@ mod tests { let (mut server, mut client) = duplex(65536); // Write from server side - let (write_dur, write_size) = - >::write_message_timed(&mut server, msg).await.unwrap(); + let (write_dur, write_size) = >::write_message_timed(&mut server, msg) + .await + .unwrap(); assert!(write_size > 0); assert!(write_dur.as_nanos() > 0); @@ -773,7 +872,12 @@ mod tests { // Verify the message was correctly serialized/deserialized match read_msg { - Message::SingleOp { layer_name, x, index_pos, block_idx } => { + Message::SingleOp { + layer_name, + x, + index_pos, + block_idx, + } => { assert_eq!(layer_name, "test_layer"); assert_eq!(index_pos, 0); assert_eq!(block_idx, 0); @@ -797,10 +901,11 @@ mod tests { let msg = Message::from_batch(&tensor, batch); let (mut server, mut client) = duplex(65536); - >::write_message_timed(&mut server, msg).await.unwrap(); + >::write_message_timed(&mut server, msg) + .await + .unwrap(); - let (_dur, _size, read_msg) = - >::read_message_timed(&mut client).await.unwrap(); + let (_dur, _size, read_msg) = >::read_message_timed(&mut client).await.unwrap(); match read_msg { Message::Batch { x, batch } => { let t = x.to_tensor(&Device::Cpu).unwrap();