Merge pull request #584 from peterverraedt/configure-maxTxPacket

Add WithMaxTxPacket server option
diff --git a/packet.go b/packet.go
index cbaa90e..bfe6a3c 100644
--- a/packet.go
+++ b/packet.go
@@ -823,7 +823,7 @@
 // So, we need: uint32(length) + byte(type) + uint32(id) + uint32(data_length)
 const dataHeaderLen = 4 + 1 + 4 + 4
 
-func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32) []byte {
+func (p *sshFxpReadPacket) getDataSlice(alloc *allocator, orderID uint32, maxTxPacket uint32) []byte {
 	dataLen := p.Len
 	if dataLen > maxTxPacket {
 		dataLen = maxTxPacket
diff --git a/request-server.go b/request-server.go
index 7a99db6..11047e6 100644
--- a/request-server.go
+++ b/request-server.go
@@ -10,7 +10,7 @@
 	"sync"
 )
 
-var maxTxPacket uint32 = 1 << 15
+const defaultMaxTxPacket uint32 = 1 << 15
 
 // Handlers contains the 4 SFTP server request handlers.
 type Handlers struct {
@@ -28,6 +28,7 @@
 	pktMgr *packetManager
 
 	startDirectory string
+	maxTxPacket    uint32
 
 	mu           sync.RWMutex
 	handleCount  int
@@ -57,6 +58,22 @@
 	}
 }
 
+// WithRSMaxTxPacket sets the maximum size of the payload returned to the client,
+// measured in bytes. The default value is 32768 bytes, and this option
+// can only be used to increase it. Setting this option to a larger value
+// should be safe, because the client decides the size of the requested payload.
+//
+// The default maximum packet size is 32768 bytes.
+func WithRSMaxTxPacket(size uint32) RequestServerOption {
+	return func(rs *RequestServer) {
+		if size < defaultMaxTxPacket {
+			return
+		}
+
+		rs.maxTxPacket = size
+	}
+}
+
 // NewRequestServer creates/allocates/returns new RequestServer.
 // Normally there will be one server per user-session.
 func NewRequestServer(rwc io.ReadWriteCloser, h Handlers, options ...RequestServerOption) *RequestServer {
@@ -73,6 +90,7 @@
 		pktMgr:     newPktMgr(svrConn),
 
 		startDirectory: "/",
+		maxTxPacket:    defaultMaxTxPacket,
 
 		openRequests: make(map[string]*Request),
 	}
@@ -260,7 +278,7 @@
 					Method:   "Stat",
 					Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
 				}
-				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
 			}
 		case *sshFxpFsetstatPacket:
 			handle := pkt.getHandle()
@@ -272,7 +290,7 @@
 					Method:   "Setstat",
 					Filepath: cleanPathWithBase(rs.startDirectory, request.Filepath),
 				}
-				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
 			}
 		case *sshFxpExtendedPacketPosixRename:
 			request := &Request{
@@ -280,24 +298,24 @@
 				Filepath: cleanPathWithBase(rs.startDirectory, pkt.Oldpath),
 				Target:   cleanPathWithBase(rs.startDirectory, pkt.Newpath),
 			}
-			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
 		case *sshFxpExtendedPacketStatVFS:
 			request := &Request{
 				Method:   "StatVFS",
 				Filepath: cleanPathWithBase(rs.startDirectory, pkt.Path),
 			}
-			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
 		case hasHandle:
 			handle := pkt.getHandle()
 			request, ok := rs.getRequest(handle)
 			if !ok {
 				rpkt = statusFromError(pkt.id(), EBADF)
 			} else {
-				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+				rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
 			}
 		case hasPath:
 			request := requestFromPacket(ctx, pkt, rs.startDirectory)
-			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID)
+			rpkt = request.call(rs.Handlers, pkt, rs.pktMgr.alloc, orderID, rs.maxTxPacket)
 			request.close()
 		default:
 			rpkt = statusFromError(pkt.id(), ErrSSHFxOpUnsupported)
diff --git a/request.go b/request.go
index cd652cf..e7c47a9 100644
--- a/request.go
+++ b/request.go
@@ -300,14 +300,14 @@
 }
 
 // called from worker to handle packet/request
