feat(v2): add WithTimeout option (#259)
diff --git a/v2/call_option.go b/v2/call_option.go
index e092005..c52e03f 100644
--- a/v2/call_option.go
+++ b/v2/call_option.go
@@ -218,6 +218,14 @@
s.Path = p.p
}
+type timeoutOpt struct {
+ t time.Duration
+}
+
+func (t timeoutOpt) Resolve(s *CallSettings) {
+ s.timeout = t.t
+}
+
// WithPath applies a Path override to the HTTP-based APICall.
//
// This is for internal use only.
@@ -230,6 +238,15 @@
return grpcOpt(append([]grpc.CallOption(nil), opt...))
}
+// WithTimeout is a convenience option for setting a context.WithTimeout on the
+// singular context.Context used for **all** APICall attempts. Calculated from
+// the start of the first APICall attempt.
+// If the context.Context provided to Invoke already has a Deadline set, that
+// will always be respected over the deadline calculated using this option.
+func WithTimeout(t time.Duration) CallOption {
+ return &timeoutOpt{t: t}
+}
+
// CallSettings allow fine-grained control over how calls are made.
type CallSettings struct {
// Retry returns a Retryer to be used to control retry logic of a method call.
@@ -241,4 +258,8 @@
// Path is an HTTP override for an APICall.
Path string
+
+ // Timeout defines the amount of time that Invoke has to complete.
+ // Unexported so it cannot be changed by the code in an APICall.
+ timeout time.Duration
}
diff --git a/v2/call_option_test.go b/v2/call_option_test.go
index 03a8609..925b768 100644
--- a/v2/call_option_test.go
+++ b/v2/call_option_test.go
@@ -129,3 +129,14 @@
}
}
}
+
+func TestWithTimeout(t *testing.T) {
+ settings := CallSettings{}
+ to := 10 * time.Second
+
+ WithTimeout(to).Resolve(&settings)
+
+ if settings.timeout != to {
+ t.Errorf("got %v, want %v", settings.timeout, to)
+ }
+}
diff --git a/v2/invoke.go b/v2/invoke.go
index 9fcc299..721d1af 100644
--- a/v2/invoke.go
+++ b/v2/invoke.go
@@ -68,6 +68,16 @@
// invoke implements Invoke, taking an additional sleeper argument for testing.
func invoke(ctx context.Context, call APICall, settings CallSettings, sp sleeper) error {
var retryer Retryer
+
+ // Only use the value provided via WithTimeout if the context doesn't
+ // already have a deadline. This is important for backwards compatibility if
+ // the user already set a deadline on the context given to Invoke.
+ if _, ok := ctx.Deadline(); !ok && settings.timeout != 0 {
+ c, cc := context.WithTimeout(ctx, settings.timeout)
+ defer cc()
+ ctx = c
+ }
+
for {
err := call(ctx, settings)
if err == nil {
diff --git a/v2/invoke_test.go b/v2/invoke_test.go
index 7fc4bf1..cef2167 100644
--- a/v2/invoke_test.go
+++ b/v2/invoke_test.go
@@ -202,3 +202,65 @@
t.Errorf("found error %s, want %s", err, context.Canceled)
}
}
+
+func TestInvokeWithTimeout(t *testing.T) {
+ // Dummy APICall that sleeps for the given amount of time. This simulates an
+ // APICall executing, allowing us to verify which deadline was respected,
+ // that which is already set on the Context, or the one calculated using the
+ // WithTimeout option's value.
+ sleepingCall := func(sleep time.Duration) APICall {
+ return func(ctx context.Context, _ CallSettings) error {
+ time.Sleep(sleep)
+ return ctx.Err()
+ }
+ }
+
+ bg := context.Background()
+ preset, pcc := context.WithTimeout(bg, 10*time.Millisecond)
+ defer pcc()
+
+ for _, tst := range []struct {
+ name string
+ timeout time.Duration
+ sleep time.Duration
+ ctx context.Context
+ want error
+ }{
+ {
+ name: "success",
+ timeout: 10 * time.Millisecond,
+ sleep: 1 * time.Millisecond,
+ ctx: bg,
+ want: nil,
+ },
+ {
+ name: "respect_context_deadline",
+ timeout: 1 * time.Millisecond,
+ sleep: 3 * time.Millisecond,
+ ctx: preset,
+ want: nil,
+ },
+ {
+ name: "with_timeout_deadline_exceeded",
+ timeout: 1 * time.Millisecond,
+ sleep: 3 * time.Millisecond,
+ ctx: bg,
+ want: context.DeadlineExceeded,
+ },
+ } {
+ t.Run(tst.name, func(t *testing.T) {
+ // Recording sleep isn't really necessary since there is
+ // no retry here, but we need a sleeper so might as well.
+ var sp recordSleeper
+ var settings CallSettings
+
+ WithTimeout(tst.timeout).Resolve(&settings)
+
+ err := invoke(tst.ctx, sleepingCall(tst.sleep), settings, sp.sleep)
+
+ if err != tst.want {
+ t.Errorf("found error %v, want %v", err, tst.want)
+ }
+ })
+ }
+}