Skip to content

Commit cd930fa

Browse files
authored
swarm: return errors on filtered addresses when dialing (#2461)
1 parent 4005fe6 commit cd930fa

File tree

5 files changed

+96
-50
lines changed

5 files changed

+96
-50
lines changed

p2p/net/swarm/black_hole_detector.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ type blackHoleDetector struct {
178178
}
179179

180180
// FilterAddrs filters the peer's addresses removing black holed addresses
181-
func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
181+
func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) (valid []ma.Multiaddr, blackHoled []ma.Multiaddr) {
182182
hasUDP, hasIPv6 := false, false
183183
for _, a := range addrs {
184184
if !manet.IsPublicAddr(a) {
@@ -202,6 +202,7 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
202202
ipv6Res = d.ipv6.HandleRequest()
203203
}
204204

205+
blackHoled = make([]ma.Multiaddr, 0, len(addrs))
205206
return ma.FilterAddrs(
206207
addrs,
207208
func(a ma.Multiaddr) bool {
@@ -218,14 +219,16 @@ func (d *blackHoleDetector) FilterAddrs(addrs []ma.Multiaddr) []ma.Multiaddr {
218219
}
219220

220221
if udpRes == blackHoleResultBlocked && isProtocolAddr(a, ma.P_UDP) {
222+
blackHoled = append(blackHoled, a)
221223
return false
222224
}
223225
if ipv6Res == blackHoleResultBlocked && isProtocolAddr(a, ma.P_IP6) {
226+
blackHoled = append(blackHoled, a)
224227
return false
225228
}
226229
return true
227230
},
228-
)
231+
), blackHoled
229232
}
230233

231234
// RecordResult updates the state of the relevant `blackHoleFilter`s for addr

p2p/net/swarm/black_hole_detector_test.go

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ func TestBlackHoleDetectorInApplicableAddress(t *testing.T) {
8585
ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1"),
8686
}
8787
for i := 0; i < 1000; i++ {
88-
filteredAddrs := bhd.FilterAddrs(addrs)
88+
filteredAddrs, _ := bhd.FilterAddrs(addrs)
8989
require.ElementsMatch(t, addrs, filteredAddrs)
9090
for j := 0; j < len(addrs); j++ {
9191
bhd.RecordResult(addrs[j], false)
@@ -101,20 +101,29 @@ func TestBlackHoleDetectorUDPDisabled(t *testing.T) {
101101
for i := 0; i < 100; i++ {
102102
bhd.RecordResult(publicAddr, false)
103103
}
104-
addrs := []ma.Multiaddr{publicAddr, privAddr}
105-
require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs))
104+
wantAddrs := []ma.Multiaddr{publicAddr, privAddr}
105+
wantRemovedAddrs := make([]ma.Multiaddr, 0)
106+
107+
gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs)
108+
require.ElementsMatch(t, wantAddrs, gotAddrs)
109+
require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs)
106110
}
107111

108112
func TestBlackHoleDetectorIPv6Disabled(t *testing.T) {
109113
udpConfig := blackHoleConfig{Enabled: true, N: 10, MinSuccesses: 5}
110114
bhd := newBlackHoleDetector(udpConfig, blackHoleConfig{Enabled: false}, nil)
111115
publicAddr := ma.StringCast("/ip6/1::1/tcp/1234")
112116
privAddr := ma.StringCast("/ip6/::1/tcp/1234")
113-
addrs := []ma.Multiaddr{publicAddr, privAddr}
114117
for i := 0; i < 100; i++ {
115118
bhd.RecordResult(publicAddr, false)
116119
}
117-
require.ElementsMatch(t, addrs, bhd.FilterAddrs(addrs))
120+
121+
wantAddrs := []ma.Multiaddr{publicAddr, privAddr}
122+
wantRemovedAddrs := make([]ma.Multiaddr, 0)
123+
124+
gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(wantAddrs)
125+
require.ElementsMatch(t, wantAddrs, gotAddrs)
126+
require.ElementsMatch(t, wantRemovedAddrs, gotRemovedAddrs)
118127
}
119128