-func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+func (r *Request) call(handlers Handlers, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
 	switch r.Method {
 	case "Get":
-		return fileget(handlers.FileGet, r, pkt, alloc, orderID)
+		return fileget(handlers.FileGet, r, pkt, alloc, orderID, maxTxPacket)
 	case "Put":
-		return fileput(handlers.FilePut, r, pkt, alloc, orderID)
+		return fileput(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
 	case "Open":
-		return fileputget(handlers.FilePut, r, pkt, alloc, orderID)
+		return fileputget(handlers.FilePut, r, pkt, alloc, orderID, maxTxPacket)
 	case "Setstat", "Rename", "Rmdir", "Mkdir", "Link", "Symlink", "Remove", "PosixRename", "StatVFS":
 		return filecmd(handlers.FileCmd, r, pkt)
 	case "List":
@@ -392,13 +392,13 @@
 }
 
 // wrap FileReader handler
-func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+func fileget(h FileReader, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
 	rd := r.getReaderAt()
 	if rd == nil {
 		return statusFromError(pkt.id(), errors.New("unexpected read packet"))
 	}
 
-	data, offset, _ := packetData(pkt, alloc, orderID)
+	data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)
 
 	n, err := rd.ReadAt(data, offset)
 	// only return EOF error if no data left to read
@@ -414,20 +414,20 @@
 }
 
 // wrap FileWriter handler
-func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+func fileput(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
 	wr := r.getWriterAt()
 	if wr == nil {
 		return statusFromError(pkt.id(), errors.New("unexpected write packet"))
 	}
 
-	data, offset, _ := packetData(pkt, alloc, orderID)
+	data, offset, _ := packetData(pkt, alloc, orderID, maxTxPacket)
 
 	_, err := wr.WriteAt(data, offset)
 	return statusFromError(pkt.id(), err)
 }
 
 // wrap OpenFileWriter handler
-func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32) responsePacket {
+func fileputget(h FileWriter, r *Request, pkt requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) responsePacket {
 	rw := r.getWriterAtReaderAt()
 	if rw == nil {
 		return statusFromError(pkt.id(), errors.New("unexpected write and read packet"))
@@ -435,7 +435,7 @@
 
 	switch p := pkt.(type) {
 	case *sshFxpReadPacket:
-		data, offset := p.getDataSlice(alloc, orderID), int64(p.Offset)
+		data, offset := p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset)
 
 		n, err := rw.ReadAt(data, offset)
 		// only return EOF error if no data left to read
@@ -461,10 +461,10 @@
 }
 
 // file data for additional read/write packets
-func packetData(p requestPacket, alloc *allocator, orderID uint32) (data []byte, offset int64, length uint32) {
+func packetData(p requestPacket, alloc *allocator, orderID uint32, maxTxPacket uint32) (data []byte, offset int64, length uint32) {
 	switch p := p.(type) {
 	case *sshFxpReadPacket:
-		return p.getDataSlice(alloc, orderID), int64(p.Offset), p.Len
+		return p.getDataSlice(alloc, orderID, maxTxPacket), int64(p.Offset), p.Len
 	case *sshFxpWritePacket:
 		return p.Data, int64(p.Offset), p.Length
 	}
diff --git a/request_test.go b/request_test.go
index 92f7c2b..807833a 100644
--- a/request_test.go
+++ b/request_test.go
@@ -149,7 +149,7 @@
 	for i, txt := range []string{"file-", "data."} {
 		pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a",
 			Offset: uint64(i * 5), Len: 5}
-		rpkt := request.call(handlers, pkt, nil, 0)
+		rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 		dpkt := rpkt.(*sshFxpDataPacket)
 		assert.Equal(t, dpkt.id(), uint32(i))
 		assert.Equal(t, string(dpkt.Data), txt)
@@ -162,7 +162,7 @@
 	pkt := fakePacket{myid: 1}
 	cmdErr := errors.New("stat not supported")
 	handlers.returnError(cmdErr)
-	rpkt := request.call(handlers, pkt, nil, 0)
+	rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 	assert.Equal(t, rpkt, statusFromError(pkt.myid, cmdErr))
 }
 
@@ -173,11 +173,11 @@
 	request.state.writerAt, _ = handlers.FilePut.Filewrite(request)
 	pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5,
 		Data: []byte("file-")}
-	rpkt := request.call(handlers, pkt, nil, 0)
+	rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 	checkOkStatus(t, rpkt)
 	pkt = &sshFxpWritePacket{ID: 1, Handle: "a", Offset: 5, Length: 5,
 		Data: []byte("data.")}
-	rpkt = request.call(handlers, pkt, nil, 0)
+	rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 	checkOkStatus(t, rpkt)
 	assert.Equal(t, "file-data.", handlers.getOutString())
 }
@@ -186,11 +186,11 @@
 	handlers := newTestHandlers()
 	request := testRequest("Mkdir")
 	pkt := fakePacket{myid: 1}
-	rpkt := request.call(handlers, pkt, nil, 0)
+	rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 	checkOkStatus(t, rpkt)
 
 	handlers.returnError(errTest)
-	rpkt = request.call(handlers, pkt, nil, 0)
+	rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 	assert.Equal(t, rpkt, statusFromError(pkt.myid, errTest))
 }
 
@@ -198,7 +198,7 @@
 	handlers := newTestHandlers()
 	request := testRequest("Stat")
 	pkt := fakePacket{myid: 1}
-	rpkt := request.call(handlers, pkt, nil, 0)
+	rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 	spkt, ok := rpkt.(*sshFxpStatResponse)
 	assert.True(t, ok)
 	assert.Equal(t, spkt.info.Name(), "request_test.go")
