Split grpclb client load report test to deflake test. (#1206)

diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go
index 38fc4a8..29c8909 100644
--- a/grpclb/grpclb_test.go
+++ b/grpclb/grpclb_test.go
@@ -44,6 +44,7 @@
 	"testing"
 	"time"
 
+	"github.com/golang/protobuf/proto"
 	"golang.org/x/net/context"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/codes"
@@ -384,7 +385,8 @@
 	creds := serverNameCheckCreds{
 		expected: besn,
 	}
-	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 		addrs: []string{tss.lbAddr},
 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
@@ -421,7 +423,8 @@
 	creds := serverNameCheckCreds{
 		expected: besn,
 	}
-	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 		addrs: []string{tss.lbAddr},
 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
@@ -471,7 +474,8 @@
 	creds := serverNameCheckCreds{
 		expected: besn,
 	}
-	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 		addrs: []string{tss.lbAddr},
 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
@@ -479,7 +483,8 @@
 		t.Fatalf("Failed to dial to the backend %v", err)
 	}
 	testC := testpb.NewTestServiceClient(cc)
-	ctx, _ = context.WithTimeout(context.Background(), 10*time.Millisecond)
+	ctx, cancel = context.WithTimeout(context.Background(), 10*time.Millisecond)
+	defer cancel()
 	if _, err := testC.EmptyCall(ctx, &testpb.Empty{}, grpc.FailFast(false)); grpc.Code(err) != codes.DeadlineExceeded {
 		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, %s", testC, err, codes.DeadlineExceeded)
 	}
@@ -521,7 +526,8 @@
 	creds := serverNameCheckCreds{
 		expected: besn,
 	}
-	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 		addrs: []string{tss.lbAddr},
 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
@@ -578,7 +584,8 @@
 	creds := serverNameCheckCreds{
 		expected: besn,
 	}
-	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
 	resolver := &testNameResolver{
 		addrs: lbAddrs[:2],
 	}
@@ -646,87 +653,14 @@
 	return false
 }
 
-func TestGRPCLBStatsUnary(t *testing.T) {
-	var (
-		countNormalRPC    = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting.
-		countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting.
-	)
-	tss, cleanup, err := newLoadBalancer(3)
-	if err != nil {
-		t.Fatalf("failed to create new load balancer: %v", err)
+func checkStats(stats *lbpb.ClientStats, expected *lbpb.ClientStats) error {
+	if !proto.Equal(stats, expected) {
+		return fmt.Errorf("stats not equal: got %+v, want %+v", stats, expected)
 	}
-	defer cleanup()
-	tss.ls.sls = []*lbpb.ServerList{{
-		Servers: []*lbpb.Server{{
-			IpAddress:            tss.beIPs[0],
-			Port:                 int32(tss.bePorts[0]),
-			LoadBalanceToken:     lbToken,
-			DropForLoadBalancing: true,
-		}, {
-			IpAddress:           tss.beIPs[1],
-			Port:                int32(tss.bePorts[1]),
-			LoadBalanceToken:    lbToken,
-			DropForRateLimiting: true,
-		}, {
-			IpAddress:            tss.beIPs[2],
-			Port:                 int32(tss.bePorts[2]),
-			LoadBalanceToken:     lbToken,
-			DropForLoadBalancing: false,
-		}},
-	}}
-	tss.ls.intervals = []time.Duration{0}
-	tss.ls.statsDura = 100 * time.Millisecond
-	creds := serverNameCheckCreds{
-		expected: besn,
-	}
-	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
-	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
-		addrs: []string{tss.lbAddr},
-	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}))
-	if err != nil {
-		t.Fatalf("Failed to dial to the backend %v", err)
-	}
-	testC := testpb.NewTestServiceClient(cc)
-	// The first non-failfast RPC succeeds, all connections are up.
-	if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
-		t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
-	}
-	for i := 0; i < countNormalRPC-1; i++ {
-		testC.EmptyCall(context.Background(), &testpb.Empty{})
-	}
-	for i := 0; i < countFailedToSend; i++ {
-		grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc)
-	}
-	cc.Close()
-
-	time.Sleep(1 * time.Second)
-	tss.ls.mu.Lock()
-	if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) {
-		t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend)
-	}
-	if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) {
-		t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend)
-	}
-	if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 {
-		t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend)
-	}
-	if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 {
-		t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend)
-	}
-	if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 {
-		t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend)
-	}
-	if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 {
-		t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC)
-	}
-	tss.ls.mu.Unlock()
+	return nil
 }
 
