Merge pull request #903 from improbable-io/blocking-graceful-shutdown-fix
Make concurrent Server.GracefulStop calls all behave equivalently.
diff --git a/clientconn.go b/clientconn.go
index 11dce44..c81e389 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -684,7 +684,11 @@
}
ctx, cancel := context.WithTimeout(ac.ctx, timeout)
connectTime := time.Now()
- newTransport, err := transport.NewClientTransport(ctx, ac.addr.Addr, ac.dopts.copts)
+ sinfo := transport.TargetInfo{
+ Addr: ac.addr.Addr,
+ Metadata: ac.addr.Metadata,
+ }
+ newTransport, err := transport.NewClientTransport(ctx, sinfo, ac.dopts.copts)
if err != nil {
cancel()
diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go
index 27fd033..7813745 100644
--- a/grpclb/grpclb.go
+++ b/grpclb/grpclb.go
@@ -45,6 +45,7 @@
"google.golang.org/grpc"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
"google.golang.org/grpc/grpclog"
+ "google.golang.org/grpc/metadata"
"google.golang.org/grpc/naming"
)
@@ -184,9 +185,10 @@
)
for _, s := range servers {
// TODO: Support ExpirationInterval
+ md := metadata.Pairs("lb-token", s.LoadBalanceToken)
addr := grpc.Address{
- Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
- // TODO: include LoadBalanceToken in the Metadata
+ Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port),
+ Metadata: &md,
}
sl = append(sl, addrInfo{
addr: addr,
diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go
index 8d12ebb..658e722 100644
--- a/grpclb/grpclb_test.go
+++ b/grpclb/grpclb_test.go
@@ -47,13 +47,16 @@
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
+ hwpb "google.golang.org/grpc/examples/helloworld/helloworld"
lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1"
+ "google.golang.org/grpc/metadata"
"google.golang.org/grpc/naming"
)
var (
- lbsn = "bar.com"
- besn = "foo.com"
+ lbsn = "bar.com"
+ besn = "foo.com"
+ lbToken = "iamatoken"
)
type testWatcher struct {
@@ -195,12 +198,29 @@
return nil
}
+type helloServer struct {
+}
+
+func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) {
+ md, ok := metadata.FromContext(ctx)
+ if !ok {
+ return nil, grpc.Errorf(codes.Internal, "failed to receive metadata")
+ }
+ if md == nil || md["lb-token"][0] != lbToken {
+ return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
+ }
+ return &hwpb.HelloReply{
+ Message: "Hello " + in.Name,
+ }, nil
+}
+
func startBackends(t *testing.T, sn string, lis ...net.Listener) (servers []*grpc.Server) {
for _, l := range lis {
creds := &serverNameCheckCreds{
sn: sn,
}
s := grpc.NewServer(grpc.Creds(creds))
+ hwpb.RegisterGreeterServer(s, &helloServer{})
servers = append(servers, s)
go func(s *grpc.Server, l net.Listener) {
s.Serve(l)
@@ -239,8 +259,9 @@
t.Fatalf("Failed to generate the port number %v", err)
}
be := &lbpb.Server{
- IpAddress: []byte(beAddr[0]),
- Port: int32(bePort),
+ IpAddress: []byte(beAddr[0]),
+ Port: int32(bePort),
+ LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
bes = append(bes, be)
@@ -266,12 +287,9 @@
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
}
- // Issue an unimplemented RPC and expect codes.Unimplemented.
- var (
- req, reply lbpb.Duration
- )
- if err := grpc.Invoke(context.Background(), "/foo/bar", &req, &reply, cc); err == nil || grpc.Code(err) != codes.Unimplemented {
- t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want error code %s", err, codes.Unimplemented)
+ helloC := hwpb.NewGreeterClient(cc)
+ if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
+ t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
}
cc.Close()
}
diff --git a/transport/http2_client.go b/transport/http2_client.go
index d88e6be..2b0f680 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -57,6 +57,7 @@
type http2Client struct {
target string // server name/addr
userAgent string
+ md interface{}
conn net.Conn // underlying communication channel
authInfo credentials.AuthInfo // auth info about the connection
nextID uint32 // the next stream ID to be used
@@ -145,9 +146,9 @@
// newHTTP2Client constructs a connected ClientTransport to addr based on HTTP2
// and starts to receive messages on it. Non-nil error returns if construction
// fails.
-func newHTTP2Client(ctx context.Context, addr string, opts ConnectOptions) (_ ClientTransport, err error) {
+func newHTTP2Client(ctx context.Context, addr TargetInfo, opts ConnectOptions) (_ ClientTransport, err error) {
scheme := "http"
- conn, err := dial(ctx, opts.Dialer, addr)
+ conn, err := dial(ctx, opts.Dialer, addr.Addr)
if err != nil {
return nil, connectionErrorf(true, err, "transport: %v", err)
}
@@ -160,7 +161,7 @@
var authInfo credentials.AuthInfo
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
- conn, authInfo, err = creds.ClientHandshake(ctx, addr, conn)
+ conn, authInfo, err = creds.ClientHandshake(ctx, addr.Addr, conn)
if err != nil {
// Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates.
@@ -174,8 +175,9 @@
}
var buf bytes.Buffer
t := &http2Client{
- target: addr,
+ target: addr.Addr,
userAgent: ua,
+ md: addr.Metadata,
conn: conn,
authInfo: authInfo,
// The client initiated stream id is odd starting from 1.
@@ -400,6 +402,16 @@
}
}
}
+ if md, ok := t.md.(*metadata.MD); ok {
+ for k, v := range *md {
+ if isReservedHeader(k) {
+ continue
+ }
+ for _, entry := range v {
+ t.hEnc.WriteField(hpack.HeaderField{Name: k, Value: entry})
+ }
+ }
+ }
first := true
// Sends the headers in a single batch even when they span multiple frames.
for !endHeaders {
diff --git a/transport/transport.go b/transport/transport.go
index 8f42dbc..413f749 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -361,7 +361,7 @@
return newHTTP2Server(conn, maxStreams, authInfo)
}
-// ConnectOptions covers all relevant options for dialing a server.
+// ConnectOptions covers all relevant options for communicating with the server.
type ConnectOptions struct {
// UserAgent is the application user agent.
UserAgent string
@@ -373,9 +373,15 @@
TransportCredentials credentials.TransportCredentials
}
+// TargetInfo contains the information of the target such as network address and metadata.
+type TargetInfo struct {
+ Addr string
+ Metadata interface{}
+}
+
// NewClientTransport establishes the transport with the required ConnectOptions
// and returns it to the caller.
-func NewClientTransport(ctx context.Context, target string, opts ConnectOptions) (ClientTransport, error) {
+func NewClientTransport(ctx context.Context, target TargetInfo, opts ConnectOptions) (ClientTransport, error) {
return newHTTP2Client(ctx, target, opts)
}
diff --git a/transport/transport_test.go b/transport/transport_test.go
index 01d95e4..81320e6 100644
--- a/transport/transport_test.go
+++ b/transport/transport_test.go
@@ -245,7 +245,10 @@
ct ClientTransport
connErr error
)
- ct, connErr = NewClientTransport(context.Background(), addr, ConnectOptions{})
+ target := TargetInfo{
+ Addr: addr,
+ }
+ ct, connErr = NewClientTransport(context.Background(), target, ConnectOptions{})
if connErr != nil {
t.Fatalf("failed to create transport: %v", connErr)
}