| package ca |
| |
| import ( |
| "context" |
| "crypto/tls" |
| "crypto/x509" |
| "crypto/x509/pkix" |
| "net" |
| "strings" |
| "sync" |
| |
| "github.com/pkg/errors" |
| "google.golang.org/grpc/credentials" |
| ) |
| |
| var ( |
| // alpnProtoStr is the specified application level protocols for gRPC. |
| alpnProtoStr = []string{"h2"} |
| ) |
| |
| // MutableTLSCreds is the credentials required for authenticating a connection using TLS. |
| type MutableTLSCreds struct { |
| // Mutex for the tls config |
| sync.Mutex |
| // TLS configuration |
| config *tls.Config |
| // TLS Credentials |
| tlsCreds credentials.TransportCredentials |
| // store the subject for easy access |
| subject pkix.Name |
| } |
| |
| // Info implements the credentials.TransportCredentials interface |
| func (c *MutableTLSCreds) Info() credentials.ProtocolInfo { |
| return credentials.ProtocolInfo{ |
| SecurityProtocol: "tls", |
| SecurityVersion: "1.2", |
| } |
| } |
| |
| // Clone returns new MutableTLSCreds created from underlying *tls.Config. |
| // It panics if validation of underlying config fails. |
| func (c *MutableTLSCreds) Clone() credentials.TransportCredentials { |
| c.Lock() |
| newCfg, err := NewMutableTLS(c.config.Clone()) |
| if err != nil { |
| panic("validation error on Clone") |
| } |
| c.Unlock() |
| return newCfg |
| } |
| |
| // OverrideServerName overrides *tls.Config.ServerName. |
| func (c *MutableTLSCreds) OverrideServerName(name string) error { |
| c.Lock() |
| c.config.ServerName = name |
| c.Unlock() |
| return nil |
| } |
| |
| // GetRequestMetadata implements the credentials.TransportCredentials interface |
| func (c *MutableTLSCreds) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { |
| return nil, nil |
| } |
| |
| // RequireTransportSecurity implements the credentials.TransportCredentials interface |
| func (c *MutableTLSCreds) RequireTransportSecurity() bool { |
| return true |
| } |
| |
| // ClientHandshake implements the credentials.TransportCredentials interface |
| func (c *MutableTLSCreds) ClientHandshake(ctx context.Context, addr string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { |
| // borrow all the code from the original TLS credentials |
| c.Lock() |
| if c.config.ServerName == "" { |
| colonPos := strings.LastIndex(addr, ":") |
| if colonPos == -1 { |
| colonPos = len(addr) |
| } |
| c.config.ServerName = addr[:colonPos] |
| } |
| |
| conn := tls.Client(rawConn, c.config) |
| // Need to allow conn.Handshake to have access to config, |
| // would create a deadlock otherwise |
| c.Unlock() |
| var err error |
| errChannel := make(chan error, 1) |
| go func() { |
| errChannel <- conn.Handshake() |
| }() |
| select { |
| case err = <-errChannel: |
| case <-ctx.Done(): |
| err = ctx.Err() |
| } |
| if err != nil { |
| rawConn.Close() |
| return nil, nil, err |
| } |
| return conn, nil, nil |
| } |
| |
| // ServerHandshake implements the credentials.TransportCredentials interface |
| func (c *MutableTLSCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { |
| c.Lock() |
| conn := tls.Server(rawConn, c.config) |
| c.Unlock() |
| if err := conn.Handshake(); err != nil { |
| rawConn.Close() |
| return nil, nil, err |
| } |
| |
| return conn, credentials.TLSInfo{State: conn.ConnectionState()}, nil |
| } |
| |
| // loadNewTLSConfig replaces the currently loaded TLS config with a new one |
| func (c *MutableTLSCreds) loadNewTLSConfig(newConfig *tls.Config) error { |
| newSubject, err := GetAndValidateCertificateSubject(newConfig.Certificates) |
| if err != nil { |
| return err |
| } |
| |
| c.Lock() |
| defer c.Unlock() |
| c.subject = newSubject |
| c.config = newConfig |
| |
| return nil |
| } |
| |
| // Config returns the current underlying TLS config. |
| func (c *MutableTLSCreds) Config() *tls.Config { |
| c.Lock() |
| defer c.Unlock() |
| |
| return c.config |
| } |
| |
| // Role returns the OU for the certificate encapsulated in this TransportCredentials |
| func (c *MutableTLSCreds) Role() string { |
| c.Lock() |
| defer c.Unlock() |
| |
| return c.subject.OrganizationalUnit[0] |
| } |
| |
| // Organization returns the O for the certificate encapsulated in this TransportCredentials |
| func (c *MutableTLSCreds) Organization() string { |
| c.Lock() |
| defer c.Unlock() |
| |
| return c.subject.Organization[0] |
| } |
| |
| // NodeID returns the CN for the certificate encapsulated in this TransportCredentials |
| func (c *MutableTLSCreds) NodeID() string { |
| c.Lock() |
| defer c.Unlock() |
| |
| return c.subject.CommonName |
| } |
| |
| // NewMutableTLS uses c to construct a mutable TransportCredentials based on TLS. |
| func NewMutableTLS(c *tls.Config) (*MutableTLSCreds, error) { |
| originalTC := credentials.NewTLS(c) |
| |
| if len(c.Certificates) < 1 { |
| return nil, errors.New("invalid configuration: needs at least one certificate") |
| } |
| |
| subject, err := GetAndValidateCertificateSubject(c.Certificates) |
| if err != nil { |
| return nil, err |
| } |
| |
| tc := &MutableTLSCreds{config: c, tlsCreds: originalTC, subject: subject} |
| tc.config.NextProtos = alpnProtoStr |
| |
| return tc, nil |
| } |
| |
| // GetAndValidateCertificateSubject is a helper method to retrieve and validate the subject |
| // from the x509 certificate underlying a tls.Certificate |
| func GetAndValidateCertificateSubject(certs []tls.Certificate) (pkix.Name, error) { |
| for i := range certs { |
| cert := &certs[i] |
| x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) |
| if err != nil { |
| continue |
| } |
| if len(x509Cert.Subject.OrganizationalUnit) < 1 { |
| return pkix.Name{}, errors.New("no OU found in certificate subject") |
| } |
| |
| if len(x509Cert.Subject.Organization) < 1 { |
| return pkix.Name{}, errors.New("no organization found in certificate subject") |
| } |
| if x509Cert.Subject.CommonName == "" { |
| return pkix.Name{}, errors.New("no valid subject names found for TLS configuration") |
| } |
| |
| return x509Cert.Subject, nil |
| } |
| |
| return pkix.Name{}, errors.New("no valid certificates found for TLS configuration") |
| } |