xds: call xdsclient.New instead of getting xds_client from attributes (#4032)

diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go
index b37cc73..5a869e0 100644
--- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go
+++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go
@@ -22,7 +22,6 @@
 	"errors"
 	"fmt"
 
-	"google.golang.org/grpc/attributes"
 	"google.golang.org/grpc/balancer"
 	"google.golang.org/grpc/balancer/base"
 	"google.golang.org/grpc/connectivity"
@@ -37,7 +36,6 @@
 	"google.golang.org/grpc/xds/internal/balancer/edsbalancer"
 	"google.golang.org/grpc/xds/internal/client/bootstrap"
 
-	xdsinternal "google.golang.org/grpc/xds/internal"
 	xdsclient "google.golang.org/grpc/xds/internal/client"
 )
 
@@ -61,6 +59,7 @@
 		// not deal with subConns.
 		return builder.Build(cc, opts), nil
 	}
+	newXDSClient  = func() (xdsClientInterface, error) { return xdsclient.New() }
 	buildProvider = buildProviderFunc
 )
 
@@ -86,6 +85,13 @@
 	b.logger = prefixLogger((b))
 	b.logger.Infof("Created")
 
+	client, err := newXDSClient()
+	if err != nil {
+		b.logger.Errorf("failed to create xds-client: %v", err)
+		return nil
+	}
+	b.xdsClient = client
+
 	var creds credentials.TransportCredentials
 	switch {
 	case opts.DialCreds != nil:
@@ -141,7 +147,6 @@
 // watcher with the xdsClient, while a non-nil error causes it to cancel the
 // existing watch and propagate the error to the underlying edsBalancer.
 type ccUpdate struct {
-	client      xdsClientInterface
 	clusterName string
 	err         error
 }
@@ -196,16 +201,9 @@
 	if err := update.err; err != nil {
 		b.handleErrorFromUpdate(err, true)
 	}
-	if b.xdsClient == update.client && b.clusterToWatch == update.clusterName {
+	if b.clusterToWatch == update.clusterName {
 		return
 	}
-	if update.client != nil {
-		// Since the cdsBalancer doesn't own the xdsClient object, we don't have
-		// to bother about closing the old client here, but we still need to
-		// cancel the watch on the old client.
-		b.cancelWatch()
-		b.xdsClient = update.client
-	}
 	if update.clusterName != "" {
 		cancelWatch := b.xdsClient.WatchCluster(update.clusterName, b.handleClusterUpdate)
 		b.logger.Infof("Watch started on resource name %v with xds-client %p", update.clusterName, b.xdsClient)
@@ -351,7 +349,6 @@
 
 	}
 	ccState := balancer.ClientConnState{
-		ResolverState:  resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, b.xdsClient)},
 		BalancerConfig: lbCfg,
 	}
 	if err := b.edsLB.UpdateClientConnState(ccState); err != nil {
@@ -469,17 +466,7 @@
 		b.logger.Warningf("xds: no clusterName found in LoadBalancingConfig: %+v", lbCfg)
 		return balancer.ErrBadResolverState
 	}
-	client := state.ResolverState.Attributes.Value(xdsinternal.XDSClientID)
-	if client == nil {
-		b.logger.Warningf("xds: no xdsClient found in resolver state attributes")
-		return balancer.ErrBadResolverState
-	}
-	newClient, ok := client.(xdsClientInterface)
-	if !ok {
-		b.logger.Warningf("xds: unexpected xdsClient type: %T", client)
-		return balancer.ErrBadResolverState
-	}
-	b.updateCh.Put(&ccUpdate{client: newClient, clusterName: lbCfg.ClusterName})
+	b.updateCh.Put(&ccUpdate{clusterName: lbCfg.ClusterName})
 	return nil
 }
 
@@ -504,6 +491,7 @@
 // Close closes the cdsBalancer and the underlying edsBalancer.
 func (b *cdsBalancer) Close() {
 	b.closed.Fire()
+	b.xdsClient.Close()
 }
 
 // ccWrapper wraps the balancer.ClientConn that was passed in to the CDS
diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go
index 972d16b..52c7f23 100644
--- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go
+++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go
@@ -115,6 +115,10 @@
 func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) {
 	t.Helper()
 
+	xdsC := fakeclient.NewClient()
+	oldNewXDSClient := newXDSClient
+	newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil }
+
 	builder := balancer.Get(cdsName)
 	if builder == nil {
 		t.Fatalf("balancer.Get(%q) returned nil", cdsName)
@@ -140,10 +144,8 @@
 		return edsB, nil
 	}
 
-	// Create a fake xDS client and push a ClientConnState update to the CDS
-	// balancer with a cluster name and the fake xDS client in the attributes.
-	xdsC := fakeclient.NewClient()
-	if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil {
+	// Push a ClientConnState update to the CDS balancer with a cluster name.
+	if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil {
 		t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err)
 	}
 
@@ -160,6 +162,7 @@
 	}
 
 	return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() {
+		newXDSClient = oldNewXDSClient
 		newEDSBalancer = oldEDSBalancerBuilder
 	}
 }
