blob: 34d0d17f38e9e36e84b4aac188fead8b328345ea [file] [log] [blame]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package tunnel
import (
"context"
"fmt"
"os"
"os/exec"
"strings"
"testing"
)
func TestGenerateDefaultSSHConfig(t *testing.T) {
var tests = []struct {
template string
remote string
deviceIP string
repoPort int
tunnelPorts []int
sshControlPath string
verbose bool
expectedSSHConfig string
expectedErrMsg string
}{
{
template: "{{.Remote}} {{.DeviceIP}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
verbose: false,
expectedSSHConfig: "fake-remote-hostname fake-IP-address",
},
{
template: "{{.DeviceIP}}{{.DeviceIP}}{{.DeviceIP}}{{.DeviceIP}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
verbose: false,
expectedSSHConfig: "fake-IP-addressfake-IP-addressfake-IP-addressfake-IP-address",
},
{
template: "{{.Remote}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
verbose: false,
expectedSSHConfig: "fake-remote-hostname",
},
{
template: "{{.Remote}} {{.DeviceIP}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
verbose: true,
expectedSSHConfig: "fake-remote-hostname fake-IP-address" + DebugLoggingSSHConfig,
},
{
template: "{{.DeviceIP}}{{.DeviceIP}}{{.DeviceIP}}{{.DeviceIP}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
verbose: true,
expectedSSHConfig: "fake-IP-addressfake-IP-addressfake-IP-addressfake-IP-address" + DebugLoggingSSHConfig,
},
{
template: "{{.Remote}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
verbose: true,
expectedSSHConfig: "fake-remote-hostname" + DebugLoggingSSHConfig,
},
{
template: "{{.TunnelPath}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
sshControlPath: "somthing",
verbose: false,
expectedSSHConfig: sshControlPath,
},
{
template: "",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
sshControlPath: "somthing",
verbose: false,
expectedErrMsg: "empty SSH config generated",
},
{
template: "{{.RepoPort}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 9999,
tunnelPorts: []int{},
sshControlPath: "somthing",
verbose: false,
expectedSSHConfig: "9999",
},
{
template: "{{range .TunnelPorts}}{{.}},{{end}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
tunnelPorts: []int{1234, 8887, 8888, 9999},
sshControlPath: "somthing",
verbose: false,
expectedSSHConfig: "1234,8887,9999,",
},
{
template: "{{.TunnelPath}}",
remote: "fake-remote-hostname",
deviceIP: "fake-IP-address",
repoPort: 8083,
tunnelPorts: []int{22, 8888, 21, 9999},
sshControlPath: "somthing",
verbose: false,
expectedErrMsg: "Cannot create SSH config with protected ports: 22, 21",
},
}
for _, test := range tests {
sshConfig, err := GenerateSSHConfig(test.template, test.remote, test.deviceIP, test.repoPort, test.tunnelPorts, test.verbose)
if err != nil {
if err.Error() != test.expectedErrMsg {
t.Errorf("Unexpected error calling GenerateDefaultSSHConfig: %s", err)
}
}
if string(sshConfig) != test.expectedSSHConfig {
t.Errorf("GenerateDefaultSSHConfig() got %s, want %s",
string(sshConfig), test.expectedSSHConfig)
}
}
repoPort := 8083
tunnelPorts := []int{8888, 9013, 1234, 8083, 9865}
expectedTunnelPorts := []int{9013, 1234, 9865, 8888}
unexpectedTunnelPorts := []int{8083}
realConfig, err := GenerateSSHConfig(DefaultSSHConfigTemplate, "fake-remote-hostname", "fake-IP-address", repoPort, tunnelPorts, false)
if err != nil {
t.Errorf("Error calling GenerateDefaultSSHConfig: %s", err)
}
for _, port := range expectedTunnelPorts {
if !strings.Contains(string(realConfig), fmt.Sprintf("RemoteForward %d [fake-IP-address]:%d", port, port)) {
t.Errorf("Expected to find %d forwarded in the ssh config", port)
}
}
for _, port := range unexpectedTunnelPorts {
if strings.Contains(string(realConfig), fmt.Sprintf("RemoteForward %d [fake-IP-address]:%d", port, port)) {
t.Errorf("Expected NOT to find %d forwarded as an extra port in the ssh config: %s", port, string(realConfig))
}
}
}
func TestStart(t *testing.T) {
var tests = []struct {
sshPath string
sshConfigPath string
remote string
expectedArgs []string
}{
{
sshPath: "fake/path/to/ssh",
sshConfigPath: "fake/ssh/config/path",
remote: "fake-remote-hostname",
expectedArgs: []string{
"fake/path/to/ssh",
"-F",
"fake/ssh/config/path",
"fake-remote-hostname",
"-n",
"echo",
"Tunnel is established",
"&&",
"echo",
ffxAddRemoteTargetMessage,
"&&",
"sleep",
"infinity",
},
},
}
for _, test := range tests {
cmd, err := Cmd(test.sshPath, test.sshConfigPath, test.remote)
if err != nil {
t.Errorf("Error calling Start: %s", err)
}
if cmd.Path != test.sshPath {
t.Errorf("Got command path %s, want %s",
cmd.Path, test.sshPath)
}
if len(cmd.Args) != len(test.expectedArgs) {
t.Errorf("Got %d args, want %d args",
len(test.expectedArgs), len(cmd.Args))
}
for i, expectedArg := range test.expectedArgs {
if i > len(cmd.Args)-1 {
t.Errorf("Expected arg %s but found no arg", expectedArg)
continue
}
if cmd.Args[i] != expectedArg {
t.Errorf("Got arg %s, want %s",
cmd.Args[i], expectedArg)
}
if i == len(test.expectedArgs)-1 && len(cmd.Args) > len(test.expectedArgs) {
t.Errorf("Extra unexpected args found: %s", strings.Join(cmd.Args[i+1:], ", "))
}
}
if cmd.Path != test.sshPath {
t.Errorf("Got command path %s, want %s",
cmd.Path, test.sshPath)
}
}
}
// See exec_test.go for details, but effectively this runs the function called TestHelperProcess passing
// the args.
func helperCommandForTestTunnel(command string, s ...string) (cmd *exec.Cmd) {
cs := []string{"-test.run=TestFakeSSH", "--"}
cs = append(cs, command)
cs = append(cs, s...)
cmd = exec.Command(os.Args[0], cs...)
// Set this in the environment, so we can control the result.
cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1")
return cmd
}
// TestFakeSSH isn't a real test. It's used as a helper process that
// mocks calling SSH.
func TestFakeSSH(*testing.T) {
if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" {
return
}
defer os.Exit(0)
args := os.Args
for len(args) > 0 {
if args[0] == "--" {
args = args[1:]
break
}
args = args[1:]
}
if len(args) == 0 {
fmt.Fprintf(os.Stderr, "No command\n")
os.Exit(2)
}
cmd, args := args[0], args[1:]
switch cmd {
case "/fake/ssh":
remote := args[0]
args := args[1:]
if remote == "happy-exit" {
handleHappyExit(args)
} else if remote == "no-existing-ports" {
handleNoExistingPorts(args)
} else {
fmt.Fprintf(os.Stderr, "Unknown mock host: %s", remote)
os.Exit(1)
}
default:
fmt.Fprintf(os.Stderr, "Unknown command %q\n", cmd)
os.Exit(2)
}
}
func handleHappyExit(args []string) {
var expected []string
switch args[0] {
case "ps `pgrep -u $USER sshd` | grep notty":
expected = []string{args[0]}
fmt.Printf("1999999 ? S 0:00 sshd: devuser@notty")
case "ss -ln | grep :8022":
fmt.Printf("tcp LISTEN 0 128 [::]:8022 [::]:*")
expected = []string{"ss -ln | grep :8022"}
case "-O":
expected = []string{"-O",
"exit", "-S", sshControlPath}
default:
fmt.Fprintf(os.Stderr, "Unknown remote command: %s", args[0])
os.Exit(2)
}
for i := range args {
if args[i] != expected[i] {
fmt.Fprintf(os.Stderr,
"Mismatched args index %d. Got %s, expected %s",
i, args[i], expected[i])
fmt.Fprintf(os.Stderr, "Full args got %s, want %s",
args, expected)
os.Exit(3)
}
}
}
func handleNoExistingPorts(args []string) {
var expected []string
switch args[0] {
case "ss -ln | grep :8022":
// Return no text, but the rc = 1.
os.Exit(1)
case "-O":
expected = []string{"-O",
"exit", "-S", sshControlPath}
default:
fmt.Fprintf(os.Stderr, "Unknown remote command: %s",
args[0])
os.Exit(2)
}
for i := range args {
if args[i] != expected[i] {
fmt.Fprintf(os.Stderr,
"Mismatched args index %d. Got %s, expected %s",
i, args[i], expected[i])
fmt.Fprintf(os.Stderr, "Full args got %s, expected %s",
args, expected)
os.Exit(3)
}
}
}
func TestCleanupTunnelHappyPath(t *testing.T) {
ctx := context.Background()
ExecCommand = helperCommandForTestTunnel
defer func() { ExecCommand = exec.Command }()
result, err := CleanupTunnel(ctx, "/fake/ssh", "happy-exit")
if err != nil {
t.Error(err)
}
foundPorts := "Existing port forwarding found on happy-exit"
if !strings.Contains(result, foundPorts) {
t.Errorf("Result does not contain %s. Found %s", foundPorts, result)
}
cleanedRemotePorts := "Existing port forwarding found on happy-exit, Cleaning up sshd sessions remotely"
if !strings.Contains(result, cleanedRemotePorts) {
t.Errorf("Result does not contain %s. Found %s", foundPorts, result)
}
}
func TestCleanupTunnelNoExistingPorts(t *testing.T) {
ctx := context.Background()
ExecCommand = helperCommandForTestTunnel
defer func() { ExecCommand = exec.Command }()
result, err := CleanupTunnel(ctx, "/fake/ssh", "no-existing-ports")
if err != nil {
t.Error(err)
}
expected := ""
if result != expected {
t.Errorf("Unexpected string: %s. Found %s", expected, result)
}
}