Merge pull request #903 from improbable-io/blocking-graceful-shutdown-fix

Make concurrent Server.GracefulStop calls all behave equivalently.
diff --git a/clientconn.go b/clientconn.go
index 11dce44..c81e389 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -684,7 +684,11 @@
 		}
 		ctx, cancel := context.WithTimeout(ac.ctx, timeout)
 		connectTime := time.Now()
-		newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
+		sinfo := transport.TargetInfo{
+			Addr:     ac.addr.Addr,
+			Metadata: ac.addr.Metadata,
+		}
+		newTransport, err := transport.NewClientTransport(ctx, sinfo, ac.dopts.copts)
 		if err != nil {
 			cancel()
 
diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go
index 27fd033..7813745 100644
--- a/grpclb/grpclb.go
+++ b/grpclb/grpclb.go
@@ -45,6 +45,7 @@
 	"google.golang.org/grpc"
 	lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
 	"google.golang.org/grpc/grpclog"
+	"google.golang.org/grpc/metadata"
 	"google.golang.org/grpc/naming"
 )
 
@@ -184,9 +185,10 @@
 	)
 	for _, s := range servers {
 		// TODO: Support ExpirationInterval
+		md := metadata.Pairs("lb-token", s.LoadBalanceToken)
 		addr := grpc.Address{
-			Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
-			// TODO: include LoadBalanceToken in the Metadata
+			Addr:     fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
+			Metadata: &md,
 		}
 		sl = append(sl, addrInfo{
 			addr: addr,
diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go
index 8d12ebb..658e722 100644
--- a/grpclb/grpclb_test.go
+++ b/grpclb/grpclb_test.go
@@ -47,13 +47,16 @@
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/credentials"
+	hwpb "google.golang.org/grpc/examples/helloworld/helloworld"
 	lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
+	"google.golang.org/grpc/metadata"
 	"google.golang.org/grpc/naming"
 )
 
 var (
-	lbsn = "bar.com"
-	besn = "foo.com"
+	lbsn    = "bar.com"
+	besn    = "foo.com"
+	lbToken = "iamatoken"
 )
 
 type testWatcher struct {
@@ -195,12 +198,29 @@
 	return nil
 }
 
+type helloServer struct {
+}
+
+func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) {
+	md, ok := metadata.FromContext(ctx)
+	if !ok {
+		return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
+	}
+	if md == nil || md["lb-token"][0] != lbToken {
+		return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
+	}
+	return &hwpb.HelloReply{
+		Message: "Hello " + in.Name,
+	}, nil
+}
+
 func startBackends(t *testing.T, sn string, lis ...net.Listener) (servers []*grpc.Server) {
 	for _, l := range lis {
 		creds := &serverNameCheckCreds{
 			sn: sn,
 		}
 		s := grpc.NewServer(grpc.Creds(creds))
+		hwpb.RegisterGreeterServer(s, &helloServer{})
 		servers = append(servers, s)
 		go func(s *grpc.Server, l net.Listener) {
 			s.Serve(l)
@@ -239,8 +259,9 @@
 		t.Fatalf("Failed to generate the port number %v", err)
 	}
 	be := &lbpb.Server{
-		IpAddress: []byte(beAddr[0]),
-		Port:      int32(bePort),
+		IpAddress:        []byte(beAddr[0]),
+		Port:             int32(bePort),
+		LoadBalanceToken: lbToken,
 	}
 	var bes []*lbpb.Server
 	bes = append(bes, be)
@@ -266,12 +287,9 @@
 	if err != nil {
 		t.Fatalf("Failed to dial to the backend %v", err)
 	}
-	// Issue an unimplemented RPC and expect codes.Unimplemented.
-	var (
-		req, reply lbpb.Duration
-	)
-	if err := grpc.Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || grpc.Code(err) != codes.Unimplemented {
-		t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented)
+	helloC := hwpb.NewGreeterClient(cc)
+	if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
+		t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
 	}
 	cc.Close()
 }
diff --git a/transport/http2_client.go b/transport/http2_client.go
index d88e6be..2b0f680 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -57,6 +57,7 @@
 type http2Client struct {
 	target    string // server name/addr
 	userAgent string
+	md        interface{}
 	conn      net.Conn             // underlying communication channel
 	authInfo  credentials.AuthInfo // auth info about the connection
 	nextID    uint32               // the next stream ID to be used
@@ -145,9 +146,9 @@
 // newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
 // and starts to receive messages on it. Non-nil error returns if construction
 // fails.
-func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
+func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) {
 	scheme := "http"
-	conn, err := dial(ctx, opts.Dialer, addr)
+	conn, err := dial(ctx, opts.Dialer, addr.Addr)
 	if err != nil {
 		return nil, connectionErrorf(true, err, "transport: %v", err)
 	}
@@ -160,7 +161,7 @@
 	var authInfo credentials.AuthInfo
 	if creds := opts.TransportCredentials; creds != nil {
 		scheme = "https"
-		conn, authInfo, err = creds.ClientHandshake(ctx, addr, conn)
+		conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn)
 		if err != nil {
 			// Credentials handshake errors are typically considered permanent
 			// to avoid retrying on e.g. bad certificates.
@@ -174,8 +175,9 @@
 	}
 	var buf bytes.Buffer
 	t := &http2Client{
-		target:    addr,
+		target:    addr.Addr,
 		userAgent: ua,
+		md:        addr.Metadata,
 		conn:      conn,
 		authInfo:  authInfo,
 		// The client initiated stream id is odd starting from 1.
@@ -400,6 +402,16 @@
 			}
 		}
 	}
+	if md, ok := t.md.(*metadata.MD); ok {
+		for k, v := range *md {
+			if isReservedHeader(k) {
+				continue
+			}
+			for _, entry := range v {
+				t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
+			}
+		}
+	}
 	first := true
 	// Sends the headers in a single batch even when they span multiple frames.
 	for !endHeaders {
diff --git a/transport/transport.go b/transport/transport.go
index 8f42dbc..413f749 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -361,7 +361,7 @@
 	return newHTTP2Server(conn, maxStreams, authInfo)
 }
 
-// ConnectOptions covers all relevant options for dialing a server.
+// ConnectOptions covers all relevant options for communicating with the server.
 type ConnectOptions struct {
 	// UserAgent is the application user agent.
 	UserAgent string
@@ -373,9 +373,15 @@
 	TransportCredentials credentials.TransportCredentials
 }
 
+// TargetInfo contains the information of the target such as network address and metadata.
+type TargetInfo struct {
+	Addr     string
+	Metadata interface{}
+}
+
 // NewClientTransport establishes the transport with the required ConnectOptions
 // and returns it to the caller.
-func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
+func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions) (ClientTransport, error) {
 	return newHTTP2Client(ctx, target, opts)
 }
 
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 01d95e4..81320e6 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -245,7 +245,10 @@
 		ct      ClientTransport
 		connErr error
 	)
-	ct, connErr = NewClientTransport(context.Background(), addr, ConnectOptions{})
+	target := TargetInfo{
+		Addr: addr,
+	}
+	ct, connErr = NewClientTransport(context.Background(), target, ConnectOptions{})
 	if connErr != nil {
 		t.Fatalf("failed to create transport: %v", connErr)
 	}