blob: 97a558bdf53a33e76c60f4ae5bad6c5505673fa0 [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package fuzz
import (
"fmt"
"io"
"io/ioutil"
"net"
"os"
"path"
"strings"
"time"
"github.com/golang/glog"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
// Spins up a local, in-memory SSH server for testing, blocking until the
// server is initialized. Any startup errors will be returned from this
// function. Any runtime errors will be passed back on the error channel. The
// SSH server itself is real, but its subsystems (exec, sftp) are fakes for
// testing.
func startLocalSSHServer() (*SSHConnector, <-chan error, *fakeSftp, error) {
connCh := make(chan *SSHConnector, 1)
errCh := make(chan error, 1)
fakeFs := &fakeSftp{}
go serveSSH(connCh, errCh, fakeFs)
// We expect either a startup error or nil to indicate success
if err := <-errCh; err != nil {
return nil, nil, nil, fmt.Errorf("error starting fake SSH server: %s", err)
}
return <-connCh, errCh, fakeFs, nil
}
func serveSSH(connCh chan<- *SSHConnector, errCh chan<- error, fakeFs *fakeSftp) {
// Note: Even though startLocalSSHServer is going to block on these
// initialization steps , they are inside this goroutine because that is
// the scope where the deferred cleanup belongs
defer close(errCh)
dir, err := ioutil.TempDir("", "clusterfuchsia_test")
if err != nil {
errCh <- fmt.Errorf("error creating temp dir: %s", err)
return
}
defer os.RemoveAll(dir)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
errCh <- fmt.Errorf("error during Listen: %s", err)
return
}
defer listener.Close()
clientKey, err := createSSHKey()
if err != nil {
errCh <- fmt.Errorf("error creating client key: %s", err)
return
}
pemFile := path.Join(dir, "ssh.key")
if err := writeSSHPrivateKeyFile(clientKey, pemFile); err != nil {
errCh <- fmt.Errorf("error writing ssh private key: %s", err)
return
}
serverKey, err := createSSHKey()
if err != nil {
errCh <- fmt.Errorf("error creating server key: %s", err)
return
}
signer, err := ssh.NewSignerFromKey(serverKey)
if err != nil {
errCh <- fmt.Errorf("error configuring server key: %s", err)
return
}
config := &ssh.ServerConfig{
PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions,
error) {
// TODO(fxbug.dev/45424): actually verify the key
return &ssh.Permissions{}, nil
},
}
config.AddHostKey(signer)
// Return a suitable connector for the test to use
addr := listener.Addr().(*net.TCPAddr)
connCh <- &SSHConnector{Host: "127.0.0.1", Port: addr.Port, Key: pemFile}
// Indicate initialization is complete
errCh <- nil
conn, err := listener.Accept()
if err != nil {
errCh <- fmt.Errorf("error during Accept: %s", err)
return
}
defer conn.Close()
server, newChannels, reqCh, err := ssh.NewServerConn(conn, config)
if err != nil {
errCh <- fmt.Errorf("error during ssh handshake: %s", err)
return
}
// We don't expect to get any global requests, and this channel must be serviced
go ssh.DiscardRequests(reqCh)
for newChannel := range newChannels {
if newChannel.ChannelType() != "session" {
newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
errCh <- fmt.Errorf("could not accept channel: %s", err)
return
}
go handleSSHSession(server, requests, channel, errCh, fakeFs)
}
}
func handleSSHSession(server *ssh.ServerConn, requests <-chan *ssh.Request,
channel ssh.Channel, errCh chan<- error, fakeFs *fakeSftp) {
defer server.Close() // TODO(fxbug.dev/47316): is this necessary / the right scope for this?
for req := range requests {
switch req.Type {
case "exec":
cmdlen := req.Payload[3]
cmdline := string(req.Payload[4 : 4+cmdlen])
req.Reply(true, nil)
output, exitCode := fakeExec(cmdline)
_, err := channel.Write([]byte(output))
if err != nil {
errCh <- fmt.Errorf("write fail: %s", err)
return
}
if err := channel.CloseWrite(); err != nil { // sends EOF
errCh <- fmt.Errorf("error during closewrite: %s", err)
return
}
if _, err := channel.SendRequest("exit-status", false,
[]byte{0, 0, 0, exitCode}); err != nil {
errCh <- fmt.Errorf("error during exit-status: %s", err)
return
}
if err := channel.Close(); err != nil {
errCh <- fmt.Errorf("error during Close: %s", err)
return
}
case "subsystem":
cmdlen := req.Payload[3]
cmd := string(req.Payload[4 : 4+cmdlen])
if cmd != "sftp" {
req.Reply(false, nil)
continue
}
req.Reply(true, nil)
root := makeSftpHandler(fakeFs)
sftpServer := sftp.NewRequestServer(channel, root)
if err := sftpServer.Serve(); err != nil && err != io.EOF {
errCh <- fmt.Errorf("error during sftp serve: %s", err)
return
}
default:
req.Reply(false, nil)
}
}
}
func fakeExec(cmdline string) (stdout string, exitCode byte) {
args := strings.Split(cmdline, " ")
switch args[0] {
case "echo":
return strings.Join(args[1:], " ") + "\n", 0
default:
return "", 127
}
}
// The following is based on pkg/sftp/request-example.go:
type fakeFile struct {
name string
content string
isDir bool
}
// implement os.FileInfo
func (f *fakeFile) Name() string { return path.Base(f.name) }
func (f *fakeFile) Size() int64 { return int64(len(f.content)) }
func (f *fakeFile) Mode() os.FileMode {
mode := os.FileMode(0644)
if f.IsDir() {
mode |= os.ModeDir
}
return mode
}
func (f *fakeFile) ModTime() time.Time { return time.Now() }
func (f *fakeFile) IsDir() bool { return f.isDir }
func (f *fakeFile) Sys() interface{} { return nil }
func (f *fakeFile) WriteAt(p []byte, off int64) (n int, err error) {
// Note: this ignores offset, assuming sequential writes
f.content = f.content + string(p)
return len(p), nil
}
type fakeSftp struct {
files []*fakeFile
}
type listerat []os.FileInfo
func (f listerat) ListAt(ls []os.FileInfo, offset int64) (int, error) {
var n int
if offset >= int64(len(f)) {
return 0, io.EOF
}
n = copy(ls, f[offset:])
if n < len(ls) {
return n, io.EOF
}
return n, nil
}
func (s *fakeSftp) getFile(path string) (*fakeFile, error) {
for _, f := range s.files {
if f.name == path {
return f, nil
}
}
return nil, os.ErrNotExist
}
func (s *fakeSftp) Fileread(r *sftp.Request) (io.ReaderAt, error) {
glog.Infof("sftp read: %v", r)
f, err := s.getFile(r.Filepath)
if err != nil {
return nil, err
}
return strings.NewReader(f.content), nil
}
func (s *fakeSftp) Filewrite(r *sftp.Request) (io.WriterAt, error) {
glog.Infof("sftp write: %v", r)
// Enforce that directory must exist
enclosingDir := path.Dir(r.Filepath)
if enclosingDir != "." {
if dir, err := s.getFile(enclosingDir); err != nil || !dir.isDir {
glog.Errorf("directory doesn't exist: %q", enclosingDir)
return nil, os.ErrNotExist
}
}
// Note: this doesn't handle pre-existing files
f := &fakeFile{name: r.Filepath}
s.files = append(s.files, f)
return f, nil
}
func (s *fakeSftp) Filecmd(r *sftp.Request) error {
glog.Infof("sftp cmd: %v", r)
switch r.Method {
case "Mkdir":
// Make sure it doesn't already exist
if _, err := s.getFile(r.Filepath); err == nil {
return os.ErrExist
}
// Make sure its parent dir exists
if f, err := s.getFile(path.Dir(r.Filepath)); err != nil || !f.isDir {
return os.ErrNotExist
}
f := &fakeFile{name: r.Filepath, isDir: true}
s.files = append(s.files, f)
return nil
default:
return fmt.Errorf("unsupported: %v", r)
}
}
func (s *fakeSftp) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
glog.Infof("sftp list: %v", r)
f, err := s.getFile(r.Filepath)
if err != nil {
return nil, err
}
switch r.Method {
case "Stat":
return listerat([]os.FileInfo{f}), nil
case "List":
if f.isDir {
var children []os.FileInfo
for _, f2 := range s.files {
// Only list files/dirs in the immediate directory
if dir := path.Dir(f2.name); dir == f.name {
children = append(children, f2)
}
}
return listerat(children), nil
}
return listerat([]os.FileInfo{f}), nil
default:
return nil, fmt.Errorf("unsupported: %v", r)
}
}
func makeSftpHandler(handler *fakeSftp) sftp.Handlers {
return sftp.Handlers{FileGet: handler, FilePut: handler, FileCmd: handler, FileList: handler}
}