blob: 28629de82958500dd8e3058d5adfb5eaa7f3b03d [file] [log] [blame]
package sftp
import (
"io"
"os"
"path"
"path/filepath"
"sync"
"syscall"
"github.com/pkg/errors"
)
// MaxFilelist is the max number of files to return in a readdir batch.
var MaxFilelist int64 = 100
// Request contains the data and state for the incoming service request.
type Request struct {
// Get, Put, Setstat, Stat, Rename, Remove
// Rmdir, Mkdir, List, Readlink, Symlink
Method string
Filepath string
Flags uint32
Attrs []byte // convert to sub-struct
Target string // for renames and sym-links
// reader/writer/readdir from handlers
stateLock *sync.RWMutex
state *state
}
type state struct {
writerAt io.WriterAt
readerAt io.ReaderAt
listerAt ListerAt
lsoffset int64
}
type packet_data struct {
_id uint32
data []byte
length uint32
offset int64
}
func (pd packet_data) id() uint32 {
return pd._id
}
// New Request initialized based on packet data
func requestFromPacket(pkt hasPath) *Request {
method := requestMethod(pkt)
request := NewRequest(method, pkt.getPath())
switch p := pkt.(type) {
case *sshFxpSetstatPacket:
request.Flags = p.Flags
request.Attrs = p.Attrs.([]byte)
case *sshFxpRenamePacket:
request.Target = cleanPath(p.Newpath)
case *sshFxpSymlinkPacket:
request.Target = cleanPath(p.Linkpath)
}
return request
}
func newRequest() *Request {
return &Request{state: &state{}, stateLock: &sync.RWMutex{}}
}
// NewRequest creates a new Request object.
func NewRequest(method, path string) *Request {
request := newRequest()
request.Method = method
request.Filepath = cleanPath(path)
return request
}
// Returns current offset for file list
func (r *Request) lsNext() int64 {
r.stateLock.RLock()
defer r.stateLock.RUnlock()
return r.state.lsoffset
}
// Increases next offset
func (r *Request) lsInc(offset int64) {
r.stateLock.Lock()
defer r.stateLock.Unlock()
r.state.lsoffset = r.state.lsoffset + offset
}
// manage file read/write state
func (r *Request) setFileState(s interface{}) {
r.stateLock.Lock()
defer r.stateLock.Unlock()
switch s := s.(type) {
case io.WriterAt:
r.state.writerAt = s
case io.ReaderAt:
r.state.readerAt = s
case ListerAt:
r.state.listerAt = s
case int64:
r.state.lsoffset = s
}
}
func (r *Request) getWriter() io.WriterAt {
r.stateLock.RLock()
defer r.stateLock.RUnlock()
return r.state.writerAt
}
func (r *Request) getReader() io.ReaderAt {
r.stateLock.RLock()
defer r.stateLock.RUnlock()
return r.state.readerAt
}
func (r *Request) getLister() ListerAt {
r.stateLock.RLock()
defer r.stateLock.RUnlock()
return r.state.listerAt
}
// Close reader/writer if possible
func (r *Request) close() {
rd := r.getReader()
if c, ok := rd.(io.Closer); ok {
c.Close()
}
wt := r.getWriter()
if c, ok := wt.(io.Closer); ok {
c.Close()
}
}
// called from worker to handle packet/request
func (r *Request) call(handlers Handlers, pkt requestPacket) responsePacket {
pd := packetData(pkt)
switch r.Method {
case "Get":
return fileget(handlers.FileGet, r, pd)
case "Put": // add "Append" to this to handle append only file writes
return fileput(handlers.FilePut, r, pd)
case "Setstat", "Rename", "Rmdir", "Mkdir", "Symlink", "Remove":
return filecmd(handlers.FileCmd, r, pd)
case "List", "Stat", "Readlink":
return filelist(handlers.FileList, r, pd)
default:
return statusFromError(pkt,
errors.Errorf("unexpected method: %s", r.Method))
}
}
// file data for additional read/write packets
func packetData(p requestPacket) packet_data {
pd := packet_data{_id: p.id()}
switch p := p.(type) {
case *sshFxpReadPacket:
pd.length = p.Len
pd.offset = int64(p.Offset)
case *sshFxpWritePacket:
pd.data = p.Data
pd.length = p.Length
pd.offset = int64(p.Offset)
}
return pd
}
// wrap FileReader handler
func fileget(h FileReader, r *Request, pd packet_data) responsePacket {
var err error
reader := r.getReader()
if reader == nil {
reader, err = h.Fileread(r)
if err != nil {
return statusFromError(pd, err)
}
r.setFileState(reader)
}
data := make([]byte, clamp(pd.length, maxTxPacket))
n, err := reader.ReadAt(data, pd.offset)
// only return EOF erro if no data left to read
if err != nil && (err != io.EOF || n == 0) {
return statusFromError(pd, err)
}
return &sshFxpDataPacket{
ID: pd.id(),
Length: uint32(n),
Data: data[:n],
}
}
// wrap FileWriter handler
func fileput(h FileWriter, r *Request, pd packet_data) responsePacket {
var err error
writer := r.getWriter()
if writer == nil {
writer, err = h.Filewrite(r)
if err != nil {
return statusFromError(pd, err)
}
r.setFileState(writer)
}
_, err = writer.WriteAt(pd.data, pd.offset)
if err != nil {
return statusFromError(pd, err)
}
return &sshFxpStatusPacket{
ID: pd.id(),
StatusError: StatusError{
Code: ssh_FX_OK,
}}
}
// wrap FileCmder handler
func filecmd(h FileCmder, r *Request, pd packet_data) responsePacket {
err := h.Filecmd(r)
if err != nil {
return statusFromError(pd, err)
}
return &sshFxpStatusPacket{
ID: pd.id(),
StatusError: StatusError{
Code: ssh_FX_OK,
}}
}
// wrap FileLister handler
func filelist(h FileLister, r *Request, pd packet_data) responsePacket {
var err error
lister := r.getLister()
if lister == nil {
lister, err = h.Filelist(r)
if err != nil {
return statusFromError(pd, err)
}
r.setFileState(lister)
}
offset := r.lsNext()
finfo := make([]os.FileInfo, MaxFilelist)
n, err := lister.ListAt(finfo, offset)
r.lsInc(int64(n))
// ignore EOF as we only return it when there are no results
finfo = finfo[:n] // avoid need for nil tests below
switch r.Method {
case "List":
if err != nil && err != io.EOF {
return statusFromError(pd, err)
}
if n == 0 {
return statusFromError(pd, io.EOF)
}
dirname := filepath.ToSlash(path.Base(r.Filepath))
ret := &sshFxpNamePacket{ID: pd.id()}
for _, fi := range finfo {
ret.NameAttrs = append(ret.NameAttrs, sshFxpNameAttr{
Name: fi.Name(),
LongName: runLs(dirname, fi),
Attrs: []interface{}{fi},
})
}
return ret
case "Stat":
if err != nil && err != io.EOF {
return statusFromError(pd, err)
}
if n == 0 {
err = &os.PathError{Op: "stat", Path: r.Filepath,
Err: syscall.ENOENT}
return statusFromError(pd, err)
}
return &sshFxpStatResponse{
ID: pd.id(),
info: finfo[0],
}
case "Readlink":
if err != nil && err != io.EOF {
return statusFromError(pd, err)
}
if n == 0 {
err = &os.PathError{Op: "readlink", Path: r.Filepath,
Err: syscall.ENOENT}
return statusFromError(pd, err)
}
filename := finfo[0].Name()
return &sshFxpNamePacket{
ID: pd.id(),
NameAttrs: []sshFxpNameAttr{{
Name: filename,
LongName: filename,
Attrs: emptyFileStat,
}},
}
default:
err = errors.Errorf("unexpected method: %s", r.Method)
return statusFromError(pd, err)
}
}
// file data for additional read/write packets
func (r *Request) updateMethod(p hasHandle) error {
switch p := p.(type) {
case *sshFxpReadPacket:
r.Method = "Get"
case *sshFxpWritePacket:
r.Method = "Put"
case *sshFxpReaddirPacket:
r.Method = "List"
default:
return errors.Errorf("unexpected packet type %T", p)
}
return nil
}
// init attributes of request object from packet data
func requestMethod(p hasPath) (method string) {
switch p.(type) {
case *sshFxpOpenPacket, *sshFxpOpendirPacket:
method = "Open"
case *sshFxpSetstatPacket:
method = "Setstat"
case *sshFxpRenamePacket:
method = "Rename"
case *sshFxpSymlinkPacket:
method = "Symlink"
case *sshFxpRemovePacket:
method = "Remove"
case *sshFxpStatPacket, *sshFxpLstatPacket:
method = "Stat"
case *sshFxpRmdirPacket:
method = "Rmdir"
case *sshFxpReadlinkPacket:
method = "Readlink"
case *sshFxpMkdirPacket:
method = "Mkdir"
}
return method
}