| // Copyright 2019 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can |
| // found in the LICENSE file. |
| |
| package main |
| |
| import ( |
| "context" |
| "encoding/json" |
| "flag" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "log" |
| "net/http" |
| "os" |
| "os/signal" |
| "strings" |
| "syscall" |
| "time" |
| |
| "golang.org/x/oauth2" |
| "golang.org/x/oauth2/google" |
| "golang.org/x/time/rate" |
| |
| "go.fuchsia.dev/fuchsia/tools/lib/retry" |
| ) |
| |
| const ( |
| // OAuth2 scope for reading GCS. |
| // See https://cloud.google.com/storage/docs/authentication. |
| gcsReadOnlyScope = "https://www.googleapis.com/auth/devstorage.read_only" |
| |
| // See https://cloud.google.com/storage/docs/request-endpoints. |
| gcsHost = "storage.googleapis.com" |
| |
| // We want to allow at most `maxNumRequests` per IP to be serviced every |
| // `refreshWindowMs` milliseconds. |
| // In terms of the token bucket underlying rate.Limiter, this translates to a |
| // new token refreshed every `tokenRefreshRate` with a burst size of |
| // `burstSize`: this allows for a token pool in which we can check Allow() |
| // and Wait() at the desired rates without consuming any tokens reserved for |
| // servicing of other requests. |
| // See https://godoc.org/golang.org/x/time/rate#Limiter for more details. |
| maxNumRequests = 20 |
| refreshWindowMs = 200 |
| tokenBudgetPerRequest = 2 |
| tokenRefreshRate = (refreshWindowMs / (tokenBudgetPerRequest * maxNumRequests)) * time.Millisecond |
| tokenBurstSize = 2 * maxNumRequests |
| |
| // Constants for retrying communication with GCS. |
| retryBackoff = 100 * time.Millisecond |
| retryAttempts = 10 |
| ) |
| |
| var ( |
| credentialsFile string |
| port string |
| allowedAddrsFile string |
| ) |
| |
| func usage() { |
| fmt.Printf(`gcsproxy [flags] |
| |
| Starts a proxy server that forwards requests to GCS with authentication. |
| `) |
| } |
| |
| func init() { |
| flag.Usage = usage |
| flag.StringVar(&credentialsFile, "credentials", "", "path to a credentials file in the Google Credentials File format; if none provided, default application credentials will be used.") |
| flag.StringVar(&port, "port", "", "port at which the server should listen.") |
| flag.StringVar(&allowedAddrsFile, "allowed", "", "a flat JSON list of remote addresses allowed to make requests of the proxy server") |
| } |
| |
| func main() { |
| flag.Parse() |
| |
| // For a graceful teardown in the event of a canceling signal. |
| ctx, cancel := context.WithCancel(context.Background()) |
| signals := make(chan os.Signal) |
| defer close(signals) |
| signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) |
| |
| go func() { |
| select { |
| case <-signals: |
| cancel() |
| } |
| }() |
| |
| if err := execute(ctx, credentialsFile, port, allowedAddrsFile); err != nil { |
| log.Fatal(err) |
| } |
| } |
| |
| // Execute starts the proxy server. |
| func execute(ctx context.Context, credFile, port, addrsFile string) error { |
| if port == "" { |
| return fmt.Errorf("-port is required") |
| } else if allowedAddrsFile == "" { |
| return fmt.Errorf("-allowed is required") |
| } |
| |
| credentials, err := findCredentials(ctx, credFile) |
| if err != nil { |
| return err |
| } |
| |
| var client *http.Client |
| if credentials == nil { |
| client = http.DefaultClient |
| } else { |
| client = oauth2.NewClient(ctx, credentials.TokenSource) |
| } |
| |
| var allowedAddrs []string |
| b, err := ioutil.ReadFile(addrsFile) |
| if err != nil { |
| return err |
| } |
| if err := json.Unmarshal(b, &allowedAddrs); err != nil { |
| return err |
| } |
| |
| limiters := make(map[string]*rate.Limiter) |
| for _, addr := range allowedAddrs { |
| limiters[addr] = newLimiter() |
| } |
| |
| redirect := redirectHandler{ |
| client: client, |
| limiters: limiters, |
| } |
| |
| mux := http.NewServeMux() |
| mux.Handle("/", redirect) |
| s := http.Server{ |
| Addr: fmt.Sprintf(":%s", port), |
| Handler: mux, |
| } |
| |
| errs := make(chan error) |
| defer close(errs) |
| go func() { |
| log.Printf("starting a GCS proxy server at localhost:%s", port) |
| errs <- s.ListenAndServe() |
| }() |
| |
| shutdown := func() error { |
| log.Printf("Shutting down GCS proxy server at localhost:%s", port) |
| return s.Shutdown(ctx) |
| } |
| |
| select { |
| case <-ctx.Done(): |
| return shutdown() |
| case err := <-errs: |
| shutdown() |
| return err |
| } |
| } |
| |
| func newLimiter() *rate.Limiter { |
| limit := rate.Every(tokenRefreshRate) |
| return rate.NewLimiter(limit, tokenBurstSize) |
| } |
| |
| // RedirectHandler is a simple handler that redirects requests to GCS. |
| type redirectHandler struct { |
| client *http.Client |
| limiters map[string]*rate.Limiter |
| } |
| |
| func (h redirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
| if req == nil { |
| http.Error(w, "cannot handle nil request", http.StatusInternalServerError) // = 501 |
| return |
| } |
| ctx := context.Background() |
| |
| limiter, ok := h.getLimiter(req.RemoteAddr) |
| if !ok { |
| http.Error( |
| w, |
| fmt.Sprintf("address %q is not authorized to make requests", req.RemoteAddr), |
| http.StatusForbidden, // = 403 |
| ) |
| return |
| } |
| |
| if !limiter.Allow() { |
| if err := limiter.Wait(ctx); err != nil { |
| http.Error(w, fmt.Sprintf("rate-limiting error: %v", err), http.StatusInternalServerError) |
| } |
| } |
| |
| req.Host = gcsHost |
| req.URL.Host = gcsHost |
| req.URL.Scheme = "https" |
| // It is an error to set this field in an HTTP client request |
| // See https://golang.org/pkg/net/http/#Request. |
| req.RequestURI = "" |
| |
| var resp *http.Response |
| backoff := retry.WithMaxRetries(retry.NewConstantBackoff(retryBackoff), retryAttempts) |
| err := retry.Retry(ctx, backoff, func() error { |
| var err error |
| resp, err = h.client.Do(req) |
| return err |
| }, nil) |
| if err != nil { |
| http.Error(w, err.Error(), http.StatusInternalServerError) |
| return |
| } else if resp == nil { |
| http.Error(w, "received a nil response", http.StatusInternalServerError) |
| } |
| |
| for k, v := range resp.Header { |
| for _, s := range v { |
| w.Header().Add(k, s) |
| } |
| } |
| |
| if _, err := io.Copy(w, resp.Body); err != nil { |
| http.Error(w, err.Error(), http.StatusInternalServerError) |
| return |
| } |
| } |
| |
| // GetLimiter returns the limiter associated with a given address and whether |
| // the address is allowed to make requests. |
| func (h *redirectHandler) getLimiter(addr string) (*rate.Limiter, bool) { |
| var limiter *rate.Limiter |
| var ok bool |
| if limiter, ok = h.limiters[addr]; !ok { |
| hostname := strings.Split(addr, ":")[0] |
| // While the full address may have been in the `-allowed` list, its |
| // hostname may be; in that case, dynamically add a limiter for that address. |
| if _, ok = h.limiters[hostname]; ok { |
| limiter = newLimiter() |
| h.limiters[addr] = limiter |
| } |
| } |
| return limiter, ok |
| } |
| |
| // FindCredentials returns the credentials in a provided file, or the default application credentials. |
| func findCredentials(ctx context.Context, credFile string) (*google.Credentials, error) { |
| if credFile == "" { |
| return nil, nil |
| } |
| |
| contents, err := ioutil.ReadFile(credFile) |
| if err != nil { |
| return nil, err |
| } |
| return google.CredentialsFromJSON(ctx, contents, gcsReadOnlyScope) |
| } |
| |
| func isIn(s string, l []string) bool { |
| if l == nil { |
| return false |
| } |
| for _, t := range l { |
| if s == t { |
| return true |
| } |
| } |
| return false |
| } |