blob: cdffafc56e10d684bb5870aa22327cafad83c929 [file] [log] [blame]
// Copyright 2021 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package rpcutil
import (
"context"
"errors"
"fmt"
"testing"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestRequestWithRetries(t *testing.T) {
oldRetryDelay := retryDelay
defer func() {
retryDelay = oldRetryDelay
}()
retryDelay = 0
tests := []struct {
name string
errors []error
extraTransientCodes []codes.Code
expectErr bool
expectAttempts int
}{
{
name: "returns nil if function passes",
errors: []error{
nil,
},
expectAttempts: 1,
},
{
name: "retries on internal error",
errors: []error{
newGRPCError(codes.Internal),
nil,
},
expectAttempts: 2,
},
{
name: "retries on timeout",
errors: []error{
newGRPCError(codes.DeadlineExceeded),
nil,
},
expectAttempts: 2,
},
{
name: "retries on wrapped transient error",
errors: []error{
fmt.Errorf("failure: %w", newGRPCError(codes.Internal)),
nil,
},
expectAttempts: 2,
},
{
name: "retries generic transient network failures",
errors: []error{
status.Error(codes.Unknown, "net/http: TLS handshake timeout"),
nil,
},
expectAttempts: 2,
},
{
name: "does not retry non-transient generic network failures",
errors: []error{
status.Error(codes.Unknown, "non-transient error"),
},
expectAttempts: 1,
expectErr: true,
},
{
name: "does not retry unrecognized failures",
errors: []error{
errors.New("non-RPC error"),
},
expectAttempts: 1,
expectErr: true,
},
{
name: "does not retry non-transient failure",
errors: []error{
newGRPCError(codes.InvalidArgument),
},
expectAttempts: 1,
expectErr: true,
},
{
name: "retries if error matches an extra transient code",
errors: []error{
newGRPCError(codes.InvalidArgument),
newGRPCError(codes.OutOfRange),
nil,
},
extraTransientCodes: []codes.Code{codes.OutOfRange, codes.InvalidArgument},
expectAttempts: 3,
},
{
name: "returns error if all attempts fail",
errors: multiplyError(newGRPCError(codes.Internal), maxAttempts),
expectAttempts: maxAttempts,
expectErr: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.expectAttempts != len(test.errors) {
// This requirement adds some redundancy but makes the tests
// more explicit and easier to read.
t.Fatalf("wantAttempts (%d) must be equal to the length of errors (%d)",
test.expectAttempts, len(test.errors))
}
attempts := 0
err := RequestWithRetries(context.Background(), func() error {
defer func() { attempts++ }()
if attempts >= len(test.errors) {
t.Fatalf("Too few codes for attempt #%d", attempts)
}
return test.errors[attempts]
}, test.extraTransientCodes...)
if (err != nil) != test.expectErr {
t.Fatalf("Unexpected error result: %s", err)
}
if test.expectAttempts != attempts {
t.Errorf("Expected %d attempts, got %d", test.expectAttempts, attempts)
}
})
}
}
func newGRPCError(code codes.Code) error {
return status.Error(code, code.String())
}
// multiplyError returns a slice of length `n`, with `err` at every index.
func multiplyError(err error, n int) (res []error) {
for range n {
res = append(res, err)
}
return res
}