Merge pull request #282 from kardianos/kardianos-fixup-tests

sftp: update tests to finish and not panic on windows
diff --git a/packet-manager.go b/packet-manager.go
index d47713b..2f3be10 100644
--- a/packet-manager.go
+++ b/packet-manager.go
@@ -176,8 +176,10 @@
 			s.sender.sendPacket(out.(encoding.BinaryMarshaler))
 			// pop off heads
 			copy(s.incoming, s.incoming[1:])            // shift left
+			s.incoming[len(s.incoming)-1] = nil         // clear last
 			s.incoming = s.incoming[:len(s.incoming)-1] // remove last
 			copy(s.outgoing, s.outgoing[1:])            // shift left
+			s.outgoing[len(s.outgoing)-1] = nil         // clear last
 			s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
 		} else {
 			break
diff --git a/request-example.go b/request-example.go
index 4ef2e67..e5abd18 100644
--- a/request-example.go
+++ b/request-example.go
@@ -12,6 +12,7 @@
 	"path/filepath"
 	"sort"
 	"sync"
+	"syscall"
 	"time"
 )
 
@@ -29,6 +30,7 @@
 	if fs.mockErr != nil {
 		return nil, fs.mockErr
 	}
+	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
 	fs.filesLock.Lock()
 	defer fs.filesLock.Unlock()
 	file, err := fs.fetch(r.Filepath)
@@ -48,6 +50,7 @@
 	if fs.mockErr != nil {
 		return nil, fs.mockErr
 	}
+	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
 	fs.filesLock.Lock()
 	defer fs.filesLock.Unlock()
 	file, err := fs.fetch(r.Filepath)
@@ -69,6 +72,7 @@
 	if fs.mockErr != nil {
 		return fs.mockErr
 	}
