Fix connectivity state transitions when dialing (#1596)

diff --git a/balancer/roundrobin/roundrobin_test.go b/balancer/roundrobin/roundrobin_test.go
index 0335e48..857dfde 100644
--- a/balancer/roundrobin/roundrobin_test.go
+++ b/balancer/roundrobin/roundrobin_test.go
@@ -111,13 +111,13 @@
 	// The first RPC should fail because there's no address.
 	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
 	defer cancel()
-	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
+	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
 		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 	}
 
 	r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
 	// The second RPC should succeed.
-	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
+	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
 		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 	}
 }
@@ -143,7 +143,7 @@
 	// The first RPC should fail because there's no address.
 	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
 	defer cancel()
-	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
+	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
 		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 	}
 
@@ -158,7 +158,7 @@
 	for si := 0; si < backendCount; si++ {
 		var connected bool
 		for i := 0; i < 1000; i++ {
-			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
+			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
 				t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 			}
 			if p.Addr.String() == test.addresses[si] {
@@ -173,7 +173,7 @@
 	}
 
 	for i := 0; i < 3*backendCount; i++ {
-		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
+		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
 			t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 		}
 		if p.Addr.String() != test.addresses[i%backendCount] {
@@ -202,13 +202,13 @@
 	// The first RPC should fail because there's no address.
 	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
 	defer cancel()
-	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
+	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
 		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 	}
 
 	r.NewAddress([]resolver.Address{{Addr: test.addresses[0]}})
 	// The second RPC should succeed.
-	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
+	if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
 		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 	}
 
@@ -248,7 +248,7 @@
 			defer wg.Done()
 			// This RPC blocks until cc is closed.
 			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
-			if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) == codes.DeadlineExceeded {
+			if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); grpc.Code(err) == codes.DeadlineExceeded {
 				t.Errorf("RPC failed because of deadline after cc is closed; want error the client connection is closing")
 			}
 			cancel()
@@ -278,7 +278,7 @@
 	// The first RPC should fail because there's no address.
 	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
 	defer cancel()
-	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
+	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
 		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 	}
 
@@ -286,7 +286,7 @@
 	// The second RPC should succeed.
 	ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second)
 	defer cancel()
-	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err != nil {
+	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err != nil {
 		t.Fatalf("EmptyCall() = _, %v, want _, nil", err)
 	}
 
@@ -298,7 +298,7 @@
 		go func() {
 			defer wg.Done()
 			// This RPC blocks until NewAddress is called.
-			testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false))
+			testc.EmptyCall(context.Background(), &testpb.Empty{})
 		}()
 	}
 	time.Sleep(50 * time.Millisecond)
@@ -327,7 +327,7 @@
 	// The first RPC should fail because there's no address.
 	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
 	defer cancel()
-	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
+	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
 		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 	}
 
