acme: reduce the number of network round trips

Before this change, every JWS-signed request was preceded
by a HEAD request to fetch a fresh nonce.

The Client is now able to collect nonce values
from server responses and use them for future requests.
Additionally, this change also makes sure the client propagates
any error encountered during a fresh nonce fetch.

Fixes golang/go#18428.

Change-Id: I33d21b450351cf4d98e72ee6c8fa654e9554bf92
Reviewed-on: https://go-review.googlesource.com/36514
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
diff --git a/acme/acme.go b/acme/acme.go
index 8aafada..8619508 100644
--- a/acme/acme.go
+++ b/acme/acme.go
@@ -47,6 +47,10 @@
 const (
 	maxChainLen = 5       // max depth and breadth of a certificate chain
 	maxCertSize = 1 << 20 // max size of a certificate, in bytes
+
+	// Max number of collected nonces kept in memory.
+	// Expect usual peak of 1 or 2.
+	maxNonces = 100
 )
 
 // CertOption is an optional argument type for Client methods which manipulate
@@ -108,6 +112,9 @@
 
 	dirMu sync.Mutex // guards writes to dir
 	dir   *Directory // cached result of Client's Discover method
+
+	noncesMu sync.Mutex
+	nonces   map[string]struct{} // nonces collected from previous responses
 }
 
 // Discover performs ACME server discovery using c.DirectoryURL.
@@ -131,6 +138,7 @@
 		return Directory{}, err
 	}
 	defer res.Body.Close()
+	c.addNonce(res.Header)
 	if res.StatusCode != http.StatusOK {
 		return Directory{}, responseError(res)
 	}
@@ -192,7 +200,7 @@
 		req.NotAfter = now.Add(exp).Format(time.RFC3339)
 	}
 
-	res, err := postJWS(ctx, c.HTTPClient, c.Key, c.dir.CertURL, req)
+	res, err := c.postJWS(ctx, c.Key, c.dir.CertURL, req)
 	if err != nil {
 		return nil, "", err
 	}
@@ -267,7 +275,7 @@
 	if key == nil {
 		key = c.Key
 	}
-	res, err := postJWS(ctx, c.HTTPClient, key, c.dir.RevokeURL, body)
+	res, err := c.postJWS(ctx, key, c.dir.RevokeURL, body)
 	if err != nil {
 		return err
 	}
@@ -355,7 +363,7 @@
 		Resource:   "new-authz",
 		Identifier: authzID{Type: "dns", Value: domain},
 	}
-	res, err := postJWS(ctx, c.HTTPClient, c.Key, c.dir.AuthzURL, req)
+	res, err := c.postJWS(ctx, c.Key, c.dir.AuthzURL, req)
 	if err != nil {
 		return nil, err
 	}
@@ -413,7 +421,7 @@
 		Status:   "deactivated",
 		Delete:   true,
 	}
-	res, err := postJWS(ctx, c.HTTPClient, c.Key, url, req)
+	res, err := c.postJWS(ctx, c.Key, url, req)
 	if err != nil {
 		return err
 	}
@@ -519,7 +527,7 @@
 		Type:     chal.Type,
 		Auth:     auth,
 	}
-	res, err := postJWS(ctx, c.HTTPClient, c.Key, chal.URI, req)
+	res, err := c.postJWS(ctx, c.Key, chal.URI, req)
 	if err != nil {
 		return nil, err
 	}
@@ -652,7 +660,7 @@
 		req.Contact = acct.Contact
 		req.Agreement = acct.AgreedTerms
 	}
-	res, err := postJWS(ctx, c.HTTPClient, c.Key, url, req)
+	res, err := c.postJWS(ctx, c.Key, url, req)
 	if err != nil {
 		return nil, err
 	}
@@ -689,6 +697,78 @@
 	}, nil
 }
 
