| // Copyright 2021 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package impersonate |
| |
| import ( |
| "bytes" |
| "context" |
| "encoding/json" |
| "fmt" |
| "io" |
| "io/ioutil" |
| "net/http" |
| "net/url" |
| "strings" |
| "time" |
| |
| "golang.org/x/oauth2" |
| ) |
| |
| func user(ctx context.Context, c CredentialsConfig, client *http.Client, lifetime time.Duration, isStaticToken bool) (oauth2.TokenSource, error) { |
| u := userTokenSource{ |
| client: client, |
| targetPrincipal: c.TargetPrincipal, |
| subject: c.Subject, |
| lifetime: lifetime, |
| } |
| u.delegates = make([]string, len(c.Delegates)) |
| for i, v := range c.Delegates { |
| u.delegates[i] = formatIAMServiceAccountName(v) |
| } |
| u.scopes = make([]string, len(c.Scopes)) |
| copy(u.scopes, c.Scopes) |
| if isStaticToken { |
| tok, err := u.Token() |
| if err != nil { |
| return nil, err |
| } |
| return oauth2.StaticTokenSource(tok), nil |
| } |
| return oauth2.ReuseTokenSource(nil, u), nil |
| } |
| |
| type claimSet struct { |
| Iss string `json:"iss"` |
| Scope string `json:"scope,omitempty"` |
| Sub string `json:"sub,omitempty"` |
| Aud string `json:"aud"` |
| Iat int64 `json:"iat"` |
| Exp int64 `json:"exp"` |
| } |
| |
| type signJWTRequest struct { |
| Payload string `json:"payload"` |
| Delegates []string `json:"delegates,omitempty"` |
| } |
| |
| type signJWTResponse struct { |
| // KeyID is the key used to sign the JWT. |
| KeyID string `json:"keyId"` |
| // SignedJwt contains the automatically generated header; the |
| // client-supplied payload; and the signature, which is generated using |
| // the key referenced by the `kid` field in the header. |
| SignedJWT string `json:"signedJwt"` |
| } |
| |
| type exchangeTokenResponse struct { |
| AccessToken string `json:"access_token"` |
| TokenType string `json:"token_type"` |
| ExpiresIn int64 `json:"expires_in"` |
| } |
| |
| type userTokenSource struct { |
| client *http.Client |
| |
| targetPrincipal string |
| subject string |
| scopes []string |
| lifetime time.Duration |
| delegates []string |
| } |
| |
| func (u userTokenSource) Token() (*oauth2.Token, error) { |
| signedJWT, err := u.signJWT() |
| if err != nil { |
| return nil, err |
| } |
| return u.exchangeToken(signedJWT) |
| } |
| |
| func (u userTokenSource) signJWT() (string, error) { |
| now := time.Now() |
| exp := now.Add(u.lifetime) |
| claims := claimSet{ |
| Iss: u.targetPrincipal, |
| Scope: strings.Join(u.scopes, " "), |
| Sub: u.subject, |
| Aud: fmt.Sprintf("%s/token", oauth2Endpoint), |
| Iat: now.Unix(), |
| Exp: exp.Unix(), |
| } |
| payloadBytes, err := json.Marshal(claims) |
| if err != nil { |
| return "", fmt.Errorf("impersonate: unable to marshal claims: %v", err) |
| } |
| signJWTReq := signJWTRequest{ |
| Payload: string(payloadBytes), |
| Delegates: u.delegates, |
| } |
| |
| bodyBytes, err := json.Marshal(signJWTReq) |
| if err != nil { |
| return "", fmt.Errorf("impersonate: unable to marshal request: %v", err) |
| } |
| reqURL := fmt.Sprintf("%s/v1/%s:signJwt", iamCredentailsEndpoint, formatIAMServiceAccountName(u.targetPrincipal)) |
| req, err := http.NewRequest("POST", reqURL, bytes.NewReader(bodyBytes)) |
| if err != nil { |
| return "", fmt.Errorf("impersonate: unable to create request: %v", err) |
| } |
| req.Header.Set("Content-Type", "application/json") |
| rawResp, err := u.client.Do(req) |
| if err != nil { |
| return "", fmt.Errorf("impersonate: unable to sign JWT: %v", err) |
| } |
| body, err := ioutil.ReadAll(io.LimitReader(rawResp.Body, 1<<20)) |
| if err != nil { |
| return "", fmt.Errorf("impersonate: unable to read body: %v", err) |
| } |
| if c := rawResp.StatusCode; c < 200 || c > 299 { |
| return "", fmt.Errorf("impersonate: status code %d: %s", c, body) |
| } |
| |
| var signJWTResp signJWTResponse |
| if err := json.Unmarshal(body, &signJWTResp); err != nil { |
| return "", fmt.Errorf("impersonate: unable to parse response: %v", err) |
| } |
| return signJWTResp.SignedJWT, nil |
| } |
| |
| func (u userTokenSource) exchangeToken(signedJWT string) (*oauth2.Token, error) { |
| now := time.Now() |
| v := url.Values{} |
| v.Set("grant_type", "assertion") |
| v.Set("assertion_type", "http://oauth.net/grant_type/jwt/1.0/bearer") |
| v.Set("assertion", signedJWT) |
| rawResp, err := u.client.PostForm(fmt.Sprintf("%s/token", oauth2Endpoint), v) |
| if err != nil { |
| return nil, fmt.Errorf("impersonate: unable to exchange token: %v", err) |
| } |
| body, err := ioutil.ReadAll(io.LimitReader(rawResp.Body, 1<<20)) |
| if err != nil { |
| return nil, fmt.Errorf("impersonate: unable to read body: %v", err) |
| } |
| if c := rawResp.StatusCode; c < 200 || c > 299 { |
| return nil, fmt.Errorf("impersonate: status code %d: %s", c, body) |
| } |
| |
| var tokenResp exchangeTokenResponse |
| if err := json.Unmarshal(body, &tokenResp); err != nil { |
| return nil, fmt.Errorf("impersonate: unable to parse response: %v", err) |
| } |
| |
| return &oauth2.Token{ |
| AccessToken: tokenResp.AccessToken, |
| TokenType: tokenResp.TokenType, |
| Expiry: now.Add(time.Second * time.Duration(tokenResp.ExpiresIn)), |
| }, nil |
| } |