delay: make it possible to get taskqueue HTTP headers from inside a delay.Func (#82)

Make it possible to get the in-flight request from inside a delay.Func

Fixes #59
diff --git a/delay/delay.go b/delay/delay.go
index af4eb8c..21c90bf 100644
--- a/delay/delay.go
+++ b/delay/delay.go
@@ -74,6 +74,8 @@
 	queue = ""
 )
 
+type contextKey int
+
 var (
 	// registry of all delayed functions
 	funcs = make(map[string]*Function)
@@ -83,7 +85,11 @@
 	errorType   = reflect.TypeOf((*error)(nil)).Elem()
 
 	// errors
-	errFirstArg = errors.New("first argument must be context.Context")
+	errFirstArg         = errors.New("first argument must be context.Context")
+	errOutsideDelayFunc = errors.New("request headers are only available inside a delay.Func")
+
+	// context keys
+	headersContextKey contextKey = 0
 )
 
 // Func declares a new Function. The second argument must be a function with a
@@ -222,6 +228,15 @@
 	}, nil
 }
 
+// Request returns the special task-queue HTTP request headers for the current
+// task queue handler. Returns an error if called from outside a delay.Func.
+func RequestHeaders(c context.Context) (*taskqueue.RequestHeaders, error) {
+	if ret, ok := c.Value(headersContextKey).(*taskqueue.RequestHeaders); ok {
+		return ret, nil
+	}
+	return nil, errOutsideDelayFunc
+}
+
 var taskqueueAdder = taskqueue.Add // for testing
 
 func init() {
@@ -233,6 +248,8 @@
 func runFunc(c context.Context, w http.ResponseWriter, req *http.Request) {
 	defer req.Body.Close()
 
+	c = context.WithValue(c, headersContextKey, taskqueue.ParseRequestHeaders(req.Header))
+
 	var inv invocation
 	if err := gob.NewDecoder(req.Body).Decode(&inv); err != nil {
 		log.Errorf(c, "delay: failed decoding task payload: %v", err)
diff --git a/delay/delay_test.go b/delay/delay_test.go
index 1c37e79..3df2bf7 100644
--- a/delay/delay_test.go
+++ b/delay/delay_test.go
@@ -94,6 +94,14 @@
 			dupeWhich = 2
 		}
 	})
+
+	reqFuncRuns    = 0
+	reqFuncHeaders *taskqueue.RequestHeaders
+	reqFuncErr     error
+	reqFunc        = Func("req", func(c context.Context) {
+		reqFuncRuns++
+		reqFuncHeaders, reqFuncErr = RequestHeaders(c)
+	})
 )
 
 type fakeContext struct {
@@ -373,3 +381,48 @@
 		t.Errorf("dupeWhich = %d; want 2", dupeWhich)
 	}
 }
+
+func TestGetRequestHeadersFromContext(t *testing.T) {
+	c := newFakeContext()
+
+	// Outside a delay.Func should return an error.
+	headers, err := RequestHeaders(c.ctx)
+	if headers != nil {
+		t.Errorf("RequestHeaders outside Func, got %v, want nil", headers)
+	}
+	if err != errOutsideDelayFunc {
+		t.Errorf("RequestHeaders outside Func err, got %v, want %v", err, errOutsideDelayFunc)
+	}
+
+	// Fake out the adding of a task.
+	var task *taskqueue.Task
+	taskqueueAdder = func(_ context.Context, tk *taskqueue.Task, queue string) (*taskqueue.Task, error) {
+		if queue != "" {
+			t.Errorf(`Got queue %q, expected ""`, queue)
+		}
+		task = tk
+		return tk, nil
+	}
+
+	reqFunc.Call(c.ctx)
+
+	reqFuncRuns, reqFuncHeaders = 0, nil // reset state
+	// Simulate the Task Queue service.
+	req, err := http.NewRequest("POST", path, bytes.NewBuffer(task.Payload))
+	req.Header.Set("x-appengine-taskname", "foobar")
+	if err != nil {
+		t.Fatalf("Failed making http.Request: %v", err)
+	}
+	rw := httptest.NewRecorder()
+	runFunc(c.ctx, rw, req)
+
+	if reqFuncRuns != 1 {
+		t.Errorf("reqFuncRuns: got %d, want 1", reqFuncRuns)
+	}
+	if reqFuncHeaders.TaskName != "foobar" {
+		t.Errorf("reqFuncHeaders.TaskName: got %v, want 'foobar'", reqFuncHeaders.TaskName)
+	}
+	if reqFuncErr != nil {
+		t.Errorf("reqFuncErr: got %v, want nil", reqFuncErr)
+	}
+}
diff --git a/taskqueue/taskqueue.go b/taskqueue/taskqueue.go
index 5fed9b3..965c5ab 100644
--- a/taskqueue/taskqueue.go
+++ b/taskqueue/taskqueue.go
@@ -21,6 +21,7 @@
 	"fmt"
 	"net/http"
 	"net/url"
+	"strconv"
 	"time"
 
 	"github.com/golang/protobuf/proto"
