Merge pull request #263 from pkg/packet-embedded-in-ordered

ensure packet responses in same order as requests

Fixes #260
diff --git a/packet-manager.go b/packet-manager.go
index 2c1c656..c53ed56 100644
--- a/packet-manager.go
+++ b/packet-manager.go
@@ -7,30 +7,30 @@
 )
 
 // The goal of the packetManager is to keep the outgoing packets in the same
-// order as the incoming. This is due to some sftp clients requiring this
-// behavior (eg. winscp).
+// order as the incoming as is requires by section 7 of the RFC.
+
+type packetManager struct {
+	requests    chan orderedPacket
+	responses   chan orderedPacket
+	fini        chan struct{}
+	incoming    orderedPackets
+	outgoing    orderedPackets
+	sender      packetSender // connection object
+	working     *sync.WaitGroup
+	packetCount uint32
+}
 
 type packetSender interface {
 	sendPacket(encoding.BinaryMarshaler) error
 }
 
-type packetManager struct {
-	requests  chan requestPacket
-	responses chan responsePacket
-	fini      chan struct{}
-	incoming  requestPacketIDs
-	outgoing  responsePackets
-	sender    packetSender // connection object
-	working   *sync.WaitGroup
-}
-
 func newPktMgr(sender packetSender) *packetManager {
 	s := &packetManager{
-		requests:  make(chan requestPacket, SftpServerWorkerCount),
-		responses: make(chan responsePacket, SftpServerWorkerCount),
+		requests:  make(chan orderedPacket, SftpServerWorkerCount),
+		responses: make(chan orderedPacket, SftpServerWorkerCount),
 		fini:      make(chan struct{}),
-		incoming:  make([]uint32, 0, SftpServerWorkerCount),
-		outgoing:  make([]responsePacket, 0, SftpServerWorkerCount),
+		incoming:  make([]orderedPacket, 0, SftpServerWorkerCount),
+		outgoing:  make([]orderedPacket, 0, SftpServerWorkerCount),
 		sender:    sender,
 		working:   &sync.WaitGroup{},
 	}
@@ -38,31 +38,56 @@
 	return s
 }
 
-type responsePackets []responsePacket
+//// packet ordering
+func (s *packetManager) newOrderId() uint32 {
+	s.packetCount++
+	return s.packetCount
+}
 
