| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package oauth2adapt |
| |
| import ( |
| "context" |
| "errors" |
| "net/http" |
| "testing" |
| |
| "cloud.google.com/go/auth" |
| "github.com/google/go-cmp/cmp" |
| "golang.org/x/oauth2" |
| "golang.org/x/oauth2/google" |
| ) |
| |
| func TestTokenProviderFromTokenSource(t *testing.T) { |
| tests := []struct { |
| name string |
| token *oauth2.Token |
| err error |
| }{ |
| { |
| name: "working token", |
| token: &oauth2.Token{AccessToken: "fakeToken", TokenType: "Basic"}, |
| err: nil, |
| }, |
| { |
| name: "coverts err", |
| err: &oauth2.RetrieveError{ |
| Body: []byte("some bytes"), |
| ErrorCode: "412", |
| Response: &http.Response{ |
| StatusCode: http.StatusTeapot, |
| }, |
| }, |
| }, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| tp := TokenProviderFromTokenSource(tokenSource{ |
| token: tt.token, |
| err: tt.err, |
| }) |
| tok, err := tp.Token(context.Background()) |
| if tt.err != nil { |
| aErr := &auth.Error{} |
| if !errors.As(err, &aErr) { |
| t.Fatalf("error not of correct type: %T", err) |
| } |
| err := tt.err.(*oauth2.RetrieveError) |
| if !cmp.Equal(aErr.Body, err.Body) { |
| t.Errorf("got %s, want %s", aErr.Body, err.Body) |
| } |
| if !cmp.Equal(aErr.Err, err) { |
| t.Errorf("got %s, want %s", aErr.Err, err) |
| } |
| if !cmp.Equal(aErr.Response, err.Response) { |
| t.Errorf("got %s, want %s", aErr.Err, err) |
| } |
| return |
| } |
| if tok.Value != tt.token.AccessToken { |
| t.Errorf("got %q, want %q", tok.Value, tt.token.AccessToken) |
| } |
| if tok.Type != tt.token.TokenType { |
| t.Errorf("got %q, want %q", tok.Type, tt.token.TokenType) |
| } |
| }) |
| } |
| } |
| |
| func TestTokenSourceFromTokenProvider(t *testing.T) { |
| tests := []struct { |
| name string |
| token *auth.Token |
| err error |
| }{ |
| { |
| name: "working token", |
| token: &auth.Token{ |
| Value: "fakeToken", |
| Type: "Basic", |
| }, |
| err: nil, |
| }, |
| { |
| name: "coverts err", |
| err: &auth.Error{ |
| Body: []byte("some bytes"), |
| Response: &http.Response{ |
| StatusCode: http.StatusTeapot, |
| }, |
| }, |
| }, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| ts := TokenSourceFromTokenProvider(tokenProvider{ |
| token: tt.token, |
| err: tt.err, |
| }) |
| tok, err := ts.Token() |
| if tt.err != nil { |
| // Should be able to be an auth.Error |
| aErr := &auth.Error{} |
| if !errors.As(err, &aErr) { |
| t.Fatalf("error not of correct type: %T", err) |
| } |
| err := tt.err.(*auth.Error) |
| if !cmp.Equal(aErr.Body, err.Body) { |
| t.Errorf("got %s, want %s", aErr.Body, err.Body) |
| } |
| if !cmp.Equal(aErr.Response, err.Response) { |
| t.Errorf("got %s, want %s", aErr.Err, err) |
| } |
| |
| // Should be able to be an oauth2.RetrieveError |
| rErr := &oauth2.RetrieveError{} |
| if !errors.As(err, &rErr) { |
| t.Fatalf("error not of correct type: %T", err) |
| } |
| if !cmp.Equal(rErr.Body, err.Body) { |
| t.Errorf("got %s, want %s", aErr.Body, err.Body) |
| } |
| if !cmp.Equal(rErr.Response, err.Response) { |
| t.Errorf("got %s, want %s", aErr.Err, err) |
| } |
| return |
| } |
| if tok.AccessToken != tt.token.Value { |
| t.Errorf("got %q, want %q", tok.AccessToken, tt.token.Value) |
| } |
| if tok.TokenType != tt.token.Type { |
| t.Errorf("got %q, want %q", tok.TokenType, tt.token.Type) |
| } |
| }) |
| } |
| } |
| |
| func TestAuthCredentialsFromOauth2Credentials(t *testing.T) { |
| ctx := context.Background() |
| inputCreds := &google.Credentials{ |
| ProjectID: "test_project", |
| TokenSource: tokenSource{token: &oauth2.Token{AccessToken: "token"}}, |
| JSON: []byte("json"), |
| UniverseDomainProvider: func() (string, error) { |
| return "domain", nil |
| }, |
| } |
| outCreds := AuthCredentialsFromOauth2Credentials(inputCreds) |
| |
| gotProject, err := outCreds.ProjectID(ctx) |
| if err != nil { |
| t.Fatalf("outCreds.ProjectID() = %v", err) |
| } |
| if want := inputCreds.ProjectID; gotProject != want { |
| t.Fatalf("got %q, want %q", gotProject, want) |
| } |
| |
| gotToken, err := outCreds.Token(ctx) |
| if err != nil { |
| t.Fatalf("outCreds.Token() = %v", err) |
| } |
| wantTok, err := inputCreds.TokenSource.Token() |
| if err != nil { |
| t.Fatalf("inputCreds.TokenSource.Token() = %v", err) |
| } |
| if gotToken.Value != wantTok.AccessToken { |
| t.Fatalf("got %q, want %q", gotToken.Value, wantTok.AccessToken) |
| } |
| |
| gotJSON := outCreds.JSON() |
| if want := inputCreds.JSON; !cmp.Equal(gotJSON, want) { |
| t.Fatalf("got %s, want %s", gotJSON, want) |
| } |
| |
| gotUD, err := outCreds.UniverseDomain(ctx) |
| if err != nil { |
| t.Fatalf("outCreds.UniverseDomain() = %v", err) |
| } |
| wantUD, err := inputCreds.GetUniverseDomain() |
| if err != nil { |
| t.Fatalf("inputCreds.GetUniverseDomain() = %v", err) |
| } |
| if gotUD != wantUD { |
| t.Fatalf("got %q, want %q", wantUD, wantUD) |
| } |
| } |
| |
| func TestOauth2CredentialsFromAuthCredentials(t *testing.T) { |
| ctx := context.Background() |
| inputCreds := auth.NewCredentials(&auth.CredentialsOptions{ |
| ProjectIDProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) { |
| return "project", nil |
| }), |
| TokenProvider: tokenProvider{token: &auth.Token{Value: "token"}}, |
| JSON: []byte("json"), |
| UniverseDomainProvider: auth.CredentialsPropertyFunc(func(ctx context.Context) (string, error) { |
| return "domain", nil |
| }), |
| }) |
| outCreds := Oauth2CredentialsFromAuthCredentials(inputCreds) |
| |
| wantProject, err := inputCreds.ProjectID(ctx) |
| if err != nil { |
| t.Fatalf("inputCreds.ProjectID() = %v", err) |
| } |
| if outCreds.ProjectID != wantProject { |
| t.Fatalf("got %q, want %q", outCreds.ProjectID, wantProject) |
| } |
| |
| gotToken, err := inputCreds.Token(ctx) |
| if err != nil { |
| t.Fatalf("inputCreds.Token() = %v", err) |
| } |
| wantTok, err := outCreds.TokenSource.Token() |
| if err != nil { |
| t.Fatalf("outCreds.TokenSource.Token() = %v", err) |
| } |
| if gotToken.Value != wantTok.AccessToken { |
| t.Fatalf("got %q, want %q", gotToken.Value, wantTok.AccessToken) |
| } |
| |
| wantJSON := inputCreds.JSON() |
| if !cmp.Equal(outCreds.JSON, wantJSON) { |
| t.Fatalf("got %s, want %s", outCreds.JSON, wantJSON) |
| } |
| |
| wantUD, err := inputCreds.UniverseDomain(ctx) |
| if err != nil { |
| t.Fatalf("outCreds.UniverseDomain() = %v", err) |
| } |
| gotUD, err := outCreds.GetUniverseDomain() |
| if err != nil { |
| t.Fatalf("inputCreds.GetUniverseDomain() = %v", err) |
| } |
| if gotUD != wantUD { |
| t.Fatalf("got %q, want %q", wantUD, wantUD) |
| } |
| } |
| |
| type tokenSource struct { |
| token *oauth2.Token |
| err error |
| } |
| |
| func (ts tokenSource) Token() (*oauth2.Token, error) { |
| if ts.err != nil { |
| return nil, ts.err |
| } |
| return &oauth2.Token{ |
| AccessToken: ts.token.AccessToken, |
| TokenType: ts.token.TokenType, |
| }, nil |
| } |
| |
| type tokenProvider struct { |
| token *auth.Token |
| err error |
| } |
| |
| func (tp tokenProvider) Token(context.Context) (*auth.Token, error) { |
| if tp.err != nil { |
| return nil, tp.err |
| } |
| return &auth.Token{ |
| Value: tp.token.Value, |
| Type: tp.token.Type, |
| }, nil |
| } |