remove vestigial packets channel from request
Request object had a packet channel where it stored packets for
read/write/readdir handling. It was required in an earlier version of
the code (that was refactored before merging) and never got cleaned up.
This removes that channel and just passes the packet through the methods
instead. This is much simpler and would have eliminated the chance of
issue #195. It also looks like it should allow for a few more
simplifications as well.
diff --git a/request-server.go b/request-server.go
index 057834c..a622113 100644
--- a/request-server.go
+++ b/request-server.go
@@ -144,7 +144,7 @@
} else {
request = requestFromPacket(
&sshFxpStatPacket{ID: pkt.id(), Path: request.Filepath})
- rpkt = request.callHandler(rs.Handlers)
+ rpkt = request.call(rs.Handlers, pkt)
}
case *sshFxpFsetstatPacket:
handle := pkt.getHandle()
@@ -156,20 +156,23 @@
&sshFxpSetstatPacket{ID: pkt.id(), Path: request.Filepath,
Flags: pkt.Flags, Attrs: pkt.Attrs,
})
- rpkt = request.callHandler(rs.Handlers)
+ rpkt = request.call(rs.Handlers, pkt)
}
case hasHandle:
handle := pkt.getHandle()
request, ok := rs.getRequest(handle)
- request.update(pkt)
- if !ok {
+ uerr := request.updateMethod(pkt)
+ if !ok || uerr != nil {
+ if uerr == nil {
+ uerr = syscall.EBADF
+ }
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
- rpkt = request.callHandler(rs.Handlers)
+ rpkt = request.call(rs.Handlers, pkt)
}
case hasPath:
request := requestFromPacket(pkt)
- rpkt = request.callHandler(rs.Handlers)
+ rpkt = request.call(rs.Handlers, pkt)
default:
return errors.Errorf("unexpected packet type %T", pkt)
}
diff --git a/request.go b/request.go
index b0eaba7..28629de 100644
--- a/request.go
+++ b/request.go
@@ -23,9 +23,6 @@
Flags uint32
Attrs []byte // convert to sub-struct
Target string // for renames and sym-links
- // packet data
- pkt_id uint32
- packets chan packet_data
// reader/writer/readdir from handlers
stateLock *sync.RWMutex
state *state
@@ -53,7 +50,6 @@
func requestFromPacket(pkt hasPath) *Request {
method := requestMethod(pkt)
request := NewRequest(method, pkt.getPath())
- request.pkt_id = pkt.id()
switch p := pkt.(type) {
case *sshFxpSetstatPacket:
request.Flags = p.Flags
@@ -66,17 +62,16 @@
return request
}
-// NewRequest creates a new Request object.
-func NewRequest(method, path string) *Request {
- request := &Request{Method: method, Filepath: cleanPath(path)}
- request.packets = make(chan packet_data, SftpServerWorkerCount)
- request.state = &state{}
- request.stateLock = &sync.RWMutex{}
- return request
+func newRequest() *Request {
+ return &Request{state: &state{}, stateLock: &sync.RWMutex{}}
}
-func (r *Request) id() uint32 {
- return r.pkt_id
+// NewRequest creates a new Request object.
+func NewRequest(method, path string) *Request {
+ request := newRequest()
+ request.Method = method
+ request.Filepath = cleanPath(path)
+ return request
}
// Returns current offset for file list
@@ -139,38 +134,43 @@
}
}
-// push packet_data into fifo
-func (r *Request) pushPacket(pd packet_data) {
- r.packets <- pd
-}
-
-// pop packet_data into fifo
-func (r *Request) popPacket() packet_data {
- return <-r.packets
-}
-
// called from worker to handle packet/request
-func (r *Request) callHandler(handlers Handlers) responsePacket {
+func (r *Request) call(handlers Handlers, pkt requestPacket) responsePacket {
+ pd := packetData(pkt)
switch r.Method {
case "Get":
- return fileget(handlers.FileGet, r)
+ return fileget(handlers.FileGet, r, pd)
case "Put": // add "Append" to this to handle append only file writes
- return fileput(handlers.FilePut, r)
+ return fileput(handlers.FilePut, r, pd)
case "Setstat", "Rename", "Rmdir", "Mkdir", "Symlink", "Remove":
- return filecmd(handlers.FileCmd, r)
+ return filecmd(handlers.FileCmd, r, pd)
case "List", "Stat", "Readlink":
- return filelist(handlers.FileList, r)
+ return filelist(handlers.FileList, r, pd)
default:
- return statusFromError(r,
+ return statusFromError(pkt,
errors.Errorf("unexpected method: %s", r.Method))
}
}
+// file data for additional read/write packets
+func packetData(p requestPacket) packet_data {
+ pd := packet_data{_id: p.id()}
+ switch p := p.(type) {
+ case *sshFxpReadPacket:
+ pd.length = p.Len
+ pd.offset = int64(p.Offset)
+ case *sshFxpWritePacket:
+ pd.data = p.Data
+ pd.length = p.Length
+ pd.offset = int64(p.Offset)
+ }
+ return pd
+}
+
// wrap FileReader handler
-func fileget(h FileReader, r *Request) responsePacket {
+func fileget(h FileReader, r *Request, pd packet_data) responsePacket {
var err error
reader := r.getReader()
- pd := r.popPacket()
if reader == nil {
reader, err = h.Fileread(r)
if err != nil {
@@ -193,18 +193,17 @@
}
// wrap FileWriter handler
-func fileput(h FileWriter, r *Request) responsePacket {
+func fileput(h FileWriter, r *Request, pd packet_data) responsePacket {
var err error
writer := r.getWriter()
if writer == nil {
writer, err = h.Filewrite(r)
if err != nil {
- return statusFromError(r, err)
+ return statusFromError(pd, err)
}
r.setFileState(writer)
}
- pd := r.popPacket()
_, err = writer.WriteAt(pd.data, pd.offset)
if err != nil {
return statusFromError(pd, err)
@@ -217,26 +216,26 @@
}
// wrap FileCmder handler
-func filecmd(h FileCmder, r *Request) responsePacket {
+func filecmd(h FileCmder, r *Request, pd packet_data) responsePacket {
err := h.Filecmd(r)
if err != nil {
- return statusFromError(r, err)
+ return statusFromError(pd, err)
}
return &sshFxpStatusPacket{
- ID: r.pkt_id,
+ ID: pd.id(),
StatusError: StatusError{
Code: ssh_FX_OK,
}}
}
// wrap FileLister handler
-func filelist(h FileLister, r *Request) responsePacket {
+func filelist(h FileLister, r *Request, pd packet_data) responsePacket {
var err error
lister := r.getLister()
if lister == nil {
lister, err = h.Filelist(r)
if err != nil {
- return statusFromError(r, err)
+ return statusFromError(pd, err)
}
r.setFileState(lister)
}
@@ -250,7 +249,6 @@
switch r.Method {
case "List":
- pd := r.popPacket()
if err != nil && err != io.EOF {
return statusFromError(pd, err)
}
@@ -270,29 +268,29 @@
return ret
case "Stat":
if err != nil && err != io.EOF {
- return statusFromError(r, err)
+ return statusFromError(pd, err)
}
if n == 0 {
err = &os.PathError{Op: "stat", Path: r.Filepath,
Err: syscall.ENOENT}
- return statusFromError(r, err)
+ return statusFromError(pd, err)
}
return &sshFxpStatResponse{
- ID: r.pkt_id,
+ ID: pd.id(),
info: finfo[0],
}
case "Readlink":
if err != nil && err != io.EOF {
- return statusFromError(r, err)
+ return statusFromError(pd, err)
}
if n == 0 {
err = &os.PathError{Op: "readlink", Path: r.Filepath,
Err: syscall.ENOENT}
- return statusFromError(r, err)
+ return statusFromError(pd, err)
}
filename := finfo[0].Name()
return &sshFxpNamePacket{
- ID: r.pkt_id,
+ ID: pd.id(),
NameAttrs: []sshFxpNameAttr{{
Name: filename,
LongName: filename,
@@ -301,29 +299,22 @@
}
default:
err = errors.Errorf("unexpected method: %s", r.Method)
- return statusFromError(r, err)
+ return statusFromError(pd, err)
}
}
// file data for additional read/write packets
-func (r *Request) update(p hasHandle) error {
- pd := packet_data{_id: p.id()}
+func (r *Request) updateMethod(p hasHandle) error {
switch p := p.(type) {
case *sshFxpReadPacket:
r.Method = "Get"
- pd.length = p.Len
- pd.offset = int64(p.Offset)
case *sshFxpWritePacket:
r.Method = "Put"
- pd.data = p.Data
- pd.length = p.Length
- pd.offset = int64(p.Offset)
case *sshFxpReaddirPacket:
r.Method = "List"
default:
return errors.Errorf("unexpected packet type %T", p)
}
- r.pushPacket(pd)
return nil
}
diff --git a/request_test.go b/request_test.go
index 7681627..c90a6ed 100644
--- a/request_test.go
+++ b/request_test.go
@@ -62,15 +62,9 @@
Method: method,
Attrs: []byte("foo"),
Target: "foo",
- packets: make(chan packet_data, SftpServerWorkerCount),
state: &state{},
stateLock: &sync.RWMutex{},
}
- for _, p := range []packet_data{
- {_id: 1, data: filecontents[:5], length: 5},
- {_id: 2, data: filecontents[5:], length: 5, offset: 5}} {
- request.packets <- p
- }
return request
}
@@ -117,14 +111,30 @@
}
}
+// fake/test packet
+type fakePacket struct {
+ myid uint32
+ handle string
+}
+
+func (f fakePacket) id() uint32 {
+ return f.myid
+}
+
+func (f fakePacket) getHandle() string {
+ return f.handle
+}
+func (fakePacket) UnmarshalBinary(d []byte) error { return nil }
+
func TestRequestGet(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Get")
// req.length is 5, so we test reads in 5 byte chunks
for i, txt := range []string{"file-", "data."} {
- pkt := request.callHandler(handlers)
- dpkt := pkt.(*sshFxpDataPacket)
- assert.Equal(t, dpkt.id(), uint32(i+1))
+ pkt := &sshFxpReadPacket{uint32(i), "a", uint64(i * 5), 5}
+ rpkt := request.call(handlers, pkt)
+ dpkt := rpkt.(*sshFxpDataPacket)
+ assert.Equal(t, dpkt.id(), uint32(i))
assert.Equal(t, string(dpkt.Data), txt)
}
}
@@ -132,22 +142,25 @@
func TestRequestPut(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Put")
- pkt := request.callHandler(handlers)
- statusOk(t, pkt)
- pkt = request.callHandler(handlers)
- statusOk(t, pkt)
+ pkt := &sshFxpWritePacket{0, "a", 0, 5, []byte("file-")}
+ rpkt := request.call(handlers, pkt)
+ statusOk(t, rpkt)
+ pkt = &sshFxpWritePacket{1, "a", 5, 5, []byte("data.")}
+ rpkt = request.call(handlers, pkt)
+ statusOk(t, rpkt)
assert.Equal(t, "file-data.", handlers.getOutString())
}
func TestRequestCmdr(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Mkdir")
- pkt := request.callHandler(handlers)
- statusOk(t, pkt)
+ pkt := fakePacket{myid: 1}
+ rpkt := request.call(handlers, pkt)
+ statusOk(t, rpkt)
handlers.returnError()
- pkt = request.callHandler(handlers)
- assert.Equal(t, pkt, statusFromError(pkt, errTest))
+ rpkt = request.call(handlers, pkt)
+ assert.Equal(t, rpkt, statusFromError(rpkt, errTest))
}
func TestRequestInfoList(t *testing.T) { testInfoMethod(t, "List") }
@@ -155,8 +168,9 @@
func TestRequestInfoStat(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
- pkt := request.callHandler(handlers)
- spkt, ok := pkt.(*sshFxpStatResponse)
+ pkt := fakePacket{myid: 1}
+ rpkt := request.call(handlers, pkt)
+ spkt, ok := rpkt.(*sshFxpStatResponse)
assert.True(t, ok)
assert.Equal(t, spkt.info.Name(), "request_test.go")
}
@@ -164,8 +178,9 @@
func testInfoMethod(t *testing.T, method string) {
handlers := newTestHandlers()
request := testRequest(method)
- pkt := request.callHandler(handlers)
- npkt, ok := pkt.(*sshFxpNamePacket)
+ pkt := fakePacket{myid: 1}
+ rpkt := request.call(handlers, pkt)
+ npkt, ok := rpkt.(*sshFxpNamePacket)
assert.True(t, ok)
assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")