blob: 13d68293ae4c7a63fea528eafc3e1b9810246910 [file] [log] [blame]
package dispatcher
import (
"fmt"
"net"
"strconv"
"sync"
"time"
"github.com/docker/go-events"
"github.com/docker/go-metrics"
"github.com/docker/swarmkit/api"
"github.com/docker/swarmkit/api/equality"
"github.com/docker/swarmkit/ca"
"github.com/docker/swarmkit/log"
"github.com/docker/swarmkit/manager/drivers"
"github.com/docker/swarmkit/manager/state/store"
"github.com/docker/swarmkit/protobuf/ptypes"
"github.com/docker/swarmkit/remotes"
"github.com/docker/swarmkit/watch"
gogotypes "github.com/gogo/protobuf/types"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/grpc/transport"
)
const (
// DefaultHeartBeatPeriod is used for setting default value in cluster config
// and in case if cluster config is missing.
DefaultHeartBeatPeriod = 5 * time.Second
defaultHeartBeatEpsilon = 500 * time.Millisecond
defaultGracePeriodMultiplier = 3
defaultRateLimitPeriod = 8 * time.Second
// maxBatchItems is the threshold of queued writes that should
// trigger an actual transaction to commit them to the shared store.
maxBatchItems = 10000
// maxBatchInterval needs to strike a balance between keeping
// latency low, and realizing opportunities to combine many writes
// into a single transaction. A fraction of a second feels about
// right.
maxBatchInterval = 100 * time.Millisecond
modificationBatchLimit = 100
batchingWaitTime = 100 * time.Millisecond
// defaultNodeDownPeriod specifies the default time period we
// wait before moving tasks assigned to down nodes to ORPHANED
// state.
defaultNodeDownPeriod = 24 * time.Hour
)
var (
// ErrNodeAlreadyRegistered returned if node with same ID was already
// registered with this dispatcher.
ErrNodeAlreadyRegistered = errors.New("node already registered")
// ErrNodeNotRegistered returned if node with such ID wasn't registered
// with this dispatcher.
ErrNodeNotRegistered = errors.New("node not registered")
// ErrSessionInvalid returned when the session in use is no longer valid.
// The node should re-register and start a new session.
ErrSessionInvalid = errors.New("session invalid")
// ErrNodeNotFound returned when the Node doesn't exist in raft.
ErrNodeNotFound = errors.New("node not found")
// Scheduling delay timer.
schedulingDelayTimer metrics.Timer
)
func init() {
ns := metrics.NewNamespace("swarm", "dispatcher", nil)
schedulingDelayTimer = ns.NewTimer("scheduling_delay",
"Scheduling delay is the time a task takes to go from NEW to RUNNING state.")
metrics.Register(ns)
}
// Config is configuration for Dispatcher. For default you should use
// DefaultConfig.
type Config struct {
HeartbeatPeriod time.Duration
HeartbeatEpsilon time.Duration
// RateLimitPeriod specifies how often node with same ID can try to register
// new session.
RateLimitPeriod time.Duration
GracePeriodMultiplier int
}
// DefaultConfig returns default config for Dispatcher.
func DefaultConfig() *Config {
return &Config{
HeartbeatPeriod: DefaultHeartBeatPeriod,
HeartbeatEpsilon: defaultHeartBeatEpsilon,
RateLimitPeriod: defaultRateLimitPeriod,
GracePeriodMultiplier: defaultGracePeriodMultiplier,
}
}
// Cluster is interface which represent raft cluster. manager/state/raft.Node
// is implements it. This interface needed only for easier unit-testing.
type Cluster interface {
GetMemberlist() map[uint64]*api.RaftMember
SubscribePeers() (chan events.Event, func())
MemoryStore() *store.MemoryStore
}
// nodeUpdate provides a new status and/or description to apply to a node
// object.
type nodeUpdate struct {
status *api.NodeStatus
description *api.NodeDescription
}
// clusterUpdate is an object that stores an update to the cluster that should trigger
// a new session message. These are pointers to indicate the difference between
// "there is no update" and "update this to nil"
type clusterUpdate struct {
managerUpdate *[]*api.WeightedPeer
bootstrapKeyUpdate *[]*api.EncryptionKey
rootCAUpdate *[]byte
}
// Dispatcher is responsible for dispatching tasks and tracking agent health.
type Dispatcher struct {
mu sync.Mutex
wg sync.WaitGroup
nodes *nodeStore
store *store.MemoryStore
lastSeenManagers []*api.WeightedPeer
networkBootstrapKeys []*api.EncryptionKey
lastSeenRootCert []byte
config *Config
cluster Cluster
ctx context.Context
cancel context.CancelFunc
clusterUpdateQueue *watch.Queue
dp *drivers.DriverProvider
securityConfig *ca.SecurityConfig
taskUpdates map[string]*api.TaskStatus // indexed by task ID
taskUpdatesLock sync.Mutex
nodeUpdates map[string]nodeUpdate // indexed by node ID
nodeUpdatesLock sync.Mutex
downNodes *nodeStore
processUpdatesTrigger chan struct{}
// for waiting for the next task/node batch update
processUpdatesLock sync.Mutex
processUpdatesCond *sync.Cond
}
// New returns Dispatcher with cluster interface(usually raft.Node).
func New(cluster Cluster, c *Config, dp *drivers.DriverProvider, securityConfig *ca.SecurityConfig) *Dispatcher {
d := &Dispatcher{
dp: dp,
nodes: newNodeStore(c.HeartbeatPeriod, c.HeartbeatEpsilon, c.GracePeriodMultiplier, c.RateLimitPeriod),
downNodes: newNodeStore(defaultNodeDownPeriod, 0, 1, 0),
store: cluster.MemoryStore(),
cluster: cluster,
processUpdatesTrigger: make(chan struct{}, 1),
config: c,
securityConfig: securityConfig,
}
d.processUpdatesCond = sync.NewCond(&d.processUpdatesLock)
return d
}
func getWeightedPeers(cluster Cluster) []*api.WeightedPeer {
members := cluster.GetMemberlist()
var mgrs []*api.WeightedPeer
for _, m := range members {
mgrs = append(mgrs, &api.WeightedPeer{
Peer: &api.Peer{
NodeID: m.NodeID,
Addr: m.Addr,
},
// TODO(stevvooe): Calculate weight of manager selection based on
// cluster-level observations, such as number of connections and
// load.
Weight: remotes.DefaultObservationWeight,
})
}
return mgrs
}
// Run runs dispatcher tasks which should be run on leader dispatcher.
// Dispatcher can be stopped with cancelling ctx or calling Stop().
func (d *Dispatcher) Run(ctx context.Context) error {
d.taskUpdatesLock.Lock()
d.taskUpdates = make(map[string]*api.TaskStatus)
d.taskUpdatesLock.Unlock()
d.nodeUpdatesLock.Lock()
d.nodeUpdates = make(map[string]nodeUpdate)
d.nodeUpdatesLock.Unlock()
d.mu.Lock()
if d.isRunning() {
d.mu.Unlock()
return errors.New("dispatcher is already running")
}
ctx = log.WithModule(ctx, "dispatcher")
if err := d.markNodesUnknown(ctx); err != nil {
log.G(ctx).Errorf(`failed to move all nodes to "unknown" state: %v`, err)
}
configWatcher, cancel, err := store.ViewAndWatch(
d.store,
func(readTx store.ReadTx) error {
clusters, err := store.FindClusters(readTx, store.ByName(store.DefaultClusterName))
if err != nil {
return err
}
if err == nil && len(clusters) == 1 {
heartbeatPeriod, err := gogotypes.DurationFromProto(clusters[0].Spec.Dispatcher.HeartbeatPeriod)
if err == nil && heartbeatPeriod > 0 {
d.config.HeartbeatPeriod = heartbeatPeriod
}
if clusters[0].NetworkBootstrapKeys != nil {
d.networkBootstrapKeys = clusters[0].NetworkBootstrapKeys
}
d.lastSeenRootCert = clusters[0].RootCA.CACert
}
return nil
},
api.EventUpdateCluster{},
)
if err != nil {
d.mu.Unlock()
return err
}
// set queue here to guarantee that Close will close it
d.clusterUpdateQueue = watch.NewQueue()
peerWatcher, peerCancel := d.cluster.SubscribePeers()
defer peerCancel()
d.lastSeenManagers = getWeightedPeers(d.cluster)
defer cancel()
d.ctx, d.cancel = context.WithCancel(ctx)
ctx = d.ctx
d.wg.Add(1)
defer d.wg.Done()
d.mu.Unlock()
publishManagers := func(peers []*api.Peer) {
var mgrs []*api.WeightedPeer
for _, p := range peers {
mgrs = append(mgrs, &api.WeightedPeer{
Peer: p,
Weight: remotes.DefaultObservationWeight,
})
}
d.mu.Lock()
d.lastSeenManagers = mgrs
d.mu.Unlock()
d.clusterUpdateQueue.Publish(clusterUpdate{managerUpdate: &mgrs})
}
batchTimer := time.NewTimer(maxBatchInterval)
defer batchTimer.Stop()
for {
select {
case ev := <-peerWatcher:
publishManagers(ev.([]*api.Peer))
case <-d.processUpdatesTrigger:
d.processUpdates(ctx)
batchTimer.Reset(maxBatchInterval)
case <-batchTimer.C:
d.processUpdates(ctx)
batchTimer.Reset(maxBatchInterval)
case v := <-configWatcher:
cluster := v.(api.EventUpdateCluster)
d.mu.Lock()
if cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod != nil {
// ignore error, since Spec has passed validation before
heartbeatPeriod, _ := gogotypes.DurationFromProto(cluster.Cluster.Spec.Dispatcher.HeartbeatPeriod)
if heartbeatPeriod != d.config.HeartbeatPeriod {
// only call d.nodes.updatePeriod when heartbeatPeriod changes
d.config.HeartbeatPeriod = heartbeatPeriod
d.nodes.updatePeriod(d.config.HeartbeatPeriod, d.config.HeartbeatEpsilon, d.config.GracePeriodMultiplier)
}
}
d.lastSeenRootCert = cluster.Cluster.RootCA.CACert
d.networkBootstrapKeys = cluster.Cluster.NetworkBootstrapKeys
d.mu.Unlock()
d.clusterUpdateQueue.Publish(clusterUpdate{
bootstrapKeyUpdate: &cluster.Cluster.NetworkBootstrapKeys,
rootCAUpdate: &cluster.Cluster.RootCA.CACert,
})
case <-ctx.Done():
return nil
}
}
}
// Stop stops dispatcher and closes all grpc streams.
func (d *Dispatcher) Stop() error {
d.mu.Lock()
if !d.isRunning() {
d.mu.Unlock()
return errors.New("dispatcher is already stopped")
}
d.cancel()
d.mu.Unlock()
d.nodes.Clean()
d.processUpdatesLock.Lock()
// In case there are any waiters. There is no chance of any starting
// after this point, because they check if the context is canceled
// before waiting.
d.processUpdatesCond.Broadcast()
d.processUpdatesLock.Unlock()
d.clusterUpdateQueue.Close()
d.wg.Wait()
return nil
}
func (d *Dispatcher) isRunningLocked() (context.Context, error) {
d.mu.Lock()
if !d.isRunning() {
d.mu.Unlock()
return nil, status.Errorf(codes.Aborted, "dispatcher is stopped")
}
ctx := d.ctx
d.mu.Unlock()
return ctx, nil
}
func (d *Dispatcher) markNodesUnknown(ctx context.Context) error {
log := log.G(ctx).WithField("method", "(*Dispatcher).markNodesUnknown")
var nodes []*api.Node
var err error
d.store.View(func(tx store.ReadTx) {
nodes, err = store.FindNodes(tx, store.All)
})
if err != nil {
return errors.Wrap(err, "failed to get list of nodes")
}
err = d.store.Batch(func(batch *store.Batch) error {
for _, n := range nodes {
err := batch.Update(func(tx store.Tx) error {
// check if node is still here
node := store.GetNode(tx, n.ID)
if node == nil {
return nil
}
// do not try to resurrect down nodes
if node.Status.State == api.NodeStatus_DOWN {
nodeCopy := node
expireFunc := func() {
if err := d.moveTasksToOrphaned(nodeCopy.ID); err != nil {
log.WithError(err).Error(`failed to move all tasks to "ORPHANED" state`)
}
d.downNodes.Delete(nodeCopy.ID)
}
d.downNodes.Add(nodeCopy, expireFunc)
return nil
}
node.Status.State = api.NodeStatus_UNKNOWN
node.Status.Message = `Node moved to "unknown" state due to leadership change in cluster`
nodeID := node.ID
expireFunc := func() {
log := log.WithField("node", nodeID)
log.Debug("heartbeat expiration for unknown node")
if err := d.markNodeNotReady(nodeID, api.NodeStatus_DOWN, `heartbeat failure for node in "unknown" state`); err != nil {
log.WithError(err).Error(`failed deregistering node after heartbeat expiration for node in "unknown" state`)
}
}
if err := d.nodes.AddUnknown(node, expireFunc); err != nil {
return errors.Wrap(err, `adding node in "unknown" state to node store failed`)
}
if err := store.UpdateNode(tx, node); err != nil {
return errors.Wrap(err, "update failed")
}
return nil
})
if err != nil {
log.WithField("node", n.ID).WithError(err).Error(`failed to move node to "unknown" state`)
}
}
return nil
})
return err
}
func (d *Dispatcher) isRunning() bool {
if d.ctx == nil {
return false
}
select {
case <-d.ctx.Done():
return false
default:
}
return true
}
// markNodeReady updates the description of a node, updates its address, and sets status to READY
// this is used during registration when a new node description is provided
// and during node updates when the node description changes
func (d *Dispatcher) markNodeReady(ctx context.Context, nodeID string, description *api.NodeDescription, addr string) error {
d.nodeUpdatesLock.Lock()
d.nodeUpdates[nodeID] = nodeUpdate{
status: &api.NodeStatus{
State: api.NodeStatus_READY,
Addr: addr,
},
description: description,
}
numUpdates := len(d.nodeUpdates)
d.nodeUpdatesLock.Unlock()
// Node is marked ready. Remove the node from down nodes if it
// is there.
d.downNodes.Delete(nodeID)
if numUpdates >= maxBatchItems {
select {
case d.processUpdatesTrigger <- struct{}{}:
case <-ctx.Done():
return ctx.Err()
}
}
// Wait until the node update batch happens before unblocking register.
d.processUpdatesLock.Lock()
defer d.processUpdatesLock.Unlock()
select {
case <-ctx.Done():
return ctx.Err()
default:
}
d.processUpdatesCond.Wait()
return nil
}
// gets the node IP from the context of a grpc call
func nodeIPFromContext(ctx context.Context) (string, error) {
nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return "", err
}
addr, _, err := net.SplitHostPort(nodeInfo.RemoteAddr)
if err != nil {
return "", errors.Wrap(err, "unable to get ip from addr:port")
}
return addr, nil
}
// register is used for registration of node with particular dispatcher.
func (d *Dispatcher) register(ctx context.Context, nodeID string, description *api.NodeDescription) (string, error) {
// prevent register until we're ready to accept it
dctx, err := d.isRunningLocked()
if err != nil {
return "", err
}
if err := d.nodes.CheckRateLimit(nodeID); err != nil {
return "", err
}
// TODO(stevvooe): Validate node specification.
var node *api.Node
d.store.View(func(tx store.ReadTx) {
node = store.GetNode(tx, nodeID)
})
if node == nil {
return "", ErrNodeNotFound
}
addr, err := nodeIPFromContext(ctx)
if err != nil {
log.G(ctx).WithError(err).Debug("failed to get remote node IP")
}
if err := d.markNodeReady(dctx, nodeID, description, addr); err != nil {
return "", err
}
expireFunc := func() {
log.G(ctx).Debug("heartbeat expiration")
if err := d.markNodeNotReady(nodeID, api.NodeStatus_DOWN, "heartbeat failure"); err != nil {
log.G(ctx).WithError(err).Errorf("failed deregistering node after heartbeat expiration")
}
}
rn := d.nodes.Add(node, expireFunc)
// NOTE(stevvooe): We need be a little careful with re-registration. The
// current implementation just matches the node id and then gives away the
// sessionID. If we ever want to use sessionID as a secret, which we may
// want to, this is giving away the keys to the kitchen.
//
// The right behavior is going to be informed by identity. Basically, each
// time a node registers, we invalidate the session and issue a new
// session, once identity is proven. This will cause misbehaved agents to
// be kicked when multiple connections are made.
return rn.SessionID, nil
}
// UpdateTaskStatus updates status of task. Node should send such updates
// on every status change of its tasks.
func (d *Dispatcher) UpdateTaskStatus(ctx context.Context, r *api.UpdateTaskStatusRequest) (*api.UpdateTaskStatusResponse, error) {
nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return nil, err
}
nodeID := nodeInfo.NodeID
fields := logrus.Fields{
"node.id": nodeID,
"node.session": r.SessionID,
"method": "(*Dispatcher).UpdateTaskStatus",
}
if nodeInfo.ForwardedBy != nil {
fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
}
log := log.G(ctx).WithFields(fields)
dctx, err := d.isRunningLocked()
if err != nil {
return nil, err
}
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
return nil, err
}
validTaskUpdates := make([]*api.UpdateTaskStatusRequest_TaskStatusUpdate, 0, len(r.Updates))
// Validate task updates
for _, u := range r.Updates {
if u.Status == nil {
log.WithField("task.id", u.TaskID).Warn("task report has nil status")
continue
}
var t *api.Task
d.store.View(func(tx store.ReadTx) {
t = store.GetTask(tx, u.TaskID)
})
if t == nil {
// Task may have been deleted
log.WithField("task.id", u.TaskID).Debug("cannot find target task in store")
continue
}
if t.NodeID != nodeID {
err := status.Errorf(codes.PermissionDenied, "cannot update a task not assigned this node")
log.WithField("task.id", u.TaskID).Error(err)
return nil, err
}
validTaskUpdates = append(validTaskUpdates, u)
}
d.taskUpdatesLock.Lock()
// Enqueue task updates
for _, u := range validTaskUpdates {
d.taskUpdates[u.TaskID] = u.Status
}
numUpdates := len(d.taskUpdates)
d.taskUpdatesLock.Unlock()
if numUpdates >= maxBatchItems {
select {
case d.processUpdatesTrigger <- struct{}{}:
case <-dctx.Done():
}
}
return nil, nil
}
func (d *Dispatcher) processUpdates(ctx context.Context) {
var (
taskUpdates map[string]*api.TaskStatus
nodeUpdates map[string]nodeUpdate
)
d.taskUpdatesLock.Lock()
if len(d.taskUpdates) != 0 {
taskUpdates = d.taskUpdates
d.taskUpdates = make(map[string]*api.TaskStatus)
}
d.taskUpdatesLock.Unlock()
d.nodeUpdatesLock.Lock()
if len(d.nodeUpdates) != 0 {
nodeUpdates = d.nodeUpdates
d.nodeUpdates = make(map[string]nodeUpdate)
}
d.nodeUpdatesLock.Unlock()
if len(taskUpdates) == 0 && len(nodeUpdates) == 0 {
return
}
log := log.G(ctx).WithFields(logrus.Fields{
"method": "(*Dispatcher).processUpdates",
})
err := d.store.Batch(func(batch *store.Batch) error {
for taskID, status := range taskUpdates {
err := batch.Update(func(tx store.Tx) error {
logger := log.WithField("task.id", taskID)
task := store.GetTask(tx, taskID)
if task == nil {
// Task may have been deleted
logger.Debug("cannot find target task in store")
return nil
}
logger = logger.WithField("state.transition", fmt.Sprintf("%v->%v", task.Status.State, status.State))
if task.Status == *status {
logger.Debug("task status identical, ignoring")
return nil
}
if task.Status.State > status.State {
logger.Debug("task status invalid transition")
return nil
}
// Update scheduling delay metric for running tasks.
// We use the status update time on the leader to calculate the scheduling delay.
// Because of this, the recorded scheduling delay will be an overestimate and include
// the network delay between the worker and the leader.
// This is not ideal, but its a known overestimation, rather than using the status update time
// from the worker node, which may cause unknown incorrect results due to possible clock skew.
if status.State == api.TaskStateRunning {
start := time.Unix(status.AppliedAt.GetSeconds(), int64(status.AppliedAt.GetNanos()))
schedulingDelayTimer.UpdateSince(start)
}
task.Status = *status
task.Status.AppliedBy = d.securityConfig.ClientTLSCreds.NodeID()
task.Status.AppliedAt = ptypes.MustTimestampProto(time.Now())
if err := store.UpdateTask(tx, task); err != nil {
logger.WithError(err).Error("failed to update task status")
return nil
}
logger.Debug("dispatcher committed status update to store")
return nil
})
if err != nil {
log.WithError(err).Error("dispatcher task update transaction failed")
}
}
for nodeID, nodeUpdate := range nodeUpdates {
err := batch.Update(func(tx store.Tx) error {
logger := log.WithField("node.id", nodeID)
node := store.GetNode(tx, nodeID)
if node == nil {
logger.Errorf("node unavailable")
return nil
}
if nodeUpdate.status != nil {
node.Status.State = nodeUpdate.status.State
node.Status.Message = nodeUpdate.status.Message
if nodeUpdate.status.Addr != "" {
node.Status.Addr = nodeUpdate.status.Addr
}
}
if nodeUpdate.description != nil {
node.Description = nodeUpdate.description
}
if err := store.UpdateNode(tx, node); err != nil {
logger.WithError(err).Error("failed to update node status")
return nil
}
logger.Debug("node status updated")
return nil
})
if err != nil {
log.WithError(err).Error("dispatcher node update transaction failed")
}
}
return nil
})
if err != nil {
log.WithError(err).Error("dispatcher batch failed")
}
d.processUpdatesCond.Broadcast()
}
// Tasks is a stream of tasks state for node. Each message contains full list
// of tasks which should be run on node, if task is not present in that list,
// it should be terminated.
func (d *Dispatcher) Tasks(r *api.TasksRequest, stream api.Dispatcher_TasksServer) error {
nodeInfo, err := ca.RemoteNode(stream.Context())
if err != nil {
return err
}
nodeID := nodeInfo.NodeID
dctx, err := d.isRunningLocked()
if err != nil {
return err
}
fields := logrus.Fields{
"node.id": nodeID,
"node.session": r.SessionID,
"method": "(*Dispatcher).Tasks",
}
if nodeInfo.ForwardedBy != nil {
fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
}
log.G(stream.Context()).WithFields(fields).Debug("")
if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
return err
}
tasksMap := make(map[string]*api.Task)
nodeTasks, cancel, err := store.ViewAndWatch(
d.store,
func(readTx store.ReadTx) error {
tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID))
if err != nil {
return err
}
for _, t := range tasks {
tasksMap[t.ID] = t
}
return nil
},
api.EventCreateTask{Task: &api.Task{NodeID: nodeID},
Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
api.EventUpdateTask{Task: &api.Task{NodeID: nodeID},
Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
api.EventDeleteTask{Task: &api.Task{NodeID: nodeID},
Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
)
if err != nil {
return err
}
defer cancel()
for {
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
return err
}
var tasks []*api.Task
for _, t := range tasksMap {
// dispatcher only sends tasks that have been assigned to a node
if t != nil && t.Status.State >= api.TaskStateAssigned {
tasks = append(tasks, t)
}
}
if err := stream.Send(&api.TasksMessage{Tasks: tasks}); err != nil {
return err
}
// bursty events should be processed in batches and sent out snapshot
var (
modificationCnt int
batchingTimer *time.Timer
batchingTimeout <-chan time.Time
)
batchingLoop:
for modificationCnt < modificationBatchLimit {
select {
case event := <-nodeTasks:
switch v := event.(type) {
case api.EventCreateTask:
tasksMap[v.Task.ID] = v.Task
modificationCnt++
case api.EventUpdateTask:
if oldTask, exists := tasksMap[v.Task.ID]; exists {
// States ASSIGNED and below are set by the orchestrator/scheduler,
// not the agent, so tasks in these states need to be sent to the
// agent even if nothing else has changed.
if equality.TasksEqualStable(oldTask, v.Task) && v.Task.Status.State > api.TaskStateAssigned {
// this update should not trigger action at agent
tasksMap[v.Task.ID] = v.Task
continue
}
}
tasksMap[v.Task.ID] = v.Task
modificationCnt++
case api.EventDeleteTask:
delete(tasksMap, v.Task.ID)
modificationCnt++
}
if batchingTimer != nil {
batchingTimer.Reset(batchingWaitTime)
} else {
batchingTimer = time.NewTimer(batchingWaitTime)
batchingTimeout = batchingTimer.C
}
case <-batchingTimeout:
break batchingLoop
case <-stream.Context().Done():
return stream.Context().Err()
case <-dctx.Done():
return dctx.Err()
}
}
if batchingTimer != nil {
batchingTimer.Stop()
}
}
}
// Assignments is a stream of assignments for a node. Each message contains
// either full list of tasks and secrets for the node, or an incremental update.
func (d *Dispatcher) Assignments(r *api.AssignmentsRequest, stream api.Dispatcher_AssignmentsServer) error {
nodeInfo, err := ca.RemoteNode(stream.Context())
if err != nil {
return err
}
nodeID := nodeInfo.NodeID
dctx, err := d.isRunningLocked()
if err != nil {
return err
}
fields := logrus.Fields{
"node.id": nodeID,
"node.session": r.SessionID,
"method": "(*Dispatcher).Assignments",
}
if nodeInfo.ForwardedBy != nil {
fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
}
log := log.G(stream.Context()).WithFields(fields)
log.Debug("")
if _, err = d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
return err
}
var (
sequence int64
appliesTo string
assignments = newAssignmentSet(log, d.dp)
)
sendMessage := func(msg api.AssignmentsMessage, assignmentType api.AssignmentsMessage_Type) error {
sequence++
msg.AppliesTo = appliesTo
msg.ResultsIn = strconv.FormatInt(sequence, 10)
appliesTo = msg.ResultsIn
msg.Type = assignmentType
return stream.Send(&msg)
}
// TODO(aaronl): Also send node secrets that should be exposed to
// this node.
nodeTasks, cancel, err := store.ViewAndWatch(
d.store,
func(readTx store.ReadTx) error {
tasks, err := store.FindTasks(readTx, store.ByNodeID(nodeID))
if err != nil {
return err
}
for _, t := range tasks {
assignments.addOrUpdateTask(readTx, t)
}
return nil
},
api.EventUpdateTask{Task: &api.Task{NodeID: nodeID},
Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
api.EventDeleteTask{Task: &api.Task{NodeID: nodeID},
Checks: []api.TaskCheckFunc{api.TaskCheckNodeID}},
)
if err != nil {
return err
}
defer cancel()
if err := sendMessage(assignments.message(), api.AssignmentsMessage_COMPLETE); err != nil {
return err
}
for {
// Check for session expiration
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
return err
}
// bursty events should be processed in batches and sent out together
var (
modificationCnt int
batchingTimer *time.Timer
batchingTimeout <-chan time.Time
)
oneModification := func() {
modificationCnt++
if batchingTimer != nil {
batchingTimer.Reset(batchingWaitTime)
} else {
batchingTimer = time.NewTimer(batchingWaitTime)
batchingTimeout = batchingTimer.C
}
}
// The batching loop waits for 50 ms after the most recent
// change, or until modificationBatchLimit is reached. The
// worst case latency is modificationBatchLimit * batchingWaitTime,
// which is 10 seconds.
batchingLoop:
for modificationCnt < modificationBatchLimit {
select {
case event := <-nodeTasks:
switch v := event.(type) {
// We don't monitor EventCreateTask because tasks are
// never created in the ASSIGNED state. First tasks are
// created by the orchestrator, then the scheduler moves
// them to ASSIGNED. If this ever changes, we will need
// to monitor task creations as well.
case api.EventUpdateTask:
d.store.View(func(readTx store.ReadTx) {
if assignments.addOrUpdateTask(readTx, v.Task) {
oneModification()
}
})
case api.EventDeleteTask:
if assignments.removeTask(v.Task) {
oneModification()
}
// TODO(aaronl): For node secrets, we'll need to handle
// EventCreateSecret.
}
case <-batchingTimeout:
break batchingLoop
case <-stream.Context().Done():
return stream.Context().Err()
case <-dctx.Done():
return dctx.Err()
}
}
if batchingTimer != nil {
batchingTimer.Stop()
}
if modificationCnt > 0 {
if err := sendMessage(assignments.message(), api.AssignmentsMessage_INCREMENTAL); err != nil {
return err
}
}
}
}
func (d *Dispatcher) moveTasksToOrphaned(nodeID string) error {
err := d.store.Batch(func(batch *store.Batch) error {
var (
tasks []*api.Task
err error
)
d.store.View(func(tx store.ReadTx) {
tasks, err = store.FindTasks(tx, store.ByNodeID(nodeID))
})
if err != nil {
return err
}
for _, task := range tasks {
// Tasks running on an unreachable node need to be marked as
// orphaned since we have no idea whether the task is still running
// or not.
//
// This only applies for tasks that could have made progress since
// the agent became unreachable (assigned<->running)
//
// Tasks in a final state (e.g. rejected) *cannot* have made
// progress, therefore there's no point in marking them as orphaned
if task.Status.State >= api.TaskStateAssigned && task.Status.State <= api.TaskStateRunning {
task.Status.State = api.TaskStateOrphaned
}
if err := batch.Update(func(tx store.Tx) error {
err := store.UpdateTask(tx, task)
if err != nil {
return err
}
return nil
}); err != nil {
return err
}
}
return nil
})
return err
}
// markNodeNotReady sets the node state to some state other than READY
func (d *Dispatcher) markNodeNotReady(id string, state api.NodeStatus_State, message string) error {
dctx, err := d.isRunningLocked()
if err != nil {
return err
}
// Node is down. Add it to down nodes so that we can keep
// track of tasks assigned to the node.
var node *api.Node
d.store.View(func(readTx store.ReadTx) {
node = store.GetNode(readTx, id)
if node == nil {
err = fmt.Errorf("could not find node %s while trying to add to down nodes store", id)
}
})
if err != nil {
return err
}
expireFunc := func() {
if err := d.moveTasksToOrphaned(id); err != nil {
log.G(dctx).WithError(err).Error(`failed to move all tasks to "ORPHANED" state`)
}
d.downNodes.Delete(id)
}
d.downNodes.Add(node, expireFunc)
status := &api.NodeStatus{
State: state,
Message: message,
}
d.nodeUpdatesLock.Lock()
// pluck the description out of nodeUpdates. this protects against a case
// where a node is marked ready and a description is added, but then the
// node is immediately marked not ready. this preserves that description
d.nodeUpdates[id] = nodeUpdate{status: status, description: d.nodeUpdates[id].description}
numUpdates := len(d.nodeUpdates)
d.nodeUpdatesLock.Unlock()
if numUpdates >= maxBatchItems {
select {
case d.processUpdatesTrigger <- struct{}{}:
case <-dctx.Done():
}
}
if rn := d.nodes.Delete(id); rn == nil {
return errors.Errorf("node %s is not found in local storage", id)
}
return nil
}
// Heartbeat is heartbeat method for nodes. It returns new TTL in response.
// Node should send new heartbeat earlier than now + TTL, otherwise it will
// be deregistered from dispatcher and its status will be updated to NodeStatus_DOWN
func (d *Dispatcher) Heartbeat(ctx context.Context, r *api.HeartbeatRequest) (*api.HeartbeatResponse, error) {
nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return nil, err
}
period, err := d.nodes.Heartbeat(nodeInfo.NodeID, r.SessionID)
return &api.HeartbeatResponse{Period: period}, err
}
func (d *Dispatcher) getManagers() []*api.WeightedPeer {
d.mu.Lock()
defer d.mu.Unlock()
return d.lastSeenManagers
}
func (d *Dispatcher) getNetworkBootstrapKeys() []*api.EncryptionKey {
d.mu.Lock()
defer d.mu.Unlock()
return d.networkBootstrapKeys
}
func (d *Dispatcher) getRootCACert() []byte {
d.mu.Lock()
defer d.mu.Unlock()
return d.lastSeenRootCert
}
// Session is a stream which controls agent connection.
// Each message contains list of backup Managers with weights. Also there is
// a special boolean field Disconnect which if true indicates that node should
// reconnect to another Manager immediately.
func (d *Dispatcher) Session(r *api.SessionRequest, stream api.Dispatcher_SessionServer) error {
ctx := stream.Context()
nodeInfo, err := ca.RemoteNode(ctx)
if err != nil {
return err
}
nodeID := nodeInfo.NodeID
dctx, err := d.isRunningLocked()
if err != nil {
return err
}
var sessionID string
if _, err := d.nodes.GetWithSession(nodeID, r.SessionID); err != nil {
// register the node.
sessionID, err = d.register(ctx, nodeID, r.Description)
if err != nil {
return err
}
} else {
sessionID = r.SessionID
// get the node IP addr
addr, err := nodeIPFromContext(stream.Context())
if err != nil {
log.G(ctx).WithError(err).Debug("failed to get remote node IP")
}
// update the node description
if err := d.markNodeReady(dctx, nodeID, r.Description, addr); err != nil {
return err
}
}
fields := logrus.Fields{
"node.id": nodeID,
"node.session": sessionID,
"method": "(*Dispatcher).Session",
}
if nodeInfo.ForwardedBy != nil {
fields["forwarder.id"] = nodeInfo.ForwardedBy.NodeID
}
log := log.G(ctx).WithFields(fields)
var nodeObj *api.Node
nodeUpdates, cancel, err := store.ViewAndWatch(d.store, func(readTx store.ReadTx) error {
nodeObj = store.GetNode(readTx, nodeID)
return nil
}, api.EventUpdateNode{Node: &api.Node{ID: nodeID},
Checks: []api.NodeCheckFunc{api.NodeCheckID}},
)
if cancel != nil {
defer cancel()
}
if err != nil {
log.WithError(err).Error("ViewAndWatch Node failed")
}
if _, err = d.nodes.GetWithSession(nodeID, sessionID); err != nil {
return err
}
clusterUpdatesCh, clusterCancel := d.clusterUpdateQueue.Watch()
defer clusterCancel()
if err := stream.Send(&api.SessionMessage{
SessionID: sessionID,
Node: nodeObj,
Managers: d.getManagers(),
NetworkBootstrapKeys: d.getNetworkBootstrapKeys(),
RootCA: d.getRootCACert(),
}); err != nil {
return err
}
// disconnectNode is a helper forcibly shutdown connection
disconnectNode := func() error {
// force disconnect by shutting down the stream.
transportStream, ok := transport.StreamFromContext(stream.Context())
if ok {
// if we have the transport stream, we can signal a disconnect
// in the client.
if err := transportStream.ServerTransport().Close(); err != nil {
log.WithError(err).Error("session end")
}
}
if err := d.markNodeNotReady(nodeID, api.NodeStatus_DISCONNECTED, "node is currently trying to find new manager"); err != nil {
log.WithError(err).Error("failed to remove node")
}
// still return an abort if the transport closure was ineffective.
return status.Errorf(codes.Aborted, "node must disconnect")
}
for {
// After each message send, we need to check the nodes sessionID hasn't
// changed. If it has, we will shut down the stream and make the node
// re-register.
node, err := d.nodes.GetWithSession(nodeID, sessionID)
if err != nil {
return err
}
var (
disconnect bool
mgrs []*api.WeightedPeer
netKeys []*api.EncryptionKey
rootCert []byte
)
select {
case ev := <-clusterUpdatesCh:
update := ev.(clusterUpdate)
if update.managerUpdate != nil {
mgrs = *update.managerUpdate
}
if update.bootstrapKeyUpdate != nil {
netKeys = *update.bootstrapKeyUpdate
}
if update.rootCAUpdate != nil {
rootCert = *update.rootCAUpdate
}
case ev := <-nodeUpdates:
nodeObj = ev.(api.EventUpdateNode).Node
case <-stream.Context().Done():
return stream.Context().Err()
case <-node.Disconnect:
disconnect = true
case <-dctx.Done():
disconnect = true
}
if mgrs == nil {
mgrs = d.getManagers()
}
if netKeys == nil {
netKeys = d.getNetworkBootstrapKeys()
}
if rootCert == nil {
rootCert = d.getRootCACert()
}
if err := stream.Send(&api.SessionMessage{
SessionID: sessionID,
Node: nodeObj,
Managers: mgrs,
NetworkBootstrapKeys: netKeys,
RootCA: rootCert,
}); err != nil {
return err
}
if disconnect {
return disconnectNode()
}
}
}