blob: a543364c75335e549db5604885cbbc248568b1e8 [file] [log] [blame]
package flightcontrol
import (
"context"
"io"
"runtime"
"sort"
"sync"
"time"
"github.com/moby/buildkit/util/progress"
"github.com/pkg/errors"
)
// flightcontrol is like singleflight but with support for cancellation and
// nested progress reporting
var (
errRetry = errors.Errorf("retry")
errRetryTimeout = errors.Errorf("exceeded retry timeout")
)
type contextKeyT string
var contextKey = contextKeyT("buildkit/util/flightcontrol.progress")
// Group is a flightcontrol synchronization group
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Do executes a context function syncronized by the key
func (g *Group) Do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (v interface{}, err error) {
var backoff time.Duration
for {
v, err = g.do(ctx, key, fn)
if err == nil || !errors.Is(err, errRetry) {
return v, err
}
// backoff logic
if backoff >= 3*time.Second {
err = errors.Wrapf(errRetryTimeout, "flightcontrol")
return v, err
}
runtime.Gosched()
if backoff > 0 {
time.Sleep(backoff)
backoff *= 2
} else {
backoff = time.Millisecond
}
}
}
func (g *Group) do(ctx context.Context, key string, fn func(ctx context.Context) (interface{}, error)) (interface{}, error) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok { // register 2nd waiter
g.mu.Unlock()
return c.wait(ctx)
}
c := newCall(fn)
g.m[key] = c
go func() {
// cleanup after a caller has returned
<-c.ready
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
close(c.cleaned)
}()
g.mu.Unlock()
return c.wait(ctx)
}
type call struct {
mu sync.Mutex
result interface{}
err error
ready chan struct{}
cleaned chan struct{}
ctx *sharedContext
ctxs []context.Context
fn func(ctx context.Context) (interface{}, error)
once sync.Once
closeProgressWriter func()
progressState *progressState
progressCtx context.Context
}
func newCall(fn func(ctx context.Context) (interface{}, error)) *call {
c := &call{
fn: fn,
ready: make(chan struct{}),
cleaned: make(chan struct{}),
progressState: newProgressState(),
}
ctx := newContext(c) // newSharedContext
pr, pctx, closeProgressWriter := progress.NewContext(context.Background())
c.progressCtx = pctx
c.ctx = ctx
c.closeProgressWriter = closeProgressWriter
go c.progressState.run(pr) // TODO: remove this, wrap writer instead
return c
}
func (c *call) run() {
defer c.closeProgressWriter()
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
v, err := c.fn(ctx)
c.mu.Lock()
c.result = v
c.err = err
c.mu.Unlock()
close(c.ready)
}
func (c *call) wait(ctx context.Context) (v interface{}, err error) {
c.mu.Lock()
// detect case where caller has just returned, let it clean up before
select {
case <-c.ready: // could return if no error
c.mu.Unlock()
<-c.cleaned
return nil, errRetry
default:
}
pw, ok, ctx := progress.FromContext(ctx)
if ok {
c.progressState.add(pw)
}
c.ctxs = append(c.ctxs, ctx)
c.mu.Unlock()
go c.once.Do(c.run)
select {
case <-ctx.Done():
select {
case <-c.ctx.Done():
// if this cancelled the last context, then wait for function to shut down
// and don't accept any more callers
<-c.ready
return c.result, c.err
default:
if ok {
c.progressState.close(pw)
}
return nil, ctx.Err()
}
case <-c.ready:
return c.result, c.err // shared not implemented yet
}
}
func (c *call) Deadline() (deadline time.Time, ok bool) {
c.mu.Lock()
defer c.mu.Unlock()
for _, ctx := range c.ctxs {
select {
case <-ctx.Done():
default:
dl, ok := ctx.Deadline()
if ok {
return dl, ok
}
}
}
return time.Time{}, false
}
func (c *call) Done() <-chan struct{} {
c.mu.Lock()
c.ctx.signal()
c.mu.Unlock()
return c.ctx.done
}
func (c *call) Err() error {
select {
case <-c.ctx.Done():
return c.ctx.err
default:
return nil
}
}
func (c *call) Value(key interface{}) interface{} {
if key == contextKey {
return c.progressState
}
c.mu.Lock()
defer c.mu.Unlock()
ctx := c.progressCtx
select {
case <-ctx.Done():
default:
if v := ctx.Value(key); v != nil {
return v
}
}
if len(c.ctxs) > 0 {
ctx = c.ctxs[0]
select {
case <-ctx.Done():
default:
if v := ctx.Value(key); v != nil {
return v
}
}
}
return nil
}
type sharedContext struct {
*call
done chan struct{}
err error
}
func newContext(c *call) *sharedContext {
return &sharedContext{call: c, done: make(chan struct{})}
}
// call with lock
func (c *sharedContext) signal() {
select {
case <-c.done:
default:
var err error
for _, ctx := range c.ctxs {
select {
case <-ctx.Done():
err = ctx.Err()
default:
return
}
}
c.err = err
close(c.done)
}
}
type rawProgressWriter interface {
WriteRawProgress(*progress.Progress) error
Close() error
}
type progressState struct {
mu sync.Mutex
items map[string]*progress.Progress
writers []rawProgressWriter
done bool
}
func newProgressState() *progressState {
return &progressState{
items: make(map[string]*progress.Progress),
}
}
func (ps *progressState) run(pr progress.Reader) {
for {
p, err := pr.Read(context.TODO())
if err != nil {
if err == io.EOF {
ps.mu.Lock()
ps.done = true
ps.mu.Unlock()
for _, w := range ps.writers {
w.Close()
}
}
return
}
ps.mu.Lock()
for _, p := range p {
for _, w := range ps.writers {
w.WriteRawProgress(p)
}
ps.items[p.ID] = p
}
ps.mu.Unlock()
}
}
func (ps *progressState) add(pw progress.Writer) {
rw, ok := pw.(rawProgressWriter)
if !ok {
return
}
ps.mu.Lock()
plist := make([]*progress.Progress, 0, len(ps.items))
for _, p := range ps.items {
plist = append(plist, p)
}
sort.Slice(plist, func(i, j int) bool {
return plist[i].Timestamp.Before(plist[j].Timestamp)
})
for _, p := range plist {
rw.WriteRawProgress(p)
}
if ps.done {
rw.Close()
} else {
ps.writers = append(ps.writers, rw)
}
ps.mu.Unlock()
}
func (ps *progressState) close(pw progress.Writer) {
rw, ok := pw.(rawProgressWriter)
if !ok {
return
}
ps.mu.Lock()
for i, w := range ps.writers {
if w == rw {
w.Close()
ps.writers = append(ps.writers[:i], ps.writers[i+1:]...)
break
}
}
ps.mu.Unlock()
}