-func TestGRPCLBStatsStreaming(t *testing.T) {
-	var (
-		countNormalRPC    = 66 // 1/3 succeeds, 1/3 dropped load balancing, 1/3 dropped rate limiting.
-		countFailedToSend = 30 // 1/3 fail to send, 1/3 dropped load balancing, 1/3 dropped rate limiting.
-	)
+func runAndGetStats(t *testing.T, dropForLoadBalancing, dropForRateLimiting bool, runRPCs func(*grpc.ClientConn)) lbpb.ClientStats {
 	tss, cleanup, err := newLoadBalancer(3)
 	if err != nil {
 		t.Fatalf("failed to create new load balancer: %v", err)
@@ -734,81 +668,247 @@
 	defer cleanup()
 	tss.ls.sls = []*lbpb.ServerList{{
 		Servers: []*lbpb.Server{{
-			IpAddress:            tss.beIPs[0],
-			Port:                 int32(tss.bePorts[0]),
-			LoadBalanceToken:     lbToken,
-			DropForLoadBalancing: true,
-		}, {
-			IpAddress:           tss.beIPs[1],
-			Port:                int32(tss.bePorts[1]),
-			LoadBalanceToken:    lbToken,
-			DropForRateLimiting: true,
-		}, {
 			IpAddress:            tss.beIPs[2],
 			Port:                 int32(tss.bePorts[2]),
 			LoadBalanceToken:     lbToken,
-			DropForLoadBalancing: false,
+			DropForLoadBalancing: dropForLoadBalancing,
+			DropForRateLimiting:  dropForRateLimiting,
 		}},
 	}}
 	tss.ls.intervals = []time.Duration{0}
 	tss.ls.statsDura = 100 * time.Millisecond
-	creds := serverNameCheckCreds{
-		expected: besn,
-	}
-	ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
+	creds := serverNameCheckCreds{expected: besn}
+
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
 	cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(grpc.NewGRPCLBBalancer(&testNameResolver{
 		addrs: []string{tss.lbAddr},
 	})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds), grpc.WithPerRPCCredentials(failPreRPCCred{}))
 	if err != nil {
 		t.Fatalf("Failed to dial to the backend %v", err)
 	}
-	testC := testpb.NewTestServiceClient(cc)
-	// The first non-failfast RPC succeeds, all connections are up.
-	var stream testpb.TestService_FullDuplexCallClient
-	stream, err = testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
-	if err != nil {
-		t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
-	}
-	for {
-		if _, err = stream.Recv(); err == io.EOF {
-			break
+	defer cc.Close()
+
+	runRPCs(cc)
+	time.Sleep(1 * time.Second)
+	tss.ls.mu.Lock()
+	stats := tss.ls.stats
+	tss.ls.mu.Unlock()
+	return stats
+}
+
+const countRPC = 40
+
+func TestGRPCLBStatsUnarySuccess(t *testing.T) {
+	stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
+		testC := testpb.NewTestServiceClient(cc)
+		// The first non-failfast RPC succeeds, all connections are up.
+		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
+			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
 		}
+		for i := 0; i < countRPC-1; i++ {
+			testC.EmptyCall(context.Background(), &testpb.Empty{})
+		}
+	})
+
+	if err := checkStats(&stats, &lbpb.ClientStats{
+		NumCallsStarted:               int64(countRPC),
+		NumCallsFinished:              int64(countRPC),
+		NumCallsFinishedKnownReceived: int64(countRPC),
+	}); err != nil {
+		t.Fatal(err)
 	}
-	for i := 0; i < countNormalRPC-1; i++ {
-		stream, err = testC.FullDuplexCall(context.Background())
-		if err == nil {
-			// Wait for stream to end if err is nil.
-			for {
-				if _, err = stream.Recv(); err == io.EOF {
+}
+
+func TestGRPCLBStatsUnaryDropLoadBalancing(t *testing.T) {
+	c := 0
+	stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
+		testC := testpb.NewTestServiceClient(cc)
+		for {
+			c++
+			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
+				if strings.Contains(err.Error(), "drops requests") {
 					break
 				}
 			}
 		}
-	}
-	for i := 0; i < countFailedToSend; i++ {
-		grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend")
-	}
-	cc.Close()
+		for i := 0; i < countRPC; i++ {
+			testC.EmptyCall(context.Background(), &testpb.Empty{})
+		}
+	})
 
-	time.Sleep(1 * time.Second)
-	tss.ls.mu.Lock()
-	if tss.ls.stats.NumCallsStarted != int64(countNormalRPC+countFailedToSend) {
-		t.Errorf("num calls started = %v, want %v+%v", tss.ls.stats.NumCallsStarted, countNormalRPC, countFailedToSend)
+	if err := checkStats(&stats, &lbpb.ClientStats{
+		NumCallsStarted:                          int64(countRPC + c),
+		NumCallsFinished:                         int64(countRPC + c),
+		NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
+		NumCallsFinishedWithClientFailedToSend:   int64(c - 1),
+	}); err != nil {
+		t.Fatal(err)
 	}
-	if tss.ls.stats.NumCallsFinished != int64(countNormalRPC+countFailedToSend) {
-		t.Errorf("num calls finished = %v, want %v+%v", tss.ls.stats.NumCallsFinished, countNormalRPC, countFailedToSend)
+}
+
+func TestGRPCLBStatsUnaryDropRateLimiting(t *testing.T) {
+	c := 0
+	stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
+		testC := testpb.NewTestServiceClient(cc)
+		for {
+			c++
+			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
+				if strings.Contains(err.Error(), "drops requests") {
+					break
+				}
+			}
+		}
+		for i := 0; i < countRPC; i++ {
+			testC.EmptyCall(context.Background(), &testpb.Empty{})
+		}
+	})
+
+	if err := checkStats(&stats, &lbpb.ClientStats{
+		NumCallsStarted:                         int64(countRPC + c),
+		NumCallsFinished:                        int64(countRPC + c),
+		NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
+		NumCallsFinishedWithClientFailedToSend:  int64(c - 1),
+	}); err != nil {
+		t.Fatal(err)
 	}
