99from contextlib import suppress
1010from http import HTTPStatus
1111from http .cookies import SimpleCookie
12- from itertools import cycle , islice
12+ from itertools import chain , cycle , islice
1313from time import monotonic
1414from types import TracebackType
1515from typing import (
5050)
5151from .client_proto import ResponseHandler
5252from .client_reqrep import ClientRequest , Fingerprint , _merge_ssl_params
53- from .helpers import ceil_timeout , is_ip_address , noop , sentinel
54- from .locks import EventResultOrError
53+ from .helpers import (
54+ ceil_timeout ,
55+ is_ip_address ,
56+ noop ,
57+ sentinel ,
58+ set_exception ,
59+ set_result ,
60+ )
5561from .resolver import DefaultResolver
5662
5763try :
@@ -840,7 +846,9 @@ def __init__(
840846
841847 self ._use_dns_cache = use_dns_cache
842848 self ._cached_hosts = _DNSCacheTable (ttl = ttl_dns_cache )
843- self ._throttle_dns_events : Dict [Tuple [str , int ], EventResultOrError ] = {}
849+ self ._throttle_dns_futures : Dict [
850+ Tuple [str , int ], Set ["asyncio.Future[None]" ]
851+ ] = {}
844852 self ._family = family
845853 self ._local_addr_infos = aiohappyeyeballs .addr_to_addr_infos (local_addr )
846854 self ._happy_eyeballs_delay = happy_eyeballs_delay
@@ -849,8 +857,8 @@ def __init__(
849857
850858 def close (self ) -> Awaitable [None ]:
851859 """Close all ongoing DNS calls."""
852- for ev in self ._throttle_dns_events .values ():
853- ev .cancel ()
860+ for fut in chain . from_iterable ( self ._throttle_dns_futures .values () ):
861+ fut .cancel ()
854862
855863 for t in self ._resolve_host_tasks :
856864 t .cancel ()
@@ -918,18 +926,35 @@ async def _resolve_host(
918926 await trace .send_dns_cache_hit (host )
919927 return result
920928
929+ futures : Set ["asyncio.Future[None]" ]
921930 #
922931 # If multiple connectors are resolving the same host, we wait
923932 # for the first one to resolve and then use the result for all of them.
924- # We use a throttle event to ensure that we only resolve the host once
933+ # We use a throttle to ensure that we only resolve the host once
925934 # and then use the result for all the waiters.
926935 #
936+ if key in self ._throttle_dns_futures :
937+ # get futures early, before any await (#4014)
938+ futures = self ._throttle_dns_futures [key ]
939+ future : asyncio .Future [None ] = self ._loop .create_future ()
940+ futures .add (future )
941+ if traces :
942+ for trace in traces :
943+ await trace .send_dns_cache_hit (host )
944+ try :
945+ await future
946+ finally :
947+ futures .discard (future )
948+ return self ._cached_hosts .next_addrs (key )
949+
950+ # update dict early, before any await (#4014)
951+ self ._throttle_dns_futures [key ] = futures = set ()
927952 # In this case we need to create a task to ensure that we can shield
928953 # the task from cancellation as cancelling this lookup should not cancel
929954 # the underlying lookup or else the cancel event will get broadcast to
930955 # all the waiters across all connections.
931956 #
932- coro = self ._resolve_host_with_throttle (key , host , port , traces )
957+ coro = self ._resolve_host_with_throttle (key , host , port , futures , traces )
933958 loop = asyncio .get_running_loop ()
934959 if sys .version_info >= (3 , 12 ):
935960 # Optimization for Python 3.12, try to send immediately
@@ -957,42 +982,39 @@ async def _resolve_host_with_throttle(
957982 key : Tuple [str , int ],
958983 host : str ,
959984 port : int ,
985+ futures : Set ["asyncio.Future[None]" ],
960986 traces : Optional [Sequence ["Trace" ]],
961987 ) -> List [ResolveResult ]:
962- """Resolve host with a dns events throttle."""
963- if key in self ._throttle_dns_events :
964- # get event early, before any await (#4014)
965- event = self ._throttle_dns_events [key ]
988+ """Resolve host and set result for all waiters.
989+
990+ This method must be run in a task and shielded from cancellation
991+ to avoid cancelling the underlying lookup.
992+ """
993+ if traces :
994+ for trace in traces :
995+ await trace .send_dns_cache_miss (host )
996+ try :
966997 if traces :
967998 for trace in traces :
968- await trace .send_dns_cache_hit (host )
969- await event .wait ()
970- else :
971- # update dict early, before any await (#4014)
972- self ._throttle_dns_events [key ] = EventResultOrError (self ._loop )
999+ await trace .send_dns_resolvehost_start (host )
1000+
1001+ addrs = await self ._resolver .resolve (host , port , family = self ._family )
9731002 if traces :
9741003 for trace in traces :
975- await trace .send_dns_cache_miss (host )
976- try :
977-
978- if traces :
979- for trace in traces :
980- await trace .send_dns_resolvehost_start (host )
981-
982- addrs = await self ._resolver .resolve (host , port , family = self ._family )
983- if traces :
984- for trace in traces :
985- await trace .send_dns_resolvehost_end (host )
1004+ await trace .send_dns_resolvehost_end (host )
9861005
987- self ._cached_hosts .add (key , addrs )
988- self ._throttle_dns_events [key ].set ()
989- except BaseException as e :
990- # any DNS exception, independently of the implementation
991- # is set for the waiters to raise the same exception.
992- self ._throttle_dns_events [key ].set (exc = e )
993- raise
994- finally :
995- self ._throttle_dns_events .pop (key )
1006+ self ._cached_hosts .add (key , addrs )
1007+ for fut in futures :
1008+ set_result (fut , None )
1009+ except BaseException as e :
1010+ # any DNS exception is set for the waiters to raise the same exception.
1011+ # This coro is always run in task that is shielded from cancellation so
1012+ # we should never be propagating cancellation here.
1013+ for fut in futures :
1014+ set_exception (fut , e )
1015+ raise
1016+ finally :
1017+ self ._throttle_dns_futures .pop (key )
9961018
9971019 return self ._cached_hosts .next_addrs (key )
9981020
0 commit comments