Skip to content

Commit a65d8e6

Browse files
authored
Cherrypick from 2.7 #4004 Fix RxTask self-deadlock on stream error cleanup (#4213)
## Summary - avoid calling `RxTask.stop()` while holding `RxTask.map_lock` in `find_or_create_task` - stop the existing task only after leaving the map lock to prevent self-deadlock - include regression coverage for both the pre-fix deadlock path and the fixed path Related: #4204 (2.7 branch)
1 parent 0b7b8c7 commit a65d8e6

File tree

2 files changed

+147
-2
lines changed

2 files changed

+147
-2
lines changed

nvflare/fuel/f3/streaming/byte_receiver.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTa
9292
sid = message.get_header(StreamHeaderKey.STREAM_ID)
9393
origin = message.get_header(MessageHeaderKey.ORIGIN)
9494
error = message.get_header(StreamHeaderKey.ERROR_MSG, None)
95+
task_to_stop = None
9596

9697
with cls.map_lock:
9798
task = cls.rx_task_map.get(sid, None)
@@ -104,8 +105,11 @@ def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTa
104105
cls.rx_task_map[sid] = task
105106
else:
106107
if error:
107-
task.stop(StreamError(f"{task} Received error from {origin}: {error}"), notify=False)
108-
return None
108+
task_to_stop = task
109+
110+
if task_to_stop:
111+
task_to_stop.stop(StreamError(f"{task_to_stop} Received error from {origin}: {error}"), notify=False)
112+
return None
109113

110114
return task
111115

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import threading
16+
from types import SimpleNamespace
17+
18+
import pytest
19+
20+
from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey
21+
from nvflare.fuel.f3.message import Message
22+
from nvflare.fuel.f3.streaming.byte_receiver import RxTask
23+
from nvflare.fuel.f3.streaming.stream_const import StreamHeaderKey
24+
from nvflare.fuel.f3.streaming.stream_types import StreamError
25+
26+
27+
@pytest.fixture(autouse=True)
28+
def clean_rx_task_map():
29+
with RxTask.map_lock:
30+
RxTask.rx_task_map.clear()
31+
yield
32+
with RxTask.map_lock:
33+
RxTask.rx_task_map.clear()
34+
35+
36+
def _make_message(origin: str, sid: int, error: str = None) -> Message:
37+
message = Message()
38+
headers = {
39+
StreamHeaderKey.STREAM_ID: sid,
40+
MessageHeaderKey.ORIGIN: origin,
41+
}
42+
if error is not None:
43+
headers[StreamHeaderKey.ERROR_MSG] = error
44+
message.add_headers(headers)
45+
return message
46+
47+
48+
class _DeadlockDetectingLock:
49+
"""Lock that raises on same-thread re-acquire to model Lock self-deadlock."""
50+
51+
def __init__(self):
52+
self._lock = threading.Lock()
53+
self._owner = None
54+
55+
def __enter__(self):
56+
acquired = self.acquire()
57+
if not acquired:
58+
raise RuntimeError("failed to acquire lock")
59+
return self
60+
61+
def __exit__(self, exc_type, exc_val, exc_tb):
62+
self.release()
63+
return False
64+
65+
def acquire(self, blocking=True, timeout=-1):
66+
tid = threading.get_ident()
67+
if self._owner == tid:
68+
raise RuntimeError("self-deadlock: same thread re-acquiring map_lock")
69+
acquired = self._lock.acquire(blocking, timeout)
70+
if acquired:
71+
self._owner = tid
72+
return acquired
73+
74+
def release(self):
75+
self._owner = None
76+
self._lock.release()
77+
78+
def locked(self):
79+
return self._lock.locked()
80+
81+
82+
def _pre_fix_find_or_create_task(message: Message, cell):
83+
"""Original buggy logic: calls stop() while map_lock is still held."""
84+
85+
sid = message.get_header(StreamHeaderKey.STREAM_ID)
86+
origin = message.get_header(MessageHeaderKey.ORIGIN)
87+
error = message.get_header(StreamHeaderKey.ERROR_MSG, None)
88+
89+
with RxTask.map_lock:
90+
task = RxTask.rx_task_map.get(sid, None)
91+
if not task:
92+
if error:
93+
return None
94+
task = RxTask(sid, origin, cell)
95+
RxTask.rx_task_map[sid] = task
96+
else:
97+
if error:
98+
task.stop(StreamError(f"{task} Received error from {origin}: {error}"), notify=False)
99+
return None
100+
return task
101+
102+
103+
def test_pre_fix_find_or_create_task_would_deadlock(monkeypatch):
104+
monkeypatch.setattr(RxTask, "map_lock", _DeadlockDetectingLock())
105+
106+
origin = "site-1"
107+
sid = 99
108+
fake_cell = SimpleNamespace()
109+
110+
create_message = _make_message(origin=origin, sid=sid)
111+
task = _pre_fix_find_or_create_task(create_message, fake_cell)
112+
assert task is not None
113+
114+
error_message = _make_message(origin=origin, sid=sid, error="stream failed")
115+
with pytest.raises(RuntimeError, match="self-deadlock"):
116+
_pre_fix_find_or_create_task(error_message, fake_cell)
117+
118+
119+
def test_find_or_create_task_stops_outside_map_lock(monkeypatch):
120+
origin = "site-1"
121+
sid = 123
122+
fake_cell = SimpleNamespace()
123+
124+
create_message = _make_message(origin=origin, sid=sid)
125+
task = RxTask.find_or_create_task(create_message, fake_cell)
126+
assert task is not None
127+
128+
stop_invocation = {"called": False, "lock_held": None}
129+
130+
def fake_stop(self, error=None, notify=True):
131+
stop_invocation["called"] = True
132+
stop_invocation["lock_held"] = RxTask.map_lock.locked()
133+
134+
monkeypatch.setattr(RxTask, "stop", fake_stop)
135+
136+
error_message = _make_message(origin=origin, sid=sid, error="stream failed")
137+
returned_task = RxTask.find_or_create_task(error_message, fake_cell)
138+
139+
assert returned_task is None
140+
assert stop_invocation["called"] is True
141+
assert stop_invocation["lock_held"] is False

0 commit comments

Comments
 (0)