blob: cc14fb5ddba7dd35a858d5af519723c1fa324361 [file] [log] [blame]
package sftp
import (
"encoding"
"sort"
)
// --------------------------------------------------------------------
// Process with 2 branch select, listening to each channel.
// 0) start of loop
// Branch A
// 1) Wait for ids to come in and add them to id list.
// Branch B
// 1) Wait for a packet comes in.
// 2) Add it to the packet list.
// 3) The heads of each list are then compared and if they have the same ids
// the packet is sent out and the entries removed.
// 4) Goto step 2 Until the lists are emptied or the ids don't match.
// 5) Goto step 0.
// --------------------------------------------------------------------
type packetSender interface {
sendPacket(encoding.BinaryMarshaler) error
}
type packetManager struct {
requests chan requestPacket
responses chan responsePacket
fini chan struct{}
incoming uint32s
outgoing responsePackets
sender packetSender // connection object
}
func newPktMgr(sender packetSender) packetManager {
s := packetManager{
requests: make(chan requestPacket, sftpServerWorkerCount),
responses: make(chan responsePacket, sftpServerWorkerCount),
fini: make(chan struct{}),
incoming: make([]uint32, 0, sftpServerWorkerCount),
outgoing: make([]responsePacket, 0, sftpServerWorkerCount),
sender: sender,
}
go s.worker()
return s
}
// for sorting/ordering incoming/outgoing
type uint32s []uint32
type responsePackets []responsePacket
func (u uint32s) Len() int { return len(u) }
func (u uint32s) Swap(i, j int) { u[i], u[j] = u[j], u[i] }
func (u uint32s) Less(i, j int) bool { return u[i] < u[j] }
func (r responsePackets) Len() int { return len(r) }
func (r responsePackets) Swap(i, j int) { r[i], r[j] = r[j], r[i] }
func (r responsePackets) Less(i, j int) bool { return r[i].id() < r[j].id() }
// register incoming packets to be handled
// send id of 0 for packets without id
func (s packetManager) incomingPacket(pkt requestPacket) {
s.requests <- pkt // buffer == sftpServerWorkerCount
}
// register outgoing packets as being ready
func (s packetManager) readyPacket(pkt responsePacket) {
s.responses <- pkt
}
// shut down packetManager worker
func (s packetManager) close() {
close(s.fini)
}
// process packets
func (s *packetManager) worker() {
for {
select {
case pkt := <-s.requests:
debug("incoming id: %v", pkt.id())
s.incoming = append(s.incoming, pkt.id())
if len(s.incoming) > 1 {
sort.Sort(s.incoming)
}
case pkt := <-s.responses:
debug("outgoing pkt: %v", pkt.id())
s.outgoing = append(s.outgoing, pkt)
if len(s.outgoing) > 1 {
sort.Sort(s.outgoing)
}
case <-s.fini:
return
}
s.maybeSendPackets()
}
}
// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
for {
if len(s.outgoing) == 0 || len(s.incoming) == 0 {
debug("break! -- outgoing: %v; incoming: %v",
len(s.outgoing), len(s.incoming))
break
}
out := s.outgoing[0]
in := s.incoming[0]
// debug("incoming: %v", s.incoming)
// debug("outgoing: %v", outfilter(s.outgoing))
if in == out.id() {
s.sender.sendPacket(out)
// pop off heads
copy(s.incoming, s.incoming[1:]) // shift left
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
copy(s.outgoing, s.outgoing[1:]) // shift left
s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
} else {
break
}
}
}
func outfilter(o []responsePacket) []uint32 {
res := make([]uint32, 0, len(o))
for _, v := range o {
res = append(res, v.id())
}
return res
}