+// postJWS signs the body with the given key and POSTs it to the provided url.
+// The body argument must be JSON-serializable.
+func (c *Client) postJWS(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
+	nonce, err := c.popNonce(ctx, url)
+	if err != nil {
+		return nil, err
+	}
+	b, err := jwsEncodeJSON(body, key, nonce)
+	if err != nil {
+		return nil, err
+	}
+	res, err := ctxhttp.Post(ctx, c.HTTPClient, url, "application/jose+json", bytes.NewReader(b))
+	if err != nil {
+		return nil, err
+	}
+	c.addNonce(res.Header)
+	return res, nil
+}
+
+// popNonce returns a nonce value previously stored with c.addNonce
+// or fetches a fresh one from the given URL.
+func (c *Client) popNonce(ctx context.Context, url string) (string, error) {
+	c.noncesMu.Lock()
+	defer c.noncesMu.Unlock()
+	if len(c.nonces) == 0 {
+		return fetchNonce(ctx, c.HTTPClient, url)
+	}
+	var nonce string
+	for nonce = range c.nonces {
+		delete(c.nonces, nonce)
+		break
+	}
+	return nonce, nil
+}
+
+// addNonce stores a nonce value found in h (if any) for future use.
+func (c *Client) addNonce(h http.Header) {
+	v := nonceFromHeader(h)
+	if v == "" {
+		return
+	}
+	c.noncesMu.Lock()
+	defer c.noncesMu.Unlock()
+	if len(c.nonces) >= maxNonces {
+		return
+	}
+	if c.nonces == nil {
+		c.nonces = make(map[string]struct{})
+	}
+	c.nonces[v] = struct{}{}
+}
+
+func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
+	resp, err := ctxhttp.Head(ctx, client, url)
+	if err != nil {
+		return "", err
+	}
+	defer resp.Body.Close()
+	nonce := nonceFromHeader(resp.Header)
+	if nonce == "" {
+		if resp.StatusCode > 299 {
+			return "", responseError(resp)
+		}
+		return "", errors.New("acme: nonce not found")
+	}
+	return nonce, nil
+}
+
+func nonceFromHeader(h http.Header) string {
+	return h.Get("Replay-Nonce")
+}
+
 func responseCert(ctx context.Context, client *http.Client, res *http.Response, bundle bool) ([][]byte, error) {
 	b, err := ioutil.ReadAll(io.LimitReader(res.Body, maxCertSize+1))
 	if err != nil {
@@ -793,33 +873,6 @@
 	return chain, nil
 }
 
-// postJWS signs the body with the given key and POSTs it to the provided url.
-// The body argument must be JSON-serializable.
-func postJWS(ctx context.Context, client *http.Client, key crypto.Signer, url string, body interface{}) (*http.Response, error) {
-	nonce, err := fetchNonce(ctx, client, url)
-	if err != nil {
-		return nil, err
-	}
-	b, err := jwsEncodeJSON(body, key, nonce)
-	if err != nil {
-		return nil, err
-	}
-	return ctxhttp.Post(ctx, client, url, "application/jose+json", bytes.NewReader(b))
-}
-
-func fetchNonce(ctx context.Context, client *http.Client, url string) (string, error) {
-	resp, err := ctxhttp.Head(ctx, client, url)
-	if err != nil {
-		return "", nil
-	}
-	defer resp.Body.Close()
-	enc := resp.Header.Get("replay-nonce")
-	if enc == "" {
-		return "", errors.New("acme: nonce not found")
-	}
-	return enc, nil
-}
-
 // linkHeader returns URI-Reference values of all Link headers
 // with relation-type rel.
 // See https://tools.ietf.org/html/rfc5988#section-5 for details.
diff --git a/acme/acme_test.go b/acme/acme_test.go
index 4e618f2..1205dbb 100644
--- a/acme/acme_test.go
+++ b/acme/acme_test.go
@@ -45,6 +45,28 @@
 	}
 }
 
+type jwsHead struct {
+	Alg   string
+	Nonce string
+	JWK   map[string]string `json:"jwk"`
+}
+
+func decodeJWSHead(r *http.Request) (*jwsHead, error) {
+	var req struct{ Protected string }
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		return nil, err
+	}
+	b, err := base64.RawURLEncoding.DecodeString(req.Protected)
+	if err != nil {
+		return nil, err
+	}
+	var head jwsHead
+	if err := json.Unmarshal(b, &head); err != nil {
+		return nil, err
+	}
+	return &head, nil
+}
+
 func TestDiscover(t *testing.T) {
 	const (
 		reg    = "https://example.com/acme/new-reg"
@@ -916,7 +938,30 @@
 	}
 }
 
