| /* |
| Copyright 2013 Google Inc. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| */ |
| |
| package groupcache |
| |
| import ( |
| "fmt" |
| "hash/crc32" |
| "io/ioutil" |
| "net/http" |
| "net/url" |
| "strings" |
| "sync" |
| |
| "code.google.com/p/goprotobuf/proto" |
| |
| pb "github.com/golang/groupcache/groupcachepb" |
| ) |
| |
| // TODO: make this configurable? |
| const defaultBasePath = "/_groupcache/" |
| |
| // HTTPPool implements PeerPicker for a pool of HTTP peers. |
| type HTTPPool struct { |
| // Context optionally specifies a context for the server to use when it |
| // receives a request. |
| // If nil, the server uses a nil Context. |
| Context func(*http.Request) Context |
| |
| // Transport optionally specifies an http.RoundTripper for the client |
| // to use when it makes a request. |
| // If nil, the client uses http.DefaultTransport. |
| Transport func(Context) http.RoundTripper |
| |
| // base path including leading and trailing slash, e.g. "/_groupcache/" |
| basePath string |
| |
| // this peer's base URL, e.g. "https://example.net:8000" |
| self string |
| |
| mu sync.Mutex |
| peers []string |
| } |
| |
| var httpPoolMade bool |
| |
| // NewHTTPPool initializes an HTTP pool of peers. |
| // It registers itself as a PeerPicker and as an HTTP handler with the |
| // http.DefaultServeMux. |
| // The self argument be a valid base URL that points to the current server, |
| // for example "http://example.net:8000". |
| func NewHTTPPool(self string) *HTTPPool { |
| if httpPoolMade { |
| panic("groupcache: NewHTTPPool must be called only once") |
| } |
| httpPoolMade = true |
| p := &HTTPPool{basePath: defaultBasePath, self: self} |
| RegisterPeerPicker(func() PeerPicker { return p }) |
| http.Handle(defaultBasePath, p) |
| return p |
| } |
| |
| // Set updates the pool's list of peers. |
| // Each peer value should be a valid base URL, |
| // for example "http://example.net:8000". |
| func (p *HTTPPool) Set(peers ...string) { |
| p.mu.Lock() |
| defer p.mu.Unlock() |
| p.peers = append([]string{}, peers...) |
| } |
| |
| func (p *HTTPPool) PickPeer(key string) (ProtoGetter, bool) { |
| // TODO: make checksum implementation pluggable |
| h := crc32.Checksum([]byte(key), crc32.IEEETable) |
| p.mu.Lock() |
| defer p.mu.Unlock() |
| if len(p.peers) == 0 { |
| return nil, false |
| } |
| n := int(h) |
| if n < 0 { |
| n *= -1 |
| } |
| if peer := p.peers[n%len(p.peers)]; peer != p.self { |
| // TODO: pre-build a slice of *httpGetter when Set() |
| // is called to avoid these two allocations. |
| return &httpGetter{p.Transport, peer + p.basePath}, true |
| } |
| return nil, false |
| } |
| |
| func (p *HTTPPool) ServeHTTP(w http.ResponseWriter, r *http.Request) { |
| // Parse request. |
| if !strings.HasPrefix(r.URL.Path, p.basePath) { |
| panic("HTTPPool serving unexpected path: " + r.URL.Path) |
| } |
| parts := strings.SplitN(r.URL.Path[len(p.basePath):], "/", 2) |
| if len(parts) != 2 { |
| http.Error(w, "bad request", http.StatusBadRequest) |
| return |
| } |
| groupName, err := url.QueryUnescape(parts[0]) |
| if err != nil { |
| http.Error(w, "decoding group: "+err.Error(), http.StatusBadRequest) |
| return |
| } |
| key, err := url.QueryUnescape(parts[1]) |
| if err != nil { |
| http.Error(w, "decoding key: "+err.Error(), http.StatusBadRequest) |
| return |
| } |
| |
| // Fetch the value for this group/key. |
| group := GetGroup(groupName) |
| if group == nil { |
| http.Error(w, "no such group: "+groupName, http.StatusNotFound) |
| return |
| } |
| var ctx Context |
| if p.Context != nil { |
| ctx = p.Context(r) |
| } |
| var value []byte |
| err = group.Get(ctx, key, AllocatingByteSliceSink(&value)) |
| if err != nil { |
| http.Error(w, err.Error(), http.StatusInternalServerError) |
| return |
| } |
| |
| // Write the value to the response body as a proto message. |
| body, err := proto.Marshal(&pb.GetResponse{Value: value}) |
| if err != nil { |
| http.Error(w, err.Error(), http.StatusInternalServerError) |
| return |
| } |
| w.Header().Set("Content-Type", "application/x-protobuf") |
| w.Write(body) |
| } |
| |
| type httpGetter struct { |
| transport func(Context) http.RoundTripper |
| baseURL string |
| } |
| |
| func (h *httpGetter) Get(context Context, in *pb.GetRequest, out *pb.GetResponse) error { |
| u := fmt.Sprintf( |
| "%v%v/%v", |
| h.baseURL, |
| url.QueryEscape(in.GetGroup()), |
| url.QueryEscape(in.GetKey()), |
| ) |
| req, err := http.NewRequest("GET", u, nil) |
| if err != nil { |
| return err |
| } |
| tr := http.DefaultTransport |
| if h.transport != nil { |
| tr = h.transport(context) |
| } |
| res, err := tr.RoundTrip(req) |
| if err != nil { |
| return err |
| } |
| if res.StatusCode != http.StatusOK { |
| return fmt.Errorf("server returned: %v", res.Status) |
| } |
| defer res.Body.Close() |
| // TODO: avoid this garbage. |
| b, err := ioutil.ReadAll(res.Body) |
| if err != nil { |
| return fmt.Errorf("reading response body: %v", err) |
| } |
| err = proto.Unmarshal(b, out) |
| if err != nil { |
| return fmt.Errorf("decoding response body: %v", err) |
| } |
| return nil |
| } |