Skip to content

Commit d9b904d

Browse files
authored
{stream,transport}: fix connection overwriting when a client uses the same port to connect. (trpc-group#131)
The server's transport maintains an `addrToConn` map based only on the client's IP-Port. When a client connects to different ports on the server using the same local IP-port, it causes the connections in the server transport's `addrToConn` map to be overwritten, leading to a "Can't find conn by addr" error in the streaming service. So now, the combination of the client's IP-Port and the server's IP-Port is used as the `addrToConn` map key.
1 parent 369c60d commit d9b904d

File tree

11 files changed

+152
-35
lines changed

11 files changed

+152
-35
lines changed

codec_stream.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ package trpc
1616
import (
1717
"errors"
1818
"fmt"
19-
"net"
2019
"os"
2120
"path"
2221
"sync"
2322

2423
"trpc.group/trpc-go/trpc-go/codec"
2524
"trpc.group/trpc-go/trpc-go/errs"
25+
"trpc.group/trpc-go/trpc-go/internal/addrutil"
2626
icodec "trpc.group/trpc-go/trpc-go/internal/codec"
2727
trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
2828

@@ -46,7 +46,7 @@ var (
4646

4747
// NewServerStreamCodec initializes and returns a ServerStreamCodec.
4848
func NewServerStreamCodec() *ServerStreamCodec {
49-
return &ServerStreamCodec{initMetas: make(map[net.Addr]map[uint32]*trpcpb.TrpcStreamInitMeta), m: &sync.RWMutex{}}
49+
return &ServerStreamCodec{initMetas: make(map[string]map[uint32]*trpcpb.TrpcStreamInitMeta), m: &sync.RWMutex{}}
5050
}
5151

5252
// NewClientStreamCodec initializes and returns a ClientStreamCodec.
@@ -58,7 +58,7 @@ func NewClientStreamCodec() *ClientStreamCodec {
5858
// Used for trpc server streaming codec.
5959
type ServerStreamCodec struct {
6060
m *sync.RWMutex
61-
initMetas map[net.Addr]map[uint32]*trpcpb.TrpcStreamInitMeta // addr->streamID->TrpcStreamInitMeta
61+
initMetas map[string]map[uint32]*trpcpb.TrpcStreamInitMeta // addr->streamID->TrpcStreamInitMeta
6262
}
6363

6464
// ClientStreamCodec is an implementation of codec.Codec.
@@ -372,7 +372,7 @@ func (s *ServerStreamCodec) decodeFeedbackFrame(msg codec.Msg, reqBuf []byte) ([
372372
// setInitMeta finds the InitMeta and sets the ServerRPCName by the server handler in the InitMeta.
373373
func (s *ServerStreamCodec) setInitMeta(msg codec.Msg) error {
374374
streamID := msg.StreamID()
375-
addr := msg.RemoteAddr()
375+
addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr())
376376
s.m.RLock()
377377
defer s.m.RUnlock()
378378
if streamIDToInitMeta, ok := s.initMetas[addr]; ok {
@@ -388,7 +388,7 @@ func (s *ServerStreamCodec) setInitMeta(msg codec.Msg) error {
388388

389389
// deleteInitMeta deletes the cached info by msg.
390390
func (s *ServerStreamCodec) deleteInitMeta(msg codec.Msg) {
391-
addr := msg.RemoteAddr()
391+
addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr())
392392
streamID := msg.StreamID()
393393
s.m.Lock()
394394
defer s.m.Unlock()
@@ -447,7 +447,7 @@ func (s *ServerStreamCodec) decodeInitFrame(msg codec.Msg, reqBuf []byte) ([]byt
447447
// storeInitMeta stores the InitMeta every time when a new frame is received.
448448
func (s *ServerStreamCodec) storeInitMeta(msg codec.Msg, initMeta *trpcpb.TrpcStreamInitMeta) {
449449
streamID := msg.StreamID()
450-
addr := msg.RemoteAddr()
450+
addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr())
451451
s.m.Lock()
452452
defer s.m.Unlock()
453453
if _, ok := s.initMetas[addr]; ok {

codec_stream_test.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ package trpc_test
1515

1616
import (
1717
"context"
18+
"net"
1819
"testing"
1920

2021
"github.com/stretchr/testify/assert"
@@ -52,6 +53,12 @@ func TestStreamCodecInit(t *testing.T) {
5253
msg.WithStreamID(100)
5354
msg.WithCallerServiceName("trpc.app.server.service")
5455
msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello")
56+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
57+
assert.Nil(t, err)
58+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
59+
assert.Nil(t, err)
60+
msg.WithLocalAddr(laddr)
61+
msg.WithRemoteAddr(raddr)
5562
initBuf, err := clientCodec.Encode(msg, nil)
5663
assert.Nil(t, err)
5764
assert.Equal(t, initResult, initBuf)
@@ -77,6 +84,8 @@ func TestStreamCodecInit(t *testing.T) {
7784
// server Decode
7885
serverCtx := context.Background()
7986
_, serverMsg := codec.WithNewMessage(serverCtx)
87+
serverMsg.WithLocalAddr(laddr)
88+
serverMsg.WithRemoteAddr(raddr)
8089
init, err := serverCodec.Decode(serverMsg, initResult)
8190
assert.Nil(t, err)
8291
assert.Nil(t, init)
@@ -138,6 +147,12 @@ func TestStreamCodecData(t *testing.T) {
138147
// server Decode
139148
serverCtx := context.Background()
140149
_, serverMsg := codec.WithNewMessage(serverCtx)
150+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
151+
assert.Nil(t, err)
152+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
153+
assert.Nil(t, err)
154+
serverMsg.WithLocalAddr(laddr)
155+
serverMsg.WithRemoteAddr(raddr)
141156
init, err := serverCodec.Decode(serverMsg, initResult)
142157
assert.Nil(t, err)
143158
assert.Nil(t, init)
@@ -162,6 +177,8 @@ func TestStreamCodecData(t *testing.T) {
162177
// Server Decode
163178
serverCtx = context.Background()
164179
_, serverMsg = codec.WithNewMessage(serverCtx)
180+
serverMsg.WithLocalAddr(laddr)
181+
serverMsg.WithRemoteAddr(raddr)
165182
dataDecode, err := serverCodec.Decode(serverMsg, dataResult)
166183
assert.Nil(t, err)
167184
assert.Equal(t, dataDecode, data)
@@ -170,6 +187,8 @@ func TestStreamCodecData(t *testing.T) {
170187
// server Encode
171188
ctx = context.Background()
172189
_, encodeMsg := codec.WithNewMessage(ctx)
190+
encodeMsg.WithLocalAddr(laddr)
191+
encodeMsg.WithRemoteAddr(raddr)
173192
serverFrameHead := &trpc.FrameHead{
174193
FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME),
175194
StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA),
@@ -223,6 +242,12 @@ func TestStreamCodecClose(t *testing.T) {
223242
// server Decode
224243
serverCtx := context.Background()
225244
_, serverMsg := codec.WithNewMessage(serverCtx)
245+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
246+
assert.Nil(t, err)
247+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
248+
assert.Nil(t, err)
249+
serverMsg.WithLocalAddr(laddr)
250+
serverMsg.WithRemoteAddr(raddr)
226251
init, err := serverCodec.Decode(serverMsg, initResult)
227252
assert.Nil(t, err)
228253
assert.Nil(t, init)
@@ -249,6 +274,8 @@ func TestStreamCodecClose(t *testing.T) {
249274
// server Decode Close
250275
serverCtx = context.Background()
251276
_, serverMsg = codec.WithNewMessage(serverCtx)
277+
serverMsg.WithLocalAddr(laddr)
278+
serverMsg.WithRemoteAddr(raddr)
252279
closeDecode, err := serverCodec.Decode(serverMsg, closeResult)
253280
assert.Nil(t, err)
254281
assert.Nil(t, closeDecode)
@@ -257,6 +284,8 @@ func TestStreamCodecClose(t *testing.T) {
257284
// server encode Close
258285
ctx = context.Background()
259286
_, encodeMsg := codec.WithNewMessage(ctx)
287+
encodeMsg.WithLocalAddr(laddr)
288+
encodeMsg.WithRemoteAddr(raddr)
260289
serverFrameHead := &trpc.FrameHead{
261290
FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME),
262291
StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE),
@@ -277,13 +306,17 @@ func TestStreamCodecClose(t *testing.T) {
277306
// Server decode error after encode close
278307
serverCtx = context.Background()
279308
_, serverMsg = codec.WithNewMessage(serverCtx)
309+
serverMsg.WithLocalAddr(laddr)
310+
serverMsg.WithRemoteAddr(raddr)
280311
closeDecode, err = serverCodec.Decode(serverMsg, closeResult)
281312
assert.NotNil(t, err)
282313
assert.Nil(t, closeDecode)
283314

284315
// Client decode close
285316
clientCtx := context.Background()
286317
_, clientMsg := codec.WithNewMessage(clientCtx)
318+
clientMsg.WithLocalAddr(laddr)
319+
clientMsg.WithRemoteAddr(raddr)
287320
CloseRsp, err := clientCodec.Decode(clientMsg, serverEncodeData)
288321
assert.Nil(t, err)
289322
assert.Nil(t, CloseRsp)
@@ -329,6 +362,12 @@ func TestStreamCodecReset(t *testing.T) {
329362
// Server decode Reset
330363
serverCtx := context.Background()
331364
_, serverMsg := codec.WithNewMessage(serverCtx)
365+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
366+
assert.Nil(t, err)
367+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
368+
assert.Nil(t, err)
369+
serverMsg.WithLocalAddr(laddr)
370+
serverMsg.WithRemoteAddr(raddr)
332371
resetDecode, err := serverCodec.Decode(serverMsg, resetResult)
333372
assert.Nil(t, err)
334373
assert.Nil(t, resetDecode)
@@ -339,6 +378,8 @@ func TestStreamCodecReset(t *testing.T) {
339378
// server encode Close
340379
ctx = context.Background()
341380
_, encodeMsg := codec.WithNewMessage(ctx)
381+
encodeMsg.WithLocalAddr(laddr)
382+
encodeMsg.WithRemoteAddr(raddr)
342383
serverFrameHead := &trpc.FrameHead{
343384
FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME),
344385
StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE),
@@ -361,6 +402,8 @@ func TestStreamCodecReset(t *testing.T) {
361402
// client Decode reset
362403
clientCtx := context.Background()
363404
_, clientMsg := codec.WithNewMessage(clientCtx)
405+
encodeMsg.WithLocalAddr(laddr)
406+
encodeMsg.WithRemoteAddr(raddr)
364407
resetRsp, err := clientCodec.Decode(clientMsg, serverEncodeData)
365408
assert.Nil(t, err)
366409
assert.Nil(t, resetRsp)
@@ -448,6 +491,12 @@ func TestFeedbackFrameType(t *testing.T) {
448491
// server Decode
449492
serverCtx := context.Background()
450493
_, serverMsg := codec.WithNewMessage(serverCtx)
494+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
495+
assert.Nil(t, err)
496+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
497+
assert.Nil(t, err)
498+
serverMsg.WithLocalAddr(laddr)
499+
serverMsg.WithRemoteAddr(raddr)
451500
init, err := serverCodec.Decode(serverMsg, initResult)
452501
assert.Nil(t, err)
453502
assert.Nil(t, init)
@@ -485,6 +534,8 @@ func TestFeedbackFrameType(t *testing.T) {
485534
// server Decode feedback frame
486535
serverCtx = context.Background()
487536
_, serverMsg = codec.WithNewMessage(serverCtx)
537+
serverMsg.WithLocalAddr(laddr)
538+
serverMsg.WithRemoteAddr(raddr)
488539
dataDecode, err := serverCodec.Decode(serverMsg, encodeData)
489540
assert.Nil(t, dataDecode)
490541
assert.Nil(t, err)
@@ -529,6 +580,12 @@ func TestDecodeEncodeFail(t *testing.T) {
529580
// server Decode
530581
serverCtx := context.Background()
531582
_, serverMsg := codec.WithNewMessage(serverCtx)
583+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
584+
assert.Nil(t, err)
585+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
586+
assert.Nil(t, err)
587+
serverMsg.WithLocalAddr(laddr)
588+
serverMsg.WithRemoteAddr(raddr)
532589
init, err := serverCodec.Decode(serverMsg, initResult)
533590
assert.Nil(t, err)
534591
assert.Nil(t, init)
@@ -605,6 +662,12 @@ func TestEncodeWithMetadata(t *testing.T) {
605662

606663
// Server Decode
607664
serverCtx, serverMsg := codec.WithNewMessage(context.Background())
665+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
666+
assert.Nil(t, err)
667+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
668+
assert.Nil(t, err)
669+
serverMsg.WithLocalAddr(laddr)
670+
serverMsg.WithRemoteAddr(raddr)
608671
initRsp, err := serverCodec.Decode(serverMsg, initResult)
609672
assert.Nil(t, err)
610673
assert.Nil(t, initRsp)
@@ -644,6 +707,12 @@ func TestEncodeWithDyeing(t *testing.T) {
644707

645708
// Server Decode
646709
serverCtx, serverMsg := codec.WithNewMessage(context.Background())
710+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
711+
assert.Nil(t, err)
712+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
713+
assert.Nil(t, err)
714+
serverMsg.WithLocalAddr(laddr)
715+
serverMsg.WithRemoteAddr(raddr)
647716
initRsp, err := serverCodec.Decode(serverMsg, initResult)
648717
assert.Nil(t, err)
649718
assert.Nil(t, initRsp)
@@ -682,6 +751,12 @@ func TestEncodeWithEnvTransfer(t *testing.T) {
682751

683752
// Server Decode
684753
serverCtx, serverMsg := codec.WithNewMessage(context.Background())
754+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
755+
assert.Nil(t, err)
756+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
757+
assert.Nil(t, err)
758+
serverMsg.WithLocalAddr(laddr)
759+
serverMsg.WithRemoteAddr(raddr)
685760
initRsp, err := serverCodec.Decode(serverMsg, initResult)
686761
assert.Nil(t, err)
687762
assert.Nil(t, initRsp)

internal/addrutil/addrutil.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Package addrutil provides some utility functions for net address.
2+
package addrutil
3+
4+
import (
5+
"net"
6+
"strings"
7+
)
8+
9+
// AddrToKey combines local and remote address into a string.
10+
func AddrToKey(local, remote net.Addr) string {
11+
return strings.Join([]string{local.Network(), local.String(), remote.String()}, "_")
12+
}

internal/addrutil/addrutil_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package addrutil_test
2+
3+
import (
4+
"net"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
"trpc.group/trpc-go/trpc-go/internal/addrutil"
9+
)
10+
11+
func TestAddrToKey(t *testing.T) {
12+
laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10000")
13+
require.Nil(t, err)
14+
raddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:10001")
15+
require.Nil(t, err)
16+
key := addrutil.AddrToKey(laddr, raddr)
17+
require.Equal(t, key, laddr.Network()+"_"+laddr.String()+"_"+raddr.String())
18+
}

stream/server.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ func (s *serverStream) SendMsg(m interface{}) error {
5151
msg := codec.Message(s.ctx)
5252
ctx, newMsg := codec.WithCloneContextAndMessage(s.ctx)
5353
defer codec.PutBackMessage(newMsg)
54+
newMsg.WithLocalAddr(msg.LocalAddr())
5455
newMsg.WithRemoteAddr(msg.RemoteAddr())
5556
newMsg.WithStreamID(s.streamID)
5657
// Refer to the pb code generated by trpc.proto, common to each language, automatically generated code.
@@ -164,6 +165,7 @@ func (s *serverStream) CloseSend(closeType, ret int32, message string) error {
164165
oldMsg := codec.Message(s.ctx)
165166
ctx, msg := codec.WithCloneContextAndMessage(s.ctx)
166167
defer codec.PutBackMessage(msg)
168+
msg.WithLocalAddr(oldMsg.LocalAddr())
167169
msg.WithRemoteAddr(oldMsg.RemoteAddr())
168170
msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_CLOSE, s.streamID))
169171
msg.WithStreamFrame(&trpcpb.TrpcStreamCloseMeta{
@@ -195,6 +197,7 @@ func (s *serverStream) feedback(w uint32) error {
195197
oldMsg := codec.Message(s.ctx)
196198
ctx, msg := codec.WithCloneContextAndMessage(s.ctx)
197199
defer codec.PutBackMessage(msg)
200+
msg.WithLocalAddr(oldMsg.LocalAddr())
198201
msg.WithRemoteAddr(oldMsg.RemoteAddr())
199202
msg.WithStreamID(s.streamID)
200203
msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, s.streamID))
@@ -364,6 +367,7 @@ func (sd *streamDispatcher) handleInit(ctx context.Context,
364367
// send init response packet.
365368
newCtx, newMsg := codec.WithCloneContextAndMessage(ctx)
366369
defer codec.PutBackMessage(newMsg)
370+
newMsg.WithLocalAddr(msg.LocalAddr())
367371
newMsg.WithRemoteAddr(msg.RemoteAddr())
368372
newMsg.WithStreamID(streamID)
369373
newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, ss.streamID))

transport/server_transport_stream.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ package transport
1616
import (
1717
"context"
1818
"fmt"
19-
"net"
2019

2120
"trpc.group/trpc-go/trpc-go/codec"
2221
"trpc.group/trpc-go/trpc-go/errs"
22+
"trpc.group/trpc-go/trpc-go/internal/addrutil"
2323
)
2424

2525
// serverStreamTransport implements ServerStreamTransport and keeps backward compatibility with the
@@ -36,10 +36,6 @@ func NewServerStreamTransport(opt ...ServerTransportOption) ServerStreamTranspor
3636
return &serverStreamTransport{s}
3737
}
3838

39-
func addrToKey(addr net.Addr) string {
40-
return fmt.Sprintf("%s//%s", addr.Network(), addr.String())
41-
}
42-
4339
// DefaultServerStreamTransport is the default ServerStreamTransport.
4440
var DefaultServerStreamTransport = NewServerStreamTransport()
4541

@@ -52,11 +48,13 @@ func (st *serverStreamTransport) ListenAndServe(ctx context.Context, opts ...Lis
5248
// Send is the method to send stream messages.
5349
func (st *serverStreamTransport) Send(ctx context.Context, req []byte) error {
5450
msg := codec.Message(ctx)
55-
addr := msg.RemoteAddr()
56-
if addr == nil {
57-
return errs.NewFrameError(errs.RetServerSystemErr, "Remote addr is invalid")
51+
raddr := msg.RemoteAddr()
52+
laddr := msg.LocalAddr()
53+
if raddr == nil || laddr == nil {
54+
return errs.NewFrameError(errs.RetServerSystemErr,
55+
fmt.Sprintf("Address is invalid, local: %s, remote: %s", laddr, raddr))
5856
}
59-
key := addrToKey(addr)
57+
key := addrutil.AddrToKey(laddr, raddr)
6058
st.serverTransport.m.RLock()
6159
tc, ok := st.serverTransport.addrToConn[key]
6260
st.serverTransport.m.RUnlock()
@@ -74,8 +72,7 @@ func (st *serverStreamTransport) Send(ctx context.Context, req []byte) error {
7472
// Close closes ServerStreamTransport, it also cleans up cached connections.
7573
func (st *serverStreamTransport) Close(ctx context.Context) {
7674
msg := codec.Message(ctx)
77-
addr := msg.RemoteAddr()
78-
key := addrToKey(addr)
75+
key := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr())
7976
st.m.Lock()
8077
delete(st.serverTransport.addrToConn, key)
8178
st.m.Unlock()

0 commit comments

Comments
 (0)