| package sftp |
| |
| import ( |
| "sync" |
| |
| "github.com/stretchr/testify/assert" |
| |
| "bytes" |
| "errors" |
| "io" |
| "os" |
| "testing" |
| ) |
| |
| type testHandler struct { |
| filecontents []byte // dummy contents |
| output io.WriterAt // dummy file out |
| err error // dummy error, should be file related |
| } |
| |
| func (t *testHandler) Fileread(r Request) (io.ReaderAt, error) { |
| if t.err != nil { |
| return nil, t.err |
| } |
| return bytes.NewReader(t.filecontents), nil |
| } |
| |
| func (t *testHandler) Filewrite(r Request) (io.WriterAt, error) { |
| if t.err != nil { |
| return nil, t.err |
| } |
| return io.WriterAt(t.output), nil |
| } |
| |
| func (t *testHandler) Filecmd(r Request) error { |
| if t.err != nil { |
| return t.err |
| } |
| return nil |
| } |
| |
| func (t *testHandler) Fileinfo(r Request) ([]os.FileInfo, error) { |
| if t.err != nil { |
| return nil, t.err |
| } |
| f, err := os.Open(r.Filepath) |
| if err != nil { |
| return nil, err |
| } |
| fi, err := f.Stat() |
| if err != nil { |
| return nil, err |
| } |
| return []os.FileInfo{fi}, nil |
| } |
| |
| // make sure len(fakefile) == len(filecontents) |
| type fakefile [10]byte |
| |
| var filecontents = []byte("file-data.") |
| |
| func testRequest(method string) Request { |
| request := Request{ |
| Filepath: "./request_test.go", |
| Method: method, |
| Attrs: []byte("foo"), |
| Target: "foo", |
| packets: make(chan packet_data, sftpServerWorkerCount), |
| state: &state{}, |
| stateLock: &sync.RWMutex{}, |
| } |
| for _, p := range []packet_data{ |
| packet_data{id: 1, data: filecontents[:5], length: 5}, |
| packet_data{id: 2, data: filecontents[5:], length: 5, offset: 5}} { |
| request.packets <- p |
| } |
| return request |
| } |
| |
| func (ff *fakefile) WriteAt(p []byte, off int64) (int, error) { |
| n := copy(ff[off:], p) |
| return n, nil |
| } |
| |
| func (ff fakefile) string() string { |
| b := make([]byte, len(ff)) |
| copy(b, ff[:]) |
| return string(b) |
| } |
| |
| func newTestHandlers() Handlers { |
| handler := &testHandler{ |
| filecontents: filecontents, |
| output: &fakefile{}, |
| err: nil, |
| } |
| return Handlers{ |
| FileGet: handler, |
| FilePut: handler, |
| FileCmd: handler, |
| FileInfo: handler, |
| } |
| } |
| |
| func (h Handlers) getOutString() string { |
| handler := h.FilePut.(*testHandler) |
| return handler.output.(*fakefile).string() |
| } |
| |
| var errTest = errors.New("test error") |
| |
| func (h *Handlers) returnError() { |
| handler := h.FilePut.(*testHandler) |
| handler.err = errTest |
| } |
| |
| func statusOk(t *testing.T, p interface{}) { |
| if pkt, ok := p.(*sshFxpStatusPacket); ok { |
| assert.Equal(t, pkt.StatusError.Code, uint32(ssh_FX_OK)) |
| } |
| } |
| |
| func TestRequestGet(t *testing.T) { |
| handlers := newTestHandlers() |
| request := testRequest("Get") |
| // req.length is 5, so we test reads in 5 byte chunks |
| for i, txt := range []string{"file-", "data."} { |
| pkt, err := request.handle(handlers) |
| assert.Nil(t, err) |
| dpkt := pkt.(*sshFxpDataPacket) |
| assert.Equal(t, dpkt.id(), uint32(i+1)) |
| assert.Equal(t, string(dpkt.Data), txt) |
| } |
| } |
| |
| func TestRequestPut(t *testing.T) { |
| handlers := newTestHandlers() |
| request := testRequest("Put") |
| pkt, err := request.handle(handlers) |
| assert.Nil(t, err) |
| statusOk(t, pkt) |
| pkt, err = request.handle(handlers) |
| assert.Nil(t, err) |
| statusOk(t, pkt) |
| assert.Equal(t, "file-data.", handlers.getOutString()) |
| } |
| |
| func TestRequestCmdr(t *testing.T) { |
| handlers := newTestHandlers() |
| request := testRequest("Mkdir") |
| pkt, err := request.handle(handlers) |
| assert.Nil(t, err) |
| statusOk(t, pkt) |
| |
| handlers.returnError() |
| pkt, err = request.handle(handlers) |
| assert.Nil(t, pkt) |
| assert.Equal(t, err, errTest) |
| } |
| |
| func TestRequestInfoList(t *testing.T) { testInfoMethod(t, "List") } |
| func TestRequestInfoReadlink(t *testing.T) { testInfoMethod(t, "Readlink") } |
| func TestRequestInfoStat(t *testing.T) { |
| handlers := newTestHandlers() |
| request := testRequest("Stat") |
| pkt, err := request.handle(handlers) |
| assert.Nil(t, err) |
| spkt, ok := pkt.(*sshFxpStatResponse) |
| assert.True(t, ok) |
| assert.Equal(t, spkt.info.Name(), "request_test.go") |
| } |
| |
| func testInfoMethod(t *testing.T, method string) { |
| handlers := newTestHandlers() |
| request := testRequest(method) |
| pkt, err := request.handle(handlers) |
| assert.Nil(t, err) |
| npkt, ok := pkt.(*sshFxpNamePacket) |
| assert.True(t, ok) |
| assert.IsType(t, sshFxpNameAttr{}, npkt.NameAttrs[0]) |
| assert.Equal(t, npkt.NameAttrs[0].Name, "request_test.go") |
| } |