Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions docs/user_guide/timeout_troubleshooting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,78 @@ Via Configuration Files
}


Streaming Stall Guardrail (``comm_config.json``)
------------------------------------------------

For large payload/model transfers, configure F3 stream stall detection in
``comm_config.json`` (server and client startup kits).

**Runtime defaults** (if not set explicitly):

- ``streaming_send_timeout``: ``30.0`` seconds
- ``streaming_ack_progress_timeout``: ``60.0`` seconds
- ``streaming_ack_progress_check_interval``: ``5.0`` seconds
- ``sfm_send_stall_timeout``: ``45.0`` seconds
- ``sfm_close_stalled_connection``: ``false`` (warn-only)
- ``sfm_send_stall_consecutive_checks``: ``3``

**Recommended deployment guideline**:

1. Start with **warn-only** to observe behavior safely.
2. If repeated stall warnings are observed during large-model streaming, enable auto-close.
3. Keep the guard enabled with consecutive checks to reduce false alarms.

Warn-only baseline:

.. code-block:: json

{
"sfm_close_stalled_connection": false,
"sfm_send_stall_timeout": 75,
"sfm_send_stall_consecutive_checks": 3
}

Auto-recovery mode (when needed):

.. code-block:: json

{
"sfm_close_stalled_connection": true,
"sfm_send_stall_timeout": 75,
"sfm_send_stall_consecutive_checks": 3
}

**Timing relationship (important)**:

- ``sfm_send_stall_timeout`` is compared against the total continuous blocked-send duration.
- ``sfm_send_stall_consecutive_checks`` counts consecutive heartbeat monitor ticks (every 5 seconds),
not multiples of ``sfm_send_stall_timeout``.

Approximate auto-close window (when ``sfm_close_stalled_connection=true``):

.. code-block:: text

close_lower_bound ~= sfm_send_stall_timeout
close_upper_bound ~= sfm_send_stall_timeout + (HEARTBEAT_TICK * sfm_send_stall_consecutive_checks)

With ``sfm_send_stall_timeout=75`` and ``sfm_send_stall_consecutive_checks=3``, close typically occurs
around ``75``-``90`` seconds of continuous stall (not 225 seconds).

**Outer-timeout guideline**:

Set higher-layer timeouts (for example ``communication_timeout`` or task/request timeouts that include
message transfer time) greater than ``close_upper_bound`` plus a safety margin.

Example: ``communication_timeout=300`` is safely larger than the ~``90`` second stall auto-close window.

**How to interpret logs**:

- Expected warning on real stalls:
``Detected stalled send on ... (N/3)``
- In healthy/normal streaming, no stall warning should be emitted.
- Intermittent stalls should not close the connection unless the threshold is reached in consecutive checks.


Recommended Settings by Scenario
================================

Expand Down
24 changes: 24 additions & 0 deletions nvflare/fuel/f3/comm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class VarName:
STREAMING_ACK_INTERVAL = "streaming_ack_interval"
STREAMING_MAX_OUT_SEQ_CHUNKS = "streaming_max_out_seq_chunks"
STREAMING_READ_TIMEOUT = "streaming_read_timeout"
STREAMING_SEND_TIMEOUT = "streaming_send_timeout"
STREAMING_ACK_PROGRESS_TIMEOUT = "streaming_ack_progress_timeout"
STREAMING_ACK_PROGRESS_CHECK_INTERVAL = "streaming_ack_progress_check_interval"
SFM_SEND_STALL_TIMEOUT = "sfm_send_stall_timeout"
SFM_CLOSE_STALLED_CONNECTION = "sfm_close_stalled_connection"
SFM_SEND_STALL_CONSECUTIVE_CHECKS = "sfm_send_stall_consecutive_checks"


class CommConfigurator:
Expand Down Expand Up @@ -114,6 +120,24 @@ def get_streaming_max_out_seq_chunks(self, default):
def get_streaming_read_timeout(self, default):
return ConfigService.get_int_var(VarName.STREAMING_READ_TIMEOUT, self.config, default)

def get_streaming_send_timeout(self, default):
return ConfigService.get_float_var(VarName.STREAMING_SEND_TIMEOUT, self.config, default=default)

def get_streaming_ack_progress_timeout(self, default):
return ConfigService.get_float_var(VarName.STREAMING_ACK_PROGRESS_TIMEOUT, self.config, default=default)

def get_streaming_ack_progress_check_interval(self, default):
return ConfigService.get_float_var(VarName.STREAMING_ACK_PROGRESS_CHECK_INTERVAL, self.config, default=default)

