blob: 3b6247c52aa49b9e96aab3e73fad71c724610121 [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package sshutil
import (
"bytes"
"context"
"io"
"strconv"
"sync/atomic"
"testing"
"go.fuchsia.dev/fuchsia/tools/lib/retry"
"golang.org/x/crypto/ssh"
)
func setUpClient(
ctx context.Context,
t *testing.T,
onNewChannel func(ssh.NewChannel),
onRequest func(*ssh.Request),
) (*Client, *sshServer) {
server, err := startSSHServer(onNewChannel, onRequest)
if err != nil {
t.Fatalf("failed to start ssh server: %v", err)
}
t.Cleanup(server.stop)
client, err := NewClient(ctx, server.addr, server.clientConfig, retry.NoRetries())
if err != nil {
t.Fatalf("failed to create client: %v", err)
}
t.Cleanup(client.Close)
return client, server
}
func TestReconnect(t *testing.T) {
ctx := context.Background()
t.Run("can run a command before and after reconnection", func(t *testing.T) {
var execCount int64
client, _ := setUpClient(
ctx,
t,
onNewExecChannel(func(cmd string, stdout io.Writer, stderr io.Writer) int {
expected := strconv.Itoa(int(atomic.AddInt64(&execCount, 1)))
if expected != cmd {
t.Fatalf("expected exec cmd to be %q, not %q", expected, cmd)
}
stdout.Write([]byte(expected))
stderr.Write([]byte(expected))
return 0
}),
nil,
)
// Check we can run a command before reconnecting.
var stdout bytes.Buffer
var stderr bytes.Buffer
if err := client.Run(ctx, []string{"1"}, &stdout, &stderr); err != nil {
t.Errorf("failed to run a command: %v", err)
}
if execCount != 1 {
t.Errorf("expected exec count to be 1, not %d", execCount)
}
if stdout.String() != "1" {
t.Errorf("expected stdout to be \"1\", not %q", stdout.String())
}
if stderr.String() != "1" {
t.Errorf("expected stderr to be \"1\", not %q", stdout.String())
}
disconnects := make(chan struct{})
client.RegisterDisconnectListener(disconnects)
client.Close()
assertChannelClosed(t, disconnects, "close should have disconnected the client")
if err := client.Reconnect(ctx); err != nil {
t.Errorf("failed to reconnect: %v", err)
}
// Check we can still run a command after reconnecting.
stdout.Reset()
stderr.Reset()
if err := client.Run(ctx, []string{"2"}, &stdout, &stderr); err != nil {
t.Errorf("failed to run a command: %v", err)
}
if execCount != 2 {
t.Errorf("expected exec count to be 2, not %d", execCount)
}
if stdout.String() != "2" {
t.Errorf("expected stdout to be \"2\", not %q", stdout.String())
}
if stderr.String() != "2" {
t.Errorf("expected stderr to be \"2\", not %q", stdout.String())
}
})
}