blob: 9a944d91cd46da846441053078e1787f4e93c72d [file] [log] [blame]
// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can
// found in the LICENSE file.
package main
import (
const (
// OAuth2 scope for reading GCS.
// See
gcsReadOnlyScope = ""
// See
gcsHost = ""
// We want to allow at most `maxNumRequests` per IP to be serviced every
// `refreshWindowMs` milliseconds.
// In terms of the token bucket underlying rate.Limiter, this translates to a
// new token refreshed every `tokenRefreshRate` with a burst size of
// `burstSize`: this allows for a token pool in which we can check Allow()
// and Wait() at the desired rates without consuming any tokens reserved for
// servicing of other requests.
// See for more details.
maxNumRequests = 20
refreshWindowMs = 200
tokenBudgetPerRequest = 2
tokenRefreshRate = (refreshWindowMs / (tokenBudgetPerRequest * maxNumRequests)) * time.Millisecond
tokenBurstSize = 2 * maxNumRequests
// Constants for retrying communication with GCS.
retryBackoff = 100 * time.Millisecond
retryAttempts = 10
var (
credentialsFile string
port string
allowedAddrsFile string
func usage() {
fmt.Printf(`gcsproxy [flags]
Starts a proxy server that forwards requests to GCS with authentication.
func init() {
flag.Usage = usage
flag.StringVar(&credentialsFile, "credentials", "", "path to a credentials file in the Google Credentials File format; if none provided, default application credentials will be used.")
flag.StringVar(&port, "port", "", "port at which the server should listen.")
flag.StringVar(&allowedAddrsFile, "allowed", "", "a flat JSON list of remote addresses allowed to make requests of the proxy server")
func main() {
// For a graceful teardown in the event of a canceling signal.
ctx, cancel := context.WithCancel(context.Background())
signals := make(chan os.Signal)
defer close(signals)
signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
go func() {
select {
case <-signals:
if err := execute(ctx, credentialsFile, port, allowedAddrsFile); err != nil {
// Execute starts the proxy server.
func execute(ctx context.Context, credFile, port, addrsFile string) error {
if port == "" {
return fmt.Errorf("-port is required")
} else if allowedAddrsFile == "" {
return fmt.Errorf("-allowed is required")
credentials, err := findCredentials(ctx, credFile)
if err != nil {
return err
var client *http.Client
if credentials == nil {
client = http.DefaultClient
} else {
client = oauth2.NewClient(ctx, credentials.TokenSource)
var allowedAddrs []string
b, err := ioutil.ReadFile(addrsFile)
if err != nil {
return err
if err := json.Unmarshal(b, &allowedAddrs); err != nil {
return err
limiters := make(map[string]*rate.Limiter)
for _, addr := range allowedAddrs {
limiters[addr] = newLimiter()
redirect := redirectHandler{
client: client,
limiters: limiters,
mux := http.NewServeMux()
mux.Handle("/", redirect)
s := http.Server{
Addr: fmt.Sprintf(":%s", port),
Handler: mux,
errs := make(chan error)
defer close(errs)
go func() {
log.Printf("starting a GCS proxy server at localhost:%s", port)
errs <- s.ListenAndServe()
shutdown := func() error {
log.Printf("Shutting down GCS proxy server at localhost:%s", port)
return s.Shutdown(ctx)
select {
case <-ctx.Done():
return shutdown()
case err := <-errs:
return err
func newLimiter() *rate.Limiter {
limit := rate.Every(tokenRefreshRate)
return rate.NewLimiter(limit, tokenBurstSize)
// RedirectHandler is a simple handler that redirects requests to GCS.
type redirectHandler struct {
client *http.Client
limiters map[string]*rate.Limiter
func (h redirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req == nil {
http.Error(w, "cannot handle nil request", http.StatusInternalServerError) // = 501
ctx := context.Background()
limiter, ok := h.getLimiter(req.RemoteAddr)
if !ok {
fmt.Sprintf("address %q is not authorized to make requests", req.RemoteAddr),
http.StatusForbidden, // = 403
if !limiter.Allow() {
if err := limiter.Wait(ctx); err != nil {
http.Error(w, fmt.Sprintf("rate-limiting error: %v", err), http.StatusInternalServerError)
req.Host = gcsHost
req.URL.Host = gcsHost
req.URL.Scheme = "https"
// It is an error to set this field in an HTTP client request
// See
req.RequestURI = ""
var resp *http.Response
backoff := retry.WithMaxRetries(retry.NewConstantBackoff(retryBackoff), retryAttempts)
err := retry.Retry(ctx, backoff, func() error {
var err error
resp, err = h.client.Do(req)
return err
}, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else if resp == nil {
http.Error(w, "received a nil response", http.StatusInternalServerError)
for k, v := range resp.Header {
for _, s := range v {
w.Header().Add(k, s)
if _, err := io.Copy(w, resp.Body); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
// GetLimiter returns the limiter associated with a given address and whether
// the address is allowed to make requests.
func (h *redirectHandler) getLimiter(addr string) (*rate.Limiter, bool) {
var limiter *rate.Limiter
var ok bool
if limiter, ok = h.limiters[addr]; !ok {
hostname := strings.Split(addr, ":")[0]
// While the full address may have been in the `-allowed` list, its
// hostname may be; in that case, dynamically add a limiter for that address.
if _, ok = h.limiters[hostname]; ok {
limiter = newLimiter()
h.limiters[addr] = limiter
return limiter, ok
// FindCredentials returns the credentials in a provided file, or the default application credentials.
func findCredentials(ctx context.Context, credFile string) (*google.Credentials, error) {
if credFile == "" {
return nil, nil
contents, err := ioutil.ReadFile(credFile)
if err != nil {
return nil, err
return google.CredentialsFromJSON(ctx, contents, gcsReadOnlyScope)
func isIn(s string, l []string) bool {
if l == nil {
return false
for _, t := range l {
if s == t {
return true
return false