blob: f30ebc6d6332483f068885909ade2ce363740c74 [file] [log] [blame]
/*
*
* Copyright 2014 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package transport
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
"net"
"net/http"
"reflect"
"strconv"
"strings"
"sync"
"testing"
"time"
"golang.org/x/net/context"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
)
type server struct {
lis net.Listener
port string
startedErr chan error // error (or nil) with server start value
mu sync.Mutex
conns map[ServerTransport]bool
h *testStreamHandler
}
var (
expectedRequest = []byte("ping")
expectedResponse = []byte("pong")
expectedRequestLarge = make([]byte, initialWindowSize*2)
expectedResponseLarge = make([]byte, initialWindowSize*2)
expectedInvalidHeaderField = "invalid/content-type"
)
type testStreamHandler struct {
t *http2Server
notify chan struct{}
}
type hType int
const (
normal hType = iota
suspended
notifyCall
misbehaved
encodingRequiredStatus
invalidHeaderField
delayRead
delayWrite
pingpong
)
func (h *testStreamHandler) handleStreamAndNotify(s *Stream) {
if h.notify == nil {
return
}
go func() {
select {
case <-h.notify:
default:
close(h.notify)
}
}()
}
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
req = expectedRequestLarge
resp = expectedResponseLarge
}
p := make([]byte, len(req))
_, err := s.Read(p)
if err != nil {
return
}
if !bytes.Equal(p, req) {
t.Fatalf("handleStream got %v, want %v", p, req)
}
// send a response back to the client.
h.t.Write(s, nil, resp, &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
func (h *testStreamHandler) handleStreamPingPong(t *testing.T, s *Stream) {
header := make([]byte, 5)
for i := 0; i < 10; i++ {
if _, err := s.Read(header); err != nil {
t.Fatalf("Error on server while reading data header: %v", err)
}
sz := binary.BigEndian.Uint32(header[1:])
msg := make([]byte, int(sz))
if _, err := s.Read(msg); err != nil {
t.Fatalf("Error on server while reading message: %v", err)
}
buf := make([]byte, sz+5)
buf[0] = byte(0)
binary.BigEndian.PutUint32(buf[1:], uint32(sz))
copy(buf[5:], msg)
h.t.Write(s, nil, buf, &Options{})
}
}
func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
conn, ok := s.ServerTransport().(*http2Server)
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", s.ServerTransport())
}
var sent int
p := make([]byte, http2MaxFrameLen)
for sent < initialWindowSize {
n := initialWindowSize - sent
// The last message may be smaller than http2MaxFrameLen
if n <= http2MaxFrameLen {
if s.Method() == "foo.Connection" {
// Violate connection level flow control window of client but do not
// violate any stream level windows.
p = make([]byte, n)
} else {
// Violate stream level flow control window of client.
p = make([]byte, n+1)
}
}
conn.controlBuf.put(&dataFrame{s.id, false, p, func() {}})
sent += len(p)
}
}
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) {
// raw newline is not accepted by http2 framer so it must be encoded.
h.t.WriteStatus(s, encodingTestStatus)
}
func (h *testStreamHandler) handleStreamInvalidHeaderField(t *testing.T, s *Stream) {
headerFields := []hpack.HeaderField{}
headerFields = append(headerFields, hpack.HeaderField{Name: "content-type", Value: expectedInvalidHeaderField})
h.t.controlBuf.put(&headerFrame{
streamID: s.id,
hf: headerFields,
endStream: false,
})
}
func (h *testStreamHandler) handleStreamDelayRead(t *testing.T, s *Stream) {
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
req = expectedRequestLarge
resp = expectedResponseLarge
}
p := make([]byte, len(req))
// Wait before reading. Give time to client to start sending
// before server starts reading.
time.Sleep(2 * time.Second)
_, err := s.Read(p)
if err != nil {
t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
return
}
if !bytes.Equal(p, req) {
t.Fatalf("handleStream got %v, want %v", p, req)
}
// send a response back to the client.
h.t.Write(s, nil, resp, &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
func (h *testStreamHandler) handleStreamDelayWrite(t *testing.T, s *Stream) {
req := expectedRequest
resp := expectedResponse
if s.Method() == "foo.Large" {
req = expectedRequestLarge
resp = expectedResponseLarge
}
p := make([]byte, len(req))
_, err := s.Read(p)
if err != nil {
t.Fatalf("s.Read(_) = _, %v, want _, <nil>", err)
return
}
if !bytes.Equal(p, req) {
t.Fatalf("handleStream got %v, want %v", p, req)
}
// Wait before sending. Give time to client to start reading
// before server starts sending.
time.Sleep(2 * time.Second)
h.t.Write(s, nil, resp, &Options{})
// send the trailer to end the stream.
h.t.WriteStatus(s, status.New(codes.OK, ""))
}
// start starts server. Other goroutines should block on s.readyChan for further operations.
func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hType) {
var err error
if port == 0 {
s.lis, err = net.Listen("tcp", "localhost:0")
} else {
s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port))
}
if err != nil {
s.startedErr <- fmt.Errorf("failed to listen: %v", err)
return
}
_, p, err := net.SplitHostPort(s.lis.Addr().String())
if err != nil {
s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err)
return
}
s.port = p
s.conns = make(map[ServerTransport]bool)
s.startedErr <- nil
for {
conn, err := s.lis.Accept()
if err != nil {
return
}
transport, err := NewServerTransport("http2", conn, serverConfig)
if err != nil {
return
}
s.mu.Lock()
if s.conns == nil {
s.mu.Unlock()
transport.Close()
return
}
s.conns[transport] = true
h := &testStreamHandler{t: transport.(*http2Server)}
s.h = h
s.mu.Unlock()
switch ht {
case notifyCall:
go transport.HandleStreams(h.handleStreamAndNotify,
func(ctx context.Context, _ string) context.Context {
return ctx
})
case suspended:
go transport.HandleStreams(func(*Stream) {}, // Do nothing to handle the stream.
func(ctx context.Context, method string) context.Context {
return ctx
})
case misbehaved:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamMisbehave(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case encodingRequiredStatus:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamEncodingRequiredStatus(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case invalidHeaderField:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamInvalidHeaderField(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case delayRead:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamDelayRead(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case delayWrite:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamDelayWrite(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
case pingpong:
go transport.HandleStreams(func(s *Stream) {
go h.handleStreamPingPong(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
default:
go transport.HandleStreams(func(s *Stream) {
go h.handleStream(t, s)
}, func(ctx context.Context, method string) context.Context {
return ctx
})
}
}
}
func (s *server) wait(t *testing.T, timeout time.Duration) {
select {
case err := <-s.startedErr:
if err != nil {
t.Fatal(err)
}
case <-time.After(timeout):
t.Fatalf("Timed out after %v waiting for server to be ready", timeout)
}
}
func (s *server) stop() {
s.lis.Close()
s.mu.Lock()
for c := range s.conns {
c.Close()
}
s.conns = nil
s.mu.Unlock()
}
func setUp(t *testing.T, port int, maxStreams uint32, ht hType) (*server, ClientTransport) {
return setUpWithOptions(t, port, &ServerConfig{MaxStreams: maxStreams}, ht, ConnectOptions{})
}
func setUpWithOptions(t *testing.T, port int, serverConfig *ServerConfig, ht hType, copts ConnectOptions) (*server, ClientTransport) {
server := &server{startedErr: make(chan error, 1)}
go server.start(t, port, serverConfig, ht)
server.wait(t, 2*time.Second)
addr := "localhost:" + server.port
var (
ct ClientTransport
connErr error
)
target := TargetInfo{
Addr: addr,
}
ct, connErr = NewClientTransport(context.Background(), target, copts)
if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr)
}
return server, ct
}
func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Conn) ClientTransport {
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen: %v", err)
}
// Launch a non responsive server.
go func() {
defer lis.Close()
conn, err := lis.Accept()
if err != nil {
t.Errorf("Error at server-side while accepting: %v", err)
close(done)
return
}
done <- conn
}()
tr, err := NewClientTransport(context.Background(), TargetInfo{Addr: lis.Addr().String()}, copts)
if err != nil {
// Server clean-up.
lis.Close()
if conn, ok := <-done; ok {
conn.Close()
}
t.Fatalf("Failed to dial: %v", err)
}
return tr
}
// TestInflightStreamClosing ensures that closing in-flight stream
// sends StreamError to concurrent stream reader.
func TestInflightStreamClosing(t *testing.T) {
serverConfig := &ServerConfig{}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
defer server.stop()
defer client.Close()
stream, err := client.NewStream(context.Background(), &CallHdr{})
if err != nil {
t.Fatalf("Client failed to create RPC request: %v", err)
}
donec := make(chan struct{})
serr := StreamError{Desc: "client connection is closing"}
go func() {
defer close(donec)
if _, err := stream.Read(make([]byte, defaultWindowSize)); err != serr {
t.Errorf("unexpected Stream error %v, expected %v", err, serr)
}
}()
// should unblock concurrent stream.Read
client.CloseStream(stream, serr)
// wait for stream.Read error
timeout := time.NewTimer(5 * time.Second)
select {
case <-donec:
if !timeout.Stop() {
<-timeout.C
}
case <-timeout.C:
t.Fatalf("Test timed out, expected a StreamError.")
}
}
// TestMaxConnectionIdle tests that a server will send GoAway to a idle client.
// An idle client is one who doesn't make any RPC calls for a duration of
// MaxConnectionIdle time.
func TestMaxConnectionIdle(t *testing.T) {
serverConfig := &ServerConfig{
KeepaliveParams: keepalive.ServerParameters{
MaxConnectionIdle: 2 * time.Second,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
defer server.stop()
defer client.Close()
stream, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Client failed to create RPC request: %v", err)
}
stream.mu.Lock()
stream.rstStream = true
stream.mu.Unlock()
client.CloseStream(stream, nil)
// wait for server to see that closed stream and max-age logic to send goaway after no new RPCs are mode
timeout := time.NewTimer(time.Second * 4)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
case <-timeout.C:
t.Fatalf("Test timed out, expected a GoAway from the server.")
}
}
// TestMaxConenctionIdleNegative tests that a server will not send GoAway to a non-idle(busy) client.
func TestMaxConnectionIdleNegative(t *testing.T) {
serverConfig := &ServerConfig{
KeepaliveParams: keepalive.ServerParameters{
MaxConnectionIdle: 2 * time.Second,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
defer server.stop()
defer client.Close()
_, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Client failed to create RPC request: %v", err)
}
timeout := time.NewTimer(time.Second * 4)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
t.Fatalf("A non-idle client received a GoAway.")
case <-timeout.C:
}
}
// TestMaxConnectionAge tests that a server will send GoAway after a duration of MaxConnectionAge.
func TestMaxConnectionAge(t *testing.T) {
serverConfig := &ServerConfig{
KeepaliveParams: keepalive.ServerParameters{
MaxConnectionAge: 2 * time.Second,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
defer server.stop()
defer client.Close()
_, err := client.NewStream(context.Background(), &CallHdr{})
if err != nil {
t.Fatalf("Client failed to create stream: %v", err)
}
// Wait for max-age logic to send GoAway.
timeout := time.NewTimer(4 * time.Second)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
case <-timeout.C:
t.Fatalf("Test timer out, expected a GoAway from the server.")
}
}
// TestKeepaliveServer tests that a server closes connection with a client that doesn't respond to keepalive pings.
func TestKeepaliveServer(t *testing.T) {
serverConfig := &ServerConfig{
KeepaliveParams: keepalive.ServerParameters{
Time: 2 * time.Second,
Timeout: 1 * time.Second,
},
}
server, c := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
defer server.stop()
defer c.Close()
client, err := net.Dial("tcp", server.lis.Addr().String())
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
defer client.Close()
// Set read deadline on client conn so that it doesn't block forever in errorsome cases.
client.SetReadDeadline(time.Now().Add(10 * time.Second))
// Wait for keepalive logic to close the connection.
time.Sleep(4 * time.Second)
b := make([]byte, 24)
for {
_, err = client.Read(b)
if err == nil {
continue
}
if err != io.EOF {
t.Fatalf("client.Read(_) = _,%v, want io.EOF", err)
}
break
}
}
// TestKeepaliveServerNegative tests that a server doesn't close connection with a client that responds to keepalive pings.
func TestKeepaliveServerNegative(t *testing.T) {
serverConfig := &ServerConfig{
KeepaliveParams: keepalive.ServerParameters{
Time: 2 * time.Second,
Timeout: 1 * time.Second,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
defer server.stop()
defer client.Close()
// Give keepalive logic some time by sleeping.
time.Sleep(4 * time.Second)
// Assert that client is still active.
clientTr := client.(*http2Client)
clientTr.mu.Lock()
defer clientTr.mu.Unlock()
if clientTr.state != reachable {
t.Fatalf("Test failed: Expected server-client connection to be healthy.")
}
}
func TestKeepaliveClientClosesIdleTransport(t *testing.T) {
done := make(chan net.Conn, 1)
tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{
Time: 2 * time.Second, // Keepalive time = 2 sec.
Timeout: 1 * time.Second, // Keepalive timeout = 1 sec.
PermitWithoutStream: true, // Run keepalive even with no RPCs.
}}, done)
defer tr.Close()
conn, ok := <-done
if !ok {
t.Fatalf("Server didn't return connection object")
}
defer conn.Close()
// Sleep for keepalive to close the connection.
time.Sleep(4 * time.Second)
// Assert that the connection was closed.
ct := tr.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state == reachable {
t.Fatalf("Test Failed: Expected client transport to have closed.")
}
}
func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) {
done := make(chan net.Conn, 1)
tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{
Time: 2 * time.Second, // Keepalive time = 2 sec.
Timeout: 1 * time.Second, // Keepalive timeout = 1 sec.
}}, done)
defer tr.Close()
conn, ok := <-done
if !ok {
t.Fatalf("server didn't reutrn connection object")
}
defer conn.Close()
// Give keepalive some time.
time.Sleep(4 * time.Second)
// Assert that connections is still healthy.
ct := tr.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state != reachable {
t.Fatalf("Test failed: Expected client transport to be healthy.")
}
}
func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) {
done := make(chan net.Conn, 1)
tr := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{
Time: 2 * time.Second, // Keepalive time = 2 sec.
Timeout: 1 * time.Second, // Keepalive timeout = 1 sec.
}}, done)
defer tr.Close()
conn, ok := <-done
if !ok {
t.Fatalf("Server didn't return connection object")
}
defer conn.Close()
// Create a stream.
_, err := tr.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Failed to create a new stream: %v", err)
}
// Give keepalive some time.
time.Sleep(4 * time.Second)
// Assert that transport was closed.
ct := tr.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state == reachable {
t.Fatalf("Test failed: Expected client transport to have closed.")
}
}
func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) {
s, tr := setUpWithOptions(t, 0, &ServerConfig{MaxStreams: math.MaxUint32}, normal, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{
Time: 2 * time.Second, // Keepalive time = 2 sec.
Timeout: 1 * time.Second, // Keepalive timeout = 1 sec.
PermitWithoutStream: true, // Run keepalive even with no RPCs.
}})
defer s.stop()
defer tr.Close()
// Give keep alive some time.
time.Sleep(4 * time.Second)
// Assert that transport is healthy.
ct := tr.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state != reachable {
t.Fatalf("Test failed: Expected client transport to be healthy.")
}
}
func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) {
serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{
MinTime: 2 * time.Second,
},
}
clientOptions := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: 50 * time.Millisecond,
Timeout: 50 * time.Millisecond,
PermitWithoutStream: true,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions)
defer server.stop()
defer client.Close()
timeout := time.NewTimer(2 * time.Second)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
case <-timeout.C:
t.Fatalf("Test failed: Expected a GoAway from server.")
}
time.Sleep(500 * time.Millisecond)
ct := client.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state == reachable {
t.Fatalf("Test failed: Expected the connection to be closed.")
}
}
func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) {
serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{
MinTime: 2 * time.Second,
},
}
clientOptions := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: 50 * time.Millisecond,
Timeout: 50 * time.Millisecond,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions)
defer server.stop()
defer client.Close()
if _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}); err != nil {
t.Fatalf("Client failed to create stream.")
}
timeout := time.NewTimer(2 * time.Second)
select {
case <-client.GoAway():
if !timeout.Stop() {
<-timeout.C
}
case <-timeout.C:
t.Fatalf("Test failed: Expected a GoAway from server.")
}
time.Sleep(500 * time.Millisecond)
ct := client.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state == reachable {
t.Fatalf("Test failed: Expected the connection to be closed.")
}
}
func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) {
serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{
MinTime: 100 * time.Millisecond,
PermitWithoutStream: true,
},
}
clientOptions := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: 101 * time.Millisecond,
Timeout: 50 * time.Millisecond,
PermitWithoutStream: true,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, normal, clientOptions)
defer server.stop()
defer client.Close()
// Give keepalive enough time.
time.Sleep(2 * time.Second)
// Assert that connection is healthy.
ct := client.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state != reachable {
t.Fatalf("Test failed: Expected connection to be healthy.")
}
}
func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) {
serverConfig := &ServerConfig{
KeepalivePolicy: keepalive.EnforcementPolicy{
MinTime: 100 * time.Millisecond,
},
}
clientOptions := ConnectOptions{
KeepaliveParams: keepalive.ClientParameters{
Time: 101 * time.Millisecond,
Timeout: 50 * time.Millisecond,
},
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions)
defer server.stop()
defer client.Close()
if _, err := client.NewStream(context.Background(), &CallHdr{Flush: true}); err != nil {
t.Fatalf("Client failed to create stream.")
}
// Give keepalive enough time.
time.Sleep(2 * time.Second)
// Assert that connection is healthy.
ct := client.(*http2Client)
ct.mu.Lock()
defer ct.mu.Unlock()
if ct.state != reachable {
t.Fatalf("Test failed: Expected connection to be healthy.")
}
}
func TestClientSendAndReceive(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
}
s1, err1 := ct.NewStream(context.Background(), callHdr)
if err1 != nil {
t.Fatalf("failed to open stream: %v", err1)
}
if s1.id != 1 {
t.Fatalf("wrong stream id: %d", s1.id)
}
s2, err2 := ct.NewStream(context.Background(), callHdr)
if err2 != nil {
t.Fatalf("failed to open stream: %v", err2)
}
if s2.id != 3 {
t.Fatalf("wrong stream id: %d", s2.id)
}
opts := Options{
Last: true,
Delay: false,
}
if err := ct.Write(s1, nil, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("failed to send data: %v", err)
}
p := make([]byte, len(expectedResponse))
_, recvErr := s1.Read(p)
if recvErr != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("Error: %v, want <nil>; Result: %v, want %v", recvErr, p, expectedResponse)
}
_, recvErr = s1.Read(p)
if recvErr != io.EOF {
t.Fatalf("Error: %v; want <EOF>", recvErr)
}
ct.Close()
server.stop()
}
func TestClientErrorNotify(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, normal)
go server.stop()
// ct.reader should detect the error and activate ct.Error().
<-ct.Error()
ct.Close()
}
func performOneRPC(ct ClientTransport) {
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
return
}
opts := Options{
Last: true,
Delay: false,
}
if err := ct.Write(s, []byte{}, expectedRequest, &opts); err == nil || err == io.EOF {
time.Sleep(5 * time.Millisecond)
// The following s.Recv()'s could error out because the
// underlying transport is gone.
//
// Read response
p := make([]byte, len(expectedResponse))
s.Read(p)
// Read io.EOF
s.Read(p)
}
}
func TestClientMix(t *testing.T) {
s, ct := setUp(t, 0, math.MaxUint32, normal)
go func(s *server) {
time.Sleep(5 * time.Second)
s.stop()
}(s)
go func(ct ClientTransport) {
<-ct.Error()
ct.Close()
}(ct)
for i := 0; i < 1000; i++ {
time.Sleep(10 * time.Millisecond)
go performOneRPC(ct)
}
}
func TestLargeMessage(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
}
var wg sync.WaitGroup
for i := 0; i < 2; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
if _, err = s.Read(p); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
}
wg.Wait()
ct.Close()
server.stop()
}
func TestLargeMessageWithDelayRead(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, delayRead)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
}
var wg sync.WaitGroup
for i := 0; i < 2; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
// Give time to server to begin sending before client starts reading.
time.Sleep(2 * time.Second)
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("s.Read(_) = _, %v, want _, <nil>", err)
}
if _, err = s.Read(p); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
}
wg.Wait()
ct.Close()
server.stop()
}
func TestLargeMessageDelayWrite(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, delayWrite)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
}
var wg sync.WaitGroup
for i := 0; i < 2; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
// Give time to server to start reading before client starts sending.
time.Sleep(2 * time.Second)
if err := ct.Write(s, []byte{}, expectedRequestLarge, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Errorf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponseLarge))
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponseLarge) {
t.Errorf("io.ReadFull(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
if _, err = s.Read(p); err != io.EOF {
t.Errorf("Failed to complete the stream %v; want <EOF>", err)
}
}()
}
wg.Wait()
ct.Close()
server.stop()
}
func TestGracefulClose(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, normal)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Small",
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("%v.NewStream(_, _) = _, %v, want _, <nil>", ct, err)
}
if err = ct.GracefulClose(); err != nil {
t.Fatalf("%v.GracefulClose() = %v, want <nil>", ct, err)
}
var wg sync.WaitGroup
// Expect the failure for all the follow-up streams because ct has been closed gracefully.
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if _, err := ct.NewStream(context.Background(), callHdr); err != ErrStreamDrain {
t.Errorf("%v.NewStream(_, _) = _, %v, want _, %v", ct, err, ErrStreamDrain)
}
}()
}
opts := Options{
Last: true,
Delay: false,
}
// The stream which was created before graceful close can still proceed.
if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("%v.Write(_, _, _) = %v, want <nil>", ct, err)
}
p := make([]byte, len(expectedResponse))
if _, err := s.Read(p); err != nil || !bytes.Equal(p, expectedResponse) {
t.Fatalf("s.Read(%v) = _, %v, want %v, <nil>", err, p, expectedResponse)
}
if _, err = s.Read(p); err != io.EOF {
t.Fatalf("Failed to complete the stream %v; want <EOF>", err)
}
wg.Wait()
ct.Close()
server.stop()
}
func TestLargeMessageSuspension(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
}
// Set a long enough timeout for writing a large message out.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
s, err := ct.NewStream(ctx, callHdr)
if err != nil {
t.Fatalf("failed to open stream: %v", err)
}
// Write should not be done successfully due to flow control.
msg := make([]byte, initialWindowSize*8)
err = ct.Write(s, nil, msg, &Options{Last: true, Delay: false})
expectedErr := streamErrorf(codes.DeadlineExceeded, "%v", context.DeadlineExceeded)
if err != expectedErr {
t.Fatalf("Write got %v, want %v", err, expectedErr)
}
ct.Close()
server.stop()
}
func TestMaxStreams(t *testing.T) {
server, ct := setUp(t, 0, 1, suspended)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Large",
}
// Have a pending stream which takes all streams quota.
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
cc, ok := ct.(*http2Client)
if !ok {
t.Fatalf("Failed to convert %v to *http2Client", ct)
}
done := make(chan struct{})
ch := make(chan int)
ready := make(chan struct{})
go func() {
for {
select {
case <-time.After(5 * time.Millisecond):
select {
case ch <- 0:
case <-ready:
return
}
case <-time.After(5 * time.Second):
close(done)
return
case <-ready:
return
}
}
}()
for {
select {
case <-ch:
case <-done:
t.Fatalf("Client has not received the max stream setting in 5 seconds.")
}
cc.mu.Lock()
// cc.maxStreams should be equal to 1 after having received settings frame from
// server.
if cc.maxStreams == 1 {
cc.mu.Unlock()
select {
case <-cc.streamsQuota.acquire():
t.Fatalf("streamsQuota.acquire() becomes readable mistakenly.")
default:
cc.streamsQuota.mu.Lock()
quota := cc.streamsQuota.quota
cc.streamsQuota.mu.Unlock()
if quota != 0 {
t.Fatalf("streamsQuota.quota got non-zero quota mistakenly.")
}
}
break
}
cc.mu.Unlock()
}
close(ready)
// Close the pending stream so that the streams quota becomes available for the next new stream.
ct.CloseStream(s, nil)
select {
case i := <-cc.streamsQuota.acquire():
if i != 1 {
t.Fatalf("streamsQuota.acquire() got %d quota, want 1.", i)
}
cc.streamsQuota.add(i)
default:
t.Fatalf("streamsQuota.acquire() is not readable.")
}
if _, err := ct.NewStream(context.Background(), callHdr); err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
ct.Close()
server.stop()
}
func TestServerContextCanceledOnClosedConnection(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",
}
var sc *http2Server
// Wait until the server transport is setup.
for {
server.mu.Lock()
if len(server.conns) == 0 {
server.mu.Unlock()
time.Sleep(time.Millisecond)
continue
}
for k := range server.conns {
var ok bool
sc, ok = k.(*http2Server)
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", k)
}
}
server.mu.Unlock()
break
}
cc, ok := ct.(*http2Client)
if !ok {
t.Fatalf("Failed to convert %v to *http2Client", ct)
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, http2MaxFrameLen), func() {}})
// Loop until the server side stream is created.
var ss *Stream
for {
time.Sleep(time.Second)
sc.mu.Lock()
if len(sc.activeStreams) == 0 {
sc.mu.Unlock()
continue
}
ss = sc.activeStreams[s.id]
sc.mu.Unlock()
break
}
cc.Close()
select {
case <-ss.Context().Done():
if ss.Context().Err() != context.Canceled {
t.Fatalf("ss.Context().Err() got %v, want %v", ss.Context().Err(), context.Canceled)
}
case <-time.After(5 * time.Second):
t.Fatalf("Failed to cancel the context of the sever side stream.")
}
server.stop()
}
func TestClientConnDecoupledFromApplicationRead(t *testing.T) {
connectOptions := ConnectOptions{
InitialWindowSize: defaultWindowSize,
InitialConnWindowSize: defaultWindowSize,
}
server, client := setUpWithOptions(t, 0, &ServerConfig{}, notifyCall, connectOptions)
defer server.stop()
defer client.Close()
waitWhileTrue(t, func() (bool, error) {
server.mu.Lock()
defer server.mu.Unlock()
if len(server.conns) == 0 {
return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
}
return false, nil
})
var st *http2Server
server.mu.Lock()
for k := range server.conns {
st = k.(*http2Server)
}
notifyChan := make(chan struct{})
server.h.notify = notifyChan
server.mu.Unlock()
cstream1, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Client failed to create first stream. Err: %v", err)
}
<-notifyChan
var sstream1 *Stream
// Access stream on the server.
st.mu.Lock()
for _, v := range st.activeStreams {
if v.id == cstream1.id {
sstream1 = v
}
}
st.mu.Unlock()
if sstream1 == nil {
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream1.id)
}
// Exhaust client's connection window.
if err := st.Write(sstream1, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
t.Fatalf("Server failed to write data. Err: %v", err)
}
notifyChan = make(chan struct{})
server.mu.Lock()
server.h.notify = notifyChan
server.mu.Unlock()
// Create another stream on client.
cstream2, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Client failed to create second stream. Err: %v", err)
}
<-notifyChan
var sstream2 *Stream
st.mu.Lock()
for _, v := range st.activeStreams {
if v.id == cstream2.id {
sstream2 = v
}
}
st.mu.Unlock()
if sstream2 == nil {
t.Fatalf("Didn't find stream corresponding to client cstream.id: %v on the server", cstream2.id)
}
// Server should be able to send data on the new stream, even though the client hasn't read anything on the first stream.
if err := st.Write(sstream2, []byte{}, make([]byte, defaultWindowSize), &Options{}); err != nil {
t.Fatalf("Server failed to write data. Err: %v", err)
}
// Client should be able to read data on second stream.
if _, err := cstream2.Read(make([]byte, defaultWindowSize)); err != nil {
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
}
// Client should be able to read data on first stream.
if _, err := cstream1.Read(make([]byte, defaultWindowSize)); err != nil {
t.Fatalf("_.Read(_) = _, %v, want _, <nil>", err)
}
}
func TestServerConnDecoupledFromApplicationRead(t *testing.T) {
serverConfig := &ServerConfig{
InitialWindowSize: defaultWindowSize,
InitialConnWindowSize: defaultWindowSize,
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{})
defer server.stop()
defer client.Close()
waitWhileTrue(t, func() (bool, error) {
server.mu.Lock()
defer server.mu.Unlock()
if len(server.conns) == 0 {
return true, fmt.Errorf("timed-out while waiting for connection to be created on the server")
}
return false, nil
})
var st *http2Server
server.mu.Lock()
for k := range server.conns {
st = k.(*http2Server)
}
server.mu.Unlock()
cstream1, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Failed to create 1st stream. Err: %v", err)
}
// Exhaust server's connection window.
if err := client.Write(cstream1, nil, make([]byte, defaultWindowSize), &Options{Last: true}); err != nil {
t.Fatalf("Client failed to write data. Err: %v", err)
}
//Client should be able to create another stream and send data on it.
cstream2, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Failed to create 2nd stream. Err: %v", err)
}
if err := client.Write(cstream2, nil, make([]byte, defaultWindowSize), &Options{}); err != nil {
t.Fatalf("Client failed to write data. Err: %v", err)
}
// Get the streams on server.
waitWhileTrue(t, func() (bool, error) {
st.mu.Lock()
defer st.mu.Unlock()
if len(st.activeStreams) != 2 {
return true, fmt.Errorf("timed-out while waiting for server to have created the streams")
}
return false, nil
})
var sstream1 *Stream
st.mu.Lock()
for _, v := range st.activeStreams {
if v.id == 1 {
sstream1 = v
}
}
st.mu.Unlock()
// Trying to write more on a max-ed out stream should result in a RST_STREAM from the server.
ct := client.(*http2Client)
ct.controlBuf.put(&dataFrame{cstream2.id, true, make([]byte, 1), func() {}})
code := http2ErrConvTab[http2.ErrCodeFlowControl]
waitWhileTrue(t, func() (bool, error) {
cstream2.mu.Lock()
defer cstream2.mu.Unlock()
if cstream2.status.Code() != code {
return true, fmt.Errorf("want code = %v, got %v", code, cstream2.status.Code())
}
return false, nil
})
// Reading from the stream on server should succeed.
if _, err := sstream1.Read(make([]byte, defaultWindowSize)); err != nil {
t.Fatalf("_.Read(_) = %v, want <nil>", err)
}
if _, err := sstream1.Read(make([]byte, 1)); err != io.EOF {
t.Fatalf("_.Read(_) = %v, want io.EOF", err)
}
}
func TestServerWithMisbehavedClient(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, suspended)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",
}
var sc *http2Server
// Wait until the server transport is setup.
for {
server.mu.Lock()
if len(server.conns) == 0 {
server.mu.Unlock()
time.Sleep(time.Millisecond)
continue
}
for k := range server.conns {
var ok bool
sc, ok = k.(*http2Server)
if !ok {
t.Fatalf("Failed to convert %v to *http2Server", k)
}
}
server.mu.Unlock()
break
}
cc, ok := ct.(*http2Client)
if !ok {
t.Fatalf("Failed to convert %v to *http2Client", ct)
}
// Test server behavior for violation of stream flow control window size restriction.
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
var sent int
// Drain the stream flow control window
cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, http2MaxFrameLen), func() {}})
sent += http2MaxFrameLen
// Wait until the server creates the corresponding stream and receive some data.
var ss *Stream
for {
time.Sleep(time.Millisecond)
sc.mu.Lock()
if len(sc.activeStreams) == 0 {
sc.mu.Unlock()
continue
}
ss = sc.activeStreams[s.id]
sc.mu.Unlock()
ss.fc.mu.Lock()
if ss.fc.pendingData > 0 {
ss.fc.mu.Unlock()
break
}
ss.fc.mu.Unlock()
}
if ss.fc.pendingData != http2MaxFrameLen || ss.fc.pendingUpdate != 0 || sc.fc.pendingData != 0 || sc.fc.pendingUpdate != 0 {
t.Fatalf("Server mistakenly updates inbound flow control params: got %d, %d, %d, %d; want %d, %d, %d, %d", ss.fc.pendingData, ss.fc.pendingUpdate, sc.fc.pendingData, sc.fc.pendingUpdate, http2MaxFrameLen, 0, 0, 0)
}
// Keep sending until the server inbound window is drained for that stream.
for sent <= initialWindowSize {
cc.controlBuf.put(&dataFrame{s.id, false, make([]byte, 1), func() {}})
sent++
}
// Server sent a resetStream for s already.
code := http2ErrConvTab[http2.ErrCodeFlowControl]
if _, err := s.Read(make([]byte, 1)); err != io.EOF {
t.Fatalf("%v got err %v want <EOF>", s, err)
}
if s.status.Code() != code {
t.Fatalf("%v got status %v; want Code=%v", s, s.status, code)
}
ct.CloseStream(s, nil)
ct.Close()
server.stop()
}
func TestClientWithMisbehavedServer(t *testing.T) {
// Turn off BDP estimation so that the server can
// violate stream window.
connectOptions := ConnectOptions{
InitialWindowSize: initialWindowSize,
}
server, ct := setUpWithOptions(t, 0, &ServerConfig{}, misbehaved, connectOptions)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo.Stream",
}
conn, ok := ct.(*http2Client)
if !ok {
t.Fatalf("Failed to convert %v to *http2Client", ct)
}
// Test the logic for the violation of stream flow control window size restriction.
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
t.Fatalf("Failed to open stream: %v", err)
}
d := make([]byte, 1)
if err := ct.Write(s, nil, d, &Options{Last: true, Delay: false}); err != nil && err != io.EOF {
t.Fatalf("Failed to write: %v", err)
}
// Read without window update.
for {
p := make([]byte, http2MaxFrameLen)
if _, err = s.trReader.(*transportReader).reader.Read(p); err != nil {
break
}
}
if s.fc.pendingData <= initialWindowSize || s.fc.pendingUpdate != 0 || conn.fc.pendingData != 0 || conn.fc.pendingUpdate != 0 {
t.Fatalf("Client mistakenly updates inbound flow control params: got %d, %d, %d, %d; want >%d, %d, %d, >%d", s.fc.pendingData, s.fc.pendingUpdate, conn.fc.pendingData, conn.fc.pendingUpdate, initialWindowSize, 0, 0, 0)
}
if err != io.EOF {
t.Fatalf("Got err %v, want <EOF>", err)
}
if s.status.Code() != codes.Internal {
t.Fatalf("Got s.status %v, want s.status.Code()=Internal", s.status)
}
conn.CloseStream(s, err)
ct.Close()
server.stop()
}
var encodingTestStatus = status.New(codes.Internal, "\n")
func TestEncodingRequiredStatus(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
return
}
opts := Options{
Last: true,
Delay: false,
}
if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF {
t.Fatalf("Read got error %v, want %v", err, io.EOF)
}
if !reflect.DeepEqual(s.Status(), encodingTestStatus) {
t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus)
}
ct.Close()
server.stop()
}
func TestInvalidHeaderField(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, invalidHeaderField)
callHdr := &CallHdr{
Host: "localhost",
Method: "foo",
}
s, err := ct.NewStream(context.Background(), callHdr)
if err != nil {
return
}
opts := Options{
Last: true,
Delay: false,
}
if err := ct.Write(s, nil, expectedRequest, &opts); err != nil && err != io.EOF {
t.Fatalf("Failed to write the request: %v", err)
}
p := make([]byte, http2MaxFrameLen)
_, err = s.trReader.(*transportReader).Read(p)
if se, ok := err.(StreamError); !ok || se.Code != codes.FailedPrecondition || !strings.Contains(err.Error(), expectedInvalidHeaderField) {
t.Fatalf("Read got error %v, want error with code %s and contains %q", err, codes.FailedPrecondition, expectedInvalidHeaderField)
}
ct.Close()
server.stop()
}
func TestStreamContext(t *testing.T) {
expectedStream := &Stream{}
ctx := newContextWithStream(context.Background(), expectedStream)
s, ok := StreamFromContext(ctx)
if !ok || expectedStream != s {
t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, s, ok, expectedStream)
}
}
func TestIsReservedHeader(t *testing.T) {
tests := []struct {
h string
want bool
}{
{"", false}, // but should be rejected earlier
{"foo", false},
{"content-type", true},
{"grpc-message-type", true},
{"grpc-encoding", true},
{"grpc-message", true},
{"grpc-status", true},
{"grpc-timeout", true},
{"te", true},
}
for _, tt := range tests {
got := isReservedHeader(tt.h)
if got != tt.want {
t.Errorf("isReservedHeader(%q) = %v; want %v", tt.h, got, tt.want)
}
}
}
func TestContextErr(t *testing.T) {
for _, test := range []struct {
// input
errIn error
// outputs
errOut StreamError
}{
{context.DeadlineExceeded, StreamError{codes.DeadlineExceeded, context.DeadlineExceeded.Error()}},
{context.Canceled, StreamError{codes.Canceled, context.Canceled.Error()}},
} {
err := ContextErr(test.errIn)
if err != test.errOut {
t.Fatalf("ContextErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
}
}
}
func max(a, b int32) int32 {
if a > b {
return a
}
return b
}
type windowSizeConfig struct {
serverStream int32
serverConn int32
clientStream int32
clientConn int32
}
func TestAccountCheckWindowSizeWithLargeWindow(t *testing.T) {
wc := windowSizeConfig{
serverStream: 10 * 1024 * 1024,
serverConn: 12 * 1024 * 1024,
clientStream: 6 * 1024 * 1024,
clientConn: 8 * 1024 * 1024,
}
testAccountCheckWindowSize(t, wc)
}
func TestAccountCheckWindowSizeWithSmallWindow(t *testing.T) {
wc := windowSizeConfig{
serverStream: defaultWindowSize,
// Note this is smaller than initialConnWindowSize which is the current default.
serverConn: defaultWindowSize,
clientStream: defaultWindowSize,
clientConn: defaultWindowSize,
}
testAccountCheckWindowSize(t, wc)
}
func testAccountCheckWindowSize(t *testing.T, wc windowSizeConfig) {
serverConfig := &ServerConfig{
InitialWindowSize: wc.serverStream,
InitialConnWindowSize: wc.serverConn,
}
connectOptions := ConnectOptions{
InitialWindowSize: wc.clientStream,
InitialConnWindowSize: wc.clientConn,
}
server, client := setUpWithOptions(t, 0, serverConfig, suspended, connectOptions)
defer server.stop()
defer client.Close()
// Wait for server conns to be populated with new server transport.
waitWhileTrue(t, func() (bool, error) {
server.mu.Lock()
defer server.mu.Unlock()
if len(server.conns) == 0 {
return true, fmt.Errorf("timed out waiting for server transport to be created")
}
return false, nil
})
var st *http2Server
server.mu.Lock()
for k := range server.conns {
st = k.(*http2Server)
}
server.mu.Unlock()
ct := client.(*http2Client)
cstream, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Failed to create stream. Err: %v", err)
}
// Wait for server to receive headers.
waitWhileTrue(t, func() (bool, error) {
st.mu.Lock()
defer st.mu.Unlock()
if len(st.activeStreams) == 0 {
return true, fmt.Errorf("timed out waiting for server to receive headers")
}
return false, nil
})
// Sleeping to make sure the settings are applied in case of negative test.
time.Sleep(time.Second)
waitWhileTrue(t, func() (bool, error) {
st.fc.mu.Lock()
lim := st.fc.limit
st.fc.mu.Unlock()
if lim != uint32(serverConfig.InitialConnWindowSize) {
return true, fmt.Errorf("Server transport flow control window size: got %v, want %v", lim, serverConfig.InitialConnWindowSize)
}
return false, nil
})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
serverSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire())
if err != nil {
t.Fatalf("Error while acquiring sendQuota on server. Err: %v", err)
}
cancel()
st.sendQuotaPool.add(serverSendQuota)
if serverSendQuota != int(connectOptions.InitialConnWindowSize) {
t.Fatalf("Server send quota(%v) not equal to client's window size(%v) on conn.", serverSendQuota, connectOptions.InitialConnWindowSize)
}
st.mu.Lock()
ssq := st.streamSendQuota
st.mu.Unlock()
if ssq != uint32(connectOptions.InitialWindowSize) {
t.Fatalf("Server stream send quota(%v) not equal to client's window size(%v) on stream.", ssq, connectOptions.InitialWindowSize)
}
ct.fc.mu.Lock()
limit := ct.fc.limit
ct.fc.mu.Unlock()
if limit != uint32(connectOptions.InitialConnWindowSize) {
t.Fatalf("Client transport flow control window size is %v, want %v", limit, connectOptions.InitialConnWindowSize)
}
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
clientSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire())
if err != nil {
t.Fatalf("Error while acquiring sendQuota on client. Err: %v", err)
}
cancel()
ct.sendQuotaPool.add(clientSendQuota)
if clientSendQuota != int(serverConfig.InitialConnWindowSize) {
t.Fatalf("Client send quota(%v) not equal to server's window size(%v) on conn.", clientSendQuota, serverConfig.InitialConnWindowSize)
}
ct.mu.Lock()
ssq = ct.streamSendQuota
ct.mu.Unlock()
if ssq != uint32(serverConfig.InitialWindowSize) {
t.Fatalf("Client stream send quota(%v) not equal to server's window size(%v) on stream.", ssq, serverConfig.InitialWindowSize)
}
cstream.fc.mu.Lock()
limit = cstream.fc.limit
cstream.fc.mu.Unlock()
if limit != uint32(connectOptions.InitialWindowSize) {
t.Fatalf("Client stream flow control window size is %v, want %v", limit, connectOptions.InitialWindowSize)
}
var sstream *Stream
st.mu.Lock()
for _, v := range st.activeStreams {
sstream = v
}
st.mu.Unlock()
sstream.fc.mu.Lock()
limit = sstream.fc.limit
sstream.fc.mu.Unlock()
if limit != uint32(serverConfig.InitialWindowSize) {
t.Fatalf("Server stream flow control window size is %v, want %v", limit, serverConfig.InitialWindowSize)
}
}
// Check accounting on both sides after sending and receiving large messages.
func TestAccountCheckExpandingWindow(t *testing.T) {
server, client := setUp(t, 0, 0, pingpong)
defer server.stop()
defer client.Close()
waitWhileTrue(t, func() (bool, error) {
server.mu.Lock()
defer server.mu.Unlock()
if len(server.conns) == 0 {
return true, fmt.Errorf("timed out while waiting for server transport to be created")
}
return false, nil
})
var st *http2Server
server.mu.Lock()
for k := range server.conns {
st = k.(*http2Server)
}
server.mu.Unlock()
ct := client.(*http2Client)
cstream, err := client.NewStream(context.Background(), &CallHdr{Flush: true})
if err != nil {
t.Fatalf("Failed to create stream. Err: %v", err)
}
msgSize := 65535 * 16 * 2
msg := make([]byte, msgSize)
buf := make([]byte, msgSize+5)
buf[0] = byte(0)
binary.BigEndian.PutUint32(buf[1:], uint32(msgSize))
copy(buf[5:], msg)
opts := Options{}
header := make([]byte, 5)
for i := 1; i <= 10; i++ {
if err := ct.Write(cstream, nil, buf, &opts); err != nil {
t.Fatalf("Error on client while writing message: %v", err)
}
if _, err := cstream.Read(header); err != nil {
t.Fatalf("Error on client while reading data frame header: %v", err)
}
sz := binary.BigEndian.Uint32(header[1:])
recvMsg := make([]byte, int(sz))
if _, err := cstream.Read(recvMsg); err != nil {
t.Fatalf("Error on client while reading data: %v", err)
}
if len(recvMsg) != len(msg) {
t.Fatalf("Length of message received by client: %v, want: %v", len(recvMsg), len(msg))
}
}
var sstream *Stream
st.mu.Lock()
for _, v := range st.activeStreams {
sstream = v
}
st.mu.Unlock()
waitWhileTrue(t, func() (bool, error) {
// Check that pendingData and delta on flow control windows on both sides are 0.
cstream.fc.mu.Lock()
if cstream.fc.delta != 0 {
cstream.fc.mu.Unlock()
return true, fmt.Errorf("delta on flow control window of client stream is non-zero")
}
if cstream.fc.pendingData != 0 {
cstream.fc.mu.Unlock()
return true, fmt.Errorf("pendingData on flow control window of client stream is non-zero")
}
cstream.fc.mu.Unlock()
sstream.fc.mu.Lock()
if sstream.fc.delta != 0 {
sstream.fc.mu.Unlock()
return true, fmt.Errorf("delta on flow control window of server stream is non-zero")
}
if sstream.fc.pendingData != 0 {
sstream.fc.mu.Unlock()
return true, fmt.Errorf("pendingData on flow control window of sercer stream is non-zero")
}
sstream.fc.mu.Unlock()
ct.fc.mu.Lock()
if ct.fc.delta != 0 {
ct.fc.mu.Unlock()
return true, fmt.Errorf("delta on flow control window of client transport is non-zero")
}
if ct.fc.pendingData != 0 {
ct.fc.mu.Unlock()
return true, fmt.Errorf("pendingData on flow control window of client transport is non-zero")
}
ct.fc.mu.Unlock()
st.fc.mu.Lock()
if st.fc.delta != 0 {
st.fc.mu.Unlock()
return true, fmt.Errorf("delta on flow control window of server transport is non-zero")
}
if st.fc.pendingData != 0 {
st.fc.mu.Unlock()
return true, fmt.Errorf("pendingData on flow control window of server transport is non-zero")
}
st.fc.mu.Unlock()
// Check flow conrtrol window on client stream is equal to out flow on server stream.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
serverStreamSendQuota, err := wait(ctx, nil, nil, nil, sstream.sendQuotaPool.acquire())
cancel()
if err != nil {
return true, fmt.Errorf("error while acquiring server stream send quota. Err: %v", err)
}
sstream.sendQuotaPool.add(serverStreamSendQuota)
cstream.fc.mu.Lock()
clientEst := cstream.fc.limit - cstream.fc.pendingUpdate
cstream.fc.mu.Unlock()
if uint32(serverStreamSendQuota) != clientEst {
return true, fmt.Errorf("server stream outflow: %v, estimated by client: %v", serverStreamSendQuota, clientEst)
}
// Check flow control window on server stream is equal to out flow on client stream.
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
clientStreamSendQuota, err := wait(ctx, nil, nil, nil, cstream.sendQuotaPool.acquire())
cancel()
if err != nil {
return true, fmt.Errorf("error while acquiring client stream send quota. Err: %v", err)
}
cstream.sendQuotaPool.add(clientStreamSendQuota)
sstream.fc.mu.Lock()
serverEst := sstream.fc.limit - sstream.fc.pendingUpdate
sstream.fc.mu.Unlock()
if uint32(clientStreamSendQuota) != serverEst {
return true, fmt.Errorf("client stream outflow: %v. estimated by server: %v", clientStreamSendQuota, serverEst)
}
// Check flow control window on client transport is equal to out flow of server transport.
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
serverTrSendQuota, err := wait(ctx, nil, nil, nil, st.sendQuotaPool.acquire())
cancel()
if err != nil {
return true, fmt.Errorf("error while acquring server transport send quota. Err: %v", err)
}
st.sendQuotaPool.add(serverTrSendQuota)
ct.fc.mu.Lock()
clientEst = ct.fc.limit - ct.fc.pendingUpdate
ct.fc.mu.Unlock()
if uint32(serverTrSendQuota) != clientEst {
return true, fmt.Errorf("server transport outflow: %v, estimated by client: %v", serverTrSendQuota, clientEst)
}
// Check flow control window on server transport is equal to out flow of client transport.
ctx, cancel = context.WithTimeout(context.Background(), time.Second)
clientTrSendQuota, err := wait(ctx, nil, nil, nil, ct.sendQuotaPool.acquire())
cancel()
if err != nil {
return true, fmt.Errorf("error while acquiring client transport send quota. Err: %v", err)
}
ct.sendQuotaPool.add(clientTrSendQuota)
st.fc.mu.Lock()
serverEst = st.fc.limit - st.fc.pendingUpdate
st.fc.mu.Unlock()
if uint32(clientTrSendQuota) != serverEst {
return true, fmt.Errorf("client transport outflow: %v, estimated by client: %v", clientTrSendQuota, serverEst)
}
return false, nil
})
}
func waitWhileTrue(t *testing.T, condition func() (bool, error)) {
var (
wait bool
err error
)
timer := time.NewTimer(time.Second * 5)
for {
wait, err = condition()
if wait {
select {
case <-timer.C:
t.Fatalf(err.Error())
default:
time.Sleep(50 * time.Millisecond)
continue
}
}
if !timer.Stop() {
<-timer.C
}
break
}
}
// A function of type writeHeaders writes out
// http status with the given stream ID using the given framer.
type writeHeaders func(*http2.Framer, uint32, int) error
func writeOneHeader(framer *http2.Framer, sid uint32, httpStatus int) error {
var buf bytes.Buffer
henc := hpack.NewEncoder(&buf)
henc.WriteField(hpack.HeaderField{Name: ":status", Value: fmt.Sprint(httpStatus)})
return framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: sid,
BlockFragment: buf.Bytes(),
EndStream: true,
EndHeaders: true,
})
}
func writeTwoHeaders(framer *http2.Framer, sid uint32, httpStatus int) error {
var buf bytes.Buffer
henc := hpack.NewEncoder(&buf)
henc.WriteField(hpack.HeaderField{
Name: ":status",
Value: fmt.Sprint(http.StatusOK),
})
if err := framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: sid,
BlockFragment: buf.Bytes(),
EndHeaders: true,
}); err != nil {
return err
}
buf.Reset()
henc.WriteField(hpack.HeaderField{
Name: ":status",
Value: fmt.Sprint(httpStatus),
})
return framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: sid,
BlockFragment: buf.Bytes(),
EndStream: true,
EndHeaders: true,
})
}
type httpServer struct {
conn net.Conn
httpStatus int
wh writeHeaders
}
func (s *httpServer) start(t *testing.T, lis net.Listener) {
// Launch an HTTP server to send back header with httpStatus.
go func() {
var err error
s.conn, err = lis.Accept()
if err != nil {
t.Errorf("Error accepting connection: %v", err)
return
}
defer s.conn.Close()
// Read preface sent by client.
if _, err = io.ReadFull(s.conn, make([]byte, len(http2.ClientPreface))); err != nil {
t.Errorf("Error at server-side while reading preface from cleint. Err: %v", err)
return
}
reader := bufio.NewReaderSize(s.conn, http2IOBufSize)
writer := bufio.NewWriterSize(s.conn, http2IOBufSize)
framer := http2.NewFramer(writer, reader)
if err = framer.WriteSettingsAck(); err != nil {
t.Errorf("Error at server-side while sending Settings ack. Err: %v", err)
return
}
var sid uint32
// Read frames until a header is received.
for {
frame, err := framer.ReadFrame()
if err != nil {
t.Errorf("Error at server-side while reading frame. Err: %v", err)
return
}
if hframe, ok := frame.(*http2.HeadersFrame); ok {
sid = hframe.Header().StreamID
break
}
}
if err = s.wh(framer, sid, s.httpStatus); err != nil {
t.Errorf("Error at server-side while writing headers. Err: %v", err)
return
}
writer.Flush()
}()
}
func (s *httpServer) cleanUp() {
if s.conn != nil {
s.conn.Close()
}
}
func setUpHTTPStatusTest(t *testing.T, httpStatus int, wh writeHeaders) (stream *Stream, cleanUp func()) {
var (
err error
lis net.Listener
server *httpServer
client ClientTransport
)
cleanUp = func() {
if lis != nil {
lis.Close()
}
if server != nil {
server.cleanUp()
}
if client != nil {
client.Close()
}
}
defer func() {
if err != nil {
cleanUp()
}
}()
lis, err = net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to listen. Err: %v", err)
}
server = &httpServer{
httpStatus: httpStatus,
wh: wh,
}
server.start(t, lis)
client, err = newHTTP2Client(context.Background(), TargetInfo{Addr: lis.Addr().String()}, ConnectOptions{})
if err != nil {
t.Fatalf("Error creating client. Err: %v", err)
}
stream, err = client.NewStream(context.Background(), &CallHdr{Method: "bogus/method", Flush: true})
if err != nil {
t.Fatalf("Error creating stream at client-side. Err: %v", err)
}
return
}
func TestHTTPToGRPCStatusMapping(t *testing.T) {
for k := range httpStatusConvTab {
testHTTPToGRPCStatusMapping(t, k, writeOneHeader)
}
}
func testHTTPToGRPCStatusMapping(t *testing.T, httpStatus int, wh writeHeaders) {
stream, cleanUp := setUpHTTPStatusTest(t, httpStatus, wh)
defer cleanUp()
want := httpStatusConvTab[httpStatus]
buf := make([]byte, 8)
_, err := stream.Read(buf)
if err == nil {
t.Fatalf("Stream.Read(_) unexpectedly returned no error. Expected stream error with code %v", want)
}
serr, ok := err.(StreamError)
if !ok {
t.Fatalf("err.(Type) = %T, want StreamError", err)
}
if want != serr.Code {
t.Fatalf("Want error code: %v, got: %v", want, serr.Code)
}
}
func TestHTTPStatusOKAndMissingGRPCStatus(t *testing.T) {
stream, cleanUp := setUpHTTPStatusTest(t, http.StatusOK, writeOneHeader)
defer cleanUp()
buf := make([]byte, 8)
_, err := stream.Read(buf)
if err != io.EOF {
t.Fatalf("stream.Read(_) = _, %v, want _, io.EOF", err)
}
want := codes.Unknown
stream.mu.Lock()
defer stream.mu.Unlock()
if stream.status.Code() != want {
t.Fatalf("Status code of stream: %v, want: %v", stream.status.Code(), want)
}
}
func TestHTTPStatusNottOKAndMissingGRPCStatusInSecondHeader(t *testing.T) {
testHTTPToGRPCStatusMapping(t, http.StatusUnauthorized, writeTwoHeaders)
}
// If any error occurs on a call to Stream.Read, future calls
// should continue to return that same error.
func TestReadGivesSameErrorAfterAnyErrorOccurs(t *testing.T) {
testRecvBuffer := newRecvBuffer()
s := &Stream{
ctx: context.Background(),
goAway: make(chan struct{}),
buf: testRecvBuffer,
requestRead: func(int) {},
}
s.trReader = &transportReader{
reader: &recvBufferReader{
ctx: s.ctx,
goAway: s.goAway,
recv: s.buf,
},
windowHandler: func(int) {},
}
testData := make([]byte, 1)
testData[0] = 5
testErr := errors.New("test error")
s.write(recvMsg{data: testData, err: testErr})
inBuf := make([]byte, 1)
actualCount, actualErr := s.Read(inBuf)
if actualCount != 0 {
t.Errorf("actualCount, _ := s.Read(_) differs; want 0; got %v", actualCount)
}
if actualErr.Error() != testErr.Error() {
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
}
s.write(recvMsg{data: testData, err: nil})
s.write(recvMsg{data: testData, err: errors.New("different error from first")})
for i := 0; i < 2; i++ {
inBuf := make([]byte, 1)
actualCount, actualErr := s.Read(inBuf)
if actualCount != 0 {
t.Errorf("actualCount, _ := s.Read(_) differs; want %v; got %v", 0, actualCount)
}
if actualErr.Error() != testErr.Error() {
t.Errorf("_ , actualErr := s.Read(_) differs; want actualErr.Error() to be %v; got %v", testErr.Error(), actualErr.Error())
}
}
}