@@ -229,7 +232,7 @@
 	// returned to the CDS balancer, because we have overridden the
 	// newEDSBalancer function as part of test setup.
 	cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName}
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer ctxCancel()
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil {
@@ -285,7 +288,7 @@
 	// newEDSBalancer function as part of test setup. No security config is
 	// passed to the CDS balancer as part of this update.
 	cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName}
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer ctxCancel()
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil {
@@ -441,7 +444,7 @@
 	// create a new EDS balancer. The fake EDS balancer created above will be
 	// returned to the CDS balancer, because we have overridden the
 	// newEDSBalancer function as part of test setup.
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil {
 		t.Fatal(err)
 	}
@@ -475,7 +478,7 @@
 	// create a new EDS balancer. The fake EDS balancer created above will be
 	// returned to the CDS balancer, because we have overridden the
 	// newEDSBalancer function as part of test setup.
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer ctxCancel()
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil {
@@ -506,7 +509,7 @@
 	// create a new EDS balancer. The fake EDS balancer created above will be
 	// returned to the CDS balancer, because we have overridden the
 	// newEDSBalancer function as part of test setup.
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer ctxCancel()
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil {
@@ -556,7 +559,7 @@
 	// create a new EDS balancer. The fake EDS balancer created above will be
 	// returned to the CDS balancer, because we have overridden the
 	// newEDSBalancer function as part of test setup.
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer ctxCancel()
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil {
@@ -633,7 +636,7 @@
 			RootInstanceName: "default1",
 		},
 	}
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer ctxCancel()
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil {
diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go
index 2257929..ccd1699 100644
--- a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go
+++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go
@@ -34,7 +34,6 @@
 	"google.golang.org/grpc/internal/testutils"
 	"google.golang.org/grpc/resolver"
 	"google.golang.org/grpc/serviceconfig"
-	xdsinternal "google.golang.org/grpc/xds/internal"
 	"google.golang.org/grpc/xds/internal/balancer/edsbalancer"
 	xdsclient "google.golang.org/grpc/xds/internal/client"
 	xdstestutils "google.golang.org/grpc/xds/internal/testutils"
@@ -172,7 +171,7 @@
 
 // cdsCCS is a helper function to construct a good update passed from the
 // xdsResolver to the cdsBalancer.
-func cdsCCS(cluster string, xdsClient interface{}) balancer.ClientConnState {
+func cdsCCS(cluster string) balancer.ClientConnState {
 	const cdsLBConfig = `{
       "loadBalancingConfig":[
         {
@@ -186,7 +185,6 @@
 	return balancer.ClientConnState{
 		ResolverState: resolver.State{
 			ServiceConfig: internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(jsonSC),
-			Attributes:    attributes.New(xdsinternal.XDSClientID, xdsClient),
 		},
 		BalancerConfig: &lbConfig{ClusterName: clusterName},
 	}
@@ -194,22 +192,25 @@
 
 // edsCCS is a helper function to construct a good update passed from the
 // cdsBalancer to the edsBalancer.
-func edsCCS(service string, enableLRS bool, xdsClient interface{}) balancer.ClientConnState {
+func edsCCS(service string, enableLRS bool) balancer.ClientConnState {
 	lbCfg := &edsbalancer.EDSConfig{EDSServiceName: service}
 	if enableLRS {
 		lbCfg.LrsLoadReportingServerName = new(string)
 	}
 	return balancer.ClientConnState{
-		ResolverState:  resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, xdsClient)},
 		BalancerConfig: lbCfg,
 	}
 }
 
 // setup creates a cdsBalancer and an edsBalancer (and overrides the
 // newEDSBalancer function to return it), and also returns a cleanup function.
-func setup(t *testing.T) (*cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) {
+func setup(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) {
 	t.Helper()
 
+	xdsC := fakeclient.NewClient()
+	oldNewXDSClient := newXDSClient
+	newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil }
+
 	builder := balancer.Get(cdsName)
 	if builder == nil {
 		t.Fatalf("balancer.Get(%q) returned nil", cdsName)
@@ -224,8 +225,9 @@
 		return edsB, nil
 	}
 
-	return cdsB.(*cdsBalancer), edsB, tcc, func() {
+	return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() {
 		newEDSBalancer = oldEDSBalancerBuilder
+		newXDSClient = oldNewXDSClient
 	}
 }
 
@@ -234,9 +236,8 @@
 func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) {
 	t.Helper()
 
-	xdsC := fakeclient.NewClient()
-	cdsB, edsB, tcc, cancel := setup(t)
-	if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil {
+	xdsC, cdsB, edsB, tcc, cancel := setup(t)
+	if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil {
 		t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err)
 	}
 
@@ -256,8 +257,6 @@
 // cdsBalancer with different inputs and verifies that the CDS watch API on the
 // provided xdsClient is invoked appropriately.
 func (s) TestUpdateClientConnState(t *testing.T) {
-	xdsC := fakeclient.NewClient()
-
 	tests := []struct {
 		name        string
 		ccs         balancer.ClientConnState
@@ -275,35 +274,15 @@
 			wantErr: balancer.ErrBadResolverState,
 		},
 		{
-			name: "no-xdsClient-in-attributes",
-			ccs: balancer.ClientConnState{
-				ResolverState: resolver.State{
-					Attributes: attributes.New("key", "value"),
-				},
-				BalancerConfig: &lbConfig{ClusterName: clusterName},
-			},
-			wantErr: balancer.ErrBadResolverState,
-		},
-		{
-			name: "bad-xdsClient-in-attributes",
-			ccs: balancer.ClientConnState{
-				ResolverState: resolver.State{
-					Attributes: attributes.New(xdsinternal.XDSClientID, "value"),
-				},
-				BalancerConfig: &lbConfig{ClusterName: clusterName},
-			},
-			wantErr: balancer.ErrBadResolverState,
-		},
-		{
 			name:        "happy-good-case",
-			ccs:         cdsCCS(clusterName, xdsC),
+			ccs:         cdsCCS(clusterName),
 			wantCluster: clusterName,
 		},
 	}
 
 	for _, test := range tests {
 		t.Run(test.name, func(t *testing.T) {
-			cdsB, _, _, cancel := setup(t)
+			xdsC, cdsB, _, _, cancel := setup(t)
 			defer func() {
 				cancel()
 				cdsB.Close()
@@ -340,7 +319,7 @@
 	}()
 
 	// This is the same clientConn update sent in setupWithWatch().
-	if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil {
+	if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil {
 		t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err)
 	}
 	// The above update should not result in a new watch being registered.
@@ -370,12 +349,12 @@
 		{
 			name:      "happy-case-with-lrs",
 			cdsUpdate: xdsclient.ClusterUpdate{ServiceName: serviceName, EnableLRS: true},
-			wantCCS:   edsCCS(serviceName, true, xdsC),
+			wantCCS:   edsCCS(serviceName, true),
 		},
 		{
 			name:      "happy-case-without-lrs",
 			cdsUpdate: xdsclient.ClusterUpdate{ServiceName: serviceName},
-			wantCCS:   edsCCS(serviceName, false, xdsC),
+			wantCCS:   edsCCS(serviceName, false),
 		},
 	}
 
@@ -443,7 +422,7 @@
 	// returned to the CDS balancer, because we have overridden the
 	// newEDSBalancer function as part of test setup.
 	cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName}
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil {
 		t.Fatal(err)
 	}
@@ -528,7 +507,7 @@
 	// returned to the CDS balancer, because we have overridden the
 	// newEDSBalancer function as part of test setup.
 	cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName}
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil {
 		t.Fatal(err)
 	}
@@ -577,7 +556,7 @@
 	// returned to the CDS balancer, because we have overridden the
 	// newEDSBalancer function as part of test setup.
 	cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName}
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer ctxCancel()
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil {
@@ -609,7 +588,7 @@
 	// returned to the CDS balancer, because we have overridden the
 	// newEDSBalancer function as part of test setup.
 	cdsUpdate := xdsclient.ClusterUpdate{ServiceName: serviceName}
-	wantCCS := edsCCS(serviceName, false, xdsC)
+	wantCCS := edsCCS(serviceName, false)
 	ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer ctxCancel()
 	if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil {
@@ -640,7 +619,7 @@
 
 	// Make sure that the UpdateClientConnState() method on the CDS balancer
 	// returns error.
-	if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, fakeclient.NewClient())); err != errBalancerClosed {
+	if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != errBalancerClosed {
 		t.Fatalf("UpdateClientConnState() after close returned %v, want %v", err, errBalancerClosed)
 	}
 
diff --git a/xds/internal/balancer/edsbalancer/eds.go b/xds/internal/balancer/edsbalancer/eds.go
index 9312b05..1358dd7 100644
--- a/xds/internal/balancer/edsbalancer/eds.go
+++ b/xds/internal/balancer/edsbalancer/eds.go
@@ -43,6 +43,7 @@
 	newEDSBalancer = func(cc balancer.ClientConn, enqueueState func(priorityType, balancer.State), xdsClient *xdsClientWrapper, logger *grpclog.PrefixLogger) edsBalancerImplInterface {
 		return newEDSBalancerImpl(cc, enqueueState, xdsClient, logger)
 	}
+	newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() }
 )
 
 func init() {
@@ -61,7 +62,14 @@
 		childPolicyUpdate: buffer.NewUnbounded(),
 	}
 	x.logger = prefixLogger((x))
-	x.client = newXDSClientWrapper(x.handleEDSUpdate, x.logger)
+
+	client, err := newXDSClient()
+	if err != nil {
+		x.logger.Errorf("xds: failed to create xds-client: %v", err)
+		return nil
+	}
+
+	x.client = newXDSClientWrapper(client, x.handleEDSUpdate, x.logger)
 	x.edsImpl = newEDSBalancer(x.cc, x.enqueueChildBalancerState, x.client, x.logger)
 	x.logger.Infof("Created")
 	go x.run()
@@ -177,7 +185,7 @@
 			return
 		}
 
-		if err := x.client.handleUpdate(cfg, u.ResolverState.Attributes); err != nil {
+		if err := x.client.handleUpdate(cfg); err != nil {
 			x.logger.Warningf("failed to update xds clients: %v", err)
 		}
 
diff --git a/xds/internal/balancer/edsbalancer/eds_test.go b/xds/internal/balancer/edsbalancer/eds_test.go
index 085310d..eeeaaca 100644
--- a/xds/internal/balancer/edsbalancer/eds_test.go
+++ b/xds/internal/balancer/edsbalancer/eds_test.go
@@ -30,9 +30,6 @@
 	"github.com/golang/protobuf/jsonpb"
 	wrapperspb "github.com/golang/protobuf/ptypes/wrappers"
 	"github.com/google/go-cmp/cmp"
-	"google.golang.org/grpc/attributes"
-	xdsinternal "google.golang.org/grpc/xds/internal"
-
 	"google.golang.org/grpc/balancer"
 	"google.golang.org/grpc/connectivity"
 	"google.golang.org/grpc/internal/grpclog"
@@ -192,15 +189,20 @@
 // edsLB, creates fake version of them and makes them available on the provided
 // channels. The returned cancel function should be called by the test for
 // cleanup.
-func setup(edsLBCh *testutils.Channel) func() {
+func setup(edsLBCh *testutils.Channel) (*fakeclient.Client, func()) {
+	xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar)
+	oldNewXDSClient := newXDSClient
+	newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil }
+
 	origNewEDSBalancer := newEDSBalancer
 	newEDSBalancer = func(cc balancer.ClientConn, enqueue func(priorityType, balancer.State), _ *xdsClientWrapper, logger *grpclog.PrefixLogger) edsBalancerImplInterface {
 		edsLB := newFakeEDSBalancer(cc)
 		defer func() { edsLBCh.Send(edsLB) }()
 		return edsLB
 	}
-	return func() {
+	return xdsC, func() {
 		newEDSBalancer = origNewEDSBalancer
+		newXDSClient = oldNewXDSClient
 	}
 }
 
@@ -261,9 +263,8 @@
 //   This time around, we expect no new xdsClient or edsLB to be created.
 //   Instead, we expect the existing edsLB to receive the new child policy.
 func (s) TestXDSConnfigChildPolicyUpdate(t *testing.T) {
-	xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar)
 	edsLBCh := testutils.NewChannel()
-	cancel := setup(edsLBCh)
+	xdsC, cancel := setup(edsLBCh)
 	defer cancel()
 
 	builder := balancer.Get(edsName)
@@ -275,7 +276,6 @@
 	defer edsB.Close()
 
 	edsB.UpdateClientConnState(balancer.ClientConnState{
-		ResolverState: resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, xdsC)},
 		BalancerConfig: &EDSConfig{
 			ChildPolicy: &loadBalancingConfig{
 				Name:   fakeBalancerA,
@@ -298,7 +298,6 @@
 	})
 
 	edsB.UpdateClientConnState(balancer.ClientConnState{
-		ResolverState: resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, xdsC)},
 		BalancerConfig: &EDSConfig{
 			ChildPolicy: &loadBalancingConfig{
 				Name:   fakeBalancerB,
@@ -316,9 +315,8 @@
 // TestXDSSubConnStateChange verifies if the top-level edsBalancer passes on
 // the subConnStateChange to appropriate child balancers.
 func (s) TestXDSSubConnStateChange(t *testing.T) {
-	xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar)
 	edsLBCh := testutils.NewChannel()
-	cancel := setup(edsLBCh)
+	xdsC, cancel := setup(edsLBCh)
 	defer cancel()
 
 	builder := balancer.Get(edsName)
@@ -330,7 +328,6 @@
 	defer edsB.Close()
 
 	edsB.UpdateClientConnState(balancer.ClientConnState{
-		ResolverState:  resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, xdsC)},
 		BalancerConfig: &EDSConfig{EDSServiceName: testEDSClusterName},
 	})
 
@@ -357,9 +354,8 @@
 // If it's connection error, nothing will happen. This will need to change to
 // handle fallback.
 func (s) TestErrorFromXDSClientUpdate(t *testing.T) {
-	xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar)
 	edsLBCh := testutils.NewChannel()
-	cancel := setup(edsLBCh)
+	xdsC, cancel := setup(edsLBCh)
 	defer cancel()
 
 	builder := balancer.Get(edsName)
@@ -371,7 +367,6 @@
 	defer edsB.Close()
 
 	if err := edsB.UpdateClientConnState(balancer.ClientConnState{
-		ResolverState:  resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, xdsC)},
 		BalancerConfig: &EDSConfig{EDSServiceName: testEDSClusterName},
 	}); err != nil {
 		t.Fatal(err)
@@ -421,9 +416,8 @@
 // If it's connection error, nothing will happen. This will need to change to
 // handle fallback.
 func (s) TestErrorFromResolver(t *testing.T) {
-	xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar)
 	edsLBCh := testutils.NewChannel()
-	cancel := setup(edsLBCh)
+	xdsC, cancel := setup(edsLBCh)
 	defer cancel()
 
 	builder := balancer.Get(edsName)
@@ -435,7 +429,6 @@
 	defer edsB.Close()
 
 	if err := edsB.UpdateClientConnState(balancer.ClientConnState{
-		ResolverState:  resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, xdsC)},
 		BalancerConfig: &EDSConfig{EDSServiceName: testEDSClusterName},
 	}); err != nil {
 		t.Fatal(err)
diff --git a/xds/internal/balancer/edsbalancer/xds_client_wrapper.go b/xds/internal/balancer/edsbalancer/xds_client_wrapper.go
index 1646601..7890228 100644
--- a/xds/internal/balancer/edsbalancer/xds_client_wrapper.go
+++ b/xds/internal/balancer/edsbalancer/xds_client_wrapper.go
@@ -19,12 +19,9 @@
 package edsbalancer
 
 import (
-	"fmt"
 	"sync"
 
-	"google.golang.org/grpc/attributes"
 	"google.golang.org/grpc/internal/grpclog"
-	xdsinternal "google.golang.org/grpc/xds/internal"
 	xdsclient "google.golang.org/grpc/xds/internal/client"
 	"google.golang.org/grpc/xds/internal/client/load"
 )
@@ -126,50 +123,15 @@
 //
 // The given callbacks won't be called until the underlying xds_client is
 // working and sends updates.
-func newXDSClientWrapper(newEDSUpdate func(xdsclient.EndpointsUpdate, error), logger *grpclog.PrefixLogger) *xdsClientWrapper {
+func newXDSClientWrapper(xdsClient xdsClientInterface, newEDSUpdate func(xdsclient.EndpointsUpdate, error), logger *grpclog.PrefixLogger) *xdsClientWrapper {
 	return &xdsClientWrapper{
 		logger:       logger,
 		newEDSUpdate: newEDSUpdate,
+		xdsClient:    xdsClient,
 		loadWrapper:  &loadStoreWrapper{},
 	}
 }
 
-// updateXDSClient sets xdsClient in wrapper to the correct one based on the
-// attributes and service config.
-//
-// If client is found in attributes, it will be used, but we also need to decide
-// whether to close the old client.
-// - if old client was created locally (balancerName is not ""), close it and
-// replace it
-// - if old client was from previous attributes, only replace it, but don't
-// close it
-//
-// If client is not found in attributes, will need to create a new one only if
-// the balancerName (from bootstrap file or from service config) changed.
-// - if balancer names are the same, do nothing, and return false
-// - if balancer names are different, create new one, and return true
-func (c *xdsClientWrapper) updateXDSClient(attr *attributes.Attributes) (bool, error) {
-	if attr == nil {
-		return false, fmt.Errorf("unexported nil attributes, want attributes with xdsClient")
-	}
-	// TODO: change the way xdsClient is retrieved from attributes. One option
-	// is to add helper functions.
-	//
-	// Or, since xdsClient will become a singleton, this can just call
-	// xdsclient.New() instead. And if we decide to do this, do it in Build
-	// instead of when handling updates.
-	clientFromAttr, _ := attr.Value(xdsinternal.XDSClientID).(xdsClientInterface)
-	if clientFromAttr == nil {
-		return false, fmt.Errorf("no xdsClient found in attributes")
-	}
-
-	if c.xdsClient == clientFromAttr {
-		return false, nil
-	}
-	c.xdsClient = clientFromAttr
-	return true, nil
-}
-
 // startEndpointsWatch starts the EDS watch. Caller can call this when the
 // xds_client is updated, or the edsServiceName is updated.
 //
@@ -221,16 +183,9 @@
 
 // handleUpdate applies the service config and attributes updates to the client,
 // including updating the xds_client to use, and updating the EDS name to watch.
-func (c *xdsClientWrapper) handleUpdate(config *EDSConfig, attr *attributes.Attributes) error {
-	clientChanged, err := c.updateXDSClient(attr)
-	if err != nil {
-		return err
-	}
-
-	// Need to restart EDS watch when one of the following happens:
-	// - the xds_client is updated
-	// - the xds_client didn't change, but the edsServiceName changed
-	if clientChanged || c.edsServiceName != config.EDSServiceName {
+func (c *xdsClientWrapper) handleUpdate(config *EDSConfig) error {
+	// Need to restart EDS watch when the edsServiceName changed
+	if c.edsServiceName != config.EDSServiceName {
 		c.edsServiceName = config.EDSServiceName
 		c.startEndpointsWatch()
 		// TODO: this update for the LRS service name is too early. It should
@@ -266,6 +221,7 @@
 
 func (c *xdsClientWrapper) close() {
 	c.cancelWatch()
+	c.xdsClient.Close()
 }
 
 // equalStringPointers returns true if
diff --git a/xds/internal/balancer/edsbalancer/xds_client_wrapper_test.go b/xds/internal/balancer/edsbalancer/xds_client_wrapper_test.go
index 38162d1..60faf6f 100644
--- a/xds/internal/balancer/edsbalancer/xds_client_wrapper_test.go
+++ b/xds/internal/balancer/edsbalancer/xds_client_wrapper_test.go
@@ -25,9 +25,7 @@
 	"testing"
 
 	"github.com/google/go-cmp/cmp"
-	"google.golang.org/grpc/attributes"
 	"google.golang.org/grpc/internal/testutils"
-	xdsinternal "google.golang.org/grpc/xds/internal"
 	xdsclient "google.golang.org/grpc/xds/internal/client"
 	"google.golang.org/grpc/xds/internal/testutils/fakeclient"
 )
@@ -75,13 +73,13 @@
 func (s) TestClientWrapperWatchEDS(t *testing.T) {
 	xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar)
 
-	cw := newXDSClientWrapper(nil, nil)
+	cw := newXDSClientWrapper(xdsC, nil, nil)
 	defer cw.close()
 	t.Logf("Started xDS client wrapper for endpoint %s...", testServiceName)
 
 	// Update with an non-empty edsServiceName should trigger an EDS watch for
 	// the same.
-	cw.handleUpdate(&EDSConfig{EDSServiceName: "foobar-1"}, attributes.New(xdsinternal.XDSClientID, xdsC))
+	cw.handleUpdate(&EDSConfig{EDSServiceName: "foobar-1"})
 	if err := verifyExpectedRequests(xdsC, "foobar-1"); err != nil {
 		t.Fatal(err)
 	}
@@ -90,7 +88,7 @@
 	// name to another, and make sure a new watch is registered. The previously
 	// registered watch will be cancelled, which will result in an EDS request
 	// with no resource names being sent to the server.
-	cw.handleUpdate(&EDSConfig{EDSServiceName: "foobar-2"}, attributes.New(xdsinternal.XDSClientID, xdsC))
+	cw.handleUpdate(&EDSConfig{EDSServiceName: "foobar-2"})
 	if err := verifyExpectedRequests(xdsC, "", "foobar-2"); err != nil {
 		t.Fatal(err)
 	}
@@ -112,11 +110,11 @@
 		edsRespChan.Send(&edsUpdate{resp: update, err: err})
 	}
 
-	cw := newXDSClientWrapper(newEDS, nil)
+	xdsC := fakeclient.NewClient()
+	cw := newXDSClientWrapper(xdsC, newEDS, nil)
 	defer cw.close()
 
-	xdsC := fakeclient.NewClient()
-	cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName}, attributes.New(xdsinternal.XDSClientID, xdsC))
+	cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName})
 
 	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
 	defer cancel()
@@ -145,43 +143,3 @@
 		t.Fatalf("want update {nil, %v}, got %+v", watchErr, update)
 	}
 }
-
-// TestClientWrapperGetsXDSClientInAttributes verfies the case where the
-// clientWrapper receives the xdsClient to use in the attributes section of the
-// update.
-func (s) TestClientWrapperGetsXDSClientInAttributes(t *testing.T) {
-	cw := newXDSClientWrapper(nil, nil)
-	defer cw.close()
-
-	xdsC1 := fakeclient.NewClient()
-	cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName}, attributes.New(xdsinternal.XDSClientID, xdsC1))
-
-	// Verify that the eds watch is registered for the expected resource name.
-	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
-	defer cancel()
-	gotCluster, err := xdsC1.WaitForWatchEDS(ctx)
-	if err != nil {
-		t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err)
-	}
-	if gotCluster != testEDSClusterName {
-		t.Fatalf("xdsClient.WatchEndpoints() called with cluster: %v, want %v", gotCluster, testEDSClusterName)
-	}
-
-	// Pass a new client in the attributes. Verify that the watch is
-	// re-registered on the new client, and that the old client is not closed
-	// (because clientWrapper only closes clients that it creates, it does not
-	// close client that are passed through attributes).
-	xdsC2 := fakeclient.NewClient()
-	cw.handleUpdate(&EDSConfig{EDSServiceName: testEDSClusterName}, attributes.New(xdsinternal.XDSClientID, xdsC2))
-	gotCluster, err = xdsC2.WaitForWatchEDS(ctx)
-	if err != nil {
-		t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err)
-	}
-	if gotCluster != testEDSClusterName {
-		t.Fatalf("xdsClient.WatchEndpoints() called with cluster: %v, want %v", gotCluster, testEDSClusterName)
-	}
-
-	if err := xdsC1.WaitForClose(ctx); err != context.DeadlineExceeded {
-		t.Fatalf("clientWrapper closed xdsClient received in attributes")
-	}
-}
diff --git a/xds/internal/balancer/edsbalancer/xds_lrs_test.go b/xds/internal/balancer/edsbalancer/xds_lrs_test.go
index 955f544..a9108a2 100644
--- a/xds/internal/balancer/edsbalancer/xds_lrs_test.go
+++ b/xds/internal/balancer/edsbalancer/xds_lrs_test.go
@@ -22,10 +22,7 @@
 	"context"
 	"testing"
 
-	"google.golang.org/grpc/attributes"
 	"google.golang.org/grpc/balancer"
-	"google.golang.org/grpc/resolver"
-	xdsinternal "google.golang.org/grpc/xds/internal"
 	"google.golang.org/grpc/xds/internal/testutils/fakeclient"
 )
 
@@ -33,6 +30,11 @@
 // stream when the lbConfig passed to it contains a valid value for the LRS
 // server (empty string).
 func (s) TestXDSLoadReporting(t *testing.T) {
+	xdsC := fakeclient.NewClient()
+	oldNewXDSClient := newXDSClient
+	newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil }
+	defer func() { newXDSClient = oldNewXDSClient }()
+
 	builder := balancer.Get(edsName)
 	cc := newNoopTestClientConn()
 	edsB, ok := builder.Build(cc, balancer.BuildOptions{}).(*edsBalancer)
@@ -41,9 +43,7 @@
 	}
 	defer edsB.Close()
 
-	xdsC := fakeclient.NewClient()
 	if err := edsB.UpdateClientConnState(balancer.ClientConnState{
-		ResolverState: resolver.State{Attributes: attributes.New(xdsinternal.XDSClientID, xdsC)},
 		BalancerConfig: &EDSConfig{
 			EDSServiceName:             testEDSClusterName,
 			LrsLoadReportingServerName: new(string),
diff --git a/xds/internal/balancer/lrs/balancer.go b/xds/internal/balancer/lrs/balancer.go
index f8e7673..7bf672b 100644
--- a/xds/internal/balancer/lrs/balancer.go
+++ b/xds/internal/balancer/lrs/balancer.go
@@ -24,12 +24,11 @@
 	"fmt"
 	"sync"
 
-	"google.golang.org/grpc/attributes"
 	"google.golang.org/grpc/balancer"
 	"google.golang.org/grpc/internal/grpclog"
 	"google.golang.org/grpc/serviceconfig"
 	"google.golang.org/grpc/xds/internal"
-	xdsinternal "google.golang.org/grpc/xds/internal"
+	xdsclient "google.golang.org/grpc/xds/internal/client"
 	"google.golang.org/grpc/xds/internal/client/load"
 )
 
@@ -37,6 +36,8 @@
 	balancer.Register(&lrsBB{})
 }
 
+var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() }
+
 const lrsBalancerName = "lrs_experimental"
 
 type lrsBB struct{}
@@ -46,9 +47,16 @@
 		cc:        cc,
 		buildOpts: opts,
 	}
-	b.client = newXDSClientWrapper()
 	b.logger = prefixLogger(b)
 	b.logger.Infof("Created")
+
+	client, err := newXDSClient()
+	if err != nil {
+		b.logger.Errorf("failed to create xds-client: %v", err)
+		return nil
+	}
+	b.client = newXDSClientWrapper(client)
+
 	return b
 }
 
@@ -80,7 +88,7 @@
 	// Update load reporting config or xds client. This needs to be done before
 	// updating the child policy because we need the loadStore from the updated
 	// client to be passed to the ccWrapper.
-	if err := b.client.update(newConfig, s.ResolverState.Attributes); err != nil {
+	if err := b.client.update(newConfig); err != nil {
 		return err
 	}
 
@@ -219,34 +227,21 @@
 	loadWrapper *loadStoreWrapper
 }
 
-func newXDSClientWrapper() *xdsClientWrapper {
+func newXDSClientWrapper(c xdsClientInterface) *xdsClientWrapper {
 	return &xdsClientWrapper{
+		c:           c,
 		loadWrapper: &loadStoreWrapper{},
 	}
 }
 
 // update checks the config and xdsclient, and decides whether it needs to
 // restart the load reporting stream.
-func (w *xdsClientWrapper) update(newConfig *lbConfig, attr *attributes.Attributes) error {
+func (w *xdsClientWrapper) update(newConfig *lbConfig) error {
 	var (
 		restartLoadReport           bool
 		updateLoadClusterAndService bool
 	)
 
-	if attr == nil {
-		return fmt.Errorf("lrs: failed to get xdsClient from attributes: attributes is nil")
-	}
-	clientFromAttr, _ := attr.Value(xdsinternal.XDSClientID).(xdsClientInterface)
-	if clientFromAttr == nil {
-		return fmt.Errorf("lrs: failed to get xdsClient from attributes: xdsClient not found in attributes")
-	}
-
-	if w.c != clientFromAttr {
-		// xds client is different, restart.
-		restartLoadReport = true
-		w.c = clientFromAttr
-	}
-
 	// ClusterName is different, restart. ClusterName is from ClusterName and
 	// EdsServiceName.
 	if w.clusterName != newConfig.ClusterName {
@@ -301,4 +296,5 @@
 		w.cancelLoadReport()
 		w.cancelLoadReport = nil
 	}
+	w.c.Close()
 }
diff --git a/xds/internal/balancer/lrs/balancer_test.go b/xds/internal/balancer/lrs/balancer_test.go
index 38dd573..0794a77 100644
--- a/xds/internal/balancer/lrs/balancer_test.go
+++ b/xds/internal/balancer/lrs/balancer_test.go
@@ -25,7 +25,6 @@
 	"time"
 
 	"github.com/google/go-cmp/cmp"
-	"google.golang.org/grpc/attributes"
 	"google.golang.org/grpc/balancer"
 	"google.golang.org/grpc/balancer/roundrobin"
 	"google.golang.org/grpc/connectivity"
@@ -53,16 +52,19 @@
 // stream when the lbConfig passed to it contains a valid value for the LRS
 // server (empty string).
 func TestLoadReporting(t *testing.T) {
+	xdsC := fakeclient.NewClient()
+	oldNewXDSClient := newXDSClient
+	newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil }
+	defer func() { newXDSClient = oldNewXDSClient }()
+
 	builder := balancer.Get(lrsBalancerName)
 	cc := testutils.NewTestClientConn(t)
 	lrsB := builder.Build(cc, balancer.BuildOptions{})
 	defer lrsB.Close()
 
-	xdsC := fakeclient.NewClient()
 	if err := lrsB.UpdateClientConnState(balancer.ClientConnState{
 		ResolverState: resolver.State{
-			Addresses:  testBackendAddrs,
-			Attributes: attributes.New(xdsinternal.XDSClientID, xdsC),
+			Addresses: testBackendAddrs,
 		},
 		BalancerConfig: &lbConfig{
 			ClusterName:                testClusterName,
diff --git a/xds/internal/internal.go b/xds/internal/internal.go
index 6924314..5c12e6f 100644
--- a/xds/internal/internal.go
+++ b/xds/internal/internal.go
@@ -24,13 +24,6 @@
 	"strings"
 )
 
-type clientID string
-
-// XDSClientID is the attributes key used to pass the address of the xdsClient
-// object shared between the resolver and the balancer. The xdsClient object is
-// created by the resolver and passed to the balancer.
-const XDSClientID = clientID("xdsClientID")
-
 // LocalityID is xds.Locality without XXX fields, so it can be used as map
 // keys.
 //
diff --git a/xds/internal/resolver/xds_resolver.go b/xds/internal/resolver/xds_resolver.go
index 667e7ae..3e12bb5 100644
--- a/xds/internal/resolver/xds_resolver.go
+++ b/xds/internal/resolver/xds_resolver.go
@@ -22,23 +22,17 @@
 import (
 	"fmt"
 
-	"google.golang.org/grpc/attributes"
 	"google.golang.org/grpc/internal/grpclog"
 	"google.golang.org/grpc/internal/grpcsync"
 	"google.golang.org/grpc/resolver"
 
-	xdsinternal "google.golang.org/grpc/xds/internal"
 	xdsclient "google.golang.org/grpc/xds/internal/client"
 )
 
 const xdsScheme = "xds"
 
 // For overriding in unittests.
-var (
-	newXDSClient = func() (xdsClientInterface, error) {
-		return xdsclient.New()
-	}
-)
+var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() }
 
 func init() {
 	resolver.Register(&xdsResolverBuilder{})
@@ -163,7 +157,6 @@
 			r.logger.Infof("Received update on resource %v from xds-client %p, generated service config: %v", r.target.Endpoint, r.client, sc)
 			r.cc.UpdateState(resolver.State{
 				ServiceConfig: r.cc.ParseServiceConfig(sc),
-				Attributes:    attributes.New(xdsinternal.XDSClientID, r.client),
 			})
 		}
 	}
diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go
index 6e2715b..3f92ffd 100644
--- a/xds/internal/resolver/xds_resolver_test.go
+++ b/xds/internal/resolver/xds_resolver_test.go
@@ -31,7 +31,6 @@
 	"google.golang.org/grpc/internal/testutils"
 	"google.golang.org/grpc/resolver"
 	"google.golang.org/grpc/serviceconfig"
-	xdsinternal "google.golang.org/grpc/xds/internal"
 	_ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // To parse LB config
 	"google.golang.org/grpc/xds/internal/client"
 	xdsclient "google.golang.org/grpc/xds/internal/client"
@@ -310,9 +309,6 @@
 			t.Fatalf("ClientConn.UpdateState returned error: %v", err)
 		}
 		rState := gotState.(resolver.State)
-		if gotClient := rState.Attributes.Value(xdsinternal.XDSClientID); gotClient != xdsC {
-			t.Fatalf("ClientConn.UpdateState got xdsClient: %v, want %v", gotClient, xdsC)
-		}
 		if err := rState.ServiceConfig.Err; err != nil {
 			t.Fatalf("ClientConn.UpdateState received error in service config: %v", rState.ServiceConfig.Err)
 		}
@@ -368,9 +364,6 @@
 		t.Fatalf("ClientConn.UpdateState returned error: %v", err)
 	}
 	rState := gotState.(resolver.State)
-	if gotClient := rState.Attributes.Value(xdsinternal.XDSClientID); gotClient != xdsC {
-		t.Fatalf("ClientConn.UpdateState got xdsClient: %v, want %v", gotClient, xdsC)
-	}
 	if err := rState.ServiceConfig.Err; err != nil {
 		t.Fatalf("ClientConn.UpdateState received error in service config: %v", rState.ServiceConfig.Err)
 	}
@@ -419,11 +412,6 @@
 		t.Fatalf("ClientConn.UpdateState returned error: %v", err)
 	}
 	rState := gotState.(resolver.State)
-	// This update shouldn't have xds-client in it, because it doesn't pick an
-	// xds balancer.
-	if gotClient := rState.Attributes.Value(xdsinternal.XDSClientID); gotClient != nil {
-		t.Fatalf("ClientConn.UpdateState got xdsClient: %v, want <nil>", gotClient)
-	}
 	wantParsedConfig := internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)("{}")
 	if !internal.EqualServiceConfigForTesting(rState.ServiceConfig.Config, wantParsedConfig.Config) {
 		t.Error("ClientConn.UpdateState got wrong service config")