| package ctxext |
| |
| import ( |
| "sync" |
| "time" |
| |
| context "golang.org/x/net/context" |
| ) |
| |
| // WithParents returns a Context that listens to all given |
| // parents. It effectively transforms the Context Tree into |
| // a Directed Acyclic Graph. This is useful when a context |
| // may be cancelled for more than one reason. For example, |
| // consider a database with the following Get function: |
| // |
| // func (db *DB) Get(ctx context.Context, ...) {} |
| // |
| // DB.Get may have to stop for two different contexts: |
| // * the caller's context (caller might cancel) |
| // * the database's context (might be shut down mid-request) |
| // |
| // WithParents saves the day by allowing us to "merge" contexts |
| // and continue on our merry contextual way: |
| // |
| // ctx = ctxext.WithParents(ctx, db.ctx) |
| // |
| // Passing related (mutually derived) contexts to WithParents is |
| // actually ok. The child is cancelled when any of its parents is |
| // cancelled, so if any of its parents are also related, the cancel |
| // propagation will reach the child via the shortest path. |
| func WithParents(ctxts ...context.Context) context.Context { |
| if len(ctxts) < 1 { |
| panic("no contexts provided") |
| } |
| |
| ctx := &errCtx{ |
| done: make(chan struct{}), |
| dead: earliestDeadline(ctxts), |
| } |
| |
| // listen to all contexts and use the first. |
| for _, c2 := range ctxts { |
| go func(pctx context.Context) { |
| select { |
| case <-ctx.Done(): // cancelled by another parent |
| return |
| case <-pctx.Done(): // this parent cancelled |
| // race: two parents may have cancelled at the same time. |
| // break tie with mutex (inside c.cancel) |
| ctx.cancel(pctx.Err()) |
| } |
| }(c2) |
| } |
| |
| return ctx |
| } |
| |
| func earliestDeadline(ctxts []context.Context) *time.Time { |
| var d1 *time.Time |
| for _, c := range ctxts { |
| if c == nil { |
| panic("given nil context.Context") |
| } |
| |
| // use earliest deadline. |
| d2, ok := c.Deadline() |
| if !ok { |
| continue |
| } |
| |
| if d1 == nil || (*d1).After(d2) { |
| d1 = &d2 |
| } |
| } |
| return d1 |
| } |
| |
| type errCtx struct { |
| dead *time.Time |
| done chan struct{} |
| err error |
| mu sync.RWMutex |
| } |
| |
| func (c *errCtx) cancel(err error) { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| |
| select { |
| case <-c.Done(): |
| return |
| default: |
| } |
| |
| c.err = err |
| close(c.done) // signal done to all |
| } |
| |
| func (c *errCtx) Done() <-chan struct{} { |
| return c.done |
| } |
| |
| func (c *errCtx) Err() error { |
| c.mu.Lock() |
| defer c.mu.Unlock() |
| return c.err |
| } |
| |
| func (c *errCtx) Value(key interface{}) interface{} { |
| return nil |
| } |
| |
| func (c *errCtx) Deadline() (deadline time.Time, ok bool) { |
| if c.dead == nil { |
| return |
| } |
| |
| return *c.dead, true |
| } |