blob: 2941393204cc14f64699aa229801463d147e6dd7 [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package sshutil
import (
"bytes"
"context"
"errors"
"io"
"net"
"strconv"
"sync/atomic"
"testing"
"golang.org/x/crypto/ssh"
"go.fuchsia.dev/fuchsia/tools/lib/retry"
)
func setUpClient(
ctx context.Context,
t *testing.T,
onNewChannel func(ssh.NewChannel),
onRequest func(ssh.Channel, *ssh.Request),
) (*Client, *sshServer) {
server, err := startSSHServer(ctx, onNewChannel, onRequest)
if err != nil {
t.Fatalf("failed to start ssh server: %s", err)
}
t.Cleanup(func() {
if err := server.stop(); err != nil {
t.Error(err)
}
})
client, err := NewClient(
ctx,
ConstantAddrResolver{
Addr: server.listener.Addr(),
},
server.clientConfig,
retry.NoRetries())
if err != nil {
t.Fatalf("failed to create client: %s", err)
}
t.Cleanup(client.Close)
return client, server
}
func TestReconnect(t *testing.T) {
ctx := context.Background()
t.Run("can run a command before and after reconnection", func(t *testing.T) {
var execCount int64
client, _ := setUpClient(
ctx,
t,
nil,
onExecRequest(func(cmd string, stdout io.Writer, stderr io.Writer) int {
expected := strconv.Itoa(int(atomic.AddInt64(&execCount, 1)))
if expected != cmd {
t.Fatalf("expected exec cmd to be %q, not %q", expected, cmd)
}
stdout.Write([]byte(expected))
stderr.Write([]byte(expected))
return 0
}),
)
// Check we can run a command before reconnecting.
var stdout bytes.Buffer
var stderr bytes.Buffer
if err := client.Run(ctx, []string{"1"}, &stdout, &stderr); err != nil {
t.Errorf("failed to run a command: %s", err)
}
if execCount != 1 {
t.Errorf("expected exec count to be 1, not %d", execCount)
}
if stdout.String() != "1" {
t.Errorf("expected stdout to be \"1\", not %q", stdout.String())
}
if stderr.String() != "1" {
t.Errorf("expected stderr to be \"1\", not %q", stdout.String())
}
client.Close()
<-client.DisconnectionListener()
if err := client.Reconnect(ctx); err != nil {
t.Errorf("failed to reconnect: %s", err)
}
// Check we can still run a command after reconnecting.
stdout.Reset()
stderr.Reset()
if err := client.Run(ctx, []string{"2"}, &stdout, &stderr); err != nil {
t.Errorf("failed to run a command: %s", err)
}
if execCount != 2 {
t.Errorf("expected exec count to be 2, not %d", execCount)
}
if stdout.String() != "2" {
t.Errorf("expected stdout to be \"2\", not %q", stdout.String())
}
if stderr.String() != "2" {
t.Errorf("expected stderr to be \"2\", not %q", stdout.String())
}
})
}
func TestCloseDuringConnection(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
serverConfig, clientConfig, err := genSSHConfig()
if err != nil {
t.Fatalf("failed to create ssh config: %s", err)
}
connected := make(chan struct{})
reconnected := make(chan struct{})
done := make(chan struct{})
defer close(done)
// Spawn a server goroutine that will accept a connection, but never read
// or write to the socket.
listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("failed to listen on port: %s", err)
}
defer func() {
if err := listener.Close(); err != nil {
t.Error(err)
}
}()
serverErrs := make(chan error)
go func() {
// `NewClient` blocks until the first connection is made, so
// we'll accept then close the connection.
tcpConn, err := listener.Accept()
if err != nil {
serverErrs <- err
return
}
sshConn, incomingChannels, _, err := ssh.NewServerConn(tcpConn, serverConfig)
if err != nil {
serverErrs <- err
return
}
// Accept the keepalive channel.
newChannel := <-incomingChannels
if _, _, err := newChannel.Accept(); err != nil {
serverErrs <- err
return
}
<-connected
if err := sshConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
t.Error(err)
}
// Accept the reconnection attempt, but never respond to it.
tcpConn, err = listener.Accept()
if err != nil {
serverErrs <- err
return
}
close(reconnected)
// Wait for the test to complete before closing the
// connection. Otherwise we'll race with the OS
// observing the closed connection with the context to
// be canceled.
<-done
if err := tcpConn.Close(); err != nil {
t.Error(err)
}
}()
// Spawn a goroutine to establish the initial client.
type clientResult struct {
client *Client
err error
}
clientChan := make(chan clientResult)
go func() {
client, err := NewClient(
ctx,
ConstantAddrResolver{
Addr: listener.Addr(),
},
clientConfig,
retry.NoRetries(),
)
clientChan <- clientResult{
client: client,
err: err,
}
}()
// Wait for the connection to succeed.
var client *Client
select {
case err := <-serverErrs:
t.Fatalf("server failed to accept connection: %s", err)
case res := <-clientChan:
if res.err != nil {
t.Fatalf("client failed to connect: %s", err)
}
client = res.client
}
// Signal to the server that we connected so it can switch to waiting
// for the reconnection attempt.
close(connected)
reconnectCtx, reconnectCancel := context.WithCancel(ctx)
reconnectErrs := make(chan error)
// Spawn a reconnection attempt. This should never succeed
go func() {
reconnectErrs <- client.Reconnect(reconnectCtx)
}()
select {
case err := <-serverErrs:
t.Fatalf("server failed to accept reconnection: %s", err)
case err := <-reconnectErrs:
t.Fatalf("client failed to reconnect: %s", err)
case <-reconnected:
}
// Close the connection, which should not block.
clientClosed := make(chan struct{})
go func() {
client.Close()
close(clientClosed)
}()
// The close goroutine should succeed.
<-clientClosed
// The reconnection goroutine won't error out until the context is canceled.
reconnectCancel()
select {
case err := <-reconnectErrs:
if !errors.Is(err, context.Canceled) {
t.Fatalf("client reconnection should have been canceled, not %v", err)
}
}
}