blob: 9a98ae3574f17b7fdae54d393ecf406f05257c8f [file] [log] [blame]
package libtrust
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net"
"os"
"path"
"sync"
)
// ClientKeyManager manages client keys on the filesystem
type ClientKeyManager struct {
key PrivateKey
clientFile string
clientDir string
clientLock sync.RWMutex
clients []PublicKey
configLock sync.Mutex
configs []*tls.Config
}
// NewClientKeyManager loads a new manager from a set of key files
// and managed by the given private key.
func NewClientKeyManager(trustKey PrivateKey, clientFile, clientDir string) (*ClientKeyManager, error) {
m := &ClientKeyManager{
key: trustKey,
clientFile: clientFile,
clientDir: clientDir,
}
if err := m.loadKeys(); err != nil {
return nil, err
}
// TODO Start watching file and directory
return m, nil
}
func (c *ClientKeyManager) loadKeys() (err error) {
// Load authorized keys file
var clients []PublicKey
if c.clientFile != "" {
clients, err = LoadKeySetFile(c.clientFile)
if err != nil {
return fmt.Errorf("unable to load authorized keys: %s", err)
}
}
// Add clients from authorized keys directory
files, err := ioutil.ReadDir(c.clientDir)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("unable to open authorized keys directory: %s", err)
}
for _, f := range files {
if !f.IsDir() {
publicKey, err := LoadPublicKeyFile(path.Join(c.clientDir, f.Name()))
if err != nil {
return fmt.Errorf("unable to load authorized key file: %s", err)
}
clients = append(clients, publicKey)
}
}
c.clientLock.Lock()
c.clients = clients
c.clientLock.Unlock()
return nil
}
// RegisterTLSConfig registers a tls configuration to manager
// such that any changes to the keys may be reflected in
// the tls client CA pool
func (c *ClientKeyManager) RegisterTLSConfig(tlsConfig *tls.Config) error {
c.clientLock.RLock()
certPool, err := GenerateCACertPool(c.key, c.clients)
if err != nil {
return fmt.Errorf("CA pool generation error: %s", err)
}
c.clientLock.RUnlock()
tlsConfig.ClientCAs = certPool
c.configLock.Lock()
c.configs = append(c.configs, tlsConfig)
c.configLock.Unlock()
return nil
}
// NewIdentityAuthTLSConfig creates a tls.Config for the server to use for
// libtrust identity authentication for the domain specified
func NewIdentityAuthTLSConfig(trustKey PrivateKey, clients *ClientKeyManager, addr string, domain string) (*tls.Config, error) {
tlsConfig := newTLSConfig()
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
if err := clients.RegisterTLSConfig(tlsConfig); err != nil {
return nil, err
}
// Generate cert
ips, domains, err := parseAddr(addr)
if err != nil {
return nil, err
}
// add domain that it expects clients to use
domains = append(domains, domain)
x509Cert, err := GenerateSelfSignedServerCert(trustKey, domains, ips)
if err != nil {
return nil, fmt.Errorf("certificate generation error: %s", err)
}
tlsConfig.Certificates = []tls.Certificate{{
Certificate: [][]byte{x509Cert.Raw},
PrivateKey: trustKey.CryptoPrivateKey(),
Leaf: x509Cert,
}}
return tlsConfig, nil
}
// NewCertAuthTLSConfig creates a tls.Config for the server to use for
// certificate authentication
func NewCertAuthTLSConfig(caPath, certPath, keyPath string) (*tls.Config, error) {
tlsConfig := newTLSConfig()
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("Couldn't load X509 key pair (%s, %s): %s. Key encrypted?", certPath, keyPath, err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
// Verify client certificates against a CA?
if caPath != "" {
certPool := x509.NewCertPool()
file, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("Couldn't read CA certificate: %s", err)
}
certPool.AppendCertsFromPEM(file)
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = certPool
}
return tlsConfig, nil
}
func newTLSConfig() *tls.Config {
return &tls.Config{
NextProtos: []string{"http/1.1"},
// Avoid fallback on insecure SSL protocols
MinVersion: tls.VersionTLS10,
}
}
// parseAddr parses an address into an array of IPs and domains
func parseAddr(addr string) ([]net.IP, []string, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, nil, err
}
var domains []string
var ips []net.IP
ip := net.ParseIP(host)
if ip != nil {
ips = []net.IP{ip}
} else {
domains = []string{host}
}
return ips, domains, nil
}