+	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
 	fs.filesLock.Lock()
 	defer fs.filesLock.Unlock()
 	switch r.Method {
@@ -129,11 +133,20 @@
 	if fs.mockErr != nil {
 		return nil, fs.mockErr
 	}
+	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
 	fs.filesLock.Lock()
 	defer fs.filesLock.Unlock()
 
+	file, err := fs.fetch(r.Filepath)
+	if err != nil {
+		return nil, err
+	}
+
 	switch r.Method {
 	case "List":
+		if !file.IsDir() {
+			return nil, syscall.ENOTDIR
+		}
 		ordered_names := []string{}
 		for fn, _ := range fs.files {
 			if filepath.Dir(fn) == r.Filepath {
@@ -147,16 +160,8 @@
 		}
 		return listerat(list), nil
 	case "Stat":
-		file, err := fs.fetch(r.Filepath)
-		if err != nil {
-			return nil, err
-		}
 		return listerat([]os.FileInfo{file}), nil
 	case "Readlink":
-		file, err := fs.fetch(r.Filepath)
-		if err != nil {
-			return nil, err
-		}
 		if file.symlink != "" {
 			file, err = fs.fetch(file.symlink)
 			if err != nil {
diff --git a/request-server.go b/request-server.go
index 74ba46c..2e99720 100644
--- a/request-server.go
+++ b/request-server.go
@@ -3,7 +3,6 @@
 import (
 	"context"
 	"io"
-	"os"
 	"path"
 	"path/filepath"
 	"strconv"
@@ -56,37 +55,24 @@
 	defer rs.openRequestLock.Unlock()
 	rs.handleCount++
 	handle := strconv.Itoa(rs.handleCount)
+	r.handle = handle
 	rs.openRequests[handle] = r
 	return handle
 }
 
-// Returns Request from openRequests, bool is false if it is missing
-// If the method is different, save/return a new Request w/ that Method.
+// Returns Request from openRequests, bool is false if it is missing.
 //
 // The Requests in openRequests work essentially as open file descriptors that
 // you can do different things with. What you are doing with it are denoted by
-// the first packet of that type (read/write/etc). We create a new Request when
-// it changes to set the request.Method attribute in a thread safe way.
-func (rs *RequestServer) getRequest(handle, method string) (*Request, bool) {
+// the first packet of that type (read/write/etc).
+func (rs *RequestServer) getRequest(handle string) (*Request, bool) {
 	rs.openRequestLock.RLock()
+	defer rs.openRequestLock.RUnlock()
 	r, ok := rs.openRequests[handle]
-	rs.openRequestLock.RUnlock()
-	if !ok || r.Method == method {
-		return r, ok
-	}
-	// if we make it here we need to replace the request
-	rs.openRequestLock.Lock()
-	defer rs.openRequestLock.Unlock()
-	r, ok = rs.openRequests[handle]
-	if !ok || r.Method == method { // re-check needed b/c lock race
-		return r, ok
-	}
-	r = r.copy()
-	r.Method = method
-	rs.openRequests[handle] = r
 	return r, ok
 }
 
+// Close the Request and clear from openRequests map
 func (rs *RequestServer) closeRequest(handle string) error {
 	rs.openRequestLock.Lock()
 	defer rs.openRequestLock.Unlock()
@@ -173,28 +159,24 @@
 			rpkt = cleanPacketPath(pkt)
 		case *sshFxpOpendirPacket:
 			request := requestFromPacket(ctx, pkt)
-			rpkt = request.call(rs.Handlers, pkt)
-			if stat, ok := rpkt.(*sshFxpStatResponse); ok {
-				if stat.info.IsDir() {
-					handle := rs.nextRequest(request)
-					rpkt = sshFxpHandlePacket{ID: pkt.id(), Handle: handle}
-				} else {
-					rpkt = statusFromError(pkt, &os.PathError{
-						Path: request.Filepath, Err: syscall.ENOTDIR})
-				}
-			}
+			rs.nextRequest(request)
+			rpkt = request.opendir(rs.Handlers, pkt)
 		case *sshFxpOpenPacket:
 			request := requestFromPacket(ctx, pkt)
-			handle := rs.nextRequest(request)
-			rpkt = sshFxpHandlePacket{ID: pkt.id(), Handle: handle}
-			if pkt.hasPflags(ssh_FXF_CREAT) {
-				if p := request.call(rs.Handlers, pkt); !statusOk(p) {
-					rpkt = p // if error in write, return it
-				}
+			rs.nextRequest(request)
+			rpkt = request.open(rs.Handlers, pkt)
+		case *sshFxpFstatPacket:
+			handle := pkt.getHandle()
+			request, ok := rs.getRequest(handle)
+			if !ok {
+				rpkt = statusFromError(pkt, syscall.EBADF)
+			} else {
+				request = NewRequest("Stat", request.Filepath)
+				rpkt = request.call(rs.Handlers, pkt)
 			}
 		case hasHandle:
 			handle := pkt.getHandle()
-			request, ok := rs.getRequest(handle, requestMethod(pkt))
+			request, ok := rs.getRequest(handle)
 			if !ok {
 				rpkt = statusFromError(pkt, syscall.EBADF)
 			} else {
@@ -214,12 +196,6 @@
 	return nil
 }
 
-// True is responsePacket is an OK status packet
-func statusOk(rpkt responsePacket) bool {
-	p, ok := rpkt.(sshFxpStatusPacket)
-	return ok && p.StatusError.Code == ssh_FX_OK
-}
-
 // clean and return name packet for file
 func cleanPacketPath(pkt *sshFxpRealpathPacket) responsePacket {
 	path := cleanPath(pkt.getPath())
diff --git a/request-server_test.go b/request-server_test.go
index 2f11789..9507692 100644
--- a/request-server_test.go
+++ b/request-server_test.go
@@ -85,7 +85,7 @@
 	fh := p.svr.nextRequest(foo)
 	bh := p.svr.nextRequest(bar)
 	assert.Len(t, p.svr.openRequests, 2)
-	_foo, ok := p.svr.getRequest(fh, "")
+	_foo, ok := p.svr.getRequest(fh)
 	assert.Equal(t, foo.Method, _foo.Method)
 	assert.Equal(t, foo.Filepath, _foo.Filepath)
 	assert.Equal(t, foo.Target, _foo.Target)
@@ -95,7 +95,7 @@
 	assert.NotNil(t, _foo.ctx)
 	assert.Equal(t, _foo.Context().Err(), nil, "context is still valid")
 	assert.True(t, ok)
-	_, ok = p.svr.getRequest("zed", "")
+	_, ok = p.svr.getRequest("zed")
 	assert.False(t, ok)
 	p.svr.closeRequest(fh)
 	assert.Equal(t, _foo.Context().Err(), context.Canceled, "context is now canceled")
@@ -148,7 +148,7 @@
 	f, err := r.fetch("/foo")
 	if assert.Nil(t, err) {
 		assert.False(t, f.isdir)
-		assert.Equal(t, f.content, []byte(""))
+		assert.Len(t, f.content, 0)
 	}
 	// lets test with an error
 	r.returnErr(os.ErrInvalid)
@@ -171,7 +171,7 @@
 	assert.Error(t, err)
 }
 
-func TestRequestRead(t *testing.T) {
+func TestRequestJustRead(t *testing.T) {
 	p := clientRequestServerPair(t)
 	defer p.Close()
 	_, err := putTestFile(p.cli, "/foo", "hello")
@@ -188,21 +188,18 @@
 	assert.Equal(t, "hello", string(contents[0:5]))
 }
 
-func TestRequestReadFail(t *testing.T) {
+func TestRequestOpenFail(t *testing.T) {
 	p := clientRequestServerPair(t)
 	defer p.Close()
 	rf, err := p.cli.Open("/foo")
-	assert.Nil(t, err)
-	contents := make([]byte, 5)
-	n, err := rf.Read(contents)
-	assert.Equal(t, n, 0)
 	assert.Exactly(t, os.ErrNotExist, err)
+	assert.Nil(t, rf)
 }
 
-func TestRequestOpen(t *testing.T) {
+func TestRequestCreate(t *testing.T) {
 	p := clientRequestServerPair(t)
 	defer p.Close()
-	fh, err := p.cli.Open("foo")
+	fh, err := p.cli.Create("foo")
 	assert.Nil(t, err)
 	err = fh.Close()
 	assert.Nil(t, err)
@@ -355,7 +352,9 @@
 	for i := 0; i < 100; i++ {
 		fname := fmt.Sprintf("/foo_%02d", i)
 		_, err := putTestFile(p.cli, fname, fname)
-		assert.Nil(t, err)
+		if err != nil {
+			t.Fatal("expected no error, got:", err)
+		}
 	}
 	_, err := p.cli.ReadDir("/foo_01")
 	assert.Equal(t, &StatusError{Code: ssh_FX_FAILURE,
diff --git a/request.go b/request.go
index 8188c77..e694e5f 100644
--- a/request.go
+++ b/request.go
@@ -24,6 +24,7 @@
 	Flags    uint32
 	Attrs    []byte // convert to sub-struct
 	Target   string // for renames and sym-links
+	handle   string
 	// reader/writer/readdir from handlers
 	state state
 	// context lasts duration of request
@@ -155,7 +156,7 @@
 	switch r.Method {
 	case "Get":
 		return fileget(handlers.FileGet, r, pkt)
-	case "Put", "Open":
+	case "Put":
 		return fileput(handlers.FilePut, r, pkt)
 	case "Setstat", "Rename", "Rmdir", "Mkdir", "Symlink", "Remove":
 		return filecmd(handlers.FileCmd, r, pkt)
@@ -169,37 +170,47 @@
 	}
 }
 
-// file data for additional read/write packets
-func packetData(p requestPacket) (data []byte, offset int64, length uint32) {
-	switch p := p.(type) {
-	case *sshFxpReadPacket:
-		length = p.Len
-		offset = int64(p.Offset)
-	case *sshFxpWritePacket:
-		data = p.Data
-		length = p.Length
-		offset = int64(p.Offset)
+// Additional initialization for Open packets
+func (r *Request) open(h Handlers, pkt requestPacket) responsePacket {
+	flags := r.Pflags()
+	var err error
+	switch {
+	case flags.Write, flags.Append, flags.Creat, flags.Trunc:
+		r.Method = "Put"
+		r.state.writerAt, err = h.FilePut.Filewrite(r)
+	case flags.Read:
+		r.Method = "Get"
+		r.state.readerAt, err = h.FileGet.Fileread(r)
+	default:
+		return statusFromError(pkt, errors.New("bad file flags"))
 	}
-	return
+	if err != nil {
+		return statusFromError(pkt, err)
+	}
+	return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle}
+}
+func (r *Request) opendir(h Handlers, pkt requestPacket) responsePacket {
+	var err error
+	r.Method = "List"
+	r.state.listerAt, err = h.FileList.Filelist(r)
+	if err != nil {
+		switch err.(type) {
+		case syscall.Errno:
+			err = &os.PathError{Path: r.Filepath, Err: err}
+		}
+		return statusFromError(pkt, err)
+	}
+	return &sshFxpHandlePacket{ID: pkt.id(), Handle: r.handle}
 }
 
 // wrap FileReader handler
 func fileget(h FileReader, r *Request, pkt requestPacket) responsePacket {
-	var err error
+	//fmt.Println("fileget", r)
 	r.state.RLock()
 	reader := r.state.readerAt
 	r.state.RUnlock()
 	if reader == nil {
-		r.state.Lock()
-		if r.state.readerAt == nil {
-			r.state.readerAt, err = h.Fileread(r)
-			if err != nil {
-				r.state.Unlock()
-				return statusFromError(pkt, err)
-			}
-		}
-		reader = r.state.readerAt
-		r.state.Unlock()
+		return statusFromError(pkt, errors.New("unexpected read packet"))
 	}
 
 	_, offset, length := packetData(pkt)
@@ -218,28 +229,33 @@
 
 // wrap FileWriter handler
 func fileput(h FileWriter, r *Request, pkt requestPacket) responsePacket {
-	var err error
+	//fmt.Println("fileput", r)
 	r.state.RLock()
 	writer := r.state.writerAt
 	r.state.RUnlock()
 	if writer == nil {
-		r.state.Lock()
-		if r.state.writerAt == nil {
-			r.state.writerAt, err = h.Filewrite(r)
-			if err != nil {
-				r.state.Unlock()
-				return statusFromError(pkt, err)
-			}
-		}
-		writer = r.state.writerAt
-		r.state.Unlock()
+		return statusFromError(pkt, errors.New("unexpected write packet"))
 	}
 
 	data, offset, _ := packetData(pkt)
-	_, err = writer.WriteAt(data, offset)
+	_, err := writer.WriteAt(data, offset)
 	return statusFromError(pkt, err)
 }
 
+// file data for additional read/write packets
+func packetData(p requestPacket) (data []byte, offset int64, length uint32) {
+	switch p := p.(type) {
+	case *sshFxpReadPacket:
+		length = p.Len
+		offset = int64(p.Offset)
+	case *sshFxpWritePacket:
+		data = p.Data
+		length = p.Length
+		offset = int64(p.Offset)
+	}
+	return
+}
+
 // wrap FileCmder handler
 func filecmd(h FileCmder, r *Request, pkt requestPacket) responsePacket {
 
@@ -257,11 +273,7 @@
 	var err error
 	lister := r.getLister()
 	if lister == nil {
-		lister, err = h.Filelist(r)
-		if err != nil {
-			return statusFromError(pkt, err)
-		}
-		r.setListerState(lister)
+		return statusFromError(pkt, errors.New("unexpected dir packet"))
 	}
 
 	offset := r.lsNext()
@@ -346,16 +358,10 @@
 // init attributes of request object from packet data
 func requestMethod(p requestPacket) (method string) {
 	switch p.(type) {
-	case *sshFxpReadPacket:
-		method = "Get"
-	case *sshFxpWritePacket:
-		method = "Put"
-	case *sshFxpReaddirPacket:
-		method = "List"
-	case *sshFxpOpenPacket:
-		method = "Open"
-	case *sshFxpOpendirPacket:
-		method = "Stat"
+	case *sshFxpReadPacket, *sshFxpWritePacket, *sshFxpOpenPacket:
+		// set in open() above
+	case *sshFxpOpendirPacket, *sshFxpReaddirPacket:
+		// set in opendir() above
 	case *sshFxpSetstatPacket, *sshFxpFsetstatPacket:
 		method = "Setstat"
 	case *sshFxpRenamePacket:
diff --git a/request_test.go b/request_test.go
index fd471d7..b3d7716 100644
--- a/request_test.go
+++ b/request_test.go
@@ -22,6 +22,7 @@
 	if t.err != nil {
 		return nil, t.err
 	}
+	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
 	return bytes.NewReader(t.filecontents), nil
 }
 
@@ -29,10 +30,12 @@
 	if t.err != nil {
 		return nil, t.err
 	}
+	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
 	return io.WriterAt(t.output), nil
 }
 
 func (t *testHandler) Filecmd(r *Request) error {
+	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
 	return t.err
 }
 
@@ -40,6 +43,7 @@
 	if t.err != nil {
 		return nil, t.err
 	}
+	_ = r.WithContext(r.Context()) // initialize context for deadlock testing
 	f, err := os.Open(r.Filepath)
 	if err != nil {
 		return nil, err
@@ -56,11 +60,20 @@
 
 var filecontents = []byte("file-data.")
 
+// XXX need new for creating test requests that supports Open-ing
 func testRequest(method string) *Request {
+	var flags uint32
+	switch method {
+	case "Get":
+		flags = flags | ssh_FXF_READ
+	case "Put":
+		flags = flags | ssh_FXF_WRITE
+	}
 	request := &Request{
 		Filepath: "./request_test.go",
 		Method:   method,
 		Attrs:    []byte("foo"),
+		Flags:    flags,
 		Target:   "foo",
 		state:    state{RWMutex: new(sync.RWMutex)},
 	}
@@ -129,9 +142,12 @@
 }
 func (fakePacket) UnmarshalBinary(d []byte) error { return nil }
 
+// XXX can't just set method to Get, need to use Open to setup Get/Put
 func TestRequestGet(t *testing.T) {
 	handlers := newTestHandlers()
 	request := testRequest("Get")
+	pkt := fakePacket{myid: 1}
+	request.open(handlers, pkt)
 	// req.length is 5, so we test reads in 5 byte chunks
 	for i, txt := range []string{"file-", "data."} {
 		pkt := &sshFxpReadPacket{ID: uint32(i), Handle: "a",
@@ -153,9 +169,11 @@
 	assert.Equal(t, rpkt, statusFromError(rpkt, cmdErr))
 }
 
+// XXX can't just set method to Get, need to use Open to setup Get/Put
 func TestRequestPut(t *testing.T) {
 	handlers := newTestHandlers()
 	request := testRequest("Put")
+	request.state.writerAt, _ = handlers.FilePut.Filewrite(request)
 	pkt := &sshFxpWritePacket{ID: 0, Handle: "a", Offset: 0, Length: 5,
 		Data: []byte("file-")}
 	rpkt := request.call(handlers, pkt)
@@ -189,28 +207,46 @@
 	assert.Equal(t, spkt.info.Name(), "request_test.go")
 }
 
-func TestRequestInfoList(t *testing.T)     { testInfoMethod(t, "List") }
-func TestRequestInfoReadlink(t *testing.T) { testInfoMethod(t, "Readlink") }
-func testInfoMethod(t *testing.T, method string) {
+func TestRequestInfoList(t *testing.T) {
 	handlers := newTestHandlers()
-	request := testRequest(method)
+	request := testRequest("List")
+	request.handle = "1"
+	pkt := fakePacket{myid: 1}
+	rpkt := request.opendir(handlers, pkt)
+	hpkt, ok := rpkt.(*sshFxpHandlePacket)
+	if assert.True(t, ok) {
+		assert.Equal(t, hpkt.Handle, "1")
+	}
+	pkt = fakePacket{myid: 2}
+	request.call(handlers, pkt)
+}
+func TestRequestInfoReadlink(t *testing.T) {
+	handlers := newTestHandlers()
+	request := testRequest("Readlink")
 	pkt := fakePacket{myid: 1}
 	rpkt := request.call(handlers, pkt)
 	npkt, ok := rpkt.(*sshFxpNamePacket)
-	assert.True(t, ok)
-	assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
-	assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
+	if assert.True(t, ok) {
+		assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0])
+		assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go")
+	}
 }
 
 func TestOpendirHandleReuse(t *testing.T) {
 	handlers := newTestHandlers()
 	request := testRequest("Stat")
+	request.handle = "1"
 	pkt := fakePacket{myid: 1}
 	rpkt := request.call(handlers, pkt)
 	assert.IsType(t, &sshFxpStatResponse{}, rpkt)
 
 	request.Method = "List"
 	pkt = fakePacket{myid: 2}
+	rpkt = request.opendir(handlers, pkt)
+	if assert.IsType(t, &sshFxpHandlePacket{}, rpkt) {
+		hpkt := rpkt.(*sshFxpHandlePacket)
+		assert.Equal(t, hpkt.Handle, "1")
+	}
 	rpkt = request.call(handlers, pkt)
 	assert.IsType(t, &sshFxpNamePacket{}, rpkt)
 }