Skip to content

Commit 1d2f20c

Browse files
authored
[2.7] Avoid self-message deadlock for local swarm result submission (#4186)
## Summary - avoid synchronous self-message path when trainer submits learn result to itself (aggr == self.me) - process local submission via _process_learn_result with local peer context, while keeping remote path unchanged - add unit coverage to verify local self-aggregation submission does not call broadcast_and_wait ## Problem PR #4141 fixed self-message deadlock in _scatter, but result submission in do_learn_task still used broadcast_and_wait(targets=[aggr]). When aggr == self.me with tensor streaming enabled, this can deadlock in synchronous self-message processing. ## Test Plan - added focused unit test in tests/unit_test/app_common/ccwf/test_swarm_self_message_deadlock.py - validated syntax locally for modified files - full pytest not run in this environment (pytest not available)
1 parent 2a95c96 commit 1d2f20c

File tree

2 files changed

+107
-15
lines changed

2 files changed

+107
-15
lines changed

nvflare/app_common/ccwf/swarm_client_ctl.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -845,23 +845,31 @@ def do_learn_task(self, name: str, task_data: Shareable, fl_ctx: FLContext, abor
845845
time.sleep(self.request_to_submit_result_interval)
846846

847847
# send the result to the aggr
848-
self.log_info(fl_ctx, f"sending training result to aggregation client {aggr}")
848+
if aggr == self.me:
849+
# Avoid synchronous self-message path through CoreCell._send_direct_message.
850+
self.log_info(fl_ctx, "submitting training result locally (aggregation client is self)")
851+
engine = fl_ctx.get_engine()
852+
local_fl_ctx = fl_ctx.clone()
853+
local_fl_ctx.set_peer_context(engine.new_context())
854+
reply = self._process_learn_result(result, local_fl_ctx, abort_signal)
855+
else:
856+
self.log_info(fl_ctx, f"sending training result to aggregation client {aggr}")
849857

850-
task = Task(
851-
name=self.report_learn_result_task_name,
852-
data=result,
853-
timeout=int(self.learn_task_ack_timeout),
854-
secure=self.is_task_secure(fl_ctx),
855-
)
858+
task = Task(
859+
name=self.report_learn_result_task_name,
860+
data=result,
861+
timeout=int(self.learn_task_ack_timeout),
862+
secure=self.is_task_secure(fl_ctx),
863+
)
856864

857-
resp = self.broadcast_and_wait(
858-
task=task,
859-
targets=[aggr],
860-
min_responses=1,
861-
fl_ctx=fl_ctx,
862-
)
865+
resp = self.broadcast_and_wait(
866+
task=task,
867+
targets=[aggr],
868+
min_responses=1,
869+
fl_ctx=fl_ctx,
870+
)
863871

864-
reply = resp.get(aggr)
872+
reply = resp.get(aggr)
865873
if not reply:
866874
self.log_error(fl_ctx, f"failed to receive reply from aggregation client: {aggr}")
867875
self.update_status(action="receive_learn_result_reply", error=ReturnCode.EXECUTION_EXCEPTION)

tests/unit_test/app_common/ccwf/test_swarm_self_message_deadlock.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,17 @@
3434
import threading
3535
import time
3636
import unittest
37-
37+
from types import SimpleNamespace
38+
from unittest import mock
39+
40+
from nvflare.apis.fl_constant import ReturnCode as FLReturnCode
41+
from nvflare.apis.fl_context import FLContextManager
42+
from nvflare.apis.shareable import Shareable, make_reply
43+
from nvflare.apis.signal import Signal
44+
from nvflare.app_common.abstract.learnable import Learnable
45+
from nvflare.app_common.app_constant import AppConstants
46+
from nvflare.app_common.ccwf.common import Constant
47+
from nvflare.app_common.ccwf.swarm_client_ctl import SwarmClientController
3848
from nvflare.fuel.f3.cellnet.core_cell import CoreCell, Message, MessageHeaderKey, TargetMessage
3949
from nvflare.fuel.f3.cellnet.defs import ReturnCode
4050
from nvflare.fuel.utils.network_utils import get_open_ports
@@ -535,5 +545,79 @@ def blocking_handler(message: Message):
535545
self.assertTrue(deadlock_detected.is_set(), "Deadlock should be detected - tensor wait timed out")
536546

537547

548+
class TestSwarmResultSubmissionFix(unittest.TestCase):
549+
def test_local_submit_when_aggregator_is_self(self):
550+
class _DummyGatherer:
551+
def __init__(self, **kwargs):
552+
self.for_round = kwargs.get("for_round", 0)
553+
554+
class _DummyEngine:
555+
def __init__(self):
556+
self.submit_req_calls = 0
557+
558+
def send_aux_request(self, **kwargs):
559+
self.submit_req_calls += 1
560+
return {"site-1": make_reply(FLReturnCode.OK)}
561+
562+
def new_context(self):
563+
return FLContextManager(engine=self, identity_name="site-1", job_id="job").new_context()
564+
565+
engine = _DummyEngine()
566+
fl_ctx = FLContextManager(engine=engine, identity_name="site-1", job_id="job").new_context()
567+
abort_signal = Signal()
568+
569+
task_data = Shareable()
570+
task_data.set_header(AppConstants.CURRENT_ROUND, 1)
571+
task_data.set_header(Constant.AGGREGATOR, "site-1")
572+
573+
learn_result = make_reply(FLReturnCode.OK)
574+
575+
ctl = object.__new__(SwarmClientController)
576+
ctl.me = "site-1"
577+
ctl.is_trainer = True
578+
ctl.gatherer = None
579+
ctl.gatherer_waiter = threading.Event()
580+
ctl.metric_comparator = object()
581+
ctl.trainers = ["site-1"]
582+
ctl.learn_task_timeout = 10
583+
ctl.min_responses_required = 1
584+
ctl.wait_time_after_min_resps_received = 0
585+
ctl.aggregator = object()
586+
ctl.max_concurrent_submissions = 1
587+
ctl.request_to_submit_result_max_wait = 10
588+
ctl.request_to_submit_result_msg_timeout = 1
589+
ctl.request_to_submit_result_interval = 0
590+
ctl.request_to_submit_learn_result_task_name = "request_submit"
591+
ctl.report_learn_result_task_name = "report_result"
592+
ctl.learn_task_ack_timeout = 5
593+
ctl.shareable_generator = SimpleNamespace(shareable_to_learnable=lambda _task_data, _ctx: Learnable())
594+
ctl.get_config_prop = lambda key, default=None: ["site-1"] if key == Constant.CLIENTS else default
595+
ctl.execute_learn_task = lambda _task_data, _ctx, _abort_signal: learn_result
596+
ctl.is_task_secure = lambda _ctx: False
597+
ctl.update_status = lambda **kwargs: None
598+
ctl.fire_event = lambda *_args, **_kwargs: None
599+
ctl.log_info = lambda *_args, **_kwargs: None
600+
ctl.log_debug = lambda *_args, **_kwargs: None
601+
ctl.log_warning = lambda *_args, **_kwargs: None
602+
ctl.log_error = lambda *_args, **_kwargs: None
603+
ctl.broadcast_and_wait = mock.Mock(
604+
side_effect=AssertionError("broadcast_and_wait must not be called for local result submission")
605+
)
606+
ctl._process_learn_result = mock.Mock(return_value=make_reply(FLReturnCode.OK))
607+
608+
with mock.patch("nvflare.app_common.ccwf.swarm_client_ctl.Gatherer", _DummyGatherer):
609+
ctl.do_learn_task("train", task_data, fl_ctx, abort_signal)
610+
611+
ctl.broadcast_and_wait.assert_not_called()
612+
ctl._process_learn_result.assert_called_once()
613+
self.assertEqual(engine.submit_req_calls, 1, "submission permission request should still be sent once")
614+
615+
called_result, called_fl_ctx, called_abort_signal = ctl._process_learn_result.call_args[0]
616+
self.assertIs(called_result, learn_result)
617+
self.assertIs(called_abort_signal, abort_signal)
618+
self.assertIsNot(called_fl_ctx, fl_ctx)
619+
self.assertEqual(called_fl_ctx.get_peer_context().get_identity_name(), "site-1")
620+
621+
538622
if __name__ == "__main__":
539623
unittest.main()

0 commit comments

Comments
 (0)