Merge pull request #283 from pkg/open-on-open
Initialize opening of files/directories upon receiving open packets
diff --git a/request-example.go b/request-example.go
index 4ef2e67..e5abd18 100644
--- a/request-example.go
+++ b/request-example.go
@@ -12,6 +12,7 @@
"path/filepath"
"sort"
"sync"
+ "syscall"
"time"
)
@@ -29,6 +30,7 @@
if fs.mockErr != nil {
return nil, fs.mockErr
}
+ _ = r.WithContext(r.Context()) // initialize context for deadlock testing
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
file, err := fs.fetch(r.Filepath)
@@ -48,6 +50,7 @@
if fs.mockErr != nil {
return nil, fs.mockErr
}
+ _ = r.WithContext(r.Context()) // initialize context for deadlock testing
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
file, err := fs.fetch(r.Filepath)
@@ -69,6 +72,7 @@
if fs.mockErr != nil {
return fs.mockErr
}
+ _ = r.WithContext(r.Context()) // initialize context for deadlock testing
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
switch r.Method {
@@ -129,11 +133,20 @@
if fs.mockErr != nil {
return nil, fs.mockErr
}
+ _ = r.WithContext(r.Context()) // initialize context for deadlock testing
fs.filesLock.Lock()
defer fs.filesLock.Unlock()
+ file, err := fs.fetch(r.Filepath)
+ if err != nil {
+ return nil, err
+ }
+
switch r.Method {
case "List":
+ if !file.IsDir() {
+ return nil, syscall.ENOTDIR
+ }
ordered_names := []string{}
for fn, _ := range fs.files {
if filepath.Dir(fn) == r.Filepath {
@@ -147,16 +160,8 @@
}
return listerat(list), nil
case "Stat":
- file, err := fs.fetch(r.Filepath)
- if err != nil {
- return nil, err
- }
return listerat([]os.FileInfo{file}), nil
case "Readlink":
- file, err := fs.fetch(r.Filepath)
- if err != nil {
- return nil, err
- }
if file.symlink != "" {
file, err = fs.fetch(file.symlink)
if err != nil {
diff --git a/request-server.go b/request-server.go
index 74ba46c..2e99720 100644
--- a/request-server.go
+++ b/request-server.go
@@ -3,7 +3,6 @@
import (
"context"
"io"
- "os"
"path"
"path/filepath"
"strconv"
@@ -56,37 +55,24 @@
defer rs.openRequestLock.Unlock()
rs.handleCount++
handle := strconv.Itoa(rs.handleCount)
+ r.handle = handle
rs.openRequests[handle] = r
return handle
}
-// Returns Request from openRequests, bool is false if it is missing
-// If the method is different, save/return a new Request w/ that Method.
+// Returns Request from openRequests, bool is false if it is missing.
//
// The Requests in openRequests work essentially as open file descriptors that
// you can do different things with. What you are doing with it are denoted by
-// the first packet of that type (read/write/etc). We create a new Request when
-// it changes to set the request.Method attribute in a thread safe way.
-func (rs *RequestServer) getRequest(handle, method string) (*Request, bool) {
+// the first packet of that type (read/write/etc).
+func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
rs.openRequestLock.RLock()
+ defer rs.openRequestLock.RUnlock()
r, ok := rs.openRequests[handle]
- rs.openRequestLock.RUnlock()
- if !ok || r.Method == method {
- return r, ok
- }
- // if we make it here we need to replace the request
- rs.openRequestLock.Lock()
- defer rs.openRequestLock.Unlock()
- r, ok = rs.openRequests[handle]
- if !ok || r.Method == method { // re-check needed b/c lock race
- return r, ok
- }
- r = r.copy()
- r.Method = method
- rs.openRequests[handle] = r
return r, ok
}
+// Close the Request and clear from openRequests map
func (rs *RequestServer) closeRequest(handle string) error {
rs.openRequestLock.Lock()
defer rs.openRequestLock.Unlock()
@@ -173,28 +159,24 @@
rpkt = cleanPacketPath(pkt)
case *sshFxpOpendirPacket:
request := requestFromPacket(ctx, pkt)
- rpkt = request.call(rs.Handlers, pkt)
- if stat, ok := rpkt.(*sshFxpStatResponse); ok {
- if stat.info.IsDir() {
- handle := rs.nextRequest(request)
- rpkt = sshFxpHandlePacket{ID: pkt.id(), Handle: handle}
- } else {
- rpkt = statusFromError(pkt, &os.PathError{
- Path: request.Filepath, Err: syscall.ENOTDIR})
- }
- }
+ rs.nextRequest(request)
+ rpkt = request.opendir(rs.Handlers, pkt)
case *sshFxpOpenPacket:
request := requestFromPacket(ctx, pkt)
- handle := rs.nextRequest(request)
- rpkt = sshFxpHandlePacket{ID: pkt.id(), Handle: handle}
- if pkt.hasPflags(ssh_FXF_CREAT) {
- if p := request.call(rs.Handlers, pkt); !statusOk(p) {
- rpkt = p // if error in write, return it
- }
+ rs.nextRequest(request)
+ rpkt = request.open(rs.Handlers, pkt)
+ case *sshFxpFstatPacket:
+ handle := pkt.getHandle()
+ request, ok := rs.getRequest(handle)
+ if !ok {
+ rpkt = statusFromError(pkt, syscall.EBADF)
+ } else {
+ request = NewRequest("Stat", request.Filepath)
+ rpkt = request.call(rs.Handlers, pkt)
}
case hasHandle:
handle := pkt.getHandle()
- request, ok := rs.getRequest(handle, requestMethod(pkt))
+ request, ok := rs.getRequest(handle)
if !ok {
rpkt = statusFromError(pkt, syscall.EBADF)
} else {
@@ -214,12 +196,6 @@
return nil
}
-// True is responsePacket is an OK status packet
-func statusOk(rpkt responsePacket) bool {
- p, ok := rpkt.(sshFxpStatusPacket)
- return ok && p.StatusError.Code == ssh_FX_OK
-}
-
// clean and return name packet for file
func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket {
path := cleanPath(pkt.getPath())
diff --git a/request-server_test.go b/request-server_test.go
index af68148..cb476b0 100644
--- a/request-server_test.go
+++ b/request-server_test.go
@@ -84,7 +84,7 @@
fh := p.svr.nextRequest(foo)
bh := p.svr.nextRequest(bar)
assert.Len(t, p.svr.openRequests, 2)
- _foo, ok := p.svr.getRequest(fh, "")
+ _foo, ok := p.svr.getRequest(fh)
assert.Equal(t, foo.Method, _foo.Method)
assert.Equal(t, foo.Filepath, _foo.Filepath)
assert.Equal(t, foo.Target, _foo.Target)
@@ -94,7 +94,7 @@
assert.NotNil(t, _foo.ctx)
assert.Equal(t, _foo.Context().Err(), nil, "context is still valid")
assert.True(t, ok)
- _, ok = p.svr.getRequest("zed", "")
+ _, ok = p.svr.getRequest("zed")
assert.False(t, ok)
p.svr.closeRequest(fh)
assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled")
@@ -147,7 +147,7 @@
f, err := r.fetch("/foo")
if assert.Nil(t, err) {
assert.False(t, f.isdir)
- assert.Equal(t, f.content, []byte(""))
+ assert.Len(t, f.content, 0)
}
// lets test with an error
r.returnErr(os.ErrInvalid)
@@ -170,7 +170,7 @@
assert.Error(t, err)
}
-func TestRequestRead(t *testing.T) {
+func TestRequestJustRead(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
_, err := putTestFile(p.cli, "/foo", "hello")
@@ -187,21 +187,18 @@
assert.Equal(t, "hello", string(contents[0:5]))
}
-func TestRequestReadFail(t *testing.T) {
+func TestRequestOpenFail(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
rf, err := p.cli.Open("/foo")
- assert.Nil(t, err)
- contents := make([]byte, 5)
- n, err := rf.Read(contents)
- assert.Equal(t, n, 0)
assert.Exactly(t, os.ErrNotExist, err)
+ assert.Nil(t, rf)
}
-func TestRequestOpen(t *testing.T) {
+func TestRequestCreate(t *testing.T) {
p := clientRequestServerPair(t)
defer p.Close()
- fh, err := p.cli.Open("foo")
+ fh, err := p.cli.Create("foo")
assert.Nil(t, err)
err = fh.Close()
assert.Nil(t, err)
@@ -354,7 +351,9 @@
for i := 0; i < 100; i++ {
fname := fmt.Sprintf("/foo_%02d", i)
_, err := putTestFile(p.cli, fname, fname)
- assert.Nil(t, err)
+ if err != nil {
+ t.Fatal("expected no error, got:", err)
+ }
}
_, err := p.cli.ReadDir("/foo_01")
assert.Equal(t, &StatusError{Code: ssh_FX_FAILURE,
diff --git a/request.go b/request.go
index 8188c77..e694e5f 100644
--- a/request.go
+++ b/request.go
@@ -24,6 +24,7 @@
Flags uint32
Attrs []byte // convert to sub-struct
Target string // for renames and sym-links
+ handle string
// reader/writer/readdir from handlers
state state
// context lasts duration of request
@@ -155,7 +156,7 @@
switch r.Method {
case "Get":
return fileget(handlers.FileGet, r, pkt)
- case "Put", "Open":
+ case "Put":
return fileput(handlers.FilePut, r, pkt)
case "Setstat", "Rename", "Rmdir", "Mkdir", "Symlink", "Remove":
return filecmd(handlers.FileCmd, r, pkt)
@@ -169,37 +170,47 @@
}
}
-// file data for additional read/write packets
-func packetData(p requestPacket) (data []byte, offset int64, length uint32) {
- switch p := p.(type) {
- case *sshFxpReadPacket:
- length = p.Len
- offset = int64(p.Offset)
- case *sshFxpWritePacket:
- data = p.Data
- length = p.Length
- offset = int64(p.Offset)
+// Additional initialization for Open packets
+func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
+ flags := r.Pflags()
+ var err error
+ switch {
+ case flags.Write, flags.Append, flags.Creat, flags.Trunc:
+ r.Method = "Put"
+ r.state.writerAt, err = h.FilePut.Filewrite(r)
+ case flags.Read:
+ r.Method = "Get"
+ r.state.readerAt, err = h.FileGet.Fileread(r)
+ default:
+ return statusFromError(pkt, errors.New("bad file flags"))
}
- return
+ if err != nil {
+ return statusFromError(pkt, err)
+ }
+ return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle}
+}
+func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
+ var err error
+ r.Method = "List"
+ r.state.listerAt, err = h.FileList.Filelist(r)
+ if err != nil {
+ switch err.(type) {
+ case syscall.Errno:
+ err = &os.PathError{Path: r.Filepath, Err: err}
+ }
+ return statusFromError(pkt, err)
+ }
+ return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle}
}
// wrap FileReader handler
func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket {
- var err error
+ //fmt.Println("fileget", r)
r.state.RLock()
reader := r.state.readerAt
r.state.RUnlock()
if reader == nil {
- r.state.Lock()
- if r.state.readerAt == nil {
- r.state.readerAt, err = h.Fileread(r)
- if err != nil {
- r.state.Unlock()
- return statusFromError(pkt, err)
- }
- }
- reader = r.state.readerAt
- r.state.Unlock()
+ return statusFromError(pkt, errors.New("unexpected read packet"))
}
_, offset, length := packetData(pkt)
@@ -218,28 +229,33 @@
// wrap FileWriter handler
func fileput(h FileWriter, r *Request, pkt requestPacket) responsePacket {
- var err error
+ //fmt.Println("fileput", r)
r.state.RLock()
writer := r.state.writerAt
r.state.RUnlock()
if writer == nil {
- r.state.Lock()
- if r.state.writerAt == nil {
- r.state.writerAt, err = h.Filewrite(r)
- if err != nil {
- r.state.Unlock()
- return statusFromError(pkt, err)
- }
- }
- writer = r.state.writerAt
- r.state.Unlock()
+ return statusFromError(pkt, errors.New("unexpected write packet"))
}
data, offset, _ := packetData(pkt)
- _, err = writer.WriteAt(data, offset)
+ _, err := writer.WriteAt(data, offset)
return statusFromError(pkt, err)
}
+// file data for additional read/write packets
+func packetData(p requestPacket) (data []byte, offset int64, length uint32) {
+ switch p := p.(type) {
+ case *sshFxpReadPacket:
+ length = p.Len
+ offset = int64(p.Offset)
+ case *sshFxpWritePacket:
+ data = p.Data
+ length = p.Length
+ offset = int64(p.Offset)
+ }
+ return
+}
+
// wrap FileCmder handler
func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
@@ -257,11 +273,7 @@
var err error
lister := r.getLister()
if lister == nil {
- lister, err = h.Filelist(r)
- if err != nil {
- return statusFromError(pkt, err)
- }
- r.setListerState(lister)
+ return statusFromError(pkt, errors.New("unexpected dir packet"))
}
offset := r.lsNext()
@@ -346,16 +358,10 @@
// init attributes of request object from packet data
func requestMethod(p requestPacket) (method string) {
switch p.(type) {
- case *sshFxpReadPacket:
- method = "Get"
- case *sshFxpWritePacket:
- method = "Put"
- case *sshFxpReaddirPacket:
- method = "List"
- case *sshFxpOpenPacket:
- method = "Open"
- case *sshFxpOpendirPacket:
- method = "Stat"
+ case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket:
+ // set in open() above
+ case *sshFxpOpendirPacket, *sshFxpReaddirPacket:
+ // set in opendir() above
case *sshFxpSetstatPacket, *sshFxpFsetstatPacket:
method = "Setstat"
case *sshFxpRenamePacket:
diff --git a/request_test.go b/request_test.go
index fd471d7..b3d7716 100644
--- a/request_test.go
+++ b/request_test.go
@@ -22,6 +22,7 @@
if t.err != nil {
return nil, t.err
}
+ _ = r.WithContext(r.Context()) // initialize context for deadlock testing
return bytes.NewReader(t.filecontents), nil
}
@@ -29,10 +30,12 @@
if t.err != nil {
return nil, t.err
}
+ _ = r.WithContext(r.Context()) // initialize context for deadlock testing
return io.WriterAt(t.output), nil
}
func (t *testHandler) Filecmd(r *Request) error {
+ _ = r.WithContext(r.Context()) // initialize context for deadlock testing
return t.err
}
@@ -40,6 +43,7 @@
if t.err != nil {
return nil, t.err
}
+ _ = r.WithContext(r.Context()) // initialize context for deadlock testing
f, err := os.Open(r.Filepath)
if err != nil {
return nil, err
@@ -56,11 +60,20 @@
var filecontents = []byte("file-data.")
+// XXX need new for creating test requests that supports Open-ing
func testRequest(method string) *Request {
+ var flags uint32
+ switch method {
+ case "Get":
+ flags = flags | ssh_FXF_READ
+ case "Put":
+ flags = flags | ssh_FXF_WRITE
+ }
request := &Request{
Filepath: "./request_test.go",
Method: method,
Attrs: []byte("foo"),
+ Flags: flags,
Target: "foo",
state: state{RWMutex: new(sync.RWMutex)},
}
@@ -129,9 +142,12 @@
}
func (fakePacket) UnmarshalBinary(d []byte) error { return nil }
+// XXX can't just set method to Get, need to use Open to setup Get/Put
func TestRequestGet(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Get")
+ pkt := fakePacket{myid: 1}
+ request.open(handlers, pkt)
// req.length is 5, so we test reads in 5 byte chunks
for i, txt := range []string{"file-", "data."} {
pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a",
@@ -153,9 +169,11 @@
assert.Equal(t, rpkt, statusFromError(rpkt, cmdErr))
}
+// XXX can't just set method to Get, need to use Open to setup Get/Put
func TestRequestPut(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Put")
+ request.state.writerAt, _ = handlers.FilePut.Filewrite(request)
pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5,
Data: []byte("file-")}
rpkt := request.call(handlers, pkt)
@@ -189,28 +207,46 @@
assert.Equal(t, spkt.info.Name(), "request_test.go")
}
-func TestRequestInfoList(t *testing.T) { testInfoMethod(t, "List") }
-func TestRequestInfoReadlink(t *testing.T) { testInfoMethod(t, "Readlink") }
-func testInfoMethod(t *testing.T, method string) {
+func TestRequestInfoList(t *testing.T) {
handlers := newTestHandlers()
- request := testRequest(method)
+ request := testRequest("List")
+ request.handle = "1"
+ pkt := fakePacket{myid: 1}
+ rpkt := request.opendir(handlers, pkt)
+ hpkt, ok := rpkt.(*sshFxpHandlePacket)
+ if assert.True(t, ok) {
+ assert.Equal(t, hpkt.Handle, "1")
+ }
+ pkt = fakePacket{myid: 2}
+ request.call(handlers, pkt)
+}
+func TestRequestInfoReadlink(t *testing.T) {
+ handlers := newTestHandlers()
+ request := testRequest("Readlink")
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt)
npkt, ok := rpkt.(*sshFxpNamePacket)
- assert.True(t, ok)
- assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
- assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
+ if assert.True(t, ok) {
+ assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
+ assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
+ }
}
func TestOpendirHandleReuse(t *testing.T) {
handlers := newTestHandlers()
request := testRequest("Stat")
+ request.handle = "1"
pkt := fakePacket{myid: 1}
rpkt := request.call(handlers, pkt)
assert.IsType(t, &sshFxpStatResponse{}, rpkt)
request.Method = "List"
pkt = fakePacket{myid: 2}
+ rpkt = request.opendir(handlers, pkt)
+ if assert.IsType(t, &sshFxpHandlePacket{}, rpkt) {
+ hpkt := rpkt.(*sshFxpHandlePacket)
+ assert.Equal(t, hpkt.Handle, "1")
+ }
rpkt = request.call(handlers, pkt)
assert.IsType(t, &sshFxpNamePacket{}, rpkt)
}