-func (r responsePackets) Sort() {
-	sort.Slice(r, func(i, j int) bool {
-		return r[i].id() < r[j].id()
+type orderedRequest struct {
+	requestPacket
+	orderid uint32
+}
+
+func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
+	return orderedRequest{requestPacket: p, orderid: s.newOrderId()}
+}
+func (p orderedRequest) orderId() uint32       { return p.orderid }
+func (p orderedRequest) setOrderId(oid uint32) { p.orderid = oid }
+
+type orderedResponse struct {
+	responsePacket
+	orderid uint32
+}
+
+func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
+) orderedResponse {
+	return orderedResponse{responsePacket: p, orderid: id}
+}
+func (p orderedResponse) orderId() uint32       { return p.orderid }
+func (p orderedResponse) setOrderId(oid uint32) { p.orderid = oid }
+
+type orderedPacket interface {
+	id() uint32
+	orderId() uint32
+}
+type orderedPackets []orderedPacket
+
+func (o orderedPackets) Sort() {
+	sort.Slice(o, func(i, j int) bool {
+		return o[i].orderId() < o[j].orderId()
 	})
 }
 
-type requestPacketIDs []uint32
-
-func (r requestPacketIDs) Sort() {
-	sort.Slice(r, func(i, j int) bool {
-		return r[i] < r[j]
-	})
-}
-
+//// packet registry
 // register incoming packets to be handled
-// send id of 0 for packets without id
-func (s *packetManager) incomingPacket(pkt requestPacket) {
+func (s *packetManager) incomingPacket(pkt orderedRequest) {
 	s.working.Add(1)
-	s.requests <- pkt // buffer == SftpServerWorkerCount
+	s.requests <- pkt
 }
 
 // register outgoing packets as being ready
-func (s *packetManager) readyPacket(pkt responsePacket) {
+func (s *packetManager) readyPacket(pkt orderedResponse) {
 	s.responses <- pkt
 	s.working.Done()
 }
@@ -75,27 +100,26 @@
 }
 
 // Passed a worker function, returns a channel for incoming packets.
-// The goal is to process packets in the order they are received as is
-// requires by section 7 of the RFC, while maximizing throughput of file
-// transfers.
-func (s *packetManager) workerChan(runWorker func(chan requestPacket),
-) chan requestPacket {
+// Keep process packet responses in the order they are received while
+// maximizing throughput of file transfers.
+func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
+) chan orderedRequest {
 
-	rwChan := make(chan requestPacket, SftpServerWorkerCount)
+	rwChan := make(chan orderedRequest, SftpServerWorkerCount)
 	for i := 0; i < SftpServerWorkerCount; i++ {
 		runWorker(rwChan)
 	}
 
-	cmdChan := make(chan requestPacket)
+	cmdChan := make(chan orderedRequest)
 	runWorker(cmdChan)
 
-	pktChan := make(chan requestPacket, SftpServerWorkerCount)
+	pktChan := make(chan orderedRequest, SftpServerWorkerCount)
 	go func() {
 		// start with cmdChan
 		curChan := cmdChan
 		for pkt := range pktChan {
 			// on file open packet, switch to rwChan
-			switch pkt.(type) {
+			switch pkt.requestPacket.(type) {
 			case *sshFxpOpenPacket:
 				curChan = rwChan
 			// on file close packet, switch back to cmdChan
@@ -122,17 +146,13 @@
 	for {
 		select {
 		case pkt := <-s.requests:
-			debug("incoming id: %v", pkt.id())
-			s.incoming = append(s.incoming, pkt.id())
-			if len(s.incoming) > 1 {
-				s.incoming.Sort()
-			}
+			debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderId())
+			s.incoming = append(s.incoming, pkt)
+			s.incoming.Sort()
 		case pkt := <-s.responses:
-			debug("outgoing pkt: %v", pkt.id())
+			debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderId())
 			s.outgoing = append(s.outgoing, pkt)
-			if len(s.outgoing) > 1 {
-				s.outgoing.Sort()
-			}
+			s.outgoing.Sort()
 		case <-s.fini:
 			return
 		}
@@ -150,10 +170,11 @@
 		}
 		out := s.outgoing[0]
 		in := s.incoming[0]
-		// 		debug("incoming: %v", s.incoming)
-		// 		debug("outgoing: %v", outfilter(s.outgoing))
-		if in == out.id() {
-			s.sender.sendPacket(out)
+		// debug("incoming: %v", ids(s.incoming))
+		// debug("outgoing: %v", ids(s.outgoing))
+		if in.orderId() == out.orderId() {
+			debug("Sending packet: %v", out.id())
+			s.sender.sendPacket(out.(encoding.BinaryMarshaler))
 			// pop off heads
 			copy(s.incoming, s.incoming[1:])            // shift left
 			s.incoming = s.incoming[:len(s.incoming)-1] // remove last
@@ -165,10 +186,17 @@
 	}
 }
 
-//func outfilter(o []responsePacket) []uint32 {
-//	res := make([]uint32, 0, len(o))
-//	for _, v := range o {
-//		res = append(res, v.id())
-//	}
-//	return res
-//}
+// func oids(o []orderedPacket) []uint32 {
+// 	res := make([]uint32, 0, len(o))
+// 	for _, v := range o {
+// 		res = append(res, v.orderId())
+// 	}
+// 	return res
+// }
+// func ids(o []orderedPacket) []uint32 {
+// 	res := make([]uint32, 0, len(o))
+// 	for _, v := range o {
+// 		res = append(res, v.id())
+// 	}
+// 	return res
+// }
diff --git a/packet-manager_test.go b/packet-manager_test.go
index 55f41c0..552d40e 100644
--- a/packet-manager_test.go
+++ b/packet-manager_test.go
@@ -21,7 +21,14 @@
 	return nil
 }
 
-type fakepacket uint32
+type fakepacket struct {
+	reqid uint32
+	oid   uint32
+}
+
+func fake(rid, order uint32) fakepacket {
+	return fakepacket{reqid: rid, oid: order}
+}
 
 func (fakepacket) MarshalBinary() ([]byte, error) {
 	return []byte{}, nil
@@ -32,40 +39,51 @@
 }
 
 func (f fakepacket) id() uint32 {
-	return uint32(f)
+	return f.reqid
 }
 
 type pair struct {
-	in  fakepacket
-	out fakepacket
+	in, out fakepacket
+}
+
+type ordered_pair struct {
+	in  orderedRequest
+	out orderedResponse
 }
 
 // basic test
 var ttable1 = []pair{
-	pair{fakepacket(0), fakepacket(0)},
-	pair{fakepacket(1), fakepacket(1)},
-	pair{fakepacket(2), fakepacket(2)},
-	pair{fakepacket(3), fakepacket(3)},
+	pair{fake(0, 0), fake(0, 0)},
+	pair{fake(1, 1), fake(1, 1)},
+	pair{fake(2, 2), fake(2, 2)},
+	pair{fake(3, 3), fake(3, 3)},
 }
 
 // outgoing packets out of order
 var ttable2 = []pair{
-	pair{fakepacket(0), fakepacket(0)},
-	pair{fakepacket(1), fakepacket(4)},
-	pair{fakepacket(2), fakepacket(1)},
-	pair{fakepacket(3), fakepacket(3)},
-	pair{fakepacket(4), fakepacket(2)},
+	pair{fake(10, 0), fake(12, 2)},
+	pair{fake(11, 1), fake(11, 1)},
+	pair{fake(12, 2), fake(13, 3)},
+	pair{fake(13, 3), fake(10, 0)},
 }
 
-// incoming packets out of order
+// request ids are not incremental
 var ttable3 = []pair{
-	pair{fakepacket(2), fakepacket(0)},
-	pair{fakepacket(1), fakepacket(1)},
-	pair{fakepacket(3), fakepacket(2)},
-	pair{fakepacket(0), fakepacket(3)},
+	pair{fake(7, 0), fake(7, 0)},
+	pair{fake(1, 1), fake(1, 1)},
+	pair{fake(9, 2), fake(3, 3)},
+	pair{fake(3, 3), fake(9, 2)},
 }
 
-var tables = [][]pair{ttable1, ttable2, ttable3}
+// request ids are all the same
+var ttable4 = []pair{
+	pair{fake(1, 0), fake(1, 0)},
+	pair{fake(1, 1), fake(1, 1)},
+	pair{fake(1, 2), fake(1, 3)},
+	pair{fake(1, 3), fake(1, 2)},
+}
+
+var tables = [][]pair{ttable1, ttable2, ttable3, ttable4}
 
 func TestPacketManager(t *testing.T) {
 	sender := newTestSender()
@@ -73,30 +91,37 @@
 
 	for i := range tables {
 		table := tables[i]
+		ordered_pairs := make([]ordered_pair, 0, len(table))
 		for _, p := range table {
+			ordered_pairs = append(ordered_pairs, ordered_pair{
+				in:  orderedRequest{p.in, p.in.oid},
+				out: orderedResponse{p.out, p.out.oid},
+			})
+		}
+		for _, p := range ordered_pairs {
 			s.incomingPacket(p.in)
 		}
-		for _, p := range table {
+		for _, p := range ordered_pairs {
 			s.readyPacket(p.out)
 		}
-		for i := 0; i < len(table); i++ {
+		for _, p := range table {
 			pkt := <-sender.sent
-			id := pkt.(fakepacket).id()
-			assert.Equal(t, id, uint32(i))
+			id := pkt.(orderedResponse).id()
+			assert.Equal(t, id, p.in.id())
 		}
 	}
 	s.close()
 }
 
 func (p sshFxpRemovePacket) String() string {
-	return fmt.Sprintf("RmPct:%d", p.ID)
+	return fmt.Sprintf("RmPkt:%d", p.ID)
 }
 func (p sshFxpOpenPacket) String() string {
-	return fmt.Sprintf("OpPct:%d", p.ID)
+	return fmt.Sprintf("OpPkt:%d", p.ID)
 }
 func (p sshFxpWritePacket) String() string {
-	return fmt.Sprintf("WrPct:%d", p.ID)
+	return fmt.Sprintf("WrPkt:%d", p.ID)
 }
 func (p sshFxpClosePacket) String() string {
-	return fmt.Sprintf("ClPct:%d", p.ID)
+	return fmt.Sprintf("ClPkt:%d", p.ID)
 }
diff --git a/request-server.go b/request-server.go
index e589672..74ba46c 100644
--- a/request-server.go
+++ b/request-server.go
@@ -105,7 +105,7 @@
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 	var wg sync.WaitGroup
-	runWorker := func(ch chan requestPacket) {
+	runWorker := func(ch chan orderedRequest) {
 		wg.Add(1)
 		go func() {
 			defer wg.Done()
@@ -142,7 +142,7 @@
 			}
 		}
 
-		pktChan <- pkt
+		pktChan <- rs.pktMgr.newOrderedRequest(pkt)
 	}
 
 	close(pktChan) // shuts down sftpServerWorkers
@@ -159,11 +159,11 @@
 }
 
 func (rs *RequestServer) packetWorker(
-	ctx context.Context, pktChan chan requestPacket,
+	ctx context.Context, pktChan chan orderedRequest,
 ) error {
 	for pkt := range pktChan {
 		var rpkt responsePacket
-		switch pkt := pkt.(type) {
+		switch pkt := pkt.requestPacket.(type) {
 		case *sshFxInitPacket:
 			rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
 		case *sshFxpClosePacket:
@@ -208,7 +208,8 @@
 			return errors.Errorf("unexpected packet type %T", pkt)
 		}
 
-		rs.sendPacket(rpkt)
+		rs.pktMgr.readyPacket(
+			rs.pktMgr.newOrderedResponse(rpkt, pkt.orderId()))
 	}
 	return nil
 }
@@ -240,8 +241,3 @@
 	}
 	return path.Clean(p)
 }
-
-// Wrap underlying connection methods to use packetManager
-func (rs *RequestServer) sendPacket(pkt responsePacket) {
-	rs.pktMgr.readyPacket(pkt)
-}
diff --git a/server.go b/server.go
index a5dd2cd..0fac1b6 100644
--- a/server.go
+++ b/server.go
@@ -123,12 +123,11 @@
 }
 
 // Up to N parallel servers
-func (svr *Server) sftpServerWorker(pktChan chan requestPacket) error {
+func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error {
 	for pkt := range pktChan {
-
 		// readonly checks
 		readonly := true
-		switch pkt := pkt.(type) {
+		switch pkt := pkt.requestPacket.(type) {
 		case notReadOnly:
 			readonly = false
 		case *sshFxpOpenPacket:
@@ -140,7 +139,9 @@
 		// If server is operating read-only and a write operation is requested,
 		// return permission denied
 		if !readonly && svr.readOnly {
-			svr.sendPacket(statusFromError(pkt, syscall.EPERM))
+			svr.sendPacket(orderedResponse{
+				responsePacket: statusFromError(pkt, syscall.EPERM),
+				orderid:        pkt.orderId()})
 			continue
 		}
 
@@ -151,9 +152,9 @@
 	return nil
 }
 
-func handlePacket(s *Server, p requestPacket) error {
+func handlePacket(s *Server, p orderedRequest) error {
 	var rpkt responsePacket
-	switch p := p.(type) {
+	switch p := p.requestPacket.(type) {
 	case *sshFxInitPacket:
 		rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
 	case *sshFxpStatPacket:
@@ -177,7 +178,6 @@
 			rpkt = statusFromError(p, err)
 		}
 	case *sshFxpFstatPacket:
-		fmt.Println("fstat")
 		f, ok := s.getHandle(p.Handle)
 		var err error = syscall.EBADF
 		var info os.FileInfo
@@ -282,7 +282,7 @@
 		return errors.Errorf("unexpected packet type %T", p)
 	}
 
-	s.sendPacket(rpkt)
+	s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, p.orderId()))
 	return nil
 }
 
@@ -290,7 +290,7 @@
 // is stopped.
 func (svr *Server) Serve() error {
 	var wg sync.WaitGroup
-	runWorker := func(ch chan requestPacket) {
+	runWorker := func(ch chan orderedRequest) {
 		wg.Add(1)
 		go func() {
 			defer wg.Done()
@@ -327,7 +327,7 @@
 			}
 		}
 
-		pktChan <- pkt
+		pktChan <- svr.pktMgr.newOrderedRequest(pkt)
 	}
 
 	close(pktChan) // shuts down sftpServerWorkers
@@ -341,12 +341,6 @@
 	return err // error from recvPacket
 }
 
-// Wrap underlying connection methods to use packetManager
-func (svr *Server) sendPacket(pkt responsePacket) error {
-	svr.pktMgr.readyPacket(pkt)
-	return nil
-}
-
 type ider interface {
 	id() uint32
 }