Skip to content

Commit 1b27fe3

Browse files
committed
feat: enhance BasicHost shutdown
1 parent b9fa02f commit 1b27fe3

File tree

3 files changed

+55
-8
lines changed

3 files changed

+55
-8
lines changed

libp2p/host/basic_host.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def __init__(
256256
# Automatic identify coordination
257257
self._identify_inflight: set[ID] = set()
258258
self._identified_peers: set[ID] = set()
259+
self._identify_scopes: dict[ID, trio.CancelScope] = {}
259260
self._network.register_notifee(_IdentifyNotifee(self))
260261

261262
def get_id(self) -> ID:
@@ -842,6 +843,31 @@ async def disconnect(self, peer_id: ID) -> None:
842843
await self._network.close_peer(peer_id)
843844

844845
async def close(self) -> None:
846+
"""
847+
Close the host and its underlying network service.
848+
"""
849+
# Stop background services
850+
if self.mDNS is not None:
851+
self.mDNS.stop()
852+
853+
if self.bootstrap is not None:
854+
self.bootstrap.stop()
855+
856+
# Cleanup UPnP mappings if active
857+
if self.upnp and self.upnp.get_external_ip():
858+
try:
859+
logger.debug("Removing UPnP port mappings due to host closure")
860+
for addr in self.get_addrs():
861+
if port := addr.value_for_protocol("tcp"):
862+
await self.upnp.remove_port_mapping(int(port), "TCP")
863+
except Exception as e:
864+
logger.warning(f"Error removing UPnP mappings during close: {e}")
865+
866+
# Cancel inflight identify tasks
867+
for scope in list(self._identify_scopes.values()):
868+
scope.cancel()
869+
870+
# Close network
845871
await self._network.close()
846872

847873
def _schedule_identify(self, peer_id: ID, *, reason: str) -> None:
@@ -857,14 +883,28 @@ def _schedule_identify(self, peer_id: ID, *, reason: str) -> None:
857883
return
858884
if not self._should_identify_peer(peer_id):
859885
return
886+
860887
self._identify_inflight.add(peer_id)
861-
trio.lowlevel.spawn_system_task(self._identify_task_entry, peer_id, reason)
888+
889+
# Create a new cancel scope for this identify task
890+
cancel_scope = trio.CancelScope()
891+
self._identify_scopes[peer_id] = cancel_scope
892+
893+
trio.lowlevel.spawn_system_task(
894+
self._identify_task_entry, peer_id, reason, cancel_scope
895+
)
862896

863-
async def _identify_task_entry(self, peer_id: ID, reason: str) -> None:
897+
async def _identify_task_entry(
898+
self, peer_id: ID, reason: str, cancel_scope: trio.CancelScope
899+
) -> None:
864900
try:
865-
await self._identify_peer(peer_id, reason=reason)
901+
with cancel_scope:
902+
await self._identify_peer(peer_id, reason=reason)
866903
finally:
867904
self._identify_inflight.discard(peer_id)
905+
# Remove scope from tracking if it matches (to avoid race conditions)
906+
if self._identify_scopes.get(peer_id) is cancel_scope:
907+
self._identify_scopes.pop(peer_id, None)
868908

869909
def _has_cached_protocols(self, peer_id: ID) -> bool:
870910
"""

tests/core/host/test_basic_host.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,13 @@ async def fake_negotiate(comm, timeout):
5353
monkeypatch.setattr(host.multiselect, "negotiate", fake_negotiate)
5454

5555
# Now run the handler and expect StreamFailure
56-
with pytest.raises(
57-
StreamFailure, match="Failed to negotiate protocol: no protocol selected"
58-
):
59-
await host._swarm_stream_handler(net_stream)
56+
try:
57+
with pytest.raises(
58+
StreamFailure, match="Failed to negotiate protocol: no protocol selected"
59+
):
60+
await host._swarm_stream_handler(net_stream)
61+
finally:
62+
await host.close()
6063

6164
# Ensure reset was called since negotiation failed
6265
net_stream.reset.assert_awaited()

tests/utils/factories.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,11 @@ async def create_batch_and_listen(
530530
number, security_protocol=security_protocol, muxer_opt=muxer_opt
531531
) as swarms:
532532
hosts = tuple(BasicHost(swarm) for swarm in swarms)
533-
yield hosts
533+
try:
534+
yield hosts
535+
finally:
536+
for host in hosts:
537+
await host.close()
534538

535539

536540
class DummyRouter(IPeerRouting):

0 commit comments

Comments
 (0)