blob: 4ed58f8f58f1e8ff11349943f4a0e6e27dea174a [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package sshutil
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net"
"time"
"go.fuchsia.dev/fuchsia/tools/lib/logger"
"go.fuchsia.dev/fuchsia/tools/lib/retry"
"go.fuchsia.dev/fuchsia/tools/net/netutil"
"golang.org/x/crypto/ssh"
)
const (
// Default SSH server port.
SSHPort = 22
// Default RSA key size.
RSAKeySize = 2048
// The allowed timeout for a single attempt at establishing an SSH
// connection.
connectAttemptTimeout = 10 * time.Second
// The allowed timeout to establish an ssh connection, possibly including
// many attempts.
totalConnectTimeout = 2 * time.Minute
sshUser = "fuchsia"
)
var (
// defaultConnectBackoff is the connection backoff for clients returned by
// Connect() and ConnectToNode().
defaultConnectBackoff = retry.WithMaxDuration(&retry.ZeroBackoff{}, totalConnectTimeout)
)
// ConnectionError is an all-purpose error indicating that a client has become
// unresponsive.
type ConnectionError struct {
Err error
}
func (e ConnectionError) Unwrap() error {
return e.Err
}
func (e ConnectionError) Error() string {
// ConnectionError is intended to be an umbrella error type for all kinds of
// SSH-related errors, so there's no information we can add to the
// underlying error message that would be particularly useful in all
// scenarios.
if e.Err != nil {
return e.Err.Error()
}
return "SSH connection error"
}
// IsConnectionError determines whether the given error is a ConnectionError.
// This is a common check that we include in sshutil to save callers a line of
// code.
func IsConnectionError(err error) bool {
var connErr ConnectionError
return errors.As(err, &connErr)
}
// GeneratePrivateKey generates a private SSH key.
func GeneratePrivateKey() ([]byte, error) {
key, err := rsa.GenerateKey(rand.Reader, RSAKeySize)
if err != nil {
return nil, err
}
privateKey := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}
buf := pem.EncodeToMemory(privateKey)
return buf, nil
}
// CheckConnection returns nil if a connection is verified as still alive; else
// it returns an error that unwraps as a ConnectionError.
func CheckConnection(client *ssh.Client) error {
if _, _, err := client.Conn.SendRequest(keepaliveOpenSSH, true, nil); err != nil {
return ConnectionError{err}
}
return nil
}
// ConnectDeprecated establishes an SSH connection at the given remote address.
// If it fails to connect, it will return an error that unwraps as a
// ConnectionError.
// TODO(fxb/48042): Delete in favor of a method that returns sshutil.Client.
func ConnectDeprecated(ctx context.Context, raddr net.Addr, config *ssh.ClientConfig) (*ssh.Client, error) {
network, err := network(raddr)
if err != nil {
return nil, err
}
var client *ssh.Client
// TODO: figure out optimal backoff time and number of retries
startTime := time.Now()
if err := retry.Retry(ctx, defaultConnectBackoff, func() error {
var err error
client, err = dialWithTimeout(ctx, network, raddr.String(), config, connectAttemptTimeout)
return err
}, nil); err != nil {
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
// The exact time at which the timeout triggers is nondeterministic;
// it'll be somewhere between `totalConnectTimeout` and
// `totalConnectTimeout + connectAttemptTimeout`. So we measure the
// duration to improve accuracy.
duration := time.Now().Sub(startTime).Truncate(time.Second)
err = ConnectionError{fmt.Errorf("timed out trying to connect to ssh after %v", duration)}
} else {
err = ConnectionError{fmt.Errorf("cannot connect to address %q: %v", raddr, err)}
}
return nil, err
}
return client, nil
}
// ssh.Dial can hang during authentication, the 'timeout' being set in the config only
// applying to establishment of the initial connection. This function is effectively
// ssh.Dial with the ability to set a deadline on the underlying connection.
//
// See https://github.com/golang/go/issues/21941 for more details on the hang.
func dialWithTimeout(ctx context.Context, network, addr string, config *ssh.ClientConfig, timeout time.Duration) (*ssh.Client, error) {
conn, err := net.DialTimeout(network, addr, config.Timeout)
if err != nil {
return nil, err
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
conn.Close()
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
conn.Close()
return nil, err
}
if err := conn.SetDeadline(time.Time{}); err != nil {
c.Close()
return nil, err
}
client := ssh.NewClient(c, chans, reqs)
go keepalive(ctx, conn, client)
return client, nil
}
// ConnectToNodeConnectToNodeDeprecated connects to the device with the given
// nodename.
// TODO(fxb/48042): Delete in favor of a method that returns sshutil.Client.
func ConnectToNodeDeprecated(ctx context.Context, nodename string, config *ssh.ClientConfig) (*ssh.Client, error) {
addr, err := netutil.GetNodeAddress(ctx, nodename, true)
if err != nil {
return nil, err
}
addr.Port = SSHPort
return ConnectDeprecated(ctx, addr, config)
}
// ConnectToNode connects to the device with the given nodename.
func ConnectToNode(ctx context.Context, nodename string, config *ssh.ClientConfig) (*Client, error) {
addr, err := netutil.GetNodeAddress(ctx, nodename, true)
if err != nil {
return nil, err
}
addr.Port = SSHPort
return NewClient(ctx, addr, config, defaultConnectBackoff)
}
// DefaultSSHConfig returns a basic SSH client configuration.
func DefaultSSHConfig(privateKey []byte) (*ssh.ClientConfig, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return nil, err
}
return DefaultSSHConfigFromSigners(signer)
}
// DefaultSSHConfigFromSigners returns a basic SSH client configuration.
func DefaultSSHConfigFromSigners(signers ...ssh.Signer) (*ssh.ClientConfig, error) {
return &ssh.ClientConfig{
User: sshUser,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signers...)},
Timeout: connectAttemptTimeout,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}, nil
}
// Returns the network to use to SSH into a device.
func network(address net.Addr) (string, error) {
var ip *net.IP
// We need these type assertions because the net package (annoyingly) doesn't provide
// an interface for objects that have an IP address.
switch addr := address.(type) {
case *net.UDPAddr:
ip = &addr.IP
case *net.TCPAddr:
ip = &addr.IP
case *net.IPAddr:
ip = &addr.IP
default:
return "", fmt.Errorf("unsupported address type: %T", address)
}
if ip.To4() != nil {
return "tcp", nil // IPv4
}
if ip.To16() != nil {
return "tcp6", nil // IPv6
}
return "", fmt.Errorf("cannot infer network for IP address %s", ip.String())
}
// keepalive runs for the duration of the client's lifetime, sending periodic
// pings (with response timeouts) to ensure that the client is still connected.
// If a ping fails, it will close the client and exit.
func keepalive(ctx context.Context, conn net.Conn, client *ssh.Client) {
ticker := time.NewTicker(defaultKeepaliveInterval)
defer ticker.Stop()
for range ticker.C {
if err := emitKeepalive(conn, client); err != nil {
// Try to close the client. It's possible the keepalive failed
// because the client has already been closed, which is fine – this
// close will just silently fail.
logger.Errorf(ctx, "ssh keepalive failed, closing client: %v", err)
client.Close()
return
}
}
}
func emitKeepalive(conn net.Conn, client *ssh.Client) error {
conn.SetDeadline(time.Now().Add(defaultKeepaliveTimeout))
return CheckConnection(client)
}