blob: 2d0e424ea95484fae5eb7dd4a69dc9be91977093 [file] [log] [blame]
// Copyright 2017 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
package ssh
import (
"context"
"io"
"net"
"os"
"strings"
"time"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
const (
User = "fuchsia"
Port = 22
)
type Connection struct {
client *ssh.Client
conn ssh.Conn
}
func NewConnection(addr string, user string, key []byte) (*Connection, error) {
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
return nil, err
}
config := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
Timeout: 5 * time.Minute,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
conn, err := net.DialTimeout("tcp6", addr, config.Timeout)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
client := ssh.NewClient(c, chans, reqs)
// TODO: using ssh.NewClientConn as opposed to Dial allows closing the
// underlying connection which is useful in case the other side hangs,
// but it's unclear whether that's the right approach.
//client, err := ssh.Dial("tcp6", address, config); err == nil {
//if err != nil {
// return nil, err
//}
return &Connection{
client: client,
conn: c,
}, nil
}
func (c *Connection) RunCommand(name string, arg ...string) error {
session, err := c.client.NewSession()
if err != nil {
return err
}
defer session.Close()
session.Stdin = os.Stdin
session.Stdout = os.Stdout
session.Stderr = os.Stderr
args := append([]string{name}, arg...)
return session.Run(strings.Join(args, " ") + "\n")
}
func (c *Connection) Run(ctx context.Context, name string, arg ...string) error {
session, err := c.client.NewSession()
if err != nil {
return err
}
defer session.Close()
session.Stdin = os.Stdin
session.Stdout = os.Stdout
session.Stderr = os.Stderr
args := append([]string{name}, arg...)
if err := session.Start(strings.Join(args, " ") + "\n"); err != nil {
return err
}
ch := make(chan struct{})
go func() {
select {
case <-ctx.Done():
c.conn.Close() // close the underlying connection
case <-ch:
}
}()
err = session.Wait()
close(ch)
return err
}
type Cmd struct {
Args []string
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
session *ssh.Session
ctx context.Context // nil means none
waitDone chan struct{}
}
func (c *Connection) Command(ctx context.Context, name string, arg ...string) (*Cmd, error) {
session, err := c.client.NewSession()
if err != nil {
return nil, err
}
return &Cmd{
Args: append([]string{name}, arg...),
session: session,
ctx: ctx,
}, nil
}
func (c *Cmd) Start() error {
c.session.Stdin = c.Stdin
c.session.Stdout = c.Stdout
c.session.Stderr = c.Stderr
if err := c.session.Start(strings.Join(c.Args, " ")); err != nil {
return err
}
if c.ctx != nil {
c.waitDone = make(chan struct{})
go func() {
select {
case <-c.ctx.Done():
c.session.Close()
case <-c.waitDone:
}
}()
}
return nil
}
func (c *Cmd) Wait() error {
err := c.session.Wait()
if c.waitDone != nil {
close(c.waitDone)
}
c.session.Close()
return err
}
func (c *Connection) Upload(dst string, src io.Reader) error {
client, err := sftp.NewClient(c.client)
if err != nil {
return err
}
defer client.Close()
remoteFile, err := client.Create(dst)
if err != nil {
return err
}
defer remoteFile.Close()
_, err = io.Copy(remoteFile, src)
return err
}
func (c *Connection) UploadFile(dst string, src string) error {
localFile, err := os.Open(src)
if err != nil {
return err
}
defer localFile.Close()
return c.Upload(dst, localFile)
}