xds: Exit from run() goroutine when resolver is closed. (#3882)

diff --git a/xds/internal/resolver/xds_resolver.go b/xds/internal/resolver/xds_resolver.go
index cdd103e..02bf19c 100644
--- a/xds/internal/resolver/xds_resolver.go
+++ b/xds/internal/resolver/xds_resolver.go
@@ -20,12 +20,12 @@
 package resolver
 
 import (
-	"context"
 	"fmt"
 
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/attributes"
 	"google.golang.org/grpc/internal/grpclog"
+	"google.golang.org/grpc/internal/grpcsync"
 	"google.golang.org/grpc/resolver"
 
 	xdsinternal "google.golang.org/grpc/xds/internal"
@@ -62,6 +62,7 @@
 	r := &xdsResolver{
 		target:   t,
 		cc:       cc,
+		closed:   grpcsync.NewEvent(),
 		updateCh: make(chan suWithError, 1),
 	}
 	r.logger = prefixLogger((r))
@@ -86,7 +87,8 @@
 		return nil, fmt.Errorf("xds: failed to create xds-client: %v", err)
 	}
 	r.client = client
-	r.ctx, r.cancelCtx = context.WithCancel(context.Background())
+
+	// Register a watch on the xdsClient for the user's dial target.
 	cancelWatch := r.client.WatchService(r.target.Endpoint, r.handleServiceUpdate)
 	r.logger.Infof("Watch started on resource name %v with xds-client %p", r.target.Endpoint, r.client)
 	r.cancelWatch = func() {
@@ -145,10 +147,9 @@
 // (which performs LDS/RDS queries for the same), and passes the received
 // updates to the ClientConn.
 type xdsResolver struct {
-	ctx       context.Context
-	cancelCtx context.CancelFunc
-	target    resolver.Target
-	cc        resolver.ClientConn
+	target resolver.Target
+	cc     resolver.ClientConn
+	closed *grpcsync.Event
 
 	logger *grpclog.PrefixLogger
 
@@ -176,7 +177,8 @@
 func (r *xdsResolver) run() {
 	for {
 		select {
-		case <-r.ctx.Done():
+		case <-r.closed.Done():
+			return
 		case update := <-r.updateCh:
 			if update.err != nil {
 				r.logger.Warningf("Watch error on resource %v from xds-client %p, %v", r.target.Endpoint, r.client, update.err)
@@ -214,7 +216,7 @@
 // the received update to the update channel, which is picked by the run
 // goroutine.
 func (r *xdsResolver) handleServiceUpdate(su xdsclient.ServiceUpdate, err error) {
-	if r.ctx.Err() != nil {
+	if r.closed.HasFired() {
 		// Do not pass updates to the ClientConn once the resolver is closed.
 		return
 	}
@@ -228,6 +230,6 @@
 func (r *xdsResolver) Close() {
 	r.cancelWatch()
 	r.client.Close()
-	r.cancelCtx()
+	r.closed.Fire()
 	r.logger.Infof("Shutdown")
 }
diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go
index 98c6d6a..9d04e3c 100644
--- a/xds/internal/resolver/xds_resolver_test.go
+++ b/xds/internal/resolver/xds_resolver_test.go
@@ -30,6 +30,7 @@
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/internal"
 	"google.golang.org/grpc/internal/grpcrand"
+	"google.golang.org/grpc/internal/grpctest"
 	"google.golang.org/grpc/internal/testutils"
 	"google.golang.org/grpc/resolver"
 	"google.golang.org/grpc/serviceconfig"
@@ -58,7 +59,15 @@
 	target = resolver.Target{Endpoint: targetStr}
 )
 
-func TestRegister(t *testing.T) {
+type s struct {
+	grpctest.Tester
+}
+
+func Test(t *testing.T) {
+	grpctest.RunSubTests(t, s{})
+}
+
+func (s) TestRegister(t *testing.T) {
 	b := resolver.Get(xdsScheme)
 	if b == nil {
 		t.Errorf("scheme %v is not registered", xdsScheme)
@@ -119,7 +128,7 @@
 
 // TestResolverBuilder tests the xdsResolverBuilder's Build method with
 // different parameters.
-func TestResolverBuilder(t *testing.T) {
+func (s) TestResolverBuilder(t *testing.T) {
 	tests := []struct {
 		name          string
 		rbo           resolver.BuildOptions
@@ -262,7 +271,7 @@
 
 // TestXDSResolverWatchCallbackAfterClose tests the case where a service update
 // from the underlying xdsClient is received after the resolver is closed.
-func TestXDSResolverWatchCallbackAfterClose(t *testing.T) {
+func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) {
 	xdsC := fakeclient.NewClient()
 	xdsR, tcc, cancel := testSetup(t, setupOpts{
 		config:        &validConfig,
@@ -286,7 +295,7 @@
 
 // TestXDSResolverBadServiceUpdate tests the case the xdsClient returns a bad
 // service update.
-func TestXDSResolverBadServiceUpdate(t *testing.T) {
+func (s) TestXDSResolverBadServiceUpdate(t *testing.T) {
 	xdsC := fakeclient.NewClient()
 	xdsR, tcc, cancel := testSetup(t, setupOpts{
 		config:        &validConfig,
@@ -313,7 +322,7 @@
 
 // TestXDSResolverGoodServiceUpdate tests the happy case where the resolver
 // gets a good service update from the xdsClient.
-func TestXDSResolverGoodServiceUpdate(t *testing.T) {
+func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) {
 	xdsC := fakeclient.NewClient()
 	xdsR, tcc, cancel := testSetup(t, setupOpts{
 		config:        &validConfig,
@@ -372,7 +381,7 @@
 
 // TestXDSResolverUpdates tests the cases where the resolver gets a good update
 // after an error, and an error after the good update.
-func TestXDSResolverGoodUpdateAfterError(t *testing.T) {
+func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) {
 	xdsC := fakeclient.NewClient()
 	xdsR, tcc, cancel := testSetup(t, setupOpts{
 		config:        &validConfig,
@@ -423,7 +432,7 @@
 // TestXDSResolverResourceNotFoundError tests the cases where the resolver gets
 // a ResourceNotFoundError. It should generate a service config picking
 // weighted_target, but no child balancers.
-func TestXDSResolverResourceNotFoundError(t *testing.T) {
+func (s) TestXDSResolverResourceNotFoundError(t *testing.T) {
 	xdsC := fakeclient.NewClient()
 	xdsR, tcc, cancel := testSetup(t, setupOpts{
 		config:        &validConfig,