| // Copyright 2021 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| package rpcutil |
| |
| import ( |
| "context" |
| "errors" |
| "fmt" |
| "testing" |
| |
| "google.golang.org/grpc/codes" |
| "google.golang.org/grpc/status" |
| ) |
| |
| func TestRequestWithRetries(t *testing.T) { |
| oldRetryDelay := retryDelay |
| defer func() { |
| retryDelay = oldRetryDelay |
| }() |
| retryDelay = 0 |
| |
| tests := []struct { |
| name string |
| errors []error |
| extraTransientCodes []codes.Code |
| expectErr bool |
| expectAttempts int |
| }{ |
| { |
| name: "returns nil if function passes", |
| errors: []error{ |
| nil, |
| }, |
| expectAttempts: 1, |
| }, |
| { |
| name: "retries on internal error", |
| errors: []error{ |
| newGRPCError(codes.Internal), |
| nil, |
| }, |
| expectAttempts: 2, |
| }, |
| { |
| name: "retries on timeout", |
| errors: []error{ |
| newGRPCError(codes.DeadlineExceeded), |
| nil, |
| }, |
| expectAttempts: 2, |
| }, |
| { |
| name: "retries on wrapped transient error", |
| errors: []error{ |
| fmt.Errorf("failure: %w", newGRPCError(codes.Internal)), |
| nil, |
| }, |
| expectAttempts: 2, |
| }, |
| { |
| name: "retries generic transient network failures", |
| errors: []error{ |
| status.Error(codes.Unknown, "net/http: TLS handshake timeout"), |
| nil, |
| }, |
| expectAttempts: 2, |
| }, |
| { |
| name: "does not retry non-transient generic network failures", |
| errors: []error{ |
| status.Error(codes.Unknown, "non-transient error"), |
| }, |
| expectAttempts: 1, |
| expectErr: true, |
| }, |
| { |
| name: "does not retry unrecognized failures", |
| errors: []error{ |
| errors.New("non-RPC error"), |
| }, |
| expectAttempts: 1, |
| expectErr: true, |
| }, |
| { |
| name: "does not retry non-transient failure", |
| errors: []error{ |
| newGRPCError(codes.InvalidArgument), |
| }, |
| expectAttempts: 1, |
| expectErr: true, |
| }, |
| { |
| name: "retries if error matches an extra transient code", |
| errors: []error{ |
| newGRPCError(codes.InvalidArgument), |
| newGRPCError(codes.OutOfRange), |
| nil, |
| }, |
| extraTransientCodes: []codes.Code{codes.OutOfRange, codes.InvalidArgument}, |
| expectAttempts: 3, |
| }, |
| { |
| name: "returns error if all attempts fail", |
| errors: multiplyError(newGRPCError(codes.Internal), maxAttempts), |
| expectAttempts: maxAttempts, |
| expectErr: true, |
| }, |
| } |
| |
| for _, test := range tests { |
| t.Run(test.name, func(t *testing.T) { |
| if test.expectAttempts != len(test.errors) { |
| // This requirement adds some redundancy but makes the tests |
| // more explicit and easier to read. |
| t.Fatalf("wantAttempts (%d) must be equal to the length of errors (%d)", |
| test.expectAttempts, len(test.errors)) |
| } |
| |
| attempts := 0 |
| err := RequestWithRetries(context.Background(), func() error { |
| defer func() { attempts++ }() |
| if attempts >= len(test.errors) { |
| t.Fatalf("Too few codes for attempt #%d", attempts) |
| } |
| return test.errors[attempts] |
| }, test.extraTransientCodes...) |
| |
| if (err != nil) != test.expectErr { |
| t.Fatalf("Unexpected error result: %s", err) |
| } |
| if test.expectAttempts != attempts { |
| t.Errorf("Expected %d attempts, got %d", test.expectAttempts, attempts) |
| } |
| }) |
| } |
| } |
| |
| func newGRPCError(code codes.Code) error { |
| return status.Error(code, code.String()) |
| } |
| |
| // multiplyError returns a slice of length `n`, with `err` at every index. |
| func multiplyError(err error, n int) (res []error) { |
| for i := 0; i < n; i++ { |
| res = append(res, err) |
| } |
| return res |
| } |