blob: c413c7abf32e5e36b9da24bdfe3a2ea4dd3d9e4d [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package sshutil
import (
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"
"go.fuchsia.dev/fuchsia/tools/lib/logger"
"go.fuchsia.dev/fuchsia/tools/lib/retry"
"go.fuchsia.dev/fuchsia/tools/net/sshutil/constants"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
const (
// Interval between keepalive pings.
defaultKeepaliveInterval = 1 * time.Second
// Cancel the connection if a we don't receive a response to a keepalive
// ping within this amount of time.
defaultKeepaliveTimeout = defaultKeepaliveInterval + 5*time.Second
// A conventionally used global request name for checking the status of a client
// connection to an OpenSSH server.
keepaliveOpenSSH = "keepalive@openssh.com"
)
// Conn is a wrapper around ssh that supports keepalive and auto-reconnection.
type Conn struct {
*ssh.Client
addr net.Addr
config *ssh.ClientConfig
shuttingDown chan struct{}
// This mutex protects the following fields
mu sync.Mutex
disconnectionListeners []chan struct{}
}
// newConn creates a new ssh client to the address and launches a goroutine to
// send keepalive pings as long as the client is connected.
func newConn(ctx context.Context, addr net.Addr, config *ssh.ClientConfig, backoff retry.Backoff) (*Conn, error) {
conn, err := connect(ctx, addr, config, backoff)
if err != nil {
return nil, err
}
// We want to log from the keepalive thread, but we don't want to inherit
// any of `ctx`'s cancellations. So we will create a new context and
// initialize it with the logger in `ctx`.
keepaliveCtx := context.Background()
if v := logger.LoggerFromContext(ctx); v != nil {
keepaliveCtx = logger.WithLogger(keepaliveCtx, v)
}
go func() {
t := time.NewTicker(defaultKeepaliveInterval)
defer t.Stop()
timeout := func() <-chan time.Time {
return time.After(defaultKeepaliveTimeout)
}
conn.keepalive(keepaliveCtx, t.C, timeout)
}()
return conn, nil
}
// connect continuously attempts to connect to a remote server, and returns an
// ssh client if successful, or errs out if the context is canceled.
func connect(ctx context.Context, addr net.Addr, config *ssh.ClientConfig, backoff retry.Backoff) (*Conn, error) {
startTime := time.Now()
var client *ssh.Client
err := retry.Retry(ctx, backoff, func() error {
logger.Debugf(ctx, "trying to connect to %s...", addr)
var err error
client, err = connectToSSH(ctx, addr, config)
if err != nil {
return err
}
logger.Debugf(ctx, "connected to %s", addr)
return nil
}, nil)
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
duration := time.Now().Sub(startTime).Truncate(time.Second)
return nil, ConnectionError{fmt.Errorf("%s after %v: %w", constants.TimedOutConnectingMsg, duration, err)}
} else if err != nil {
return nil, ConnectionError{fmt.Errorf("cannot connect to address %q: %w", addr, err)}
}
return &Conn{
Client: client,
addr: addr,
config: config,
shuttingDown: make(chan struct{}),
}, nil
}
func connectToSSH(ctx context.Context, addr net.Addr, config *ssh.ClientConfig) (*ssh.Client, error) {
// Update the context with the ssh connection timeout, if specified.
if config.Timeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, config.Timeout)
defer cancel()
}
d := net.Dialer{}
conn, err := d.DialContext(ctx, "tcp", addr.String())
if err != nil {
// DialContext wraps maps context errors to custom non-exported error
// types, so even if the operation failed due to a context error it
// might not return a context error. See
// https://github.com/golang/go/blob/b4652028d48f42506cfd10c1763c6d7e8b22cb7b/src/net/net.go#L420
// So we convert back to a context error to provide a more consistent
// interface for callers of this method.
//
// There is a potential race condition where the context might be
// canceled after DialContext exits but before we hit this line, in
// which case we would actually return the wrong error. But that's
// probably not a big deal because the fact that the context was
// canceled implies that we should be giving up on, and ignoring the
// results of, ongoing operations anyway.
if ctx.Err() != nil {
err = ctx.Err()
}
return nil, err
}
// We made a TCP connection, now establish an SSH connection over it.
//
// We can hang if the server accepts a connection but never replies to the
// ssh handshake. To handle this case, we'll establish the connection in a
// goroutine, and wait for it to complete or the context to be canceled.
type result struct {
client *ssh.Client
err error
}
ch := make(chan result, 1)
go func() {
clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr.String(), config)
if err != nil {
if closeErr := conn.Close(); closeErr != nil {
err = fmt.Errorf("error closing connection: %v; original error: %w", closeErr, err)
}
ch <- result{err: err}
return
}
ch <- result{client: ssh.NewClient(clientConn, chans, reqs)}
}()
select {
case r := <-ch:
return r.client, r.err
case <-ctx.Done():
err = ctx.Err()
if closeErr := conn.Close(); closeErr != nil {
err = fmt.Errorf("error closing connection: %v; original error: %w", closeErr, err)
}
return nil, err
}
}
func (c *Conn) makeSession(ctx context.Context, stdout io.Writer, stderr io.Writer) (*Session, error) {
// Temporarily grab the lock and make a copy of the client. This
// prevents a long running `Run` command from blocking the keepalive
// goroutine.
c.mu.Lock()
client := c.Client
c.mu.Unlock()
if client == nil {
return nil, ConnectionError{fmt.Errorf("ssh is disconnected")}
}
type result struct {
session *Session
err error
}
// Use a buffered channel to ensure that sending the first element doesn't
// block and cause the goroutine to leak in the case where the context gets
// cancelled before we receive on the channel.
ch := make(chan result, 1)
go func() {
session, err := client.NewSession()
if err != nil {
ch <- result{session: nil, err: err}
return
}
session.Stdout = stdout
session.Stderr = stderr
s := Session{session: session}
ch <- result{session: &s, err: nil}
}()
select {
case r := <-ch:
if r.err != nil {
return nil, ConnectionError{fmt.Errorf("failed to start ssh session: %w", r.err)}
}
return r.session, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Start a command on the remote device and write STDOUT and STDERR to the
// passed in io.Writers.
func (c *Conn) Start(ctx context.Context, command []string, stdout io.Writer, stderr io.Writer) (*Session, error) {
session, err := c.makeSession(ctx, stdout, stderr)
if err != nil {
return nil, err
}
logger.Debugf(ctx, "starting over ssh: %s", command)
if err := session.Start(ctx, command); err != nil {
session.Close()
return nil, err
}
return session, nil
}
// Run a command to completion on the remote device and write STDOUT and STDERR
// to the passed in io.Writers.
func (c *Conn) Run(ctx context.Context, command []string, stdout io.Writer, stderr io.Writer) error {
session, err := c.makeSession(ctx, stdout, stderr)
if err != nil {
return err
}
defer session.Close()
logger.Debugf(ctx, "running over ssh: %v", command)
if err := session.Run(ctx, command); err != nil {
if ctx.Err() != nil {
// Don't bother logging the error if the context was canceled.
return err
}
var log string
var level logger.LogLevel
switch e := err.(type) {
case *ssh.ExitError:
log = fmt.Sprintf("ssh command failed with exit code %d", e.ExitStatus())
level = logger.DebugLevel
case *ssh.ExitMissingError:
log = "ssh command failed with no exit code"
level = logger.DebugLevel
err = ConnectionError{err}
default:
log = fmt.Sprintf("ssh command failed with error: %v", err)
level = logger.ErrorLevel
}
logger.Logf(ctx, level, "%s: %v", log, command)
return err
}
logger.Debugf(ctx, "successfully ran over ssh: %v", command)
return nil
}
// Close the ssh client connections.
func (c *Conn) Close() {
select {
// Only signal we are shutting down if it hasn't already been closed.
case <-c.shuttingDown:
// Notify the keepalive goroutine we are shutting down.
default:
close(c.shuttingDown)
}
c.disconnect()
}
// RegisterDisconnectListener adds a waiter that gets notified when the ssh
// client is disconnected.
func (c *Conn) RegisterDisconnectListener(ch chan struct{}) {
c.mu.Lock()
if c.Client == nil {
close(ch)
} else {
c.disconnectionListeners = append(c.disconnectionListeners, ch)
}
c.mu.Unlock()
}
// NewSFTPClient returns an SFTP client that uses the currently underlying
// ssh.Client. The SFTP client will become unresponsive if the ssh connection is
// closed and/or refreshed.
func (c *Conn) NewSFTPClient() (*sftp.Client, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.Client == nil {
return nil, errors.New("ssh connection is closed, cannot create new SFTP client")
}
return sftp.NewClient(c.Client)
}
// disconnect from ssh, and notify anyone waiting for disconnection.
func (c *Conn) disconnect() {
c.mu.Lock()
defer c.mu.Unlock()
if c.Client != nil {
c.Client.Close()
c.Client = nil
}
for _, listener := range c.disconnectionListeners {
close(listener)
}
c.disconnectionListeners = []chan struct{}{}
}
// Send periodic keepalives. If we don't do this, then we might not observe
// the server side disconnecting from us.
// A keepalive ping is sent whenever we receive something on the `ticks`
// channel.
// After sending a ping, we call the `timeout` function and wait until either we
// receive a response or we receive something on the channel returned by
// `timeout`.
func (c *Conn) keepalive(ctx context.Context, ticks <-chan time.Time, timeout func() <-chan time.Time) {
if timeout == nil {
timeout = func() <-chan time.Time {
return nil
}
}
for {
// Sleep until the next poll cycle or until the client is closed.
select {
case <-ticks:
case <-c.shuttingDown:
return
}
c.mu.Lock()
client := c.Client
c.mu.Unlock()
// Exit early if the client's already been shut down.
if client == nil {
return
}
// SendRequest can actually hang if the server stops responding
// in between receiving a keepalive and sending a response (see
// fxbug.dev/47698). To protect against this, we'll emit events in a
// separate goroutine so if we don't get one in the expected
// time we'll disconnect.
ch := make(chan error, 1)
go func() {
// Try to emit a keepalive message. We use a unique
// name to distinguish ourselves from the server-side
// keepalive name to ease debugging. If we get any
// error, reconnect to the server.
_, _, err := client.SendRequest(keepaliveOpenSSH, true, nil)
ch <- err
}()
sendTime := time.Now()
select {
case <-c.shuttingDown:
// Ignore the keepalive result if we are shutting down.
c.disconnect()
case err := <-ch:
// disconnect if we hit an error sending a keepalive.
if err != nil {
// Ignore a spurious error if we tried to send
// a keepalive while the connection was closed
// out from under us.
select {
case <-c.shuttingDown:
default:
logger.Debugf(
ctx,
"error sending keepalive to %s, disconnecting: %s",
c.addr,
err,
)
}
c.disconnect()
return
}
case <-timeout():
timeoutDuration := time.Since(sendTime)
logger.Debugf(ctx, "ssh keepalive timed out after %.3fs, disconnecting", timeoutDuration.Seconds())
c.disconnect()
return
}
}
}
// Session is a wrapper around ssh.Session that allows operations to be canceled.
type Session struct {
session *ssh.Session
}
func (s *Session) Close() {
s.session.Close()
}
func (s *Session) Start(ctx context.Context, command []string) error {
ch := make(chan error, 1)
go func() {
ch <- s.session.Start(strings.Join(command, " "))
}()
select {
case err := <-ch:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (s *Session) Wait(ctx context.Context) error {
ch := make(chan error, 1)
go func() {
ch <- s.session.Wait()
}()
select {
case err := <-ch:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (s *Session) Run(ctx context.Context, command []string) error {
ch := make(chan error, 1)
go func() {
ch <- s.session.Run(strings.Join(command, " "))
}()
select {
case err := <-ch:
return err
case <-ctx.Done():
return ctx.Err()
}
}