| /* |
| * |
| * Copyright 2014 gRPC authors. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| package transport |
| |
| import ( |
| "bufio" |
| "bytes" |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "io" |
| "math" |
| "net" |
| "net/http" |
| "reflect" |
| "strconv" |
| "strings" |
| "sync" |
| "testing" |
| "time" |
| |
| "golang.org/x/net/context" |
| "golang.org/x/net/http2" |
| "golang.org/x/net/http2/hpack" |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/keepalive" |
| "google.golang.org/grpc/status" |
| ) |
| |
| type server struct { |
| lis net.Listener |
| port string |
| startedErr chan error // error (or nil) with server start value |
| mu sync.Mutex |
| conns map[ServerTransport]bool |
| h *testStreamHandler |
| } |
| |
| var ( |
| expectedRequest = []byte("ping") |
| expectedResponse = []byte("pong") |
| expectedRequestLarge = make([]byte, initialWindowSize*2) |
| expectedResponseLarge = make([]byte, initialWindowSize*2) |
| expectedInvalidHeaderField = "invalid/content-type" |
| ) |
| |
| type testStreamHandler struct { |
| t *http2Server |
| notify chan struct{} |
| } |
| |
| type hType int |
| |
| const ( |
| normal hType = iota |
| suspended |
| notifyCall |
| misbehaved |
| encodingRequiredStatus |
| invalidHeaderField |
| delayRead |
| delayWrite |
| pingpong |
| ) |
| |
| func (h *testStreamHandler) handleStreamAndNotify(s *Stream) { |
| if h.notify == nil { |
| return |
| } |
| go func() { |
| select { |
| case <-h.notify: |
| default: |
| close(h.notify) |
| } |
| }() |
| } |
| |
| func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { |
| req := expectedRequest |
| resp := expectedResponse |
| if s.Method() == "foo.Large" { |
| req = expectedRequestLarge |
| resp = expectedResponseLarge |
| } |
| p := make([]byte, len(req)) |
| _, err := s.Read(p) |
| if err != nil { |
| return |
| } |
| if !bytes.Equal(p, req) { |
| t.Fatalf("handleStream got %v, want %v", p, req) |
| } |
| // send a response back to the client. |
| h.t.Write(s, nil, resp, &Options{}) |
| // send the trailer to end the stream. |
| h.t.WriteStatus(s, status.New(codes.OK, "")) |
| } |
| |
| func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) { |
| header := make([]byte, 5) |
| for i := 0; i < 10; i++ { |
| if _, err := s.Read(header); err != nil { |
| t.Fatalf("Error on server while reading data header: %v", err) |
| } |
| sz := binary.BigEndian.Uint32(header[1:]) |
| msg := make([]byte, int(sz)) |
| if _, err := s.Read(msg); err != nil { |
| t.Fatalf("Error on server while reading message: %v", err) |
| } |
| buf := make([]byte, sz+5) |
| buf[0] = byte(0) |
| binary.BigEndian.PutUint32(buf[1:], uint32(sz)) |
| copy(buf[5:], msg) |
| h.t.Write(s, nil, buf, &Options{}) |
| } |
| } |
| |
| func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) { |
| conn, ok := s.ServerTransport().(*http2Server) |
| if !ok { |
| t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport()) |
| } |
| var sent int |
| p := make([]byte, http2MaxFrameLen) |
| for sent < initialWindowSize { |
| n := initialWindowSize - sent |
| // The last message may be smaller than http2MaxFrameLen |
| if n <= http2MaxFrameLen { |
| if s.Method() == "foo.Connection" { |
| // Violate connection level flow control window of client but do not |
| // violate any stream level windows. |
| p = make([]byte, n) |
| } else { |
| // Violate stream level flow control window of client. |
| p = make([]byte, n+1) |
| } |
| } |
| conn.controlBuf.put(&dataFrame{s.id, false, p, func() {}}) |
| sent += len(p) |
| } |
| } |
| |
| func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) { |
| // raw newline is not accepted by http2 framer so it must be encoded. |
| h.t.WriteStatus(s, encodingTestStatus) |
| } |
| |
| func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) { |
| headerFields := []hpack.HeaderField{} |
| headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField}) |
| h.t.controlBuf.put(&headerFrame{ |
| streamID: s.id, |
| hf: headerFields, |
| endStream: false, |
| }) |
| } |
| |
| func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) { |
| req := expectedRequest |
| resp := expectedResponse |
| if s.Method() == "foo.Large" { |
| req = expectedRequestLarge |
| resp = expectedResponseLarge |
| } |
| p := make([]byte, len(req)) |
| |
| // Wait before reading. Give time to client to start sending |
| // before server starts reading. |
| time.Sleep(2 * time.Second) |
| _, err := s.Read(p) |
| if err != nil { |
| t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err) |
| return |
| } |
| |
| if !bytes.Equal(p, req) { |
| t.Fatalf("handleStream got %v, want %v", p, req) |
| } |
| // send a response back to the client. |
| h.t.Write(s, nil, resp, &Options{}) |
| // send the trailer to end the stream. |
| h.t.WriteStatus(s, status.New(codes.OK, "")) |
| } |
| |
| func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) { |
| req := expectedRequest |
| resp := expectedResponse |
| if s.Method() == "foo.Large" { |
| req = expectedRequestLarge |
| resp = expectedResponseLarge |
| } |
| p := make([]byte, len(req)) |
| _, err := s.Read(p) |
| if err != nil { |
| t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err) |
| return |
| } |
| if !bytes.Equal(p, req) { |
| t.Fatalf("handleStream got %v, want %v", p, req) |
| } |
| |
| // Wait before sending. Give time to client to start reading |
| // before server starts sending. |
| time.Sleep(2 * time.Second) |
| h.t.Write(s, nil, resp, &Options{}) |
| // send the trailer to end the stream. |
| h.t.WriteStatus(s, status.New(codes.OK, "")) |
| } |
| |
| // start starts server. Other goroutines should block on s.readyChan for further operations. |
| func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) { |
| var err error |
| if port == 0 { |
| s.lis, err = net.Listen("tcp", "localhost:0") |
| } else { |
| s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port)) |
| } |
| if err != nil { |
| s.startedErr <- fmt.Errorf("failed to listen: %v", err) |
| return |
| } |
| _, p, err := net.SplitHostPort(s.lis.Addr().String()) |
| if err != nil { |
| s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err) |
| return |
| } |
| s.port = p |
| s.conns = make(map[ServerTransport]bool) |
| s.startedErr <- nil |
| for { |
| conn, err := s.lis.Accept() |
| if err != nil { |
| return |
| } |
| transport, err := NewServerTransport("http2", conn, serverConfig) |
| if err != nil { |
| return |
| } |
| s.mu.Lock() |
| if s.conns == nil { |
| s.mu.Unlock() |
| transport.Close() |
| return |
| } |
| s.conns[transport] = true |
| h := &testStreamHandler{t: transport.(*http2Server)} |
| s.h = h |
| s.mu.Unlock() |
| switch ht { |
| case notifyCall: |
| go transport.HandleStreams(h.handleStreamAndNotify, |
| func(ctx context.Context, _ string) context.Context { |
| return ctx |
| }) |
| case suspended: |
| go transport.HandleStreams(func(*Stream) {}, // Do nothing to handle the stream. |
| func(ctx context.Context, method string) context.Context { |
| return ctx |
| }) |
| case misbehaved: |
| go transport.HandleStreams(func(s *Stream) { |
| go h.handleStreamMisbehave(t, s) |
| }, func(ctx context.Context, method string) context.Context { |
| return ctx |
| }) |
| case encodingRequiredStatus: |
| go transport.HandleStreams(func(s *Stream) { |
| go h.handleStreamEncodingRequiredStatus(t, s) |
| }, func(ctx context.Context, method string) context.Context { |
| return ctx |
| }) |
| case invalidHeaderField: |
| go transport.HandleStreams(func(s *Stream) { |
| go h.handleStreamInvalidHeaderField(t, s) |
| }, func(ctx context.Context, method string) context.Context { |
| return ctx |
| }) |
| case delayRead: |
| go transport.HandleStreams(func(s *Stream) { |
| go h.handleStreamDelayRead(t, s) |
| }, func(ctx context.Context, method string) context.Context { |
| return ctx |
| }) |
| case delayWrite: |
| go transport.HandleStreams(func(s *Stream) { |
| go h.handleStreamDelayWrite(t, s) |
| }, func(ctx context.Context, method string) context.Context { |
| return ctx |
| }) |
| case pingpong: |
| go transport.HandleStreams(func(s *Stream) { |
| go h.handleStreamPingPong(t, s) |
| }, func(ctx context.Context, method string) context.Context { |
| return ctx |
| }) |
| default: |
| go transport.HandleStreams(func(s *Stream) { |
| go h.handleStream(t, s) |
| }, func(ctx context.Context, method string) context.Context { |
| return ctx |
| }) |
| } |
| } |
| } |
| |
| func (s *server) wait(t *testing.T, timeout time.Duration) { |
| select { |
| case err := <-s.startedErr: |
| if err != nil { |
| t.Fatal(err) |
| } |
| case <-time.After(timeout): |
| t.Fatalf("Timed out after %v waiting for server to be ready", timeout) |
| } |
| } |
| |
| func (s *server) stop() { |
| s.lis.Close() |
| s.mu.Lock() |
| for c := range s.conns { |
| c.Close() |
| } |
| s.conns = nil |
| s.mu.Unlock() |
| } |
| |
| func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) { |
| return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{}) |
| } |
| |
| func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hType, copts ConnectOptions) (*server, ClientTransport) { |
| server := &server{startedErr: make(chan error, 1)} |
| go server.start(t, port, serverConfig, ht) |
| server.wait(t, 2*time.Second) |
| addr := "localhost:" + server.port |
| var ( |
| ct ClientTransport |
| connErr error |
| ) |
| target := TargetInfo{ |
| Addr: addr, |
| } |
| ct, connErr = NewClientTransport(context.Background(), target, copts, 2*time.Second) |
| if connErr != nil { |
| t.Fatalf("failed to create transport: %v", connErr) |
| } |
| return server, ct |
| } |
| |
| func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Conn) ClientTransport { |
| lis, err := net.Listen("tcp", "localhost:0") |
| if err != nil { |
| t.Fatalf("Failed to listen: %v", err) |
| } |
| // Launch a non responsive server. |
| go func() { |
| defer lis.Close() |
| conn, err := lis.Accept() |
| if err != nil { |
| t.Errorf("Error at server-side while accepting: %v", err) |
| close(done) |
| return |
| } |
| done <- conn |
| }() |
| tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts, 2*time.Second) |
| if err != nil { |
| // Server clean-up. |
| lis.Close() |
| if conn, ok := <-done; ok { |
| conn.Close() |
| } |
| t.Fatalf("Failed to dial: %v", err) |
| } |
| return tr |
| } |
| |
| // TestInflightStreamClosing ensures that closing in-flight stream |
| // sends StreamError to concurrent stream reader. |
| func TestInflightStreamClosing(t *testing.T) { |
| serverConfig := &ServerConfig{} |
| server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) |
| defer server.stop() |
| defer client.Close() |
| |
| stream, err := client.NewStream(context.Background(), &CallHdr{}) |
| if err != nil { |
| t.Fatalf("Client failed to create RPC request: %v", err) |
| } |
| |
| donec := make(chan struct{}) |
| serr := StreamError{Desc: "client connection is closing"} |
| go func() { |
| defer close(donec) |
| if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr { |
| t.Errorf("unexpected Stream error %v, expected %v", err, serr) |
| } |
| }() |
| |
| // should unblock concurrent stream.Read |
| client.CloseStream(stream, serr) |
| |
| // wait for stream.Read error |
| timeout := time.NewTimer(5 * time.Second) |
| select { |
| case <-donec: |
| if !timeout.Stop() { |
| <-timeout.C |
| } |
| case <-timeout.C: |
| t.Fatalf("Test timed out, expected a StreamError.") |
| } |
| } |
| |
| // TestMaxConnectionIdle tests that a server will send GoAway to a idle client. |
| // An idle client is one who doesn't make any RPC calls for a duration of |
| // MaxConnectionIdle time. |
| func TestMaxConnectionIdle(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| KeepaliveParams: keepalive.ServerParameters{ |
| MaxConnectionIdle: 2 * time.Second, |
| }, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) |
| defer server.stop() |
| defer client.Close() |
| stream, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) |
| if err != nil { |
| t.Fatalf("Client failed to create RPC request: %v", err) |
| } |
| stream.mu.Lock() |
| stream.rstStream = true |
| stream.mu.Unlock() |
| client.CloseStream(stream, nil) |
| // wait for server to see that closed stream and max-age logic to send goaway after no new RPCs are mode |
| timeout := time.NewTimer(time.Second * 4) |
| select { |
| case <-client.GoAway(): |
| if !timeout.Stop() { |
| <-timeout.C |
| } |
| case <-timeout.C: |
| t.Fatalf("Test timed out, expected a GoAway from the server.") |
| } |
| } |
| |
| // TestMaxConenctionIdleNegative tests that a server will not send GoAway to a non-idle(busy) client. |
| func TestMaxConnectionIdleNegative(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| KeepaliveParams: keepalive.ServerParameters{ |
| MaxConnectionIdle: 2 * time.Second, |
| }, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) |
| defer server.stop() |
| defer client.Close() |
| _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) |
| if err != nil { |
| t.Fatalf("Client failed to create RPC request: %v", err) |
| } |
| timeout := time.NewTimer(time.Second * 4) |
| select { |
| case <-client.GoAway(): |
| if !timeout.Stop() { |
| <-timeout.C |
| } |
| t.Fatalf("A non-idle client received a GoAway.") |
| case <-timeout.C: |
| } |
| |
| } |
| |
| // TestMaxConnectionAge tests that a server will send GoAway after a duration of MaxConnectionAge. |
| func TestMaxConnectionAge(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| KeepaliveParams: keepalive.ServerParameters{ |
| MaxConnectionAge: 2 * time.Second, |
| }, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) |
| defer server.stop() |
| defer client.Close() |
| _, err := client.NewStream(context.Background(), &CallHdr{}) |
| if err != nil { |
| t.Fatalf("Client failed to create stream: %v", err) |
| } |
| // Wait for max-age logic to send GoAway. |
| timeout := time.NewTimer(4 * time.Second) |
| select { |
| case <-client.GoAway(): |
| if !timeout.Stop() { |
| <-timeout.C |
| } |
| case <-timeout.C: |
| t.Fatalf("Test timer out, expected a GoAway from the server.") |
| } |
| } |
| |
| // TestKeepaliveServer tests that a server closes connection with a client that doesn't respond to keepalive pings. |
| func TestKeepaliveServer(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| KeepaliveParams: keepalive.ServerParameters{ |
| Time: 2 * time.Second, |
| Timeout: 1 * time.Second, |
| }, |
| } |
| server, c := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) |
| defer server.stop() |
| defer c.Close() |
| client, err := net.Dial("tcp", server.lis.Addr().String()) |
| if err != nil { |
| t.Fatalf("Failed to dial: %v", err) |
| } |
| defer client.Close() |
| // Set read deadline on client conn so that it doesn't block forever in errorsome cases. |
| client.SetReadDeadline(time.Now().Add(10 * time.Second)) |
| // Wait for keepalive logic to close the connection. |
| time.Sleep(4 * time.Second) |
| b := make([]byte, 24) |
| for { |
| _, err = client.Read(b) |
| if err == nil { |
| continue |
| } |
| if err != io.EOF { |
| t.Fatalf("client.Read(_) = _,%v, want io.EOF", err) |
| } |
| break |
| } |
| } |
| |
| // TestKeepaliveServerNegative tests that a server doesn't close connection with a client that responds to keepalive pings. |
| func TestKeepaliveServerNegative(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| KeepaliveParams: keepalive.ServerParameters{ |
| Time: 2 * time.Second, |
| Timeout: 1 * time.Second, |
| }, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) |
| defer server.stop() |
| defer client.Close() |
| // Give keepalive logic some time by sleeping. |
| time.Sleep(4 * time.Second) |
| // Assert that client is still active. |
| clientTr := client.(*http2Client) |
| clientTr.mu.Lock() |
| defer clientTr.mu.Unlock() |
| if clientTr.state != reachable { |
| t.Fatalf("Test failed: Expected server-client connection to be healthy.") |
| } |
| } |
| |
| func TestKeepaliveClientClosesIdleTransport(t *testing.T) { |
| done := make(chan net.Conn, 1) |
| tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ |
| Time: 2 * time.Second, // Keepalive time = 2 sec. |
| Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. |
| PermitWithoutStream: true, // Run keepalive even with no RPCs. |
| }}, done) |
| defer tr.Close() |
| conn, ok := <-done |
| if !ok { |
| t.Fatalf("Server didn't return connection object") |
| } |
| defer conn.Close() |
| // Sleep for keepalive to close the connection. |
| time.Sleep(4 * time.Second) |
| // Assert that the connection was closed. |
| ct := tr.(*http2Client) |
| ct.mu.Lock() |
| defer ct.mu.Unlock() |
| if ct.state == reachable { |
| t.Fatalf("Test Failed: Expected client transport to have closed.") |
| } |
| } |
| |
| func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { |
| done := make(chan net.Conn, 1) |
| tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ |
| Time: 2 * time.Second, // Keepalive time = 2 sec. |
| Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. |
| }}, done) |
| defer tr.Close() |
| conn, ok := <-done |
| if !ok { |
| t.Fatalf("server didn't reutrn connection object") |
| } |
| defer conn.Close() |
| // Give keepalive some time. |
| time.Sleep(4 * time.Second) |
| // Assert that connections is still healthy. |
| ct := tr.(*http2Client) |
| ct.mu.Lock() |
| defer ct.mu.Unlock() |
| if ct.state != reachable { |
| t.Fatalf("Test failed: Expected client transport to be healthy.") |
| } |
| } |
| |
| func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { |
| done := make(chan net.Conn, 1) |
| tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ |
| Time: 2 * time.Second, // Keepalive time = 2 sec. |
| Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. |
| }}, done) |
| defer tr.Close() |
| conn, ok := <-done |
| if !ok { |
| t.Fatalf("Server didn't return connection object") |
| } |
| defer conn.Close() |
| // Create a stream. |
| _, err := tr.NewStream(context.Background(), &CallHdr{Flush: true}) |
| if err != nil { |
| t.Fatalf("Failed to create a new stream: %v", err) |
| } |
| // Give keepalive some time. |
| time.Sleep(4 * time.Second) |
| // Assert that transport was closed. |
| ct := tr.(*http2Client) |
| ct.mu.Lock() |
| defer ct.mu.Unlock() |
| if ct.state == reachable { |
| t.Fatalf("Test failed: Expected client transport to have closed.") |
| } |
| } |
| |
| func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { |
| s, tr := setUpWithOptions(t, 0, &ServerConfig{MaxStreams: math.MaxUint32}, normal, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ |
| Time: 2 * time.Second, // Keepalive time = 2 sec. |
| Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. |
| PermitWithoutStream: true, // Run keepalive even with no RPCs. |
| }}) |
| defer s.stop() |
| defer tr.Close() |
| // Give keep alive some time. |
| time.Sleep(4 * time.Second) |
| // Assert that transport is healthy. |
| ct := tr.(*http2Client) |
| ct.mu.Lock() |
| defer ct.mu.Unlock() |
| if ct.state != reachable { |
| t.Fatalf("Test failed: Expected client transport to be healthy.") |
| } |
| } |
| |
| func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| KeepalivePolicy: keepalive.EnforcementPolicy{ |
| MinTime: 2 * time.Second, |
| }, |
| } |
| clientOptions := ConnectOptions{ |
| KeepaliveParams: keepalive.ClientParameters{ |
| Time: 50 * time.Millisecond, |
| Timeout: 50 * time.Millisecond, |
| PermitWithoutStream: true, |
| }, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) |
| defer server.stop() |
| defer client.Close() |
| |
| timeout := time.NewTimer(2 * time.Second) |
| select { |
| case <-client.GoAway(): |
| if !timeout.Stop() { |
| <-timeout.C |
| } |
| case <-timeout.C: |
| t.Fatalf("Test failed: Expected a GoAway from server.") |
| } |
| time.Sleep(500 * time.Millisecond) |
| ct := client.(*http2Client) |
| ct.mu.Lock() |
| defer ct.mu.Unlock() |
| if ct.state == reachable { |
| t.Fatalf("Test failed: Expected the connection to be closed.") |
| } |
| } |
| |
| func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| KeepalivePolicy: keepalive.EnforcementPolicy{ |
| MinTime: 2 * time.Second, |
| }, |
| } |
| clientOptions := ConnectOptions{ |
| KeepaliveParams: keepalive.ClientParameters{ |
| Time: 50 * time.Millisecond, |
| Timeout: 50 * time.Millisecond, |
| }, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) |
| defer server.stop() |
| defer client.Close() |
| |
| if _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}); err != nil { |
| t.Fatalf("Client failed to create stream.") |
| } |
| timeout := time.NewTimer(2 * time.Second) |
| select { |
| case <-client.GoAway(): |
| if !timeout.Stop() { |
| <-timeout.C |
| } |
| case <-timeout.C: |
| t.Fatalf("Test failed: Expected a GoAway from server.") |
| } |
| time.Sleep(500 * time.Millisecond) |
| ct := client.(*http2Client) |
| ct.mu.Lock() |
| defer ct.mu.Unlock() |
| if ct.state == reachable { |
| t.Fatalf("Test failed: Expected the connection to be closed.") |
| } |
| } |
| |
| func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| KeepalivePolicy: keepalive.EnforcementPolicy{ |
| MinTime: 100 * time.Millisecond, |
| PermitWithoutStream: true, |
| }, |
| } |
| clientOptions := ConnectOptions{ |
| KeepaliveParams: keepalive.ClientParameters{ |
| Time: 101 * time.Millisecond, |
| Timeout: 50 * time.Millisecond, |
| PermitWithoutStream: true, |
| }, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) |
| defer server.stop() |
| defer client.Close() |
| |
| // Give keepalive enough time. |
| time.Sleep(2 * time.Second) |
| // Assert that connection is healthy. |
| ct := client.(*http2Client) |
| ct.mu.Lock() |
| defer ct.mu.Unlock() |
| if ct.state != reachable { |
| t.Fatalf("Test failed: Expected connection to be healthy.") |
| } |
| } |
| |
| func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| KeepalivePolicy: keepalive.EnforcementPolicy{ |
| MinTime: 100 * time.Millisecond, |
| }, |
| } |
| clientOptions := ConnectOptions{ |
| KeepaliveParams: keepalive.ClientParameters{ |
| Time: 101 * time.Millisecond, |
| Timeout: 50 * time.Millisecond, |
| }, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) |
| defer server.stop() |
| defer client.Close() |
| |
| if _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}); err != nil { |
| t.Fatalf("Client failed to create stream.") |
| } |
| |
| // Give keepalive enough time. |
| time.Sleep(2 * time.Second) |
| // Assert that connection is healthy. |
| ct := client.(*http2Client) |
| ct.mu.Lock() |
| defer ct.mu.Unlock() |
| if ct.state != reachable { |
| t.Fatalf("Test failed: Expected connection to be healthy.") |
| } |
| } |
| |
| func TestClientSendAndReceive(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, normal) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo.Small", |
| } |
| s1, err1 := ct.NewStream(context.Background(), callHdr) |
| if err1 != nil { |
| t.Fatalf("failed to open stream: %v", err1) |
| } |
| if s1.id != 1 { |
| t.Fatalf("wrong stream id: %d", s1.id) |
| } |
| s2, err2 := ct.NewStream(context.Background(), callHdr) |
| if err2 != nil { |
| t.Fatalf("failed to open stream: %v", err2) |
| } |
| if s2.id != 3 { |
| t.Fatalf("wrong stream id: %d", s2.id) |
| } |
| opts := Options{ |
| Last: true, |
| Delay: false, |
| } |
| if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF { |
| t.Fatalf("failed to send data: %v", err) |
| } |
| p := make([]byte, len(expectedResponse)) |
| _, recvErr := s1.Read(p) |
| if recvErr != nil || !bytes.Equal(p, expectedResponse) { |
| t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse) |
| } |
| _, recvErr = s1.Read(p) |
| if recvErr != io.EOF { |
| t.Fatalf("Error: %v; want <EOF>", recvErr) |
| } |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestClientErrorNotify(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, normal) |
| go server.stop() |
| // ct.reader should detect the error and activate ct.Error(). |
| <-ct.Error() |
| ct.Close() |
| } |
| |
| func performOneRPC(ct ClientTransport) { |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo.Small", |
| } |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| return |
| } |
| opts := Options{ |
| Last: true, |
| Delay: false, |
| } |
| if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF { |
| time.Sleep(5 * time.Millisecond) |
| // The following s.Recv()'s could error out because the |
| // underlying transport is gone. |
| // |
| // Read response |
| p := make([]byte, len(expectedResponse)) |
| s.Read(p) |
| // Read io.EOF |
| s.Read(p) |
| } |
| } |
| |
| func TestClientMix(t *testing.T) { |
| s, ct := setUp(t, 0, math.MaxUint32, normal) |
| go func(s *server) { |
| time.Sleep(5 * time.Second) |
| s.stop() |
| }(s) |
| go func(ct ClientTransport) { |
| <-ct.Error() |
| ct.Close() |
| }(ct) |
| for i := 0; i < 1000; i++ { |
| time.Sleep(10 * time.Millisecond) |
| go performOneRPC(ct) |
| } |
| } |
| |
| func TestLargeMessage(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, normal) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo.Large", |
| } |
| var wg sync.WaitGroup |
| for i := 0; i < 2; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) |
| } |
| if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { |
| t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) |
| } |
| p := make([]byte, len(expectedResponseLarge)) |
| if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { |
| t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse) |
| } |
| if _, err = s.Read(p); err != io.EOF { |
| t.Errorf("Failed to complete the stream %v; want <EOF>", err) |
| } |
| }() |
| } |
| wg.Wait() |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestLargeMessageWithDelayRead(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, delayRead) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo.Large", |
| } |
| var wg sync.WaitGroup |
| for i := 0; i < 2; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) |
| } |
| if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { |
| t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) |
| } |
| p := make([]byte, len(expectedResponseLarge)) |
| |
| // Give time to server to begin sending before client starts reading. |
| time.Sleep(2 * time.Second) |
| if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { |
| t.Errorf("s.Read(_) = _, %v, want _, <nil>", err) |
| } |
| if _, err = s.Read(p); err != io.EOF { |
| t.Errorf("Failed to complete the stream %v; want <EOF>", err) |
| } |
| }() |
| } |
| wg.Wait() |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestLargeMessageDelayWrite(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, delayWrite) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo.Large", |
| } |
| var wg sync.WaitGroup |
| for i := 0; i < 2; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) |
| } |
| |
| // Give time to server to start reading before client starts sending. |
| time.Sleep(2 * time.Second) |
| if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { |
| t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err) |
| } |
| p := make([]byte, len(expectedResponseLarge)) |
| if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) { |
| t.Errorf("io.ReadFull(%v) = _, %v, want %v, <nil>", err, p, expectedResponse) |
| } |
| if _, err = s.Read(p); err != io.EOF { |
| t.Errorf("Failed to complete the stream %v; want <EOF>", err) |
| } |
| }() |
| } |
| wg.Wait() |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestGracefulClose(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, normal) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo.Small", |
| } |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err) |
| } |
| if err = ct.GracefulClose(); err != nil { |
| t.Fatalf("%v.GracefulClose() = %v, want <nil>", ct, err) |
| } |
| var wg sync.WaitGroup |
| // Expect the failure for all the follow-up streams because ct has been closed gracefully. |
| for i := 0; i < 100; i++ { |
| wg.Add(1) |
| go func() { |
| defer wg.Done() |
| if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain { |
| t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain) |
| } |
| }() |
| } |
| opts := Options{ |
| Last: true, |
| Delay: false, |
| } |
| // The stream which was created before graceful close can still proceed. |
| if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF { |
| t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err) |
| } |
| p := make([]byte, len(expectedResponse)) |
| if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponse) { |
| t.Fatalf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse) |
| } |
| if _, err = s.Read(p); err != io.EOF { |
| t.Fatalf("Failed to complete the stream %v; want <EOF>", err) |
| } |
| wg.Wait() |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestLargeMessageSuspension(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, suspended) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo.Large", |
| } |
| // Set a long enough timeout for writing a large message out. |
| ctx, cancel := context.WithTimeout(context.Background(), time.Second) |
| defer cancel() |
| s, err := ct.NewStream(ctx, callHdr) |
| if err != nil { |
| t.Fatalf("failed to open stream: %v", err) |
| } |
| // Write should not be done successfully due to flow control. |
| msg := make([]byte, initialWindowSize*8) |
| err = ct.Write(s, nil, msg, &Options{Last: true, Delay: false}) |
| expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded) |
| if err != expectedErr { |
| t.Fatalf("Write got %v, want %v", err, expectedErr) |
| } |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestMaxStreams(t *testing.T) { |
| server, ct := setUp(t, 0, 1, suspended) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo.Large", |
| } |
| // Have a pending stream which takes all streams quota. |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| t.Fatalf("Failed to open stream: %v", err) |
| } |
| cc, ok := ct.(*http2Client) |
| if !ok { |
| t.Fatalf("Failed to convert %v to *http2Client", ct) |
| } |
| done := make(chan struct{}) |
| ch := make(chan int) |
| ready := make(chan struct{}) |
| go func() { |
| for { |
| select { |
| case <-time.After(5 * time.Millisecond): |
| select { |
| case ch <- 0: |
| case <-ready: |
| return |
| } |
| case <-time.After(5 * time.Second): |
| close(done) |
| return |
| case <-ready: |
| return |
| } |
| } |
| }() |
| for { |
| select { |
| case <-ch: |
| case <-done: |
| t.Fatalf("Client has not received the max stream setting in 5 seconds.") |
| } |
| cc.mu.Lock() |
| // cc.maxStreams should be equal to 1 after having received settings frame from |
| // server. |
| if cc.maxStreams == 1 { |
| cc.mu.Unlock() |
| select { |
| case <-cc.streamsQuota.acquire(): |
| t.Fatalf("streamsQuota.acquire() becomes readable mistakenly.") |
| default: |
| cc.streamsQuota.mu.Lock() |
| quota := cc.streamsQuota.quota |
| cc.streamsQuota.mu.Unlock() |
| if quota != 0 { |
| t.Fatalf("streamsQuota.quota got non-zero quota mistakenly.") |
| } |
| } |
| break |
| } |
| cc.mu.Unlock() |
| } |
| close(ready) |
| // Close the pending stream so that the streams quota becomes available for the next new stream. |
| ct.CloseStream(s, nil) |
| select { |
| case i := <-cc.streamsQuota.acquire(): |
| if i != 1 { |
| t.Fatalf("streamsQuota.acquire() got %d quota, want 1.", i) |
| } |
| cc.streamsQuota.add(i) |
| default: |
| t.Fatalf("streamsQuota.acquire() is not readable.") |
| } |
| if _, err := ct.NewStream(context.Background(), callHdr); err != nil { |
| t.Fatalf("Failed to open stream: %v", err) |
| } |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestServerContextCanceledOnClosedConnection(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, suspended) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo", |
| } |
| var sc *http2Server |
| // Wait until the server transport is setup. |
| for { |
| server.mu.Lock() |
| if len(server.conns) == 0 { |
| server.mu.Unlock() |
| time.Sleep(time.Millisecond) |
| continue |
| } |
| for k := range server.conns { |
| var ok bool |
| sc, ok = k.(*http2Server) |
| if !ok { |
| t.Fatalf("Failed to convert %v to *http2Server", k) |
| } |
| } |
| server.mu.Unlock() |
| break |
| } |
| cc, ok := ct.(*http2Client) |
| if !ok { |
| t.Fatalf("Failed to convert %v to *http2Client", ct) |
| } |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| t.Fatalf("Failed to open stream: %v", err) |
| } |
| cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, http2MaxFrameLen), func() {}}) |
| // Loop until the server side stream is created. |
| var ss *Stream |
| for { |
| time.Sleep(time.Second) |
| sc.mu.Lock() |
| if len(sc.activeStreams) == 0 { |
| sc.mu.Unlock() |
| continue |
| } |
| ss = sc.activeStreams[s.id] |
| sc.mu.Unlock() |
| break |
| } |
| cc.Close() |
| select { |
| case <-ss.Context().Done(): |
| if ss.Context().Err() != context.Canceled { |
| t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled) |
| } |
| case <-time.After(5 * time.Second): |
| t.Fatalf("Failed to cancel the context of the sever side stream.") |
| } |
| server.stop() |
| } |
| |
| func TestClientConnDecoupledFromApplicationRead(t *testing.T) { |
| connectOptions := ConnectOptions{ |
| InitialWindowSize: defaultWindowSize, |
| InitialConnWindowSize: defaultWindowSize, |
| } |
| server, client := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions) |
| defer server.stop() |
| defer client.Close() |
| |
| waitWhileTrue(t, func() (bool, error) { |
| server.mu.Lock() |
| defer server.mu.Unlock() |
| |
| if len(server.conns) == 0 { |
| return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") |
| } |
| return false, nil |
| }) |
| |
| var st *http2Server |
| server.mu.Lock() |
| for k := range server.conns { |
| st = k.(*http2Server) |
| } |
| notifyChan := make(chan struct{}) |
| server.h.notify = notifyChan |
| server.mu.Unlock() |
| cstream1, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) |
| if err != nil { |
| t.Fatalf("Client failed to create first stream. Err: %v", err) |
| } |
| |
| <-notifyChan |
| var sstream1 *Stream |
| // Access stream on the server. |
| st.mu.Lock() |
| for _, v := range st.activeStreams { |
| if v.id == cstream1.id { |
| sstream1 = v |
| } |
| } |
| st.mu.Unlock() |
| if sstream1 == nil { |
| t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id) |
| } |
| // Exhaust client's connection window. |
| if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { |
| t.Fatalf("Server failed to write data. Err: %v", err) |
| } |
| notifyChan = make(chan struct{}) |
| server.mu.Lock() |
| server.h.notify = notifyChan |
| server.mu.Unlock() |
| // Create another stream on client. |
| cstream2, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) |
| if err != nil { |
| t.Fatalf("Client failed to create second stream. Err: %v", err) |
| } |
| <-notifyChan |
| var sstream2 *Stream |
| st.mu.Lock() |
| for _, v := range st.activeStreams { |
| if v.id == cstream2.id { |
| sstream2 = v |
| } |
| } |
| st.mu.Unlock() |
| if sstream2 == nil { |
| t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id) |
| } |
| // Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream. |
| if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil { |
| t.Fatalf("Server failed to write data. Err: %v", err) |
| } |
| |
| // Client should be able to read data on second stream. |
| if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil { |
| t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) |
| } |
| |
| // Client should be able to read data on first stream. |
| if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil { |
| t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err) |
| } |
| } |
| |
| func TestServerConnDecoupledFromApplicationRead(t *testing.T) { |
| serverConfig := &ServerConfig{ |
| InitialWindowSize: defaultWindowSize, |
| InitialConnWindowSize: defaultWindowSize, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) |
| defer server.stop() |
| defer client.Close() |
| waitWhileTrue(t, func() (bool, error) { |
| server.mu.Lock() |
| defer server.mu.Unlock() |
| |
| if len(server.conns) == 0 { |
| return true, fmt.Errorf("timed-out while waiting for connection to be created on the server") |
| } |
| return false, nil |
| }) |
| var st *http2Server |
| server.mu.Lock() |
| for k := range server.conns { |
| st = k.(*http2Server) |
| } |
| server.mu.Unlock() |
| cstream1, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) |
| if err != nil { |
| t.Fatalf("Failed to create 1st stream. Err: %v", err) |
| } |
| // Exhaust server's connection window. |
| if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil { |
| t.Fatalf("Client failed to write data. Err: %v", err) |
| } |
| //Client should be able to create another stream and send data on it. |
| cstream2, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) |
| if err != nil { |
| t.Fatalf("Failed to create 2nd stream. Err: %v", err) |
| } |
| if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil { |
| t.Fatalf("Client failed to write data. Err: %v", err) |
| } |
| // Get the streams on server. |
| waitWhileTrue(t, func() (bool, error) { |
| st.mu.Lock() |
| defer st.mu.Unlock() |
| |
| if len(st.activeStreams) != 2 { |
| return true, fmt.Errorf("timed-out while waiting for server to have created the streams") |
| } |
| return false, nil |
| }) |
| var sstream1 *Stream |
| st.mu.Lock() |
| for _, v := range st.activeStreams { |
| if v.id == 1 { |
| sstream1 = v |
| } |
| } |
| st.mu.Unlock() |
| // Trying to write more on a max-ed out stream should result in a RST_STREAM from the server. |
| ct := client.(*http2Client) |
| ct.controlBuf.put(&dataFrame{cstream2.id, true, make([]byte, 1), func() {}}) |
| code := http2ErrConvTab[http2.ErrCodeFlowControl] |
| waitWhileTrue(t, func() (bool, error) { |
| cstream2.mu.Lock() |
| defer cstream2.mu.Unlock() |
| if cstream2.status.Code() != code { |
| return true, fmt.Errorf("want code = %v, got %v", code, cstream2.status.Code()) |
| } |
| return false, nil |
| }) |
| // Reading from the stream on server should succeed. |
| if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil { |
| t.Fatalf("_.Read(_) = %v, want <nil>", err) |
| } |
| |
| if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF { |
| t.Fatalf("_.Read(_) = %v, want io.EOF", err) |
| } |
| |
| } |
| |
| func TestServerWithMisbehavedClient(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, suspended) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo", |
| } |
| var sc *http2Server |
| // Wait until the server transport is setup. |
| for { |
| server.mu.Lock() |
| if len(server.conns) == 0 { |
| server.mu.Unlock() |
| time.Sleep(time.Millisecond) |
| continue |
| } |
| for k := range server.conns { |
| var ok bool |
| sc, ok = k.(*http2Server) |
| if !ok { |
| t.Fatalf("Failed to convert %v to *http2Server", k) |
| } |
| } |
| server.mu.Unlock() |
| break |
| } |
| cc, ok := ct.(*http2Client) |
| if !ok { |
| t.Fatalf("Failed to convert %v to *http2Client", ct) |
| } |
| // Test server behavior for violation of stream flow control window size restriction. |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| t.Fatalf("Failed to open stream: %v", err) |
| } |
| var sent int |
| // Drain the stream flow control window |
| cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, http2MaxFrameLen), func() {}}) |
| sent += http2MaxFrameLen |
| // Wait until the server creates the corresponding stream and receive some data. |
| var ss *Stream |
| for { |
| time.Sleep(time.Millisecond) |
| sc.mu.Lock() |
| if len(sc.activeStreams) == 0 { |
| sc.mu.Unlock() |
| continue |
| } |
| ss = sc.activeStreams[s.id] |
| sc.mu.Unlock() |
| ss.fc.mu.Lock() |
| if ss.fc.pendingData > 0 { |
| ss.fc.mu.Unlock() |
| break |
| } |
| ss.fc.mu.Unlock() |
| } |
| if ss.fc.pendingData != http2MaxFrameLen || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate != 0 { |
| t.Fatalf("Server mistakenly updates inbound flow control params: got %d, %d, %d, %d; want %d, %d, %d, %d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, http2MaxFrameLen, 0, 0, 0) |
| } |
| // Keep sending until the server inbound window is drained for that stream. |
| for sent <= initialWindowSize { |
| cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, 1), func() {}}) |
| sent++ |
| } |
| // Server sent a resetStream for s already. |
| code := http2ErrConvTab[http2.ErrCodeFlowControl] |
| if _, err := s.Read(make([]byte, 1)); err != io.EOF { |
| t.Fatalf("%v got err %v want <EOF>", s, err) |
| } |
| if s.status.Code() != code { |
| t.Fatalf("%v got status %v; want Code=%v", s, s.status, code) |
| } |
| |
| ct.CloseStream(s, nil) |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestClientWithMisbehavedServer(t *testing.T) { |
| // Turn off BDP estimation so that the server can |
| // violate stream window. |
| connectOptions := ConnectOptions{ |
| InitialWindowSize: initialWindowSize, |
| } |
| server, ct := setUpWithOptions(t, 0, &ServerConfig{}, misbehaved, connectOptions) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo.Stream", |
| } |
| conn, ok := ct.(*http2Client) |
| if !ok { |
| t.Fatalf("Failed to convert %v to *http2Client", ct) |
| } |
| // Test the logic for the violation of stream flow control window size restriction. |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| t.Fatalf("Failed to open stream: %v", err) |
| } |
| d := make([]byte, 1) |
| if err := ct.Write(s, nil, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF { |
| t.Fatalf("Failed to write: %v", err) |
| } |
| // Read without window update. |
| for { |
| p := make([]byte, http2MaxFrameLen) |
| if _, err = s.trReader.(*transportReader).reader.Read(p); err != nil { |
| break |
| } |
| } |
| if s.fc.pendingData <= initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate != 0 { |
| t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want >%d, %d, %d, >%d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, 0, 0) |
| } |
| |
| if err != io.EOF { |
| t.Fatalf("Got err %v, want <EOF>", err) |
| } |
| if s.status.Code() != codes.Internal { |
| t.Fatalf("Got s.status %v, want s.status.Code()=Internal", s.status) |
| } |
| |
| conn.CloseStream(s, err) |
| ct.Close() |
| server.stop() |
| } |
| |
| var encodingTestStatus = status.New(codes.Internal, "\n") |
| |
| func TestEncodingRequiredStatus(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo", |
| } |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| return |
| } |
| opts := Options{ |
| Last: true, |
| Delay: false, |
| } |
| if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF { |
| t.Fatalf("Failed to write the request: %v", err) |
| } |
| p := make([]byte, http2MaxFrameLen) |
| if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF { |
| t.Fatalf("Read got error %v, want %v", err, io.EOF) |
| } |
| if !reflect.DeepEqual(s.Status(), encodingTestStatus) { |
| t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus) |
| } |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestInvalidHeaderField(t *testing.T) { |
| server, ct := setUp(t, 0, math.MaxUint32, invalidHeaderField) |
| callHdr := &CallHdr{ |
| Host: "localhost", |
| Method: "foo", |
| } |
| s, err := ct.NewStream(context.Background(), callHdr) |
| if err != nil { |
| return |
| } |
| opts := Options{ |
| Last: true, |
| Delay: false, |
| } |
| if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF { |
| t.Fatalf("Failed to write the request: %v", err) |
| } |
| p := make([]byte, http2MaxFrameLen) |
| _, err = s.trReader.(*transportReader).Read(p) |
| if se, ok := err.(StreamError); !ok || se.Code != codes.FailedPrecondition || !strings.Contains(err.Error(), expectedInvalidHeaderField) { |
| t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.FailedPrecondition, expectedInvalidHeaderField) |
| } |
| ct.Close() |
| server.stop() |
| } |
| |
| func TestStreamContext(t *testing.T) { |
| expectedStream := &Stream{} |
| ctx := newContextWithStream(context.Background(), expectedStream) |
| s, ok := StreamFromContext(ctx) |
| if !ok || expectedStream != s { |
| t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, s, ok, expectedStream) |
| } |
| } |
| |
| func TestIsReservedHeader(t *testing.T) { |
| tests := []struct { |
| h string |
| want bool |
| }{ |
| {"", false}, // but should be rejected earlier |
| {"foo", false}, |
| {"content-type", true}, |
| {"grpc-message-type", true}, |
| {"grpc-encoding", true}, |
| {"grpc-message", true}, |
| {"grpc-status", true}, |
| {"grpc-timeout", true}, |
| {"te", true}, |
| } |
| for _, tt := range tests { |
| got := isReservedHeader(tt.h) |
| if got != tt.want { |
| t.Errorf("isReservedHeader(%q) = %v; want %v", tt.h, got, tt.want) |
| } |
| } |
| } |
| |
| func TestContextErr(t *testing.T) { |
| for _, test := range []struct { |
| // input |
| errIn error |
| // outputs |
| errOut StreamError |
| }{ |
| {context.DeadlineExceeded, StreamError{codes.DeadlineExceeded, context.DeadlineExceeded.Error()}}, |
| {context.Canceled, StreamError{codes.Canceled, context.Canceled.Error()}}, |
| } { |
| err := ContextErr(test.errIn) |
| if err != test.errOut { |
| t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut) |
| } |
| } |
| } |
| |
| func max(a, b int32) int32 { |
| if a > b { |
| return a |
| } |
| return b |
| } |
| |
| type windowSizeConfig struct { |
| serverStream int32 |
| serverConn int32 |
| clientStream int32 |
| clientConn int32 |
| } |
| |
| func TestAccountCheckWindowSizeWithLargeWindow(t *testing.T) { |
| wc := windowSizeConfig{ |
| serverStream: 10 * 1024 * 1024, |
| serverConn: 12 * 1024 * 1024, |
| clientStream: 6 * 1024 * 1024, |
| clientConn: 8 * 1024 * 1024, |
| } |
| testAccountCheckWindowSize(t, wc) |
| } |
| |
| func TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) { |
| wc := windowSizeConfig{ |
| serverStream: defaultWindowSize, |
| // Note this is smaller than initialConnWindowSize which is the current default. |
| serverConn: defaultWindowSize, |
| clientStream: defaultWindowSize, |
| clientConn: defaultWindowSize, |
| } |
| testAccountCheckWindowSize(t, wc) |
| } |
| |
| func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) { |
| serverConfig := &ServerConfig{ |
| InitialWindowSize: wc.serverStream, |
| InitialConnWindowSize: wc.serverConn, |
| } |
| connectOptions := ConnectOptions{ |
| InitialWindowSize: wc.clientStream, |
| InitialConnWindowSize: wc.clientConn, |
| } |
| server, client := setUpWithOptions(t, 0, serverConfig, suspended, connectOptions) |
| defer server.stop() |
| defer client.Close() |
| |
| // Wait for server conns to be populated with new server transport. |
| waitWhileTrue(t, func() (bool, error) { |
| server.mu.Lock() |
| defer server.mu.Unlock() |
| if len(server.conns) == 0 { |
| return true, fmt.Errorf("timed out waiting for server transport to be created") |
| } |
| return false, nil |
| }) |
| var st *http2Server |
| server.mu.Lock() |
| for k := range server.conns { |
| st = k.(*http2Server) |
| } |
| server.mu.Unlock() |
| ct := client.(*http2Client) |
| cstream, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) |
| if err != nil { |
| t.Fatalf("Failed to create stream. Err: %v", err) |
| } |
| // Wait for server to receive headers. |
| waitWhileTrue(t, func() (bool, error) { |
| st.mu.Lock() |
| defer st.mu.Unlock() |
| if len(st.activeStreams) == 0 { |
| return true, fmt.Errorf("timed out waiting for server to receive headers") |
| } |
| return false, nil |
| }) |
| // Sleeping to make sure the settings are applied in case of negative test. |
| time.Sleep(time.Second) |
| |
| waitWhileTrue(t, func() (bool, error) { |
| st.fc.mu.Lock() |
| lim := st.fc.limit |
| st.fc.mu.Unlock() |
| if lim != uint32(serverConfig.InitialConnWindowSize) { |
| return true, fmt.Errorf("Server transport flow control window size: got %v, want %v", lim, serverConfig.InitialConnWindowSize) |
| } |
| return false, nil |
| }) |
| |
| ctx, cancel := context.WithTimeout(context.Background(), time.Second) |
| serverSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire()) |
| if err != nil { |
| t.Fatalf("Error while acquiring sendQuota on server. Err: %v", err) |
| } |
| cancel() |
| st.sendQuotaPool.add(serverSendQuota) |
| if serverSendQuota != int(connectOptions.InitialConnWindowSize) { |
| t.Fatalf("Server send quota(%v) not equal to client's window size(%v) on conn.", serverSendQuota, connectOptions.InitialConnWindowSize) |
| } |
| st.mu.Lock() |
| ssq := st.streamSendQuota |
| st.mu.Unlock() |
| if ssq != uint32(connectOptions.InitialWindowSize) { |
| t.Fatalf("Server stream send quota(%v) not equal to client's window size(%v) on stream.", ssq, connectOptions.InitialWindowSize) |
| } |
| ct.fc.mu.Lock() |
| limit := ct.fc.limit |
| ct.fc.mu.Unlock() |
| if limit != uint32(connectOptions.InitialConnWindowSize) { |
| t.Fatalf("Client transport flow control window size is %v, want %v", limit, connectOptions.InitialConnWindowSize) |
| } |
| ctx, cancel = context.WithTimeout(context.Background(), time.Second) |
| clientSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire()) |
| if err != nil { |
| t.Fatalf("Error while acquiring sendQuota on client. Err: %v", err) |
| } |
| cancel() |
| ct.sendQuotaPool.add(clientSendQuota) |
| if clientSendQuota != int(serverConfig.InitialConnWindowSize) { |
| t.Fatalf("Client send quota(%v) not equal to server's window size(%v) on conn.", clientSendQuota, serverConfig.InitialConnWindowSize) |
| } |
| ct.mu.Lock() |
| ssq = ct.streamSendQuota |
| ct.mu.Unlock() |
| if ssq != uint32(serverConfig.InitialWindowSize) { |
| t.Fatalf("Client stream send quota(%v) not equal to server's window size(%v) on stream.", ssq, serverConfig.InitialWindowSize) |
| } |
| cstream.fc.mu.Lock() |
| limit = cstream.fc.limit |
| cstream.fc.mu.Unlock() |
| if limit != uint32(connectOptions.InitialWindowSize) { |
| t.Fatalf("Client stream flow control window size is %v, want %v", limit, connectOptions.InitialWindowSize) |
| } |
| var sstream *Stream |
| st.mu.Lock() |
| for _, v := range st.activeStreams { |
| sstream = v |
| } |
| st.mu.Unlock() |
| sstream.fc.mu.Lock() |
| limit = sstream.fc.limit |
| sstream.fc.mu.Unlock() |
| if limit != uint32(serverConfig.InitialWindowSize) { |
| t.Fatalf("Server stream flow control window size is %v, want %v", limit, serverConfig.InitialWindowSize) |
| } |
| } |
| |
| // Check accounting on both sides after sending and receiving large messages. |
| func TestAccountCheckExpandingWindow(t *testing.T) { |
| server, client := setUp(t, 0, 0, pingpong) |
| defer server.stop() |
| defer client.Close() |
| waitWhileTrue(t, func() (bool, error) { |
| server.mu.Lock() |
| defer server.mu.Unlock() |
| if len(server.conns) == 0 { |
| return true, fmt.Errorf("timed out while waiting for server transport to be created") |
| } |
| return false, nil |
| }) |
| var st *http2Server |
| server.mu.Lock() |
| for k := range server.conns { |
| st = k.(*http2Server) |
| } |
| server.mu.Unlock() |
| ct := client.(*http2Client) |
| cstream, err := client.NewStream(context.Background(), &CallHdr{Flush: true}) |
| if err != nil { |
| t.Fatalf("Failed to create stream. Err: %v", err) |
| } |
| |
| msgSize := 65535 * 16 * 2 |
| msg := make([]byte, msgSize) |
| buf := make([]byte, msgSize+5) |
| buf[0] = byte(0) |
| binary.BigEndian.PutUint32(buf[1:], uint32(msgSize)) |
| copy(buf[5:], msg) |
| opts := Options{} |
| header := make([]byte, 5) |
| for i := 1; i <= 10; i++ { |
| if err := ct.Write(cstream, nil, buf, &opts); err != nil { |
| t.Fatalf("Error on client while writing message: %v", err) |
| } |
| if _, err := cstream.Read(header); err != nil { |
| t.Fatalf("Error on client while reading data frame header: %v", err) |
| } |
| sz := binary.BigEndian.Uint32(header[1:]) |
| recvMsg := make([]byte, int(sz)) |
| if _, err := cstream.Read(recvMsg); err != nil { |
| t.Fatalf("Error on client while reading data: %v", err) |
| } |
| if len(recvMsg) != len(msg) { |
| t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg)) |
| } |
| } |
| var sstream *Stream |
| st.mu.Lock() |
| for _, v := range st.activeStreams { |
| sstream = v |
| } |
| st.mu.Unlock() |
| |
| waitWhileTrue(t, func() (bool, error) { |
| // Check that pendingData and delta on flow control windows on both sides are 0. |
| cstream.fc.mu.Lock() |
| if cstream.fc.delta != 0 { |
| cstream.fc.mu.Unlock() |
| return true, fmt.Errorf("delta on flow control window of client stream is non-zero") |
| } |
| if cstream.fc.pendingData != 0 { |
| cstream.fc.mu.Unlock() |
| return true, fmt.Errorf("pendingData on flow control window of client stream is non-zero") |
| } |
| cstream.fc.mu.Unlock() |
| sstream.fc.mu.Lock() |
| if sstream.fc.delta != 0 { |
| sstream.fc.mu.Unlock() |
| return true, fmt.Errorf("delta on flow control window of server stream is non-zero") |
| } |
| if sstream.fc.pendingData != 0 { |
| sstream.fc.mu.Unlock() |
| return true, fmt.Errorf("pendingData on flow control window of sercer stream is non-zero") |
| } |
| sstream.fc.mu.Unlock() |
| ct.fc.mu.Lock() |
| if ct.fc.delta != 0 { |
| ct.fc.mu.Unlock() |
| return true, fmt.Errorf("delta on flow control window of client transport is non-zero") |
| } |
| if ct.fc.pendingData != 0 { |
| ct.fc.mu.Unlock() |
| return true, fmt.Errorf("pendingData on flow control window of client transport is non-zero") |
| } |
| ct.fc.mu.Unlock() |
| st.fc.mu.Lock() |
| if st.fc.delta != 0 { |
| st.fc.mu.Unlock() |
| return true, fmt.Errorf("delta on flow control window of server transport is non-zero") |
| } |
| if st.fc.pendingData != 0 { |
| st.fc.mu.Unlock() |
| return true, fmt.Errorf("pendingData on flow control window of server transport is non-zero") |
| } |
| st.fc.mu.Unlock() |
| |
| // Check flow conrtrol window on client stream is equal to out flow on server stream. |
| ctx, cancel := context.WithTimeout(context.Background(), time.Second) |
| serverStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, sstream.sendQuotaPool.acquire()) |
| cancel() |
| if err != nil { |
| return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err) |
| } |
| sstream.sendQuotaPool.add(serverStreamSendQuota) |
| cstream.fc.mu.Lock() |
| clientEst := cstream.fc.limit - cstream.fc.pendingUpdate |
| cstream.fc.mu.Unlock() |
| if uint32(serverStreamSendQuota) != clientEst { |
| return true, fmt.Errorf("server stream outflow: %v, estimated by client: %v", serverStreamSendQuota, clientEst) |
| } |
| |
| // Check flow control window on server stream is equal to out flow on client stream. |
| ctx, cancel = context.WithTimeout(context.Background(), time.Second) |
| clientStreamSendQuota, err := wait(ctx, context.Background(), nil, nil, cstream.sendQuotaPool.acquire()) |
| cancel() |
| if err != nil { |
| return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err) |
| } |
| cstream.sendQuotaPool.add(clientStreamSendQuota) |
| sstream.fc.mu.Lock() |
| serverEst := sstream.fc.limit - sstream.fc.pendingUpdate |
| sstream.fc.mu.Unlock() |
| if uint32(clientStreamSendQuota) != serverEst { |
| return true, fmt.Errorf("client stream outflow: %v. estimated by server: %v", clientStreamSendQuota, serverEst) |
| } |
| |
| // Check flow control window on client transport is equal to out flow of server transport. |
| ctx, cancel = context.WithTimeout(context.Background(), time.Second) |
| serverTrSendQuota, err := wait(ctx, context.Background(), nil, nil, st.sendQuotaPool.acquire()) |
| cancel() |
| if err != nil { |
| return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err) |
| } |
| st.sendQuotaPool.add(serverTrSendQuota) |
| ct.fc.mu.Lock() |
| clientEst = ct.fc.limit - ct.fc.pendingUpdate |
| ct.fc.mu.Unlock() |
| if uint32(serverTrSendQuota) != clientEst { |
| return true, fmt.Errorf("server transport outflow: %v, estimated by client: %v", serverTrSendQuota, clientEst) |
| } |
| |
| // Check flow control window on server transport is equal to out flow of client transport. |
| ctx, cancel = context.WithTimeout(context.Background(), time.Second) |
| clientTrSendQuota, err := wait(ctx, context.Background(), nil, nil, ct.sendQuotaPool.acquire()) |
| cancel() |
| if err != nil { |
| return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err) |
| } |
| ct.sendQuotaPool.add(clientTrSendQuota) |
| st.fc.mu.Lock() |
| serverEst = st.fc.limit - st.fc.pendingUpdate |
| st.fc.mu.Unlock() |
| if uint32(clientTrSendQuota) != serverEst { |
| return true, fmt.Errorf("client transport outflow: %v, estimated by client: %v", clientTrSendQuota, serverEst) |
| } |
| |
| return false, nil |
| }) |
| |
| } |
| |
| func waitWhileTrue(t *testing.T, condition func() (bool, error)) { |
| var ( |
| wait bool |
| err error |
| ) |
| timer := time.NewTimer(time.Second * 5) |
| for { |
| wait, err = condition() |
| if wait { |
| select { |
| case <-timer.C: |
| t.Fatalf(err.Error()) |
| default: |
| time.Sleep(50 * time.Millisecond) |
| continue |
| } |
| } |
| if !timer.Stop() { |
| <-timer.C |
| } |
| break |
| } |
| } |
| |
| // A function of type writeHeaders writes out |
| // http status with the given stream ID using the given framer. |
| type writeHeaders func(*http2.Framer, uint32, int) error |
| |
| func writeOneHeader(framer *http2.Framer, sid uint32, httpStatus int) error { |
| var buf bytes.Buffer |
| henc := hpack.NewEncoder(&buf) |
| henc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(httpStatus)}) |
| return framer.WriteHeaders(http2.HeadersFrameParam{ |
| StreamID: sid, |
| BlockFragment: buf.Bytes(), |
| EndStream: true, |
| EndHeaders: true, |
| }) |
| } |
| |
| func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error { |
| var buf bytes.Buffer |
| henc := hpack.NewEncoder(&buf) |
| henc.WriteField(hpack.HeaderField{ |
| Name: ":status", |
| Value: fmt.Sprint(http.StatusOK), |
| }) |
| if err := framer.WriteHeaders(http2.HeadersFrameParam{ |
| StreamID: sid, |
| BlockFragment: buf.Bytes(), |
| EndHeaders: true, |
| }); err != nil { |
| return err |
| } |
| buf.Reset() |
| henc.WriteField(hpack.HeaderField{ |
| Name: ":status", |
| Value: fmt.Sprint(httpStatus), |
| }) |
| return framer.WriteHeaders(http2.HeadersFrameParam{ |
| StreamID: sid, |
| BlockFragment: buf.Bytes(), |
| EndStream: true, |
| EndHeaders: true, |
| }) |
| } |
| |
| type httpServer struct { |
| conn net.Conn |
| httpStatus int |
| wh writeHeaders |
| } |
| |
| func (s *httpServer) start(t *testing.T, lis net.Listener) { |
| // Launch an HTTP server to send back header with httpStatus. |
| go func() { |
| var err error |
| s.conn, err = lis.Accept() |
| if err != nil { |
| t.Errorf("Error accepting connection: %v", err) |
| return |
| } |
| defer s.conn.Close() |
| // Read preface sent by client. |
| if _, err = io.ReadFull(s.conn, make([]byte, len(http2.ClientPreface))); err != nil { |
| t.Errorf("Error at server-side while reading preface from cleint. Err: %v", err) |
| return |
| } |
| reader := bufio.NewReaderSize(s.conn, defaultWriteBufSize) |
| writer := bufio.NewWriterSize(s.conn, defaultReadBufSize) |
| framer := http2.NewFramer(writer, reader) |
| if err = framer.WriteSettingsAck(); err != nil { |
| t.Errorf("Error at server-side while sending Settings ack. Err: %v", err) |
| return |
| } |
| var sid uint32 |
| // Read frames until a header is received. |
| for { |
| frame, err := framer.ReadFrame() |
| if err != nil { |
| t.Errorf("Error at server-side while reading frame. Err: %v", err) |
| return |
| } |
| if hframe, ok := frame.(*http2.HeadersFrame); ok { |
| sid = hframe.Header().StreamID |
| break |
| } |
| } |
| if err = s.wh(framer, sid, s.httpStatus); err != nil { |
| t.Errorf("Error at server-side while writing headers. Err: %v", err) |
| return |
| } |
| writer.Flush() |
| }() |
| } |
| |
| func (s *httpServer) cleanUp() { |
| if s.conn != nil { |
| s.conn.Close() |
| } |
| } |
| |
| func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream *Stream, cleanUp func()) { |
| var ( |
| err error |
| lis net.Listener |
| server *httpServer |
| client ClientTransport |
| ) |
| cleanUp = func() { |
| if lis != nil { |
| lis.Close() |
| } |
| if server != nil { |
| server.cleanUp() |
| } |
| if client != nil { |
| client.Close() |
| } |
| } |
| defer func() { |
| if err != nil { |
| cleanUp() |
| } |
| }() |
| lis, err = net.Listen("tcp", "localhost:0") |
| if err != nil { |
| t.Fatalf("Failed to listen. Err: %v", err) |
| } |
| server = &httpServer{ |
| httpStatus: httpStatus, |
| wh: wh, |
| } |
| server.start(t, lis) |
| client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{}, 2*time.Second) |
| if err != nil { |
| t.Fatalf("Error creating client. Err: %v", err) |
| } |
| stream, err = client.NewStream(context.Background(), &CallHdr{Method: "bogus/method", Flush: true}) |
| if err != nil { |
| t.Fatalf("Error creating stream at client-side. Err: %v", err) |
| } |
| return |
| } |
| |
| func TestHTTPToGRPCStatusMapping(t *testing.T) { |
| for k := range httpStatusConvTab { |
| testHTTPToGRPCStatusMapping(t, k, writeOneHeader) |
| } |
| } |
| |
| func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) { |
| stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh) |
| defer cleanUp() |
| want := httpStatusConvTab[httpStatus] |
| buf := make([]byte, 8) |
| _, err := stream.Read(buf) |
| if err == nil { |
| t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want) |
| } |
| serr, ok := err.(StreamError) |
| if !ok { |
| t.Fatalf("err.(Type) = %T, want StreamError", err) |
| } |
| if want != serr.Code { |
| t.Fatalf("Want error code: %v, got: %v", want, serr.Code) |
| } |
| } |
| |
| func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) { |
| stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader) |
| defer cleanUp() |
| buf := make([]byte, 8) |
| _, err := stream.Read(buf) |
| if err != io.EOF { |
| t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err) |
| } |
| want := codes.Unknown |
| stream.mu.Lock() |
| defer stream.mu.Unlock() |
| if stream.status.Code() != want { |
| t.Fatalf("Status code of stream: %v, want: %v", stream.status.Code(), want) |
| } |
| } |
| |
| func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) { |
| testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders) |
| } |
| |
| // If any error occurs on a call to Stream.Read, future calls |
| // should continue to return that same error. |
| func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) { |
| testRecvBuffer := newRecvBuffer() |
| s := &Stream{ |
| ctx: context.Background(), |
| goAway: make(chan struct{}), |
| buf: testRecvBuffer, |
| requestRead: func(int) {}, |
| } |
| s.trReader = &transportReader{ |
| reader: &recvBufferReader{ |
| ctx: s.ctx, |
| goAway: s.goAway, |
| recv: s.buf, |
| }, |
| windowHandler: func(int) {}, |
| } |
| testData := make([]byte, 1) |
| testData[0] = 5 |
| testErr := errors.New("test error") |
| s.write(recvMsg{data: testData, err: testErr}) |
| |
| inBuf := make([]byte, 1) |
| actualCount, actualErr := s.Read(inBuf) |
| if actualCount != 0 { |
| t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount) |
| } |
| if actualErr.Error() != testErr.Error() { |
| t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) |
| } |
| |
| s.write(recvMsg{data: testData, err: nil}) |
| s.write(recvMsg{data: testData, err: errors.New("different error from first")}) |
| |
| for i := 0; i < 2; i++ { |
| inBuf := make([]byte, 1) |
| actualCount, actualErr := s.Read(inBuf) |
| if actualCount != 0 { |
| t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount) |
| } |
| if actualErr.Error() != testErr.Error() { |
| t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error()) |
| } |
| } |
| } |