Merge pull request #584 from peterverraedt/configure-maxTxPacket
Add WithMaxTxPacket server option
diff --git a/packet.go b/packet.go
index cbaa90e..bfe6a3c 100644
--- a/packet.go
+++ b/packet.go
@@ -823,7 +823,7 @@
// So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length)
const dataHeaderLen = 4 + 1 + 4 + 4
-func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte {
+func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32, maxTxPacket uint32) []byte {
dataLen := p.Len
if dataLen > maxTxPacket {
dataLen = maxTxPacket
diff --git a/request-server.go b/request-server.go
index 7a99db6..11047e6 100644
--- a/request-server.go
+++ b/request-server.go
@@ -10,7 +10,7 @@
"sync"
)
-var maxTxPacket uint32 = 1 << 15
+const defaultMaxTxPacket uint32 = 1 << 15
// Handlers contains the 4 SFTP server request handlers.
type Handlers struct {
@@ -28,6 +28,7 @@
pktMgr *packetManager
startDirectory string
+ maxTxPacket uint32
mu sync.RWMutex
handleCount int
@@ -57,6 +58,22 @@
}
}
+// WithRSMaxTxPacket sets the maximum size of the payload returned to the client,
+// measured in bytes. The default value is 32768 bytes, and this option
+// can only be used to increase it. Setting this option to a larger value
+// should be safe, because the client decides the size of the requested payload.
+//
+// The default maximum packet size is 32768 bytes.
+func WithRSMaxTxPacket(size uint32) RequestServerOption {
+ return func(rs *RequestServer) {
+ if size < defaultMaxTxPacket {
+ return
+ }
+
+ rs.maxTxPacket = size
+ }
+}
+
// NewRequestServer creates/allocates/returns new RequestServer.
// Normally there will be one server per user-session.
func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
@@ -73,6 +90,7 @@
pktMgr: newPktMgr(svrConn),
startDirectory: "/",
+ maxTxPacket: defaultMaxTxPacket,
openRequests: make(map[string]*Request),
}
@@ -260,7 +278,7 @@
Method: "Stat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
- rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
@@ -272,7 +290,7 @@
Method: "Setstat",
Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
}
- rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case *sshFxpExtendedPacketPosixRename:
request := &Request{
@@ -280,24 +298,24 @@
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
Target: cleanPathWithBase(rs.startDirectory, pkt.Newpath),
}
- rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case *sshFxpExtendedPacketStatVFS:
request := &Request{
Method: "StatVFS",
Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
}
- rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt.id(), EBADF)
} else {
- rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
}
case hasPath:
request := requestFromPacket(ctx, pkt, rs.startDirectory)
- rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+ rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
request.close()
default:
rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
diff --git a/request.go b/request.go
index cd652cf..e7c47a9 100644
--- a/request.go
+++ b/request.go
@@ -300,14 +300,14 @@
}
// called from worker to handle packet/request
-func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
switch r.Method {
case "Get":
- return fileget(handlers.FileGet, r, pkt, alloc, orderID)
+ return fileget(handlers.FileGet, r, pkt, alloc, orderID, maxTxPacket)
case "Put":
- return fileput(handlers.FilePut, r, pkt, alloc, orderID)
+ return fileput(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Open":
- return fileputget(handlers.FilePut, r, pkt, alloc, orderID)
+ return fileputget(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
return filecmd(handlers.FileCmd, r, pkt)
case "List":
@@ -392,13 +392,13 @@
}
// wrap FileReader handler
-func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rd := r.getReaderAt()
if rd == nil {
return statusFromError(pkt.id(), errors.New("unexpected read packet"))
}
- data, offset, _ := packetData(pkt, alloc, orderID)
+ data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)
n, err := rd.ReadAt(data, offset)
// only return EOF error if no data left to read
@@ -414,20 +414,20 @@
}
// wrap FileWriter handler
-func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
wr := r.getWriterAt()
if wr == nil {
return statusFromError(pkt.id(), errors.New("unexpected write packet"))
}
- data, offset, _ := packetData(pkt, alloc, orderID)
+ data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)
_, err := wr.WriteAt(data, offset)
return statusFromError(pkt.id(), err)
}
// wrap OpenFileWriter handler
-func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
rw := r.getWriterAtReaderAt()
if rw == nil {
return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
@@ -435,7 +435,7 @@
switch p := pkt.(type) {
case *sshFxpReadPacket:
- data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
+ data, offset := p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset)
n, err := rw.ReadAt(data, offset)
// only return EOF error if no data left to read
@@ -461,10 +461,10 @@
}
// file data for additional read/write packets
-func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) {
+func packetData(p requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) (data []byte, offset int64, length uint32) {
switch p := p.(type) {
case *sshFxpReadPacket:
- return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len
+ return p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset), p.Len
case *sshFxpWritePacket:
return p.Data, int64(p.Offset), p.Length
}
diff --git a/request_test.go b/request_test.go
index 92f7c2b..807833a 100644
--- a/request_test.go
+++ b/request_test.go
@@ -149,7 +149,7 @@
for i, txt := range []string{"file-", "data."} {
pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a",
Offset: uint64(i * 5), Len: 5}
- rpkt := request.call(handlers, pkt, nil, 0)
+ rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
dpkt := rpkt.(*sshFxpDataPacket)
assert.Equal(t, dpkt.id(), uint32(i))
assert.Equal(t, string(dpkt.Data), txt)
@@ -162,7 +162,7 @@
pkt := fakePacket{myid: 1}
cmdErr := errors.New("stat not supported")
handlers.returnError(cmdErr)
- rpkt := request.call(handlers, pkt, nil, 0)
+ rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.Equal(t, rpkt, statusFromError(pkt.myid, cmdErr))
}
@@ -173,11 +173,11 @@
request.state.writerAt, _ = handlers.FilePut.Filewrite(request)
pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5,
Data: []byte("file-")}
- rpkt := request.call(handlers, pkt, nil, 0)
+ rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
checkOkStatus(t, rpkt)
pkt = &sshFxpWritePacket{ID: 1, Handle: "a", Offset: 5, Length: 5,
Data: []byte("data.")}
- rpkt = request.call(handlers, pkt, nil, 0)
+ rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
checkOkStatus(t, rpkt)
assert.Equal(t, "file-data.", handlers.getOutString())
}
@@ -186,11 +186,11 @@
handlers := newTestHandlers()
request := testRequest("Mkdir")
pkt := fakePacket{myid: 1}
- rpkt := request.call(handlers, pkt, nil, 0)
+ rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
checkOkStatus(t, rpkt)
handlers.returnError(errTest)
- rpkt = request.call(handlers, pkt, nil, 0)
+ rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.Equal(t, rpkt, statusFromError(pkt.myid, errTest))
}
@@ -198,7 +198,7 @@
handlers := newTestHandlers()
request := testRequest("Stat")
pkt := fakePacket{myid: 1}
- rpkt := request.call(handlers, pkt, nil, 0)
+ rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
spkt, ok := rpkt.(*sshFxpStatResponse)
assert.True(t, ok)
assert.Equal(t, spkt.info.Name(), "request_test.go")
@@ -215,13 +215,13 @@
assert.Equal(t, hpkt.Handle, "1")
}
pkt = fakePacket{myid: 2}
- request.call(handlers, pkt, nil, 0)
+ request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
}
func TestRequestInfoReadlink(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Readlink")
pkt := fakePacket{myid: 1}
- rpkt := request.call(handlers, pkt, nil, 0)
+ rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
npkt, ok := rpkt.(*sshFxpNamePacket)
if assert.True(t, ok) {
assert.IsType(t, &sshFxpNameAttr{}, npkt.NameAttrs[0])
@@ -234,7 +234,7 @@
request := testRequest("Stat")
request.handle = "1"
pkt := fakePacket{myid: 1}
- rpkt := request.call(handlers, pkt, nil, 0)
+ rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.IsType(t, &sshFxpStatResponse{}, rpkt)
request.Method = "List"
@@ -244,6 +244,6 @@
hpkt := rpkt.(*sshFxpHandlePacket)
assert.Equal(t, hpkt.Handle, "1")
}
- rpkt = request.call(handlers, pkt, nil, 0)
+ rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
assert.IsType(t, &sshFxpNamePacket{}, rpkt)
}
diff --git a/server.go b/server.go
index acdc30e..fb474c4 100644
--- a/server.go
+++ b/server.go
@@ -34,6 +34,7 @@
openFilesLock sync.RWMutex
handleCount int
workDir string
+ maxTxPacket uint32
}
func (svr *Server) nextHandle(f *os.File) string {
@@ -86,6 +87,7 @@
debugStream: ioutil.Discard,
pktMgr: newPktMgr(svrConn),
openFiles: make(map[string]*os.File),
+ maxTxPacket: defaultMaxTxPacket,
}
for _, o := range options {
@@ -139,6 +141,24 @@
}
}
+// WithMaxTxPacket sets the maximum size of the payload returned to the client,
+// measured in bytes. The default value is 32768 bytes, and this option
+// can only be used to increase it. Setting this option to a larger value
+// should be safe, because the client decides the size of the requested payload.
+//
+// The default maximum packet size is 32768 bytes.
+func WithMaxTxPacket(size uint32) ServerOption {
+ return func(s *Server) error {
+ if size < defaultMaxTxPacket {
+ return errors.New("size must be greater than or equal to 32768")
+ }
+
+ s.maxTxPacket = size
+
+ return nil
+ }
+}
+
type rxPacket struct {
pktType fxp
pktBytes []byte
@@ -287,7 +307,7 @@
f, ok := s.getHandle(p.Handle)
if ok {
err = nil
- data := p.getDataSlice(s.pktMgr.alloc, orderID)
+ data := p.getDataSlice(s.pktMgr.alloc, orderID, s.maxTxPacket)
n, _err := f.ReadAt(data, int64(p.Offset))
if _err != nil && (_err != io.EOF || n == 0) {
err = _err
@@ -513,16 +533,16 @@
fs, err := p.unmarshalFileStat(p.Flags)
- if err == nil && (p.Flags & sshFileXferAttrSize) != 0 {
+ if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
err = os.Truncate(path, int64(fs.Size))
}
- if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 {
+ if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = os.Chmod(path, fs.FileMode())
}
- if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 {
+ if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = os.Chown(path, int(fs.UID), int(fs.GID))
}
- if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 {
+ if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
}
@@ -541,16 +561,16 @@
fs, err := p.unmarshalFileStat(p.Flags)
- if err == nil && (p.Flags & sshFileXferAttrSize) != 0 {
+ if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
err = f.Truncate(int64(fs.Size))
}
- if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 {
+ if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
err = f.Chmod(fs.FileMode())
}
- if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 {
+ if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
err = f.Chown(int(fs.UID), int(fs.GID))
}
- if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 {
+ if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
type chtimer interface {
Chtimes(atime, mtime time.Time) error
}