// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can
// found in the LICENSE file.

package main

import (
	"context"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"io/ioutil"
	"log"
	"net/http"
	"os"
	"os/signal"
	"strings"
	"sync"
	"syscall"
	"time"

	"golang.org/x/oauth2"
	"golang.org/x/oauth2/google"
	"golang.org/x/time/rate"

	"go.fuchsia.dev/fuchsia/tools/lib/retry"
	"go.fuchsia.dev/fuchsia/tools/lib/runner"
)

const (
	// OAuth2 scope for reading GCS.
	// See https://cloud.google.com/storage/docs/authentication.
	gcsReadOnlyScope = "https://www.googleapis.com/auth/devstorage.read_only"

	// See https://cloud.google.com/storage/docs/request-endpoints.
	gcsHost = "storage.googleapis.com"

	// We want to allow at most `maxNumRequests` per IP to be serviced every
	// `refreshWindowMs` milliseconds.
	// In terms of the token bucket underlying rate.Limiter, this translates to a
	// new token refreshed every `tokenRefreshRate` with a burst size of
	// `burstSize`: this allows for a token pool in which we can check Allow()
	// and Wait() at the desired rates without consuming any tokens reserved for
	// servicing of other requests.
	// See https://godoc.org/golang.org/x/time/rate#Limiter for more details.
	maxNumRequests        = 20
	refreshWindowMs       = 200
	tokenBudgetPerRequest = 2
	tokenRefreshRate      = (refreshWindowMs / (tokenBudgetPerRequest * maxNumRequests)) * time.Millisecond
	tokenBurstSize        = 2 * maxNumRequests

	// Constants for retrying communication with GCS.
	retryBackoff  = 100 * time.Millisecond
	retryAttempts = 10
)

var (
	credentialsFile  string
	port             string
	allowedAddrsFile string
)

func usage() {
	fmt.Printf(`gcsproxy [flags] [subcommand]

Starts a proxy server that forwards requests to GCS with authentication.
If positional arguments are provided, they will be run as subprocess and
the lifetime of the server will be scoped to the lifetime of that process.
`)
}

func init() {
	flag.Usage = usage
	flag.StringVar(&credentialsFile, "credentials", "", "path to a credentials file in the Google Credentials File format; if none provided, default application credentials will be used.")
	flag.StringVar(&port, "port", "", "port at which the server should listen.")
	flag.StringVar(&allowedAddrsFile, "allowed", "", "a flat JSON list of remote addresses allowed to make requests of the proxy server; if not provided, all addresses will be allowed")
}

func main() {
	flag.Parse()

	// For a graceful teardown in the event of a canceling signal.
	ctx, cancel := context.WithCancel(context.Background())
	signals := make(chan os.Signal)
	defer close(signals)
	signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)

	go func() {
		select {
		case <-signals:
			cancel()
		}
	}()

	if err := execute(ctx, credentialsFile, port, allowedAddrsFile, flag.Args()); err != nil {
		log.Fatal(err)
	}
}

// Execute starts the proxy server.
func execute(ctx context.Context, credFile, port, addrsFile string, subCmd []string) error {
	if port == "" {
		return fmt.Errorf("-port is required")
	}

	client, err := httpClient(ctx, credFile)
	if err != nil {
		return err
	}

	var allowedAddrs []string
	if allowedAddrsFile != "" {
		b, err := ioutil.ReadFile(allowedAddrsFile)
		if err != nil {
			return err
		}
		if err := json.Unmarshal(b, &allowedAddrs); err != nil {
			return err
		}
	}

	limiters := new(sync.Map)
	for _, addr := range allowedAddrs {
		limiters.Store(addr, newLimiter())
	}

	redirect := &redirectHandler{
		client:        client,
		restrictAddrs: allowedAddrs != nil,
		limiters:      limiters,
	}

	mux := http.NewServeMux()
	mux.Handle("/", redirect)
	s := http.Server{
		Addr:    fmt.Sprintf(":%s", port),
		Handler: mux,
	}

	errs := make(chan error)
	go func() {
		log.Printf("starting a GCS proxy server at localhost:%s", port)
		errs <- s.ListenAndServe()
	}()

	if len(subCmd) > 0 {
		r := runner.SubprocessRunner{
			Env: os.Environ(),
		}
		go func() {
			errs <- r.Run(ctx, subCmd, os.Stdout, os.Stderr)
		}()
	}

	shutdown := func() error {
		log.Printf("Shutting down GCS proxy server at localhost:%s", port)
		return s.Shutdown(ctx)
	}

	select {
	case <-ctx.Done():
		return shutdown()
	case err := <-errs:
		shutdown()
		return err
	}
}