@@ -342,7 +342,7 @@
 	for si := 0; si < backendCount; si++ {
 		var connected bool
 		for i := 0; i < 1000; i++ {
-			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
+			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
 				t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 			}
 			if p.Addr.String() == test.addresses[si] {
@@ -357,7 +357,7 @@
 	}
 
 	for i := 0; i < 3*backendCount; i++ {
-		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
+		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
 			t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 		}
 		if p.Addr.String() != test.addresses[i%backendCount] {
@@ -371,7 +371,7 @@
 	// Loop until see server[backendCount-1] twice without seeing server[backendCount].
 	var targetSeen int
 	for i := 0; i < 1000; i++ {
-		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
+		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
 			t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 		}
 		switch p.Addr.String() {
@@ -390,7 +390,7 @@
 		t.Fatal("Failed to see server[backendCount-1] twice without seeing server[backendCount]")
 	}
 	for i := 0; i < 3*backendCount; i++ {
-		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
+		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
 			t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 		}
 		if p.Addr.String() != test.addresses[i%backendCount] {
@@ -420,7 +420,7 @@
 	// The first RPC should fail because there's no address.
 	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
 	defer cancel()
-	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
+	if _, err := testc.EmptyCall(ctx, &testpb.Empty{}); err == nil || grpc.Code(err) != codes.DeadlineExceeded {
 		t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 	}
 
@@ -435,7 +435,7 @@
 	for si := 0; si < backendCount; si++ {
 		var connected bool
 		for i := 0; i < 1000; i++ {
-			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
+			if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
 				t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 			}
 			if p.Addr.String() == test.addresses[si] {
@@ -450,7 +450,7 @@
 	}
 
 	for i := 0; i < 3*backendCount; i++ {
-		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false), grpc.Peer(&p)); err != nil {
+		if _, err := testc.EmptyCall(context.Background(), &testpb.Empty{}, grpc.Peer(&p)); err != nil {
 			t.Fatalf("EmptyCall() = _, %v, want _, DeadlineExceeded", err)
 		}
 		if p.Addr.String() != test.addresses[i%backendCount] {
diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go
index 8ec7492..7ab1e78 100644
--- a/balancer_conn_wrappers.go
+++ b/balancer_conn_wrappers.go
@@ -252,7 +252,7 @@
 		acbw.ac = ac
 		ac.acbw = acbw
 		if acState != connectivity.Idle {
-			ac.connect(false)
+			ac.connect()
 		}
 	}
 }
@@ -260,7 +260,7 @@
 func (acbw *acBalancerWrapper) Connect() {
 	acbw.mu.Lock()
 	defer acbw.mu.Unlock()
-	acbw.ac.connect(false)
+	acbw.ac.connect()
 }
 
 func (acbw *acBalancerWrapper) getAddrConn() *addrConn {
diff --git a/clientconn.go b/clientconn.go
index 61ac1a0..5f5aac4 100644
--- a/clientconn.go
+++ b/clientconn.go
@@ -709,7 +709,7 @@
 // It does nothing if the ac is not IDLE.
 // TODO(bar) Move this to the addrConn section.
 // This was part of resetAddrConn, keep it here to make the diff look clean.
-func (ac *addrConn) connect(block bool) error {
+func (ac *addrConn) connect() error {
 	ac.mu.Lock()
 	if ac.state == connectivity.Shutdown {
 		ac.mu.Unlock()
@@ -723,32 +723,18 @@
 	ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
 	ac.mu.Unlock()
 
-	if block {
+	// Start a goroutine connecting to the server asynchronously.
+	go func() {
 		if err := ac.resetTransport(); err != nil {
+			grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addrs[0].Addr, err)
 			if err != errConnClosing {
+				// Keep this ac in cc.conns, to get the reason it's torn down.
 				ac.tearDown(err)
 			}
-			if e, ok := err.(transport.ConnectionError); ok && !e.Temporary() {
-				return e.Origin()
-			}
-			return err
+			return
 		}
-		// Start to monitor the error status of transport.
-		go ac.transportMonitor()
-	} else {
-		// Start a goroutine connecting to the server asynchronously.
-		go func() {
-			if err := ac.resetTransport(); err != nil {
-				grpclog.Warningf("Failed to dial %s: %v; please retry.", ac.addrs[0].Addr, err)
-				if err != errConnClosing {
-					// Keep this ac in cc.conns, to get the reason it's torn down.
-					ac.tearDown(err)
-				}
-				return
-			}
-			ac.transportMonitor()
-		}()
-	}
+		ac.transportMonitor()
+	}()
 	return nil
 }
 
@@ -909,6 +895,7 @@
 
 // resetTransport recreates a transport to the address for ac.  The old
 // transport will close itself on error or when the clientconn is closed.
+//
 // TODO(bar) make sure all state transitions are valid.
 func (ac *addrConn) resetTransport() error {
 	ac.mu.Lock()
@@ -916,8 +903,6 @@
 		ac.mu.Unlock()
 		return errConnClosing
 	}
-	ac.state = connectivity.TransientFailure
-	ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
 	if ac.ready != nil {
 		close(ac.ready)
 		ac.ready = nil
@@ -941,8 +926,10 @@
 			return errConnClosing
 		}
 		ac.printf("connecting")
-		ac.state = connectivity.Connecting
-		ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
+		if ac.state != connectivity.Connecting {
+			ac.state = connectivity.Connecting
+			ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
+		}
 		// copy ac.addrs in case of race
 		addrsIter := make([]resolver.Address, len(ac.addrs))
 		copy(addrsIter, ac.addrs)
@@ -1037,6 +1024,13 @@
 			ac.adjustParams(t.GetGoAwayReason())
 		default:
 		}
+		ac.mu.Lock()
+		// Set connectivity state to TransientFailure before calling
+		// resetTransport. Transition READY->CONNECTING is not valid.
+		ac.state = connectivity.TransientFailure
+		ac.cc.handleSubConnStateChange(ac.acbw, ac.state)
+		ac.curAddr = resolver.Address{}
+		ac.mu.Unlock()
 		if err := ac.resetTransport(); err != nil {
 			ac.mu.Lock()
 			ac.printf("transport exiting: %v", err)
@@ -1108,7 +1102,7 @@
 	ac.mu.Unlock()
 	// Trigger idle ac to connect.
 	if idle {
-		ac.connect(false)
+		ac.connect()
 	}
 	return nil, false
 }
diff --git a/pickfirst_test.go b/pickfirst_test.go
index f288754..e58b342 100644
--- a/pickfirst_test.go
+++ b/pickfirst_test.go
@@ -128,7 +128,7 @@
 		go func() {
 			defer wg.Done()
 			// This RPC blocks until NewAddress is called.
-			Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false))
+			Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
 		}()
 	}
 	time.Sleep(50 * time.Millisecond)
@@ -165,7 +165,7 @@
 		go func() {
 			defer wg.Done()
 			// This RPC blocks until NewAddress is called.
-			Invoke(context.Background(), "/foo/bar", &req, &reply, cc, FailFast(false))
+			Invoke(context.Background(), "/foo/bar", &req, &reply, cc)
 		}()
 	}
 	time.Sleep(50 * time.Millisecond)