| /* |
| * |
| * Copyright 2019 gRPC authors. |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| * |
| */ |
| |
| package advancedtls |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "crypto/x509" |
| "encoding/pem" |
| "errors" |
| "fmt" |
| "io/ioutil" |
| "net" |
| "reflect" |
| "syscall" |
| "testing" |
| |
| "google.golang.org/grpc/credentials" |
| "google.golang.org/grpc/security/advancedtls/testdata" |
| ) |
| |
| func TestClientServerHandshake(t *testing.T) { |
| // ------------------Load Client Trust Cert and Peer Cert------------------- |
| clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem")) |
| if err != nil { |
| t.Fatalf("Client is unable to load trust certs. Error: %v", err) |
| } |
| getRootCAsForClient := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { |
| return &GetRootCAsResults{TrustCerts: clientTrustPool}, nil |
| } |
| clientVerifyFuncGood := func(params *VerificationFuncParams) (*VerificationResults, error) { |
| if params.ServerName == "" { |
| return nil, errors.New("client side server name should have a value") |
| } |
| // "foo.bar.com" is the common name on server certificate server_cert_1.pem. |
| if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.com") { |
| return nil, errors.New("client side params parsing error") |
| } |
| |
| return &VerificationResults{}, nil |
| } |
| verifyFuncBad := func(params *VerificationFuncParams) (*VerificationResults, error) { |
| return nil, fmt.Errorf("custom verification function failed") |
| } |
| clientPeerCert, err := tls.LoadX509KeyPair(testdata.Path("client_cert_1.pem"), |
| testdata.Path("client_key_1.pem")) |
| if err != nil { |
| t.Fatalf("Client is unable to parse peer certificates. Error: %v", err) |
| } |
| // ------------------Load Server Trust Cert and Peer Cert------------------- |
| serverTrustPool, err := readTrustCert(testdata.Path("server_trust_cert_1.pem")) |
| if err != nil { |
| t.Fatalf("Server is unable to load trust certs. Error: %v", err) |
| } |
| getRootCAsForServer := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { |
| return &GetRootCAsResults{TrustCerts: serverTrustPool}, nil |
| } |
| serverVerifyFunc := func(params *VerificationFuncParams) (*VerificationResults, error) { |
| if params.ServerName != "" { |
| return nil, errors.New("server side server name should not have a value") |
| } |
| // "foo.bar.hoo.com" is the common name on client certificate client_cert_1.pem. |
| if len(params.VerifiedChains) > 0 && (params.Leaf == nil || params.Leaf.Subject.CommonName != "foo.bar.hoo.com") { |
| return nil, errors.New("server side params parsing error") |
| } |
| |
| return &VerificationResults{}, nil |
| } |
| serverPeerCert, err := tls.LoadX509KeyPair(testdata.Path("server_cert_1.pem"), |
| testdata.Path("server_key_1.pem")) |
| if err != nil { |
| t.Fatalf("Server is unable to parse peer certificates. Error: %v", err) |
| } |
| getRootCAsForServerBad := func(params *GetRootCAsParams) (*GetRootCAsResults, error) { |
| return nil, fmt.Errorf("bad root certificate reloading") |
| } |
| for _, test := range []struct { |
| desc string |
| clientCert []tls.Certificate |
| clientGetCert func(*tls.CertificateRequestInfo) (*tls.Certificate, error) |
| clientRoot *x509.CertPool |
| clientGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) |
| clientVerifyFunc CustomVerificationFunc |
| clientVType VerificationType |
| clientExpectCreateError bool |
| clientExpectHandshakeError bool |
| serverMutualTLS bool |
| serverCert []tls.Certificate |
| serverGetCert func(*tls.ClientHelloInfo) (*tls.Certificate, error) |
| serverRoot *x509.CertPool |
| serverGetRoot func(params *GetRootCAsParams) (*GetRootCAsResults, error) |
| serverVerifyFunc CustomVerificationFunc |
| serverVType VerificationType |
| serverExpectError bool |
| }{ |
| // Client: nil setting |
| // Server: only set serverCert with mutual TLS off |
| // Expected Behavior: server side failure |
| // Reason: if clientRoot, clientGetRoot and verifyFunc is not set, client |
| // side doesn't provide any verification mechanism. We don't allow this |
| // even setting vType to SkipVerification. Clients should at least provide |
| // their own verification logic. |
| { |
| desc: "Client has no trust cert; server sends peer cert", |
| clientVType: SkipVerification, |
| clientExpectCreateError: true, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverVType: CertAndHostVerification, |
| serverExpectError: true, |
| }, |
| // Client: nil setting except verifyFuncGood |
| // Server: only set serverCert with mutual TLS off |
| // Expected Behavior: success |
| // Reason: we will use verifyFuncGood to verify the server, |
| // if either clientCert or clientGetCert is not set |
| { |
| desc: "Client has no trust cert with verifyFuncGood; server sends peer cert", |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: SkipVerification, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverVType: CertAndHostVerification, |
| }, |
| // Client: only set clientRoot |
| // Server: only set serverCert with mutual TLS off |
| // Expected Behavior: server side failure and client handshake failure |
| // Reason: client side sets vType to CertAndHostVerification, and will do |
| // default hostname check. All the default hostname checks will fail in |
| // this test suites. |
| { |
| desc: "Client has root cert; server sends peer cert", |
| clientRoot: clientTrustPool, |
| clientVType: CertAndHostVerification, |
| clientExpectHandshakeError: true, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverVType: CertAndHostVerification, |
| serverExpectError: true, |
| }, |
| // Client: only set clientGetRoot |
| // Server: only set serverCert with mutual TLS off |
| // Expected Behavior: server side failure and client handshake failure |
| // Reason: client side sets vType to CertAndHostVerification, and will do |
| // default hostname check. All the default hostname checks will fail in |
| // this test suites. |
| { |
| desc: "Client sets reload root function; server sends peer cert", |
| clientGetRoot: getRootCAsForClient, |
| clientVType: CertAndHostVerification, |
| clientExpectHandshakeError: true, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverVType: CertAndHostVerification, |
| serverExpectError: true, |
| }, |
| // Client: set clientGetRoot and clientVerifyFunc |
| // Server: only set serverCert with mutual TLS off |
| // Expected Behavior: success |
| { |
| desc: "Client sets reload root function with verifyFuncGood; server sends peer cert", |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverVType: CertAndHostVerification, |
| }, |
| // Client: set clientGetRoot and bad clientVerifyFunc function |
| // Server: only set serverCert with mutual TLS off |
| // Expected Behavior: server side failure and client handshake failure |
| // Reason: custom verification function is bad |
| { |
| desc: "Client sets reload root function with verifyFuncBad; server sends peer cert", |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: verifyFuncBad, |
| clientVType: CertVerification, |
| clientExpectHandshakeError: true, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverVType: CertVerification, |
| serverExpectError: true, |
| }, |
| // Client: set clientGetRoot and clientVerifyFunc |
| // Server: nil setting |
| // Expected Behavior: server side failure |
| // Reason: server side must either set serverCert or serverGetCert |
| { |
| desc: "Client sets reload root function with verifyFuncGood; server sets nil", |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| serverVType: CertVerification, |
| serverExpectError: true, |
| }, |
| // Client: set clientGetRoot, clientVerifyFunc and clientCert |
| // Server: set serverRoot and serverCert with mutual TLS on |
| // Expected Behavior: success |
| { |
| desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets peer cert and root cert; mutualTLS", |
| clientCert: []tls.Certificate{clientPeerCert}, |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| serverMutualTLS: true, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverRoot: serverTrustPool, |
| serverVType: CertVerification, |
| }, |
| // Client: set clientGetRoot, clientVerifyFunc and clientCert |
| // Server: set serverCert, but not setting any of serverRoot, serverGetRoot |
| // or serverVerifyFunc, with mutual TLS on |
| // Expected Behavior: server side failure |
| // Reason: server side needs to provide any verification mechanism when |
| // mTLS in on, even setting vType to SkipVerification. Servers should at |
| // least provide their own verification logic. |
| { |
| desc: "Client sets peer cert, reload root function with verifyFuncGood; server sets no verification; mutualTLS", |
| clientCert: []tls.Certificate{clientPeerCert}, |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| clientExpectHandshakeError: true, |
| serverMutualTLS: true, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverVType: SkipVerification, |
| serverExpectError: true, |
| }, |
| // Client: set clientGetRoot, clientVerifyFunc and clientCert |
| // Server: set serverGetRoot and serverCert with mutual TLS on |
| // Expected Behavior: success |
| { |
| desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, reload root function; mutualTLS", |
| clientCert: []tls.Certificate{clientPeerCert}, |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| serverMutualTLS: true, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverGetRoot: getRootCAsForServer, |
| serverVType: CertVerification, |
| }, |
| // Client: set clientGetRoot, clientVerifyFunc and clientCert |
| // Server: set serverGetRoot returning error and serverCert with mutual |
| // TLS on |
| // Expected Behavior: server side failure |
| // Reason: server side reloading returns failure |
| { |
| desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets peer cert, bad reload root function; mutualTLS", |
| clientCert: []tls.Certificate{clientPeerCert}, |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| serverMutualTLS: true, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverGetRoot: getRootCAsForServerBad, |
| serverVType: CertVerification, |
| serverExpectError: true, |
| }, |
| // Client: set clientGetRoot, clientVerifyFunc and clientGetCert |
| // Server: set serverGetRoot and serverGetCert with mutual TLS on |
| // Expected Behavior: success |
| { |
| desc: "Client sets reload peer/root function with verifyFuncGood; Server sets reload peer/root function with verifyFuncGood; mutualTLS", |
| clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { |
| return &clientPeerCert, nil |
| }, |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| serverMutualTLS: true, |
| serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { |
| return &serverPeerCert, nil |
| }, |
| serverGetRoot: getRootCAsForServer, |
| serverVerifyFunc: serverVerifyFunc, |
| serverVType: CertVerification, |
| }, |
| // Client: set everything but with the wrong peer cert not trusted by |
| // server |
| // Server: set serverGetRoot and serverGetCert with mutual TLS on |
| // Expected Behavior: server side returns failure because of |
| // certificate mismatch |
| { |
| desc: "Client sends wrong peer cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS", |
| clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { |
| return &serverPeerCert, nil |
| }, |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| serverMutualTLS: true, |
| serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { |
| return &serverPeerCert, nil |
| }, |
| serverGetRoot: getRootCAsForServer, |
| serverVerifyFunc: serverVerifyFunc, |
| serverVType: CertVerification, |
| serverExpectError: true, |
| }, |
| // Client: set everything but with the wrong trust cert not trusting server |
| // Server: set serverGetRoot and serverGetCert with mutual TLS on |
| // Expected Behavior: server side and client side return failure due to |
| // certificate mismatch and handshake failure |
| { |
| desc: "Client has wrong trust cert; Server sets reload peer/root function with verifyFuncGood; mutualTLS", |
| clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { |
| return &clientPeerCert, nil |
| }, |
| clientGetRoot: getRootCAsForServer, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| clientExpectHandshakeError: true, |
| serverMutualTLS: true, |
| serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { |
| return &serverPeerCert, nil |
| }, |
| serverGetRoot: getRootCAsForServer, |
| serverVerifyFunc: serverVerifyFunc, |
| serverVType: CertVerification, |
| serverExpectError: true, |
| }, |
| // Client: set clientGetRoot, clientVerifyFunc and clientCert |
| // Server: set everything but with the wrong peer cert not trusted by |
| // client |
| // Expected Behavior: server side and client side return failure due to |
| // certificate mismatch and handshake failure |
| { |
| desc: "Client sets reload peer/root function with verifyFuncGood; Server sends wrong peer cert; mutualTLS", |
| clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { |
| return &clientPeerCert, nil |
| }, |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| serverMutualTLS: true, |
| serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { |
| return &clientPeerCert, nil |
| }, |
| serverGetRoot: getRootCAsForServer, |
| serverVerifyFunc: serverVerifyFunc, |
| serverVType: CertVerification, |
| serverExpectError: true, |
| }, |
| // Client: set clientGetRoot, clientVerifyFunc and clientCert |
| // Server: set everything but with the wrong trust cert not trusting client |
| // Expected Behavior: server side and client side return failure due to |
| // certificate mismatch and handshake failure |
| { |
| desc: "Client sets reload peer/root function with verifyFuncGood; Server has wrong trust cert; mutualTLS", |
| clientGetCert: func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { |
| return &clientPeerCert, nil |
| }, |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| clientExpectHandshakeError: true, |
| serverMutualTLS: true, |
| serverGetCert: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { |
| return &serverPeerCert, nil |
| }, |
| serverGetRoot: getRootCAsForClient, |
| serverVerifyFunc: serverVerifyFunc, |
| serverVType: CertVerification, |
| serverExpectError: true, |
| }, |
| // Client: set clientGetRoot, clientVerifyFunc and clientCert |
| // Server: set serverGetRoot and serverCert, but with bad verifyFunc |
| // Expected Behavior: server side and client side return failure due to |
| // server custom check fails |
| { |
| desc: "Client sets peer cert, reload root function with verifyFuncGood; Server sets bad custom check; mutualTLS", |
| clientCert: []tls.Certificate{clientPeerCert}, |
| clientGetRoot: getRootCAsForClient, |
| clientVerifyFunc: clientVerifyFuncGood, |
| clientVType: CertVerification, |
| clientExpectHandshakeError: true, |
| serverMutualTLS: true, |
| serverCert: []tls.Certificate{serverPeerCert}, |
| serverGetRoot: getRootCAsForServer, |
| serverVerifyFunc: verifyFuncBad, |
| serverVType: CertVerification, |
| serverExpectError: true, |
| }, |
| } { |
| test := test |
| t.Run(test.desc, func(t *testing.T) { |
| done := make(chan credentials.AuthInfo, 1) |
| lis, err := net.Listen("tcp", "localhost:0") |
| if err != nil { |
| t.Fatalf("Failed to listen: %v", err) |
| } |
| // Start a server using ServerOptions in another goroutine. |
| serverOptions := &ServerOptions{ |
| Certificates: test.serverCert, |
| GetCertificate: test.serverGetCert, |
| RootCertificateOptions: RootCertificateOptions{ |
| RootCACerts: test.serverRoot, |
| GetRootCAs: test.serverGetRoot, |
| }, |
| RequireClientCert: test.serverMutualTLS, |
| VerifyPeer: test.serverVerifyFunc, |
| VType: test.serverVType, |
| } |
| go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) { |
| serverRawConn, err := lis.Accept() |
| if err != nil { |
| close(done) |
| return |
| } |
| serverTLS, err := NewServerCreds(serverOptions) |
| if err != nil { |
| serverRawConn.Close() |
| close(done) |
| return |
| } |
| _, serverAuthInfo, err := serverTLS.ServerHandshake(serverRawConn) |
| if err != nil { |
| serverRawConn.Close() |
| close(done) |
| return |
| } |
| done <- serverAuthInfo |
| }(done, lis, serverOptions) |
| defer lis.Close() |
| // Start a client using ClientOptions and connects to the server. |
| lisAddr := lis.Addr().String() |
| conn, err := net.Dial("tcp", lisAddr) |
| if err != nil { |
| t.Fatalf("Client failed to connect to %s. Error: %v", lisAddr, err) |
| } |
| defer conn.Close() |
| clientOptions := &ClientOptions{ |
| Certificates: test.clientCert, |
| GetClientCertificate: test.clientGetCert, |
| VerifyPeer: test.clientVerifyFunc, |
| RootCertificateOptions: RootCertificateOptions{ |
| RootCACerts: test.clientRoot, |
| GetRootCAs: test.clientGetRoot, |
| }, |
| VType: test.clientVType, |
| } |
| clientTLS, newClientErr := NewClientCreds(clientOptions) |
| if newClientErr != nil && test.clientExpectCreateError { |
| return |
| } |
| if newClientErr != nil && !test.clientExpectCreateError || |
| newClientErr == nil && test.clientExpectCreateError { |
| t.Fatalf("Expect error: %v, but err is %v", |
| test.clientExpectCreateError, newClientErr) |
| } |
| _, clientAuthInfo, handshakeErr := clientTLS.ClientHandshake(context.Background(), |
| lisAddr, conn) |
| // wait until server sends serverAuthInfo or fails. |
| serverAuthInfo, ok := <-done |
| if !ok && test.serverExpectError { |
| return |
| } |
| if ok && test.serverExpectError || !ok && !test.serverExpectError { |
| t.Fatalf("Server side error mismatch, got %v, want %v", !ok, test.serverExpectError) |
| } |
| if handshakeErr != nil && test.clientExpectHandshakeError { |
| return |
| } |
| if handshakeErr != nil && !test.clientExpectHandshakeError || |
| handshakeErr == nil && test.clientExpectHandshakeError { |
| t.Fatalf("Expect error: %v, but err is %v", |
| test.clientExpectHandshakeError, handshakeErr) |
| } |
| if !compare(clientAuthInfo, serverAuthInfo) { |
| t.Fatalf("c.ClientHandshake(_, %v, _) = %v, want %v.", lisAddr, |
| clientAuthInfo, serverAuthInfo) |
| } |
| }) |
| } |
| } |
| |
| func readTrustCert(fileName string) (*x509.CertPool, error) { |
| trustData, err := ioutil.ReadFile(fileName) |
| if err != nil { |
| return nil, err |
| } |
| trustBlock, _ := pem.Decode(trustData) |
| if trustBlock == nil { |
| return nil, err |
| } |
| trustCert, err := x509.ParseCertificate(trustBlock.Bytes) |
| if err != nil { |
| return nil, err |
| } |
| trustPool := x509.NewCertPool() |
| trustPool.AddCert(trustCert) |
| return trustPool, nil |
| } |
| |
| func compare(a1, a2 credentials.AuthInfo) bool { |
| if a1.AuthType() != a2.AuthType() { |
| return false |
| } |
| switch a1.AuthType() { |
| case "tls": |
| state1 := a1.(credentials.TLSInfo).State |
| state2 := a2.(credentials.TLSInfo).State |
| if state1.Version == state2.Version && |
| state1.HandshakeComplete == state2.HandshakeComplete && |
| state1.CipherSuite == state2.CipherSuite && |
| state1.NegotiatedProtocol == state2.NegotiatedProtocol { |
| return true |
| } |
| return false |
| default: |
| return false |
| } |
| } |
| |
| func TestAdvancedTLSOverrideServerName(t *testing.T) { |
| expectedServerName := "server.name" |
| clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem")) |
| if err != nil { |
| t.Fatalf("Client is unable to load trust certs. Error: %v", err) |
| } |
| clientOptions := &ClientOptions{ |
| RootCertificateOptions: RootCertificateOptions{ |
| RootCACerts: clientTrustPool, |
| }, |
| ServerNameOverride: expectedServerName, |
| } |
| c, err := NewClientCreds(clientOptions) |
| if err != nil { |
| t.Fatalf("Client is unable to create credentials. Error: %v", err) |
| } |
| c.OverrideServerName(expectedServerName) |
| if c.Info().ServerName != expectedServerName { |
| t.Fatalf("c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) |
| } |
| } |
| |
| func TestTLSClone(t *testing.T) { |
| expectedServerName := "server.name" |
| clientTrustPool, err := readTrustCert(testdata.Path("client_trust_cert_1.pem")) |
| if err != nil { |
| t.Fatalf("Client is unable to load trust certs. Error: %v", err) |
| } |
| clientOptions := &ClientOptions{ |
| RootCertificateOptions: RootCertificateOptions{ |
| RootCACerts: clientTrustPool, |
| }, |
| ServerNameOverride: expectedServerName, |
| } |
| c, err := NewClientCreds(clientOptions) |
| if err != nil { |
| t.Fatalf("Failed to create new client: %v", err) |
| } |
| cc := c.Clone() |
| if cc.Info().ServerName != expectedServerName { |
| t.Fatalf("cc.Info().ServerName = %v, want %v", cc.Info().ServerName, expectedServerName) |
| } |
| cc.OverrideServerName("") |
| if c.Info().ServerName != expectedServerName { |
| t.Fatalf("Change in clone should not affect the original, "+ |
| "c.Info().ServerName = %v, want %v", c.Info().ServerName, expectedServerName) |
| } |
| |
| } |
| |
| func TestAppendH2ToNextProtos(t *testing.T) { |
| tests := []struct { |
| name string |
| ps []string |
| want []string |
| }{ |
| { |
| name: "empty", |
| ps: nil, |
| want: []string{"h2"}, |
| }, |
| { |
| name: "only h2", |
| ps: []string{"h2"}, |
| want: []string{"h2"}, |
| }, |
| { |
| name: "with h2", |
| ps: []string{"alpn", "h2"}, |
| want: []string{"alpn", "h2"}, |
| }, |
| { |
| name: "no h2", |
| ps: []string{"alpn"}, |
| want: []string{"alpn", "h2"}, |
| }, |
| } |
| for _, tt := range tests { |
| t.Run(tt.name, func(t *testing.T) { |
| if got := appendH2ToNextProtos(tt.ps); !reflect.DeepEqual(got, tt.want) { |
| t.Errorf("appendH2ToNextProtos() = %v, want %v", got, tt.want) |
| } |
| }) |
| } |
| } |
| |
| type nonSyscallConn struct { |
| net.Conn |
| } |
| |
| func TestWrapSyscallConn(t *testing.T) { |
| sc := &syscallConn{} |
| nsc := &nonSyscallConn{} |
| |
| wrapConn := WrapSyscallConn(sc, nsc) |
| if _, ok := wrapConn.(syscall.Conn); !ok { |
| t.Errorf("returned conn (type %T) doesn't implement syscall.Conn, want implement", |
| wrapConn) |
| } |
| } |