grpclb should connect to the second balancer (#1181)
grpclb needs to connect the second resolved balancer address when the first balancer disconnects.
If grpclb gets 2 resolved addresses: balancer1 and balancer2. When balancer1 disconnects, grpclb should automatically start to use balancer2.
diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go
index b699a52..ea065fa 100644
--- a/grpclb/grpclb.go
+++ b/grpclb/grpclb.go
@@ -111,7 +111,7 @@
rand *rand.Rand
}
-func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error {
+func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan []remoteBalancerInfo) error {
updates, err := w.Next()
if err != nil {
return err
@@ -121,10 +121,6 @@
if b.done {
return grpc.ErrClientConnClosing
}
- var bAddr remoteBalancerInfo
- if len(b.rbs) > 0 {
- bAddr = b.rbs[0]
- }
for _, update := range updates {
switch update.Op {
case naming.Add:
@@ -173,21 +169,11 @@
}
// TODO: Fall back to the basic round-robin load balancing if the resulting address is
// not a load balancer.
- if len(b.rbs) > 0 {
- // For simplicity, always use the first one now. May revisit this decision later.
- if b.rbs[0] != bAddr {
- select {
- case <-ch:
- default:
- }
- // Pick a random one from the list, instead of always using the first one.
- if l := len(b.rbs); l > 1 {
- tmpIdx := b.rand.Intn(l - 1)
- b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0]
- }
- ch <- b.rbs[0]
- }
+ select {
+ case <-ch:
+ default:
}
+ ch <- b.rbs
return nil
}
@@ -261,7 +247,7 @@
func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient, seq int) (retry bool) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false))
+ stream, err := lbc.BalanceLoad(ctx)
if err != nil {
grpclog.Printf("Failed to perform RPC to the remote balancer %v", err)
return
@@ -340,32 +326,98 @@
}
b.w = w
b.mu.Unlock()
- balancerAddrCh := make(chan remoteBalancerInfo, 1)
+ balancerAddrsCh := make(chan []remoteBalancerInfo, 1)
// Spawn a goroutine to monitor the name resolution of remote load balancer.
go func() {
for {
- if err := b.watchAddrUpdates(w, balancerAddrCh); err != nil {
+ if err := b.watchAddrUpdates(w, balancerAddrsCh); err != nil {
grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err)
- close(balancerAddrCh)
+ close(balancerAddrsCh)
return
}
}
}()
// Spawn a goroutine to talk to the remote load balancer.
go func() {
- var cc *grpc.ClientConn
- for {
- rb, ok := <-balancerAddrCh
+ var (
+ cc *grpc.ClientConn
+ // ccError is closed when there is an error in the current cc.
+ // A new rb should be picked from rbs and connected.
+ ccError chan struct{}
+ rb *remoteBalancerInfo
+ rbs []remoteBalancerInfo
+ rbIdx int
+ )
+
+ defer func() {
+ if ccError != nil {
+ select {
+ case <-ccError:
+ default:
+ close(ccError)
+ }
+ }
if cc != nil {
cc.Close()
}
- if !ok {
- // b is closing.
- return
+ }()
+
+ for {
+ var ok bool
+ select {
+ case rbs, ok = <-balancerAddrsCh:
+ if !ok {
+ return
+ }
+ foundIdx := -1
+ if rb != nil {
+ for i, trb := range rbs {
+ if trb == *rb {
+ foundIdx = i
+ break
+ }
+ }
+ }
+ if foundIdx >= 0 {
+ if foundIdx >= 1 {
+ // Move the address in use to the beginning of the list.
+ b.rbs[0], b.rbs[foundIdx] = b.rbs[foundIdx], b.rbs[0]
+ rbIdx = 0
+ }
+ continue // If found, don't dial new cc.
+ } else if len(rbs) > 0 {
+ // Pick a random one from the list, instead of always using the first one.
+ if l := len(rbs); l > 1 && rb != nil {
+ tmpIdx := b.rand.Intn(l - 1)
+ b.rbs[0], b.rbs[tmpIdx] = b.rbs[tmpIdx], b.rbs[0]
+ }
+ rbIdx = 0
+ rb = &rbs[0]
+ } else {
+ // foundIdx < 0 && len(rbs) <= 0.
+ rb = nil
+ }
+ case <-ccError:
+ ccError = nil
+ if rbIdx < len(rbs)-1 {
+ rbIdx++
+ rb = &rbs[rbIdx]
+ } else {
+ rb = nil
+ }
+ }
+
+ if rb == nil {
+ continue
+ }
+
+ if cc != nil {
+ cc.Close()
}
// Talk to the remote load balancer to get the server list.
var err error
creds := config.DialCreds
+ ccError = make(chan struct{})
if creds == nil {
cc, err = grpc.Dial(rb.addr, grpc.WithInsecure())
} else {
@@ -379,22 +431,24 @@
}
if err != nil {
grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err)
- return
+ close(ccError)
+ continue
}
b.mu.Lock()
b.seq++ // tick when getting a new balancer address
seq := b.seq
b.next = 0
b.mu.Unlock()
- go func(cc *grpc.ClientConn) {
+ go func(cc *grpc.ClientConn, ccError chan struct{}) {
lbc := lbpb.NewLoadBalancerClient(cc)
- for {
- if retry := b.callRemoteBalancer(lbc, seq); !retry {
- cc.Close()
- return
- }
+ b.callRemoteBalancer(lbc, seq)
+ cc.Close()
+ select {
+ case <-ccError:
+ default:
+ close(ccError)
}
- }(cc)
+ }(cc, ccError)
}
}()
return nil
diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go
index ba7824c..f6115b2 100644
--- a/grpclb/grpclb_test.go
+++ b/grpclb/grpclb_test.go
@@ -99,24 +99,26 @@
}
type testNameResolver struct {
- w *testWatcher
- addr string
+ w *testWatcher
+ addrs []string
}
func (r *testNameResolver) Resolve(target string) (naming.Watcher, error) {
r.w = &testWatcher{
- update: make(chan *naming.Update, 1),
+ update: make(chan *naming.Update, len(r.addrs)),
side: make(chan int, 1),
readDone: make(chan int),
}
- r.w.side <- 1
- r.w.update <- &naming.Update{
- Op: naming.Add,
- Addr: r.addr,
- Metadata: &Metadata{
- AddrType: GRPCLB,
- ServerName: lbsn,
- },
+ r.w.side <- len(r.addrs)
+ for _, addr := range r.addrs {
+ r.w.update <- &naming.Update{
+ Op: naming.Add,
+ Addr: addr,
+ Metadata: &Metadata{
+ AddrType: GRPCLB,
+ ServerName: lbsn,
+ },
+ }
}
go func() {
<-r.w.readDone
@@ -124,6 +126,12 @@
return r.w, nil
}
+func (r *testNameResolver) inject(updates []*naming.Update) {
+ if r.w != nil {
+ r.w.inject(updates)
+ }
+}
+
type serverNameCheckCreds struct {
expected string
sn string
@@ -212,6 +220,7 @@
}
type helloServer struct {
+ addr string
}
func (s *helloServer) SayHello(ctx context.Context, in *hwpb.HelloRequest) (*hwpb.HelloReply, error) {
@@ -223,17 +232,17 @@
return nil, grpc.Errorf(codes.Internal, "received unexpected metadata: %v", md)
}
return &hwpb.HelloReply{
- Message: "Hello " + in.Name,
+ Message: "Hello " + in.Name + " for " + s.addr,
}, nil
}
-func startBackends(t *testing.T, sn string, lis ...net.Listener) (servers []*grpc.Server) {
+func startBackends(sn string, lis ...net.Listener) (servers []*grpc.Server) {
for _, l := range lis {
creds := &serverNameCheckCreds{
sn: sn,
}
s := grpc.NewServer(grpc.Creds(creds))
- hwpb.RegisterGreeterServer(s, &helloServer{})
+ hwpb.RegisterGreeterServer(s, &helloServer{addr: l.Addr().String()})
servers = append(servers, s)
go func(s *grpc.Server, l net.Listener) {
s.Serve(l)
@@ -248,32 +257,86 @@
}
}
-func TestGRPCLB(t *testing.T) {
- // Start a backend.
- beLis, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- t.Fatalf("Failed to listen %v", err)
+type testServers struct {
+ lbAddr string
+ ls *remoteBalancer
+ lb *grpc.Server
+ beIPs []net.IP
+ bePorts []int
+}
+
+func newLoadBalancer(numberOfBackends int) (tss *testServers, cleanup func(), err error) {
+ var (
+ beListeners []net.Listener
+ ls *remoteBalancer
+ lb *grpc.Server
+ beIPs []net.IP
+ bePorts []int
+ )
+ for i := 0; i < numberOfBackends; i++ {
+ // Start a backend.
+ beLis, e := net.Listen("tcp", "localhost:0")
+ if e != nil {
+ err = fmt.Errorf("Failed to listen %v", err)
+ return
+ }
+ beIPs = append(beIPs, beLis.Addr().(*net.TCPAddr).IP)
+
+ beAddr := strings.Split(beLis.Addr().String(), ":")
+ bePort, _ := strconv.Atoi(beAddr[1])
+ bePorts = append(bePorts, bePort)
+
+ beListeners = append(beListeners, beLis)
}
- beAddr := strings.Split(beLis.Addr().String(), ":")
- bePort, err := strconv.Atoi(beAddr[1])
- backends := startBackends(t, besn, beLis)
- defer stopBackends(backends)
+ backends := startBackends(besn, beListeners...)
// Start a load balancer.
lbLis, err := net.Listen("tcp", "localhost:0")
if err != nil {
- t.Fatalf("Failed to create the listener for the load balancer %v", err)
+ err = fmt.Errorf("Failed to create the listener for the load balancer %v", err)
+ return
}
lbCreds := &serverNameCheckCreds{
sn: lbsn,
}
- lb := grpc.NewServer(grpc.Creds(lbCreds))
+ lb = grpc.NewServer(grpc.Creds(lbCreds))
if err != nil {
- t.Fatalf("Failed to generate the port number %v", err)
+ err = fmt.Errorf("Failed to generate the port number %v", err)
+ return
}
+ ls = newRemoteBalancer(nil, nil)
+ lbpb.RegisterLoadBalancerServer(lb, ls)
+ go func() {
+ lb.Serve(lbLis)
+ }()
+
+ tss = &testServers{
+ lbAddr: lbLis.Addr().String(),
+ ls: ls,
+ lb: lb,
+ beIPs: beIPs,
+ bePorts: bePorts,
+ }
+ cleanup = func() {
+ defer stopBackends(backends)
+ defer func() {
+ ls.stop()
+ lb.Stop()
+ }()
+ }
+ return
+}
+
+func TestGRPCLB(t *testing.T) {
+ tss, cleanup, err := newLoadBalancer(1)
+ if err != nil {
+ t.Fatalf("failed to create new load balancer: %v", err)
+ }
+ defer cleanup()
+
be := &lbpb.Server{
- IpAddress: beLis.Addr().(*net.TCPAddr).IP,
- Port: int32(bePort),
+ IpAddress: tss.beIPs[0],
+ Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
@@ -281,23 +344,14 @@
sl := &lbpb.ServerList{
Servers: bes,
}
- sls := []*lbpb.ServerList{sl}
- intervals := []time.Duration{0}
- ls := newRemoteBalancer(sls, intervals)
- lbpb.RegisterLoadBalancerServer(lb, ls)
- go func() {
- lb.Serve(lbLis)
- }()
- defer func() {
- ls.stop()
- lb.Stop()
- }()
+ tss.ls.sls = []*lbpb.ServerList{sl}
+ tss.ls.intervals = []time.Duration{0}
creds := serverNameCheckCreds{
expected: besn,
}
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
- addr: lbLis.Addr().String(),
+ addrs: []string{tss.lbAddr},
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
@@ -310,65 +364,31 @@
}
func TestDropRequest(t *testing.T) {
- // Start 2 backends.
- beLis1, err := net.Listen("tcp", "localhost:0")
+ tss, cleanup, err := newLoadBalancer(2)
if err != nil {
- t.Fatalf("Failed to listen %v", err)
+ t.Fatalf("failed to create new load balancer: %v", err)
}
- beAddr1 := strings.Split(beLis1.Addr().String(), ":")
- bePort1, err := strconv.Atoi(beAddr1[1])
-
- beLis2, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- t.Fatalf("Failed to listen %v", err)
- }
- beAddr2 := strings.Split(beLis2.Addr().String(), ":")
- bePort2, err := strconv.Atoi(beAddr2[1])
-
- backends := startBackends(t, besn, beLis1, beLis2)
- defer stopBackends(backends)
-
- // Start a load balancer.
- lbLis, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- t.Fatalf("Failed to create the listener for the load balancer %v", err)
- }
- lbCreds := &serverNameCheckCreds{
- sn: lbsn,
- }
- lb := grpc.NewServer(grpc.Creds(lbCreds))
- if err != nil {
- t.Fatalf("Failed to generate the port number %v", err)
- }
- sls := []*lbpb.ServerList{{
+ defer cleanup()
+ tss.ls.sls = []*lbpb.ServerList{{
Servers: []*lbpb.Server{{
- IpAddress: beLis1.Addr().(*net.TCPAddr).IP,
- Port: int32(bePort1),
+ IpAddress: tss.beIPs[0],
+ Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
DropRequest: true,
}, {
- IpAddress: beLis2.Addr().(*net.TCPAddr).IP,
- Port: int32(bePort2),
+ IpAddress: tss.beIPs[1],
+ Port: int32(tss.bePorts[1]),
LoadBalanceToken: lbToken,
DropRequest: false,
}},
}}
- intervals := []time.Duration{0}
- ls := newRemoteBalancer(sls, intervals)
- lbpb.RegisterLoadBalancerServer(lb, ls)
- go func() {
- lb.Serve(lbLis)
- }()
- defer func() {
- ls.stop()
- lb.Stop()
- }()
+ tss.ls.intervals = []time.Duration{0}
creds := serverNameCheckCreds{
expected: besn,
}
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
- addr: lbLis.Addr().String(),
+ addrs: []string{tss.lbAddr},
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
@@ -395,31 +415,14 @@
}
func TestDropRequestFailedNonFailFast(t *testing.T) {
- // Start a backend.
- beLis, err := net.Listen("tcp", "localhost:0")
+ tss, cleanup, err := newLoadBalancer(1)
if err != nil {
- t.Fatalf("Failed to listen %v", err)
+ t.Fatalf("failed to create new load balancer: %v", err)
}
- beAddr := strings.Split(beLis.Addr().String(), ":")
- bePort, err := strconv.Atoi(beAddr[1])
- backends := startBackends(t, besn, beLis)
- defer stopBackends(backends)
-
- // Start a load balancer.
- lbLis, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- t.Fatalf("Failed to create the listener for the load balancer %v", err)
- }
- lbCreds := &serverNameCheckCreds{
- sn: lbsn,
- }
- lb := grpc.NewServer(grpc.Creds(lbCreds))
- if err != nil {
- t.Fatalf("Failed to generate the port number %v", err)
- }
+ defer cleanup()
be := &lbpb.Server{
- IpAddress: beLis.Addr().(*net.TCPAddr).IP,
- Port: int32(bePort),
+ IpAddress: tss.beIPs[0],
+ Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
DropRequest: true,
}
@@ -428,23 +431,14 @@
sl := &lbpb.ServerList{
Servers: bes,
}
- sls := []*lbpb.ServerList{sl}
- intervals := []time.Duration{0}
- ls := newRemoteBalancer(sls, intervals)
- lbpb.RegisterLoadBalancerServer(lb, ls)
- go func() {
- lb.Serve(lbLis)
- }()
- defer func() {
- ls.stop()
- lb.Stop()
- }()
+ tss.ls.sls = []*lbpb.ServerList{sl}
+ tss.ls.intervals = []time.Duration{0}
creds := serverNameCheckCreds{
expected: besn,
}
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
- addr: lbLis.Addr().String(),
+ addrs: []string{tss.lbAddr},
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
@@ -458,31 +452,14 @@
}
func TestServerExpiration(t *testing.T) {
- // Start a backend.
- beLis, err := net.Listen("tcp", "localhost:0")
+ tss, cleanup, err := newLoadBalancer(1)
if err != nil {
- t.Fatalf("Failed to listen %v", err)
+ t.Fatalf("failed to create new load balancer: %v", err)
}
- beAddr := strings.Split(beLis.Addr().String(), ":")
- bePort, err := strconv.Atoi(beAddr[1])
- backends := startBackends(t, besn, beLis)
- defer stopBackends(backends)
-
- // Start a load balancer.
- lbLis, err := net.Listen("tcp", "localhost:0")
- if err != nil {
- t.Fatalf("Failed to create the listener for the load balancer %v", err)
- }
- lbCreds := &serverNameCheckCreds{
- sn: lbsn,
- }
- lb := grpc.NewServer(grpc.Creds(lbCreds))
- if err != nil {
- t.Fatalf("Failed to generate the port number %v", err)
- }
+ defer cleanup()
be := &lbpb.Server{
- IpAddress: beLis.Addr().(*net.TCPAddr).IP,
- Port: int32(bePort),
+ IpAddress: tss.beIPs[0],
+ Port: int32(tss.bePorts[0]),
LoadBalanceToken: lbToken,
}
var bes []*lbpb.Server
@@ -504,21 +481,14 @@
var intervals []time.Duration
intervals = append(intervals, 0)
intervals = append(intervals, 500*time.Millisecond)
- ls := newRemoteBalancer(sls, intervals)
- lbpb.RegisterLoadBalancerServer(lb, ls)
- go func() {
- lb.Serve(lbLis)
- }()
- defer func() {
- ls.stop()
- lb.Stop()
- }()
+ tss.ls.sls = sls
+ tss.ls.intervals = intervals
creds := serverNameCheckCreds{
expected: besn,
}
ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{
- addr: lbLis.Addr().String(),
+ addrs: []string{tss.lbAddr},
})), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
if err != nil {
t.Fatalf("Failed to dial to the backend %v", err)
@@ -539,3 +509,90 @@
}
cc.Close()
}
+
+// When the balancer in use disconnects, grpclb should connect to the next address from resolved balancer address list.
+func TestBalancerDisconnects(t *testing.T) {
+ var (
+ lbAddrs []string
+ lbs []*grpc.Server
+ )
+ for i := 0; i < 3; i++ {
+ tss, cleanup, err := newLoadBalancer(1)
+ if err != nil {
+ t.Fatalf("failed to create new load balancer: %v", err)
+ }
+ defer cleanup()
+
+ be := &lbpb.Server{
+ IpAddress: tss.beIPs[0],
+ Port: int32(tss.bePorts[0]),
+ LoadBalanceToken: lbToken,
+ }
+ var bes []*lbpb.Server
+ bes = append(bes, be)
+ sl := &lbpb.ServerList{
+ Servers: bes,
+ }
+ tss.ls.sls = []*lbpb.ServerList{sl}
+ tss.ls.intervals = []time.Duration{0}
+
+ lbAddrs = append(lbAddrs, tss.lbAddr)
+ lbs = append(lbs, tss.lb)
+ }
+
+ creds := serverNameCheckCreds{
+ expected: besn,
+ }
+ ctx, _ := context.WithTimeout(context.Background(), 10*time.Second)
+ resolver := &testNameResolver{
+ addrs: lbAddrs[:2],
+ }
+ cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(resolver)), grpc.WithBlock(), grpc.WithTransportCredentials(&creds))
+ if err != nil {
+ t.Fatalf("Failed to dial to the backend %v", err)
+ }
+ helloC := hwpb.NewGreeterClient(cc)
+ var message string
+ if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
+ t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
+ } else {
+ message = resp.Message
+ }
+ // The initial resolver update contains lbs[0] and lbs[1].
+ // When lbs[0] is stopped, lbs[1] should be used.
+ lbs[0].Stop()
+ for {
+ if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
+ t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
+ } else if resp.Message != message {
+ // A new backend server should receive the request.
+ // The response contains the backend address, so the message should be different from the previous one.
+ message = resp.Message
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+ // Inject a update to add lbs[2] to resolved addresses.
+ resolver.inject([]*naming.Update{
+ {Op: naming.Add,
+ Addr: lbAddrs[2],
+ Metadata: &Metadata{
+ AddrType: GRPCLB,
+ ServerName: lbsn,
+ },
+ },
+ })
+ // Stop lbs[1]. Now lbs[0] and lbs[1] are all stopped. lbs[2] should be used.
+ lbs[1].Stop()
+ for {
+ if resp, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil {
+ t.Fatalf("%v.SayHello(_, _) = _, %v, want _, <nil>", helloC, err)
+ } else if resp.Message != message {
+ // A new backend server should receive the request.
+ // The response contains the backend address, so the message should be different from the previous one.
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+ cc.Close()
+}