def get_sfm_send_stall_timeout(self, default):
return ConfigService.get_float_var(VarName.SFM_SEND_STALL_TIMEOUT, self.config, default=default)

def get_sfm_close_stalled_connection(self, default=False):
return ConfigService.get_bool_var(VarName.SFM_CLOSE_STALLED_CONNECTION, self.config, default=default)

def get_sfm_send_stall_consecutive_checks(self, default=3):
return ConfigService.get_int_var(VarName.SFM_SEND_STALL_CONSECUTIVE_CHECKS, self.config, default=default)

def get_int_var(self, name: str, default=None):
return ConfigService.get_int_var(name, self.config, default=default)

Expand Down
31 changes: 30 additions & 1 deletion nvflare/fuel/f3/drivers/socket_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import select
import socket
import time
from socketserver import BaseRequestHandler
from typing import Any, Union

from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import BytesAlike, Connection
from nvflare.fuel.f3.drivers.driver import ConnectorInfo
Expand All @@ -35,6 +38,7 @@ def __init__(self, sock: Any, connector: ConnectorInfo, secure: bool = False):
self.secure = secure
self.closing = False
self.conn_props = self._get_socket_properties()
self.send_timeout = CommConfigurator().get_streaming_send_timeout(30.0)

def get_conn_properties(self) -> dict:
return self.conn_props
Expand All @@ -52,11 +56,36 @@ def close(self):

def send_frame(self, frame: BytesAlike):
try:
self.sock.sendall(frame)
self._send_with_timeout(frame, self.send_timeout)
except CommError as error:
if not self.closing:
# A send timeout may occur after partial bytes are already written to the stream.
# Close the connection to avoid frame-boundary desync on subsequent sends.
if error.code == CommError.TIMEOUT:
self.close()
raise
except Exception as ex:
if not self.closing:
raise CommError(CommError.ERROR, f"Error sending frame on conn {self}: {secure_format_exception(ex)}")

def _send_with_timeout(self, frame: BytesAlike, timeout_sec: float):
view = memoryview(frame)
deadline = time.monotonic() + timeout_sec
while view:
remaining = deadline - time.monotonic()
if remaining <= 0:
raise CommError(CommError.TIMEOUT, f"send_frame timeout after {timeout_sec} seconds on {self.name}")

_, writable, _ = select.select([], [self.sock], [], remaining)
if not writable:
raise CommError(CommError.TIMEOUT, f"send_frame timeout after {timeout_sec} seconds on {self.name}")

sent = self.sock.send(view)
if sent <= 0:
raise CommError(CommError.CLOSED, f"Connection {self.name} is closed while sending")

view = view[sent:]

def read_loop(self):
try:
self.read_frame_loop()
Expand Down
31 changes: 30 additions & 1 deletion nvflare/fuel/f3/sfm/heartbeat_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

HEARTBEAT_TICK = 5
DEFAULT_HEARTBEAT_INTERVAL = 60
DEFAULT_SEND_STALL_CONSECUTIVE_CHECKS = 3


class HeartbeatMonitor(Thread):
Expand All @@ -33,7 +34,14 @@ def __init__(self, conns: Dict[str, SfmConnection]):
self.conns = conns
self.stopped = Event()
self.curr_time = 0
self.interval = CommConfigurator().get_heartbeat_interval(DEFAULT_HEARTBEAT_INTERVAL)
config = CommConfigurator()
self.interval = config.get_heartbeat_interval(DEFAULT_HEARTBEAT_INTERVAL)
self.send_stall_timeout = config.get_sfm_send_stall_timeout(45.0)
self.close_stalled_connection = config.get_sfm_close_stalled_connection(False)
self.stall_consecutive_checks = max(
1, config.get_sfm_send_stall_consecutive_checks(DEFAULT_SEND_STALL_CONSECUTIVE_CHECKS)
)
self.stall_counts = {}
if self.interval < HEARTBEAT_TICK:
log.warning(f"Heartbeat interval is too small ({self.interval} < {HEARTBEAT_TICK})")

Expand All @@ -55,7 +63,24 @@ def run(self):

def _check_heartbeat(self):

active_keys = set()
for sfm_conn in self.conns.values():
conn_key = sfm_conn.get_name() if hasattr(sfm_conn, "get_name") else str(id(sfm_conn))
active_keys.add(conn_key)

stall_sec = sfm_conn.get_send_stall_seconds()
if stall_sec > self.send_stall_timeout:
count = self.stall_counts.get(conn_key, 0) + 1
self.stall_counts[conn_key] = count
log.warning(
f"Detected stalled send on {sfm_conn.conn}: blocked {stall_sec:.1f}s "
f"({count}/{self.stall_consecutive_checks})"
)
if self.close_stalled_connection and count >= self.stall_consecutive_checks:
sfm_conn.conn.close()
continue

