blob: 65ccbaa2b19bd1785dab189b859f646f57c84f09 [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package sshutil
import (
"context"
"fmt"
"io"
"log"
"net"
"strings"
"sync"
"time"
"go.fuchsia.dev/fuchsia/tools/lib/retry"
"golang.org/x/crypto/ssh"
)
const (
connectInterval = 5 * time.Second
// Interval between keep-alive pings.
defaultKeepaliveInterval = 1 * time.Second
// Cancel the connection if a we don't receive a response to a keep-alive
// ping within this amount of time.
defaultKeepaliveDeadline = defaultKeepaliveInterval + 15*time.Second
)
// Client is a wrapper around ssh that supports keepalive and auto-reconnection.
// TODO(fxb/48042): change all usage of sshutil to use this Client type instead
// of ssh.Client.
type Client struct {
*ssh.Client
addr net.Addr
config *ssh.ClientConfig
shuttingDown chan struct{}
// This mutex protects the following fields
mu sync.Mutex
conn net.Conn
disconnectionListeners []chan struct{}
}
// NewClient creates a new ssh client to the address and launches a goroutine to
// send keep-alive pings as long as the client is connected.
func NewClient(ctx context.Context, addr net.Addr, config *ssh.ClientConfig) (*Client, error) {
client, err := connect(ctx, addr, config)
if err != nil {
return nil, err
}
go func() {
t := time.NewTicker(defaultKeepaliveInterval)
defer t.Stop()
timeout := func() <-chan time.Time {
return time.After(defaultKeepaliveDeadline)
}
client.keepalive(t.C, timeout)
}()
return client, nil
}
// connect continously attempts to connect to a remote server, and returns an
// ssh client if successful, or errs out if the context is canceled.
func connect(ctx context.Context, addr net.Addr, config *ssh.ClientConfig) (*Client, error) {
var client *ssh.Client
var conn net.Conn
err := retry.Retry(ctx, retry.NewConstantBackoff(connectInterval), func() error {
log.Printf("trying to connect to %s...", addr)
var err error
client, conn, err = connectToSSH(ctx, addr, config)
if err != nil {
log.Printf("failed to connect, will try again in %s: %s", connectInterval, err)
return err
}
log.Printf("connected to %s", addr)
return nil
}, nil)
if err != nil {
return nil, err
}
return &Client{
Client: client,
addr: addr,
config: config,
conn: conn,
shuttingDown: make(chan struct{}),
}, nil
}
func connectToSSH(ctx context.Context, addr net.Addr, config *ssh.ClientConfig) (*ssh.Client, net.Conn, error) {
d := net.Dialer{Timeout: config.Timeout}
conn, err := d.DialContext(ctx, "tcp", addr.String())
if err != nil {
return nil, nil, err
}
// We made a TCP connection, now establish an SSH connection over it.
clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr.String(), config)
if err != nil {
if closeErr := conn.Close(); closeErr != nil {
return nil, nil, fmt.Errorf(
"error closing connection: %v; original error: %w", closeErr, err)
}
return nil, nil, err
}
return ssh.NewClient(clientConn, chans, reqs), conn, nil
}
func (c *Client) makeSession(ctx context.Context, stdout io.Writer, stderr io.Writer) (*Session, error) {
// Temporarily grab the lock and make a copy of the client. This
// prevents a long running `Run` command from blocking the keep-alive
// goroutine.
c.mu.Lock()
client := c.Client
c.mu.Unlock()
if client == nil {
return nil, fmt.Errorf("ssh is disconnected")
}
type result struct {
session *Session
err error
}
// Use a buffered channel to ensure that sending the first element doesn't
// block and cause the goroutine to leak in the case where the context gets
// cancelled before we receive on the channel.
ch := make(chan result, 1)
go func() {
session, err := client.NewSession()
if err != nil {
ch <- result{session: nil, err: err}
return
}
session.Stdout = stdout
session.Stderr = stderr
s := Session{session: session}
ch <- result{session: &s, err: nil}
}()
select {
case r := <-ch:
return r.session, r.err
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Start a command on the remote device and write STDOUT and STDERR to the
// passed in io.Writers.
func (c *Client) Start(ctx context.Context, command []string, stdout io.Writer, stderr io.Writer) (*Session, error) {
session, err := c.makeSession(ctx, stdout, stderr)
if err != nil {
return nil, err
}
log.Printf("spawning: %s", command)
if err := session.Start(ctx, command); err != nil {
session.Close()
return nil, err
}
return session, nil
}
// Run a command to completion on the remote device and write STDOUT and STDERR
// to the passed in io.Writers.
func (c *Client) Run(ctx context.Context, command []string, stdout io.Writer, stderr io.Writer) error {
session, err := c.makeSession(ctx, stdout, stderr)
if err != nil {
return err
}
defer session.Close()
log.Printf("running: %s", command)
return session.Run(ctx, command)
}
// Close the ssh client connections.
func (c *Client) Close() {
// Notify the keepalive goroutine we are shutting down.
close(c.shuttingDown)
c.disconnect()
}
// RegisterDisconnectListener adds a waiter that gets notified when the ssh
// client is disconnected.
func (c *Client) RegisterDisconnectListener(ch chan struct{}) {
c.mu.Lock()
if c.Client == nil {
close(ch)
} else {
c.disconnectionListeners = append(c.disconnectionListeners, ch)
}
c.mu.Unlock()
}
// disconnect from ssh, and notify anyone waiting for disconnection.
func (c *Client) disconnect() {
c.mu.Lock()
defer c.mu.Unlock()
if c.Client != nil {
c.Client.Close()
c.Client = nil
}
for _, listener := range c.disconnectionListeners {
close(listener)
}
c.disconnectionListeners = []chan struct{}{}
}
// Send periodic keep-alives. If we don't do this, then we might not observe
// the server side disconnecting from us.
// A keep-alive ping is sent whenever we receive something on the `ticks`
// channel.
// After sending a ping, we call the `timeout` function and wait until either we
// recieve a response or we receive something on the channel returned by
// `timeout`.
func (c *Client) keepalive(ticks <-chan time.Time, timeout func() <-chan time.Time) {
if timeout == nil {
timeout = func() <-chan time.Time {
return nil
}
}
for {
c.mu.Lock()
client := c.Client
conn := c.conn
c.mu.Unlock()
// Exit early if the client's already been shut down.
if client == nil {
return
}
// SendRequest can actually hang if the server stops responding
// in between receving a keepalive and sending a response (see
// fxb/47698). To protect against this, we'll emit events in a
// separate goroutine so if we don't get one in the expected
// time we'll disconnect.
ch := make(chan error, 1)
go func() {
// ssh keepalive is not completely reliable. So in
// addition to emitting it, we'll also set a tcp
// deadline to timeout if we don't get a keepalive
// response within some period of time.
conn.SetDeadline(time.Now().Add(defaultKeepaliveDeadline))
// Try to emit a keepalive message. We use a unique
// name to distinguish ourselves from the server-side
// keepalive name to ease debugging. If we get any
// error, reconnect to the server.
_, _, err := client.SendRequest("keepalive@fuchsia.com", true, nil)
ch <- err
}()
select {
case <-c.shuttingDown:
// Ignore the keepalive result if we are shutting down.
c.disconnect()
case err := <-ch:
// disconnect if we hit an error sending a keepalive.
if err != nil {
log.Printf("error sending keepalive to %s, disconnecting: %s", c.addr, err)
c.disconnect()
return
}
case <-timeout():
log.Printf("timed out sending keepalive, disconnecting")
c.disconnect()
return
}
// Otherwise, sleep until the next poll cycle.
select {
case <-ticks:
case <-c.shuttingDown:
return
}
}
}
// Session is a wrapper around ssh.Session that allows operations to be canceled.
type Session struct {
session *ssh.Session
}
func (s *Session) Close() {
s.session.Close()
}
func (s *Session) Start(ctx context.Context, command []string) error {
ch := make(chan error, 1)
go func() {
ch <- s.session.Start(strings.Join(command, " "))
}()
select {
case err := <-ch:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (s *Session) Wait(ctx context.Context) error {
ch := make(chan error, 1)
go func() {
ch <- s.session.Wait()
}()
select {
case err := <-ch:
return err
case <-ctx.Done():
return ctx.Err()
}
}
func (s *Session) Run(ctx context.Context, command []string) error {
ch := make(chan error, 1)
go func() {
ch <- s.session.Run(strings.Join(command, " "))
}()
select {
case err := <-ch:
return err
case <-ctx.Done():
return ctx.Err()
}
}