-	if tss.ls.stats.NumCallsFinishedWithDropForRateLimiting != int64(countNormalRPC+countFailedToSend)/3 {
-		t.Errorf("num calls drop rate limiting = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForRateLimiting, countNormalRPC, countFailedToSend)
+}
+
+func TestGRPCLBStatsUnaryFailedToSend(t *testing.T) {
+	stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
+		testC := testpb.NewTestServiceClient(cc)
+		// The first non-failfast RPC succeeds, all connections are up.
+		if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}, grpc.FailFast(false)); err != nil {
+			t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, <nil>", testC, err)
+		}
+		for i := 0; i < countRPC-1; i++ {
+			grpc.Invoke(context.Background(), "failtosend", &testpb.Empty{}, nil, cc)
+		}
+	})
+
+	if err := checkStats(&stats, &lbpb.ClientStats{
+		NumCallsStarted:                        int64(countRPC),
+		NumCallsFinished:                       int64(countRPC),
+		NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
+		NumCallsFinishedKnownReceived:          1,
+	}); err != nil {
+		t.Fatal(err)
 	}
-	if tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing != int64(countNormalRPC+countFailedToSend)/3 {
-		t.Errorf("num calls drop load balancing = %v, want (%v+%v)/3", tss.ls.stats.NumCallsFinishedWithDropForLoadBalancing, countNormalRPC, countFailedToSend)
+}
+
+func TestGRPCLBStatsStreamingSuccess(t *testing.T) {
+	stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
+		testC := testpb.NewTestServiceClient(cc)
+		// The first non-failfast RPC succeeds, all connections are up.
+		stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
+		if err != nil {
+			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
+		}
+		for {
+			if _, err = stream.Recv(); err == io.EOF {
+				break
+			}
+		}
+		for i := 0; i < countRPC-1; i++ {
+			stream, err = testC.FullDuplexCall(context.Background())
+			if err == nil {
+				// Wait for stream to end if err is nil.
+				for {
+					if _, err = stream.Recv(); err == io.EOF {
+						break
+					}
+				}
+			}
+		}
+	})
+
+	if err := checkStats(&stats, &lbpb.ClientStats{
+		NumCallsStarted:               int64(countRPC),
+		NumCallsFinished:              int64(countRPC),
+		NumCallsFinishedKnownReceived: int64(countRPC),
+	}); err != nil {
+		t.Fatal(err)
 	}
