delay: make it possible to get taskqueue HTTP headers from inside a delay.Func (#82)
Make it possible to get the in-flight request from inside a delay.Func
Fixes #59
diff --git a/delay/delay.go b/delay/delay.go
index af4eb8c..21c90bf 100644
--- a/delay/delay.go
+++ b/delay/delay.go
@@ -74,6 +74,8 @@
queue = ""
)
+type contextKey int
+
var (
// registry of all delayed functions
funcs = make(map[string]*Function)
@@ -83,7 +85,11 @@
errorType = reflect.TypeOf((*error)(nil)).Elem()
// errors
- errFirstArg = errors.New("first argument must be context.Context")
+ errFirstArg = errors.New("first argument must be context.Context")
+ errOutsideDelayFunc = errors.New("request headers are only available inside a delay.Func")
+
+ // context keys
+ headersContextKey contextKey = 0
)
// Func declares a new Function. The second argument must be a function with a
@@ -222,6 +228,15 @@
}, nil
}
+// Request returns the special task-queue HTTP request headers for the current
+// task queue handler. Returns an error if called from outside a delay.Func.
+func RequestHeaders(c context.Context) (*taskqueue.RequestHeaders, error) {
+ if ret, ok := c.Value(headersContextKey).(*taskqueue.RequestHeaders); ok {
+ return ret, nil
+ }
+ return nil, errOutsideDelayFunc
+}
+
var taskqueueAdder = taskqueue.Add // for testing
func init() {
@@ -233,6 +248,8 @@
func runFunc(c context.Context, w http.ResponseWriter, req *http.Request) {
defer req.Body.Close()
+ c = context.WithValue(c, headersContextKey, taskqueue.ParseRequestHeaders(req.Header))
+
var inv invocation
if err := gob.NewDecoder(req.Body).Decode(&inv); err != nil {
log.Errorf(c, "delay: failed decoding task payload: %v", err)
diff --git a/delay/delay_test.go b/delay/delay_test.go
index 1c37e79..3df2bf7 100644
--- a/delay/delay_test.go
+++ b/delay/delay_test.go
@@ -94,6 +94,14 @@
dupeWhich = 2
}
})
+
+ reqFuncRuns = 0
+ reqFuncHeaders *taskqueue.RequestHeaders
+ reqFuncErr error
+ reqFunc = Func("req", func(c context.Context) {
+ reqFuncRuns++
+ reqFuncHeaders, reqFuncErr = RequestHeaders(c)
+ })
)
type fakeContext struct {
@@ -373,3 +381,48 @@
t.Errorf("dupeWhich = %d; want 2", dupeWhich)
}
}
+
+func TestGetRequestHeadersFromContext(t *testing.T) {
+ c := newFakeContext()
+
+ // Outside a delay.Func should return an error.
+ headers, err := RequestHeaders(c.ctx)
+ if headers != nil {
+ t.Errorf("RequestHeaders outside Func, got %v, want nil", headers)
+ }
+ if err != errOutsideDelayFunc {
+ t.Errorf("RequestHeaders outside Func err, got %v, want %v", err, errOutsideDelayFunc)
+ }
+
+ // Fake out the adding of a task.
+ var task *taskqueue.Task
+ taskqueueAdder = func(_ context.Context, tk *taskqueue.Task, queue string) (*taskqueue.Task, error) {
+ if queue != "" {
+ t.Errorf(`Got queue %q, expected ""`, queue)
+ }
+ task = tk
+ return tk, nil
+ }
+
+ reqFunc.Call(c.ctx)
+
+ reqFuncRuns, reqFuncHeaders = 0, nil // reset state
+ // Simulate the Task Queue service.
+ req, err := http.NewRequest("POST", path, bytes.NewBuffer(task.Payload))
+ req.Header.Set("x-appengine-taskname", "foobar")
+ if err != nil {
+ t.Fatalf("Failed making http.Request: %v", err)
+ }
+ rw := httptest.NewRecorder()
+ runFunc(c.ctx, rw, req)
+
+ if reqFuncRuns != 1 {
+ t.Errorf("reqFuncRuns: got %d, want 1", reqFuncRuns)
+ }
+ if reqFuncHeaders.TaskName != "foobar" {
+ t.Errorf("reqFuncHeaders.TaskName: got %v, want 'foobar'", reqFuncHeaders.TaskName)
+ }
+ if reqFuncErr != nil {
+ t.Errorf("reqFuncErr: got %v, want nil", reqFuncErr)
+ }
+}
diff --git a/taskqueue/taskqueue.go b/taskqueue/taskqueue.go
index 5fed9b3..965c5ab 100644
--- a/taskqueue/taskqueue.go
+++ b/taskqueue/taskqueue.go
@@ -21,6 +21,7 @@
"fmt"
"net/http"
"net/url"
+ "strconv"
"time"
"github.com/golang/protobuf/proto"
@@ -147,6 +148,48 @@
}
}
+// RequestHeaders are the special HTTP request headers available to push task
+// HTTP request handlers. These headers are set internally by App Engine.
+// See https://cloud.google.com/appengine/docs/standard/go/taskqueue/push/creating-handlers#reading_request_headers
+// for a description of the fields.
+type RequestHeaders struct {
+ QueueName string
+ TaskName string
+ TaskRetryCount int64
+ TaskExecutionCount int64
+ TaskETA time.Time
+
+ TaskPreviousResponse int
+ TaskRetryReason string
+ FailFast bool
+}
+
+// ParseRequestHeaders parses the special HTTP request headers available to push
+// task request handlers. This function silently ignores values of the wrong
+// format.
+func ParseRequestHeaders(h http.Header) *RequestHeaders {
+ ret := &RequestHeaders{
+ QueueName: h.Get("X-AppEngine-QueueName"),
+ TaskName: h.Get("X-AppEngine-TaskName"),
+ }
+
+ ret.TaskRetryCount, _ = strconv.ParseInt(h.Get("X-AppEngine-TaskRetryCount"), 10, 64)
+ ret.TaskExecutionCount, _ = strconv.ParseInt(h.Get("X-AppEngine-TaskExecutionCount"), 10, 64)
+
+ etaSecs, _ := strconv.ParseInt(h.Get("X-AppEngine-TaskETA"), 10, 64)
+ if etaSecs != 0 {
+ ret.TaskETA = time.Unix(etaSecs, 0)
+ }
+
+ ret.TaskPreviousResponse, _ = strconv.Atoi(h.Get("X-AppEngine-TaskPreviousResponse"))
+ ret.TaskRetryReason = h.Get("X-AppEngine-TaskRetryReason")
+ if h.Get("X-AppEngine-FailFast") != "" {
+ ret.FailFast = true
+ }
+
+ return ret
+}
+
var (
currentNamespace = http.CanonicalHeaderKey("X-AppEngine-Current-Namespace")
defaultNamespace = http.CanonicalHeaderKey("X-AppEngine-Default-Namespace")
diff --git a/taskqueue/taskqueue_test.go b/taskqueue/taskqueue_test.go
index 0c14015..d9eec50 100644
--- a/taskqueue/taskqueue_test.go
+++ b/taskqueue/taskqueue_test.go
@@ -7,8 +7,10 @@
import (
"errors"
"fmt"
+ "net/http"
"reflect"
"testing"
+ "time"
"google.golang.org/appengine"
"google.golang.org/appengine/internal"
@@ -114,3 +116,58 @@
t.Fatalf("Add: %v", err)
}
}
+
+func TestParseRequestHeaders(t *testing.T) {
+ tests := []struct {
+ Header http.Header
+ Want RequestHeaders
+ }{
+ {
+ Header: map[string][]string{
+ "X-Appengine-Queuename": []string{"foo"},
+ "X-Appengine-Taskname": []string{"bar"},
+ "X-Appengine-Taskretrycount": []string{"4294967297"}, // 2^32 + 1
+ "X-Appengine-Taskexecutioncount": []string{"4294967298"}, // 2^32 + 2
+ "X-Appengine-Tasketa": []string{"1500000000"},
+ "X-Appengine-Taskpreviousresponse": []string{"404"},
+ "X-Appengine-Taskretryreason": []string{"baz"},
+ "X-Appengine-Failfast": []string{"yes"},
+ },
+ Want: RequestHeaders{
+ QueueName: "foo",
+ TaskName: "bar",
+ TaskRetryCount: 4294967297,
+ TaskExecutionCount: 4294967298,
+ TaskETA: time.Date(2017, time.July, 14, 2, 40, 0, 0, time.UTC),
+ TaskPreviousResponse: 404,
+ TaskRetryReason: "baz",
+ FailFast: true,
+ },
+ },
+ {
+ Header: map[string][]string{},
+ Want: RequestHeaders{
+ QueueName: "",
+ TaskName: "",
+ TaskRetryCount: 0,
+ TaskExecutionCount: 0,
+ TaskETA: time.Time{},
+ TaskPreviousResponse: 0,
+ TaskRetryReason: "",
+ FailFast: false,
+ },
+ },
+ }
+
+ for idx, test := range tests {
+ got := *ParseRequestHeaders(test.Header)
+ if got.TaskETA.UnixNano() != test.Want.TaskETA.UnixNano() {
+ t.Errorf("%d. ParseRequestHeaders got TaskETA %v, wanted %v", idx, got.TaskETA, test.Want.TaskETA)
+ }
+ got.TaskETA = time.Time{}
+ test.Want.TaskETA = time.Time{}
+ if !reflect.DeepEqual(got, test.Want) {
+ t.Errorf("%d. ParseRequestHeaders got %v, wanted %v", idx, got, test.Want)
+ }
+ }
+}