blob: e7c5a37d1b172e3bffa90bb396f8163c6325c1fb [file] [log] [blame]
package fsutil
import (
"context"
"io"
"os"
"sync"
"github.com/pkg/errors"
"github.com/tonistiigi/fsutil/types"
"golang.org/x/sync/errgroup"
)
var bufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1<<10)
},
}
type Stream interface {
RecvMsg(interface{}) error
SendMsg(m interface{}) error
Context() context.Context
}
func Send(ctx context.Context, conn Stream, fs FS, progressCb func(int, bool)) error {
s := &sender{
conn: &syncStream{Stream: conn},
fs: fs,
files: make(map[uint32]string),
progressCb: progressCb,
sendpipeline: make(chan *sendHandle, 128),
}
return s.run(ctx)
}
type sendHandle struct {
id uint32
path string
}
type sender struct {
conn Stream
fs FS
files map[uint32]string
mu sync.RWMutex
progressCb func(int, bool)
progressCurrent int
sendpipeline chan *sendHandle
}
func (s *sender) run(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
defer s.updateProgress(0, true)
g.Go(func() error {
err := s.walk(ctx)
if err != nil {
s.conn.SendMsg(&types.Packet{Type: types.PACKET_ERR, Data: []byte(err.Error())})
}
return err
})
for i := 0; i < 4; i++ {
g.Go(func() error {
for h := range s.sendpipeline {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err := s.sendFile(h); err != nil {
return err
}
}
return nil
})
}
g.Go(func() error {
defer close(s.sendpipeline)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
var p types.Packet
if err := s.conn.RecvMsg(&p); err != nil {
return err
}
switch p.Type {
case types.PACKET_ERR:
return errors.Errorf("error from receiver: %s", p.Data)
case types.PACKET_REQ:
if err := s.queue(p.ID); err != nil {
return err
}
case types.PACKET_FIN:
return s.conn.SendMsg(&types.Packet{Type: types.PACKET_FIN})
}
}
})
return g.Wait()
}
func (s *sender) updateProgress(size int, last bool) {
if s.progressCb != nil {
s.progressCurrent += size
s.progressCb(s.progressCurrent, last)
}
}
func (s *sender) queue(id uint32) error {
s.mu.Lock()
p, ok := s.files[id]
if !ok {
s.mu.Unlock()
return errors.Errorf("invalid file id %d", id)
}
delete(s.files, id)
s.mu.Unlock()
s.sendpipeline <- &sendHandle{id, p}
return nil
}
func (s *sender) sendFile(h *sendHandle) error {
f, err := s.fs.Open(h.path)
if err == nil {
defer f.Close()
buf := bufPool.Get().([]byte)
defer bufPool.Put(buf)
if _, err := io.CopyBuffer(&fileSender{sender: s, id: h.id}, f, buf); err != nil {
return err
}
}
return s.conn.SendMsg(&types.Packet{ID: h.id, Type: types.PACKET_DATA})
}
func (s *sender) walk(ctx context.Context) error {
var i uint32 = 0
err := s.fs.Walk(ctx, func(path string, fi os.FileInfo, err error) error {
if err != nil {
return err
}
stat, ok := fi.Sys().(*types.Stat)
if !ok {
return errors.Wrapf(err, "invalid fileinfo without stat info: %s", path)
}
p := &types.Packet{
Type: types.PACKET_STAT,
Stat: stat,
}
if fileCanRequestData(os.FileMode(stat.Mode)) {
s.mu.Lock()
s.files[i] = stat.Path
s.mu.Unlock()
}
i++
s.updateProgress(p.Size(), false)
return errors.Wrapf(s.conn.SendMsg(p), "failed to send stat %s", path)
})
if err != nil {
return err
}
return errors.Wrapf(s.conn.SendMsg(&types.Packet{Type: types.PACKET_STAT}), "failed to send last stat")
}
func fileCanRequestData(m os.FileMode) bool {
// avoid updating this function as it needs to match between sender/receiver.
// version if needed
return m&os.ModeType == 0
}
type fileSender struct {
sender *sender
id uint32
}
func (fs *fileSender) Write(dt []byte) (int, error) {
if len(dt) == 0 {
return 0, nil
}
p := &types.Packet{Type: types.PACKET_DATA, ID: fs.id, Data: dt}
if err := fs.sender.conn.SendMsg(p); err != nil {
return 0, err
}
fs.sender.updateProgress(p.Size(), false)
return len(dt), nil
}
type syncStream struct {
Stream
mu sync.Mutex
}
func (ss *syncStream) SendMsg(m interface{}) error {
ss.mu.Lock()
err := ss.Stream.SendMsg(m)
ss.mu.Unlock()
return err
}