120129
func TestBlackHoleDetectorProbes(t *testing.T) {
@@ -128,7 +137,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) {
128137
bhd.RecordResult(udp6Addr, false)
129138
}
130139
for i := 1; i < 100; i++ {
131-
filteredAddrs := bhd.FilterAddrs(addrs)
140+
filteredAddrs, _ := bhd.FilterAddrs(addrs)
132141
if i%2 == 0 || i%3 == 0 {
133142
if len(filteredAddrs) == 0 {
134143
t.Fatalf("expected probe to be allowed irrespective of the state of other black hole filter")
@@ -145,7 +154,7 @@ func TestBlackHoleDetectorProbes(t *testing.T) {
145154
func TestBlackHoleDetectorAddrFiltering(t *testing.T) {
146155
udp6Pub := ma.StringCast("/ip6/1::1/udp/1234/quic-v1")
147156
udp6Pri := ma.StringCast("/ip6/::1/udp/1234/quic-v1")
148-
upd4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1")
157+
udp4Pub := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1")
149158
udp4Pri := ma.StringCast("/ip4/192.168.1.5/udp/1234/quic-v1")
150159
tcp6Pub := ma.StringCast("/ip6/1::1/tcp/1234/quic-v1")
151160
tcp6Pri := ma.StringCast("/ip6/::1/tcp/1234/quic-v1")
@@ -158,26 +167,35 @@ func TestBlackHoleDetectorAddrFiltering(t *testing.T) {
158167
ipv6: &blackHoleFilter{n: 100, minSuccesses: 10, name: "ipv6"},
159168
}
160169
for i := 0; i < 100; i++ {
161-
bhd.RecordResult(upd4Pub, !udpBlocked)
170+
bhd.RecordResult(udp4Pub, !udpBlocked)
162171
}
163172
for i := 0; i < 100; i++ {
164173
bhd.RecordResult(tcp6Pub, !ipv6Blocked)
165174
}
166175
return bhd
167176
}
168177

169-
allInput := []ma.Multiaddr{udp6Pub, udp6Pri, upd4Pub, udp4Pri, tcp6Pub, tcp6Pri,
178+
allInput := []ma.Multiaddr{udp6Pub, udp6Pri, udp4Pub, udp4Pri, tcp6Pub, tcp6Pri,
170179
tcp4Pub, tcp4Pri}
171180

172181
udpBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pub, tcp6Pri, tcp4Pub, tcp4Pri}
182+
udpPublicAddrs := []ma.Multiaddr{udp6Pub, udp4Pub}
173183
bhd := makeBHD(true, false)
174-
require.ElementsMatch(t, udpBlockedOutput, bhd.FilterAddrs(allInput))
184+
gotAddrs, gotRemovedAddrs := bhd.FilterAddrs(allInput)
185+
require.ElementsMatch(t, udpBlockedOutput, gotAddrs)
186+
require.ElementsMatch(t, udpPublicAddrs, gotRemovedAddrs)
175187

176-
ip6BlockedOutput := []ma.Multiaddr{udp6Pri, upd4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
188+
ip6BlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pub, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
189+
ip6PublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub}
177190
bhd = makeBHD(false, true)
178-
require.ElementsMatch(t, ip6BlockedOutput, bhd.FilterAddrs(allInput))
191+
gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput)
192+
require.ElementsMatch(t, ip6BlockedOutput, gotAddrs)
193+
require.ElementsMatch(t, ip6PublicAddrs, gotRemovedAddrs)
179194

180195
bothBlockedOutput := []ma.Multiaddr{udp6Pri, udp4Pri, tcp6Pri, tcp4Pub, tcp4Pri}
196+
bothPublicAddrs := []ma.Multiaddr{udp6Pub, tcp6Pub, udp4Pub}
181197
bhd = makeBHD(true, true)
182-
require.ElementsMatch(t, bothBlockedOutput, bhd.FilterAddrs(allInput))
198+
gotAddrs, gotRemovedAddrs = bhd.FilterAddrs(allInput)
199+
require.ElementsMatch(t, bothBlockedOutput, gotAddrs)
200+
require.ElementsMatch(t, bothPublicAddrs, gotRemovedAddrs)
183201
}

p2p/net/swarm/dial_worker.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,14 @@ loop:
165165
continue loop
166166
}
167167

168-
addrs, err := w.s.addrsForDial(req.ctx, w.peer)
168+
addrs, addrErrs, err := w.s.addrsForDial(req.ctx, w.peer)
169169
if err != nil {
170-
req.resch <- dialResponse{err: err}
170+
req.resch <- dialResponse{
171+
err: &DialError{
172+
Peer: w.peer,
173+
DialErrors: addrErrs,
174+
Cause: err,
175+
}}
171176
continue loop
172177
}
173178

@@ -179,8 +184,8 @@ loop:
179184
// create the pending request object
180185
pr := &pendRequest{
181186
req: req,
182-
err: &DialError{Peer: w.peer},
183187
addrs: make(map[string]struct{}, len(addrRanking)),
188+
err: &DialError{Peer: w.peer, DialErrors: addrErrs},
184189
}
185190
for _, adelay := range addrRanking {
186191
pr.addrs[string(adelay.Addr.Bytes())] = struct{}{}
@@ -221,6 +226,7 @@ loop:
221226

222227
if len(todial) == 0 && len(tojoin) == 0 {
223228
// all request applicable addrs have been dialed, we must have errored
229+
pr.err.Cause = ErrAllDialsFailed
224230
req.resch <- dialResponse{err: pr.err}
225231
continue loop
226232
}
@@ -371,6 +377,7 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) {
371377
if c != nil {
372378
pr.req.resch <- dialResponse{conn: c}
373379
} else {
380+
pr.err.Cause = ErrAllDialsFailed
374381
pr.req.resch <- dialResponse{err: pr.err}
375382
}
376383
delete(w.pendingRequests, pr)

p2p/net/swarm/swarm_dial.go

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -280,10 +280,10 @@ func (s *Swarm) dialWorkerLoop(p peer.ID, reqch <-chan dialRequest) {
280280
w.loop()
281281
}
282282

283-
func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, error) {
283+
func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) (goodAddrs []ma.Multiaddr, addrErrs []TransportError, err error) {
284284
peerAddrs := s.peers.Addrs(p)
285285
if len(peerAddrs) == 0 {
286-
return nil, ErrNoAddresses
286+
return nil, nil, ErrNoAddresses
287287
}
288288

289289
peerAddrsAfterTransportResolved := make([]ma.Multiaddr, 0, len(peerAddrs))
@@ -308,22 +308,22 @@ func (s *Swarm) addrsForDial(ctx context.Context, p peer.ID) ([]ma.Multiaddr, er
308308
Addrs: peerAddrsAfterTransportResolved,
309309
})
310310
if err != nil {
311-
return nil, err
311+
return nil, nil, err
312312
}
313313

314-
goodAddrs := s.filterKnownUndialables(p, resolved)
314+
goodAddrs = ma.Unique(resolved)
315+
goodAddrs, addrErrs = s.filterKnownUndialables(p, goodAddrs)
315316
if forceDirect, _ := network.GetForceDirectDial(ctx); forceDirect {
316317
goodAddrs = ma.FilterAddrs(goodAddrs, s.nonProxyAddr)
317318
}
318-
goodAddrs = ma.Unique(goodAddrs)
319319

320320
if len(goodAddrs) == 0 {
321-
return nil, ErrNoGoodAddresses
321+
return nil, addrErrs, ErrNoGoodAddresses
322322
}
323323

324324
s.peers.AddAddrs(p, goodAddrs, peerstore.TempAddrTTL)
325325

326-
return goodAddrs, nil
326+
return goodAddrs, addrErrs, nil
327327
}
328328

