@@ -3356,6 +3356,357 @@ func (s) TestServerSendsRSTAfterDeadlineToMisbehavedClient(t *testing.T) {
33563356 }
33573357}
33583358
3359+ // Tests the scenario where the client sends a DATA frame without END_STREAM
3360+ // flag. The test verifies that the server responds with a RST stream when it
3361+ // tries to send trailers.
3362+ func (s ) TestServerSendsResetStreamOnEarlyTrailer (t * testing.T ) {
3363+ // Create a server that expects the client to send a "ping" request and
3364+ // responds with a "pong" response.
3365+ server := setUpServerOnly (t , 0 , & ServerConfig {BufferPool : mem .DefaultBufferPool ()}, normal )
3366+ defer server .stop ()
3367+
3368+ // Connect to the above server with a client that sends a DATA frame without
3369+ // END_STREAM. This simulates a scenario where the client has not
3370+ // half-closed when the server is done sending the response and trailers.
3371+ mconn , err := net .Dial ("tcp" , server .lis .Addr ().String ())
3372+ if err != nil {
3373+ t .Fatalf ("Clent failed to dial:%v" , err )
3374+ }
3375+ defer mconn .Close ()
3376+ if n , err := mconn .Write (clientPreface ); err != nil || n != len (clientPreface ) {
3377+ t .Fatalf ("mconn.Write(clientPreface) = %d, %v, want %d, <nil>" , n , err , len (clientPreface ))
3378+ }
3379+ framer := http2 .NewFramer (mconn , mconn )
3380+ if err := framer .WriteSettings (); err != nil {
3381+ t .Fatalf ("Error while writing settings: %v" , err )
3382+ }
3383+
3384+ seenResetFrame := make (chan struct {})
3385+ go func () { // Launch a reader for this client.
3386+ for {
3387+ frame , err := framer .ReadFrame ()
3388+ if err != nil {
3389+ return
3390+ }
3391+ switch frame := frame .(type ) {
3392+ case * http2.RSTStreamFrame :
3393+ const wantStreamID = 1
3394+ const wantErrCode = http2 .ErrCodeNo
3395+ if frame .Header ().StreamID != wantStreamID || http2 .ErrCode (frame .ErrCode ) != wantErrCode {
3396+ t .Errorf ("RST stream received with streamID: %d and code: %v, want streamID: %d and code: %v" , frame .Header ().StreamID , http2 .ErrCode (frame .ErrCode ), wantStreamID , wantErrCode )
3397+ }
3398+ close (seenResetFrame )
3399+ return
3400+ default :
3401+ // Do nothing.
3402+ }
3403+ }
3404+ }()
3405+
3406+ // Create a stream, sending headers first, followed by a DATA frame without
3407+ // END_STREAM.
3408+ var buf bytes.Buffer
3409+ henc := hpack .NewEncoder (& buf )
3410+ if err := henc .WriteField (hpack.HeaderField {Name : ":method" , Value : "POST" }); err != nil {
3411+ t .Fatalf ("Error while encoding header: %v" , err )
3412+ }
3413+ if err := henc .WriteField (hpack.HeaderField {Name : ":path" , Value : "foo" }); err != nil {
3414+ t .Fatalf ("Error while encoding header: %v" , err )
3415+ }
3416+ if err := henc .WriteField (hpack.HeaderField {Name : ":authority" , Value : "localhost" }); err != nil {
3417+ t .Fatalf ("Error while encoding header: %v" , err )
3418+ }
3419+ if err := henc .WriteField (hpack.HeaderField {Name : "content-type" , Value : "application/grpc" }); err != nil {
3420+ t .Fatalf ("Error while encoding header: %v" , err )
3421+ }
3422+ if err := framer .WriteHeaders (http2.HeadersFrameParam {StreamID : 1 , BlockFragment : buf .Bytes (), EndHeaders : true }); err != nil {
3423+ t .Fatalf ("Error while writing headers: %v" , err )
3424+ }
3425+ if err := framer .WriteData (1 , false , expectedRequest ); err != nil {
3426+ t .Fatalf ("Error while writing data: %v" , err )
3427+ }
3428+
3429+ select {
3430+ case <- time .After (defaultTestTimeout ):
3431+ t .Fatalf ("Test timed out when waiting for a RST frame from server" )
3432+ case <- seenResetFrame :
3433+ }
3434+ }
3435+
3436+ // setupRSTStreamOnEOSTest sets up a test scenario where a client and a manual
3437+ // server are connected.
3438+ //
3439+ // The server invokes the provided sendServerFrames function to send frames to
3440+ // the client (using the framer and the stream ID provided by the test). Callers
3441+ // should not read from the framer passed to this function, as the server will
3442+ // be reading from it to look for the RST_STREAM frame from the client.
3443+ //
3444+ // Returns the client stream created for the test and a function that will wait
3445+ // for the server to be done processing the test scenario.
3446+ func setupRSTStreamOnEOSTest (ctx context.Context , t * testing.T , sendServerFrames func (* testing.T , * http2.Framer , uint32 )) (* ClientStream , func ()) {
3447+ // Set up a listener for a manual server.
3448+ lis , err := net .Listen ("tcp" , "localhost:0" )
3449+ if err != nil {
3450+ t .Fatalf ("Failed to listen: %v" , err )
3451+ }
3452+ t .Cleanup (func () { lis .Close () })
3453+
3454+ // Set up a manual server.
3455+ seenHeadersFrame := make (chan struct {})
3456+ serverDone := make (chan struct {})
3457+ go func () {
3458+ defer close (serverDone )
3459+ conn , err := lis .Accept ()
3460+ if err != nil {
3461+ t .Errorf ("Server failed to accept connection: %v" , err )
3462+ return
3463+ }
3464+ defer conn .Close ()
3465+
3466+ // Read client preface.
3467+ if _ , err := io .ReadFull (conn , make ([]byte , len (clientPreface ))); err != nil {
3468+ t .Errorf ("Server failed to read client preface: %v" , err )
3469+ return
3470+ }
3471+
3472+ // Read client's initial SETTINGS frame.
3473+ framer := http2 .NewFramer (conn , conn )
3474+ frame , err := framer .ReadFrame ()
3475+ if err != nil {
3476+ t .Errorf ("Server failed to read client SETTINGS frame: %v" , err )
3477+ return
3478+ }
3479+ if _ , ok := frame .(* http2.SettingsFrame ); ! ok {
3480+ t .Errorf ("Server read unexpected frame of type %T, want *http2.SettingsFrame" , frame )
3481+ return
3482+ }
3483+
3484+ // Write server SETTINGS and ACK frame.
3485+ if err := framer .WriteSettings (); err != nil {
3486+ t .Errorf ("Server failed to write SETTINGS frame: %v" , err )
3487+ return
3488+ }
3489+ if err := framer .WriteSettingsAck (); err != nil {
3490+ t .Errorf ("Server failed to write SETTINGS ACK frame: %v" , err )
3491+ return
3492+ }
3493+
3494+ // Read client headers. Loop until we get a HEADERS frame, skipping
3495+ // any SETTINGS ACK frames.
3496+ var hframe * http2.HeadersFrame
3497+ for {
3498+ frame , err = framer .ReadFrame ()
3499+ if err != nil {
3500+ t .Errorf ("Server failed to read client headers: %v" , err )
3501+ return
3502+ }
3503+ if f , ok := frame .(* http2.HeadersFrame ); ok {
3504+ hframe = f
3505+ break
3506+ }
3507+ }
3508+ streamID := hframe .StreamID
3509+ close (seenHeadersFrame )
3510+
3511+ // Launch a reader goroutine to look for RST frame from the client.
3512+ readDone := make (chan struct {})
3513+ go func () {
3514+ defer close (readDone )
3515+ for {
3516+ frame , err := framer .ReadFrame ()
3517+ if err != nil {
3518+ t .Errorf ("Server reader goroutine failed to read frame: %v" , err )
3519+ return
3520+ }
3521+ switch frame := frame .(type ) {
3522+ case * http2.RSTStreamFrame :
3523+ const wantErrCode = http2 .ErrCodeNo
3524+ if frame .Header ().StreamID != streamID || http2 .ErrCode (frame .ErrCode ) != wantErrCode {
3525+ t .Errorf ("RST stream received with streamID: %d and code: %v, want streamID: %d and code: %v" , frame .Header ().StreamID , http2 .ErrCode (frame .ErrCode ), streamID , wantErrCode )
3526+ }
3527+ return
3528+ default :
3529+ // Do nothing.
3530+ }
3531+ }
3532+ }()
3533+
3534+ writeDone := make (chan struct {})
3535+ go func () {
3536+ defer close (writeDone )
3537+ sendServerFrames (t , framer , streamID )
3538+ }()
3539+
3540+ select {
3541+ case <- ctx .Done ():
3542+ t .Errorf ("Test timed out when waiting for a RST_STREAM frame from client" )
3543+ case <- readDone :
3544+ }
3545+ select {
3546+ case <- ctx .Done ():
3547+ t .Errorf ("Test timed out when waiting for server to send frames" )
3548+ case <- writeDone :
3549+ }
3550+ }()
3551+
3552+ // Set up a client.
3553+ copts := ConnectOptions {BufferPool : mem .DefaultBufferPool ()}
3554+ ct , err := NewHTTP2Client (ctx , ctx , resolver.Address {Addr : lis .Addr ().String ()}, copts , func (GoAwayReason ) {})
3555+ if err != nil {
3556+ t .Fatalf ("NewHTTP2Client failed: %v" , err )
3557+ }
3558+ t .Cleanup (func () { ct .Close (errors .New ("test cleanup: forcing close" )) })
3559+
3560+ // Create a stream.
3561+ stream , err := ct .NewStream (ctx , & CallHdr {}, nil )
3562+ if err != nil {
3563+ t .Fatalf ("NewStream failed: %v" , err )
3564+ }
3565+
3566+ // Wait for server to see client's headers.
3567+ select {
3568+ case <- ctx .Done ():
3569+ t .Fatalf ("Test timed out when waiting for server to see client's headers" )
3570+ case <- seenHeadersFrame :
3571+ }
3572+
3573+ waitForServerDone := func () {
3574+ select {
3575+ case <- ctx .Done ():
3576+ t .Fatalf ("Test timed out when waiting for server to be done" )
3577+ case <- serverDone :
3578+ }
3579+ }
3580+ return stream , waitForServerDone
3581+ }
3582+
3583+ // Tests the scenario where the server sets the END_STREAM flag in the HEADERS
3584+ // frame and verifies that the client responds with a RST stream.
3585+ func (s ) TestClientSendsRSTStream_InHeaders (t * testing.T ) {
3586+ serverFrames := func (t * testing.T , framer * http2.Framer , streamID uint32 ) {
3587+ var buf bytes.Buffer
3588+ henc := hpack .NewEncoder (& buf )
3589+ henc .WriteField (hpack.HeaderField {Name : "content-type" , Value : "application/grpc" })
3590+ if err := framer .WriteHeaders (http2.HeadersFrameParam {
3591+ StreamID : streamID ,
3592+ BlockFragment : buf .Bytes (),
3593+ EndHeaders : true ,
3594+ EndStream : true ,
3595+ }); err != nil {
3596+ t .Errorf ("Server failed to write headers: %v" , err )
3597+ }
3598+ }
3599+
3600+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
3601+ defer cancel ()
3602+ stream , waitForServer := setupRSTStreamOnEOSTest (ctx , t , serverFrames )
3603+ defer waitForServer ()
3604+
3605+ if _ , err := stream .readTo (make ([]byte , 1 )); ! errors .Is (err , io .EOF ) {
3606+ t .Fatalf ("stream.readTo() got %v, want %v" , err , io .EOF )
3607+ }
3608+
3609+ // Ensure the stream is done before checking status.
3610+ <- stream .Done ()
3611+ if code := stream .Status ().Code (); code != codes .Unknown {
3612+ t .Fatalf ("stream.Status().Code() got %s, want %s" , code , codes .Unknown )
3613+ }
3614+ }
3615+
3616+ // Tests the scenario where the server sets the END_STREAM flag in the Trailers
3617+ // (HEADERS frame) and verifies that the client responds with a RST stream.
3618+ func (s ) TestClientSendsRSTStream_InTrailers (t * testing.T ) {
3619+ serverFrames := func (t * testing.T , framer * http2.Framer , streamID uint32 ) {
3620+ var buf bytes.Buffer
3621+ henc := hpack .NewEncoder (& buf )
3622+ henc .WriteField (hpack.HeaderField {Name : "content-type" , Value : "application/grpc" })
3623+ if err := framer .WriteHeaders (http2.HeadersFrameParam {
3624+ StreamID : streamID ,
3625+ BlockFragment : buf .Bytes (),
3626+ EndHeaders : true ,
3627+ EndStream : false ,
3628+ }); err != nil {
3629+ t .Errorf ("Server failed to write headers: %v" , err )
3630+ }
3631+ if err := framer .WriteData (streamID , false , expectedResponse ); err != nil {
3632+ t .Errorf ("Server failed to write data: %v" , err )
3633+ }
3634+ buf .Reset ()
3635+ henc = hpack .NewEncoder (& buf )
3636+ henc .WriteField (hpack.HeaderField {Name : "grpc-status" , Value : "0" })
3637+ if err := framer .WriteHeaders (http2.HeadersFrameParam {
3638+ StreamID : streamID ,
3639+ BlockFragment : buf .Bytes (),
3640+ EndHeaders : true ,
3641+ EndStream : true ,
3642+ }); err != nil {
3643+ t .Errorf ("Server failed to write trailers: %v" , err )
3644+ }
3645+ }
3646+
3647+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
3648+ defer cancel ()
3649+ stream , waitForServer := setupRSTStreamOnEOSTest (ctx , t , serverFrames )
3650+ defer waitForServer ()
3651+
3652+ // Wait for the stream to be closed.
3653+ <- stream .Done ()
3654+ if code := stream .Status ().Code (); code != codes .OK {
3655+ t .Fatalf ("stream.Status().Code() got %s, want %s" , code , codes .OK )
3656+ }
3657+ }
3658+
3659+ // Tests the scenario where the server sets the END_STREAM flag in one of its
3660+ // DATA frames (before sending trailers), causing the client to send a
3661+ // RST_STREAM. The test verifies that the client can still read buffered data
3662+ // from the stream after this event.
3663+ func (s ) TestClientSendsRSTStream_ReadUnreadData (t * testing.T ) {
3664+ serverFrames := func (t * testing.T , framer * http2.Framer , streamID uint32 ) {
3665+ var buf bytes.Buffer
3666+ henc := hpack .NewEncoder (& buf )
3667+ henc .WriteField (hpack.HeaderField {Name : "content-type" , Value : "application/grpc" })
3668+ if err := framer .WriteHeaders (http2.HeadersFrameParam {
3669+ StreamID : streamID ,
3670+ BlockFragment : buf .Bytes (),
3671+ EndHeaders : true ,
3672+ EndStream : false ,
3673+ }); err != nil {
3674+ t .Errorf ("Server failed to write headers: %v" , err )
3675+ }
3676+ if err := framer .WriteData (streamID , true , expectedResponse ); err != nil {
3677+ t .Errorf ("Server failed to write data: %v" , err )
3678+ }
3679+ }
3680+
3681+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
3682+ defer cancel ()
3683+ stream , waitForServer := setupRSTStreamOnEOSTest (ctx , t , serverFrames )
3684+ defer waitForServer ()
3685+
3686+ // Wait for the stream to match the state we expect (which is that it
3687+ // has sent a RST_STREAM, which means it has closed).
3688+ //
3689+ // If we read before the RST_STREAM is sent, we might race with the
3690+ // client receiving the EOS from the server, and the client might
3691+ // not have sent the RST_STREAM yet.
3692+ <- stream .Done ()
3693+
3694+ // Read the data.
3695+ gotData := make ([]byte , len (expectedResponse ))
3696+ if _ , err := stream .readTo (gotData ); err != nil {
3697+ t .Fatalf ("stream.readTo() got %v, want <nil>" , err )
3698+ }
3699+ if ! bytes .Equal (gotData , expectedResponse ) {
3700+ t .Fatalf ("stream.readTo() got %v, want %v" , gotData , expectedResponse )
3701+ }
3702+ if _ , err := stream .readTo (make ([]byte , 1 )); ! errors .Is (err , io .EOF ) {
3703+ t .Fatalf ("stream.readTo() got %v, want %v" , err , io .EOF )
3704+ }
3705+ if code := stream .Status ().Code (); code != codes .Internal {
3706+ t .Fatalf ("stream.Status().Code() got %s, want %s" , code , codes .Internal )
3707+ }
3708+ }
3709+
33593710// TestClientTransport_Handle1xxHeaders validates that 1xx HTTP status headers
33603711// are ignored and treated as a protocol error if END_STREAM is set.
33613712func (s ) TestClientTransport_Handle1xxHeaders (t * testing.T ) {
0 commit comments