self.stall_counts[conn_key] = 0

driver = sfm_conn.conn.connector.driver
caps = driver.capabilities()
Expand All @@ -65,3 +90,7 @@ def _check_heartbeat(self):
if self.curr_time - sfm_conn.last_activity > self.interval:
sfm_conn.send_heartbeat(Types.PING)
log.debug(f"Heartbeat sent to connection: {sfm_conn.conn}")

stale_keys = [k for k in self.stall_counts.keys() if k not in active_keys]
for k in stale_keys:
self.stall_counts.pop(k, None)
16 changes: 15 additions & 1 deletion nvflare/fuel/f3/sfm/sfm_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(self, conn: Connection, local_endpoint: Endpoint):
self.last_activity = 0
self.sequence = 0
self.lock = threading.Lock()
self.send_state_lock = threading.Lock()
self.send_started_at = 0.0

def get_name(self) -> str:
return self.conn.name
Expand Down Expand Up @@ -145,7 +147,19 @@ def send_frame(self, prefix: Prefix, headers: Optional[dict], payload: Optional[
log.debug(f"Sending frame: {prefix} on {self.conn}")
# Only one thread can send data on a connection. Otherwise, the frames may interleave.
with self.lock:
self.conn.send_frame(buffer)
with self.send_state_lock:
self.send_started_at = time.monotonic()
try:
self.conn.send_frame(buffer)
finally:
with self.send_state_lock:
self.send_started_at = 0.0

def get_send_stall_seconds(self) -> float:
with self.send_state_lock:
if self.send_started_at <= 0.0:
return 0.0
return time.monotonic() - self.send_started_at

@staticmethod
def headers_to_bytes(headers: Optional[dict]) -> Optional[bytes]:
Expand Down
31 changes: 24 additions & 7 deletions nvflare/fuel/f3/streaming/byte_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import logging
import threading
import time
from typing import Callable, Optional

from nvflare.fuel.f3.cellnet.core_cell import CoreCell
Expand Down Expand Up @@ -87,8 +88,13 @@ def __init__(
self.stream_future = StreamFuture(self.sid, task_handle=self)
self.stream_future.set_size(stream.get_size())

self.window_size = CommConfigurator().get_streaming_window_size(STREAM_WINDOW_SIZE)
self.ack_wait = CommConfigurator().get_streaming_ack_wait(STREAM_ACK_WAIT)
config = CommConfigurator()
self.window_size = config.get_streaming_window_size(STREAM_WINDOW_SIZE)
self.ack_wait = config.get_streaming_ack_wait(STREAM_ACK_WAIT)
self.ack_progress_timeout = config.get_streaming_ack_progress_timeout(60.0)
# Guard against zero/negative config to avoid wait(0) busy-spin loops.
self.ack_progress_check_interval = max(0.01, config.get_streaming_ack_progress_check_interval(5.0))
self.last_ack_progress_ts = time.monotonic()

def __str__(self):
return f"Tx[SID:{self.sid} to {self.target} for {self.channel}/{self.topic}]"
Expand All @@ -109,13 +115,23 @@ def send_loop(self):
# It may take several ACKs to clear up the window
while window > self.window_size:
log.debug(f"{self} window size {window} exceeds limit: {self.window_size}")
self.ack_waiter.clear()
wait_start = time.monotonic()

if not self.ack_waiter.wait(timeout=self.ack_wait):
self.stop(StreamError(f"{self} ACK timeouts after {self.ack_wait} seconds"))
return
while window > self.window_size:
now = time.monotonic()
if now - self.last_ack_progress_ts >= self.ack_progress_timeout:
self.stop(StreamError(f"{self} ACK made no progress for {self.ack_progress_timeout} seconds"))
return

window = self.offset - self.offset_ack
elapsed = now - wait_start
if elapsed >= self.ack_wait:
self.stop(StreamError(f"{self} ACK timeouts after {self.ack_wait} seconds"))
return

self.ack_waiter.clear()
wait_timeout = min(self.ack_progress_check_interval, self.ack_wait - elapsed)
self.ack_waiter.wait(timeout=wait_timeout)
window = self.offset - self.offset_ack

size = len(buf)
if size > self.chunk_size:
Expand Down Expand Up @@ -231,6 +247,7 @@ def handle_ack(self, message: Message):

if offset > self.offset_ack:
self.offset_ack = offset
self.last_ack_progress_ts = time.monotonic()

if not self.ack_waiter.is_set():
self.ack_waiter.set()
Expand Down
Loading