blob: d15f3a1104cab58e283cfb1a06e3ff93b838231f [file] [log] [blame]
package sftp
import (
"context"
"io"
"path"
"path/filepath"
"strconv"
"sync"
"syscall"
"github.com/pkg/errors"
)
var maxTxPacket uint32 = 1 << 15
// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
FileGet FileReader
FilePut FileWriter
FileCmd FileCmder
FileList FileLister
}
// RequestServer abstracts the sftp protocol with an http request-like protocol
type RequestServer struct {
*serverConn
Handlers Handlers
pktMgr *packetManager
openRequests map[string]*Request
openRequestLock sync.RWMutex
handleCount int
}
// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers) *RequestServer {
svrConn := &serverConn{
conn: conn{
Reader: rwc,
WriteCloser: rwc,
},
}
return &RequestServer{
serverConn: svrConn,
Handlers: h,
pktMgr: newPktMgr(svrConn),
openRequests: make(map[string]*Request),
}
}
// New Open packet/Request
func (rs *RequestServer) nextRequest(r *Request) string {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
rs.handleCount++
handle := strconv.Itoa(rs.handleCount)
r.handle = handle
rs.openRequests[handle] = r
return handle
}
// Returns Request from openRequests, bool is false if it is missing.
//
// The Requests in openRequests work essentially as open file descriptors that
// you can do different things with. What you are doing with it are denoted by
// the first packet of that type (read/write/etc).
func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
rs.openRequestLock.RLock()
defer rs.openRequestLock.RUnlock()
r, ok := rs.openRequests[handle]
return r, ok
}
// Close the Request and clear from openRequests map
func (rs *RequestServer) closeRequest(handle string) error {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
if r, ok := rs.openRequests[handle]; ok {
delete(rs.openRequests, handle)
return r.close()
}
return syscall.EBADF
}
// Close the read/write/closer to trigger exiting the main server loop
func (rs *RequestServer) Close() error { return rs.conn.Close() }
// Serve requests for user session
func (rs *RequestServer) Serve() error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
runWorker := func(ch chan orderedRequest) {
wg.Add(1)
go func() {
defer wg.Done()
if err := rs.packetWorker(ctx, ch); err != nil {
rs.conn.Close() // shuts down recvPacket
}
}()
}
pktChan := rs.pktMgr.workerChan(runWorker)
var err error
var pkt requestPacket
var pktType uint8
var pktBytes []byte
for {
pktType, pktBytes, err = rs.recvPacket()
if err != nil {
break
}
pkt, err = makePacket(rxPacket{fxp(pktType), pktBytes})
if err != nil {
switch errors.Cause(err) {
case errUnknownExtendedPacket:
// do nothing
default:
debug("makePacket err: %v", err)
rs.conn.Close() // shuts down recvPacket
break
}
}
pktChan <- rs.pktMgr.newOrderedRequest(pkt)
}
close(pktChan) // shuts down sftpServerWorkers
wg.Wait() // wait for all workers to exit
// make sure all open requests are properly closed
// (eg. possible on dropped connections, client crashes, etc.)
for handle, req := range rs.openRequests {
if err != nil {
req.state.RLock()
writer := req.state.writerAt
reader := req.state.readerAt
req.state.RUnlock()
if t, ok := writer.(TransferError); ok {
debug("notify error: %v to writer: %v\n", err, writer)
t.TransferError(err)
}
if t, ok := reader.(TransferError); ok {
debug("notify error: %v to reader: %v\n", err, reader)
t.TransferError(err)
}
}
delete(rs.openRequests, handle)
req.close()
}
return err
}
func (rs *RequestServer) packetWorker(
ctx context.Context, pktChan chan orderedRequest,
) error {
for pkt := range pktChan {
if epkt, ok := pkt.requestPacket.(*sshFxpExtendedPacket); ok {
if epkt.SpecificPacket != nil {
pkt.requestPacket = epkt.SpecificPacket
}
}
var rpkt responsePacket
switch pkt := pkt.requestPacket.(type) {
case *sshFxInitPacket:
rpkt = sshFxVersionPacket{Version: sftpProtocolVersion, Extensions: sftpExtensions}
case *sshFxpClosePacket:
handle := pkt.getHandle()
rpkt = statusFromError(pkt, rs.closeRequest(handle))
case *sshFxpRealpathPacket:
rpkt = cleanPacketPath(pkt)
case *sshFxpOpendirPacket:
request := requestFromPacket(ctx, pkt)
rs.nextRequest(request)
rpkt = request.opendir(rs.Handlers, pkt)
case *sshFxpOpenPacket:
request := requestFromPacket(ctx, pkt)
rs.nextRequest(request)
rpkt = request.open(rs.Handlers, pkt)
case *sshFxpFstatPacket:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
request = NewRequest("Stat", request.Filepath)
rpkt = request.call(rs.Handlers, pkt)
}
case *sshFxpExtendedPacketPosixRename:
request := NewRequest("Rename", pkt.Oldpath)
request.Target = pkt.Newpath
rpkt = request.call(rs.Handlers, pkt)
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
rpkt = request.call(rs.Handlers, pkt)
}
case hasPath:
request := requestFromPacket(ctx, pkt)
rpkt = request.call(rs.Handlers, pkt)
request.close()
default:
rpkt = statusFromError(pkt, ErrSSHFxOpUnsupported)
}
rs.pktMgr.readyPacket(
rs.pktMgr.newOrderedResponse(rpkt, pkt.orderID()))
}
return nil
}
// clean and return name packet for file
func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket {
path := cleanPath(pkt.getPath())
return &sshFxpNamePacket{
ID: pkt.id(),
NameAttrs: []sshFxpNameAttr{{
Name: path,
LongName: path,
Attrs: emptyFileStat,
}},
}
}
// Makes sure we have a clean POSIX (/) absolute path to work with
func cleanPath(p string) string {
p = filepath.ToSlash(p)
if !filepath.IsAbs(p) {
p = "/" + p
}
return path.Clean(p)
}