blob: 58f4b8897ee04263b673ee3d912b91c5169c12fe [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package sshutil
import (
"bytes"
"context"
"errors"
"io"
"net"
"testing"
"time"
"go.fuchsia.dev/fuchsia/tools/lib/retry"
"golang.org/x/crypto/ssh"
)
const testTimeout = 1 * time.Second
func setUpConn(
ctx context.Context,
t *testing.T,
onNewChannel func(ssh.NewChannel),
onRequest func(*ssh.Request),
) (*Conn, *sshServer) {
server, err := startSSHServer(onNewChannel, onRequest)
if err != nil {
t.Fatalf("failed to start ssh server: %v", err)
}
t.Cleanup(server.stop)
conn, err := connect(ctx, server.addr, server.clientConfig, retry.NoRetries())
if err != nil {
t.Fatalf("failed to create conn: %v", err)
}
t.Cleanup(conn.Close)
return conn, server
}
func assertChannelClosed(t *testing.T, ch chan struct{}, errorMessage string) {
select {
case <-ch:
case <-time.After(testTimeout):
t.Errorf(errorMessage)
}
}
func TestKeepalive(t *testing.T) {
ctx := context.Background()
t.Run("sends pings when timer fires", func(t *testing.T) {
requestsReceived := make(chan *ssh.Request, 1)
conn, _ := setUpConn(ctx, t, nil, func(req *ssh.Request) {
if !req.WantReply {
t.Errorf("keepalive pings must have WantReply set")
}
requestsReceived <- req
req.Reply(true, []byte{})
})
// Sending on this channel triggers a keepalive ping.
keepaliveTicks := make(chan time.Time)
go conn.keepalive(ctx, keepaliveTicks, nil)
keepaliveTicks <- time.Now()
select {
case <-requestsReceived:
case <-time.After(testTimeout):
t.Errorf("didn't receive keepalive ping after trigger sent")
}
})
t.Run("disconnects conn if keepalive times out", func(t *testing.T) {
conn, _ := setUpConn(ctx, t, nil, func(req *ssh.Request) {
req.Reply(true, []byte{})
})
disconnects := make(chan struct{})
conn.RegisterDisconnectListener(disconnects)
// Sending on this channel triggers the timeout handling mechanism for
// the next keepalive ping.
keepaliveTimeouts := make(chan time.Time, 1)
keepaliveTimeouts <- time.Now()
keepaliveTicks := make(chan time.Time)
go conn.keepalive(ctx, keepaliveTicks, func() <-chan time.Time {
return keepaliveTimeouts
})
keepaliveTicks <- time.Now()
assertChannelClosed(t, disconnects, "keepalive failure should have disconnected the conn")
})
t.Run("disconnects conn if keepalive fails", func(t *testing.T) {
conn, server := setUpConn(ctx, t, nil, func(req *ssh.Request) {
req.Reply(true, []byte{})
})
disconnects := make(chan struct{})
conn.RegisterDisconnectListener(disconnects)
// The first keepalive request should fail immediately if the server is
// stopped.
server.stop()
keepaliveTicks := make(chan time.Time)
keepaliveComplete := make(chan struct{})
go func() {
conn.keepalive(ctx, keepaliveTicks, nil)
close(keepaliveComplete)
}()
keepaliveTicks <- time.Now()
assertChannelClosed(t, disconnects, "a keepalive failure didn't disconnect the conn")
assertChannelClosed(t, keepaliveComplete, "a keepalive failure didn't terminate the keepalive goroutine")
})
t.Run("stops sending when conn is closed", func(t *testing.T) {
conn, _ := setUpConn(ctx, t, nil, func(req *ssh.Request) {
req.Reply(true, []byte{})
})
keepaliveComplete := make(chan struct{})
go func() {
conn.keepalive(ctx, nil, nil)
close(keepaliveComplete)
}()
disconnects := make(chan struct{})
conn.RegisterDisconnectListener(disconnects)
conn.Close()
assertChannelClosed(t, disconnects, "conn.Close() didn't disconnect the conn")
assertChannelClosed(t, keepaliveComplete, "conn.Close() didn't terminate the keepalive goroutine")
})
}
func TestRun(t *testing.T) {
ctx := context.Background()
t.Run("runs a command", func(t *testing.T) {
// Set up a server that will respond to a command:
//
// "pass": with "pass stdout" as STDOUT, "pass stderr" as STDERR.
// "fail": with "fail stdout" as STDOUT, "failstderr" as STDERR.
client, _ := setUpClient(
ctx,
t,
onNewExecChannel(func(cmd string, stdout io.Writer, stderr io.Writer) int {
switch cmd {
case "pass":
stdout.Write([]byte("pass stdout"))
stderr.Write([]byte("pass stderr"))
return 0
case "fail":
stdout.Write([]byte("fail stdout"))
stderr.Write([]byte("fail stderr"))
return 1
default:
t.Errorf("unexpected command %q", cmd)
return 255
}
}),
nil,
)
check := func(cmd string, expectedExitStatus int, expectedStdout string, expectedStderr string) {
var stdout bytes.Buffer
var stderr bytes.Buffer
err := client.Run(ctx, []string{cmd}, &stdout, &stderr)
if expectedExitStatus == 0 {
if err != nil {
t.Errorf("command %q failed: %v", cmd, err)
}
} else if err != nil {
if e, ok := err.(*ssh.ExitError); !ok || e.ExitStatus() != expectedExitStatus {
t.Errorf("command %q failed: %v", cmd, err)
}
}
actualStdout := stdout.String()
actualStderr := stderr.String()
if expectedStdout != actualStdout {
t.Errorf("expected stdout for %q to be %q, not %q", cmd, expectedStdout, actualStdout)
}
if expectedStderr != actualStderr {
t.Errorf("expected stderr for %q to be %q, not %q", cmd, expectedStderr, actualStderr)
}
}
check("pass", 0, "pass stdout", "pass stderr")
check("fail", 1, "fail stdout", "fail stderr")
})
t.Run("exits early if context canceled during handshake", func(t *testing.T) {
accepted := make(chan struct{})
done := make(chan struct{})
defer close(done)
// Spawn a server goroutine that will accept a connection, but never read
// or write to the socket.
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen on port: %v", err)
}
defer listener.Close()
serverErrs := make(chan error)
go func() {
conn, err := listener.Accept()
if err != nil {
serverErrs <- err
return
}
close(accepted)
// Wait for the test to complete before closing the
// connection. Otherwise we'll race with the OS
// observing the closed connection with the context to
// be canceled.
<-done
conn.Close()
}()
_, clientConfig, err := genSSHConfig()
if err != nil {
t.Fatalf("failed to create ssh config: %v", err)
}
// In order to test that we can break out of a stuck client, we
// can use either a context.WithCancel or context.WithDeadline.
// Since these both have similar cancelation mechanisms, we can
// use either to verify we can interrupt the connection. Since
// it's picky to pick the right deadline values in a test,
// we'll use a context.WithCancel.
connectCtx, cancel := context.WithCancel(ctx)
defer cancel()
// Spawn an ssh client goroutine, that will connect to the server, and err
// out when the connection is canceled.
connectErrs := make(chan error)
go func() {
client, err := connect(connectCtx, listener.Addr(), clientConfig, retry.NoRetries())
if client != nil {
client.Close()
}
connectErrs <- err
}()
// Wait for the connection to be accepted.
select {
case <-time.After(testTimeout):
t.Fatalf("server didn't accept the connection in time")
return
case err := <-serverErrs:
if err != nil {
t.Errorf("server failed to accept connection: %v", err)
}
case <-accepted:
}
// Now that we know the connection has been accepted, we can
// cancel the context to cause the `connect()` function to err out.
cancel()
// Wait for the connection to be canceled.
select {
case <-time.After(testTimeout):
t.Errorf("canceling the context should cause connect() to exit")
case err := <-connectErrs:
if !errors.Is(err, context.Canceled) {
t.Errorf("context was canceled but connect() returned wrong error: %v", err)
}
}
})
t.Run("exits early if context canceled while creating session", func(t *testing.T) {
// By not passing an `onNewChannel` function we ensure that the command
// will hang until the context is canceled.
conn, _ := setUpConn(ctx, t, nil, nil)
ctx, cancel := context.WithCancel(ctx)
errs := make(chan error)
go func() {
errs <- conn.Run(ctx, []string{"foo"}, nil, nil)
}()
cancel()
select {
case <-time.After(testTimeout):
t.Errorf("canceling the context should cause Run() to exit")
case err := <-errs:
if !errors.Is(err, context.Canceled) {
t.Errorf("context was canceled but Run() returned wrong error: %v", err)
}
}
})
t.Run("exits early if session creation fails", func(t *testing.T) {
conn, server := setUpConn(ctx, t, nil, nil)
server.stop()
errs := make(chan error)
go func() {
errs <- conn.Run(ctx, []string{"foo"}, nil, nil)
}()
select {
case <-time.After(testTimeout):
t.Errorf("Run() should exit if the server becomes unavailable")
case err := <-errs:
if err == nil {
t.Errorf("Run() should return an error if the server is unavailable")
}
}
})
}