| /* |
| * |
| * Copyright 2016, Google Inc. |
| * All rights reserved. |
| * |
| * Redistribution and use in source and binary forms, with or without |
| * modification, are permitted provided that the following conditions are |
| * met: |
| * |
| * * Redistributions of source code must retain the above copyright |
| * notice, this list of conditions and the following disclaimer. |
| * * Redistributions in binary form must reproduce the above |
| * copyright notice, this list of conditions and the following disclaimer |
| * in the documentation and/or other materials provided with the |
| * distribution. |
| * * Neither the name of Google Inc. nor the names of its |
| * contributors may be used to endorse or promote products derived from |
| * this software without specific prior written permission. |
| * |
| * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS |
| * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT |
| * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR |
| * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT |
| * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, |
| * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT |
| * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, |
| * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY |
| * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
| * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| * |
| */ |
| |
| package credentials |
| |
| import ( |
| "crypto/tls" |
| "net" |
| "testing" |
| |
| "golang.org/x/net/context" |
| ) |
| |
| func TestTLSOverrideServerName(t *testing.T) { |
| expectedServerName := "server.name" |
| c := NewTLS(nil) |
| c.OverrideServerName(expectedServerName) |
| if c.Info().ServerName != expectedServerName { |
| t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) |
| } |
| } |
| |
| func TestTLSClone(t *testing.T) { |
| expectedServerName := "server.name" |
| c := NewTLS(nil) |
| c.OverrideServerName(expectedServerName) |
| cc := c.Clone() |
| if cc.Info().ServerName != expectedServerName { |
| t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) |
| } |
| cc.OverrideServerName("") |
| if c.Info().ServerName != expectedServerName { |
| t.Fatalf("Change in clone should not affect the original, c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) |
| } |
| |
| } |
| |
| const tlsDir = "../test/testdata/" |
| |
| type serverHandshake func(net.Conn) (AuthInfo, error) |
| |
| func TestClientHandshakeReturnsAuthInfo(t *testing.T) { |
| done := make(chan AuthInfo, 1) |
| lis := launchServer(t, tlsServerHandshake, done) |
| defer lis.Close() |
| lisAddr := lis.Addr().String() |
| clientAuthInfo := clientHandle(t, gRPCClientHandshake, lisAddr) |
| // wait until server sends serverAuthInfo or fails. |
| serverAuthInfo, ok := <-done |
| if !ok { |
| t.Fatalf("Error at server-side") |
| } |
| if !compare(clientAuthInfo, serverAuthInfo) { |
| t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, clientAuthInfo, serverAuthInfo) |
| } |
| } |
| |
| func TestServerHandshakeReturnsAuthInfo(t *testing.T) { |
| done := make(chan AuthInfo, 1) |
| lis := launchServer(t, gRPCServerHandshake, done) |
| defer lis.Close() |
| clientAuthInfo := clientHandle(t, tlsClientHandshake, lis.Addr().String()) |
| // wait until server sends serverAuthInfo or fails. |
| serverAuthInfo, ok := <-done |
| if !ok { |
| t.Fatalf("Error at server-side") |
| } |
| if !compare(clientAuthInfo, serverAuthInfo) { |
| t.Fatalf("ServerHandshake(_) = %v, want %v.", serverAuthInfo, clientAuthInfo) |
| } |
| } |
| |
| func TestServerAndClientHandshake(t *testing.T) { |
| done := make(chan AuthInfo, 1) |
| lis := launchServer(t, gRPCServerHandshake, done) |
| defer lis.Close() |
| clientAuthInfo := clientHandle(t, gRPCClientHandshake, lis.Addr().String()) |
| // wait until server sends serverAuthInfo or fails. |
| serverAuthInfo, ok := <-done |
| if !ok { |
| t.Fatalf("Error at server-side") |
| } |
| if !compare(clientAuthInfo, serverAuthInfo) { |
| t.Fatalf("AuthInfo returned by server: %v and client: %v aren't same", serverAuthInfo, clientAuthInfo) |
| } |
| } |
| |
| func compare(a1, a2 AuthInfo) bool { |
| if a1.AuthType() != a2.AuthType() { |
| return false |
| } |
| switch a1.AuthType() { |
| case "tls": |
| state1 := a1.(TLSInfo).State |
| state2 := a2.(TLSInfo).State |
| if state1.Version == state2.Version && |
| state1.HandshakeComplete == state2.HandshakeComplete && |
| state1.CipherSuite == state2.CipherSuite && |
| state1.NegotiatedProtocol == state2.NegotiatedProtocol { |
| return true |
| } |
| return false |
| default: |
| return false |
| } |
| } |
| |
| func launchServer(t *testing.T, hs serverHandshake, done chan AuthInfo) net.Listener { |
| lis, err := net.Listen("tcp", "localhost:0") |
| if err != nil { |
| t.Fatalf("Failed to listen: %v", err) |
| } |
| go serverHandle(t, hs, done, lis) |
| return lis |
| } |
| |
| // Is run in a seperate goroutine. |
| func serverHandle(t *testing.T, hs serverHandshake, done chan AuthInfo, lis net.Listener) { |
| serverRawConn, err := lis.Accept() |
| if err != nil { |
| t.Errorf("Server failed to accept connection: %v", err) |
| close(done) |
| return |
| } |
| serverAuthInfo, err := hs(serverRawConn) |
| if err != nil { |
| t.Errorf("Server failed while handshake. Error: %v", err) |
| serverRawConn.Close() |
| close(done) |
| return |
| } |
| done <- serverAuthInfo |
| } |
| |
| func clientHandle(t *testing.T, hs func(net.Conn, string) (AuthInfo, error), lisAddr string) AuthInfo { |
| conn, err := net.Dial("tcp", lisAddr) |
| if err != nil { |
| t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) |
| } |
| defer conn.Close() |
| clientAuthInfo, err := hs(conn, lisAddr) |
| if err != nil { |
| t.Fatalf("Error on client while handshake. Error: %v", err) |
| } |
| return clientAuthInfo |
| } |
| |
| // Server handshake implementation in gRPC. |
| func gRPCServerHandshake(conn net.Conn) (AuthInfo, error) { |
| serverTLS, err := NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key") |
| if err != nil { |
| return nil, err |
| } |
| _, serverAuthInfo, err := serverTLS.ServerHandshake(conn) |
| if err != nil { |
| return nil, err |
| } |
| return serverAuthInfo, nil |
| } |
| |
| // Client handshake implementation in gRPC. |
| func gRPCClientHandshake(conn net.Conn, lisAddr string) (AuthInfo, error) { |
| clientTLS := NewTLS(&tls.Config{InsecureSkipVerify: true}) |
| _, authInfo, err := clientTLS.ClientHandshake(context.Background(), lisAddr, conn) |
| if err != nil { |
| return nil, err |
| } |
| return authInfo, nil |
| } |
| |
| func tlsServerHandshake(conn net.Conn) (AuthInfo, error) { |
| cert, err := tls.LoadX509KeyPair(tlsDir+"server1.pem", tlsDir+"server1.key") |
| if err != nil { |
| return nil, err |
| } |
| serverTLSConfig := &tls.Config{Certificates: []tls.Certificate{cert}} |
| serverConn := tls.Server(conn, serverTLSConfig) |
| err = serverConn.Handshake() |
| if err != nil { |
| return nil, err |
| } |
| return TLSInfo{State: serverConn.ConnectionState()}, nil |
| } |
| |
| func tlsClientHandshake(conn net.Conn, _ string) (AuthInfo, error) { |
| clientTLSConfig := &tls.Config{InsecureSkipVerify: true} |
| clientConn := tls.Client(conn, clientTLSConfig) |
| if err := clientConn.Handshake(); err != nil { |
| return nil, err |
| } |
| return TLSInfo{State: clientConn.ConnectionState()}, nil |
| } |