329329
func (s *Swarm) resolveAddrs(ctx context.Context, pi peer.AddrInfo) ([]ma.Multiaddr, error) {
@@ -402,11 +402,6 @@ func (s *Swarm) dialNextAddr(ctx context.Context, p peer.ID, addr ma.Multiaddr,
402402
return nil
403403
}
404404

405-
func (s *Swarm) canDial(addr ma.Multiaddr) bool {
406-
t := s.TransportForDialing(addr)
407-
return t != nil && t.CanDial(addr)
408-
}
409-
410405
func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool {
411406
t := s.TransportForDialing(addr)
412407
return !t.Proxy()
@@ -418,7 +413,7 @@ func (s *Swarm) nonProxyAddr(addr ma.Multiaddr) bool {
418413
// addresses that we know to be our own, and addresses with a better tranport
419414
// available. This is an optimization to avoid wasting time on dials that we
420415
// know are going to fail or for which we have a better alternative.
421-
func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Multiaddr {
416+
func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) (goodAddrs []ma.Multiaddr, addrErrs []TransportError) {
422417
lisAddrs, _ := s.InterfaceListenAddresses()
423418
var ourAddrs []ma.Multiaddr
424419
for _, addr := range lisAddrs {
@@ -431,27 +426,49 @@ func (s *Swarm) filterKnownUndialables(p peer.ID, addrs []ma.Multiaddr) []ma.Mul
431426
})
432427
}
433428

434-
// The order of these two filters is important. If we can only dial /webtransport,
435-
// we don't want to filter /webtransport addresses out because the peer had a /quic-v1
436-
// address
429+
addrErrs = make([]TransportError, 0, len(addrs))
437430

438-
// filter addresses we cannot dial
439-
addrs = ma.FilterAddrs(addrs, s.canDial)
431+
// The order of checking for transport and filtering low priority addrs is important. If we
432+
// can only dial /webtransport, we don't want to filter /webtransport addresses out because
433+
// the peer had a /quic-v1 address
434+
435+
// filter addresses with no transport
436+
addrs = ma.FilterAddrs(addrs, func(a ma.Multiaddr) bool {
437+
if s.TransportForDialing(a) == nil {
438+
addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrNoTransport})
439+
return false
440+
}
441+
return true
442+
})
440443

