| /* |
| * |
| * Copyright 2016 gRPC authors. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| package credentials |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "net" |
| "strings" |
| "testing" |
| |
| "google.golang.org/grpc/internal" |
| "google.golang.org/grpc/internal/grpctest" |
| "google.golang.org/grpc/testdata" |
| ) |
| |
| type s struct { |
| grpctest.Tester |
| } |
| |
| func Test(t *testing.T) { |
| grpctest.RunSubTests(t, s{}) |
| } |
| |
| // A struct that implements AuthInfo interface but does not implement GetCommonAuthInfo() method. |
| type testAuthInfoNoGetCommonAuthInfoMethod struct{} |
| |
| func (ta testAuthInfoNoGetCommonAuthInfoMethod) AuthType() string { |
| return "testAuthInfoNoGetCommonAuthInfoMethod" |
| } |
| |
| // A struct that implements AuthInfo interface and implements CommonAuthInfo() method. |
| type testAuthInfo struct { |
| CommonAuthInfo |
| } |
| |
| func (ta testAuthInfo) AuthType() string { |
| return "testAuthInfo" |
| } |
| |
| func createTestContext(s SecurityLevel) context.Context { |
| auth := &testAuthInfo{CommonAuthInfo: CommonAuthInfo{SecurityLevel: s}} |
| ri := RequestInfo{ |
| Method: "testInfo", |
| AuthInfo: auth, |
| } |
| return internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(context.Background(), ri) |
| } |
| |
| func (s) TestCheckSecurityLevel(t *testing.T) { |
| testCases := []struct { |
| authLevel SecurityLevel |
| testLevel SecurityLevel |
| want bool |
| }{ |
| { |
| authLevel: PrivacyAndIntegrity, |
| testLevel: PrivacyAndIntegrity, |
| want: true, |
| }, |
| { |
| authLevel: IntegrityOnly, |
| testLevel: PrivacyAndIntegrity, |
| want: false, |
| }, |
| { |
| authLevel: IntegrityOnly, |
| testLevel: NoSecurity, |
| want: true, |
| }, |
| { |
| authLevel: Invalid, |
| testLevel: IntegrityOnly, |
| want: true, |
| }, |
| { |
| authLevel: Invalid, |
| testLevel: PrivacyAndIntegrity, |
| want: true, |
| }, |
| } |
| for _, tc := range testCases { |
| err := CheckSecurityLevel(createTestContext(tc.authLevel), tc.testLevel) |
| if tc.want && (err != nil) { |
| t.Fatalf("CheckSeurityLevel(%s, %s) returned failure but want success", tc.authLevel.String(), tc.testLevel.String()) |
| } else if !tc.want && (err == nil) { |
| t.Fatalf("CheckSeurityLevel(%s, %s) returned success but want failure", tc.authLevel.String(), tc.testLevel.String()) |
| |
| } |
| } |
| } |
| |
| func (s) TestCheckSecurityLevelNoGetCommonAuthInfoMethod(t *testing.T) { |
| auth := &testAuthInfoNoGetCommonAuthInfoMethod{} |
| ri := RequestInfo{ |
| Method: "testInfo", |
| AuthInfo: auth, |
| } |
| ctxWithRequestInfo := internal.NewRequestInfoContext.(func(context.Context, RequestInfo) context.Context)(context.Background(), ri) |
| if err := CheckSecurityLevel(ctxWithRequestInfo, PrivacyAndIntegrity); err != nil { |
| t.Fatalf("CheckSeurityLevel() returned failure but want success") |
| } |
| } |
| |
| func (s) TestTLSOverrideServerName(t *testing.T) { |
| expectedServerName := "server.name" |
| c := NewTLS(nil) |
| c.OverrideServerName(expectedServerName) |
| if c.Info().ServerName != expectedServerName { |
| t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) |
| } |
| } |
| |
| func (s) TestTLSClone(t *testing.T) { |
| expectedServerName := "server.name" |
| c := NewTLS(nil) |
| c.OverrideServerName(expectedServerName) |
| cc := c.Clone() |
| if cc.Info().ServerName != expectedServerName { |
| t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) |
| } |
| cc.OverrideServerName("") |
| if c.Info().ServerName != expectedServerName { |
| t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) |
| } |
| |
| } |
| |
| type serverHandshake func(net.Conn) (AuthInfo, error) |
| |
| func (s) TestClientHandshakeReturnsAuthInfo(t *testing.T) { |
| tcs := []struct { |
| name string |
| address string |
| }{ |
| { |
| name: "localhost", |
| address: "localhost:0", |
| }, |
| { |
| name: "ipv4", |
| address: "127.0.0.1:0", |
| }, |
| { |
| name: "ipv6", |
| address: "[::1]:0", |
| }, |
| } |
| |
| for _, tc := range tcs { |
| t.Run(tc.name, func(t *testing.T) { |
| done := make(chan AuthInfo, 1) |
| lis := launchServerOnListenAddress(t, tlsServerHandshake, done, tc.address) |
| defer lis.Close() |
| lisAddr := lis.Addr().String() |
| clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) |
| // wait until server sends serverAuthInfo or fails. |
| serverAuthInfo, ok := <-done |
| if !ok { |
| t.Fatalf("Error at server-side") |
| } |
| if !compare(clientAuthInfo, serverAuthInfo) { |
| t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) |
| } |
| }) |
| } |
| } |
| |
| func (s) TestServerHandshakeReturnsAuthInfo(t *testing.T) { |
| done := make(chan AuthInfo, 1) |
| lis := launchServer(t, gRPCServerHandshake, done) |
| defer lis.Close() |
| clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String()) |
| // wait until server sends serverAuthInfo or fails. |
| serverAuthInfo, ok := <-done |
| if !ok { |
| t.Fatalf("Error at server-side") |
| } |
| if !compare(clientAuthInfo, serverAuthInfo) { |
| t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo) |
| } |
| } |
| |
| func (s) TestServerAndClientHandshake(t *testing.T) { |
| done := make(chan AuthInfo, 1) |
| lis := launchServer(t, gRPCServerHandshake, done) |
| defer lis.Close() |
| clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String()) |
| // wait until server sends serverAuthInfo or fails. |
| serverAuthInfo, ok := <-done |
| if !ok { |
| t.Fatalf("Error at server-side") |
| } |
| if !compare(clientAuthInfo, serverAuthInfo) { |
| t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo) |
| } |
| } |
| |
| func compare(a1, a2 AuthInfo) bool { |
| if a1.AuthType() != a2.AuthType() { |
| return false |
| } |
| switch a1.AuthType() { |
| case "tls": |
| state1 := a1.(TLSInfo).State |
| state2 := a2.(TLSInfo).State |
| if state1.Version == state2.Version && |
| state1.HandshakeComplete == state2.HandshakeComplete && |
| state1.CipherSuite == state2.CipherSuite && |
| state1.NegotiatedProtocol == state2.NegotiatedProtocol { |
| return true |
| } |
| return false |
| default: |
| return false |
| } |
| } |
| |
| func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener { |
| return launchServerOnListenAddress(t, hs, done, "localhost:0") |
| } |
| |
| func launchServerOnListenAddress(t *testing.T, hs serverHandshake, done chan AuthInfo, address string) net.Listener { |
| lis, err := net.Listen("tcp", address) |
| if err != nil { |
| if strings.Contains(err.Error(), "bind: cannot assign requested address") || |
| strings.Contains(err.Error(), "socket: address family not supported by protocol") { |
| t.Skipf("no support for address %v", address) |
| } |
| t.Fatalf("Failed to listen: %v", err) |
| } |
| go serverHandle(t, hs, done, lis) |
| return lis |
| } |
| |
| // Is run in a separate goroutine. |
| func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) { |
| serverRawConn, err := lis.Accept() |
| if err != nil { |
| t.Errorf("Server failed to accept connection: %v", err) |
| close(done) |
| return |
| } |
| serverAuthInfo, err := hs(serverRawConn) |
| if err != nil { |
| t.Errorf("Server failed while handshake. Error: %v", err) |
| serverRawConn.Close() |
| close(done) |
| return |
| } |
| done <- serverAuthInfo |
| } |
| |
| func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo { |
| conn, err := net.Dial("tcp", lisAddr) |
| if err != nil { |
| t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) |
| } |
| defer conn.Close() |
| clientAuthInfo, err := hs(conn, lisAddr) |
| if err != nil { |
| t.Fatalf("Error on client while handshake. Error: %v", err) |
| } |
| return clientAuthInfo |
| } |
| |
| // Server handshake implementation in gRPC. |
| func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) { |
| serverTLS, err := NewServerTLSFromFile(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) |
| if err != nil { |
| return nil, err |
| } |
| _, serverAuthInfo, err := serverTLS.ServerHandshake(conn) |
| if err != nil { |
| return nil, err |
| } |
| return serverAuthInfo, nil |
| } |
| |
| // Client handshake implementation in gRPC. |
| func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) { |
| clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true}) |
| _, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn) |
| if err != nil { |
| return nil, err |
| } |
| return authInfo, nil |
| } |
| |
| func tlsServerHandshake(conn net.Conn) (AuthInfo, error) { |
| cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem")) |
| if err != nil { |
| return nil, err |
| } |
| serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} |
| serverConn := tls.Server(conn, serverTLSConfig) |
| err = serverConn.Handshake() |
| if err != nil { |
| return nil, err |
| } |
| return TLSInfo{State: serverConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil |
| } |
| |
| func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) { |
| clientTLSConfig := &tls.Config{InsecureSkipVerify: true} |
| clientConn := tls.Client(conn, clientTLSConfig) |
| if err := clientConn.Handshake(); err != nil { |
| return nil, err |
| } |
| return TLSInfo{State: clientConn.ConnectionState(), CommonAuthInfo: CommonAuthInfo{SecurityLevel: PrivacyAndIntegrity}}, nil |
| } |