Skip to content

Commit 6dda22e

Browse files
committed
Send server side RST stream on deadline execeeded
1 parent e0d191d commit 6dda22e

File tree

5 files changed

+116
-6
lines changed

5 files changed

+116
-6
lines changed

internal/transport/client_stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (s *ClientStream) Read(n int) (mem.BufferSlice, error) {
5959
return b, err
6060
}
6161

62-
// Close closes the stream and popagates err to any readers.
62+
// Close closes the stream and propagates err to any readers.
6363
func (s *ClientStream) Close(err error) {
6464
var (
6565
rst bool

internal/transport/http2_client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,7 @@ func (t *http2Client) handleRSTStream(f *http2.RSTStreamFrame) {
12381238
if statusCode == codes.Canceled {
12391239
if d, ok := s.ctx.Deadline(); ok && !d.After(time.Now()) {
12401240
// Our deadline was already exceeded, and that was likely the cause
1241-
// of this cancellation. Alter the status code accordingly.
1241+
// of this cancellation. Alter the status code accordingly.
12421242
statusCode = codes.DeadlineExceeded
12431243
}
12441244
}

internal/transport/http2_server.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,15 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade
598598
if len(t.activeStreams) == 1 {
599599
t.idle = time.Time{}
600600
}
601+
// Start a timer to close the stream on reaching the deadline.
602+
if timeoutSet {
603+
timer := time.AfterFunc(timeout, func() {
604+
t.closeStream(s, true, http2.ErrCodeCancel, false)
605+
})
606+
s.deadlineTimerCancel = func() {
607+
timer.Stop()
608+
}
609+
}
601610
t.mu.Unlock()
602611
if channelz.IsOn() {
603612
t.channelz.SocketMetrics.StreamsStarted.Add(1)
@@ -1268,6 +1277,9 @@ func (t *http2Server) Close(err error) {
12681277
channelz.RemoveEntry(t.channelz.ID)
12691278
// Cancel all active streams.
12701279
for _, s := range streams {
1280+
if s.deadlineTimerCancel != nil {
1281+
s.deadlineTimerCancel()
1282+
}
12711283
s.cancel()
12721284
}
12731285
}
@@ -1306,6 +1318,10 @@ func (t *http2Server) finishStream(s *ServerStream, rst bool, rstCode http2.ErrC
13061318
return
13071319
}
13081320

1321+
if s.deadlineTimerCancel != nil {
1322+
s.deadlineTimerCancel()
1323+
}
1324+
13091325
hdr.cleanup = &cleanupStream{
13101326
streamID: s.id,
13111327
rst: rst,
@@ -1324,7 +1340,13 @@ func (t *http2Server) closeStream(s *ServerStream, rst bool, rstCode http2.ErrCo
13241340
// called to interrupt the potential blocking on other goroutines.
13251341
s.cancel()
13261342

1327-
s.swapState(streamDone)
1343+
oldState := s.swapState(streamDone)
1344+
if oldState == streamDone {
1345+
return
1346+
}
1347+
if s.deadlineTimerCancel != nil {
1348+
s.deadlineTimerCancel()
1349+
}
13281350
t.deleteStream(s, eosReceived)
13291351

13301352
t.controlBuf.put(&cleanupStream{

internal/transport/server_stream.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@ import (
3434
type ServerStream struct {
3535
*Stream // Embed for common stream functionality.
3636

37-
st internalServerTransport
38-
ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance)
39-
cancel context.CancelFunc // invoked at the end of stream to cancel ctx.
37+
st internalServerTransport
38+
ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance)
39+
cancel context.CancelFunc // invoked at the end of stream to cancel ctx.
40+
deadlineTimerCancel func() // Invoked at the end of stream.
4041

4142
// Holds compressor names passed in grpc-accept-encoding metadata from the
4243
// client.

internal/transport/transport_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2961,3 +2961,90 @@ func (s) TestReadMessageHeaderMultipleBuffers(t *testing.T) {
29612961
t.Errorf("bytesRead = %d, want = %d", bytesRead, headerLen)
29622962
}
29632963
}
2964+
2965+
// Tests a scenario when the client doesn't send an RST frame when the
2966+
// configured deadline is reached. The test verifies that the server sends an
2967+
// RST stream only after the deadline is reached.
2968+
func (s) TestServerSendsRSTAfterDeadlineToMisbehavedClient(t *testing.T) {
2969+
server := setUpServerOnly(t, 0, &ServerConfig{}, suspended)
2970+
defer server.stop()
2971+
// Create a client that can override server stream quota.
2972+
mconn, err := net.Dial("tcp", server.lis.Addr().String())
2973+
if err != nil {
2974+
t.Fatalf("Clent failed to dial:%v", err)
2975+
}
2976+
defer mconn.Close()
2977+
if err := mconn.SetWriteDeadline(time.Now().Add(time.Second * 10)); err != nil {
2978+
t.Fatalf("Failed to set write deadline: %v", err)
2979+
}
2980+
if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) {
2981+
t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface))
2982+
}
2983+
// rstTimeChan chan indicates that reader received a RSTStream from server.
2984+
rstTimeChan := make(chan time.Time, 1)
2985+
var mu sync.Mutex
2986+
framer := http2.NewFramer(mconn, mconn)
2987+
if err := framer.WriteSettings(); err != nil {
2988+
t.Fatalf("Error while writing settings: %v", err)
2989+
}
2990+
go func() { // Launch a reader for this misbehaving client.
2991+
for {
2992+
frame, err := framer.ReadFrame()
2993+
if err != nil {
2994+
return
2995+
}
2996+
switch frame := frame.(type) {
2997+
case *http2.PingFrame:
2998+
// Write ping ack back so that server's BDP estimation works right.
2999+
mu.Lock()
3000+
framer.WritePing(true, frame.Data)
3001+
mu.Unlock()
3002+
case *http2.RSTStreamFrame:
3003+
if frame.Header().StreamID != 1 || http2.ErrCode(frame.ErrCode) != http2.ErrCodeCancel {
3004+
t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: 1 and code: http2.ErrCodeCancel", frame.Header().StreamID, http2.ErrCode(frame.ErrCode))
3005+
}
3006+
rstTimeChan <- time.Now()
3007+
return
3008+
default:
3009+
// Do nothing.
3010+
}
3011+
}
3012+
}()
3013+
// Create a stream.
3014+
var buf bytes.Buffer
3015+
henc := hpack.NewEncoder(&buf)
3016+
if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil {
3017+
t.Fatalf("Error while encoding header: %v", err)
3018+
}
3019+
if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil {
3020+
t.Fatalf("Error while encoding header: %v", err)
3021+
}
3022+
if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil {
3023+
t.Fatalf("Error while encoding header: %v", err)
3024+
}
3025+
if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil {
3026+
t.Fatalf("Error while encoding header: %v", err)
3027+
}
3028+
if err := henc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: "10m"}); err != nil {
3029+
t.Fatalf("Error while encoding header: %v", err)
3030+
}
3031+
mu.Lock()
3032+
startTime := time.Now()
3033+
if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil {
3034+
mu.Unlock()
3035+
t.Fatalf("Error while writing headers: %v", err)
3036+
}
3037+
mu.Unlock()
3038+
3039+
// Test server behavior for deadline expiration.
3040+
var rstTime time.Time
3041+
select {
3042+
case <-time.After(5 * time.Second):
3043+
t.Fatalf("Test timed out.")
3044+
case rstTime = <-rstTimeChan:
3045+
}
3046+
3047+
if got, want := rstTime.Sub(startTime), 10*time.Millisecond; got < want {
3048+
t.Fatalf("RST frame received earlier than expected by duration: %v", want-got)
3049+
}
3050+
}

0 commit comments

Comments
 (0)