-func TestFetchNonce(t *testing.T) {
+func TestNonce_add(t *testing.T) {
+	var c Client
+	c.addNonce(http.Header{"Replay-Nonce": {"nonce"}})
+	c.addNonce(http.Header{"Replay-Nonce": {}})
+	c.addNonce(http.Header{"Replay-Nonce": {"nonce"}})
+
+	nonces := map[string]struct{}{"nonce": struct{}{}}
+	if !reflect.DeepEqual(c.nonces, nonces) {
+		t.Errorf("c.nonces = %q; want %q", c.nonces, nonces)
+	}
+}
+
+func TestNonce_addMax(t *testing.T) {
+	c := &Client{nonces: make(map[string]struct{})}
+	for i := 0; i < maxNonces; i++ {
+		c.nonces[fmt.Sprintf("%d", i)] = struct{}{}
+	}
+	c.addNonce(http.Header{"Replay-Nonce": {"nonce"}})
+	if n := len(c.nonces); n != maxNonces {
+		t.Errorf("len(c.nonces) = %d; want %d", n, maxNonces)
+	}
+}
+
+func TestNonce_fetch(t *testing.T) {
 	tests := []struct {
 		code  int
 		nonce string
@@ -949,6 +994,76 @@
 	}
 }
 
+func TestNonce_fetchError(t *testing.T) {
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		w.WriteHeader(http.StatusTooManyRequests)
+	}))
+	defer ts.Close()
+	_, err := fetchNonce(context.Background(), http.DefaultClient, ts.URL)
+	e, ok := err.(*Error)
+	if !ok {
+		t.Fatalf("err is %T; want *Error", err)
+	}
+	if e.StatusCode != http.StatusTooManyRequests {
+		t.Errorf("e.StatusCode = %d; want %d", e.StatusCode, http.StatusTooManyRequests)
+	}
+}
+
+func TestNonce_postJWS(t *testing.T) {
+	var count int
+	seen := make(map[string]bool)
+	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		count++
+		w.Header().Set("replay-nonce", fmt.Sprintf("nonce%d", count))
+		if r.Method == "HEAD" {
+			// We expect the client do a HEAD request
+			// but only to fetch the first nonce.
+			return
+		}
+		// Make client.Authorize happy; we're not testing its result.
+		defer func() {
+			w.WriteHeader(http.StatusCreated)
+			w.Write([]byte(`{"status":"valid"}`))
+		}()
+
+		head, err := decodeJWSHead(r)
+		if err != nil {
+			t.Errorf("decodeJWSHead: %v", err)
+			return
+		}
+		if head.Nonce == "" {
+			t.Error("head.Nonce is empty")
+			return
+		}
+		if seen[head.Nonce] {
+			t.Errorf("nonce is already used: %q", head.Nonce)
+		}
+		seen[head.Nonce] = true
+	}))
+	defer ts.Close()
+
+	client := Client{Key: testKey, dir: &Directory{AuthzURL: ts.URL}}
+	if _, err := client.Authorize(context.Background(), "example.com"); err != nil {
+		t.Errorf("client.Authorize 1: %v", err)
+	}
+	// The second call should not generate another extra HEAD request.
+	if _, err := client.Authorize(context.Background(), "example.com"); err != nil {
+		t.Errorf("client.Authorize 2: %v", err)
+	}
+
+	if count != 3 {
+		t.Errorf("total requests count: %d; want 3", count)
+	}
+	if n := len(client.nonces); n != 1 {
+		t.Errorf("len(client.nonces) = %d; want 1", n)
+	}
+	for k := range seen {
+		if _, exist := client.nonces[k]; exist {
+			t.Errorf("used nonce %q in client.nonces", k)
+		}
+	}
+}
+
 func TestLinkHeader(t *testing.T) {
 	h := http.Header{"Link": {
 		`<https://example.com/acme/new-authz>;rel="next"`,