Merge pull request #3 from saicheems/invoke

Change CreateAPICall to Invoke
diff --git a/api_callable.go b/api_callable.go
index 5defc50..f7544b0 100644
--- a/api_callable.go
+++ b/api_callable.go
@@ -8,8 +8,8 @@
 	"google.golang.org/grpc/codes"
 )
 
-// Represents a GRPC call stub.
-type APICall func(context.Context, interface{}) (interface{}, error)
+// A user defined call stub.
+type APICall func(context.Context) error
 
 // scaleDuration returns the product of a and mult.
 func scaleDuration(a time.Duration, mult float64) time.Duration {
@@ -17,59 +17,51 @@
 	return time.Duration(ns)
 }
 
-// stubWithRetry returns a wrapper for stub with an exponential backoff retry
-// mechanism based on the values provided in retrySettings.
-func stubWithRetry(stub APICall, retrySettings retrySettings) APICall {
-	return func(ctx context.Context, req interface{}) (resp interface{}, err error) {
-		backoffSettings := retrySettings.backoffSettings
-		// Forces ctx to expire after a deadline.
-		ctx, _ = context.WithTimeout(ctx, backoffSettings.totalTimeout)
-
-		delay := backoffSettings.delayTimeoutSettings.initial
-		timeout := backoffSettings.rpcTimeoutSettings.initial
-
-		for {
-			// If the deadline is exceeded...
-			if ctx.Err() != nil {
-				return nil, ctx.Err()
-			}
-			timeoutCtx, _ := context.WithTimeout(ctx, backoffSettings.rpcTimeoutSettings.max)
-			timeoutCtx, _ = context.WithTimeout(timeoutCtx, timeout)
-			resp, err = stub(timeoutCtx, req)
-			code := grpc.Code(err)
-			if code == codes.OK {
-				return resp, err
-			}
-			if !retrySettings.retryCodes[code] {
-				return nil, err
-			}
-			delayCtx, _ := context.WithTimeout(ctx, backoffSettings.delayTimeoutSettings.max)
-			delayCtx, _ = context.WithTimeout(delayCtx, delay)
-			<-delayCtx.Done()
-
-			delay = scaleDuration(delay, backoffSettings.delayTimeoutSettings.multiplier)
-			timeout = scaleDuration(timeout, backoffSettings.rpcTimeoutSettings.multiplier)
+// invokeWithRetry calls stub using an exponential backoff retry mechanism
+// based on the values provided in retrySettings.
+func invokeWithRetry(ctx context.Context, stub APICall, retrySettings retrySettings) error {
+	backoffSettings := retrySettings.backoffSettings
+	// Forces ctx to expire after a deadline.
+	childCtx, _ := context.WithTimeout(ctx, backoffSettings.totalTimeout)
+	delay := backoffSettings.delayTimeoutSettings.initial
+	timeout := backoffSettings.rpcTimeoutSettings.initial
+	for {
+		// If the deadline is exceeded...
+		if childCtx.Err() != nil {
+			return childCtx.Err()
 		}
-		return
+		timeoutCtx, _ := context.WithTimeout(childCtx, backoffSettings.rpcTimeoutSettings.max)
+		timeoutCtx, _ = context.WithTimeout(timeoutCtx, timeout)
+		err := stub(timeoutCtx)
+		code := grpc.Code(err)
+		if code == codes.OK {
+			return nil
+		}
+		if !retrySettings.retryCodes[code] {
+			return err
+		}
+		delayCtx, _ := context.WithTimeout(childCtx, backoffSettings.delayTimeoutSettings.max)
+		delayCtx, _ = context.WithTimeout(delayCtx, delay)
+		<-delayCtx.Done()
+
+		delay = scaleDuration(delay, backoffSettings.delayTimeoutSettings.multiplier)
+		timeout = scaleDuration(timeout, backoffSettings.rpcTimeoutSettings.multiplier)
 	}
+	return nil
 }
 
-// stubWithTimeout returns a wrapper for stub with a timeout applied to its
-// context.
-func stubWithTimeout(stub APICall, timeout time.Duration) APICall {
-	return func(ctx context.Context, data interface{}) (interface{}, error) {
-		childCtx, _ := context.WithTimeout(ctx, timeout)
-		return stub(childCtx, data)
-	}
+// invokeWithTimeout calls stub with a timeout applied to its context.
+func invokeWithTimeout(ctx context.Context, stub APICall, timeout time.Duration) error {
+	childCtx, _ := context.WithTimeout(ctx, timeout)
+	return stub(childCtx)
 }
 
-// CreateAPICall returns a wrapper for stub governed by the values provided in
-// settings.
-func CreateAPICall(stub APICall, opts ...CallOption) APICall {
+// Invoke calls stub with a child of context modified by the specified options.
+func Invoke(ctx context.Context, stub APICall, opts ...CallOption) error {
 	settings := &callSettings{}
 	callOptions(opts).Resolve(settings)
 	if len(settings.retrySettings.retryCodes) > 0 {
-		return stubWithRetry(stub, settings.retrySettings)
+		return invokeWithRetry(ctx, stub, settings.retrySettings)
 	}
-	return stubWithTimeout(stub, settings.timeout)
+	return invokeWithTimeout(ctx, stub, settings.timeout)
 }
diff --git a/api_callable_test.go b/api_callable_test.go
index 6352087..3462818 100644
--- a/api_callable_test.go
+++ b/api_callable_test.go
@@ -19,29 +19,31 @@
 	}
 )
 
-func TestCreateAPICallWithTimeout(t *testing.T) {
+func TestInvokeWithTimeout(t *testing.T) {
 	ctx := context.Background()
 	var ok bool
-	CreateAPICall(func(ctx context.Context, req interface{}) (interface{}, error) {
-		_, ok = ctx.Deadline()
-		return nil, nil
-	}, WithTimeout(10000*time.Millisecond))(ctx, nil)
+	Invoke(ctx, func(childCtx context.Context) error {
+		_, ok = childCtx.Deadline()
+		return nil
+	}, WithTimeout(10000*time.Millisecond))
 	if !ok {
 		t.Errorf("expected call to have an assigned timeout")
 	}
 }
 
-func TestCreateApiCallWithOKResponseWithTimeout(t *testing.T) {
+func TestInvokeWithOKResponseWithTimeout(t *testing.T) {
 	ctx := context.Background()
-	resp, err := CreateAPICall(func(ctx context.Context, req interface{}) (interface{}, error) {
-		return 42, nil
-	}, WithTimeout(10000*time.Millisecond))(ctx, nil)
-	if resp.(int) != 42 || err != nil {
-		t.Errorf("expected call to return (42, nil)")
+	var resp int
+	err := Invoke(ctx, func(childCtx context.Context) error {
+		resp = 42
+		return nil
+	}, WithTimeout(10000*time.Millisecond))
+	if resp != 42 || err != nil {
+		t.Errorf("expected call to return nil and set resp to 42")
 	}
 }
 
-func TestCreateApiCallWithDeadlineAfterRetries(t *testing.T) {
+func TestInvokeWithDeadlineAfterRetries(t *testing.T) {
 	ctx := context.Background()
 	count := 0
 
@@ -52,33 +54,35 @@
 		450 * time.Millisecond,
 	}
 
-	_, err := CreateAPICall(func(ctx context.Context, req interface{}) (interface{}, error) {
+	err := Invoke(ctx, func(childCtx context.Context) error {
 		t.Log("delta:", time.Now().Sub(now.Add(expectedTimeout[count])))
 		if !time.Now().After(now.Add(expectedTimeout[count])) {
 			t.Errorf("expected %s to pass before this call", expectedTimeout[count])
 		}
 		count += 1
-		<-ctx.Done()
-		return nil, grpc.Errorf(codes.DeadlineExceeded, "")
-	}, testCallSettings...)(ctx, nil)
+		<-childCtx.Done()
+		return grpc.Errorf(codes.DeadlineExceeded, "")
+	}, testCallSettings...)
 	if count != 3 || err == nil {
 		t.Errorf("expected call to retry 3 times and return an error")
 	}
 }
 
-func TestCreateApiCallWithOKResponseAfterRetries(t *testing.T) {
+func TestInvokeWithOKResponseAfterRetries(t *testing.T) {
 	ctx := context.Background()
 	count := 0
 
-	resp, err := CreateAPICall(func(ctx context.Context, req interface{}) (interface{}, error) {
+	var resp int
+	err := Invoke(ctx, func(childCtx context.Context) error {
 		count += 1
 		if count == 3 {
-			return 42, nil
+			resp = 42
+			return nil
 		}
-		<-ctx.Done()
-		return nil, grpc.Errorf(codes.DeadlineExceeded, "")
-	}, testCallSettings...)(ctx, nil)
-	if count != 3 || resp.(int) != 42 || err != nil {
-		t.Errorf("expected call to retry 3 times and return (42, nil)")
+		<-childCtx.Done()
+		return grpc.Errorf(codes.DeadlineExceeded, "")
+	}, testCallSettings...)
+	if count != 3 || resp != 42 || err != nil {
+		t.Errorf("expected call to retry 3 times, return nil, and set resp to 42")
 	}
 }