Merge pull request #439 from pkg/feature/ReadFromWithConcurrency
Export a ReadFromWithConcurrency function that permits ensuring concurrency usage.
diff --git a/client.go b/client.go
index 377c367..aa17a39 100644
--- a/client.go
+++ b/client.go
@@ -387,27 +387,11 @@
// Stat returns a FileInfo structure describing the file specified by path 'p'.
// If 'p' is a symbolic link, the returned FileInfo structure describes the referent file.
func (c *Client) Stat(p string) (os.FileInfo, error) {
- id := c.nextID()
- typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{
- ID: id,
- Path: p,
- })
+ fs, err := c.stat(p)
if err != nil {
return nil, err
}
- switch typ {
- case sshFxpAttrs:
- sid, data := unmarshalUint32(data)
- if sid != id {
- return nil, &unexpectedIDErr{id, sid}
- }
- attr, _ := unmarshalAttrs(data)
- return fileInfoFromStat(attr, path.Base(p)), nil
- case sshFxpStatus:
- return nil, normaliseError(unmarshalStatus(id, data))
- default:
- return nil, unimplementedPacketErr(typ)
- }
+ return fileInfoFromStat(fs, path.Base(p)), nil
}
// Lstat returns a FileInfo structure describing the file specified by path 'p'.
@@ -638,6 +622,30 @@
}
}
+func (c *Client) stat(path string) (*FileStat, error) {
+ id := c.nextID()
+ typ, data, err := c.sendPacket(nil, &sshFxpStatPacket{
+ ID: id,
+ Path: path,
+ })
+ if err != nil {
+ return nil, err
+ }
+ switch typ {
+ case sshFxpAttrs:
+ sid, data := unmarshalUint32(data)
+ if sid != id {
+ return nil, &unexpectedIDErr{id, sid}
+ }
+ attr, _ := unmarshalAttrs(data)
+ return attr, nil
+ case sshFxpStatus:
+ return nil, normaliseError(unmarshalStatus(id, data))
+ default:
+ return nil, unimplementedPacketErr(typ)
+ }
+}
+
func (c *Client) fstat(handle string) (*FileStat, error) {
id := c.nextID()
typ, data, err := c.sendPacket(nil, &sshFxpFstatPacket{
@@ -1160,23 +1168,19 @@
}
// For concurrency, we want to guess how many concurrent workers we should use.
- var fileSize uint64
+ var fileStat *FileStat
if f.c.useFstat {
- fileStat, err := f.c.fstat(f.handle)
- if err != nil {
- return 0, err
- }
- fileSize = fileStat.Size
+ fileStat, err = f.c.fstat(f.handle)
} else {
- fi, err := f.c.Stat(f.path)
- if err != nil {
- return 0, err
- }
- fileSize = uint64(fi.Size())
+ fileStat, err = f.c.stat(f.path)
+ }
+ if err != nil {
+ return 0, err
}
- if fileSize <= uint64(f.c.maxPacket) {
- // We should be able to handle this in one Read.
+ fileSize := fileStat.Size
+ if fileSize <= uint64(f.c.maxPacket) || !isRegular(fileStat.Mode) {
+ // only regular files are guaranteed to return (full read) xor (partial read, next error)
return f.writeToSequential(w)
}
@@ -1187,6 +1191,10 @@
// Now that concurrency64 is saturated to an int value, we know this assignment cannot possibly overflow.
concurrency := int(concurrency64)
+ chunkSize := f.c.maxPacket
+ pool := newBufPool(concurrency, chunkSize)
+ resPool := newResChanPool(concurrency)
+
cancel := make(chan struct{})
var wg sync.WaitGroup
defer func() {
@@ -1200,7 +1208,6 @@
type writeWork struct {
b []byte
- n int
off int64
err error
@@ -1209,7 +1216,10 @@
writeCh := make(chan writeWork)
type readWork struct {
- off int64
+ id uint32
+ res chan result
+ off int64
+
cur, next chan writeWork
}
readCh := make(chan readWork)
@@ -1219,49 +1229,78 @@
defer close(readCh)
off := f.offset
- chunkSize := int64(f.c.maxPacket)
cur := writeCh
for {
+ id := f.c.nextID()
+ res := resPool.Get()
+
next := make(chan writeWork)
readWork := readWork{
- off: off,
+ id: id,
+ res: res,
+ off: off,
+
cur: cur,
next: next,
}
+ f.c.dispatchRequest(res, &sshFxpReadPacket{
+ ID: id,
+ Handle: f.handle,
+ Offset: uint64(off),
+ Len: uint32(chunkSize),
+ })
+
select {
case readCh <- readWork:
case <-cancel:
return
}
- off += chunkSize
+ off += int64(chunkSize)
cur = next
}
}()
- pool := newBufPool(concurrency, f.c.maxPacket)
-
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
// Map_i: each worker gets readWork, and does the Read into a buffer at the given offset.
go func() {
defer wg.Done()
- ch := make(chan result, 1) // reusable channel
-
for readWork := range readCh {
- b := pool.Get()
+ var b []byte
+ var n int
- n, err := f.readChunkAt(ch, b, readWork.off)
- if n < 0 {
- panic("sftp.File: returned negative count from readChunkAt")
+ s := <-readWork.res
+ resPool.Put(readWork.res)
+
+ err := s.err
+ if err == nil {
+ switch s.typ {
+ case sshFxpStatus:
+ err = normaliseError(unmarshalStatus(readWork.id, s.data))
+
+ case sshFxpData:
+ sid, data := unmarshalUint32(s.data)
+ if readWork.id != sid {
+ err = &unexpectedIDErr{readWork.id, sid}
+
+ } else {
+ l, data := unmarshalUint32(data)
+ b = pool.Get()[:l]
+ n = copy(b, data[:l])
+ b = b[:n]
+ }
+
+ default:
+ err = unimplementedPacketErr(s.typ)
+ }
}
writeWork := writeWork{
b: b,
- n: n,
off: readWork.off,
err: err,
@@ -1290,10 +1329,10 @@
}
// Because writes are serialized, this will always be the last successfully read byte.
- f.offset = packet.off + int64(packet.n)
+ f.offset = packet.off + int64(len(packet.b))
- if packet.n > 0 {
- n, err := w.Write(packet.b[:packet.n])
+ if len(packet.b) > 0 {
+ n, err := w.Write(packet.b)
written += int64(n)
if err != nil {
return written, err
diff --git a/internal/encoding/ssh/filexfer/attrs.go b/internal/encoding/ssh/filexfer/attrs.go
index 5ce942e..1d4bb79 100644
--- a/internal/encoding/ssh/filexfer/attrs.go
+++ b/internal/encoding/ssh/filexfer/attrs.go
@@ -34,6 +34,57 @@
ExtendedAttributes []ExtendedAttribute
}
+// GetSize returns the Size field and a bool that is true if and only if the value is valid/defined.
+func (a *Attributes) GetSize() (size uint64, ok bool) {
+ return a.Size, a.Flags&AttrSize != 0
+}
+
+// SetSize is a convenience function that sets the Size field,
+// and marks the field as valid/defined in Flags.
+func (a *Attributes) SetSize(size uint64) {
+ a.Flags |= AttrSize
+ a.Size = size
+}
+
+// GetUIDGID returns the UID and GID fields and a bool that is true if and only if the values are valid/defined.
+func (a *Attributes) GetUIDGID() (uid, gid uint32, ok bool) {
+ return a.UID, a.GID, a.Flags&AttrUIDGID != 0
+}
+
+// SetUIDGID is a convenience function that sets the UID and GID fields,
+// and marks the fields as valid/defined in Flags.
+func (a *Attributes) SetUIDGID(uid, gid uint32) {
+ a.Flags |= AttrUIDGID
+ a.UID = uid
+ a.GID = gid
+}
+
+// GetPermissions returns the Permissions field and a bool that is true if and only if the value is valid/defined.
+func (a *Attributes) GetPermissions() (perms FileMode, ok bool) {
+ return a.Permissions, a.Flags&AttrPermissions != 0
+}
+
+// SetPermissions is a convenience function that sets the Permissions field,
+// and marks the field as valid/defined in Flags.
+func (a *Attributes) SetPermissions(perms FileMode) {
+ a.Flags |= AttrPermissions
+ a.Permissions = perms
+}
+
+// GetACModTime returns the ATime and MTime fields and a bool that is true if and only if the values are valid/defined.
+func (a *Attributes) GetACModTime() (atime, mtime uint32, ok bool) {
+ return a.ATime, a.MTime, a.Flags&AttrACModTime != 0
+ return a.ATime, a.MTime, a.Flags&AttrACModTime != 0
+}
+
+// SetACModTime is a convenience function that sets the ATime and MTime fields,
+// and marks the fields as valid/defined in Flags.
+func (a *Attributes) SetACModTime(atime, mtime uint32) {
+ a.Flags |= AttrACModTime
+ a.ATime = atime
+ a.MTime = mtime
+}
+
// Len returns the number of bytes a would marshal into.
func (a *Attributes) Len() int {
length := 4
@@ -98,7 +149,7 @@
// MarshalBinary returns a as the binary encoding of a.
func (a *Attributes) MarshalBinary() ([]byte, error) {
- buf := NewBuffer(make([]byte, a.Len()))
+ buf := NewBuffer(make([]byte, 0, a.Len()))
a.MarshalInto(buf)
return buf.Bytes(), nil
}
@@ -107,10 +158,20 @@
//
// NOTE: The values of fields not covered in the a.Flags are explicitly undefined.
func (a *Attributes) UnmarshalFrom(b *Buffer) (err error) {
- if a.Flags, err = b.ConsumeUint32(); err != nil {
+ flags, err := b.ConsumeUint32()
+ if err != nil {
return err
}
+ return a.XXX_UnmarshalByFlags(flags, b)
+}
+
+// XXX_UnmarshalByFlags uses the pre-existing a.Flags field to determine which fields to decode.
+// DO NOT USE THIS: it is an anti-corruption function to implement existing internal usage in pkg/sftp.
+// This function is not a part of any compatibility promise.
+func (a *Attributes) XXX_UnmarshalByFlags(flags uint32, b *Buffer) (err error) {
+ a.Flags = flags
+
// Short-circuit dummy attributes.
if a.Flags == 0 {
return nil
@@ -192,7 +253,7 @@
// MarshalBinary returns e as the binary encoding of e.
func (e *ExtendedAttribute) MarshalBinary() ([]byte, error) {
- buf := NewBuffer(make([]byte, e.Len()))
+ buf := NewBuffer(make([]byte, 0, e.Len()))
e.MarshalInto(buf)
return buf.Bytes(), nil
}
@@ -239,7 +300,7 @@
// MarshalBinary returns e as the binary encoding of e.
func (e *NameEntry) MarshalBinary() ([]byte, error) {
- buf := NewBuffer(make([]byte, e.Len()))
+ buf := NewBuffer(make([]byte, 0, e.Len()))
e.MarshalInto(buf)
return buf.Bytes(), nil
}
diff --git a/internal/encoding/ssh/filexfer/attrs_test.go b/internal/encoding/ssh/filexfer/attrs_test.go
index a4504a4..c03015c 100644
--- a/internal/encoding/ssh/filexfer/attrs_test.go
+++ b/internal/encoding/ssh/filexfer/attrs_test.go
@@ -113,48 +113,50 @@
t.Run(tt.name, func(t *testing.T) {
attr.Flags = tt.flags
- buf := new(Buffer)
- attr.MarshalInto(buf)
+ buf, err := attr.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
- if got, want := buf.Bytes(), tt.encoded; !bytes.Equal(got, want) {
- t.Fatalf("MarshalInto() = %X, but wanted %X", got, want)
+ if !bytes.Equal(buf, tt.encoded) {
+ t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, tt.encoded)
}
attr = Attributes{}
- if err := attr.UnmarshalFrom(buf); err != nil {
+ if err := attr.UnmarshalBinary(buf); err != nil {
t.Fatal("unexpected error:", err)
}
if attr.Flags != tt.flags {
- t.Errorf("UnmarshalFrom(): Flags was %x, but wanted %x", attr.Flags, tt.flags)
+ t.Errorf("UnmarshalBinary(): Flags was %x, but wanted %x", attr.Flags, tt.flags)
}
if attr.Flags&AttrSize != 0 && attr.Size != size {
- t.Errorf("UnmarshalFrom(): Size was %x, but wanted %x", attr.Size, size)
+ t.Errorf("UnmarshalBinary(): Size was %x, but wanted %x", attr.Size, size)
}
if attr.Flags&AttrUIDGID != 0 {
if attr.UID != uid {
- t.Errorf("UnmarshalFrom(): UID was %x, but wanted %x", attr.UID, uid)
+ t.Errorf("UnmarshalBinary(): UID was %x, but wanted %x", attr.UID, uid)
}
if attr.GID != gid {
- t.Errorf("UnmarshalFrom(): GID was %x, but wanted %x", attr.GID, gid)
+ t.Errorf("UnmarshalBinary(): GID was %x, but wanted %x", attr.GID, gid)
}
}
if attr.Flags&AttrPermissions != 0 && attr.Permissions != perms {
- t.Errorf("UnmarshalFrom(): Permissions was %#v, but wanted %#v", attr.Permissions, perms)
+ t.Errorf("UnmarshalBinary(): Permissions was %#v, but wanted %#v", attr.Permissions, perms)
}
if attr.Flags&AttrACModTime != 0 {
if attr.ATime != atime {
- t.Errorf("UnmarshalFrom(): ATime was %x, but wanted %x", attr.ATime, atime)
+ t.Errorf("UnmarshalBinary(): ATime was %x, but wanted %x", attr.ATime, atime)
}
if attr.MTime != mtime {
- t.Errorf("UnmarshalFrom(): MTime was %x, but wanted %x", attr.MTime, mtime)
+ t.Errorf("UnmarshalBinary(): MTime was %x, but wanted %x", attr.MTime, mtime)
}
}
@@ -162,11 +164,11 @@
extAttrs := attr.ExtendedAttributes
if count := len(extAttrs); count != 1 {
- t.Fatalf("UnmarshalFrom(): len(ExtendedAttributes) was %d, but wanted %d", count, 1)
+ t.Fatalf("UnmarshalBinary(): len(ExtendedAttributes) was %d, but wanted %d", count, 1)
}
if got := extAttrs[0]; got != extAttr {
- t.Errorf("UnmarshalFrom(): ExtendedAttributes[0] was %#v, but wanted %#v", got, extAttr)
+ t.Errorf("UnmarshalBinary(): ExtendedAttributes[0] was %#v, but wanted %#v", got, extAttr)
}
}
})
@@ -189,8 +191,10 @@
},
}
- buf := new(Buffer)
- e.MarshalInto(buf)
+ buf, err := e.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
want := []byte{
0x00, 0x00, 0x00, 0x03, 'f', 'o', 'o',
@@ -199,13 +203,13 @@
0x87, 0x65, 0x43, 0x21,
}
- if got := buf.Bytes(); !bytes.Equal(got, want) {
- t.Fatalf("MarshalInto() = %X, but wanted %X", got, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, want)
}
*e = NameEntry{}
- if err := e.UnmarshalFrom(buf); err != nil {
+ if err := e.UnmarshalBinary(buf); err != nil {
t.Fatal("unexpected error:", err)
}
diff --git a/internal/encoding/ssh/filexfer/extended_packets.go b/internal/encoding/ssh/filexfer/extended_packets.go
index 85ccce0..6b7b2ce 100644
--- a/internal/encoding/ssh/filexfer/extended_packets.go
+++ b/internal/encoding/ssh/filexfer/extended_packets.go
@@ -51,6 +51,11 @@
Data ExtendedData
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ExtendedPacket) Type() PacketType {
+ return PacketTypeExtended
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
//
// The Data is marshaled into binary, and returned as the payload.
@@ -97,6 +102,11 @@
Data ExtendedData
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ExtendedReplyPacket) Type() PacketType {
+ return PacketTypeExtendedReply
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
//
// The Data is marshaled into binary, and returned as the payload.
diff --git a/internal/encoding/ssh/filexfer/extended_packets_test.go b/internal/encoding/ssh/filexfer/extended_packets_test.go
index 668ef57..0860773 100644
--- a/internal/encoding/ssh/filexfer/extended_packets_test.go
+++ b/internal/encoding/ssh/filexfer/extended_packets_test.go
@@ -42,7 +42,7 @@
ExtendedRequest: extendedRequest,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -54,14 +54,14 @@
0x00, 0x00, 0x00, 11, 'f', 'o', 'o', '@', 'e', 'x', 'a', 'm', 'p', 'l', 'e',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = ExtendedPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -86,7 +86,7 @@
},
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -99,8 +99,8 @@
0x27,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = ExtendedPacket{
@@ -108,7 +108,7 @@
}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -116,17 +116,17 @@
t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extendedRequest)
}
- if data, ok := p.Data.(*testExtendedData); !ok {
- t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, data)
+ if buf, ok := p.Data.(*testExtendedData); !ok {
+ t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf)
- } else if data.value != value {
- t.Errorf("UnmarshalPacketBody(): Data.value was %#x, but expected %#x", data.value, value)
+ } else if buf.value != value {
+ t.Errorf("UnmarshalPacketBody(): Data.value was %#x, but expected %#x", buf.value, value)
}
*p = ExtendedPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -136,11 +136,11 @@
wantBuffer := []byte{0x27}
- if data, ok := p.Data.(*Buffer); !ok {
- t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, data)
+ if buf, ok := p.Data.(*Buffer); !ok {
+ t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf)
- } else if !bytes.Equal(data.b, wantBuffer) {
- t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", data.b, wantBuffer)
+ } else if !bytes.Equal(buf.b, wantBuffer) {
+ t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", buf.b, wantBuffer)
}
}
@@ -153,7 +153,7 @@
p := &ExtendedReplyPacket{}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -164,14 +164,14 @@
0x00, 0x00, 0x00, 42,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = ExtendedReplyPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
}
@@ -190,7 +190,7 @@
},
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -202,8 +202,8 @@
0x27,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = ExtendedReplyPacket{
@@ -211,30 +211,30 @@
}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
- if data, ok := p.Data.(*testExtendedData); !ok {
- t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, data)
+ if buf, ok := p.Data.(*testExtendedData); !ok {
+ t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf)
- } else if data.value != value {
- t.Errorf("UnmarshalPacketBody(): Data.value was %#x, but expected %#x", data.value, value)
+ } else if buf.value != value {
+ t.Errorf("UnmarshalPacketBody(): Data.value was %#x, but expected %#x", buf.value, value)
}
*p = ExtendedReplyPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
wantBuffer := []byte{0x27}
- if data, ok := p.Data.(*Buffer); !ok {
- t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, data)
+ if buf, ok := p.Data.(*Buffer); !ok {
+ t.Errorf("UnmarshalPacketBody(): Data was type %T, but expected %T", p.Data, buf)
- } else if !bytes.Equal(data.b, wantBuffer) {
- t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", data.b, wantBuffer)
+ } else if !bytes.Equal(buf.b, wantBuffer) {
+ t.Errorf("UnmarshalPacketBody(): Data was %X, but expected %X", buf.b, wantBuffer)
}
}
diff --git a/internal/encoding/ssh/filexfer/extensions.go b/internal/encoding/ssh/filexfer/extensions.go
index 23980b1..11c0b99 100644
--- a/internal/encoding/ssh/filexfer/extensions.go
+++ b/internal/encoding/ssh/filexfer/extensions.go
@@ -22,7 +22,7 @@
// MarshalBinary returns e as the binary encoding of e.
func (e *ExtensionPair) MarshalBinary() ([]byte, error) {
- buf := NewBuffer(make([]byte, e.Len()))
+ buf := NewBuffer(make([]byte, 0, e.Len()))
e.MarshalInto(buf)
return buf.Bytes(), nil
}
diff --git a/internal/encoding/ssh/filexfer/extensions_test.go b/internal/encoding/ssh/filexfer/extensions_test.go
index c145c03..453265b 100644
--- a/internal/encoding/ssh/filexfer/extensions_test.go
+++ b/internal/encoding/ssh/filexfer/extensions_test.go
@@ -16,9 +16,10 @@
Data: data,
}
- buf := new(Buffer)
-
- pair.MarshalInto(buf)
+ buf, err := pair.MarshalBinary()
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
want := []byte{
0x00, 0x00, 0x00, 3,
@@ -27,18 +28,18 @@
'1',
}
- if got := buf.Bytes(); !bytes.Equal(got, want) {
- t.Errorf("ExtensionPair.MarshalInto() = %X, but wanted %X", got, want)
+ if !bytes.Equal(buf, want) {
+ t.Errorf("ExtensionPair.MarshalBinary() = %X, but wanted %X", buf, want)
}
*pair = ExtensionPair{}
- if err := pair.UnmarshalFrom(buf); err != nil {
+ if err := pair.UnmarshalBinary(buf); err != nil {
t.Fatal("unexpected error:", err)
}
if pair.Name != name {
- t.Errorf("ExtensionPair.UnmarshalFrom(): Name was %q, but expected %q", pair.Name, name)
+ t.Errorf("ExtensionPair.UnmarshalBinary(): Name was %q, but expected %q", pair.Name, name)
}
if pair.Data != data {
diff --git a/internal/encoding/ssh/filexfer/filexfer.go b/internal/encoding/ssh/filexfer/filexfer.go
index e8b2c34..1e5abf7 100644
--- a/internal/encoding/ssh/filexfer/filexfer.go
+++ b/internal/encoding/ssh/filexfer/filexfer.go
@@ -1,8 +1,11 @@
// Package filexfer implements the wire encoding for secsh-filexfer as described in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02
package filexfer
-// Packet defines the behavior of an SFTP packet.
-type Packet interface {
+// PacketMarshaller narrowly defines packets that will only be transmitted.
+//
+// ExtendedPacket types will often only implement this interface,
+// since decoding the whole packet body of an ExtendedPacket can only be done dependent on the ExtendedRequest field.
+type PacketMarshaller interface {
// MarshalPacket is the primary intended way to encode a packet.
// The request-id for the packet is set from reqid.
//
@@ -15,9 +18,26 @@
//
// It shall encode in the first 4-bytes of the header the proper length of the rest of the header+payload.
MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error)
+}
+
+// Packet defines the behavior of a full generic SFTP packet.
+//
+// InitPacket, and VersionPacket are not generic SFTP packets, and instead implement (Un)MarshalBinary.
+//
+// ExtendedPacket types should not iplement this interface,
+// since decoding the whole packet body of an ExtendedPacket can only be done dependent on the ExtendedRequest field.
+type Packet interface {
+ PacketMarshaller
+
+ // Type returns the SSH_FXP_xy value associated with the specific packet.
+ Type() PacketType
// UnmarshalPacketBody decodes a packet body from the given Buffer.
// It is assumed that the common header values of the length, type and request-id have already been consumed.
+ //
+ // Implementations should not alias the given Buffer,
+ // instead they can consider prepopulating an internal buffer as a hint,
+ // and copying into that buffer if it has sufficient length.
UnmarshalPacketBody(buf *Buffer) error
}
diff --git a/internal/encoding/ssh/filexfer/fx.go b/internal/encoding/ssh/filexfer/fx.go
index 3eeac65..48f8698 100644
--- a/internal/encoding/ssh/filexfer/fx.go
+++ b/internal/encoding/ssh/filexfer/fx.go
@@ -61,12 +61,22 @@
StatusV6NoMatchingByteRangeLock
)
-func (f Status) Error() string {
- return f.String()
+func (s Status) Error() string {
+ return s.String()
}
-func (f Status) String() string {
- switch f {
+// Is returns true if the target is the same Status code,
+// or target is a StatusPacket with the same Status code.
+func (s Status) Is(target error) bool {
+ if target, ok := target.(*StatusPacket); ok {
+ return target.StatusCode == s
+ }
+
+ return s == target
+}
+
+func (s Status) String() string {
+ switch s {
case StatusOK:
return "SSH_FX_OK"
case StatusEOF:
@@ -132,6 +142,6 @@
case StatusV6NoMatchingByteRangeLock:
return "SSH_FX_NO_MATCHING_BYTE_RANGE_LOCK"
default:
- return fmt.Sprintf("SSH_FX_UNKNOWN(%d)", f)
+ return fmt.Sprintf("SSH_FX_UNKNOWN(%d)", s)
}
}
diff --git a/internal/encoding/ssh/filexfer/fx_test.go b/internal/encoding/ssh/filexfer/fx_test.go
index 4d321d9..3e8db1d 100644
--- a/internal/encoding/ssh/filexfer/fx_test.go
+++ b/internal/encoding/ssh/filexfer/fx_test.go
@@ -2,6 +2,7 @@
import (
"bufio"
+ "errors"
"regexp"
"strconv"
"strings"
@@ -82,3 +83,20 @@
t.Fatal("unexpected error:", err)
}
}
+
+func TestStatusIs(t *testing.T) {
+ status := StatusFailure
+
+ if !errors.Is(status, StatusFailure) {
+ t.Error("errors.Is(StatusFailure, StatusFailure) != true")
+ }
+ if !errors.Is(status, &StatusPacket{StatusCode: StatusFailure}) {
+ t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) != true")
+ }
+ if errors.Is(status, StatusOK) {
+ t.Error("errors.Is(StatusFailure, StatusFailure) == true")
+ }
+ if errors.Is(status, &StatusPacket{StatusCode: StatusOK}) {
+ t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) == true")
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/handle_packets.go b/internal/encoding/ssh/filexfer/handle_packets.go
index 33670dc..a142771 100644
--- a/internal/encoding/ssh/filexfer/handle_packets.go
+++ b/internal/encoding/ssh/filexfer/handle_packets.go
@@ -5,6 +5,11 @@
Handle string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ClosePacket) Type() PacketType {
+ return PacketTypeClose
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *ClosePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -36,6 +41,11 @@
Len uint32
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ReadPacket) Type() PacketType {
+ return PacketTypeRead
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *ReadPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -78,6 +88,11 @@
Data []byte
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *WritePacket) Type() PacketType {
+ return PacketTypeWrite
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *WritePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -97,6 +112,14 @@
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
+//
+// If p.Data is already populated, and of sufficient length to hold the data,
+// then this will copy the data into that byte slice.
+//
+// If p.Data has a length insufficient to hold the data,
+// then this will make a new slice of sufficient length, and copy the data into that.
+//
+// This means this _does not_ alias any of the data buffer that is passed in.
func (p *WritePacket) UnmarshalPacketBody(buf *Buffer) (err error) {
if p.Handle, err = buf.ConsumeString(); err != nil {
return err
@@ -106,10 +129,17 @@
return err
}
- if p.Data, err = buf.ConsumeByteSlice(); err != nil {
+ data, err := buf.ConsumeByteSlice()
+ if err != nil {
return err
}
+ if len(p.Data) < len(data) {
+ p.Data = make([]byte, len(data))
+ }
+
+ n := copy(p.Data, data)
+ p.Data = p.Data[:n]
return nil
}
@@ -118,6 +148,11 @@
Handle string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *FStatPacket) Type() PacketType {
+ return PacketTypeFStat
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *FStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -148,6 +183,11 @@
Attrs Attributes
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *FSetstatPacket) Type() PacketType {
+ return PacketTypeFSetstat
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *FSetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -179,6 +219,11 @@
Handle string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ReadDirPacket) Type() PacketType {
+ return PacketTypeReadDir
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *ReadDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
diff --git a/internal/encoding/ssh/filexfer/handle_packets_test.go b/internal/encoding/ssh/filexfer/handle_packets_test.go
index 10fdc53..8cc394e 100644
--- a/internal/encoding/ssh/filexfer/handle_packets_test.go
+++ b/internal/encoding/ssh/filexfer/handle_packets_test.go
@@ -17,7 +17,7 @@
Handle: "somehandle",
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -29,14 +29,14 @@
0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = ClosePacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -61,7 +61,7 @@
Len: length,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -75,14 +75,14 @@
0xFE, 0xDC, 0xBA, 0x98,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = ReadPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -116,7 +116,7 @@
Data: payload,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -130,14 +130,14 @@
0x00, 0x00, 0x00, 0x06, 'f', 'o', 'o', 'b', 'a', 'r',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = WritePacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -166,7 +166,7 @@
Handle: "somehandle",
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -178,14 +178,14 @@
0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = FStatPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -211,7 +211,7 @@
},
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -225,14 +225,14 @@
0x87, 0x65, 0x43, 0x21,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = FSetstatPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -253,7 +253,7 @@
Handle: "somehandle",
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -265,14 +265,14 @@
0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = ReadDirPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
diff --git a/internal/encoding/ssh/filexfer/init_packets.go b/internal/encoding/ssh/filexfer/init_packets.go
index b3e99dc..b0bc6f5 100644
--- a/internal/encoding/ssh/filexfer/init_packets.go
+++ b/internal/encoding/ssh/filexfer/init_packets.go
@@ -64,7 +64,7 @@
}
b := NewBuffer(make([]byte, 4, 4+size))
- b.AppendUint8(uint8(PacketTypeInit))
+ b.AppendUint8(uint8(PacketTypeVersion))
b.AppendUint32(p.Version)
for _, ext := range p.Extensions {
diff --git a/internal/encoding/ssh/filexfer/init_packets_test.go b/internal/encoding/ssh/filexfer/init_packets_test.go
index 0cab752..e7605f9 100644
--- a/internal/encoding/ssh/filexfer/init_packets_test.go
+++ b/internal/encoding/ssh/filexfer/init_packets_test.go
@@ -18,7 +18,7 @@
},
}
- data, err := p.MarshalBinary()
+ buf, err := p.MarshalBinary()
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -31,14 +31,14 @@
0x00, 0x00, 0x00, 1, '1',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, want)
}
*p = InitPacket{}
// UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
- if err := p.UnmarshalBinary(data[5:]); err != nil {
+ if err := p.UnmarshalBinary(buf[5:]); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -72,27 +72,27 @@
},
}
- data, err := p.MarshalBinary()
+ buf, err := p.MarshalBinary()
if err != nil {
t.Fatal("unexpected error:", err)
}
want := []byte{
0x00, 0x00, 0x00, 17,
- 1,
+ 2,
0x00, 0x00, 0x00, version,
0x00, 0x00, 0x00, 3, 'f', 'o', 'o',
0x00, 0x00, 0x00, 1, '1',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalBinary() = %X, but wanted %X", buf, want)
}
*p = VersionPacket{}
// UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
- if err := p.UnmarshalBinary(data[5:]); err != nil {
+ if err := p.UnmarshalBinary(buf[5:]); err != nil {
t.Fatal("unexpected error:", err)
}
diff --git a/internal/encoding/ssh/filexfer/open_packets.go b/internal/encoding/ssh/filexfer/open_packets.go
index b0e25c2..1358711 100644
--- a/internal/encoding/ssh/filexfer/open_packets.go
+++ b/internal/encoding/ssh/filexfer/open_packets.go
@@ -17,6 +17,11 @@
Attrs Attributes
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *OpenPacket) Type() PacketType {
+ return PacketTypeOpen
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *OpenPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -54,6 +59,11 @@
Path string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *OpenDirPacket) Type() PacketType {
+ return PacketTypeOpenDir
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *OpenDirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
diff --git a/internal/encoding/ssh/filexfer/open_packets_test.go b/internal/encoding/ssh/filexfer/open_packets_test.go
index 8637cd0..560c8b4 100644
--- a/internal/encoding/ssh/filexfer/open_packets_test.go
+++ b/internal/encoding/ssh/filexfer/open_packets_test.go
@@ -23,7 +23,7 @@
},
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -38,14 +38,14 @@
0x87, 0x65, 0x43, 0x21,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = OpenPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -78,7 +78,7 @@
Path: path,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -90,14 +90,14 @@
0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = OpenDirPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
diff --git a/internal/encoding/ssh/filexfer/openssh/fsync.go b/internal/encoding/ssh/filexfer/openssh/fsync.go
new file mode 100644
index 0000000..7ecfb0c
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/fsync.go
@@ -0,0 +1,73 @@
+package openssh
+
+import (
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+const extensionFSync = "fsync@openssh.com"
+
+// RegisterExtensionFSync registers the "fsync@openssh.com" extended packet with the encoding/ssh/filexfer package.
+func RegisterExtensionFSync() {
+ sshfx.RegisterExtendedPacketType(extensionFSync, func() sshfx.ExtendedData {
+ return new(FSyncExtendedPacket)
+ })
+}
+
+// ExtensionFSync returns an ExtensionPair suitable to append into an sshfx.InitPacket or sshfx.VersionPacket.
+func ExtensionFSync() *sshfx.ExtensionPair {
+ return &sshfx.ExtensionPair{
+ Name: extensionFSync,
+ Data: "1",
+ }
+}
+
+// FSyncExtendedPacket defines the fsync@openssh.com extend packet.
+type FSyncExtendedPacket struct {
+ Handle string
+}
+
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *FSyncExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
+// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
+func (ep *FSyncExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
+ p := &sshfx.ExtendedPacket{
+ ExtendedRequest: extensionFSync,
+
+ Data: ep,
+ }
+ return p.MarshalPacket(reqid, b)
+}
+
+// MarshalInto encodes ep into the binary encoding of the fsync@openssh.com extended packet-specific data.
+func (ep *FSyncExtendedPacket) MarshalInto(buf *sshfx.Buffer) {
+ buf.AppendString(ep.Handle)
+}
+
+// MarshalBinary encodes ep into the binary encoding of the fsync@openssh.com extended packet-specific data.
+//
+// NOTE: This _only_ encodes the packet-specific data, it does not encode the full extended packet.
+func (ep *FSyncExtendedPacket) MarshalBinary() ([]byte, error) {
+ // string(handle)
+ size := 4 + len(ep.Handle)
+
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
+ ep.MarshalInto(buf)
+ return buf.Bytes(), nil
+}
+
+// UnmarshalFrom decodes the fsync@openssh.com extended packet-specific data from buf.
+func (ep *FSyncExtendedPacket) UnmarshalFrom(buf *sshfx.Buffer) (err error) {
+ if ep.Handle, err = buf.ConsumeString(); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalBinary decodes the fsync@openssh.com extended packet-specific data into ep.
+func (ep *FSyncExtendedPacket) UnmarshalBinary(data []byte) (err error) {
+ return ep.UnmarshalFrom(sshfx.NewBuffer(data))
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/fsync_test.go b/internal/encoding/ssh/filexfer/openssh/fsync_test.go
new file mode 100644
index 0000000..f9e878f
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/fsync_test.go
@@ -0,0 +1,62 @@
+package openssh
+
+import (
+ "bytes"
+ "testing"
+
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+var _ sshfx.PacketMarshaller = &FSyncExtendedPacket{}
+
+func init() {
+ RegisterExtensionFSync()
+}
+
+func TestFSyncExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ handle = "somehandle"
+ )
+
+ ep := &FSyncExtendedPacket{
+ Handle: handle,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 40,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 17, 'f', 's', 'y', 'n', 'c', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionFSync {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionFSync)
+ }
+
+ ep, ok := p.Data.(*FSyncExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *FSyncExtendedPacket", p.Data)
+ }
+
+ if ep.Handle != handle {
+ t.Errorf("UnmarshalPacketBody(): Handle was %q, but expected %q", ep.Handle, handle)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/hardlink.go b/internal/encoding/ssh/filexfer/openssh/hardlink.go
index 7f7d5e0..17c3499 100644
--- a/internal/encoding/ssh/filexfer/openssh/hardlink.go
+++ b/internal/encoding/ssh/filexfer/openssh/hardlink.go
@@ -27,6 +27,11 @@
NewPath string
}
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *HardlinkExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
func (ep *HardlinkExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedPacket{
@@ -50,7 +55,7 @@
// string(oldpath) + string(newpath)
size := 4 + len(ep.OldPath) + 4 + len(ep.NewPath)
- buf := sshfx.NewBuffer(make([]byte, size))
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(buf)
return buf.Bytes(), nil
}
diff --git a/internal/encoding/ssh/filexfer/openssh/hardlink_test.go b/internal/encoding/ssh/filexfer/openssh/hardlink_test.go
new file mode 100644
index 0000000..5d3be06
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/hardlink_test.go
@@ -0,0 +1,69 @@
+package openssh
+
+import (
+ "bytes"
+ "testing"
+
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+var _ sshfx.PacketMarshaller = &HardlinkExtendedPacket{}
+
+func init() {
+ RegisterExtensionHardlink()
+}
+
+func TestHardlinkExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ oldpath = "/foo"
+ newpath = "/bar"
+ )
+
+ ep := &HardlinkExtendedPacket{
+ OldPath: oldpath,
+ NewPath: newpath,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 45,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 20, 'h', 'a', 'r', 'd', 'l', 'i', 'n', 'k', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionHardlink {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionHardlink)
+ }
+
+ ep, ok := p.Data.(*HardlinkExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *HardlinkExtendedPacket", p.Data)
+ }
+
+ if ep.OldPath != oldpath {
+ t.Errorf("UnmarshalPacketBody(): OldPath was %q, but expected %q", ep.OldPath, oldpath)
+ }
+
+ if ep.NewPath != newpath {
+ t.Errorf("UnmarshalPacketBody(): NewPath was %q, but expected %q", ep.NewPath, newpath)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/posix-rename.go b/internal/encoding/ssh/filexfer/openssh/posix-rename.go
index d19a580..a3d3de5 100644
--- a/internal/encoding/ssh/filexfer/openssh/posix-rename.go
+++ b/internal/encoding/ssh/filexfer/openssh/posix-rename.go
@@ -27,6 +27,11 @@
NewPath string
}
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *PosixRenameExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
func (ep *PosixRenameExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedPacket{
@@ -50,7 +55,7 @@
// string(oldpath) + string(newpath)
size := 4 + len(ep.OldPath) + 4 + len(ep.NewPath)
- buf := sshfx.NewBuffer(make([]byte, size))
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(buf)
return buf.Bytes(), nil
}
diff --git a/internal/encoding/ssh/filexfer/openssh/posix-rename_test.go b/internal/encoding/ssh/filexfer/openssh/posix-rename_test.go
new file mode 100644
index 0000000..6bdb10d
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/posix-rename_test.go
@@ -0,0 +1,69 @@
+package openssh
+
+import (
+ "bytes"
+ "testing"
+
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+var _ sshfx.PacketMarshaller = &PosixRenameExtendedPacket{}
+
+func init() {
+ RegisterExtensionPosixRename()
+}
+
+func TestPosixRenameExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ oldpath = "/foo"
+ newpath = "/bar"
+ )
+
+ ep := &PosixRenameExtendedPacket{
+ OldPath: oldpath,
+ NewPath: newpath,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 49,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 24, 'p', 'o', 's', 'i', 'x', '-', 'r', 'e', 'n', 'a', 'm', 'e', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ 0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionPosixRename {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionPosixRename)
+ }
+
+ ep, ok := p.Data.(*PosixRenameExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *PosixRenameExtendedPacket", p.Data)
+ }
+
+ if ep.OldPath != oldpath {
+ t.Errorf("UnmarshalPacketBody(): OldPath was %q, but expected %q", ep.OldPath, oldpath)
+ }
+
+ if ep.NewPath != newpath {
+ t.Errorf("UnmarshalPacketBody(): NewPath was %q, but expected %q", ep.NewPath, newpath)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/openssh/statvfs.go b/internal/encoding/ssh/filexfer/openssh/statvfs.go
index a58fc6a..3e9015f 100644
--- a/internal/encoding/ssh/filexfer/openssh/statvfs.go
+++ b/internal/encoding/ssh/filexfer/openssh/statvfs.go
@@ -26,6 +26,11 @@
Path string
}
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *StatVFSExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
func (ep *StatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedPacket{
@@ -47,7 +52,7 @@
func (ep *StatVFSExtendedPacket) MarshalBinary() ([]byte, error) {
size := 4 + len(ep.Path) // string(path)
- buf := sshfx.NewBuffer(make([]byte, size))
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(buf)
@@ -90,6 +95,11 @@
Path string
}
+// Type returns the SSH_FXP_EXTENDED packet type.
+func (ep *FStatVFSExtendedPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtended
+}
+
// MarshalPacket returns ep as a two-part binary encoding of the full extended packet.
func (ep *FStatVFSExtendedPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedPacket{
@@ -111,7 +121,7 @@
func (ep *FStatVFSExtendedPacket) MarshalBinary() ([]byte, error) {
size := 4 + len(ep.Path) // string(path)
- buf := sshfx.NewBuffer(make([]byte, size))
+ buf := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(buf)
@@ -154,6 +164,11 @@
MaxNameLength uint64 /* f_namemax: maximum filename length */
}
+// Type returns the SSH_FXP_EXTENDED_REPLY packet type.
+func (ep *StatVFSExtendedReplyPacket) Type() sshfx.PacketType {
+ return sshfx.PacketTypeExtendedReply
+}
+
// MarshalPacket returns ep as a two-part binary encoding of the full extended reply packet.
func (ep *StatVFSExtendedReplyPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
p := &sshfx.ExtendedReplyPacket{
@@ -162,6 +177,14 @@
return p.MarshalPacket(reqid, b)
}
+// UnmarshalPacketBody returns ep as a two-part binary encoding of the full extended reply packet.
+func (ep *StatVFSExtendedReplyPacket) UnmarshalPacketBody(buf *sshfx.Buffer) (err error) {
+ p := &sshfx.ExtendedReplyPacket{
+ Data: ep,
+ }
+ return p.UnmarshalPacketBody(buf)
+}
+
// MarshalInto encodes ep into the binary encoding of the (f)statvfs@openssh.com extended reply packet-specific data.
func (ep *StatVFSExtendedReplyPacket) MarshalInto(buf *sshfx.Buffer) {
buf.AppendUint64(ep.BlockSize)
@@ -183,7 +206,7 @@
func (ep *StatVFSExtendedReplyPacket) MarshalBinary() ([]byte, error) {
size := 11 * 8 // 11 × uint64(various)
- b := sshfx.NewBuffer(make([]byte, size))
+ b := sshfx.NewBuffer(make([]byte, 0, size))
ep.MarshalInto(b)
return b.Bytes(), nil
}
diff --git a/internal/encoding/ssh/filexfer/openssh/statvfs_test.go b/internal/encoding/ssh/filexfer/openssh/statvfs_test.go
new file mode 100644
index 0000000..014aa63
--- /dev/null
+++ b/internal/encoding/ssh/filexfer/openssh/statvfs_test.go
@@ -0,0 +1,239 @@
+package openssh
+
+import (
+ "bytes"
+ "testing"
+
+ sshfx "github.com/pkg/sftp/internal/encoding/ssh/filexfer"
+)
+
+var _ sshfx.PacketMarshaller = &StatVFSExtendedPacket{}
+
+func init() {
+ RegisterExtensionStatVFS()
+}
+
+func TestStatVFSExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ ep := &StatVFSExtendedPacket{
+ Path: path,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 36,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 19, 's', 't', 'a', 't', 'v', 'f', 's', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionStatVFS {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionStatVFS)
+ }
+
+ ep, ok := p.Data.(*StatVFSExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *StatVFSExtendedPacket", p.Data)
+ }
+
+ if ep.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", ep.Path, path)
+ }
+}
+
+var _ sshfx.PacketMarshaller = &FStatVFSExtendedPacket{}
+
+func init() {
+ RegisterExtensionFStatVFS()
+}
+
+func TestFStatVFSExtendedPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ ep := &FStatVFSExtendedPacket{
+ Path: path,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 37,
+ 200,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 20, 'f', 's', 't', 'a', 't', 'v', 'f', 's', '@', 'o', 'p', 'e', 'n', 's', 's', 'h', '.', 'c', 'o', 'm',
+ 0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ var p sshfx.ExtendedPacket
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ if p.ExtendedRequest != extensionFStatVFS {
+ t.Errorf("UnmarshalPacketBody(): ExtendedRequest was %q, but expected %q", p.ExtendedRequest, extensionFStatVFS)
+ }
+
+ ep, ok := p.Data.(*FStatVFSExtendedPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *FStatVFSExtendedPacket", p.Data)
+ }
+
+ if ep.Path != path {
+ t.Errorf("UnmarshalPacketBody(): Path was %q, but expected %q", ep.Path, path)
+ }
+}
+
+var _ sshfx.Packet = &StatVFSExtendedReplyPacket{}
+
+func TestStatVFSExtendedReplyPacket(t *testing.T) {
+ const (
+ id = 42
+ path = "/foo"
+ )
+
+ const (
+ BlockSize = uint64(iota + 13)
+ FragmentSize
+ Blocks
+ BlocksFree
+ BlocksAvail
+ Files
+ FilesFree
+ FilesAvail
+ FilesystemID
+ MountFlags
+ MaxNameLength
+ )
+
+ ep := &StatVFSExtendedReplyPacket{
+ BlockSize: BlockSize,
+ FragmentSize: FragmentSize,
+ Blocks: Blocks,
+ BlocksFree: BlocksFree,
+ BlocksAvail: BlocksAvail,
+ Files: Files,
+ FilesFree: FilesFree,
+ FilesAvail: FilesAvail,
+ FilesystemID: FilesystemID,
+ MountFlags: MountFlags,
+ MaxNameLength: MaxNameLength,
+ }
+
+ data, err := sshfx.ComposePacket(ep.MarshalPacket(id, nil))
+ if err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ want := []byte{
+ 0x00, 0x00, 0x00, 93,
+ 201,
+ 0x00, 0x00, 0x00, 42,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 13,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 14,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 15,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 16,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 17,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 18,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 19,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 20,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 21,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 22,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 23,
+ }
+
+ if !bytes.Equal(data, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", data, want)
+ }
+
+ *ep = StatVFSExtendedReplyPacket{}
+
+ p := sshfx.ExtendedReplyPacket{
+ Data: ep,
+ }
+
+ // UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
+ if err := p.UnmarshalPacketBody(sshfx.NewBuffer(data[9:])); err != nil {
+ t.Fatal("unexpected error:", err)
+ }
+
+ ep, ok := p.Data.(*StatVFSExtendedReplyPacket)
+ if !ok {
+ t.Fatalf("UnmarshaledPacketBody(): Data was type %T, but expected *StatVFSExtendedReplyPacket", p.Data)
+ }
+
+ if ep.BlockSize != BlockSize {
+ t.Errorf("UnmarshalPacketBody(): BlockSize was %d, but expected %d", ep.BlockSize, BlockSize)
+ }
+
+ if ep.FragmentSize != FragmentSize {
+ t.Errorf("UnmarshalPacketBody(): FragmentSize was %d, but expected %d", ep.FragmentSize, FragmentSize)
+ }
+
+ if ep.Blocks != Blocks {
+ t.Errorf("UnmarshalPacketBody(): Blocks was %d, but expected %d", ep.Blocks, Blocks)
+ }
+
+ if ep.BlocksFree != BlocksFree {
+ t.Errorf("UnmarshalPacketBody(): BlocksFree was %d, but expected %d", ep.BlocksFree, BlocksFree)
+ }
+
+ if ep.BlocksAvail != BlocksAvail {
+ t.Errorf("UnmarshalPacketBody(): BlocksAvail was %d, but expected %d", ep.BlocksAvail, BlocksAvail)
+ }
+
+ if ep.Files != Files {
+ t.Errorf("UnmarshalPacketBody(): Files was %d, but expected %d", ep.Files, Files)
+ }
+
+ if ep.FilesFree != FilesFree {
+ t.Errorf("UnmarshalPacketBody(): FilesFree was %d, but expected %d", ep.FilesFree, FilesFree)
+ }
+
+ if ep.FilesAvail != FilesAvail {
+ t.Errorf("UnmarshalPacketBody(): FilesAvail was %d, but expected %d", ep.FilesAvail, FilesAvail)
+ }
+
+ if ep.FilesystemID != FilesystemID {
+ t.Errorf("UnmarshalPacketBody(): FilesystemID was %d, but expected %d", ep.FilesystemID, FilesystemID)
+ }
+
+ if ep.MountFlags != MountFlags {
+ t.Errorf("UnmarshalPacketBody(): MountFlags was %d, but expected %d", ep.MountFlags, MountFlags)
+ }
+
+ if ep.MaxNameLength != MaxNameLength {
+ t.Errorf("UnmarshalPacketBody(): MaxNameLength was %d, but expected %d", ep.MaxNameLength, MaxNameLength)
+ }
+}
diff --git a/internal/encoding/ssh/filexfer/packets.go b/internal/encoding/ssh/filexfer/packets.go
index d597f49..72afa41 100644
--- a/internal/encoding/ssh/filexfer/packets.go
+++ b/internal/encoding/ssh/filexfer/packets.go
@@ -65,12 +65,17 @@
//
// Defined in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-3
type RawPacket struct {
- Type PacketType
- RequestID uint32
+ PacketType PacketType
+ RequestID uint32
Data Buffer
}
+// Type returns the Type field defining the SSH_FXP_xy type for this packet.
+func (p *RawPacket) Type() PacketType {
+ return p.PacketType
+}
+
// Reset clears the pointers and reference-semantic variables of RawPacket,
// releasing underlying resources, and making them and the RawPacket suitable to be reused,
// so long as no other references have been kept.
@@ -87,7 +92,7 @@
buf = NewMarshalBuffer(0)
}
- buf.StartPacket(p.Type, reqid)
+ buf.StartPacket(p.PacketType, reqid)
return buf.Packet(p.Data.Bytes())
}
@@ -110,7 +115,7 @@
return err
}
- p.Type = PacketType(typ)
+ p.PacketType = PacketType(typ)
if p.RequestID, err = buf.ConsumeUint32(); err != nil {
return err
@@ -227,6 +232,11 @@
Request Packet
}
+// Type returns the SSH_FXP_xy value associated with the underlying packet.
+func (p *RequestPacket) Type() PacketType {
+ return p.Request.Type()
+}
+
// Reset clears the pointers and reference-semantic variables in RequestPacket,
// releasing underlying resources, and making them and the RequestPacket suitable to be reused,
// so long as no other references have been kept.
diff --git a/internal/encoding/ssh/filexfer/packets_test.go b/internal/encoding/ssh/filexfer/packets_test.go
index 8658e04..1600920 100644
--- a/internal/encoding/ssh/filexfer/packets_test.go
+++ b/internal/encoding/ssh/filexfer/packets_test.go
@@ -13,8 +13,8 @@
)
p := &RawPacket{
- Type: PacketTypeStatus,
- RequestID: id,
+ PacketType: PacketTypeStatus,
+ RequestID: id,
Data: Buffer{
b: []byte{
0x00, 0x00, 0x00, 0x01,
@@ -24,7 +24,7 @@
},
}
- data, err := p.MarshalBinary()
+ buf, err := p.MarshalBinary()
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -38,18 +38,18 @@
0x00, 0x00, 0x00, 2, 'e', 'n',
}
- if !bytes.Equal(data, want) {
- t.Errorf("RawPacket.Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Errorf("RawPacket.MarshalBinary() = %X, but wanted %X", buf, want)
}
*p = RawPacket{}
- if err := p.ReadFrom(bytes.NewReader(data), nil, DefaultMaxPacketLength); err != nil {
+ if err := p.ReadFrom(bytes.NewReader(buf), nil, DefaultMaxPacketLength); err != nil {
t.Fatal("unexpected error:", err)
}
- if p.Type != PacketTypeStatus {
- t.Errorf("RawPacket.UnmarshalBinary(): Type was %v, but expected %v", p.Type, PacketTypeStat)
+ if p.PacketType != PacketTypeStatus {
+ t.Errorf("RawPacket.UnmarshalBinary(): Type was %v, but expected %v", p.PacketType, PacketTypeStat)
}
if p.RequestID != uint32(id) {
@@ -70,15 +70,15 @@
resp.UnmarshalPacketBody(&p.Data)
if resp.StatusCode != StatusEOF {
- t.Errorf("UnmarshalPacketBody(RawPacket.Data): StatusCode was %v, but expected %v", resp.StatusCode, StatusEOF)
+ t.Errorf("UnmarshalPacketBody(): StatusCode was %v, but expected %v", resp.StatusCode, StatusEOF)
}
if resp.ErrorMessage != errMsg {
- t.Errorf("UnmarshalPacketBody(RawPacket.Data): ErrorMessage was %q, but expected %q", resp.ErrorMessage, errMsg)
+ t.Errorf("UnmarshalPacketBody(): ErrorMessage was %q, but expected %q", resp.ErrorMessage, errMsg)
}
if resp.LanguageTag != langTag {
- t.Errorf("UnmarshalPacketBody(RawPacket.Data): LanguageTag was %q, but expected %q", resp.LanguageTag, langTag)
+ t.Errorf("UnmarshalPacketBody(): LanguageTag was %q, but expected %q", resp.LanguageTag, langTag)
}
}
@@ -95,7 +95,7 @@
},
}
- data, err := p.MarshalBinary()
+ buf, err := p.MarshalBinary()
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -107,13 +107,13 @@
0x00, 0x00, 0x00, 3, 'f', 'o', 'o',
}
- if !bytes.Equal(data, want) {
- t.Errorf("RequestPacket.Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Errorf("RequestPacket.MarshalBinary() = %X, but wanted %X", buf, want)
}
*p = RequestPacket{}
- if err := p.ReadFrom(bytes.NewReader(data), nil, DefaultMaxPacketLength); err != nil {
+ if err := p.ReadFrom(bytes.NewReader(buf), nil, DefaultMaxPacketLength); err != nil {
t.Fatal("unexpected error:", err)
}
diff --git a/internal/encoding/ssh/filexfer/path_packets.go b/internal/encoding/ssh/filexfer/path_packets.go
index f7085f3..e6f692d 100644
--- a/internal/encoding/ssh/filexfer/path_packets.go
+++ b/internal/encoding/ssh/filexfer/path_packets.go
@@ -5,6 +5,11 @@
Path string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *LStatPacket) Type() PacketType {
+ return PacketTypeLStat
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *LStatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -35,6 +40,11 @@
Attrs Attributes
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *SetstatPacket) Type() PacketType {
+ return PacketTypeSetstat
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *SetstatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -66,6 +76,11 @@
Path string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *RemovePacket) Type() PacketType {
+ return PacketTypeRemove
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *RemovePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -96,6 +111,11 @@
Attrs Attributes
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *MkdirPacket) Type() PacketType {
+ return PacketTypeMkdir
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *MkdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -127,6 +147,11 @@
Path string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *RmdirPacket) Type() PacketType {
+ return PacketTypeRmdir
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *RmdirPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -156,6 +181,11 @@
Path string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *RealPathPacket) Type() PacketType {
+ return PacketTypeRealPath
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *RealPathPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -185,6 +215,11 @@
Path string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *StatPacket) Type() PacketType {
+ return PacketTypeStat
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *StatPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -215,6 +250,11 @@
NewPath string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *RenamePacket) Type() PacketType {
+ return PacketTypeRename
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *RenamePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -250,6 +290,11 @@
Path string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *ReadLinkPacket) Type() PacketType {
+ return PacketTypeReadLink
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *ReadLinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -284,6 +329,11 @@
TargetPath string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *SymlinkPacket) Type() PacketType {
+ return PacketTypeSymlink
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *SymlinkPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
diff --git a/internal/encoding/ssh/filexfer/path_packets_test.go b/internal/encoding/ssh/filexfer/path_packets_test.go
index 852145e..4cff582 100644
--- a/internal/encoding/ssh/filexfer/path_packets_test.go
+++ b/internal/encoding/ssh/filexfer/path_packets_test.go
@@ -17,7 +17,7 @@
Path: path,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -29,14 +29,14 @@
0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = LStatPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -62,7 +62,7 @@
},
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -76,14 +76,14 @@
0x87, 0x65, 0x43, 0x21,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = SetstatPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -112,7 +112,7 @@
Path: path,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -124,14 +124,14 @@
0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = RemovePacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -157,7 +157,7 @@
},
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -171,14 +171,14 @@
0x87, 0x65, 0x43, 0x21,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = MkdirPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -207,7 +207,7 @@
Path: path,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -219,14 +219,14 @@
0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = RmdirPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -247,7 +247,7 @@
Path: path,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -259,14 +259,14 @@
0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = RealPathPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -287,7 +287,7 @@
Path: path,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -299,14 +299,14 @@
0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = StatPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -329,7 +329,7 @@
NewPath: newpath,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -342,14 +342,14 @@
0x00, 0x00, 0x00, 4, '/', 'b', 'a', 'r',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = RenamePacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -374,7 +374,7 @@
Path: path,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -386,14 +386,14 @@
0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = ReadLinkPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -416,7 +416,7 @@
TargetPath: targetpath,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -429,14 +429,14 @@
0x00, 0x00, 0x00, 4, '/', 'f', 'o', 'o',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = SymlinkPacket{}
// UnmarshalPacketBody assumes the (length, type, request-id) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
diff --git a/internal/encoding/ssh/filexfer/permissions.go b/internal/encoding/ssh/filexfer/permissions.go
index c5050ba..2fe63d5 100644
--- a/internal/encoding/ssh/filexfer/permissions.go
+++ b/internal/encoding/ssh/filexfer/permissions.go
@@ -86,5 +86,29 @@
}
}
+ if m&ModeSetUID != 0 {
+ if buf[3] == 'x' {
+ buf[3] = 's'
+ } else {
+ buf[3] = 'S'
+ }
+ }
+
+ if m&ModeSetGID != 0 {
+ if buf[6] == 'x' {
+ buf[6] = 's'
+ } else {
+ buf[6] = 'S'
+ }
+ }
+
+ if m&ModeSticky != 0 {
+ if buf[9] == 'x' {
+ buf[9] = 't'
+ } else {
+ buf[9] = 'T'
+ }
+ }
+
return string(buf[:])
}
diff --git a/internal/encoding/ssh/filexfer/response_packets.go b/internal/encoding/ssh/filexfer/response_packets.go
index 831b443..7a9b3ea 100644
--- a/internal/encoding/ssh/filexfer/response_packets.go
+++ b/internal/encoding/ssh/filexfer/response_packets.go
@@ -1,5 +1,9 @@
package filexfer
+import (
+ "fmt"
+)
+
// StatusPacket defines the SSH_FXP_STATUS packet.
//
// Specified in https://tools.ietf.org/html/draft-ietf-secsh-filexfer-02#section-7
@@ -9,6 +13,30 @@
LanguageTag string
}
+// Error makes StatusPacket an error type.
+func (p *StatusPacket) Error() string {
+ if p.ErrorMessage == "" {
+ return "sftp: " + p.StatusCode.String()
+ }
+
+ return fmt.Sprintf("sftp: %q (%s)", p.ErrorMessage, p.StatusCode)
+}
+
+// Is returns true if target is a StatusPacket with the same StatusCode,
+// or target is a Status code which is the same as SatusCode.
+func (p *StatusPacket) Is(target error) bool {
+ if target, ok := target.(*StatusPacket); ok {
+ return p.StatusCode == target.StatusCode
+ }
+
+ return p.StatusCode == target
+}
+
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *StatusPacket) Type() PacketType {
+ return PacketTypeStatus
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *StatusPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -51,6 +79,11 @@
Handle string
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *HandlePacket) Type() PacketType {
+ return PacketTypeHandle
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *HandlePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -80,6 +113,11 @@
Data []byte
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *DataPacket) Type() PacketType {
+ return PacketTypeData
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *DataPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -96,11 +134,26 @@
// UnmarshalPacketBody unmarshals the packet body from the given Buffer.
// It is assumed that the uint32(request-id) has already been consumed.
+//
+// If p.Data is already populated, and of sufficient length to hold the data,
+// then this will copy the data into that byte slice.
+//
+// If p.Data has a length insufficient to hold the data,
+// then this will make a new slice of sufficient length, and copy the data into that.
+//
+// This means this _does not_ alias any of the data buffer that is passed in.
func (p *DataPacket) UnmarshalPacketBody(buf *Buffer) (err error) {
- if p.Data, err = buf.ConsumeByteSlice(); err != nil {
+ data, err := buf.ConsumeByteSlice()
+ if err != nil {
return err
}
+ if len(p.Data) < len(data) {
+ p.Data = make([]byte, len(data))
+ }
+
+ n := copy(p.Data, data)
+ p.Data = p.Data[:n]
return nil
}
@@ -109,6 +162,11 @@
Entries []*NameEntry
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *NamePacket) Type() PacketType {
+ return PacketTypeName
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *NamePacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
@@ -159,6 +217,11 @@
Attrs Attributes
}
+// Type returns the SSH_FXP_xy value associated with this packet type.
+func (p *AttrsPacket) Type() PacketType {
+ return PacketTypeAttrs
+}
+
// MarshalPacket returns p as a two-part binary encoding of p.
func (p *AttrsPacket) MarshalPacket(reqid uint32, b []byte) (header, payload []byte, err error) {
buf := NewBuffer(b)
diff --git a/internal/encoding/ssh/filexfer/response_packets_test.go b/internal/encoding/ssh/filexfer/response_packets_test.go
index 23e42bb..9468665 100644
--- a/internal/encoding/ssh/filexfer/response_packets_test.go
+++ b/internal/encoding/ssh/filexfer/response_packets_test.go
@@ -2,9 +2,31 @@
import (
"bytes"
+ "errors"
"testing"
)
+func TestStatusPacketIs(t *testing.T) {
+ status := &StatusPacket{
+ StatusCode: StatusFailure,
+ ErrorMessage: "error message",
+ LanguageTag: "language tag",
+ }
+
+ if !errors.Is(status, StatusFailure) {
+ t.Error("errors.Is(StatusFailure, StatusFailure) != true")
+ }
+ if !errors.Is(status, &StatusPacket{StatusCode: StatusFailure}) {
+ t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) != true")
+ }
+ if errors.Is(status, StatusOK) {
+ t.Error("errors.Is(StatusFailure, StatusFailure) == true")
+ }
+ if errors.Is(status, &StatusPacket{StatusCode: StatusOK}) {
+ t.Error("errors.Is(StatusFailure, StatusPacket{StatusFailure}) == true")
+ }
+}
+
var _ Packet = &StatusPacket{}
func TestStatusPacket(t *testing.T) {
@@ -21,7 +43,7 @@
LanguageTag: languageTag,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -35,14 +57,14 @@
0x00, 0x00, 0x00, 9, 'x', '-', 'e', 'x', 'a', 'm', 'p', 'l', 'e',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = StatusPacket{}
// UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -71,7 +93,7 @@
Handle: "somehandle",
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -83,14 +105,14 @@
0x00, 0x00, 0x00, 10, 's', 'o', 'm', 'e', 'h', 'a', 'n', 'd', 'l', 'e',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = HandlePacket{}
// UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -112,7 +134,7 @@
Data: payload,
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -124,14 +146,14 @@
0x00, 0x00, 0x00, 6, 'f', 'o', 'o', 'b', 'a', 'r',
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = DataPacket{}
// UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -171,7 +193,7 @@
},
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -191,14 +213,14 @@
0x87, 0x65, 0x43, 0x02,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = NamePacket{}
// UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
@@ -240,7 +262,7 @@
},
}
- data, err := ComposePacket(p.MarshalPacket(id, nil))
+ buf, err := ComposePacket(p.MarshalPacket(id, nil))
if err != nil {
t.Fatal("unexpected error:", err)
}
@@ -253,14 +275,14 @@
0x87, 0x65, 0x43, 0x21,
}
- if !bytes.Equal(data, want) {
- t.Fatalf("Marshal() = %X, but wanted %X", data, want)
+ if !bytes.Equal(buf, want) {
+ t.Fatalf("MarshalPacket() = %X, but wanted %X", buf, want)
}
*p = AttrsPacket{}
// UnmarshalBinary assumes the uint32(length) + uint8(type) have already been consumed.
- if err := p.UnmarshalPacketBody(NewBuffer(data[9:])); err != nil {
+ if err := p.UnmarshalPacketBody(NewBuffer(buf[9:])); err != nil {
t.Fatal("unexpected error:", err)
}
diff --git a/pool.go b/pool.go
index c074637..563f64b 100644
--- a/pool.go
+++ b/pool.go
@@ -1,5 +1,7 @@
package sftp
+// bufPool provides a pool of byte-slices to be reused in various parts of the package.
+// It is safe to use concurrently through a pointer.
type bufPool struct {
ch chan []byte
blen int
@@ -13,17 +15,66 @@
}
func (p *bufPool) Get() []byte {
- select {
- case b := <-p.ch:
- return b
- default:
+ if p == nil {
+ // functional default: no reuse.
return make([]byte, p.blen)
}
+
+ for {
+ select {
+ case b := <-p.ch:
+ if cap(b) < p.blen {
+ // just in case: throw away any buffer with insufficient capacity.
+ continue
+ }
+
+ return b[:p.blen]
+
+ default:
+ return make([]byte, p.blen)
+ }
+ }
}
func (p *bufPool) Put(b []byte) {
+ if p == nil {
+ // functional default: no reuse.
+ return
+ }
+
+ if cap(b) < p.blen || cap(b) > p.blen*2 {
+ // DO NOT reuse buffers with insufficient capacity.
+ // This could cause panics when resizing to p.blen.
+
+ // DO NOT reuse buffers with excessive capacity.
+ // This could cause memory leaks.
+ return
+ }
+
select {
case p.ch <- b:
default:
}
}
+
+type resChanPool chan chan result
+
+func newResChanPool(depth int) resChanPool {
+ return make(chan chan result, depth)
+}
+
+func (p resChanPool) Get() chan result {
+ select {
+ case ch := <-p:
+ return ch
+ default:
+ return make(chan result, 1)
+ }
+}
+
+func (p resChanPool) Put(ch chan result) {
+ select {
+ case p <- ch:
+ default:
+ }
+}
diff --git a/stat_plan9.go b/stat_plan9.go
index 25074fe..418f121 100644
--- a/stat_plan9.go
+++ b/stat_plan9.go
@@ -41,6 +41,11 @@
return 0, false
}
+// isRegular returns true if the mode describes a regular file.
+func isRegular(mode uint32) bool {
+ return mode&S_IFMT == syscall.S_IFREG
+}
+
// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)
diff --git a/stat_posix.go b/stat_posix.go
index 71080ac..98b60e7 100644
--- a/stat_posix.go
+++ b/stat_posix.go
@@ -43,6 +43,11 @@
return 0, false
}
+// isRegular returns true if the mode describes a regular file.
+func isRegular(mode uint32) bool {
+ return mode&S_IFMT == syscall.S_IFREG
+}
+
// toFileMode converts sftp filemode bits to the os.FileMode specification
func toFileMode(mode uint32) os.FileMode {
var fm = os.FileMode(mode & 0777)