blob: 5aa58958a66bfb61e0924d7d55f8c36bf6391429 [file] [log] [blame]
// Copyright 2021 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package rpcutil
import (
"context"
"errors"
"os"
"strings"
"time"
"go.chromium.org/luci/common/retry"
"go.chromium.org/luci/common/retry/transient"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// Exposed so that it can be overridden by unit tests to avoid sleeping during
// tests.
var retryDelay = 2 * time.Second
// Max number of times to try a request.
const maxAttempts = 4
// RequestWithRetries calls a function that makes an RPC request. If the
// function returns an RPC error (possibly wrapped) it will be retried as long
// as the error matches one of a known set of transient error codes, or one of
// the error codes listed in `extraTransientCodes`.
//
// Generic network errors that are marked as "temporary" will also be retried.
//
// TODO(olivernewman): Once we're on Go 1.18, use generics to have this function
// pass through the callback's return value, which will help simplify callsites:
//
// RequestWithRetries[T any](Context, func() (T, error), ...codes.Code) (T, error)
func RequestWithRetries(ctx context.Context, f func() error, extraTransientCodes ...codes.Code) error {
retryPolicy := transient.Only(func() retry.Iterator {
return &retry.ExponentialBackoff{
Limited: retry.Limited{
Delay: retryDelay,
Retries: maxAttempts - 1,
},
Multiplier: 1.5,
}
})
return retry.Retry(ctx, retryPolicy, func() error {
if err := f(); err != nil {
if isRetryable(err, extraTransientCodes) {
return transient.Tag.Apply(err)
}
return err
}
return nil
}, nil)
}
func isRetryable(err error, extraTransientCodes []codes.Code) bool {
transientCodes := append(
extraTransientCodes,
codes.DataLoss,
codes.DeadlineExceeded,
codes.Internal,
codes.Unavailable,
)
if status, ok := StatusFromError(err); ok {
for _, c := range transientCodes {
if status.Code() == c {
return true
}
}
if status.Code() == codes.Unknown {
// The grpc library converts all unknown errors to its Error type,
// which only exposes the underlying error as a string rather than
// properly wrapping it. Hence why string comparison is necessary.
transientSubstrings := []string{
"net/http: TLS handshake timeout",
os.ErrDeadlineExceeded.Error(),
}
for _, s := range transientSubstrings {
if strings.Contains(status.Message(), s) {
return true
}
}
}
return false
}
return false
}
// StatusFromError returns the GRPC status of the underlying cause of the error
// if it unwraps to a GRPC error, along with a true boolean. It returns false if
// the underlying cause of the error is not a GRPC error.
//
// TODO(olivernewman): This is only necessary because `status.FromError()`
// doesn't unwrap errors. Delete this and use `status.FromError()` directly if
// and when it learns how to unwrap errors.
func StatusFromError(err error) (*status.Status, bool) {
for ; err != nil; err = errors.Unwrap(err) {
if s, ok := status.FromError(err); ok {
return s, true
}
}
return nil, false
}