@@ -215,13 +215,13 @@
 		assert.Equal(t, hpkt.Handle, "1")
 	}
 	pkt = fakePacket{myid: 2}
-	request.call(handlers, pkt, nil, 0)
+	request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 }
 func TestRequestInfoReadlink(t *testing.T) {
 	handlers := newTestHandlers()
 	request := testRequest("Readlink")
 	pkt := fakePacket{myid: 1}
-	rpkt := request.call(handlers, pkt, nil, 0)
+	rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 	npkt, ok := rpkt.(*sshFxpNamePacket)
 	if assert.True(t, ok) {
 		assert.IsType(t, &sshFxpNameAttr{}, npkt.NameAttrs[0])
@@ -234,7 +234,7 @@
 	request := testRequest("Stat")
 	request.handle = "1"
 	pkt := fakePacket{myid: 1}
-	rpkt := request.call(handlers, pkt, nil, 0)
+	rpkt := request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 	assert.IsType(t, &sshFxpStatResponse{}, rpkt)
 
 	request.Method = "List"
@@ -244,6 +244,6 @@
 		hpkt := rpkt.(*sshFxpHandlePacket)
 		assert.Equal(t, hpkt.Handle, "1")
 	}
-	rpkt = request.call(handlers, pkt, nil, 0)
+	rpkt = request.call(handlers, pkt, nil, 0, defaultMaxTxPacket)
 	assert.IsType(t, &sshFxpNamePacket{}, rpkt)
 }
diff --git a/server.go b/server.go
index acdc30e..fb474c4 100644
--- a/server.go
+++ b/server.go
@@ -34,6 +34,7 @@
 	openFilesLock sync.RWMutex
 	handleCount   int
 	workDir       string
+	maxTxPacket   uint32
 }
 
 func (svr *Server) nextHandle(f *os.File) string {
@@ -86,6 +87,7 @@
 		debugStream: ioutil.Discard,
 		pktMgr:      newPktMgr(svrConn),
 		openFiles:   make(map[string]*os.File),
+		maxTxPacket: defaultMaxTxPacket,
 	}
 
 	for _, o := range options {
@@ -139,6 +141,24 @@
 	}
 }
 
+// WithMaxTxPacket sets the maximum size of the payload returned to the client,
+// measured in bytes. The default value is 32768 bytes, and this option
+// can only be used to increase it. Setting this option to a larger value
+// should be safe, because the client decides the size of the requested payload.
+//
+// The default maximum packet size is 32768 bytes.
+func WithMaxTxPacket(size uint32) ServerOption {
+	return func(s *Server) error {
+		if size < defaultMaxTxPacket {
+			return errors.New("size must be greater than or equal to 32768")
+		}
+
+		s.maxTxPacket = size
+
+		return nil
+	}
+}
+
 type rxPacket struct {
 	pktType  fxp
 	pktBytes []byte
@@ -287,7 +307,7 @@
 		f, ok := s.getHandle(p.Handle)
 		if ok {
 			err = nil
-			data := p.getDataSlice(s.pktMgr.alloc, orderID)
+			data := p.getDataSlice(s.pktMgr.alloc, orderID, s.maxTxPacket)
 			n, _err := f.ReadAt(data, int64(p.Offset))
 			if _err != nil && (_err != io.EOF || n == 0) {
 				err = _err
@@ -513,16 +533,16 @@
 
 	fs, err := p.unmarshalFileStat(p.Flags)
 
-	if err == nil && (p.Flags & sshFileXferAttrSize) != 0 {
+	if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
 		err = os.Truncate(path, int64(fs.Size))
 	}
-	if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 {
+	if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
 		err = os.Chmod(path, fs.FileMode())
 	}
-	if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 {
+	if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
 		err = os.Chown(path, int(fs.UID), int(fs.GID))
 	}
-	if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 {
+	if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
 		err = os.Chtimes(path, fs.AccessTime(), fs.ModTime())
 	}
 
@@ -541,16 +561,16 @@
 
 	fs, err := p.unmarshalFileStat(p.Flags)
 
-	if err == nil && (p.Flags & sshFileXferAttrSize) != 0 {
+	if err == nil && (p.Flags&sshFileXferAttrSize) != 0 {
 		err = f.Truncate(int64(fs.Size))
 	}
-	if err == nil && (p.Flags & sshFileXferAttrPermissions) != 0 {
+	if err == nil && (p.Flags&sshFileXferAttrPermissions) != 0 {
 		err = f.Chmod(fs.FileMode())
 	}
-	if err == nil && (p.Flags & sshFileXferAttrUIDGID) != 0 {
+	if err == nil && (p.Flags&sshFileXferAttrUIDGID) != 0 {
 		err = f.Chown(int(fs.UID), int(fs.GID))
 	}
-	if err == nil && (p.Flags & sshFileXferAttrACmodTime) != 0 {
+	if err == nil && (p.Flags&sshFileXferAttrACmodTime) != 0 {
 		type chtimer interface {
 			Chtimes(atime, mtime time.Time) error
 		}