| package fsutil |
| |
| import ( |
| "context" |
| "io" |
| "os" |
| "sync" |
| |
| "github.com/pkg/errors" |
| "github.com/tonistiigi/fsutil/types" |
| "golang.org/x/sync/errgroup" |
| ) |
| |
| var bufPool = sync.Pool{ |
| New: func() interface{} { |
| return make([]byte, 32*1<<10) |
| }, |
| } |
| |
| type Stream interface { |
| RecvMsg(interface{}) error |
| SendMsg(m interface{}) error |
| Context() context.Context |
| } |
| |
| func Send(ctx context.Context, conn Stream, fs FS, progressCb func(int, bool)) error { |
| s := &sender{ |
| conn: &syncStream{Stream: conn}, |
| fs: fs, |
| files: make(map[uint32]string), |
| progressCb: progressCb, |
| sendpipeline: make(chan *sendHandle, 128), |
| } |
| return s.run(ctx) |
| } |
| |
| type sendHandle struct { |
| id uint32 |
| path string |
| } |
| |
| type sender struct { |
| conn Stream |
| fs FS |
| files map[uint32]string |
| mu sync.RWMutex |
| progressCb func(int, bool) |
| progressCurrent int |
| sendpipeline chan *sendHandle |
| } |
| |
| func (s *sender) run(ctx context.Context) error { |
| g, ctx := errgroup.WithContext(ctx) |
| |
| defer s.updateProgress(0, true) |
| |
| g.Go(func() error { |
| err := s.walk(ctx) |
| if err != nil { |
| s.conn.SendMsg(&types.Packet{Type: types.PACKET_ERR, Data: []byte(err.Error())}) |
| } |
| return err |
| }) |
| |
| for i := 0; i < 4; i++ { |
| g.Go(func() error { |
| for h := range s.sendpipeline { |
| select { |
| case <-ctx.Done(): |
| return ctx.Err() |
| default: |
| } |
| if err := s.sendFile(h); err != nil { |
| return err |
| } |
| } |
| return nil |
| }) |
| } |
| |
| g.Go(func() error { |
| defer close(s.sendpipeline) |
| |
| for { |
| select { |
| case <-ctx.Done(): |
| return ctx.Err() |
| default: |
| } |
| var p types.Packet |
| if err := s.conn.RecvMsg(&p); err != nil { |
| return err |
| } |
| switch p.Type { |
| case types.PACKET_ERR: |
| return errors.Errorf("error from receiver: %s", p.Data) |
| case types.PACKET_REQ: |
| if err := s.queue(p.ID); err != nil { |
| return err |
| } |
| case types.PACKET_FIN: |
| return s.conn.SendMsg(&types.Packet{Type: types.PACKET_FIN}) |
| } |
| } |
| }) |
| |
| return g.Wait() |
| } |
| |
| func (s *sender) updateProgress(size int, last bool) { |
| if s.progressCb != nil { |
| s.progressCurrent += size |
| s.progressCb(s.progressCurrent, last) |
| } |
| } |
| |
| func (s *sender) queue(id uint32) error { |
| s.mu.Lock() |
| p, ok := s.files[id] |
| if !ok { |
| s.mu.Unlock() |
| return errors.Errorf("invalid file id %d", id) |
| } |
| delete(s.files, id) |
| s.mu.Unlock() |
| s.sendpipeline <- &sendHandle{id, p} |
| return nil |
| } |
| |
| func (s *sender) sendFile(h *sendHandle) error { |
| f, err := s.fs.Open(h.path) |
| if err == nil { |
| defer f.Close() |
| buf := bufPool.Get().([]byte) |
| defer bufPool.Put(buf) |
| if _, err := io.CopyBuffer(&fileSender{sender: s, id: h.id}, f, buf); err != nil { |
| return err |
| } |
| } |
| return s.conn.SendMsg(&types.Packet{ID: h.id, Type: types.PACKET_DATA}) |
| } |
| |
| func (s *sender) walk(ctx context.Context) error { |
| var i uint32 = 0 |
| err := s.fs.Walk(ctx, func(path string, fi os.FileInfo, err error) error { |
| if err != nil { |
| return err |
| } |
| stat, ok := fi.Sys().(*types.Stat) |
| if !ok { |
| return errors.Wrapf(err, "invalid fileinfo without stat info: %s", path) |
| } |
| |
| p := &types.Packet{ |
| Type: types.PACKET_STAT, |
| Stat: stat, |
| } |
| if fileCanRequestData(os.FileMode(stat.Mode)) { |
| s.mu.Lock() |
| s.files[i] = stat.Path |
| s.mu.Unlock() |
| } |
| i++ |
| s.updateProgress(p.Size(), false) |
| return errors.Wrapf(s.conn.SendMsg(p), "failed to send stat %s", path) |
| }) |
| if err != nil { |
| return err |
| } |
| return errors.Wrapf(s.conn.SendMsg(&types.Packet{Type: types.PACKET_STAT}), "failed to send last stat") |
| } |
| |
| func fileCanRequestData(m os.FileMode) bool { |
| // avoid updating this function as it needs to match between sender/receiver. |
| // version if needed |
| return m&os.ModeType == 0 |
| } |
| |
| type fileSender struct { |
| sender *sender |
| id uint32 |
| } |
| |
| func (fs *fileSender) Write(dt []byte) (int, error) { |
| if len(dt) == 0 { |
| return 0, nil |
| } |
| p := &types.Packet{Type: types.PACKET_DATA, ID: fs.id, Data: dt} |
| if err := fs.sender.conn.SendMsg(p); err != nil { |
| return 0, err |
| } |
| fs.sender.updateProgress(p.Size(), false) |
| return len(dt), nil |
| } |
| |
| type syncStream struct { |
| Stream |
| mu sync.Mutex |
| } |
| |
| func (ss *syncStream) SendMsg(m interface{}) error { |
| ss.mu.Lock() |
| err := ss.Stream.SendMsg(m) |
| ss.mu.Unlock() |
| return err |
| } |