blob: 521390b68120182f2346bce85b494d1153abaed1 [file] [log] [blame]
package ctxext
import (
"sync"
"time"
context "golang.org/x/net/context"
)
// WithParents returns a Context that listens to all given
// parents. It effectively transforms the Context Tree into
// a Directed Acyclic Graph. This is useful when a context
// may be cancelled for more than one reason. For example,
// consider a database with the following Get function:
//
// func (db *DB) Get(ctx context.Context, ...) {}
//
// DB.Get may have to stop for two different contexts:
// * the caller's context (caller might cancel)
// * the database's context (might be shut down mid-request)
//
// WithParents saves the day by allowing us to "merge" contexts
// and continue on our merry contextual way:
//
// ctx = ctxext.WithParents(ctx, db.ctx)
//
// Passing related (mutually derived) contexts to WithParents is
// actually ok. The child is cancelled when any of its parents is
// cancelled, so if any of its parents are also related, the cancel
// propagation will reach the child via the shortest path.
func WithParents(ctxts ...context.Context) context.Context {
if len(ctxts) < 1 {
panic("no contexts provided")
}
ctx := &errCtx{
done: make(chan struct{}),
dead: earliestDeadline(ctxts),
}
// listen to all contexts and use the first.
for _, c2 := range ctxts {
go func(pctx context.Context) {
select {
case <-ctx.Done(): // cancelled by another parent
return
case <-pctx.Done(): // this parent cancelled
// race: two parents may have cancelled at the same time.
// break tie with mutex (inside c.cancel)
ctx.cancel(pctx.Err())
}
}(c2)
}
return ctx
}
func earliestDeadline(ctxts []context.Context) *time.Time {
var d1 *time.Time
for _, c := range ctxts {
if c == nil {
panic("given nil context.Context")
}
// use earliest deadline.
d2, ok := c.Deadline()
if !ok {
continue
}
if d1 == nil || (*d1).After(d2) {
d1 = &d2
}
}
return d1
}
type errCtx struct {
dead *time.Time
done chan struct{}
err error
mu sync.RWMutex
}
func (c *errCtx) cancel(err error) {
c.mu.Lock()
defer c.mu.Unlock()
select {
case <-c.Done():
return
default:
}
c.err = err
close(c.done) // signal done to all
}
func (c *errCtx) Done() <-chan struct{} {
return c.done
}
func (c *errCtx) Err() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.err
}
func (c *errCtx) Value(key interface{}) interface{} {
return nil
}
func (c *errCtx) Deadline() (deadline time.Time, ok bool) {
if c.dead == nil {
return
}
return *c.dead, true
}