441444
// filter low priority addresses among the addresses we can dial
445+
// We don't return an error for these addresses
442446
addrs = filterLowPriorityAddresses(addrs)
443447

444448
// remove black holed addrs
445-
addrs = s.bhd.FilterAddrs(addrs)
449+
addrs, blackHoledAddrs := s.bhd.FilterAddrs(addrs)
450+
for _, a := range blackHoledAddrs {
451+
addrErrs = append(addrErrs, TransportError{Address: a, Cause: ErrDialRefusedBlackHole})
452+
}
446453

447454
return ma.FilterAddrs(addrs,
448-
func(addr ma.Multiaddr) bool { return !ma.Contains(ourAddrs, addr) },
455+
func(addr ma.Multiaddr) bool {
456+
if ma.Contains(ourAddrs, addr) {
457+
addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrDialToSelf})
458+
return false
459+
}
460+
return true
461+
},
449462
// TODO: Consider allowing link-local addresses
450463
func(addr ma.Multiaddr) bool { return !manet.IsIP6LinkLocal(addr) },
451464
func(addr ma.Multiaddr) bool {
452-
return s.gater == nil || s.gater.InterceptAddrDial(p, addr)
465+
if s.gater != nil && !s.gater.InterceptAddrDial(p, addr) {
466+
addrErrs = append(addrErrs, TransportError{Address: addr, Cause: ErrGaterDisallowedConnection})
467+
return false
468+
}
469+
return true
453470
},
454-
)
471+
), addrErrs
455472
}
456473

457474
// limitedDial will start a dial to the given peer when

