balancer: set RPC metadata in address attributes, instead of Metadata field (#4041)

This metadata will be sent with all RPCs on the created SubConn
diff --git a/balancer/grpclb/grpclb_remote_balancer.go b/balancer/grpclb/grpclb_remote_balancer.go
index 8fdda09..08d326a 100644
--- a/balancer/grpclb/grpclb_remote_balancer.go
+++ b/balancer/grpclb/grpclb_remote_balancer.go
@@ -35,6 +35,7 @@
 	"google.golang.org/grpc/connectivity"
 	"google.golang.org/grpc/internal/backoff"
 	"google.golang.org/grpc/internal/channelz"
+	imetadata "google.golang.org/grpc/internal/metadata"
 	"google.golang.org/grpc/keepalive"
 	"google.golang.org/grpc/metadata"
 	"google.golang.org/grpc/resolver"
@@ -76,10 +77,7 @@
 			// net.SplitHostPort() will return too many colons error.
 			ipStr = fmt.Sprintf("[%s]", ipStr)
 		}
-		addr := resolver.Address{
-			Addr:     fmt.Sprintf("%s:%d", ipStr, s.Port),
-			Metadata: &md,
-		}
+		addr := imetadata.Set(resolver.Address{Addr: fmt.Sprintf("%s:%d", ipStr, s.Port)}, md)
 		if logger.V(2) {
 			logger.Infof("lbBalancer: server list entry[%d]: ipStr:|%s|, port:|%d|, load balancer token:|%v|",
 				i, ipStr, s.Port, s.LoadBalanceToken)
@@ -163,19 +161,19 @@
 	addrsSet := make(map[resolver.Address]struct{})
 	// Create new SubConns.
 	for _, addr := range backendAddrs {
-		addrWithoutMD := addr
-		addrWithoutMD.Metadata = nil
-		addrsSet[addrWithoutMD] = struct{}{}
-		lb.backendAddrsWithoutMetadata = append(lb.backendAddrsWithoutMetadata, addrWithoutMD)
+		addrWithoutAttrs := addr
+		addrWithoutAttrs.Attributes = nil
+		addrsSet[addrWithoutAttrs] = struct{}{}
+		lb.backendAddrsWithoutMetadata = append(lb.backendAddrsWithoutMetadata, addrWithoutAttrs)
 
-		if _, ok := lb.subConns[addrWithoutMD]; !ok {
+		if _, ok := lb.subConns[addrWithoutAttrs]; !ok {
 			// Use addrWithMD to create the SubConn.
 			sc, err := lb.cc.NewSubConn([]resolver.Address{addr}, opts)
 			if err != nil {
 				logger.Warningf("grpclb: failed to create new SubConn: %v", err)
 				continue
 			}
-			lb.subConns[addrWithoutMD] = sc // Use the addr without MD as key for the map.
+			lb.subConns[addrWithoutAttrs] = sc // Use the addr without MD as key for the map.
 			if _, ok := lb.scStates[sc]; !ok {
 				// Only set state of new sc to IDLE. The state could already be
 				// READY for cached SubConns.
diff --git a/balancer/grpclb/grpclb_test.go b/balancer/grpclb/grpclb_test.go
index dc94ca8..9cbb338 100644
--- a/balancer/grpclb/grpclb_test.go
+++ b/balancer/grpclb/grpclb_test.go
@@ -303,7 +303,7 @@
 	if !ok {
 		return nil, status.Error(codes.Internal, "failed to receive metadata")
 	}
-	if !s.fallback && (md == nil || md["lb-token"][0] != lbToken) {
+	if !s.fallback && (md == nil || len(md["lb-token"]) == 0 || md["lb-token"][0] != lbToken) {
 		return nil, status.Errorf(codes.Internal, "received unexpected metadata: %v", md)
 	}
 	grpc.SetTrailer(ctx, metadata.Pairs(testmdkey, s.addr))
diff --git a/balancer/grpclb/grpclb_util.go b/balancer/grpclb/grpclb_util.go
index 636725e..373f04b 100644
--- a/balancer/grpclb/grpclb_util.go
+++ b/balancer/grpclb/grpclb_util.go
@@ -124,16 +124,16 @@
 	if len(addrs) != 1 {
 		return nil, fmt.Errorf("grpclb calling NewSubConn with addrs of length %v", len(addrs))
 	}
-	addrWithoutMD := addrs[0]
-	addrWithoutMD.Metadata = nil
+	addrWithoutAttrs := addrs[0]
+	addrWithoutAttrs.Attributes = nil
 
 	ccc.mu.Lock()
 	defer ccc.mu.Unlock()
-	if entry, ok := ccc.subConnCache[addrWithoutMD]; ok {
+	if entry, ok := ccc.subConnCache[addrWithoutAttrs]; ok {
 		// If entry is in subConnCache, the SubConn was being deleted.
 		// cancel function will never be nil.
 		entry.cancel()
-		delete(ccc.subConnCache, addrWithoutMD)
+		delete(ccc.subConnCache, addrWithoutAttrs)
 		return entry.sc, nil
 	}
 
@@ -142,7 +142,7 @@
 		return nil, err
 	}
 
-	ccc.subConnToAddr[scNew] = addrWithoutMD
+	ccc.subConnToAddr[scNew] = addrWithoutAttrs
 	return scNew, nil
 }
 
diff --git a/internal/hierarchy/hierarchy.go b/internal/hierarchy/hierarchy.go
index 17185d9..a2f990f 100644
--- a/internal/hierarchy/hierarchy.go
+++ b/internal/hierarchy/hierarchy.go
@@ -23,7 +23,6 @@
 package hierarchy
 
 import (
-	"google.golang.org/grpc/attributes"
 	"google.golang.org/grpc/resolver"
 )
 
@@ -37,19 +36,12 @@
 	if attrs == nil {
 		return nil
 	}
-	path, ok := attrs.Value(pathKey).([]string)
-	if !ok {
-		return nil
-	}
+	path, _ := attrs.Value(pathKey).([]string)
 	return path
 }
 
 // Set overrides the hierarchical path in addr with path.
 func Set(addr resolver.Address, path []string) resolver.Address {
-	if addr.Attributes == nil {
-		addr.Attributes = attributes.New(pathKey, path)
-		return addr
-	}
 	addr.Attributes = addr.Attributes.WithValues(pathKey, path)
 	return addr
 }
diff --git a/internal/metadata/metadata.go b/internal/metadata/metadata.go
new file mode 100644
index 0000000..3022626
--- /dev/null
+++ b/internal/metadata/metadata.go
@@ -0,0 +1,50 @@
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+// Package metadata contains functions to set and get metadata from addresses.
+//
+// This package is experimental.
+package metadata
+
+import (
+	"google.golang.org/grpc/metadata"
+	"google.golang.org/grpc/resolver"
+)
+
+type mdKeyType string
+
+const mdKey = mdKeyType("grpc.internal.address.metadata")
+
+// Get returns the metadata of addr.
+func Get(addr resolver.Address) metadata.MD {
+	attrs := addr.Attributes
+	if attrs == nil {
+		return nil
+	}
+	md, _ := attrs.Value(mdKey).(metadata.MD)
+	return md
+}
+
+// Set sets (overrides) the metadata in addr.
+//
+// When a SubConn is created with this address, the RPCs sent on it will all
+// have this metadata.
+func Set(addr resolver.Address, md metadata.MD) resolver.Address {
+	addr.Attributes = addr.Attributes.WithValues(mdKey, md)
+	return addr
+}
diff --git a/internal/metadata/metadata_test.go b/internal/metadata/metadata_test.go
new file mode 100644
index 0000000..68c2ca5
--- /dev/null
+++ b/internal/metadata/metadata_test.go
@@ -0,0 +1,86 @@
+/*
+ *
+ * Copyright 2020 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+package metadata
+
+import (
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+	"google.golang.org/grpc/attributes"
+	"google.golang.org/grpc/metadata"
+	"google.golang.org/grpc/resolver"
+)
+
+func TestGet(t *testing.T) {
+	tests := []struct {
+		name string
+		addr resolver.Address
+		want metadata.MD
+	}{
+		{
+			name: "not set",
+			addr: resolver.Address{},
+			want: nil,
+		},
+		{
+			name: "not set",
+			addr: resolver.Address{
+				Attributes: attributes.New(mdKey, metadata.Pairs("k", "v")),
+			},
+			want: metadata.Pairs("k", "v"),
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if got := Get(tt.addr); !cmp.Equal(got, tt.want) {
+				t.Errorf("Get() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestSet(t *testing.T) {
+	tests := []struct {
+		name string
+		addr resolver.Address
+		md   metadata.MD
+	}{
+		{
+			name: "unset before",
+			addr: resolver.Address{},
+			md:   metadata.Pairs("k", "v"),
+		},
+		{
+			name: "set before",
+			addr: resolver.Address{
+				Attributes: attributes.New(mdKey, metadata.Pairs("bef", "ore")),
+			},
+			md: metadata.Pairs("k", "v"),
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			newAddr := Set(tt.addr, tt.md)
+			newMD := Get(newAddr)
+			if !cmp.Equal(newMD, tt.md) {
+				t.Errorf("md after Set() = %v, want %v", newMD, tt.md)
+			}
+		})
+	}
+}
diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go
index fef365c..4778ed1 100644
--- a/internal/transport/http2_client.go
+++ b/internal/transport/http2_client.go
@@ -33,6 +33,7 @@
 	"golang.org/x/net/http2"
 	"golang.org/x/net/http2/hpack"
 	"google.golang.org/grpc/internal/grpcutil"
+	imetadata "google.golang.org/grpc/internal/metadata"
 	"google.golang.org/grpc/internal/transport/networktype"
 
 	"google.golang.org/grpc/codes"
@@ -60,7 +61,7 @@
 	cancel     context.CancelFunc
 	ctxDone    <-chan struct{} // Cache the ctx.Done() chan.
 	userAgent  string
-	md         interface{}
+	md         metadata.MD
 	conn       net.Conn // underlying communication channel
 	loopy      *loopyWriter
 	remoteAddr net.Addr
@@ -268,7 +269,6 @@
 		ctxDone:               ctx.Done(), // Cache Done chan.
 		cancel:                cancel,
 		userAgent:             opts.UserAgent,
-		md:                    addr.Metadata,
 		conn:                  conn,
 		remoteAddr:            conn.RemoteAddr(),
 		localAddr:             conn.LocalAddr(),
@@ -296,6 +296,12 @@
 		keepaliveEnabled:      keepaliveEnabled,
 		bufferPool:            newBufferPool(),
 	}
+
+	if md, ok := addr.Metadata.(*metadata.MD); ok {
+		t.md = *md
+	} else if md := imetadata.Get(addr); md != nil {
+		t.md = md
+	}
 	t.controlBuf = newControlBuffer(t.ctxDone)
 	if opts.InitialWindowSize >= defaultWindowSize {
 		t.initialWindowSize = opts.InitialWindowSize
@@ -512,14 +518,12 @@
 			}
 		}
 	}
-	if md, ok := t.md.(*metadata.MD); ok {
-		for k, vv := range *md {
-			if isReservedHeader(k) {
-				continue
-			}
-			for _, v := range vv {
-				headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
-			}
+	for k, vv := range t.md {
+		if isReservedHeader(k) {
+			continue
+		}
+		for _, v := range vv {
+			headerFields = append(headerFields, hpack.HeaderField{Name: k, Value: encodeMetadataHeader(k, v)})
 		}
 	}
 	return headerFields, nil
diff --git a/test/balancer_test.go b/test/balancer_test.go
index f0189cf..7af5c81 100644
--- a/test/balancer_test.go
+++ b/test/balancer_test.go
@@ -39,6 +39,7 @@
 	"google.golang.org/grpc/internal/balancerload"
 	"google.golang.org/grpc/internal/grpcsync"
 	"google.golang.org/grpc/internal/grpcutil"
+	imetadata "google.golang.org/grpc/internal/metadata"
 	"google.golang.org/grpc/internal/testutils"
 	"google.golang.org/grpc/metadata"
 	"google.golang.org/grpc/resolver"
@@ -543,6 +544,76 @@
 	}
 }
 
+// TestMetadataInAddressAttributes verifies that the metadata added to
+// address.Attributes will be sent with the RPCs.
+func (s) TestMetadataInAddressAttributes(t *testing.T) {
+	const (
+		testMDKey      = "test-md"
+		testMDValue    = "test-md-value"
+		mdBalancerName = "metadata-balancer"
+	)
+
+	// Register a stub balancer which adds metadata to the first address that it
+	// receives and then calls NewSubConn on it.
+	bf := stub.BalancerFuncs{
+		UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
+			addrs := ccs.ResolverState.Addresses
+			if len(addrs) == 0 {
+				return nil
+			}
+			// Only use the first address.
+			sc, err := bd.ClientConn.NewSubConn([]resolver.Address{
+				imetadata.Set(addrs[0], metadata.Pairs(testMDKey, testMDValue)),
+			}, balancer.NewSubConnOptions{})
+			if err != nil {
+				return err
+			}
+			sc.Connect()
+			return nil
+		},
+		UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) {
+			bd.ClientConn.UpdateState(balancer.State{ConnectivityState: state.ConnectivityState, Picker: &aiPicker{result: balancer.PickResult{SubConn: sc}, err: state.ConnectionError}})
+		},
+	}
+	stub.Register(mdBalancerName, bf)
+	t.Logf("Registered balancer %s...", mdBalancerName)
+
+	testMDChan := make(chan []string, 1)
+	ss := &stubServer{
+		emptyCall: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
+			md, ok := metadata.FromIncomingContext(ctx)
+			if ok {
+				select {
+				case testMDChan <- md[testMDKey]:
+				case <-ctx.Done():
+					return nil, ctx.Err()
+				}
+			}
+			return &testpb.Empty{}, nil
+		},
+	}
+	if err := ss.Start(nil, grpc.WithDefaultServiceConfig(
+		fmt.Sprintf(`{ "loadBalancingConfig": [{"%v": {}}] }`, mdBalancerName),
+	)); err != nil {
+		t.Fatalf("Error starting endpoint server: %v", err)
+	}
+	defer ss.Stop()
+
+	// The RPC should succeed with the expected md.
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
+	if _, err := ss.client.EmptyCall(ctx, &testpb.Empty{}); err != nil {
+		t.Fatalf("EmptyCall() = _, %v, want _, <nil>", err)
+	}
+	t.Log("Made an RPC which succeeded...")
+
+	// The server should receive the test metadata.
+	md1 := <-testMDChan
+	if len(md1) == 0 || md1[0] != testMDValue {
+		t.Fatalf("got md: %v, want %v", md1, []string{testMDValue})
+	}
+}
+
 // TestServersSwap creates two servers and verifies the client switches between
 // them when the name resolver reports the first and then the second.
 func (s) TestServersSwap(t *testing.T) {