| // Copyright 2019 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can |
| // found in the LICENSE file. |
| |
| package main |
| |
| import ( |
| "context" |
| "encoding/json" |
| "flag" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "log" |
| "net/http" |
| "os" |
| "os/signal" |
| "strings" |
| "sync" |
| "syscall" |
| "time" |
| |
| "golang.org/x/oauth2" |
| "golang.org/x/oauth2/google" |
| "golang.org/x/time/rate" |
| |
| "go.fuchsia.dev/fuchsia/tools/lib/retry" |
| "go.fuchsia.dev/fuchsia/tools/lib/runner" |
| ) |
| |
| const ( |
| // OAuth2 scope for reading GCS. |
| // See https://cloud.google.com/storage/docs/authentication. |
| gcsReadOnlyScope = "https://www.googleapis.com/auth/devstorage.read_only" |
| |
| // See https://cloud.google.com/storage/docs/request-endpoints. |
| gcsHost = "storage.googleapis.com" |
| |
| // We want to allow at most `maxNumRequests` per IP to be serviced every |
| // `refreshWindowMs` milliseconds. |
| // In terms of the token bucket underlying rate.Limiter, this translates to a |
| // new token refreshed every `tokenRefreshRate` with a burst size of |
| // `burstSize`: this allows for a token pool in which we can check Allow() |
| // and Wait() at the desired rates without consuming any tokens reserved for |
| // servicing of other requests. |
| // See https://godoc.org/golang.org/x/time/rate#Limiter for more details. |
| maxNumRequests = 20 |
| refreshWindowMs = 200 |
| tokenBudgetPerRequest = 2 |
| tokenRefreshRate = (refreshWindowMs / (tokenBudgetPerRequest * maxNumRequests)) * time.Millisecond |
| tokenBurstSize = 2 * maxNumRequests |
| |
| // Constants for retrying communication with GCS. |
| retryBackoff = 100 * time.Millisecond |
| retryAttempts = 10 |
| ) |
| |
| var ( |
| credentialsFile string |
| port string |
| allowedAddrsFile string |
| ) |
| |
| func usage() { |
| fmt.Printf(`gcsproxy [flags] [subcommand] |
| |
| Starts a proxy server that forwards requests to GCS with authentication. |
| If positional arguments are provided, they will be run as subprocess and |
| the lifetime of the server will be scoped to the lifetime of that process. |
| `) |
| } |
| |
| func init() { |
| flag.Usage = usage |
| flag.StringVar(&credentialsFile, "credentials", "", "path to a credentials file in the Google Credentials File format; if none provided, default application credentials will be used.") |
| flag.StringVar(&port, "port", "", "port at which the server should listen.") |
| flag.StringVar(&allowedAddrsFile, "allowed", "", "a flat JSON list of remote addresses allowed to make requests of the proxy server; if not provided, all addresses will be allowed") |
| } |
| |
| func main() { |
| flag.Parse() |
| |
| // For a graceful teardown in the event of a canceling signal. |
| ctx, cancel := context.WithCancel(context.Background()) |
| signals := make(chan os.Signal) |
| defer close(signals) |
| signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) |
| |
| go func() { |
| select { |
| case <-signals: |
| cancel() |
| } |
| }() |
| |
| if err := execute(ctx, credentialsFile, port, allowedAddrsFile, flag.Args()); err != nil { |
| log.Fatal(err) |
| } |
| } |
| |
| // Execute starts the proxy server. |
| func execute(ctx context.Context, credFile, port, addrsFile string, subCmd []string) error { |
| if port == "" { |
| return fmt.Errorf("-port is required") |
| } |
| |
| client, err := httpClient(ctx, credFile) |
| if err != nil { |
| return err |
| } |
| |
| var allowedAddrs []string |
| if allowedAddrsFile != "" { |
| b, err := ioutil.ReadFile(allowedAddrsFile) |
| if err != nil { |
| return err |
| } |
| if err := json.Unmarshal(b, &allowedAddrs); err != nil { |
| return err |
| } |
| } |
| |
| limiters := new(sync.Map) |
| for _, addr := range allowedAddrs { |
| limiters.Store(addr, newLimiter()) |
| } |
| |
| redirect := &redirectHandler{ |
| client: client, |
| restrictAddrs: allowedAddrs != nil, |
| limiters: limiters, |
| } |
| |
| mux := http.NewServeMux() |
| mux.Handle("/", redirect) |
| s := http.Server{ |
| Addr: fmt.Sprintf(":%s", port), |
| Handler: mux, |
| } |
| |
| errs := make(chan error) |
| go func() { |
| log.Printf("starting a GCS proxy server at localhost:%s", port) |
| errs <- s.ListenAndServe() |
| }() |
| |
| if len(subCmd) > 0 { |
| r := runner.SubprocessRunner{ |
| Env: os.Environ(), |
| } |
| go func() { |
| errs <- r.Run(ctx, subCmd, os.Stdout, os.Stderr) |
| }() |
| } |
| |
| shutdown := func() error { |
| log.Printf("Shutting down GCS proxy server at localhost:%s", port) |
| return s.Shutdown(ctx) |
| } |
| |
| select { |
| case <-ctx.Done(): |
| return shutdown() |
| case err := <-errs: |
| shutdown() |
| return err |
| } |
| } |
| |
| func newLimiter() *rate.Limiter { |
| limit := rate.Every(tokenRefreshRate) |
| return rate.NewLimiter(limit, tokenBurstSize) |
| } |
| |
| // RedirectHandler is a simple handler that redirects requests to GCS. |
| type redirectHandler struct { |
| client *http.Client |
| // Whether to only serve to addresses present in the limiter map. |
| restrictAddrs bool |
| // limiters is a map of string: *rate.Limiter. |
| limiters *sync.Map |
| } |
| |
| func (h *redirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
| if req == nil { |
| http.Error(w, "cannot handle nil request", http.StatusInternalServerError) // = 501 |
| return |
| } |
| ctx := context.Background() |
| |
| limiter, ok := h.getLimiter(req.RemoteAddr) |
| if !ok { |
| http.Error( |
| w, |
| fmt.Sprintf("address %q is not authorized to make requests", req.RemoteAddr), |
| http.StatusForbidden, // = 403 |
| ) |
| return |
| } |
| |
| if !limiter.Allow() { |
| if err := limiter.Wait(ctx); err != nil { |
| http.Error(w, fmt.Sprintf("rate-limiting error: %v", err), http.StatusInternalServerError) |
| } |
| } |
| |
| req.Host = gcsHost |
| req.URL.Host = gcsHost |
| req.URL.Scheme = "https" |
| // It is an error to set this field in an HTTP client request |
| // See https://golang.org/pkg/net/http/#Request. |
| req.RequestURI = "" |
| |
| var resp *http.Response |
| backoff := retry.WithMaxAttempts(retry.NewConstantBackoff(retryBackoff), retryAttempts) |
| err := retry.Retry(ctx, backoff, func() error { |
| var err error |
| resp, err = h.client.Do(req) |
| return err |
| }, nil) |
| if err != nil { |
| http.Error(w, err.Error(), http.StatusInternalServerError) |
| return |
| } else if resp == nil { |
| http.Error(w, "received a nil response", http.StatusInternalServerError) |
| } |
| |
| for k, v := range resp.Header { |
| for _, s := range v { |
| w.Header().Add(k, s) |
| } |
| } |
| w.WriteHeader(resp.StatusCode) |
| |
| if _, err := io.Copy(w, resp.Body); err != nil { |
| http.Error(w, err.Error(), http.StatusInternalServerError) |
| return |
| } |
| } |
| |
| // GetLimiter returns the limiter associated with a given address and whether |
| // the address is allowed to make requests. |
| func (h *redirectHandler) getLimiter(addr string) (*rate.Limiter, bool) { |
| var limiter interface{} |
| var ok bool |
| if limiter, ok = h.limiters.Load(addr); !ok { |
| // While the full address may have been in the `-allowed` list, its |
| // hostname may be; in that case, dynamically add a limiter for that address. |
| hostname := strings.Split(addr, ":")[0] |
| _, hostOK := h.limiters.Load(hostname) |
| |
| if !h.restrictAddrs || hostOK { |
| limiter = newLimiter() |
| h.limiters.Store(addr, limiter) |
| ok = true |
| } |
| } |
| return limiter.(*rate.Limiter), ok |
| } |
| |
| // Returns an HTTP client with the credentials to read from GCS. If no |
| // credential file is supplied, then the default application credentials |
| // will be used. |
| func httpClient(ctx context.Context, credFile string) (*http.Client, error) { |
| var creds *google.Credentials |
| var err error |
| if credFile == "" { |
| creds, err = google.FindDefaultCredentials(ctx, gcsReadOnlyScope) |
| if err != nil { |
| return nil, fmt.Errorf("failed to find default credentials: %w", err) |
| } |
| } else { |
| contents, err := ioutil.ReadFile(credFile) |
| if err != nil { |
| return nil, err |
| } |
| creds, err = google.CredentialsFromJSON(ctx, contents, gcsReadOnlyScope) |
| if err != nil { |
| return nil, fmt.Errorf("failed to derive the derive credentials from supplied file: %w", err) |
| } |
| } |
| return oauth2.NewClient(ctx, creds.TokenSource), nil |
| } |
| |
| func isIn(s string, l []string) bool { |
| if l == nil { |
| return false |
| } |
| for _, t := range l { |
| if s == t { |
| return true |
| } |
| } |
| return false |
| } |