blob: efb52b8c18ee0e2924bbcd1d833805a2157fe41e [file] [log] [blame]
// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can
// found in the LICENSE file.
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"golang.org/x/time/rate"
"go.fuchsia.dev/fuchsia/tools/lib/retry"
)
const (
// OAuth2 scope for reading GCS.
// See https://cloud.google.com/storage/docs/authentication.
gcsReadOnlyScope = "https://www.googleapis.com/auth/devstorage.read_only"
// See https://cloud.google.com/storage/docs/request-endpoints.
gcsHost = "storage.googleapis.com"
// We want to allow at most `maxNumRequests` per IP to be serviced every
// `refreshWindowMs` milliseconds.
// In terms of the token bucket underlying rate.Limiter, this translates to a
// new token refreshed every `tokenRefreshRate` with a burst size of
// `burstSize`: this allows for a token pool in which we can check Allow()
// and Wait() at the desired rates without consuming any tokens reserved for
// servicing of other requests.
// See https://godoc.org/golang.org/x/time/rate#Limiter for more details.
maxNumRequests = 20
refreshWindowMs = 200
tokenBudgetPerRequest = 2
tokenRefreshRate = (refreshWindowMs / (tokenBudgetPerRequest * maxNumRequests)) * time.Millisecond
tokenBurstSize = 2 * maxNumRequests
// Constants for retrying communication with GCS.
retryBackoff = 100 * time.Millisecond
retryAttempts = 10
)
var (
credentialsFile string
port string
allowedAddrsFile string
)
func usage() {
fmt.Printf(`gcsproxy [flags]
Starts a proxy server that forwards requests to GCS with authentication.
`)
}
func init() {
flag.Usage = usage
flag.StringVar(&credentialsFile, "credentials", "", "path to a credentials file in the Google Credentials File format; if none provided, default application credentials will be used.")
flag.StringVar(&port, "port", "", "port at which the server should listen.")
flag.StringVar(&allowedAddrsFile, "allowed", "", "a flat JSON list of remote addresses allowed to make requests of the proxy server")
}
func main() {
flag.Parse()
// For a graceful teardown in the event of a canceling signal.
ctx, cancel := context.WithCancel(context.Background())
signals := make(chan os.Signal)
defer close(signals)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
go func() {
select {
case <-signals:
cancel()
}
}()
if err := execute(ctx, credentialsFile, port, allowedAddrsFile); err != nil {
log.Fatal(err)
}
}
// Execute starts the proxy server.
func execute(ctx context.Context, credFile, port, addrsFile string) error {
if port == "" {
return fmt.Errorf("-port is required")
} else if allowedAddrsFile == "" {
return fmt.Errorf("-allowed is required")
}
credentials, err := findCredentials(ctx, credFile)
if err != nil {
return err
}
var client *http.Client
if credentials == nil {
client = http.DefaultClient
} else {
client = oauth2.NewClient(ctx, credentials.TokenSource)
}
var allowedAddrs []string
b, err := ioutil.ReadFile(addrsFile)
if err != nil {
return err
}
if err := json.Unmarshal(b, &allowedAddrs); err != nil {
return err
}
limiters := new(sync.Map)
for _, addr := range allowedAddrs {
limiters.Store(addr, newLimiter())
}
redirect := &redirectHandler{
client: client,
limiters: limiters,
}
mux := http.NewServeMux()
mux.Handle("/", redirect)
s := http.Server{
Addr: fmt.Sprintf(":%s", port),
Handler: mux,
}
errs := make(chan error)
defer close(errs)
go func() {
log.Printf("starting a GCS proxy server at localhost:%s", port)
errs <- s.ListenAndServe()
}()
shutdown := func() error {
log.Printf("Shutting down GCS proxy server at localhost:%s", port)
return s.Shutdown(ctx)
}
select {
case <-ctx.Done():
return shutdown()
case err := <-errs:
shutdown()
return err
}
}
func newLimiter() *rate.Limiter {
limit := rate.Every(tokenRefreshRate)
return rate.NewLimiter(limit, tokenBurstSize)
}
// RedirectHandler is a simple handler that redirects requests to GCS.
type redirectHandler struct {
client *http.Client
// limiters is a map of string: *rate.Limiter.
limiters *sync.Map
}
func (h *redirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req == nil {
http.Error(w, "cannot handle nil request", http.StatusInternalServerError) // = 501
return
}
ctx := context.Background()
limiter, ok := h.getLimiter(req.RemoteAddr)
if !ok {
http.Error(
w,
fmt.Sprintf("address %q is not authorized to make requests", req.RemoteAddr),
http.StatusForbidden, // = 403
)
return
}
if !limiter.Allow() {
if err := limiter.Wait(ctx); err != nil {
http.Error(w, fmt.Sprintf("rate-limiting error: %v", err), http.StatusInternalServerError)
}
}
req.Host = gcsHost
req.URL.Host = gcsHost
req.URL.Scheme = "https"
// It is an error to set this field in an HTTP client request
// See https://golang.org/pkg/net/http/#Request.
req.RequestURI = ""
var resp *http.Response
backoff := retry.WithMaxRetries(retry.NewConstantBackoff(retryBackoff), retryAttempts)
err := retry.Retry(ctx, backoff, func() error {
var err error
resp, err = h.client.Do(req)
return err
}, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
} else if resp == nil {
http.Error(w, "received a nil response", http.StatusInternalServerError)
}
for k, v := range resp.Header {
for _, s := range v {
w.Header().Add(k, s)
}
}
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resp.Body); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
// GetLimiter returns the limiter associated with a given address and whether
// the address is allowed to make requests.
func (h *redirectHandler) getLimiter(addr string) (*rate.Limiter, bool) {
var limiter interface{}
var ok bool
if limiter, ok = h.limiters.Load(addr); !ok {
hostname := strings.Split(addr, ":")[0]
// While the full address may have been in the `-allowed` list, its
// hostname may be; in that case, dynamically add a limiter for that address.
if _, ok = h.limiters.Load(hostname); ok {
limiter = newLimiter()
h.limiters.Store(addr, limiter)
}
}
return limiter.(*rate.Limiter), ok
}
// FindCredentials returns the credentials in a provided file, or the default application credentials.
func findCredentials(ctx context.Context, credFile string) (*google.Credentials, error) {
if credFile == "" {
return nil, nil
}
contents, err := ioutil.ReadFile(credFile)
if err != nil {
return nil, err
}
return google.CredentialsFromJSON(ctx, contents, gcsReadOnlyScope)
}
func isIn(s string, l []string) bool {
if l == nil {
return false
}
for _, t := range l {
if s == t {
return true
}
}
return false
}