code that manages incoming/outgoing packet order
Makes sure that outgoing packets order matches incoming packets order.
This is not required by spec but some clients seem to require it (eg.
winscp).
diff --git a/packet-manager.go b/packet-manager.go
new file mode 100644
index 0000000..385ee99
--- /dev/null
+++ b/packet-manager.go
@@ -0,0 +1,124 @@
+package sftp
+
+import (
+ "encoding"
+ "sort"
+)
+
+// --------------------------------------------------------------------
+// Process with 2 branch select, listening to each channel.
+// 0) start of loop
+
+// Branch A
+// 1) Wait for ids to come in and add them to id list.
+
+// Branch B
+// 1) Wait for a packet comes in.
+// 2) Add it to the packet list.
+// 3) The heads of each list are then compared and if they have the same ids
+// the packet is sent out and the entries removed.
+// 4) Goto step 2 Until the lists are emptied or the ids don't match.
+// 5) Goto step 0.
+// --------------------------------------------------------------------
+
+type packetSender interface {
+ sendPacket(encoding.BinaryMarshaler) error
+}
+
+type packetManager struct {
+ requests chan requestPacket
+ responses chan responsePacket
+ fini chan struct{}
+ incoming []uint32
+ outgoing []responsePacket
+ sender packetSender // connection object
+}
+
+func newPktMgr(sender packetSender) packetManager {
+ s := packetManager{
+ requests: make(chan requestPacket, sftpServerWorkerCount),
+ responses: make(chan responsePacket, sftpServerWorkerCount),
+ fini: make(chan struct{}),
+ incoming: make([]uint32, 0, sftpServerWorkerCount),
+ outgoing: make([]responsePacket, 0, sftpServerWorkerCount),
+ sender: sender,
+ }
+ go s.worker()
+ return s
+}
+
+// register incoming packets to be handled
+// send id of 0 for packets without id
+func (s packetManager) incomingPacket(pkt requestPacket) {
+ s.requests <- pkt // buffer == sftpServerWorkerCount
+}
+
+// register outgoing packets as being ready
+func (s packetManager) readyPacket(pkt responsePacket) {
+ s.responses <- pkt
+}
+
+// shut down packetManager worker
+func (s packetManager) close() {
+ close(s.fini)
+}
+
+// process packets
+func (s *packetManager) worker() {
+ for {
+ select {
+ case pkt := <-s.requests:
+ debug("incoming id: %v", pkt.id())
+ s.incoming = append(s.incoming, pkt.id())
+ if len(s.incoming) > 1 {
+ sort.Slice(s.incoming, func(i, j int) bool {
+ return s.incoming[i] < s.incoming[j]
+ })
+ }
+ case pkt := <-s.responses:
+ debug("outgoing pkt: %v", pkt.id())
+ s.outgoing = append(s.outgoing, pkt)
+ if len(s.outgoing) > 1 {
+ sort.Slice(s.outgoing, func(i, j int) bool {
+ return s.outgoing[i].id() < s.outgoing[j].id()
+ })
+ }
+ case <-s.fini:
+ return
+ }
+ s.maybeSendPackets()
+ }
+}
+
+// send as many packets as are ready
+func (s *packetManager) maybeSendPackets() {
+ for {
+ if len(s.outgoing) == 0 || len(s.incoming) == 0 {
+ debug("break! -- outgoing: %v; incoming: %v",
+ len(s.outgoing), len(s.incoming))
+ break
+ }
+ out := s.outgoing[0]
+ in := s.incoming[0]
+ debug("incoming: %v", s.incoming)
+ debug("outgoing: %v", outfilter(s.outgoing))
+ if in == out.id() {
+ s.sender.sendPacket(out)
+ // pop off heads
+ copy(s.incoming, s.incoming[1:]) // shift left
+ s.incoming = s.incoming[:len(s.incoming)-1] // remove last
+ copy(s.outgoing, s.outgoing[1:]) // shift left
+ s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
+ } else {
+ break
+ }
+ }
+}
+
+func outfilter(o []responsePacket) []uint32 {
+ res := make([]uint32, 0, len(o))
+ for _, v := range o {
+ res = append(res, v.id())
+ }
+ return res
+}
diff --git a/packet-manager_test.go b/packet-manager_test.go
new file mode 100644
index 0000000..e8dc9b9
--- /dev/null
+++ b/packet-manager_test.go
@@ -0,0 +1,81 @@
+package sftp
+
+import (
+ "encoding"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+type _sender struct {
+ sent chan encoding.BinaryMarshaler
+}
+
+func newsender() *_sender {
+ return &_sender{make(chan encoding.BinaryMarshaler)}
+}
+
+func (s _sender) sendPacket(p encoding.BinaryMarshaler) error {
+ s.sent <- p
+ return nil
+}
+
+type fakepacket uint32
+
+func (fakepacket) MarshalBinary() ([]byte, error) {
+ return []byte{}, nil
+}
+
+func (fakepacket) UnmarshalBinary([]byte) error {
+ return nil
+}
+
+func (f fakepacket) id() uint32 {
+ return uint32(f)
+}
+
+type pair struct {
+ in fakepacket
+ out fakepacket
+}
+
+var ttable1 = []pair{
+ pair{fakepacket(0), fakepacket(0)},
+ pair{fakepacket(1), fakepacket(1)},
+ pair{fakepacket(2), fakepacket(2)},
+ pair{fakepacket(3), fakepacket(3)},
+}
+
+var ttable2 = []pair{
+ pair{fakepacket(0), fakepacket(0)},
+ pair{fakepacket(1), fakepacket(4)},
+ pair{fakepacket(2), fakepacket(1)},
+ pair{fakepacket(3), fakepacket(3)},
+ pair{fakepacket(4), fakepacket(2)},
+}
+
+var tables = [][]pair{ttable1, ttable2}
+
+func TestPacketManager(t *testing.T) {
+ sender := newsender()
+ s := newPktMgr(sender)
+ // go func() {
+ // for _ = range s.workers {
+ // }
+ // }()
+ for i := range tables {
+ table := tables[i]
+ for _, p := range table {
+ s.incomingPacket(p.in)
+ }
+ for _, p := range table {
+ s.readyPacket(p.out)
+ }
+ for _, p := range table {
+ pkt := <-sender.sent
+ id := pkt.(fakepacket).id()
+ assert.Equal(t, id, p.in.id())
+ }
+ }
+ s.close()
+}