Skip to content

Commit f967422

Browse files
authored
transport: make the client send a RST_STREAM when it receives an END_STREAM from the server (grpc#8832)
Fixes grpc#835 This PR fixes the behavior of the client to send a RST_STREAM when it receives an END_STREAM from the server when the client-side of the stream is still open. It also adds tests for both the client and server side behaviors of sending RST_STREAM when they receive an END_STREAM from their peer. RELEASE NOTES: - transport: fix a bug in the client where it was failing to send a RST_STREAM upon receiving an END_STREAM from the server when the stream was still open
1 parent 99f36d4 commit f967422

File tree

2 files changed

+355
-1
lines changed

2 files changed

+355
-1
lines changed

internal/transport/http2_client.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1236,7 +1236,10 @@ func (t *http2Client) handleData(f *parsedDataFrame) {
12361236
// The server has closed the stream without sending trailers. Record that
12371237
// the read direction is closed, and set the status appropriately.
12381238
if f.StreamEnded() {
1239-
t.closeStream(s, io.EOF, false, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
1239+
// If client received END_STREAM from server while stream was still
1240+
// active, send RST_STREAM.
1241+
rstStream := s.getState() == streamActive
1242+
t.closeStream(s, io.EOF, rstStream, http2.ErrCodeNo, status.New(codes.Internal, "server closed the stream without sending trailers"), nil, true)
12401243
}
12411244
}
12421245

internal/transport/transport_test.go

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3356,6 +3356,357 @@ func (s) TestServerSendsRSTAfterDeadlineToMisbehavedClient(t *testing.T) {
33563356
}
33573357
}
33583358

3359+
// Tests the scenario where the client sends a DATA frame without END_STREAM
3360+
// flag. The test verifies that the server responds with a RST stream when it
3361+
// tries to send trailers.
3362+
func (s) TestServerSendsResetStreamOnEarlyTrailer(t *testing.T) {
3363+
// Create a server that expects the client to send a "ping" request and
3364+
// responds with a "pong" response.
3365+
server := setUpServerOnly(t, 0, &ServerConfig{BufferPool: mem.DefaultBufferPool()}, normal)
3366+
defer server.stop()
3367+
3368+
// Connect to the above server with a client that sends a DATA frame without
3369+
// END_STREAM. This simulates a scenario where the client has not
3370+
// half-closed when the server is done sending the response and trailers.
3371+
mconn, err := net.Dial("tcp", server.lis.Addr().String())
3372+
if err != nil {
3373+
t.Fatalf("Clent failed to dial:%v", err)
3374+
}
3375+
defer mconn.Close()
3376+
if n, err := mconn.Write(clientPreface); err != nil || n != len(clientPreface) {
3377+
t.Fatalf("mconn.Write(clientPreface) = %d, %v, want %d, <nil>", n, err, len(clientPreface))
3378+
}
3379+
framer := http2.NewFramer(mconn, mconn)
3380+
if err := framer.WriteSettings(); err != nil {
3381+
t.Fatalf("Error while writing settings: %v", err)
3382+
}
3383+
3384+
seenResetFrame := make(chan struct{})
3385+
go func() { // Launch a reader for this client.
3386+
for {
3387+
frame, err := framer.ReadFrame()
3388+
if err != nil {
3389+
return
3390+
}
3391+
switch frame := frame.(type) {
3392+
case *http2.RSTStreamFrame:
3393+
const wantStreamID = 1
3394+
const wantErrCode = http2.ErrCodeNo
3395+
if frame.Header().StreamID != wantStreamID || http2.ErrCode(frame.ErrCode) != wantErrCode {
3396+
t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: %d and code: %v", frame.Header().StreamID, http2.ErrCode(frame.ErrCode), wantStreamID, wantErrCode)
3397+
}
3398+
close(seenResetFrame)
3399+
return
3400+
default:
3401+
// Do nothing.
3402+
}
3403+
}
3404+
}()
3405+
3406+
// Create a stream, sending headers first, followed by a DATA frame without
3407+
// END_STREAM.
3408+
var buf bytes.Buffer
3409+
henc := hpack.NewEncoder(&buf)
3410+
if err := henc.WriteField(hpack.HeaderField{Name: ":method", Value: "POST"}); err != nil {
3411+
t.Fatalf("Error while encoding header: %v", err)
3412+
}
3413+
if err := henc.WriteField(hpack.HeaderField{Name: ":path", Value: "foo"}); err != nil {
3414+
t.Fatalf("Error while encoding header: %v", err)
3415+
}
3416+
if err := henc.WriteField(hpack.HeaderField{Name: ":authority", Value: "localhost"}); err != nil {
3417+
t.Fatalf("Error while encoding header: %v", err)
3418+
}
3419+
if err := henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"}); err != nil {
3420+
t.Fatalf("Error while encoding header: %v", err)
3421+
}
3422+
if err := framer.WriteHeaders(http2.HeadersFrameParam{StreamID: 1, BlockFragment: buf.Bytes(), EndHeaders: true}); err != nil {
3423+
t.Fatalf("Error while writing headers: %v", err)
3424+
}
3425+
if err := framer.WriteData(1, false, expectedRequest); err != nil {
3426+
t.Fatalf("Error while writing data: %v", err)
3427+
}
3428+
3429+
select {
3430+
case <-time.After(defaultTestTimeout):
3431+
t.Fatalf("Test timed out when waiting for a RST frame from server")
3432+
case <-seenResetFrame:
3433+
}
3434+
}
3435+
3436+
// setupRSTStreamOnEOSTest sets up a test scenario where a client and a manual
3437+
// server are connected.
3438+
//
3439+
// The server invokes the provided sendServerFrames function to send frames to
3440+
// the client (using the framer and the stream ID provided by the test). Callers
3441+
// should not read from the framer passed to this function, as the server will
3442+
// be reading from it to look for the RST_STREAM frame from the client.
3443+
//
3444+
// Returns the client stream created for the test and a function that will wait
3445+
// for the server to be done processing the test scenario.
3446+
func setupRSTStreamOnEOSTest(ctx context.Context, t *testing.T, sendServerFrames func(*testing.T, *http2.Framer, uint32)) (*ClientStream, func()) {
3447+
// Set up a listener for a manual server.
3448+
lis, err := net.Listen("tcp", "localhost:0")
3449+
if err != nil {
3450+
t.Fatalf("Failed to listen: %v", err)
3451+
}
3452+
t.Cleanup(func() { lis.Close() })
3453+
3454+
// Set up a manual server.
3455+
seenHeadersFrame := make(chan struct{})
3456+
serverDone := make(chan struct{})
3457+
go func() {
3458+
defer close(serverDone)
3459+
conn, err := lis.Accept()
3460+
if err != nil {
3461+
t.Errorf("Server failed to accept connection: %v", err)
3462+
return
3463+
}
3464+
defer conn.Close()
3465+
3466+
// Read client preface.
3467+
if _, err := io.ReadFull(conn, make([]byte, len(clientPreface))); err != nil {
3468+
t.Errorf("Server failed to read client preface: %v", err)
3469+
return
3470+
}
3471+
3472+
// Read client's initial SETTINGS frame.
3473+
framer := http2.NewFramer(conn, conn)
3474+
frame, err := framer.ReadFrame()
3475+
if err != nil {
3476+
t.Errorf("Server failed to read client SETTINGS frame: %v", err)
3477+
return
3478+
}
3479+
if _, ok := frame.(*http2.SettingsFrame); !ok {
3480+
t.Errorf("Server read unexpected frame of type %T, want *http2.SettingsFrame", frame)
3481+
return
3482+
}
3483+
3484+
// Write server SETTINGS and ACK frame.
3485+
if err := framer.WriteSettings(); err != nil {
3486+
t.Errorf("Server failed to write SETTINGS frame: %v", err)
3487+
return
3488+
}
3489+
if err := framer.WriteSettingsAck(); err != nil {
3490+
t.Errorf("Server failed to write SETTINGS ACK frame: %v", err)
3491+
return
3492+
}
3493+
3494+
// Read client headers. Loop until we get a HEADERS frame, skipping
3495+
// any SETTINGS ACK frames.
3496+
var hframe *http2.HeadersFrame
3497+
for {
3498+
frame, err = framer.ReadFrame()
3499+
if err != nil {
3500+
t.Errorf("Server failed to read client headers: %v", err)
3501+
return
3502+
}
3503+
if f, ok := frame.(*http2.HeadersFrame); ok {
3504+
hframe = f
3505+
break
3506+
}
3507+
}
3508+
streamID := hframe.StreamID
3509+
close(seenHeadersFrame)
3510+
3511+
// Launch a reader goroutine to look for RST frame from the client.
3512+
readDone := make(chan struct{})
3513+
go func() {
3514+
defer close(readDone)
3515+
for {
3516+
frame, err := framer.ReadFrame()
3517+
if err != nil {
3518+
t.Errorf("Server reader goroutine failed to read frame: %v", err)
3519+
return
3520+
}
3521+
switch frame := frame.(type) {
3522+
case *http2.RSTStreamFrame:
3523+
const wantErrCode = http2.ErrCodeNo
3524+
if frame.Header().StreamID != streamID || http2.ErrCode(frame.ErrCode) != wantErrCode {
3525+
t.Errorf("RST stream received with streamID: %d and code: %v, want streamID: %d and code: %v", frame.Header().StreamID, http2.ErrCode(frame.ErrCode), streamID, wantErrCode)
3526+
}
3527+
return
3528+
default:
3529+
// Do nothing.
3530+
}
3531+
}
3532+
}()
3533+
3534+
writeDone := make(chan struct{})
3535+
go func() {
3536+
defer close(writeDone)
3537+
sendServerFrames(t, framer, streamID)
3538+
}()
3539+
3540+
select {
3541+
case <-ctx.Done():
3542+
t.Errorf("Test timed out when waiting for a RST_STREAM frame from client")
3543+
case <-readDone:
3544+
}
3545+
select {
3546+
case <-ctx.Done():
3547+
t.Errorf("Test timed out when waiting for server to send frames")
3548+
case <-writeDone:
3549+
}
3550+
}()
3551+
3552+
// Set up a client.
3553+
copts := ConnectOptions{BufferPool: mem.DefaultBufferPool()}
3554+
ct, err := NewHTTP2Client(ctx, ctx, resolver.Address{Addr: lis.Addr().String()}, copts, func(GoAwayReason) {})
3555+
if err != nil {
3556+
t.Fatalf("NewHTTP2Client failed: %v", err)
3557+
}
3558+
t.Cleanup(func() { ct.Close(errors.New("test cleanup: forcing close")) })
3559+
3560+
// Create a stream.
3561+
stream, err := ct.NewStream(ctx, &CallHdr{}, nil)
3562+
if err != nil {
3563+
t.Fatalf("NewStream failed: %v", err)
3564+
}
3565+
3566+
// Wait for server to see client's headers.
3567+
select {
3568+
case <-ctx.Done():
3569+
t.Fatalf("Test timed out when waiting for server to see client's headers")
3570+
case <-seenHeadersFrame:
3571+
}
3572+
3573+
waitForServerDone := func() {
3574+
select {
3575+
case <-ctx.Done():
3576+
t.Fatalf("Test timed out when waiting for server to be done")
3577+
case <-serverDone:
3578+
}
3579+
}
3580+
return stream, waitForServerDone
3581+
}
3582+
3583+
// Tests the scenario where the server sets the END_STREAM flag in the HEADERS
3584+
// frame and verifies that the client responds with a RST stream.
3585+
func (s) TestClientSendsRSTStream_InHeaders(t *testing.T) {
3586+
serverFrames := func(t *testing.T, framer *http2.Framer, streamID uint32) {
3587+
var buf bytes.Buffer
3588+
henc := hpack.NewEncoder(&buf)
3589+
henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
3590+
if err := framer.WriteHeaders(http2.HeadersFrameParam{
3591+
StreamID: streamID,
3592+
BlockFragment: buf.Bytes(),
3593+
EndHeaders: true,
3594+
EndStream: true,
3595+
}); err != nil {
3596+
t.Errorf("Server failed to write headers: %v", err)
3597+
}
3598+
}
3599+
3600+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
3601+
defer cancel()
3602+
stream, waitForServer := setupRSTStreamOnEOSTest(ctx, t, serverFrames)
3603+
defer waitForServer()
3604+
3605+
if _, err := stream.readTo(make([]byte, 1)); !errors.Is(err, io.EOF) {
3606+
t.Fatalf("stream.readTo() got %v, want %v", err, io.EOF)
3607+
}
3608+
3609+
// Ensure the stream is done before checking status.
3610+
<-stream.Done()
3611+
if code := stream.Status().Code(); code != codes.Unknown {
3612+
t.Fatalf("stream.Status().Code() got %s, want %s", code, codes.Unknown)
3613+
}
3614+
}
3615+
3616+
// Tests the scenario where the server sets the END_STREAM flag in the Trailers
3617+
// (HEADERS frame) and verifies that the client responds with a RST stream.
3618+
func (s) TestClientSendsRSTStream_InTrailers(t *testing.T) {
3619+
serverFrames := func(t *testing.T, framer *http2.Framer, streamID uint32) {
3620+
var buf bytes.Buffer
3621+
henc := hpack.NewEncoder(&buf)
3622+
henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
3623+
if err := framer.WriteHeaders(http2.HeadersFrameParam{
3624+
StreamID: streamID,
3625+
BlockFragment: buf.Bytes(),
3626+
EndHeaders: true,
3627+
EndStream: false,
3628+
}); err != nil {
3629+
t.Errorf("Server failed to write headers: %v", err)
3630+
}
3631+
if err := framer.WriteData(streamID, false, expectedResponse); err != nil {
3632+
t.Errorf("Server failed to write data: %v", err)
3633+
}
3634+
buf.Reset()
3635+
henc = hpack.NewEncoder(&buf)
3636+
henc.WriteField(hpack.HeaderField{Name: "grpc-status", Value: "0"})
3637+
if err := framer.WriteHeaders(http2.HeadersFrameParam{
3638+
StreamID: streamID,
3639+
BlockFragment: buf.Bytes(),
3640+
EndHeaders: true,
3641+
EndStream: true,
3642+
}); err != nil {
3643+
t.Errorf("Server failed to write trailers: %v", err)
3644+
}
3645+
}
3646+
3647+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
3648+
defer cancel()
3649+
stream, waitForServer := setupRSTStreamOnEOSTest(ctx, t, serverFrames)
3650+
defer waitForServer()
3651+
3652+
// Wait for the stream to be closed.
3653+
<-stream.Done()
3654+
if code := stream.Status().Code(); code != codes.OK {
3655+
t.Fatalf("stream.Status().Code() got %s, want %s", code, codes.OK)
3656+
}
3657+
}
3658+
3659+
// Tests the scenario where the server sets the END_STREAM flag in one of its
3660+
// DATA frames (before sending trailers), causing the client to send a
3661+
// RST_STREAM. The test verifies that the client can still read buffered data
3662+
// from the stream after this event.
3663+
func (s) TestClientSendsRSTStream_ReadUnreadData(t *testing.T) {
3664+
serverFrames := func(t *testing.T, framer *http2.Framer, streamID uint32) {
3665+
var buf bytes.Buffer
3666+
henc := hpack.NewEncoder(&buf)
3667+
henc.WriteField(hpack.HeaderField{Name: "content-type", Value: "application/grpc"})
3668+
if err := framer.WriteHeaders(http2.HeadersFrameParam{
3669+
StreamID: streamID,
3670+
BlockFragment: buf.Bytes(),
3671+
EndHeaders: true,
3672+
EndStream: false,
3673+
}); err != nil {
3674+
t.Errorf("Server failed to write headers: %v", err)
3675+
}
3676+
if err := framer.WriteData(streamID, true, expectedResponse); err != nil {
3677+
t.Errorf("Server failed to write data: %v", err)
3678+
}
3679+
}
3680+
3681+
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
3682+
defer cancel()
3683+
stream, waitForServer := setupRSTStreamOnEOSTest(ctx, t, serverFrames)
3684+
defer waitForServer()
3685+
3686+
// Wait for the stream to match the state we expect (which is that it
3687+
// has sent a RST_STREAM, which means it has closed).
3688+
//
3689+
// If we read before the RST_STREAM is sent, we might race with the
3690+
// client receiving the EOS from the server, and the client might
3691+
// not have sent the RST_STREAM yet.
3692+
<-stream.Done()
3693+
3694+
// Read the data.
3695+
gotData := make([]byte, len(expectedResponse))
3696+
if _, err := stream.readTo(gotData); err != nil {
3697+
t.Fatalf("stream.readTo() got %v, want <nil>", err)
3698+
}
3699+
if !bytes.Equal(gotData, expectedResponse) {
3700+
t.Fatalf("stream.readTo() got %v, want %v", gotData, expectedResponse)
3701+
}
3702+
if _, err := stream.readTo(make([]byte, 1)); !errors.Is(err, io.EOF) {
3703+
t.Fatalf("stream.readTo() got %v, want %v", err, io.EOF)
3704+
}
3705+
if code := stream.Status().Code(); code != codes.Internal {
3706+
t.Fatalf("stream.Status().Code() got %s, want %s", code, codes.Internal)
3707+
}
3708+
}
3709+
33593710
// TestClientTransport_Handle1xxHeaders validates that 1xx HTTP status headers
33603711
// are ignored and treated as a protocol error if END_STREAM is set.
33613712
func (s) TestClientTransport_Handle1xxHeaders(t *testing.T) {

0 commit comments

Comments
 (0)