blob: 192261ced2c8dd2988988e333160ab871105fa6b [file] [log] [blame]
// Copyright 2020 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
// This library is intended only for use by tests that want to spin up a local
// ssh server to test ssh client code. It uses password authentication rather
// than key-based authentication, making it insecure and inappropriate for
// production use.
package sshutil
import (
"context"
"crypto/rand"
"crypto/rsa"
"errors"
"fmt"
"io"
"log"
"net"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
const (
testServerUser = "testuser"
)
type sshServer struct {
// The server's listener.
listener net.Listener
// The configuration that clients can use to connect to the server.
clientConfig *ssh.ClientConfig
// The configuration used by the server when accepting new connections.
serverConfig *ssh.ServerConfig
// onNewChannel is a callback that gets called when the server receives a
// new channel.
onNewChannel func(ssh.NewChannel)
// onNewChannel is a callback that gets called when the server receives a
// new out-of-band request.
onRequest func(ssh.Channel, *ssh.Request)
// g tracks all the current goroutines that are able to serve connections,
// or launch new goroutines that themselves are able to serve connections.
g *errgroup.Group
}
// start launches the server and sets the server's address. It launches a
// goroutine that listens for new connections until stop() is called.
func (s *sshServer) start(ctx context.Context) error {
// We don't care which port the server runs on as long as it doesn't collide
// with another process. Specifying ":0" gives us any available port.
listener, err := net.Listen("tcp", ":0")
if err != nil {
return err
}
s.listener = listener
s.g, ctx = errgroup.WithContext(ctx)
// This goroutine is capable of launching new server goroutines, so the
// server can't be considered shut down if this goroutine is still running.
s.g.Go(func() error {
for {
tcpConn, err := listener.Accept()
if err != nil {
return err
}
conn, incomingChannels, incomingRequests, err := ssh.NewServerConn(tcpConn, s.serverConfig)
if err != nil {
return err
}
// Inner group for this connection.
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return s.serveRequests(ctx, nil, incomingRequests)
})
g.Go(func() error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case newChannel, ok := <-incomingChannels:
if !ok {
return nil
}
if fn := s.onNewChannel; fn != nil {
fn(newChannel)
}
g.Go(func() error {
ch, incomingRequests, err := newChannel.Accept()
if err != nil {
return err
}
defer func() {
if err := ch.Close(); err != nil && !errors.Is(err, io.EOF) {
log.Printf("ch.Close() = %s", err)
}
}()
return s.serveRequests(ctx, ch, incomingRequests)
})
}
}
})
s.g.Go(func() error {
// This might err out if the client is closed first, so don't bother
// checking the return value.
defer func() {
_ = conn.Close()
}()
return g.Wait()
})
}
})
return nil
}
// stop shuts down the server.
func (s *sshServer) stop() error {
if err := s.listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("listener.Close() = %s", err)
}
// Block until we know that no new handshakes can occur, and that any
// existing connections can no longer be served.
if err := s.g.Wait(); !errors.Is(err, net.ErrClosed) {
return err
}
return nil
}
func (s *sshServer) serveRequests(ctx context.Context, ch ssh.Channel, incomingRequests <-chan *ssh.Request) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
case req, ok := <-incomingRequests:
if !ok {
return nil
}
if fn := s.onRequest; fn != nil {
fn(ch, req)
}
}
}
}
// startSSHServer starts an ssh server on localhost, at any available port.
func startSSHServer(ctx context.Context, onNewChannel func(ssh.NewChannel), onRequest func(ssh.Channel, *ssh.Request)) (*sshServer, error) {
serverConfig, clientConfig, err := genSSHConfig()
server := &sshServer{
clientConfig: clientConfig,
serverConfig: serverConfig,
onNewChannel: onNewChannel,
onRequest: onRequest,
}
if err = server.start(ctx); err != nil {
return nil, err
}
return server, nil
}
func genSSHConfig() (*ssh.ServerConfig, *ssh.ClientConfig, error) {
clientPassword, err := genPassword(40)
if err != nil {
return nil, nil, fmt.Errorf("failed to generate password: %w", err)
}
serverConfig := &ssh.ServerConfig{
MaxAuthTries: 1,
PasswordCallback: func(metadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if metadata.User() != testServerUser || string(password) != clientPassword {
return nil, errors.New("invalid user/password combination")
}
return nil, nil
},
}
serverKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, fmt.Errorf("error generating keypair: %w", err)
}
signer, err := ssh.NewSignerFromKey(serverKey)
if err != nil {
return nil, nil, err
}
serverConfig.AddHostKey(signer)
clientConfig := &ssh.ClientConfig{
User: testServerUser,
Auth: []ssh.AuthMethod{ssh.Password(clientPassword)},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
return serverConfig, clientConfig, nil
}
func genPassword(length int) (string, error) {
buf := make([]byte, length)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return fmt.Sprintf("%x", buf), nil
}
// onChannelExec is a helper method for creating a sshServer.onNewChannel which
// will call a callback if the new channel request is a session with a single
// request to execute a command. Any other channel or request type will result
// in a panic.
func onExecRequest(f func(cmd string, stdout io.Writer, stderr io.Writer) int) func(ssh.Channel, *ssh.Request) {
return func(ch ssh.Channel, req *ssh.Request) {
switch req.Type {
case "exec":
var execMsg struct{ Command string }
if err := ssh.Unmarshal(req.Payload, &execMsg); err != nil {
log.Panicf("failed to unmarshal payload: %s", err)
}
if err := req.Reply(true, nil); err != nil {
log.Panicf("failed to send reply: %s", err)
}
exitStatus := f(execMsg.Command, ch, ch.Stderr())
exitMsg := struct {
ExitStatus uint32
}{ExitStatus: uint32(exitStatus)}
if _, err := ch.SendRequest("exit-status", false, ssh.Marshal(&exitMsg)); err != nil {
log.Panicf("failed to send exit status: %s", err)
}
if err := ch.Close(); err != nil {
log.Printf("failed to close channel: %s", err)
}
case keepaliveFuchsia:
// Ignore.
default:
log.Panicf("unexpected request type: %s", req.Type)
}
}
}