@@ -147,6 +148,48 @@
 	}
 }
 
+// RequestHeaders are the special HTTP request headers available to push task
+// HTTP request handlers. These headers are set internally by App Engine.
+// See https://cloud.google.com/appengine/docs/standard/go/taskqueue/push/creating-handlers#reading_request_headers
+// for a description of the fields.
+type RequestHeaders struct {
+	QueueName          string
+	TaskName           string
+	TaskRetryCount     int64
+	TaskExecutionCount int64
+	TaskETA            time.Time
+
+	TaskPreviousResponse int
+	TaskRetryReason      string
+	FailFast             bool
+}
+
+// ParseRequestHeaders parses the special HTTP request headers available to push
+// task request handlers. This function silently ignores values of the wrong
+// format.
+func ParseRequestHeaders(h http.Header) *RequestHeaders {
+	ret := &RequestHeaders{
+		QueueName: h.Get("X-AppEngine-QueueName"),
+		TaskName:  h.Get("X-AppEngine-TaskName"),
+	}
+
+	ret.TaskRetryCount, _ = strconv.ParseInt(h.Get("X-AppEngine-TaskRetryCount"), 10, 64)
+	ret.TaskExecutionCount, _ = strconv.ParseInt(h.Get("X-AppEngine-TaskExecutionCount"), 10, 64)
+
+	etaSecs, _ := strconv.ParseInt(h.Get("X-AppEngine-TaskETA"), 10, 64)
+	if etaSecs != 0 {
+		ret.TaskETA = time.Unix(etaSecs, 0)
+	}
+
+	ret.TaskPreviousResponse, _ = strconv.Atoi(h.Get("X-AppEngine-TaskPreviousResponse"))
+	ret.TaskRetryReason = h.Get("X-AppEngine-TaskRetryReason")
+	if h.Get("X-AppEngine-FailFast") != "" {
+		ret.FailFast = true
+	}
+
+	return ret
+}
+
 var (
 	currentNamespace = http.CanonicalHeaderKey("X-AppEngine-Current-Namespace")
 	defaultNamespace = http.CanonicalHeaderKey("X-AppEngine-Default-Namespace")
diff --git a/taskqueue/taskqueue_test.go b/taskqueue/taskqueue_test.go
index 0c14015..d9eec50 100644
--- a/taskqueue/taskqueue_test.go
+++ b/taskqueue/taskqueue_test.go
@@ -7,8 +7,10 @@
 import (
 	"errors"
 	"fmt"
+	"net/http"
 	"reflect"
 	"testing"
+	"time"
 
 	"google.golang.org/appengine"
 	"google.golang.org/appengine/internal"
@@ -114,3 +116,58 @@
 		t.Fatalf("Add: %v", err)
 	}
 }
+
+func TestParseRequestHeaders(t *testing.T) {
+	tests := []struct {
+		Header http.Header
+		Want   RequestHeaders
+	}{
+		{
+			Header: map[string][]string{
+				"X-Appengine-Queuename":            []string{"foo"},
+				"X-Appengine-Taskname":             []string{"bar"},
+				"X-Appengine-Taskretrycount":       []string{"4294967297"}, // 2^32 + 1
+				"X-Appengine-Taskexecutioncount":   []string{"4294967298"}, // 2^32 + 2
+				"X-Appengine-Tasketa":              []string{"1500000000"},
+				"X-Appengine-Taskpreviousresponse": []string{"404"},
+				"X-Appengine-Taskretryreason":      []string{"baz"},
+				"X-Appengine-Failfast":             []string{"yes"},
+			},
+			Want: RequestHeaders{
+				QueueName:            "foo",
+				TaskName:             "bar",
+				TaskRetryCount:       4294967297,
+				TaskExecutionCount:   4294967298,
+				TaskETA:              time.Date(2017, time.July, 14, 2, 40, 0, 0, time.UTC),
+				TaskPreviousResponse: 404,
+				TaskRetryReason:      "baz",
+				FailFast:             true,
+			},
+		},
+		{
+			Header: map[string][]string{},
+			Want: RequestHeaders{
+				QueueName:            "",
+				TaskName:             "",
+				TaskRetryCount:       0,
+				TaskExecutionCount:   0,
+				TaskETA:              time.Time{},
+				TaskPreviousResponse: 0,
+				TaskRetryReason:      "",
+				FailFast:             false,
+			},
+		},
+	}
+
+	for idx, test := range tests {
+		got := *ParseRequestHeaders(test.Header)
+		if got.TaskETA.UnixNano() != test.Want.TaskETA.UnixNano() {
+			t.Errorf("%d. ParseRequestHeaders got TaskETA %v, wanted %v", idx, got.TaskETA, test.Want.TaskETA)
+		}
+		got.TaskETA = time.Time{}
+		test.Want.TaskETA = time.Time{}
+		if !reflect.DeepEqual(got, test.Want) {
+			t.Errorf("%d. ParseRequestHeaders got %v, wanted %v", idx, got, test.Want)
+		}
+	}
+}