blob: ab1738bb2db226ac937e01fc3ab0c855a81226fb [file] [log] [blame]
// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can
// found in the LICENSE file.
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"golang.org/x/time/rate"
"go.fuchsia.dev/fuchsia/tools/lib/retry"
"go.fuchsia.dev/fuchsia/tools/lib/runner"
)
const (
// OAuth2 scope for reading GCS.
// See https://cloud.google.com/storage/docs/authentication.
gcsReadOnlyScope = "https://www.googleapis.com/auth/devstorage.read_only"
// See https://cloud.google.com/storage/docs/request-endpoints.
gcsHost = "storage.googleapis.com"
// We want to allow at most `maxNumRequests` per IP to be serviced every
// `refreshWindowMs` milliseconds.
// In terms of the token bucket underlying rate.Limiter, this translates to a
// new token refreshed every `tokenRefreshRate` with a burst size of
// `burstSize`: this allows for a token pool in which we can check Allow()
// and Wait() at the desired rates without consuming any tokens reserved for
// servicing of other requests.
// See https://godoc.org/golang.org/x/time/rate#Limiter for more details.
maxNumRequests = 20
refreshWindowMs = 200
tokenBudgetPerRequest = 2
tokenRefreshRate = (refreshWindowMs / (tokenBudgetPerRequest * maxNumRequests)) * time.Millisecond
tokenBurstSize = 2 * maxNumRequests
// Constants for retrying communication with GCS.
retryBackoff = 100 * time.Millisecond
retryAttempts = 10
)
var (
credentialsFile string
port string
allowedAddrsFile string
)
func usage() {
fmt.Printf(`gcsproxy [flags] [subcommand]
Starts a proxy server that forwards requests to GCS with authentication.
If positional arguments are provided, they will be run as subprocess and
the lifetime of the server will be scoped to the lifetime of that process.
`)
}
func init() {
flag.Usage = usage
flag.StringVar(&credentialsFile, "credentials", "", "path to a credentials file in the Google Credentials File format; if none provided, default application credentials will be used.")
flag.StringVar(&port, "port", "", "port at which the server should listen.")
flag.StringVar(&allowedAddrsFile, "allowed", "", "a flat JSON list of remote addresses allowed to make requests of the proxy server; if not provided, all addresses will be allowed")
}
func main() {
flag.Parse()
// For a graceful teardown in the event of a canceling signal.
ctx, cancel := context.WithCancel(context.Background())
signals := make(chan os.Signal)
defer close(signals)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
go func() {
select {
case <-signals:
cancel()
}
}()
if err := execute(ctx, credentialsFile, port, allowedAddrsFile, flag.Args()); err != nil {
log.Fatal(err)
}
}
// Execute starts the proxy server.
func execute(ctx context.Context, credFile, port, addrsFile string, subCmd []string) error {
if port == "" {
return fmt.Errorf("-port is required")
}
credentials, err := findCredentials(ctx, credFile)
if err != nil {
return err
}
var client *http.Client
if credentials == nil {
client = http.DefaultClient
} else {
client = oauth2.NewClient(ctx, credentials.TokenSource)
}
var allowedAddrs []string
if allowedAddrsFile != "" {
b, err := ioutil.ReadFile(allowedAddrsFile)
if err != nil {
return err
}
if err := json.Unmarshal(b, &allowedAddrs); err != nil {
return err
}
}
limiters := new(sync.Map)
for _, addr := range allowedAddrs {
limiters.Store(addr, newLimiter())
}
redirect := &redirectHandler{
client: client,
restrictAddrs: allowedAddrs != nil,
limiters: limiters,
}
mux := http.NewServeMux()
mux.Handle("/", redirect)
s := http.Server{
Addr: fmt.Sprintf(":%s", port),
Handler: mux,
}
errs := make(chan error)
defer close(errs)
go func() {
log.Printf("starting a GCS proxy server at localhost:%s", port)
errs <- s.ListenAndServe()
}()
if len(subCmd) > 0 {
r := runner.SubprocessRunner{
Env: os.Environ(),
}
go func() {
errs <- r.Run(ctx, subCmd, os.Stdout, os.Stderr)
}()
}
shutdown := func() error {
log.Printf("Shutting down GCS proxy server at localhost:%s", port)
return s.Shutdown(ctx)
}
select {
case <-ctx.Done():
return shutdown()
case err := <-errs:
shutdown()
return err
}
}
func newLimiter() *rate.Limiter {
limit := rate.Every(tokenRefreshRate)
return rate.NewLimiter(limit, tokenBurstSize)
}
// RedirectHandler is a simple handler that redirects requests to GCS.
type redirectHandler struct {
client *http.Client
// Whether to only serve to addresses present in the limiter map.
restrictAddrs bool
// limiters is a map of string: *rate.Limiter.
limiters *sync.Map
}
func (h *redirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req == nil {
http.Error(w, "cannot handle nil request", http.StatusInternalServerError) // = 501
return
}
ctx := context.Background()
limiter, ok := h.getLimiter(req.RemoteAddr)
if !ok {
http.Error(
w,
fmt.Sprintf("address %q is not authorized to make requests", req.RemoteAddr),
http.StatusForbidden, // = 403
)
return
}
if !limiter.Allow() {
if err := limiter.Wait(ctx); err != nil {
http.Error(w, fmt.Sprintf("rate-limiting error: %v", err), http.StatusInternalServerError)
}
}
req.Host = gcsHost
req.URL.Host = gcsHost
req.URL.Scheme = "https"
// It is an error to set this field in an HTTP client request
// See https://golang.org/pkg/net/http/#Request.
req.RequestURI = ""
var resp *http.Response
backoff := retry.WithMaxRetries(retry.NewConstantBackoff(retryBackoff), retryAttempts)
err := retry.Retry(ctx, backoff, func() error {
var err error
resp, err = h.client.Do(req)
return err
}, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
} else if resp == nil {
http.Error(w, "received a nil response", http.StatusInternalServerError)
}
for k, v := range resp.Header {
for _, s := range v {
w.Header().Add(k, s)
}
}
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resp.Body); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// GetLimiter returns the limiter associated with a given address and whether
// the address is allowed to make requests.
func (h *redirectHandler) getLimiter(addr string) (*rate.Limiter, bool) {
var limiter interface{}
var ok bool
if limiter, ok = h.limiters.Load(addr); !ok {
// While the full address may have been in the `-allowed` list, its
// hostname may be; in that case, dynamically add a limiter for that address.
hostname := strings.Split(addr, ":")[0]
_, hostOK := h.limiters.Load(hostname)
if !h.restrictAddrs || hostOK {
limiter = newLimiter()
h.limiters.Store(addr, limiter)
ok = true
}
}
return limiter.(*rate.Limiter), ok
}
// FindCredentials returns the credentials in a provided file, or the default application credentials.
func findCredentials(ctx context.Context, credFile string) (*google.Credentials, error) {
if credFile == "" {
return nil, nil
}
contents, err := ioutil.ReadFile(credFile)
if err != nil {
return nil, err
}
return google.CredentialsFromJSON(ctx, contents, gcsReadOnlyScope)
}
func isIn(s string, l []string) bool {
if l == nil {
return false
}
for _, t := range l {
if s == t {
return true
}
}
return false
}