Merge pull request #263 from pkg/packet-embedded-in-ordered
ensure packet responses in same order as requests
Fixes #260
diff --git a/packet-manager.go b/packet-manager.go
index 2c1c656..c53ed56 100644
--- a/packet-manager.go
+++ b/packet-manager.go
@@ -7,30 +7,30 @@
)
// The goal of the packetManager is to keep the outgoing packets in the same
-// order as the incoming. This is due to some sftp clients requiring this
-// behavior (eg. winscp).
+// order as the incoming as is requires by section 7 of the RFC.
+
+type packetManager struct {
+ requests chan orderedPacket
+ responses chan orderedPacket
+ fini chan struct{}
+ incoming orderedPackets
+ outgoing orderedPackets
+ sender packetSender // connection object
+ working *sync.WaitGroup
+ packetCount uint32
+}
type packetSender interface {
sendPacket(encoding.BinaryMarshaler) error
}
-type packetManager struct {
- requests chan requestPacket
- responses chan responsePacket
- fini chan struct{}
- incoming requestPacketIDs
- outgoing responsePackets
- sender packetSender // connection object
- working *sync.WaitGroup
-}
-
func newPktMgr(sender packetSender) *packetManager {
s := &packetManager{
- requests: make(chan requestPacket, SftpServerWorkerCount),
- responses: make(chan responsePacket, SftpServerWorkerCount),
+ requests: make(chan orderedPacket, SftpServerWorkerCount),
+ responses: make(chan orderedPacket, SftpServerWorkerCount),
fini: make(chan struct{}),
- incoming: make([]uint32, 0, SftpServerWorkerCount),
- outgoing: make([]responsePacket, 0, SftpServerWorkerCount),
+ incoming: make([]orderedPacket, 0, SftpServerWorkerCount),
+ outgoing: make([]orderedPacket, 0, SftpServerWorkerCount),
sender: sender,
working: &sync.WaitGroup{},
}
@@ -38,31 +38,56 @@
return s
}
-type responsePackets []responsePacket
+//// packet ordering
+func (s *packetManager) newOrderId() uint32 {
+ s.packetCount++
+ return s.packetCount
+}
-func (r responsePackets) Sort() {
- sort.Slice(r, func(i, j int) bool {
- return r[i].id() < r[j].id()
+type orderedRequest struct {
+ requestPacket
+ orderid uint32
+}
+
+func (s *packetManager) newOrderedRequest(p requestPacket) orderedRequest {
+ return orderedRequest{requestPacket: p, orderid: s.newOrderId()}
+}
+func (p orderedRequest) orderId() uint32 { return p.orderid }
+func (p orderedRequest) setOrderId(oid uint32) { p.orderid = oid }
+
+type orderedResponse struct {
+ responsePacket
+ orderid uint32
+}
+
+func (s *packetManager) newOrderedResponse(p responsePacket, id uint32,
+) orderedResponse {
+ return orderedResponse{responsePacket: p, orderid: id}
+}
+func (p orderedResponse) orderId() uint32 { return p.orderid }
+func (p orderedResponse) setOrderId(oid uint32) { p.orderid = oid }
+
+type orderedPacket interface {
+ id() uint32
+ orderId() uint32
+}
+type orderedPackets []orderedPacket
+
+func (o orderedPackets) Sort() {
+ sort.Slice(o, func(i, j int) bool {
+ return o[i].orderId() < o[j].orderId()
})
}
-type requestPacketIDs []uint32
-
-func (r requestPacketIDs) Sort() {
- sort.Slice(r, func(i, j int) bool {
- return r[i] < r[j]
- })
-}
-
+//// packet registry
// register incoming packets to be handled
-// send id of 0 for packets without id
-func (s *packetManager) incomingPacket(pkt requestPacket) {
+func (s *packetManager) incomingPacket(pkt orderedRequest) {
s.working.Add(1)
- s.requests <- pkt // buffer == SftpServerWorkerCount
+ s.requests <- pkt
}
// register outgoing packets as being ready
-func (s *packetManager) readyPacket(pkt responsePacket) {
+func (s *packetManager) readyPacket(pkt orderedResponse) {
s.responses <- pkt
s.working.Done()
}
@@ -75,27 +100,26 @@
}
// Passed a worker function, returns a channel for incoming packets.
-// The goal is to process packets in the order they are received as is
-// requires by section 7 of the RFC, while maximizing throughput of file
-// transfers.
-func (s *packetManager) workerChan(runWorker func(chan requestPacket),
-) chan requestPacket {
+// Keep process packet responses in the order they are received while
+// maximizing throughput of file transfers.
+func (s *packetManager) workerChan(runWorker func(chan orderedRequest),
+) chan orderedRequest {
- rwChan := make(chan requestPacket, SftpServerWorkerCount)
+ rwChan := make(chan orderedRequest, SftpServerWorkerCount)
for i := 0; i < SftpServerWorkerCount; i++ {
runWorker(rwChan)
}
- cmdChan := make(chan requestPacket)
+ cmdChan := make(chan orderedRequest)
runWorker(cmdChan)
- pktChan := make(chan requestPacket, SftpServerWorkerCount)
+ pktChan := make(chan orderedRequest, SftpServerWorkerCount)
go func() {
// start with cmdChan
curChan := cmdChan
for pkt := range pktChan {
// on file open packet, switch to rwChan
- switch pkt.(type) {
+ switch pkt.requestPacket.(type) {
case *sshFxpOpenPacket:
curChan = rwChan
// on file close packet, switch back to cmdChan
@@ -122,17 +146,13 @@
for {
select {
case pkt := <-s.requests:
- debug("incoming id: %v", pkt.id())
- s.incoming = append(s.incoming, pkt.id())
- if len(s.incoming) > 1 {
- s.incoming.Sort()
- }
+ debug("incoming id (oid): %v (%v)", pkt.id(), pkt.orderId())
+ s.incoming = append(s.incoming, pkt)
+ s.incoming.Sort()
case pkt := <-s.responses:
- debug("outgoing pkt: %v", pkt.id())
+ debug("outgoing id (oid): %v (%v)", pkt.id(), pkt.orderId())
s.outgoing = append(s.outgoing, pkt)
- if len(s.outgoing) > 1 {
- s.outgoing.Sort()
- }
+ s.outgoing.Sort()
case <-s.fini:
return
}
@@ -150,10 +170,11 @@
}
out := s.outgoing[0]
in := s.incoming[0]
- // debug("incoming: %v", s.incoming)
- // debug("outgoing: %v", outfilter(s.outgoing))
- if in == out.id() {
- s.sender.sendPacket(out)
+ // debug("incoming: %v", ids(s.incoming))
+ // debug("outgoing: %v", ids(s.outgoing))
+ if in.orderId() == out.orderId() {
+ debug("Sending packet: %v", out.id())
+ s.sender.sendPacket(out.(encoding.BinaryMarshaler))
// pop off heads
copy(s.incoming, s.incoming[1:]) // shift left
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
@@ -165,10 +186,17 @@
}
}
-//func outfilter(o []responsePacket) []uint32 {
-// res := make([]uint32, 0, len(o))
-// for _, v := range o {
-// res = append(res, v.id())
-// }
-// return res
-//}
+// func oids(o []orderedPacket) []uint32 {
+// res := make([]uint32, 0, len(o))
+// for _, v := range o {
+// res = append(res, v.orderId())
+// }
+// return res
+// }
+// func ids(o []orderedPacket) []uint32 {
+// res := make([]uint32, 0, len(o))
+// for _, v := range o {
+// res = append(res, v.id())
+// }
+// return res
+// }
diff --git a/packet-manager_test.go b/packet-manager_test.go
index 55f41c0..552d40e 100644
--- a/packet-manager_test.go
+++ b/packet-manager_test.go
@@ -21,7 +21,14 @@
return nil
}
-type fakepacket uint32
+type fakepacket struct {
+ reqid uint32
+ oid uint32
+}
+
+func fake(rid, order uint32) fakepacket {
+ return fakepacket{reqid: rid, oid: order}
+}
func (fakepacket) MarshalBinary() ([]byte, error) {
return []byte{}, nil
@@ -32,40 +39,51 @@
}
func (f fakepacket) id() uint32 {
- return uint32(f)
+ return f.reqid
}
type pair struct {
- in fakepacket
- out fakepacket
+ in, out fakepacket
+}
+
+type ordered_pair struct {
+ in orderedRequest
+ out orderedResponse
}
// basic test
var ttable1 = []pair{
- pair{fakepacket(0), fakepacket(0)},
- pair{fakepacket(1), fakepacket(1)},
- pair{fakepacket(2), fakepacket(2)},
- pair{fakepacket(3), fakepacket(3)},
+ pair{fake(0, 0), fake(0, 0)},
+ pair{fake(1, 1), fake(1, 1)},
+ pair{fake(2, 2), fake(2, 2)},
+ pair{fake(3, 3), fake(3, 3)},
}
// outgoing packets out of order
var ttable2 = []pair{
- pair{fakepacket(0), fakepacket(0)},
- pair{fakepacket(1), fakepacket(4)},
- pair{fakepacket(2), fakepacket(1)},
- pair{fakepacket(3), fakepacket(3)},
- pair{fakepacket(4), fakepacket(2)},
+ pair{fake(10, 0), fake(12, 2)},
+ pair{fake(11, 1), fake(11, 1)},
+ pair{fake(12, 2), fake(13, 3)},
+ pair{fake(13, 3), fake(10, 0)},
}
-// incoming packets out of order
+// request ids are not incremental
var ttable3 = []pair{
- pair{fakepacket(2), fakepacket(0)},
- pair{fakepacket(1), fakepacket(1)},
- pair{fakepacket(3), fakepacket(2)},
- pair{fakepacket(0), fakepacket(3)},
+ pair{fake(7, 0), fake(7, 0)},
+ pair{fake(1, 1), fake(1, 1)},
+ pair{fake(9, 2), fake(3, 3)},
+ pair{fake(3, 3), fake(9, 2)},
}
-var tables = [][]pair{ttable1, ttable2, ttable3}
+// request ids are all the same
+var ttable4 = []pair{
+ pair{fake(1, 0), fake(1, 0)},
+ pair{fake(1, 1), fake(1, 1)},
+ pair{fake(1, 2), fake(1, 3)},
+ pair{fake(1, 3), fake(1, 2)},
+}
+
+var tables = [][]pair{ttable1, ttable2, ttable3, ttable4}
func TestPacketManager(t *testing.T) {
sender := newTestSender()
@@ -73,30 +91,37 @@
for i := range tables {
table := tables[i]
+ ordered_pairs := make([]ordered_pair, 0, len(table))
for _, p := range table {
+ ordered_pairs = append(ordered_pairs, ordered_pair{
+ in: orderedRequest{p.in, p.in.oid},
+ out: orderedResponse{p.out, p.out.oid},
+ })
+ }
+ for _, p := range ordered_pairs {
s.incomingPacket(p.in)
}
- for _, p := range table {
+ for _, p := range ordered_pairs {
s.readyPacket(p.out)
}
- for i := 0; i < len(table); i++ {
+ for _, p := range table {
pkt := <-sender.sent
- id := pkt.(fakepacket).id()
- assert.Equal(t, id, uint32(i))
+ id := pkt.(orderedResponse).id()
+ assert.Equal(t, id, p.in.id())
}
}
s.close()
}
func (p sshFxpRemovePacket) String() string {
- return fmt.Sprintf("RmPct:%d", p.ID)
+ return fmt.Sprintf("RmPkt:%d", p.ID)
}
func (p sshFxpOpenPacket) String() string {
- return fmt.Sprintf("OpPct:%d", p.ID)
+ return fmt.Sprintf("OpPkt:%d", p.ID)
}
func (p sshFxpWritePacket) String() string {
- return fmt.Sprintf("WrPct:%d", p.ID)
+ return fmt.Sprintf("WrPkt:%d", p.ID)
}
func (p sshFxpClosePacket) String() string {
- return fmt.Sprintf("ClPct:%d", p.ID)
+ return fmt.Sprintf("ClPkt:%d", p.ID)
}
diff --git a/request-server.go b/request-server.go
index e589672..74ba46c 100644
--- a/request-server.go
+++ b/request-server.go
@@ -105,7 +105,7 @@
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var wg sync.WaitGroup
- runWorker := func(ch chan requestPacket) {
+ runWorker := func(ch chan orderedRequest) {
wg.Add(1)
go func() {
defer wg.Done()
@@ -142,7 +142,7 @@
}
}
- pktChan <- pkt
+ pktChan <- rs.pktMgr.newOrderedRequest(pkt)
}
close(pktChan) // shuts down sftpServerWorkers
@@ -159,11 +159,11 @@
}
func (rs *RequestServer) packetWorker(
- ctx context.Context, pktChan chan requestPacket,
+ ctx context.Context, pktChan chan orderedRequest,
) error {
for pkt := range pktChan {
var rpkt responsePacket
- switch pkt := pkt.(type) {
+ switch pkt := pkt.requestPacket.(type) {
case *sshFxInitPacket:
rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
case *sshFxpClosePacket:
@@ -208,7 +208,8 @@
return errors.Errorf("unexpected packet type %T", pkt)
}
- rs.sendPacket(rpkt)
+ rs.pktMgr.readyPacket(
+ rs.pktMgr.newOrderedResponse(rpkt, pkt.orderId()))
}
return nil
}
@@ -240,8 +241,3 @@
}
return path.Clean(p)
}
-
-// Wrap underlying connection methods to use packetManager
-func (rs *RequestServer) sendPacket(pkt responsePacket) {
- rs.pktMgr.readyPacket(pkt)
-}
diff --git a/server.go b/server.go
index a5dd2cd..0fac1b6 100644
--- a/server.go
+++ b/server.go
@@ -123,12 +123,11 @@
}
// Up to N parallel servers
-func (svr *Server) sftpServerWorker(pktChan chan requestPacket) error {
+func (svr *Server) sftpServerWorker(pktChan chan orderedRequest) error {
for pkt := range pktChan {
-
// readonly checks
readonly := true
- switch pkt := pkt.(type) {
+ switch pkt := pkt.requestPacket.(type) {
case notReadOnly:
readonly = false
case *sshFxpOpenPacket:
@@ -140,7 +139,9 @@
// If server is operating read-only and a write operation is requested,
// return permission denied
if !readonly && svr.readOnly {
- svr.sendPacket(statusFromError(pkt, syscall.EPERM))
+ svr.sendPacket(orderedResponse{
+ responsePacket: statusFromError(pkt, syscall.EPERM),
+ orderid: pkt.orderId()})
continue
}
@@ -151,9 +152,9 @@
return nil
}
-func handlePacket(s *Server, p requestPacket) error {
+func handlePacket(s *Server, p orderedRequest) error {
var rpkt responsePacket
- switch p := p.(type) {
+ switch p := p.requestPacket.(type) {
case *sshFxInitPacket:
rpkt = sshFxVersionPacket{Version: sftpProtocolVersion}
case *sshFxpStatPacket:
@@ -177,7 +178,6 @@
rpkt = statusFromError(p, err)
}
case *sshFxpFstatPacket:
- fmt.Println("fstat")
f, ok := s.getHandle(p.Handle)
var err error = syscall.EBADF
var info os.FileInfo
@@ -282,7 +282,7 @@
return errors.Errorf("unexpected packet type %T", p)
}
- s.sendPacket(rpkt)
+ s.pktMgr.readyPacket(s.pktMgr.newOrderedResponse(rpkt, p.orderId()))
return nil
}
@@ -290,7 +290,7 @@
// is stopped.
func (svr *Server) Serve() error {
var wg sync.WaitGroup
- runWorker := func(ch chan requestPacket) {
+ runWorker := func(ch chan orderedRequest) {
wg.Add(1)
go func() {
defer wg.Done()
@@ -327,7 +327,7 @@
}
}
- pktChan <- pkt
+ pktChan <- svr.pktMgr.newOrderedRequest(pkt)
}
close(pktChan) // shuts down sftpServerWorkers
@@ -341,12 +341,6 @@
return err // error from recvPacket
}
-// Wrap underlying connection methods to use packetManager
-func (svr *Server) sendPacket(pkt responsePacket) error {
- svr.pktMgr.readyPacket(pkt)
- return nil
-}
-
type ider interface {
id() uint32
}