add HTTP pool implementation
diff --git a/http.go b/http.go
new file mode 100644
index 0000000..db3ac59
--- /dev/null
+++ b/http.go
@@ -0,0 +1,187 @@
+/*
+Copyright 2013 Google Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package groupcache
+
+import (
+ "fmt"
+ "hash/crc32"
+ "io/ioutil"
+ "net/http"
+ "net/url"
+ "strings"
+ "sync"
+
+ "code.google.com/p/goprotobuf/proto"
+
+ pb "github.com/golang/groupcache/groupcachepb"
+)
+
+// TODO: make this configurable?
+const defaultBasePath = "/_groupcache/"
+
+// HTTPPool implements PeerPicker for a pool of HTTP peers.
+type HTTPPool struct {
+ // Context optionally specifies a context for the server to use when it
+ // receives a request.
+ // If nil, the server uses a nil Context.
+ Context func(*http.Request) Context
+
+ // Transport optionally specifies an http.RoundTripper for the client
+ // to use when it makes a request.
+ // If nil, the client uses http.DefaultTransport.
+ Transport func(Context) http.RoundTripper
+
+ // base path including leading and trailing slash, e.g. "/_groupcache/"
+ basePath string
+
+ // this peer's base URL, e.g. "https://example.net:8000"
+ self string
+
+ mu sync.Mutex
+ peers []string
+}
+
+var httpPoolMade bool
+
+// NewHTTPPool initializes an HTTP pool of peers.
+// It registers itself as a PeerPicker and as an HTTP handler with the
+// http.DefaultServeMux.
+// The self argument be a valid base URL that points to the current server,
+// for example "http://example.net:8000".
+func NewHTTPPool(self string) *HTTPPool {
+ if httpPoolMade {
+ panic("groupcache: NewHTTPPool must be called only once")
+ }
+ httpPoolMade = true
+ p := &HTTPPool{basePath: defaultBasePath, self: self}
+ RegisterPeerPicker(func() PeerPicker { return p })
+ http.Handle(defaultBasePath, p)
+ return p
+}
+
+// Set updates the pool's list of peers.
+// Each peer value should be a valid base URL,
+// for example "http://example.net:8000".
+func (p *HTTPPool) Set(peers ...string) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ p.peers = append([]string{}, peers...)
+}
+
+func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) {
+ // TODO: make checksum implementation pluggable
+ h := crc32.Checksum([]byte(key), crc32.IEEETable)
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if len(p.peers) == 0 {
+ return nil, false
+ }
+ if peer := p.peers[int(h)%len(p.peers)]; peer != p.self {
+ // TODO: pre-build a slice of *httpGetter when Set()
+ // is called to avoid these two allocations.
+ return &httpGetter{p.Transport, peer + p.basePath}, true
+ }
+ return nil, false
+}
+
+func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ // Parse request.
+ if !strings.HasPrefix(r.URL.Path, p.basePath) {
+ panic("HTTPPool serving unexpected path: " + r.URL.Path)
+ }
+ parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2)
+ if len(parts) != 2 {
+ http.Error(w, "bad request", http.StatusBadRequest)
+ return
+ }
+ groupName, err := url.QueryUnescape(parts[0])
+ if err != nil {
+ http.Error(w, "decoding group: "+err.Error(), http.StatusBadRequest)
+ return
+ }
+ key, err := url.QueryUnescape(parts[1])
+ if err != nil {
+ http.Error(w, "decoding key: "+err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ // Fetch the value for this group/key.
+ group := GetGroup(groupName)
+ if group == nil {
+ http.Error(w, "no such group: "+groupName, http.StatusNotFound)
+ return
+ }
+ var ctx Context
+ if p.Context != nil {
+ ctx = p.Context(r)
+ }
+ var value []byte
+ err = group.Get(ctx, key, AllocatingByteSliceSink(&value))
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ // Write the value to the response body as a proto message.
+ body, err := proto.Marshal(&pb.GetResponse{Value: value})
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ w.Header().Set("Content-Type", "application/x-protobuf")
+ w.Write(body)
+}
+
+type httpGetter struct {
+ transport func(Context) http.RoundTripper
+ baseURL string
+}
+
+func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse) error {
+ u := fmt.Sprintf(
+ "%v%v/%v",
+ h.baseURL,
+ url.QueryEscape(in.GetGroup()),
+ url.QueryEscape(in.GetKey()),
+ )
+ req, err := http.NewRequest("GET", u, nil)
+ if err != nil {
+ return err
+ }
+ tr := http.DefaultTransport
+ if h.transport != nil {
+ tr = h.transport(context)
+ }
+ res, err := tr.RoundTrip(req)
+ if err != nil {
+ return err
+ }
+ if res.StatusCode != http.StatusOK {
+ return fmt.Errorf("server returned: %v", res.Status)
+ }
+ defer res.Body.Close()
+ // TODO: avoid this garbage.
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ return fmt.Errorf("reading response body: %v", err)
+ }
+ err = proto.Unmarshal(b, out)
+ if err != nil {
+ return fmt.Errorf("decoding response body: %v", err)
+ }
+ return nil
+}
diff --git a/http_test.go b/http_test.go
new file mode 100644
index 0000000..279bcbf
--- /dev/null
+++ b/http_test.go
@@ -0,0 +1,166 @@
+/*
+Copyright 2013 Google Inc.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package groupcache
+
+import (
+ "errors"
+ "flag"
+ "log"
+ "net"
+ "net/http"
+ "os"
+ "os/exec"
+ "strconv"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+var (
+ peerAddrs = flag.String("test_peer_addrs", "", "Comma-separated list of peer addresses; used by TestHTTPPool")
+ peerIndex = flag.Int("test_peer_index", -1, "Index of which peer this child is; used by TestHTTPPool")
+ peerChild = flag.Bool("test_peer_child", false, "True if running as a child process; used by TestHTTPPool")
+)
+
+func TestHTTPPool(t *testing.T) {
+ if *peerChild {
+ beChildForTestHTTPPool()
+ os.Exit(0)
+ }
+
+ const (
+ nChild = 4
+ nGets = 100
+ )
+
+ var childAddr []string
+ for i := 0; i < nChild; i++ {
+ childAddr = append(childAddr, pickFreeAddr(t))
+ }
+
+ var cmds []*exec.Cmd
+ var wg sync.WaitGroup
+ for i := 0; i < nChild; i++ {
+ cmd := exec.Command(os.Args[0],
+ "--test.run=TestHTTPPool",
+ "--test_peer_child",
+ "--test_peer_addrs="+strings.Join(childAddr, ","),
+ "--test_peer_index="+strconv.Itoa(i),
+ )
+ cmds = append(cmds, cmd)
+ wg.Add(1)
+ if err := cmd.Start(); err != nil {
+ t.Fatal("failed to start child process: ", err)
+ }
+ go awaitAddrReady(t, childAddr[i], &wg)
+ }
+ defer func() {
+ for i := 0; i < nChild; i++ {
+ if cmds[i].Process != nil {
+ cmds[i].Process.Kill()
+ }
+ }
+ }()
+ wg.Wait()
+
+ // Use a dummy self address so that we don't handle gets in-process.
+ p := NewHTTPPool("should-be-ignored")
+ p.Set(addrToURL(childAddr)...)
+
+ // Dummy getter function. Gets should go to children only.
+ // The only time this process will handle a get is when the
+ // children can't be contacted for seome reason.
+ getter := GetterFunc(func(ctx Context, key string, dest Sink) error {
+ return errors.New("parent getter called; something's wrong")
+ })
+ g := NewGroup("httpPoolTest", 1<<20, getter)
+
+ for _, key := range testKeys(nGets) {
+ var value string
+ if err := g.Get(nil, key, StringSink(&value)); err != nil {
+ t.Fatal(err)
+ }
+ if suffix := ":" + key; !strings.HasSuffix(value, suffix) {
+ t.Errorf("Get(%q) = %q, want value ending in %q", key, value, suffix)
+ }
+ t.Logf("Get key=%q, value=%q (peer:key)", key, value)
+ }
+}
+
+func testKeys(n int) (keys []string) {
+ keys = make([]string, n)
+ for i := range keys {
+ keys[i] = strconv.Itoa(i)
+ }
+ return
+}
+
+func beChildForTestHTTPPool() {
+ addrs := strings.Split(*peerAddrs, ",")
+
+ p := NewHTTPPool("http://" + addrs[*peerIndex])
+ p.Set(addrToURL(addrs)...)
+
+ getter := GetterFunc(func(ctx Context, key string, dest Sink) error {
+ dest.SetString(strconv.Itoa(*peerIndex) + ":" + key)
+ return nil
+ })
+ NewGroup("httpPoolTest", 1<<20, getter)
+
+ log.Fatal(http.ListenAndServe(addrs[*peerIndex], p))
+}
+
+// This is racy. Another process could swoop in and steal the port between the
+// call to this function and the next listen call. Should be okay though.
+// The proper way would be to pass the l.File() as ExtraFiles to the child
+// process, and then close your copy once the child starts.
+func pickFreeAddr(t *testing.T) string {
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer l.Close()
+ return l.Addr().String()
+}
+
+func addrToURL(addr []string) []string {
+ url := make([]string, len(addr))
+ for i := range addr {
+ url[i] = "http://" + addr[i]
+ }
+ return url
+}
+
+func awaitAddrReady(t *testing.T, addr string, wg *sync.WaitGroup) {
+ defer wg.Done()
+ const max = 1 * time.Second
+ tries := 0
+ for {
+ tries++
+ c, err := net.Dial("tcp", addr)
+ if err == nil {
+ c.Close()
+ return
+ }
+ delay := time.Duration(tries) * 25 * time.Millisecond
+ if delay > max {
+ delay = max
+ }
+ time.Sleep(delay)
+ }
+}