blob: 56989411e946f4c394af96983f972f24fa727c60 [file] [log] [blame]
// Copyright 2016 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gax
import (
"context"
"errors"
"testing"
"time"
)
var canceledContext context.Context
func init() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
canceledContext = ctx
}
// recordSleeper is a test implementation of sleeper.
type recordSleeper int
func (s *recordSleeper) sleep(ctx context.Context, _ time.Duration) error {
*s++
return ctx.Err()
}
type boolRetryer bool
func (r boolRetryer) Retry(err error) (time.Duration, bool) { return 0, bool(r) }
func TestInvokeSuccess(t *testing.T) {
apiCall := func(context.Context, CallSettings) error { return nil }
var sp recordSleeper
err := invoke(context.Background(), apiCall, CallSettings{}, sp.sleep)
if err != nil {
t.Errorf("found error %s, want nil", err)
}
if sp != 0 {
t.Errorf("slept %d times, should not have slept since the call succeeded", int(sp))
}
}
func TestInvokeNoRetry(t *testing.T) {
apiErr := errors.New("foo error")
apiCall := func(context.Context, CallSettings) error { return apiErr }
var sp recordSleeper
err := invoke(context.Background(), apiCall, CallSettings{}, sp.sleep)
if err != apiErr {
t.Errorf("found error %s, want %s", err, apiErr)
}
if sp != 0 {
t.Errorf("slept %d times, should not have slept since retry is not specified", int(sp))
}
}
func TestInvokeNilRetry(t *testing.T) {
apiErr := errors.New("foo error")
apiCall := func(context.Context, CallSettings) error { return apiErr }
var settings CallSettings
WithRetry(func() Retryer { return nil }).Resolve(&settings)
var sp recordSleeper
err := invoke(context.Background(), apiCall, settings, sp.sleep)
if err != apiErr {
t.Errorf("found error %s, want %s", err, apiErr)
}
if sp != 0 {
t.Errorf("slept %d times, should not have slept since retry is not specified", int(sp))
}
}
func TestInvokeNeverRetry(t *testing.T) {
apiErr := errors.New("foo error")
apiCall := func(context.Context, CallSettings) error { return apiErr }
var settings CallSettings
WithRetry(func() Retryer { return boolRetryer(false) }).Resolve(&settings)
var sp recordSleeper
err := invoke(context.Background(), apiCall, settings, sp.sleep)
if err != apiErr {
t.Errorf("found error %s, want %s", err, apiErr)
}
if sp != 0 {
t.Errorf("slept %d times, should not have slept since retry is not specified", int(sp))
}
}
func TestInvokeRetry(t *testing.T) {
const target = 3
retryNum := 0
apiErr := errors.New("foo error")
apiCall := func(context.Context, CallSettings) error {
retryNum++
if retryNum < target {
return apiErr
}
return nil
}
var settings CallSettings
WithRetry(func() Retryer { return boolRetryer(true) }).Resolve(&settings)
var sp recordSleeper
err := invoke(context.Background(), apiCall, settings, sp.sleep)
if err != nil {
t.Errorf("found error %s, want nil, call should have succeeded after %d tries", err, target)
}
if sp != target-1 {
t.Errorf("retried %d times, want %d", int(sp), int(target-1))
}
}
func TestInvokeRetryTimeout(t *testing.T) {
apiErr := errors.New("foo error")
apiCall := func(context.Context, CallSettings) error { return apiErr }
var settings CallSettings
WithRetry(func() Retryer { return boolRetryer(true) }).Resolve(&settings)
var sp recordSleeper
err := invoke(canceledContext, apiCall, settings, sp.sleep)
if err != context.Canceled {
t.Errorf("found error %s, want %s", err, context.Canceled)
}
}