| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package auth |
| |
| import ( |
| "context" |
| "fmt" |
| "io" |
| "net/http" |
| "net/http/httptest" |
| "net/url" |
| "testing" |
| "time" |
| ) |
| |
| const day = 24 * time.Hour |
| |
| func newOpts(url string) *Options3LO { |
| return &Options3LO{ |
| ClientID: "CLIENT_ID", |
| ClientSecret: "CLIENT_SECRET", |
| RedirectURL: "REDIRECT_URL", |
| Scopes: []string{"scope1", "scope2"}, |
| AuthURL: url + "/auth", |
| TokenURL: url + "/token", |
| AuthStyle: StyleInHeader, |
| RefreshToken: "OLD_REFRESH_TOKEN", |
| } |
| } |
| |
| func Test3LO_URLUnsafe(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if got, want := r.Header.Get("Authorization"), "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y="; got != want { |
| t.Errorf("Authorization header = %q; want %q", got, want) |
| } |
| |
| w.Header().Set("Content-Type", "application/x-www-form-urlencoded") |
| w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) |
| })) |
| defer ts.Close() |
| conf := newOpts(ts.URL) |
| conf.ClientID = "CLIENT_ID??" |
| conf.ClientSecret = "CLIENT_SECRET??" |
| _, _, err := conf.exchange(context.Background(), "exchange-code") |
| if err != nil { |
| t.Error(err) |
| } |
| } |
| |
| func Test3LO_StandardExchange(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if r.URL.String() != "/token" { |
| t.Errorf("Unexpected exchange request URL %q", r.URL) |
| } |
| headerAuth := r.Header.Get("Authorization") |
| if want := "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="; headerAuth != want { |
| t.Errorf("Unexpected authorization header %q, want %q", headerAuth, want) |
| } |
| headerContentType := r.Header.Get("Content-Type") |
| if headerContentType != "application/x-www-form-urlencoded" { |
| t.Errorf("Unexpected Content-Type header %q", headerContentType) |
| } |
| body, err := io.ReadAll(r.Body) |
| if err != nil { |
| t.Errorf("Failed reading request body: %s.", err) |
| } |
| if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { |
| t.Errorf("Unexpected exchange payload; got %q", body) |
| } |
| w.Header().Set("Content-Type", "application/x-www-form-urlencoded") |
| w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) |
| })) |
| defer ts.Close() |
| conf := newOpts(ts.URL) |
| tok, _, err := conf.exchange(context.Background(), "exchange-code") |
| if err != nil { |
| t.Error(err) |
| } |
| if !tok.IsValid() { |
| t.Fatalf("Token invalid. Got: %#v", tok) |
| } |
| if tok.Value != "90d64460d14870c08c81352a05dedd3465940a7c" { |
| t.Errorf("Unexpected access token, %#v.", tok.Value) |
| } |
| if tok.Type != "bearer" { |
| t.Errorf("Unexpected token type, %#v.", tok.Type) |
| } |
| scope := tok.Metadata["scope"].([]string) |
| if scope[0] != "user" { |
| t.Errorf("Unexpected value for scope: %v", scope) |
| } |
| } |
| |
| func Test3LO_ExchangeCustomParams(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if r.URL.String() != "/token" { |
| t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) |
| } |
| headerAuth := r.Header.Get("Authorization") |
| if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { |
| t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
| } |
| headerContentType := r.Header.Get("Content-Type") |
| if headerContentType != "application/x-www-form-urlencoded" { |
| t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
| } |
| body, err := io.ReadAll(r.Body) |
| if err != nil { |
| t.Errorf("Failed reading request body: %s.", err) |
| } |
| if string(body) != "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { |
| t.Errorf("Unexpected exchange payload, %v is found.", string(body)) |
| } |
| w.Header().Set("Content-Type", "application/x-www-form-urlencoded") |
| w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer")) |
| })) |
| defer ts.Close() |
| conf := newOpts(ts.URL) |
| conf.URLParams = url.Values{} |
| conf.URLParams.Set("foo", "bar") |
| |
| tok, _, err := conf.exchange(context.Background(), "exchange-code") |
| if err != nil { |
| t.Error(err) |
| } |
| if !tok.IsValid() { |
| t.Fatalf("Token invalid. Got: %#v", tok) |
| } |
| if tok.Value != "90d64460d14870c08c81352a05dedd3465940a7c" { |
| t.Errorf("Unexpected access token, %#v.", tok.Value) |
| } |
| if tok.Type != "bearer" { |
| t.Errorf("Unexpected token type, %#v.", tok.Type) |
| } |
| scope := tok.Metadata["scope"].([]string) |
| if scope[0] != "user" { |
| t.Errorf("Unexpected value for scope: %v", scope) |
| } |
| } |
| |
| func Test3LO_ExchangeJSONResponse(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| if r.URL.String() != "/token" { |
| t.Errorf("Unexpected exchange request URL, %v is found.", r.URL) |
| } |
| headerAuth := r.Header.Get("Authorization") |
| if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" { |
| t.Errorf("Unexpected authorization header, %v is found.", headerAuth) |
| } |
| headerContentType := r.Header.Get("Content-Type") |
| if headerContentType != "application/x-www-form-urlencoded" { |
| t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType) |
| } |
| body, err := io.ReadAll(r.Body) |
| if err != nil { |
| t.Errorf("Failed reading request body: %s.", err) |
| } |
| if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" { |
| t.Errorf("Unexpected exchange payload, %v is found.", string(body)) |
| } |
| w.Header().Set("Content-Type", "application/json") |
| w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`)) |
| })) |
| defer ts.Close() |
| conf := newOpts(ts.URL) |
| tok, _, err := conf.exchange(context.Background(), "exchange-code") |
| if err != nil { |
| t.Error(err) |
| } |
| if !tok.IsValid() { |
| t.Fatalf("Token invalid. Got: %#v", tok) |
| } |
| if tok.Value != "90d64460d14870c08c81352a05dedd3465940a7c" { |
| t.Errorf("Unexpected access token, %#v.", tok.Value) |
| } |
| if tok.Type != "bearer" { |
| t.Errorf("Unexpected token type, %#v.", tok.Type) |
| } |
| scope := tok.Metadata["scope"].(string) |
| if scope != "user" { |
| t.Errorf("Unexpected value for scope: %v", scope) |
| } |
| expiresIn := tok.Metadata["expires_in"] |
| if expiresIn != float64(86400) { |
| t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn) |
| } |
| } |
| |
| func Test3LO_ExchangeJSONResponseExpiry(t *testing.T) { |
| seconds := int32(day.Seconds()) |
| for _, c := range []struct { |
| name string |
| expires string |
| want bool |
| nullExpires bool |
| }{ |
| {"normal", fmt.Sprintf(`"expires_in": %d`, seconds), true, false}, |
| {"null", `"expires_in": null`, true, true}, |
| {"wrong_type", `"expires_in": false`, false, false}, |
| {"wrong_type2", `"expires_in": {}`, false, false}, |
| {"wrong_value", `"expires_in": "zzz"`, false, false}, |
| } { |
| t.Run(c.name, func(t *testing.T) { |
| test3LOExchangeJSONResponseExpiry(t, c.expires, c.want, c.nullExpires) |
| }) |
| } |
| } |
| |
| func test3LOExchangeJSONResponseExpiry(t *testing.T, exp string, want, nullExpires bool) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("Content-Type", "application/json") |
| w.Write([]byte(fmt.Sprintf(`{"access_token": "90d", "scope": "user", "token_type": "bearer", %s}`, exp))) |
| })) |
| defer ts.Close() |
| conf := newOpts(ts.URL) |
| t1 := time.Now().Add(day) |
| tok, _, err := conf.exchange(context.Background(), "exchange-code") |
| t2 := t1.Add(day) |
| |
| if got := (err == nil); got != want { |
| if want { |
| t.Errorf("unexpected error: got %v", err) |
| } else { |
| t.Errorf("unexpected success") |
| } |
| } |
| if !want { |
| return |
| } |
| if !tok.IsValid() { |
| t.Fatalf("Token invalid. Got: %#v", tok) |
| } |
| expiry := tok.Expiry |
| |
| if nullExpires && expiry.IsZero() { |
| return |
| } |
| if expiry.Before(t1) || expiry.After(t2) { |
| t.Errorf("Unexpected value for Expiry: %v (should be between %v and %v)", expiry, t1, t2) |
| } |
| } |
| |
| func Test3LO_ExchangeBadResponse(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("Content-Type", "application/json") |
| w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`)) |
| })) |
| defer ts.Close() |
| conf := newOpts(ts.URL) |
| _, _, err := conf.exchange(context.Background(), "code") |
| if err == nil { |
| t.Error("expected error from missing access_token") |
| } |
| } |
| |
| func Test3LO_ExchangeBadResponseType(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("Content-Type", "application/json") |
| w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`)) |
| })) |
| defer ts.Close() |
| conf := newOpts(ts.URL) |
| _, _, err := conf.exchange(context.Background(), "exchange-code") |
| if err == nil { |
| t.Error("expected error from non-string access_token") |
| } |
| } |
| |
| func Test3LO_RefreshTokenReplacement(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("Content-Type", "application/json") |
| w.Write([]byte(`{"access_token":"ACCESS_TOKEN", "scope": "user", "token_type": "bearer", "refresh_token": "NEW_REFRESH_TOKEN"}`)) |
| })) |
| defer ts.Close() |
| opts := newOpts(ts.URL) |
| tp, err := New3LOTokenProvider(opts) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if _, err := tp.Token(context.Background()); err != nil { |
| t.Errorf("got err = %v; want none", err) |
| return |
| } |
| innerTP := tp.(*cachedTokenProvider).tp.(*tokenProvider3LO) |
| if want := "NEW_REFRESH_TOKEN"; innerTP.refreshToken != want { |
| t.Errorf("RefreshToken = %q; want %q", innerTP.refreshToken, want) |
| } |
| } |
| |
| func Test3LO_RefreshTokenPreservation(t *testing.T) { |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("Content-Type", "application/json") |
| w.Write([]byte(`{"access_token":"ACCESS_TOKEN", "scope": "user", "token_type": "bearer"}`)) |
| })) |
| defer ts.Close() |
| opts := newOpts(ts.URL) |
| const oldRefreshToken = "OLD_REFRESH_TOKEN" |
| tp, err := New3LOTokenProvider(opts) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if _, err := tp.Token(context.Background()); err != nil { |
| t.Errorf("got err = %v; want none", err) |
| return |
| } |
| innerTP := tp.(*cachedTokenProvider).tp.(*tokenProvider3LO) |
| if innerTP.refreshToken != oldRefreshToken { |
| t.Errorf("RefreshToken = %q; want %q", innerTP.refreshToken, oldRefreshToken) |
| } |
| } |
| |
| func Test3LO_AuthHandlerExchangeSuccess(t *testing.T) { |
| authhandler := func(authCodeURL string) (string, string, error) { |
| if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" { |
| return "testCode", "testState", nil |
| } |
| return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL) |
| } |
| |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| r.ParseForm() |
| if r.Form.Get("code") == "testCode" { |
| w.Header().Set("Content-Type", "application/json") |
| w.Write([]byte(`{ |
| "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", |
| "scope": "pubsub", |
| "token_type": "bearer", |
| "expires_in": 3600 |
| }`)) |
| } |
| })) |
| defer ts.Close() |
| |
| opts := &Options3LO{ |
| ClientID: "testClientID", |
| Scopes: []string{"pubsub"}, |
| AuthURL: "testAuthCodeURL", |
| TokenURL: ts.URL, |
| AuthStyle: StyleInHeader, |
| AuthHandlerOpts: &AuthorizationHandlerOptions{ |
| State: "testState", |
| Handler: authhandler, |
| }, |
| } |
| |
| tp, err := New3LOTokenProvider(opts) |
| if err != nil { |
| t.Fatal(err) |
| } |
| tok, err := tp.Token(context.Background()) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !tok.IsValid() { |
| t.Errorf("got invalid token: %v", tok) |
| } |
| if got, want := tok.Value, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want { |
| t.Errorf("access token = %q; want %q", got, want) |
| } |
| if got, want := tok.Type, "bearer"; got != want { |
| t.Errorf("token type = %q; want %q", got, want) |
| } |
| if got := tok.Expiry.IsZero(); got { |
| t.Errorf("token expiry is zero = %v, want false", got) |
| } |
| scope := tok.Metadata["scope"].(string) |
| if got, want := scope, "pubsub"; got != want { |
| t.Errorf("scope = %q; want %q", got, want) |
| } |
| } |
| |
| func Test3LO_AuthHandlerExchangeStateMismatch(t *testing.T) { |
| authhandler := func(authCodeURL string) (string, string, error) { |
| return "testCode", "testStateMismatch", nil |
| } |
| |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("Content-Type", "application/json") |
| w.Write([]byte(`{ |
| "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", |
| "scope": "pubsub", |
| "token_type": "bearer", |
| "expires_in": 3600 |
| }`)) |
| })) |
| defer ts.Close() |
| |
| opts := &Options3LO{ |
| ClientID: "testClientID", |
| Scopes: []string{"pubsub"}, |
| AuthURL: "testAuthCodeURL", |
| TokenURL: ts.URL, |
| AuthStyle: StyleInParams, |
| AuthHandlerOpts: &AuthorizationHandlerOptions{ |
| State: "testState", |
| Handler: authhandler, |
| }, |
| } |
| tp, err := New3LOTokenProvider(opts) |
| if err != nil { |
| t.Fatal(err) |
| } |
| _, err = tp.Token(context.Background()) |
| if wantErr := "auth: state mismatch in 3-legged-OAuth flow"; err == nil || err.Error() != wantErr { |
| t.Errorf("err = %q; want %q", err, wantErr) |
| } |
| } |
| |
| func Test3LO_PKCEExchangeWithSuccess(t *testing.T) { |
| authhandler := func(authCodeURL string) (string, string, error) { |
| if authCodeURL == "testAuthCodeURL?client_id=testClientID&code_challenge=codeChallenge&code_challenge_method=plain&response_type=code&scope=pubsub&state=testState" { |
| return "testCode", "testState", nil |
| } |
| return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL) |
| } |
| |
| ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| r.ParseForm() |
| if r.Form.Get("code") == "testCode" && r.Form.Get("code_verifier") == "codeChallenge" { |
| w.Header().Set("Content-Type", "application/json") |
| w.Write([]byte(`{ |
| "access_token": "90d64460d14870c08c81352a05dedd3465940a7c", |
| "scope": "pubsub", |
| "token_type": "bearer", |
| "expires_in": 3600 |
| }`)) |
| } |
| })) |
| defer ts.Close() |
| |
| opts := &Options3LO{ |
| ClientID: "testClientID", |
| Scopes: []string{"pubsub"}, |
| AuthURL: "testAuthCodeURL", |
| TokenURL: ts.URL, |
| AuthStyle: StyleInParams, |
| AuthHandlerOpts: &AuthorizationHandlerOptions{ |
| State: "testState", |
| Handler: authhandler, |
| PKCEOpts: &PKCEOptions{ |
| Challenge: "codeChallenge", |
| ChallengeMethod: "plain", |
| Verifier: "codeChallenge", |
| }, |
| }, |
| } |
| |
| tp, err := New3LOTokenProvider(opts) |
| if err != nil { |
| t.Fatal(err) |
| } |
| tok, err := tp.Token(context.Background()) |
| if err != nil { |
| t.Fatal(err) |
| } |
| if !tok.IsValid() { |
| t.Errorf("got invalid token: %v", tok) |
| } |
| if got, want := tok.Value, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want { |
| t.Errorf("access token = %q; want %q", got, want) |
| } |
| if got, want := tok.Type, "bearer"; got != want { |
| t.Errorf("token type = %q; want %q", got, want) |
| } |
| if got := tok.Expiry.IsZero(); got { |
| t.Errorf("token expiry is zero = %v, want false", got) |
| } |
| scope := tok.Metadata["scope"].(string) |
| if got, want := scope, "pubsub"; got != want { |
| t.Errorf("scope = %q; want %q", got, want) |
| } |
| } |
| |
| func Test3LO_Validate(t *testing.T) { |
| tests := []struct { |
| name string |
| opts *Options3LO |
| }{ |
| { |
| name: "missing options", |
| }, |
| { |
| name: "missing client ID", |
| opts: &Options3LO{ |
| ClientSecret: "client_secret", |
| AuthURL: "auth_url", |
| TokenURL: "token_url", |
| AuthStyle: StyleInHeader, |
| RefreshToken: "refreshing", |
| }, |
| }, |
| { |
| name: "missing client secret", |
| opts: &Options3LO{ |
| ClientID: "client_id", |
| AuthURL: "auth_url", |
| TokenURL: "token_url", |
| AuthStyle: StyleInHeader, |
| RefreshToken: "refreshing", |
| }, |
| }, |
| { |
| name: "missing auth URL", |
| opts: &Options3LO{ |
| ClientID: "client_id", |
| ClientSecret: "client_secret", |
| TokenURL: "token_url", |
| AuthStyle: StyleInHeader, |
| RefreshToken: "refreshing", |
| }, |
| }, |
| { |
| name: "missing token URL", |
| opts: &Options3LO{ |
| ClientID: "client_id", |
| ClientSecret: "client_secret", |
| AuthURL: "auth_url", |
| AuthStyle: StyleInHeader, |
| RefreshToken: "refreshing", |
| }, |
| }, |
| { |
| name: "missing auth style", |
| opts: &Options3LO{ |
| ClientID: "client_id", |
| ClientSecret: "client_secret", |
| AuthURL: "auth_url", |
| TokenURL: "token_url", |
| RefreshToken: "refreshing", |
| }, |
| }, |
| { |
| name: "missing refresh token", |
| opts: &Options3LO{ |
| ClientID: "client_id", |
| ClientSecret: "client_secret", |
| AuthURL: "auth_url", |
| TokenURL: "token_url", |
| AuthStyle: StyleInHeader, |
| }, |
| }, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| if _, err := New3LOTokenProvider(tt.opts); err == nil { |
| t.Error("got nil, want an error") |
| } |
| }) |
| } |
| } |