blob: 8927823d09da85523828cd94eb4aaaf7dd26829c [file] [log] [blame]
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package testutils
import (
"context"
"errors"
"fmt"
"testing"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/resolver"
)
// TestSubConnsCount is the number of TestSubConns initialized as part of
// package init.
const TestSubConnsCount = 16
// testingLogger wraps the logging methods from testing.T.
type testingLogger interface {
Log(args ...interface{})
Logf(format string, args ...interface{})
}
// TestSubConns contains a list of SubConns to be used in tests.
var TestSubConns []*TestSubConn
func init() {
for i := 0; i < TestSubConnsCount; i++ {
TestSubConns = append(TestSubConns, &TestSubConn{
id: fmt.Sprintf("sc%d", i),
ConnectCh: make(chan struct{}, 1),
})
}
}
// TestSubConn implements the SubConn interface, to be used in tests.
type TestSubConn struct {
id string
ConnectCh chan struct{}
}
// UpdateAddresses is a no-op.
func (tsc *TestSubConn) UpdateAddresses([]resolver.Address) {}
// Connect is a no-op.
func (tsc *TestSubConn) Connect() {
select {
case tsc.ConnectCh <- struct{}{}:
default:
}
}
// GetOrBuildProducer is a no-op.
func (tsc *TestSubConn) GetOrBuildProducer(balancer.ProducerBuilder) (balancer.Producer, func()) {
return nil, nil
}
// String implements stringer to print human friendly error message.
func (tsc *TestSubConn) String() string {
return tsc.id
}
// TestClientConn is a mock balancer.ClientConn used in tests.
type TestClientConn struct {
logger testingLogger
NewSubConnAddrsCh chan []resolver.Address // the last 10 []Address to create subconn.
NewSubConnCh chan balancer.SubConn // the last 10 subconn created.
RemoveSubConnCh chan balancer.SubConn // the last 10 subconn removed.
UpdateAddressesAddrsCh chan []resolver.Address // last updated address via UpdateAddresses().
NewPickerCh chan balancer.Picker // the last picker updated.
NewStateCh chan connectivity.State // the last state.
ResolveNowCh chan resolver.ResolveNowOptions // the last ResolveNow().
subConnIdx int
}
// NewTestClientConn creates a TestClientConn.
func NewTestClientConn(t *testing.T) *TestClientConn {
return &TestClientConn{
logger: t,
NewSubConnAddrsCh: make(chan []resolver.Address, 10),
NewSubConnCh: make(chan balancer.SubConn, 10),
RemoveSubConnCh: make(chan balancer.SubConn, 10),
UpdateAddressesAddrsCh: make(chan []resolver.Address, 1),
NewPickerCh: make(chan balancer.Picker, 1),
NewStateCh: make(chan connectivity.State, 1),
ResolveNowCh: make(chan resolver.ResolveNowOptions, 1),
}
}
// NewSubConn creates a new SubConn.
func (tcc *TestClientConn) NewSubConn(a []resolver.Address, o balancer.NewSubConnOptions) (balancer.SubConn, error) {
sc := TestSubConns[tcc.subConnIdx]
tcc.subConnIdx++
tcc.logger.Logf("testClientConn: NewSubConn(%v, %+v) => %s", a, o, sc)
select {
case tcc.NewSubConnAddrsCh <- a:
default:
}
select {
case tcc.NewSubConnCh <- sc:
default:
}
return sc, nil
}
// RemoveSubConn removes the SubConn.
func (tcc *TestClientConn) RemoveSubConn(sc balancer.SubConn) {
tcc.logger.Logf("testClientConn: RemoveSubConn(%s)", sc)
select {
case tcc.RemoveSubConnCh <- sc:
default:
}
}
// UpdateAddresses updates the addresses on the SubConn.
func (tcc *TestClientConn) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) {
tcc.logger.Logf("testClientConn: UpdateAddresses(%v, %+v)", sc, addrs)
select {
case tcc.UpdateAddressesAddrsCh <- addrs:
default:
}
}
// UpdateState updates connectivity state and picker.
func (tcc *TestClientConn) UpdateState(bs balancer.State) {
tcc.logger.Logf("testClientConn: UpdateState(%v)", bs)
select {
case <-tcc.NewStateCh:
default:
}
tcc.NewStateCh <- bs.ConnectivityState
select {
case <-tcc.NewPickerCh:
default:
}
tcc.NewPickerCh <- bs.Picker
}
// ResolveNow panics.
func (tcc *TestClientConn) ResolveNow(o resolver.ResolveNowOptions) {
select {
case <-tcc.ResolveNowCh:
default:
}
tcc.ResolveNowCh <- o
}
// Target panics.
func (tcc *TestClientConn) Target() string {
panic("not implemented")
}
// WaitForErrPicker waits until an error picker is pushed to this ClientConn.
// Returns error if the provided context expires or a non-error picker is pushed
// to the ClientConn.
func (tcc *TestClientConn) WaitForErrPicker(ctx context.Context) error {
select {
case <-ctx.Done():
return errors.New("timeout when waiting for an error picker")
case picker := <-tcc.NewPickerCh:
if _, perr := picker.Pick(balancer.PickInfo{}); perr == nil {
return fmt.Errorf("balancer returned a picker which is not an error picker")
}
}
return nil
}
// WaitForPickerWithErr waits until an error picker is pushed to this
// ClientConn with the error matching the wanted error. Returns an error if
// the provided context expires, including the last received picker error (if
// any).
func (tcc *TestClientConn) WaitForPickerWithErr(ctx context.Context, want error) error {
lastErr := errors.New("received no picker")
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout when waiting for an error picker with %v; last picker error: %v", want, lastErr)
case picker := <-tcc.NewPickerCh:
if _, lastErr = picker.Pick(balancer.PickInfo{}); lastErr != nil && lastErr.Error() == want.Error() {
return nil
}
}
}
}
// WaitForConnectivityState waits until the state pushed to this ClientConn
// matches the wanted state. Returns an error if the provided context expires,
// including the last received state (if any).
func (tcc *TestClientConn) WaitForConnectivityState(ctx context.Context, want connectivity.State) error {
var lastState connectivity.State = -1
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout when waiting for state to be %s; last state: %s", want, lastState)
case s := <-tcc.NewStateCh:
if s == want {
return nil
}
lastState = s
}
}
}
// WaitForRoundRobinPicker waits for a picker that passes IsRoundRobin. Also
// drains the matching state channel and requires it to be READY (if an entry
// is pending) to be considered. Returns an error if the provided context
// expires, including the last received error from IsRoundRobin or the picker
// (if any).
func (tcc *TestClientConn) WaitForRoundRobinPicker(ctx context.Context, want ...balancer.SubConn) error {
lastErr := errors.New("received no picker")
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout when waiting for round robin picker with %v; last error: %v", want, lastErr)
case p := <-tcc.NewPickerCh:
s := connectivity.Ready
select {
case s = <-tcc.NewStateCh:
default:
}
if s != connectivity.Ready {
lastErr = fmt.Errorf("received state %v instead of ready", s)
break
}
var pickerErr error
if err := IsRoundRobin(want, func() balancer.SubConn {
sc, err := p.Pick(balancer.PickInfo{})
if err != nil {
pickerErr = err
} else if sc.Done != nil {
sc.Done(balancer.DoneInfo{})
}
return sc.SubConn
}); pickerErr != nil {
lastErr = pickerErr
continue
} else if err != nil {
lastErr = err
continue
}
return nil
}
}
}
// WaitForPicker waits for a picker that results in f returning nil. If the
// context expires, returns the last error returned by f (if any).
func (tcc *TestClientConn) WaitForPicker(ctx context.Context, f func(balancer.Picker) error) error {
lastErr := errors.New("received no picker")
for {
select {
case <-ctx.Done():
return fmt.Errorf("timeout when waiting for picker; last error: %v", lastErr)
case p := <-tcc.NewPickerCh:
if err := f(p); err != nil {
lastErr = err
continue
}
return nil
}
}
}
// IsRoundRobin checks whether f's return value is roundrobin of elements from
// want. But it doesn't check for the order. Note that want can contain
// duplicate items, which makes it weight-round-robin.
//
// Step 1. the return values of f should form a permutation of all elements in
// want, but not necessary in the same order. E.g. if want is {a,a,b}, the check
// fails if f returns:
// - {a,a,a}: third a is returned before b
// - {a,b,b}: second b is returned before the second a
//
// If error is found in this step, the returned error contains only the first
// iteration until where it goes wrong.
//
// Step 2. the return values of f should be repetitions of the same permutation.
// E.g. if want is {a,a,b}, the check failes if f returns:
// - {a,b,a,b,a,a}: though it satisfies step 1, the second iteration is not
// repeating the first iteration.
//
// If error is found in this step, the returned error contains the first
// iteration + the second iteration until where it goes wrong.
func IsRoundRobin(want []balancer.SubConn, f func() balancer.SubConn) error {
wantSet := make(map[balancer.SubConn]int) // SubConn -> count, for weighted RR.
for _, sc := range want {
wantSet[sc]++
}
// The first iteration: makes sure f's return values form a permutation of
// elements in want.
//
// Also keep the returns values in a slice, so we can compare the order in
// the second iteration.
gotSliceFirstIteration := make([]balancer.SubConn, 0, len(want))
for range want {
got := f()
gotSliceFirstIteration = append(gotSliceFirstIteration, got)
wantSet[got]--
if wantSet[got] < 0 {
return fmt.Errorf("non-roundrobin want: %v, result: %v", want, gotSliceFirstIteration)
}
}
// The second iteration should repeat the first iteration.
var gotSliceSecondIteration []balancer.SubConn
for i := 0; i < 2; i++ {
for _, w := range gotSliceFirstIteration {
g := f()
gotSliceSecondIteration = append(gotSliceSecondIteration, g)
if w != g {
return fmt.Errorf("non-roundrobin, first iter: %v, second iter: %v", gotSliceFirstIteration, gotSliceSecondIteration)
}
}
}
return nil
}
// SubConnFromPicker returns a function which returns a SubConn by calling the
// Pick() method of the provided picker. There is no caching of SubConns here.
// Every invocation of the returned function results in a new pick.
func SubConnFromPicker(p balancer.Picker) func() balancer.SubConn {
return func() balancer.SubConn {
scst, _ := p.Pick(balancer.PickInfo{})
return scst.SubConn
}
}
// ErrTestConstPicker is error returned by test const picker.
var ErrTestConstPicker = fmt.Errorf("const picker error")
// TestConstPicker is a const picker for tests.
type TestConstPicker struct {
Err error
SC balancer.SubConn
}
// Pick returns the const SubConn or the error.
func (tcp *TestConstPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) {
if tcp.Err != nil {
return balancer.PickResult{}, tcp.Err
}
return balancer.PickResult{SubConn: tcp.SC}, nil
}