blob: 472b054f80fe9442a851eaadb10551e1cb40e146 [file] [log] [blame]
package retry
import (
"context"
"errors"
"fmt"
"testing"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func alwaysRetry(error) bool { return true }
// failer returns an error until its counter reaches 0, at which point it returns finalErr, which is
// nil by default (no error).
type failer struct {
attempts int
finalErr error
}
func (f *failer) run() error {
f.attempts--
if f.attempts < 0 {
return f.finalErr
}
return errors.New("failing")
}
func policyString(bp BackoffPolicy) string {
base := fmt.Sprintf("baseDelay: %v, maxDelay: %v", bp.baseDelay, bp.maxDelay)
if bp.maxAttempts == 0 {
return base + ": unlimited retries"
}
return base + fmt.Sprintf(": max %d attempts", bp.maxAttempts)
}
func TestRetries(t *testing.T) {
cases := []struct {
policy BackoffPolicy
sr ShouldRetry
attempts int
finalErr error
wantError bool
wantErrCode codes.Code
}{
{
policy: ExponentialBackoff(time.Millisecond, time.Millisecond, UnlimitedAttempts),
sr: alwaysRetry,
attempts: 5,
},
{
policy: ExponentialBackoff(time.Millisecond, time.Millisecond, 5),
sr: alwaysRetry,
attempts: 5,
finalErr: status.Error(codes.Unimplemented, "unimplemented!"),
wantError: true,
wantErrCode: codes.Unimplemented,
},
{
policy: ExponentialBackoff(time.Millisecond, time.Millisecond, 1),
sr: alwaysRetry,
wantError: true,
attempts: 1,
},
{
policy: ExponentialBackoff(time.Millisecond, time.Millisecond, 2),
sr: alwaysRetry,
wantError: true,
attempts: 2,
},
{
policy: ExponentialBackoff(time.Millisecond, time.Millisecond, 5),
sr: alwaysRetry,
attempts: 5,
},
}
ctx := context.Background()
for _, c := range cases {
f := failer{
attempts: 4,
finalErr: c.finalErr,
}
err := WithPolicy(context.WithValue(ctx, TimeAfterContextKey, func(time.Duration) <-chan time.Time {
c := make(chan time.Time)
close(c) // Reading from the closed channel will immediately succeed.
return c
}), c.sr, c.policy, f.run)
attempts := 4 - f.attempts
if attempts != c.attempts {
t.Errorf("%s: expected %d attempts, got %d", policyString(c.policy), c.attempts, attempts)
}
switch {
case c.wantError:
if err == nil {
t.Errorf("%s: want error, got no error", policyString(c.policy))
}
if s, ok := status.FromError(err); c.wantErrCode != 0 && (!ok || s.Code() != c.wantErrCode) {
t.Errorf("%s: want error with code %v, got %v", policyString(c.policy), c.wantErrCode.String(), err)
}
case err != nil:
t.Errorf("%s: want success, got error: %v", policyString(c.policy), err)
}
}
}