blob: 385ee998a9cc6d8d93e42964942d6a6b73401dda [file] [log] [blame]
package sftp
import (
"encoding"
"sort"
)
// --------------------------------------------------------------------
// Process with 2 branch select, listening to each channel.
// 0) start of loop
// Branch A
// 1) Wait for ids to come in and add them to id list.
// Branch B
// 1) Wait for a packet comes in.
// 2) Add it to the packet list.
// 3) The heads of each list are then compared and if they have the same ids
// the packet is sent out and the entries removed.
// 4) Goto step 2 Until the lists are emptied or the ids don't match.
// 5) Goto step 0.
// --------------------------------------------------------------------
type packetSender interface {
sendPacket(encoding.BinaryMarshaler) error
}
type packetManager struct {
requests chan requestPacket
responses chan responsePacket
fini chan struct{}
incoming []uint32
outgoing []responsePacket
sender packetSender // connection object
}
func newPktMgr(sender packetSender) packetManager {
s := packetManager{
requests: make(chan requestPacket, sftpServerWorkerCount),
responses: make(chan responsePacket, sftpServerWorkerCount),
fini: make(chan struct{}),
incoming: make([]uint32, 0, sftpServerWorkerCount),
outgoing: make([]responsePacket, 0, sftpServerWorkerCount),
sender: sender,
}
go s.worker()
return s
}
// register incoming packets to be handled
// send id of 0 for packets without id
func (s packetManager) incomingPacket(pkt requestPacket) {
s.requests <- pkt // buffer == sftpServerWorkerCount
}
// register outgoing packets as being ready
func (s packetManager) readyPacket(pkt responsePacket) {
s.responses <- pkt
}
// shut down packetManager worker
func (s packetManager) close() {
close(s.fini)
}
// process packets
func (s *packetManager) worker() {
for {
select {
case pkt := <-s.requests:
debug("incoming id: %v", pkt.id())
s.incoming = append(s.incoming, pkt.id())
if len(s.incoming) > 1 {
sort.Slice(s.incoming, func(i, j int) bool {
return s.incoming[i] < s.incoming[j]
})
}
case pkt := <-s.responses:
debug("outgoing pkt: %v", pkt.id())
s.outgoing = append(s.outgoing, pkt)
if len(s.outgoing) > 1 {
sort.Slice(s.outgoing, func(i, j int) bool {
return s.outgoing[i].id() < s.outgoing[j].id()
})
}
case <-s.fini:
return
}
s.maybeSendPackets()
}
}
// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
for {
if len(s.outgoing) == 0 || len(s.incoming) == 0 {
debug("break! -- outgoing: %v; incoming: %v",
len(s.outgoing), len(s.incoming))
break
}
out := s.outgoing[0]
in := s.incoming[0]
debug("incoming: %v", s.incoming)
debug("outgoing: %v", outfilter(s.outgoing))
if in == out.id() {
s.sender.sendPacket(out)
// pop off heads
copy(s.incoming, s.incoming[1:]) // shift left
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
copy(s.outgoing, s.outgoing[1:]) // shift left
s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
} else {
break
}
}
}
func outfilter(o []responsePacket) []uint32 {
res := make([]uint32, 0, len(o))
for _, v := range o {
res = append(res, v.id())
}
return res
}