-	if tss.ls.stats.NumCallsFinishedWithClientFailedToSend != int64(countFailedToSend)/3 {
-		t.Errorf("num calls failed to send = %v, want %v/3", tss.ls.stats.NumCallsFinishedWithClientFailedToSend, countFailedToSend)
+}
+
+func TestGRPCLBStatsStreamingDropLoadBalancing(t *testing.T) {
+	c := 0
+	stats := runAndGetStats(t, true, false, func(cc *grpc.ClientConn) {
+		testC := testpb.NewTestServiceClient(cc)
+		for {
+			c++
+			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
+				if strings.Contains(err.Error(), "drops requests") {
+					break
+				}
+			}
+		}
+		for i := 0; i < countRPC; i++ {
+			testC.FullDuplexCall(context.Background())
+		}
+	})
+
+	if err := checkStats(&stats, &lbpb.ClientStats{
+		NumCallsStarted:                          int64(countRPC + c),
+		NumCallsFinished:                         int64(countRPC + c),
+		NumCallsFinishedWithDropForLoadBalancing: int64(countRPC + 1),
+		NumCallsFinishedWithClientFailedToSend:   int64(c - 1),
+	}); err != nil {
+		t.Fatal(err)
 	}
-	if tss.ls.stats.NumCallsFinishedKnownReceived != int64(countNormalRPC)/3 {
-		t.Errorf("num calls known received = %v, want %v/3", tss.ls.stats.NumCallsFinishedKnownReceived, countNormalRPC)
+}
+
+func TestGRPCLBStatsStreamingDropRateLimiting(t *testing.T) {
+	c := 0
+	stats := runAndGetStats(t, false, true, func(cc *grpc.ClientConn) {
+		testC := testpb.NewTestServiceClient(cc)
+		for {
+			c++
+			if _, err := testC.EmptyCall(context.Background(), &testpb.Empty{}); err != nil {
+				if strings.Contains(err.Error(), "drops requests") {
+					break
+				}
+			}
+		}
+		for i := 0; i < countRPC; i++ {
+			testC.FullDuplexCall(context.Background())
+		}
+	})
+
+	if err := checkStats(&stats, &lbpb.ClientStats{
+		NumCallsStarted:                         int64(countRPC + c),
+		NumCallsFinished:                        int64(countRPC + c),
+		NumCallsFinishedWithDropForRateLimiting: int64(countRPC + 1),
+		NumCallsFinishedWithClientFailedToSend:  int64(c - 1),
+	}); err != nil {
+		t.Fatal(err)
 	}
-	tss.ls.mu.Unlock()
+}
+
+func TestGRPCLBStatsStreamingFailedToSend(t *testing.T) {
+	stats := runAndGetStats(t, false, false, func(cc *grpc.ClientConn) {
+		testC := testpb.NewTestServiceClient(cc)
+		// The first non-failfast RPC succeeds, all connections are up.
+		stream, err := testC.FullDuplexCall(context.Background(), grpc.FailFast(false))
+		if err != nil {
+			t.Fatalf("%v.FullDuplexCall(_, _) = _, %v, want _, <nil>", testC, err)
+		}
+		for {
+			if _, err = stream.Recv(); err == io.EOF {
+				break
+			}
+		}
+		for i := 0; i < countRPC-1; i++ {
+			grpc.NewClientStream(context.Background(), &grpc.StreamDesc{}, cc, "failtosend")
+		}
+	})
+
+	if err := checkStats(&stats, &lbpb.ClientStats{
+		NumCallsStarted:                        int64(countRPC),
+		NumCallsFinished:                       int64(countRPC),
+		NumCallsFinishedWithClientFailedToSend: int64(countRPC - 1),
+		NumCallsFinishedKnownReceived:          1,
+	}); err != nil {
+		t.Fatal(err)
+	}
 }