[sshutil][netutil] Move ssh/net utilities into separate packages

Change-Id: I418ca1e79166d67687892e946619bcdcca9b0d01
diff --git a/botanist/common.go b/botanist/common.go
index 2b08c78..9d39adc 100644
--- a/botanist/common.go
+++ b/botanist/common.go
@@ -5,34 +5,11 @@
 package botanist
 
 import (
-	"context"
 	"encoding/json"
 	"fmt"
 	"io/ioutil"
-	"net"
-	"time"
-
-	"fuchsia.googlesource.com/tools/netboot"
-	"fuchsia.googlesource.com/tools/retry"
 )
 
-// GetNodeAddress returns the UDP address corresponding to a given node, specifically
-// the netsvc or fuchsia address dependending on the value of `fuchsia`.
-func GetNodeAddress(ctx context.Context, nodename string, fuchsia bool) (*net.UDPAddr, error) {
-	// Retry, as the netstack might not yet be up.
-	var addr *net.UDPAddr
-	var err error
-	n := netboot.NewClient(time.Second)
-	err = retry.Retry(ctx, retry.WithMaxDuration(&retry.ZeroBackoff{}, time.Minute), func() error {
-		addr, err = n.Discover(nodename, fuchsia)
-		return err
-	}, nil)
-	if err != nil {
-		return nil, fmt.Errorf("cannot find node %q: %v", nodename, err)
-	}
-	return addr, nil
-}
-
 // DeviceProperties contains static properties of a hardware device.
 type DeviceProperties struct {
 	// Nodename is the hostname of the device that we want to boot on.
diff --git a/botanist/common_test.go b/botanist/common_test.go
index 94f5785..4b3c9ed 100644
--- a/botanist/common_test.go
+++ b/botanist/common_test.go
@@ -7,6 +7,8 @@
 	"io/ioutil"
 	"os"
 	"testing"
+
+	"fuchsia.googlesource.com/tools/sshutil"
 )
 
 func TestLoadDevicePropertiesSlice(t *testing.T) {
@@ -54,3 +56,81 @@
 		}
 	}
 }
+
+func TestSSHSignersFromDeviceProperties(t *testing.T) {
+	tests := []struct {
+		name        string
+		device1Keys []string
+		device2Keys []string
+		expectedLen int
+		expectErr   bool
+	}{
+		// Valid configs.
+		{"ValidSameKeyConfig", []string{"valid1"}, []string{"valid1"}, 1, false},
+		{"ValidDiffKeysWithDuplicateConfig", []string{"valid1", "valid2"}, []string{"valid1"}, 2, false},
+		{"ValidDiffKeysConfig", []string{"valid1"}, []string{"valid2"}, 2, false},
+		{"ValidEmptyKeysConfig", []string{}, []string{}, 0, false},
+		// Invalid configs.
+		{"InvalidKeyFileConfig", []string{"valid1"}, []string{"invalid"}, 0, true},
+		{"MissingKeyFileConfig", []string{"missing"}, []string{}, 0, true},
+	}
+
+	validKey1, err := sshutil.GeneratePrivateKey()
+	if err != nil {
+		t.Fatalf("Failed to generate private key: %s", err)
+	}
+	validKey2, err := sshutil.GeneratePrivateKey()
+	if err != nil {
+		t.Fatalf("Failed to generate private key: %s", err)
+	}
+	invalidKey := []byte("invalidKey")
+
+	keys := []struct {
+		name        string
+		keyContents []byte
+	}{
+		{"valid1", validKey1}, {"valid2", validKey2}, {"invalid", invalidKey},
+	}
+
+	keyNameToPath := make(map[string]string)
+	keyNameToPath["missing"] = "/path/to/nonexistent/key"
+	for _, key := range keys {
+		tmpfile, err := ioutil.TempFile(os.TempDir(), key.name)
+		if err != nil {
+			t.Fatalf("Failed to create test device properties file: %s", err)
+		}
+		defer os.Remove(tmpfile.Name())
+		if _, err := tmpfile.Write(key.keyContents); err != nil {
+			t.Fatalf("Failed to write to test device properties file: %s", err)
+		}
+		if err := tmpfile.Close(); err != nil {
+			t.Fatal(err)
+		}
+		keyNameToPath[key.name] = tmpfile.Name()
+	}
+
+	for _, test := range tests {
+		var keyPaths1 []string
+		for _, keyName := range test.device1Keys {
+			keyPaths1 = append(keyPaths1, keyNameToPath[keyName])
+		}
+		var keyPaths2 []string
+		for _, keyName := range test.device2Keys {
+			keyPaths2 = append(keyPaths2, keyNameToPath[keyName])
+		}
+		devices := []DeviceProperties{
+			DeviceProperties{"device1", &Config{}, keyPaths1},
+			DeviceProperties{"device2", &Config{}, keyPaths2},
+		}
+		signers, err := SSHSignersFromDeviceProperties(devices)
+		if test.expectErr && err == nil {
+			t.Errorf("Test%v: Expected errors; no errors found", test.name)
+		}
+		if !test.expectErr && err != nil {
+			t.Errorf("Test%v: Expected no errors; found error - %v", test.name, err)
+		}
+		if len(signers) != test.expectedLen {
+			t.Errorf("Test%v: Expected %d signers; found %d", test.name, test.expectedLen, len(signers))
+		}
+	}
+}
diff --git a/botanist/reboot.go b/botanist/reboot.go
index b8812cf..6013675 100644
--- a/botanist/reboot.go
+++ b/botanist/reboot.go
@@ -10,6 +10,8 @@
 
 	"fuchsia.googlesource.com/tools/botanist/pdu/amt"
 	"fuchsia.googlesource.com/tools/botanist/pdu/wol"