func newLimiter() *rate.Limiter {
	limit := rate.Every(tokenRefreshRate)
	return rate.NewLimiter(limit, tokenBurstSize)
}

// RedirectHandler is a simple handler that redirects requests to GCS.
type redirectHandler struct {
	client *http.Client
	// Whether to only serve to addresses present in the limiter map.
	restrictAddrs bool
	// limiters is a map of string: *rate.Limiter.
	limiters *sync.Map
}

func (h *redirectHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	if req == nil {
		http.Error(w, "cannot handle nil request", http.StatusInternalServerError) // = 501
		return
	}
	ctx := context.Background()

	limiter, ok := h.getLimiter(req.RemoteAddr)
	if !ok {
		http.Error(
			w,
			fmt.Sprintf("address %q is not authorized to make requests", req.RemoteAddr),
			http.StatusForbidden, // = 403
		)
		return
	}

	if !limiter.Allow() {
		if err := limiter.Wait(ctx); err != nil {
			http.Error(w, fmt.Sprintf("rate-limiting error: %v", err), http.StatusInternalServerError)
		}
	}

	req.Host = gcsHost
	req.URL.Host = gcsHost
	req.URL.Scheme = "https"
	// It is an error to set this field in an HTTP client request
	// See https://golang.org/pkg/net/http/#Request.
	req.RequestURI = ""

	var resp *http.Response
	backoff := retry.WithMaxAttempts(retry.NewConstantBackoff(retryBackoff), retryAttempts)
	err := retry.Retry(ctx, backoff, func() error {
		var err error
		resp, err = h.client.Do(req)
		return err
	}, nil)
	if err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	} else if resp == nil {
		http.Error(w, "received a nil response", http.StatusInternalServerError)
	}

	for k, v := range resp.Header {
		for _, s := range v {
			w.Header().Add(k, s)
		}
	}
	w.WriteHeader(resp.StatusCode)

	if _, err := io.Copy(w, resp.Body); err != nil {
		http.Error(w, err.Error(), http.StatusInternalServerError)
		return
	}
}

// GetLimiter returns the limiter associated with a given address and whether
// the address is allowed to make requests.
func (h *redirectHandler) getLimiter(addr string) (*rate.Limiter, bool) {
	var limiter interface{}
	var ok bool
	if limiter, ok = h.limiters.Load(addr); !ok {
		// While the full address may have been in the `-allowed` list, its
		// hostname may be; in that case, dynamically add a limiter for that address.
		hostname := strings.Split(addr, ":")[0]
		_, hostOK := h.limiters.Load(hostname)

		if !h.restrictAddrs || hostOK {
			limiter = newLimiter()
			h.limiters.Store(addr, limiter)
			ok = true
		}
	}
	return limiter.(*rate.Limiter), ok
}

// Returns an HTTP client with the credentials to read from GCS. If no
// credential file is supplied, then the default application credentials
// will be used.
func httpClient(ctx context.Context, credFile string) (*http.Client, error) {
	var creds *google.Credentials
	var err error
	if credFile == "" {
		creds, err = google.FindDefaultCredentials(ctx, gcsReadOnlyScope)
		if err != nil {
			return nil, fmt.Errorf("failed to find default credentials: %w", err)
		}
	} else {
		contents, err := ioutil.ReadFile(credFile)
		if err != nil {
			return nil, err
		}
		creds, err = google.CredentialsFromJSON(ctx, contents, gcsReadOnlyScope)
		if err != nil {
			return nil, fmt.Errorf("failed to derive the derive credentials from supplied file: %w", err)
		}
	}
	return oauth2.NewClient(ctx, creds.TokenSource), nil
}

func isIn(s string, l []string) bool {
	if l == nil {
		return false
	}
	for _, t := range l {
		if s == t {
			return true
		}
	}
	return false
}
