blob: 059deab71177e0b82d22c549268d2ea722513fcb [file] [log] [blame]
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impersonate
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/oauth2"
)
func user(ctx context.Context, c CredentialsConfig, client *http.Client, lifetime time.Duration, isStaticToken bool) (oauth2.TokenSource, error) {
u := userTokenSource{
client: client,
targetPrincipal: c.TargetPrincipal,
subject: c.Subject,
lifetime: lifetime,
}
u.delegates = make([]string, len(c.Delegates))
for i, v := range c.Delegates {
u.delegates[i] = formatIAMServiceAccountName(v)
}
u.scopes = make([]string, len(c.Scopes))
copy(u.scopes, c.Scopes)
if isStaticToken {
tok, err := u.Token()
if err != nil {
return nil, err
}
return oauth2.StaticTokenSource(tok), nil
}
return oauth2.ReuseTokenSource(nil, u), nil
}
type claimSet struct {
Iss string `json:"iss"`
Scope string `json:"scope,omitempty"`
Sub string `json:"sub,omitempty"`
Aud string `json:"aud"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
}
type signJWTRequest struct {
Payload string `json:"payload"`
Delegates []string `json:"delegates,omitempty"`
}
type signJWTResponse struct {
// KeyID is the key used to sign the JWT.
KeyID string `json:"keyId"`
// SignedJwt contains the automatically generated header; the
// client-supplied payload; and the signature, which is generated using
// the key referenced by the `kid` field in the header.
SignedJWT string `json:"signedJwt"`
}
type exchangeTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
type userTokenSource struct {
client *http.Client
targetPrincipal string
subject string
scopes []string
lifetime time.Duration
delegates []string
}
func (u userTokenSource) Token() (*oauth2.Token, error) {
signedJWT, err := u.signJWT()
if err != nil {
return nil, err
}
return u.exchangeToken(signedJWT)
}
func (u userTokenSource) signJWT() (string, error) {
now := time.Now()
exp := now.Add(u.lifetime)
claims := claimSet{
Iss: u.targetPrincipal,
Scope: strings.Join(u.scopes, " "),
Sub: u.subject,
Aud: fmt.Sprintf("%s/token", oauth2Endpoint),
Iat: now.Unix(),
Exp: exp.Unix(),
}
payloadBytes, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("impersonate: unable to marshal claims: %v", err)
}
signJWTReq := signJWTRequest{
Payload: string(payloadBytes),
Delegates: u.delegates,
}
bodyBytes, err := json.Marshal(signJWTReq)
if err != nil {
return "", fmt.Errorf("impersonate: unable to marshal request: %v", err)
}
reqURL := fmt.Sprintf("%s/v1/%s:signJwt", iamCredentailsEndpoint, formatIAMServiceAccountName(u.targetPrincipal))
req, err := http.NewRequest("POST", reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return "", fmt.Errorf("impersonate: unable to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
rawResp, err := u.client.Do(req)
if err != nil {
return "", fmt.Errorf("impersonate: unable to sign JWT: %v", err)
}
body, err := ioutil.ReadAll(io.LimitReader(rawResp.Body, 1<<20))
if err != nil {
return "", fmt.Errorf("impersonate: unable to read body: %v", err)
}
if c := rawResp.StatusCode; c < 200 || c > 299 {
return "", fmt.Errorf("impersonate: status code %d: %s", c, body)
}
var signJWTResp signJWTResponse
if err := json.Unmarshal(body, &signJWTResp); err != nil {
return "", fmt.Errorf("impersonate: unable to parse response: %v", err)
}
return signJWTResp.SignedJWT, nil
}
func (u userTokenSource) exchangeToken(signedJWT string) (*oauth2.Token, error) {
now := time.Now()
v := url.Values{}
v.Set("grant_type", "assertion")
v.Set("assertion_type", "http://oauth.net/grant_type/jwt/1.0/bearer")
v.Set("assertion", signedJWT)
rawResp, err := u.client.PostForm(fmt.Sprintf("%s/token", oauth2Endpoint), v)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to exchange token: %v", err)
}
body, err := ioutil.ReadAll(io.LimitReader(rawResp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("impersonate: unable to read body: %v", err)
}
if c := rawResp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
}
var tokenResp exchangeTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("impersonate: unable to parse response: %v", err)
}
return &oauth2.Token{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
Expiry: now.Add(time.Second * time.Duration(tokenResp.ExpiresIn)),
}, nil
}