+	"fuchsia.googlesource.com/tools/sshutil"
+
 	"golang.org/x/crypto/ssh"
 )
 
@@ -75,7 +77,7 @@
 	}
 
 	ctx := context.Background()
-	client, err := SSHIntoNode(ctx, nodeName, config)
+	client, err := sshutil.ConnectToNode(ctx, nodeName, config)
 	if err != nil {
 		return err
 	}
diff --git a/botanist/ssh.go b/botanist/ssh.go
index 3e1858a..a4b04e5 100644
--- a/botanist/ssh.go
+++ b/botanist/ssh.go
@@ -1,119 +1,15 @@
 // Copyright 2018 The Fuchsia Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style license that can be
 // found in the LICENSE file.
+
 package botanist
 
 import (
-	"context"
-	"crypto/rand"
-	"crypto/rsa"
-	"crypto/x509"
-	"encoding/pem"
-	"fmt"
 	"io/ioutil"
-	"net"
-	"time"
-
-	"fuchsia.googlesource.com/tools/retry"
 
 	"golang.org/x/crypto/ssh"
 )
 
-const (
-	// Default SSH server port.
-	SSHPort = 22
-
-	// Default RSA key size.
-	RSAKeySize = 2048
-
-	// The default timeout for IO operations.
-	defaultIOTimeout = 5 * time.Second
-)
-
-// GeneratePrivateKey generates a private SSH key.
-func GeneratePrivateKey() ([]byte, error) {
-	key, err := rsa.GenerateKey(rand.Reader, RSAKeySize)
-	if err != nil {
-		return nil, err
-	}
-	privateKey := &pem.Block{
-		Type:  "RSA PRIVATE KEY",
-		Bytes: x509.MarshalPKCS1PrivateKey(key),
-	}
-	buf := pem.EncodeToMemory(privateKey)
-
-	return buf, nil
-}
-
-func ConnectViaSSH(ctx context.Context, address net.Addr, config *ssh.ClientConfig) (*ssh.Client, error) {
-	network, err := network(address)
-	if err != nil {
-		return nil, err
-	}
-
-	var client *ssh.Client
-	// TODO: figure out optimal backoff time and number of retries
-	if err := retry.Retry(ctx, retry.WithMaxDuration(&retry.ZeroBackoff{}, 10*time.Second), func() error {
-		var err error
-		client, err = ssh.Dial(network, address.String(), config)
-		return err
-	}, nil); err != nil {
-		return nil, fmt.Errorf("cannot connect to address %q: %v", address, err)
-	}
-
-	return client, nil
-}
-
-// SSHIntoNode connects to the device with the given nodename.
-func SSHIntoNode(ctx context.Context, nodename string, config *ssh.ClientConfig) (*ssh.Client, error) {
-	addr, err := GetNodeAddress(ctx, nodename, true)
-	if err != nil {
-		return nil, err
-	}
-	addr.Port = SSHPort
-	return ConnectViaSSH(ctx, addr, config)
-}
-
-// DefaultSSHConfig returns a basic SSH client configuration.
-func DefaultSSHConfig(privateKey []byte) (*ssh.ClientConfig, error) {
-	signer, err := ssh.ParsePrivateKey(privateKey)
-	if err != nil {
-		return nil, err
-	}
-	return &ssh.ClientConfig{
-		User:            sshUser,
-		Auth:            []ssh.AuthMethod{ssh.PublicKeys(signer)},
-		Timeout:         defaultIOTimeout,
-		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
-	}, nil
-}
-
-// Returns the network to use to SSH into a device.
-func network(address net.Addr) (string, error) {
-	var ip *net.IP
-
-	// We need these type assertions because the net package (annoyingly) doesn't provide
-	// an interface for objects that have an IP address.
-	switch addr := address.(type) {
-	case *net.UDPAddr:
-		ip = &addr.IP
-	case *net.TCPAddr:
-		ip = &addr.IP
-	case *net.IPAddr:
-		ip = &addr.IP
-	default:
-		return "", fmt.Errorf("unsupported address type: %T", address)
-	}
-
-	if ip.To4() != nil {
-		return "tcp", nil // IPv4
-	}
-	if ip.To16() != nil {
-		return "tcp6", nil // IPv6
-	}
-	return "", fmt.Errorf("cannot infer network for IP address %s", ip.String())
-}
-
 // Returns the SSH signers associated with the key paths in the botanist config file if present.
 func SSHSignersFromDeviceProperties(properties []DeviceProperties) ([]ssh.Signer, error) {
 	processedKeys := make(map[string]bool)
diff --git a/botanist/ssh_test.go b/botanist/ssh_test.go
deleted file mode 100644
index 97af842..0000000
--- a/botanist/ssh_test.go
+++ /dev/null
@@ -1,128 +0,0 @@
-// Copyright 2018 The Fuchsia Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-package botanist
-
-import (
-	"io/ioutil"
-	"net"
-	"os"
-	"testing"
-)
-
-func TestNetwork(t *testing.T) {
-	tests := []struct {
-		id      int
-		addr    net.Addr
-		family  string
-		wantErr bool
-	}{
-		// Valid tcp addresses.
-		{1, &net.TCPAddr{IP: net.IPv4(1, 2, 3, 4)}, "tcp", false},
-		{2, &net.UDPAddr{IP: net.IPv4(5, 6, 7, 8)}, "tcp", false},
-		{3, &net.IPAddr{IP: net.IPv4(9, 10, 11, 12)}, "tcp", false},
-
-		// Valid tcp6 addresses.
-		{4, &net.TCPAddr{IP: net.IPv6loopback}, "tcp6", false},
-		{5, &net.UDPAddr{IP: net.ParseIP("2001:db8::1")}, "tcp6", false},
-		{6, &net.IPAddr{IP: net.IPv6linklocalallrouters}, "tcp6", false},
-
-		// Invalid IP addresses
-		{7, &net.TCPAddr{IP: net.IP("")}, "", true},
-		{8, &net.UDPAddr{IP: net.IP("123456")}, "", true},
-		{9, &net.IPAddr{IP: nil}, "", true},
-
-		// Invalid net.AddrType
-		{10, &net.UnixAddr{}, "", true},
-	}
-
-	for _, test := range tests {
-		n, err := network(test.addr)
-		if test.wantErr && err == nil {
-			t.Errorf("Test %d: got no error; want error", test.id)
-		} else if !test.wantErr && err != nil {
-			t.Errorf("Test %d: got error %q; want no error", test.id, err)
-		} else if n != test.family {
-			t.Errorf("Test %d: got %q; want %q", test.id, n, test.family)
-		}
-	}
-}
-
-func TestSSHSignersFromDeviceProperties(t *testing.T) {
-	tests := []struct {
-		name        string
-		device1Keys []string
-		device2Keys []string
-		expectedLen int
-		expectErr   bool
-	}{
-		// Valid configs.
-		{"ValidSameKeyConfig", []string{"valid1"}, []string{"valid1"}, 1, false},
-		{"ValidDiffKeysWithDuplicateConfig", []string{"valid1", "valid2"}, []string{"valid1"}, 2, false},
-		{"ValidDiffKeysConfig", []string{"valid1"}, []string{"valid2"}, 2, false},
-		{"ValidEmptyKeysConfig", []string{}, []string{}, 0, false},
-		// Invalid configs.
-		{"InvalidKeyFileConfig", []string{"valid1"}, []string{"invalid"}, 0, true},
-		{"MissingKeyFileConfig", []string{"missing"}, []string{}, 0, true},
-	}
-
-	validKey1, err := GeneratePrivateKey()
-	if err != nil {
-		t.Fatalf("Failed to generate private key: %s", err)
-	}
-	validKey2, err := GeneratePrivateKey()
-	if err != nil {
-		t.Fatalf("Failed to generate private key: %s", err)
-	}
-	invalidKey := []byte("invalidKey")
-
-	keys := []struct {
-		name        string
-		keyContents []byte
-	}{
-		{"valid1", validKey1}, {"valid2", validKey2}, {"invalid", invalidKey},
-	}
-
-	keyNameToPath := make(map[string]string)
-	keyNameToPath["missing"] = "/path/to/nonexistent/key"
-	for _, key := range keys {
-		tmpfile, err := ioutil.TempFile(os.TempDir(), key.name)
-		if err != nil {
-			t.Fatalf("Failed to create test device properties file: %s", err)
-		}
-		defer os.Remove(tmpfile.Name())
-		if _, err := tmpfile.Write(key.keyContents); err != nil {
-			t.Fatalf("Failed to write to test device properties file: %s", err)
-		}
-		if err := tmpfile.Close(); err != nil {
-			t.Fatal(err)
-		}
-		keyNameToPath[key.name] = tmpfile.Name()
-	}
-
-	for _, test := range tests {
-		var keyPaths1 []string
-		for _, keyName := range test.device1Keys {
-			keyPaths1 = append(keyPaths1, keyNameToPath[keyName])
-		}
-		var keyPaths2 []string
-		for _, keyName := range test.device2Keys {
-			keyPaths2 = append(keyPaths2, keyNameToPath[keyName])
-		}
-		devices := []DeviceProperties{
-			DeviceProperties{"device1", &Config{}, keyPaths1},
-			DeviceProperties{"device2", &Config{}, keyPaths2},
-		}
-		signers, err := SSHSignersFromDeviceProperties(devices)
-		if test.expectErr && err == nil {
-			t.Errorf("Test%v: Expected errors; no errors found", test.name)
-		}
-		if !test.expectErr && err != nil {
-			t.Errorf("Test%v: Expected no errors; found error - %v", test.name, err)
-		}
-		if len(signers) != test.expectedLen {
-			t.Errorf("Test%v: Expected %d signers; found %d", test.name, test.expectedLen, len(signers))
-		}
-	}
-}
diff --git a/cmd/botanist/run.go b/cmd/botanist/run.go
index a5f7137..1136b6b 100644
--- a/cmd/botanist/run.go
+++ b/cmd/botanist/run.go
@@ -17,7 +17,9 @@
 	"fuchsia.googlesource.com/tools/command"
 	"fuchsia.googlesource.com/tools/logger"
 	"fuchsia.googlesource.com/tools/netboot"
+	"fuchsia.googlesource.com/tools/netutil"
 	"fuchsia.googlesource.com/tools/runner"
+	"fuchsia.googlesource.com/tools/sshutil"
 
 	"github.com/google/subcommands"
 	"golang.org/x/crypto/ssh"
@@ -112,7 +114,7 @@
 		}
 	}()
 
