ClientHandshake should get the dialing endpoint as the authority (#1607)
diff --git a/clientconn.go b/clientconn.go
index 21541b6..61ac1a0 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -431,13 +431,16 @@
if cc.dopts.bs == nil {
cc.dopts.bs = DefaultBackoffConfig
}
+ cc.parsedTarget = parseTarget(cc.target)
creds := cc.dopts.copts.TransportCredentials
if creds != nil && creds.Info().ServerName != "" {
cc.authority = creds.Info().ServerName
} else if cc.dopts.insecure && cc.dopts.copts.Authority != "" {
cc.authority = cc.dopts.copts.Authority
} else {
- cc.authority = target
+ // Use endpoint from "scheme://authority/endpoint" as the default
+ // authority for ClientConn.
+ cc.authority = cc.parsedTarget.Endpoint
}
if cc.dopts.scChan != nil && !scSet {
@@ -541,10 +544,11 @@
ctx context.Context
cancel context.CancelFunc
- target string
- authority string
- dopts dialOptions
- csMgr *connectivityStateManager
+ target string
+ parsedTarget resolver.Target
+ authority string
+ dopts dialOptions
+ csMgr *connectivityStateManager
customBalancer bool // If this is true, switching balancer will be disabled.
balancerBuildOpts balancer.BuildOptions
@@ -953,8 +957,9 @@
}
ac.mu.Unlock()
sinfo := transport.TargetInfo{
- Addr: addr.Addr,
- Metadata: addr.Metadata,
+ Addr: addr.Addr,
+ Metadata: addr.Metadata,
+ Authority: ac.cc.authority,
}
newTransport, err := transport.NewClientTransport(ac.cc.ctx, sinfo, copts, timeout)
if err != nil {
diff --git a/credentials/credentials.go b/credentials/credentials.go
index c6575b5..90b6a61 100644
--- a/credentials/credentials.go
+++ b/credentials/credentials.go
@@ -127,15 +127,15 @@
}
}
-func (c *tlsCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
+func (c *tlsCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (_ net.Conn, _ AuthInfo, err error) {
// use local cfg to avoid clobbering ServerName if using multiple endpoints
cfg := cloneTLSConfig(c.config)
if cfg.ServerName == "" {
- colonPos := strings.LastIndex(addr, ":")
+ colonPos := strings.LastIndex(authority, ":")
if colonPos == -1 {
- colonPos = len(addr)
+ colonPos = len(authority)
}
- cfg.ServerName = addr[:colonPos]
+ cfg.ServerName = authority[:colonPos]
}
conn := tls.Client(rawConn, cfg)
errChannel := make(chan error, 1)
diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go
index 7bb843e..53fb77a 100644
--- a/resolver_conn_wrapper.go
+++ b/resolver_conn_wrapper.go
@@ -58,12 +58,11 @@
// builder for this scheme. It then builds the resolver and starts the
// monitoring goroutine for it.
func newCCResolverWrapper(cc *ClientConn) (*ccResolverWrapper, error) {
- target := parseTarget(cc.target)
- grpclog.Infof("dialing to target with scheme: %q", target.Scheme)
+ grpclog.Infof("dialing to target with scheme: %q", cc.parsedTarget.Scheme)
- rb := resolver.Get(target.Scheme)
+ rb := resolver.Get(cc.parsedTarget.Scheme)
if rb == nil {
- return nil, fmt.Errorf("could not get resolver for scheme: %q", target.Scheme)
+ return nil, fmt.Errorf("could not get resolver for scheme: %q", cc.parsedTarget.Scheme)
}
ccr := &ccResolverWrapper{
@@ -74,7 +73,7 @@
}
var err error
- ccr.resolver, err = rb.Build(target, ccr, resolver.BuildOption{})
+ ccr.resolver, err = rb.Build(cc.parsedTarget, ccr, resolver.BuildOption{})
if err != nil {
return nil, err
}
diff --git a/test/end2end_test.go b/test/end2end_test.go
index 6f0c0fa..be9eb27 100644
--- a/test/end2end_test.go
+++ b/test/end2end_test.go
@@ -4287,6 +4287,69 @@
}
}
+type authorityCheckCreds struct {
+ got string
+}
+
+func (c *authorityCheckCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ return rawConn, nil, nil
+}
+func (c *authorityCheckCreds) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
+ c.got = authority
+ return rawConn, nil, nil
+}
+func (c *authorityCheckCreds) Info() credentials.ProtocolInfo {
+ return credentials.ProtocolInfo{}
+}
+func (c *authorityCheckCreds) Clone() credentials.TransportCredentials {
+ return c
+}
+func (c *authorityCheckCreds) OverrideServerName(s string) error {
+ return nil
+}
+
+// This test makes sure that the authority client handshake gets is the endpoint
+// in dial target, not the resolved ip address.
+func TestCredsHandshakeAuthority(t *testing.T) {
+ const testAuthority = "test.auth.ori.ty"
+
+ lis, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatalf("Failed to listen: %v", err)
+ }
+ cred := &authorityCheckCreds{}
+ s := grpc.NewServer()
+ go s.Serve(lis)
+ defer s.Stop()
+
+ r, rcleanup := manual.GenerateAndRegisterManualResolver()
+ defer rcleanup()
+
+ cc, err := grpc.Dial(r.Scheme()+":///"+testAuthority, grpc.WithTransportCredentials(cred))
+ if err != nil {
+ t.Fatalf("grpc.Dial(%q) = %v", lis.Addr().String(), err)
+ }
+ defer cc.Close()
+ r.NewAddress([]resolver.Address{{Addr: lis.Addr().String()}})
+
+ ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ defer cancel()
+ for {
+ s := cc.GetState()
+ if s == connectivity.Ready {
+ break
+ }
+ if !cc.WaitForStateChange(ctx, s) {
+ // ctx got timeout or canceled.
+ t.Fatalf("ClientConn is not ready after 100 ms")
+ }
+ }
+
+ if cred.got != testAuthority {
+ t.Fatalf("client creds got authority: %q, want: %q", cred.got, testAuthority)
+ }
+}
+
func TestFlowControlLogicalRace(t *testing.T) {
// Test for a regression of https://github.com/grpc/grpc-go/issues/632,
// and other flow control bugs.
diff --git a/transport/http2_client.go b/transport/http2_client.go
index 5985666..f665ef0 100644
--- a/transport/http2_client.go
+++ b/transport/http2_client.go
@@ -44,7 +44,6 @@
type http2Client struct {
ctx context.Context
cancel context.CancelFunc
- target string // server name/addr
userAgent string
md interface{}
conn net.Conn // underlying communication channel
@@ -175,7 +174,7 @@
)
if creds := opts.TransportCredentials; creds != nil {
scheme = "https"
- conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Addr, conn)
+ conn, authInfo, err = creds.ClientHandshake(connectCtx, addr.Authority, conn)
if err != nil {
// Credentials handshake errors are typically considered permanent
// to avoid retrying on e.g. bad certificates.
@@ -210,7 +209,6 @@
t := &http2Client{
ctx: ctx,
cancel: cancel,
- target: addr.Addr,
userAgent: opts.UserAgent,
md: addr.Metadata,
conn: conn,
diff --git a/transport/transport.go b/transport/transport.go
index 44d48a9..e0651c4 100644
--- a/transport/transport.go
+++ b/transport/transport.go
@@ -513,8 +513,9 @@
// TargetInfo contains the information of the target such as network address and metadata.
type TargetInfo struct {
- Addr string
- Metadata interface{}
+ Addr string
+ Metadata interface{}
+ Authority string
}
// NewClientTransport establishes the transport with the required ConnectOptions