p2p/net/swarm/swarm_dial_test.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"crypto/rand"
7+
"errors"
78
"net"
89
"sort"
910
"testing"
@@ -65,7 +66,7 @@ func TestAddrsForDial(t *testing.T) {
6566
ps.AddAddr(otherPeer, ma.StringCast("/dns4/example.com/tcp/1234/wss"), time.Hour)
6667

6768
ctx := context.Background()
68-
mas, err := s.addrsForDial(ctx, otherPeer)
69+
mas, _, err := s.addrsForDial(ctx, otherPeer)
6970
require.NoError(t, err)
7071

7172
require.NotZero(t, len(mas))
@@ -110,7 +111,7 @@ func TestDedupAddrsForDial(t *testing.T) {
110111
ps.AddAddr(otherPeer, ma.StringCast("/ip4/1.2.3.4/tcp/1234"), time.Hour)
111112

112113
ctx := context.Background()
113-
mas, err := s.addrsForDial(ctx, otherPeer)
114+
mas, _, err := s.addrsForDial(ctx, otherPeer)
114115
require.NoError(t, err)
115116

116117
require.Equal(t, 1, len(mas))
@@ -183,7 +184,7 @@ func TestAddrResolution(t *testing.T) {
183184

184185
tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100)
185186
defer cancel()
186-
mas, err := s.addrsForDial(tctx, p1)
187+
mas, _, err := s.addrsForDial(tctx, p1)
187188
require.NoError(t, err)
188189

189190
require.Len(t, mas, 1)
@@ -241,7 +242,7 @@ func TestAddrResolutionRecursive(t *testing.T) {
241242
tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100)
242243
defer cancel()
243244
s.Peerstore().AddAddrs(pi1.ID, pi1.Addrs, peerstore.TempAddrTTL)
244-
_, err = s.addrsForDial(tctx, p1)
245+
_, _, err = s.addrsForDial(tctx, p1)
245246
require.NoError(t, err)
246247

247248
addrs1 := s.Peerstore().Addrs(pi1.ID)
@@ -253,7 +254,7 @@ func TestAddrResolutionRecursive(t *testing.T) {
253254
require.NoError(t, err)
254255

255256
s.Peerstore().AddAddrs(pi2.ID, pi2.Addrs, peerstore.TempAddrTTL)
256-
_, err = s.addrsForDial(tctx, p2)
257+
_, _, err = s.addrsForDial(tctx, p2)
257258
// This never resolves to a good address
258259
require.Equal(t, ErrNoGoodAddresses, err)
259260

@@ -315,7 +316,7 @@ func TestAddrsForDialFiltering(t *testing.T) {
315316
t.Run(tc.name, func(t *testing.T) {
316317
s.Peerstore().ClearAddrs(p1)
317318
s.Peerstore().AddAddrs(p1, tc.input, peerstore.PermanentAddrTTL)
318-
result, err := s.addrsForDial(ctx, p1)
319+
result, _, err := s.addrsForDial(ctx, p1)
319320
require.NoError(t, err)
320321
sort.Slice(result, func(i, j int) bool { return bytes.Compare(result[i].Bytes(), result[j].Bytes()) < 0 })
321322
sort.Slice(tc.output, func(i, j int) bool { return bytes.Compare(tc.output[i].Bytes(), tc.output[j].Bytes()) < 0 })
@@ -366,10 +367,10 @@ func TestBlackHoledAddrBlocked(t *testing.T) {
366367
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
367368
defer cancel()
368369
conn, err := s.DialPeer(ctx, p)
369-
if conn != nil {
370-
t.Fatalf("expected dial to be blocked")
371-
}
372-
if err != ErrNoGoodAddresses {
370+
require.Nil(t, conn)
371+
var de *DialError
372+
if !errors.As(err, &de) {
373373
t.Fatalf("expected to receive an error of type *DialError, got %s of type %T", err, err)
374374
}
375+
require.Contains(t, de.DialErrors, TransportError{Address: addr, Cause: ErrDialRefusedBlackHole})
375376
}

0 commit comments

Comments
 (0)