-	addr, err := botanist.GetNodeAddress(ctx, nodename, false)
+	addr, err := netutil.GetNodeAddress(ctx, nodename, false)
 	if err != nil {
 		return err
 	}
@@ -134,11 +136,11 @@
 		if err != nil {
 			return err
 		}
-		config, err := botanist.DefaultSSHConfig(privKey)
+		config, err := sshutil.DefaultSSHConfig(privKey)
 		if err != nil {
 			return err
 		}
-		client, err := botanist.SSHIntoNode(ctx, nodename, config)
+		client, err := sshutil.ConnectToNode(ctx, nodename, config)
 		if err != nil {
 			return err
 		}
@@ -221,7 +223,7 @@
 	}
 	var privKeys [][]byte
 	if len(privKeyPaths) == 0 {
-		p, err := botanist.GeneratePrivateKey()
+		p, err := sshutil.GeneratePrivateKey()
 		if err != nil {
 			return err
 		}
diff --git a/cmd/botanist/zedboot.go b/cmd/botanist/zedboot.go
index dd942da..68e9b11 100644
--- a/cmd/botanist/zedboot.go
+++ b/cmd/botanist/zedboot.go
@@ -24,6 +24,7 @@
 	"fuchsia.googlesource.com/tools/command"
 	"fuchsia.googlesource.com/tools/logger"
 	"fuchsia.googlesource.com/tools/netboot"
+	"fuchsia.googlesource.com/tools/netutil"
 	"fuchsia.googlesource.com/tools/retry"
 	"fuchsia.googlesource.com/tools/runner"
 	"fuchsia.googlesource.com/tools/runtests"
@@ -236,7 +237,7 @@
 
 	var addrs []*net.UDPAddr
 	for _, node := range nodes {
-		addr, err := botanist.GetNodeAddress(ctx, node.Nodename, false)
+		addr, err := netutil.GetNodeAddress(ctx, node.Nodename, false)
 		if err != nil {
 			return err
 		}
diff --git a/cmd/testrunner/tester.go b/cmd/testrunner/tester.go
index cac7021..6cbf431 100644
--- a/cmd/testrunner/tester.go
+++ b/cmd/testrunner/tester.go
@@ -10,8 +10,8 @@
 	"io"
 	"path"
 
-	"fuchsia.googlesource.com/tools/botanist"
 	"fuchsia.googlesource.com/tools/runner"
+	"fuchsia.googlesource.com/tools/sshutil"
 	"fuchsia.googlesource.com/tools/testsharder"
 	"golang.org/x/crypto/ssh"
 )
@@ -55,11 +55,11 @@
 }
 
 func NewSSHTester(nodename string, sshKey []byte) (*SSHTester, error) {
-	config, err := botanist.DefaultSSHConfig(sshKey)
+	config, err := sshutil.DefaultSSHConfig(sshKey)
 	if err != nil {
 		return nil, fmt.Errorf("failed to create an SSH client config: %v", err)
 	}
-	client, err := botanist.SSHIntoNode(context.Background(), nodename, config)
+	client, err := sshutil.ConnectToNode(context.Background(), nodename, config)
 	if err != nil {
 		return nil, fmt.Errorf("failed to connect to node %q: %v", nodename, err)
 	}
diff --git a/netutil/netutil.go b/netutil/netutil.go
new file mode 100644
index 0000000..fc3ead3
--- /dev/null
+++ b/netutil/netutil.go
@@ -0,0 +1,31 @@
+// Copyright 2018 The Fuchsia Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+package netutil
+
+import (
+	"context"
+	"fmt"
+	"net"
+	"time"
+
+	"fuchsia.googlesource.com/tools/netboot"
+	"fuchsia.googlesource.com/tools/retry"
+)
+
+// GetNodeAddress returns the UDP address corresponding to a given node, specifically
+// the netsvc or fuchsia address dependending on the value of `fuchsia`.
+func GetNodeAddress(ctx context.Context, nodename string, fuchsia bool) (*net.UDPAddr, error) {
+	// Retry, as the netstack might not yet be up.
+	var addr *net.UDPAddr
+	var err error
+	n := netboot.NewClient(time.Second)
+	err = retry.Retry(ctx, retry.WithMaxDuration(&retry.ZeroBackoff{}, time.Minute), func() error {
+		addr, err = n.Discover(nodename, fuchsia)
+		return err
+	}, nil)
+	if err != nil {
+		return nil, fmt.Errorf("cannot find node %q: %v", nodename, err)
+	}
+	return addr, nil
+}
diff --git a/sshutil/sshutil.go b/sshutil/sshutil.go
new file mode 100644
index 0000000..038ffff
--- /dev/null
+++ b/sshutil/sshutil.go
@@ -0,0 +1,117 @@
+// Copyright 2018 The Fuchsia Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+package sshutil
+
+import (
+	"context"
+	"crypto/rand"
+	"crypto/rsa"
+	"crypto/x509"
+	"encoding/pem"
+	"fmt"
+	"net"
+	"time"
+
+	"fuchsia.googlesource.com/tools/netutil"
+	"fuchsia.googlesource.com/tools/retry"
+
+	"golang.org/x/crypto/ssh"
+)
+
+const (
+	// Default SSH server port.
+	SSHPort = 22
+
+	// Default RSA key size.
+	RSAKeySize = 2048
+
+	// The default timeout for IO operations.
+	defaultIOTimeout = 5 * time.Second
+
+	sshUser = "fuchsia"
+)
+
+// GeneratePrivateKey generates a private SSH key.
+func GeneratePrivateKey() ([]byte, error) {
+	key, err := rsa.GenerateKey(rand.Reader, RSAKeySize)
+	if err != nil {
+		return nil, err
+	}
+	privateKey := &pem.Block{
+		Type:  "RSA PRIVATE KEY",
+		Bytes: x509.MarshalPKCS1PrivateKey(key),
+	}
+	buf := pem.EncodeToMemory(privateKey)
+
+	return buf, nil
+}
+
+func Connect(ctx context.Context, address net.Addr, config *ssh.ClientConfig) (*ssh.Client, error) {
+	network, err := network(address)
+	if err != nil {
+		return nil, err
+	}
+
+	var client *ssh.Client
+	// TODO: figure out optimal backoff time and number of retries
+	if err := retry.Retry(ctx, retry.WithMaxDuration(&retry.ZeroBackoff{}, 10*time.Second), func() error {
+		var err error
+		client, err = ssh.Dial(network, address.String(), config)
+		return err
+	}, nil); err != nil {
+		return nil, fmt.Errorf("cannot connect to address %q: %v", address, err)
+	}
+
+	return client, nil
+}
+
+// ConnectToNode connects to the device with the given nodename.
+func ConnectToNode(ctx context.Context, nodename string, config *ssh.ClientConfig) (*ssh.Client, error) {
+	addr, err := netutil.GetNodeAddress(ctx, nodename, true)
+	if err != nil {
+		return nil, err
+	}
+	addr.Port = SSHPort
+	return Connect(ctx, addr, config)
+}
+
+// DefaultSSHConfig returns a basic SSH client configuration.
+func DefaultSSHConfig(privateKey []byte) (*ssh.ClientConfig, error) {
+	signer, err := ssh.ParsePrivateKey(privateKey)
+	if err != nil {
+		return nil, err
+	}
+	return &ssh.ClientConfig{
+		User:            sshUser,
+		Auth:            []ssh.AuthMethod{ssh.PublicKeys(signer)},
+		Timeout:         defaultIOTimeout,
+		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+	}, nil
+}
+
+// Returns the network to use to SSH into a device.
+func network(address net.Addr) (string, error) {
+	var ip *net.IP
+
+	// We need these type assertions because the net package (annoyingly) doesn't provide
+	// an interface for objects that have an IP address.
+	switch addr := address.(type) {
+	case *net.UDPAddr:
+		ip = &addr.IP
+	case *net.TCPAddr:
+		ip = &addr.IP
+	case *net.IPAddr:
+		ip = &addr.IP
+	default:
+		return "", fmt.Errorf("unsupported address type: %T", address)
+	}
+
+	if ip.To4() != nil {
+		return "tcp", nil // IPv4
+	}
+	if ip.To16() != nil {
+		return "tcp6", nil // IPv6
+	}
+	return "", fmt.Errorf("cannot infer network for IP address %s", ip.String())
+}
diff --git a/sshutil/sshutil_test.go b/sshutil/sshutil_test.go
new file mode 100644
index 0000000..458098b
--- /dev/null
+++ b/sshutil/sshutil_test.go
@@ -0,0 +1,48 @@
+// Copyright 2018 The Fuchsia Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+package sshutil
+
+import (
+	"net"
+	"testing"
+)
+
+func TestNetwork(t *testing.T) {
+	tests := []struct {
+		id      int
+		addr    net.Addr
+		family  string
+		wantErr bool
+	}{
+		// Valid tcp addresses.
+		{1, &net.TCPAddr{IP: net.IPv4(1, 2, 3, 4)}, "tcp", false},
+		{2, &net.UDPAddr{IP: net.IPv4(5, 6, 7, 8)}, "tcp", false},
+		{3, &net.IPAddr{IP: net.IPv4(9, 10, 11, 12)}, "tcp", false},
+
+		// Valid tcp6 addresses.
+		{4, &net.TCPAddr{IP: net.IPv6loopback}, "tcp6", false},
+		{5, &net.UDPAddr{IP: net.ParseIP("2001:db8::1")}, "tcp6", false},
+		{6, &net.IPAddr{IP: net.IPv6linklocalallrouters}, "tcp6", false},
+
+		// Invalid IP addresses
+		{7, &net.TCPAddr{IP: net.IP("")}, "", true},
+		{8, &net.UDPAddr{IP: net.IP("123456")}, "", true},
+		{9, &net.IPAddr{IP: nil}, "", true},
+
+		// Invalid net.AddrType
+		{10, &net.UnixAddr{}, "", true},
+	}
+
+	for _, test := range tests {
+		n, err := network(test.addr)
+		if test.wantErr && err == nil {
+			t.Errorf("Test %d: got no error; want error", test.id)
+		} else if !test.wantErr && err != nil {
+			t.Errorf("Test %d: got error %q; want no error", test.id, err)
+		} else if n != test.family {
+			t.Errorf("Test %d: got %q; want %q", test.id, n, test.family)
+		}
+	}
+}