| package transport |
| |
| import ( |
| "fmt" |
| "sync" |
| "time" |
| |
| "golang.org/x/net/context" |
| |
| "google.golang.org/grpc" |
| "google.golang.org/grpc/codes" |
| |
| "github.com/coreos/etcd/raft" |
| "github.com/coreos/etcd/raft/raftpb" |
| "github.com/docker/swarmkit/api" |
| "github.com/docker/swarmkit/log" |
| "github.com/docker/swarmkit/manager/state/raft/membership" |
| "github.com/pkg/errors" |
| ) |
| |
| const ( |
| // GRPCMaxMsgSize is the max allowed gRPC message size for raft messages. |
| GRPCMaxMsgSize = 4 << 20 |
| ) |
| |
| type peer struct { |
| id uint64 |
| |
| tr *Transport |
| |
| msgc chan raftpb.Message |
| |
| ctx context.Context |
| cancel context.CancelFunc |
| done chan struct{} |
| |
| mu sync.Mutex |
| cc *grpc.ClientConn |
| addr string |
| newAddr string |
| |
| active bool |
| becameActive time.Time |
| } |
| |
| func newPeer(id uint64, addr string, tr *Transport) (*peer, error) { |
| cc, err := tr.dial(addr) |
| if err != nil { |
| return nil, errors.Wrapf(err, "failed to create conn for %x with addr %s", id, addr) |
| } |
| ctx, cancel := context.WithCancel(tr.ctx) |
| ctx = log.WithField(ctx, "peer_id", fmt.Sprintf("%x", id)) |
| p := &peer{ |
| id: id, |
| addr: addr, |
| cc: cc, |
| tr: tr, |
| ctx: ctx, |
| cancel: cancel, |
| msgc: make(chan raftpb.Message, 4096), |
| done: make(chan struct{}), |
| } |
| go p.run(ctx) |
| return p, nil |
| } |
| |
| func (p *peer) send(m raftpb.Message) (err error) { |
| p.mu.Lock() |
| defer func() { |
| if err != nil { |
| p.active = false |
| p.becameActive = time.Time{} |
| } |
| p.mu.Unlock() |
| }() |
| select { |
| case <-p.ctx.Done(): |
| return p.ctx.Err() |
| default: |
| } |
| select { |
| case p.msgc <- m: |
| case <-p.ctx.Done(): |
| return p.ctx.Err() |
| default: |
| p.tr.config.ReportUnreachable(p.id) |
| return errors.Errorf("peer is unreachable") |
| } |
| return nil |
| } |
| |
| func (p *peer) update(addr string) error { |
| p.mu.Lock() |
| defer p.mu.Unlock() |
| if p.addr == addr { |
| return nil |
| } |
| cc, err := p.tr.dial(addr) |
| if err != nil { |
| return err |
| } |
| |
| p.cc.Close() |
| p.cc = cc |
| p.addr = addr |
| return nil |
| } |
| |
| func (p *peer) updateAddr(addr string) error { |
| p.mu.Lock() |
| defer p.mu.Unlock() |
| if p.addr == addr { |
| return nil |
| } |
| log.G(p.ctx).Debugf("peer %x updated to address %s, it will be used if old failed", p.id, addr) |
| p.newAddr = addr |
| return nil |
| } |
| |
| func (p *peer) conn() *grpc.ClientConn { |
| p.mu.Lock() |
| defer p.mu.Unlock() |
| return p.cc |
| } |
| |
| func (p *peer) address() string { |
| p.mu.Lock() |
| defer p.mu.Unlock() |
| return p.addr |
| } |
| |
| func (p *peer) resolveAddr(ctx context.Context, id uint64) (string, error) { |
| resp, err := api.NewRaftClient(p.conn()).ResolveAddress(ctx, &api.ResolveAddressRequest{RaftID: id}) |
| if err != nil { |
| return "", errors.Wrap(err, "failed to resolve address") |
| } |
| return resp.Addr, nil |
| } |
| |
| // Returns the raft message struct size (not including the payload size) for the given raftpb.Message. |
| // The payload is typically the snapshot or append entries. |
| func raftMessageStructSize(m *raftpb.Message) int { |
| return (&api.ProcessRaftMessageRequest{Message: m}).Size() - len(m.Snapshot.Data) |
| } |
| |
| // Returns the max allowable payload based on MaxRaftMsgSize and |
| // the struct size for the given raftpb.Message. |
| func raftMessagePayloadSize(m *raftpb.Message) int { |
| return GRPCMaxMsgSize - raftMessageStructSize(m) |
| } |
| |
| // Split a large raft message into smaller messages. |
| // Currently this means splitting the []Snapshot.Data into chunks whose size |
| // is dictacted by MaxRaftMsgSize. |
| func splitSnapshotData(ctx context.Context, m *raftpb.Message) []api.StreamRaftMessageRequest { |
| var messages []api.StreamRaftMessageRequest |
| if m.Type != raftpb.MsgSnap { |
| return messages |
| } |
| |
| // get the size of the data to be split. |
| size := len(m.Snapshot.Data) |
| |
| // Get the max payload size. |
| payloadSize := raftMessagePayloadSize(m) |
| |
| // split the snapshot into smaller messages. |
| for snapDataIndex := 0; snapDataIndex < size; { |
| chunkSize := size - snapDataIndex |
| if chunkSize > payloadSize { |
| chunkSize = payloadSize |
| } |
| |
| raftMsg := *m |
| |
| // sub-slice for this snapshot chunk. |
| raftMsg.Snapshot.Data = m.Snapshot.Data[snapDataIndex : snapDataIndex+chunkSize] |
| |
| snapDataIndex += chunkSize |
| |
| // add message to the list of messages to be sent. |
| msg := api.StreamRaftMessageRequest{Message: &raftMsg} |
| messages = append(messages, msg) |
| } |
| |
| return messages |
| } |
| |
| // Function to check if this message needs to be split to be streamed |
| // (because it is larger than GRPCMaxMsgSize). |
| // Returns true if the message type is MsgSnap |
| // and size larger than MaxRaftMsgSize. |
| func needsSplitting(m *raftpb.Message) bool { |
| raftMsg := api.ProcessRaftMessageRequest{Message: m} |
| return m.Type == raftpb.MsgSnap && raftMsg.Size() > GRPCMaxMsgSize |
| } |
| |
| func (p *peer) sendProcessMessage(ctx context.Context, m raftpb.Message) error { |
| ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) |
| defer cancel() |
| |
| var err error |
| var stream api.Raft_StreamRaftMessageClient |
| stream, err = api.NewRaftClient(p.conn()).StreamRaftMessage(ctx) |
| |
| if err == nil { |
| // Split the message if needed. |
| // Currently only supported for MsgSnap. |
| var msgs []api.StreamRaftMessageRequest |
| if needsSplitting(&m) { |
| msgs = splitSnapshotData(ctx, &m) |
| } else { |
| raftMsg := api.StreamRaftMessageRequest{Message: &m} |
| msgs = append(msgs, raftMsg) |
| } |
| |
| // Stream |
| for _, msg := range msgs { |
| err = stream.Send(&msg) |
| if err != nil { |
| log.G(ctx).WithError(err).Error("error streaming message to peer") |
| stream.CloseAndRecv() |
| break |
| } |
| } |
| |
| // Finished sending all the messages. |
| // Close and receive response. |
| if err == nil { |
| _, err = stream.CloseAndRecv() |
| |
| if err != nil { |
| log.G(ctx).WithError(err).Error("error receiving response") |
| } |
| } |
| } else { |
| log.G(ctx).WithError(err).Error("error sending message to peer") |
| } |
| |
| // Try doing a regular rpc if the receiver doesn't support streaming. |
| if grpc.Code(err) == codes.Unimplemented { |
| _, err = api.NewRaftClient(p.conn()).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m}) |
| } |
| |
| // Handle errors. |
| if grpc.Code(err) == codes.NotFound && grpc.ErrorDesc(err) == membership.ErrMemberRemoved.Error() { |
| p.tr.config.NodeRemoved() |
| } |
| if m.Type == raftpb.MsgSnap { |
| if err != nil { |
| p.tr.config.ReportSnapshot(m.To, raft.SnapshotFailure) |
| } else { |
| p.tr.config.ReportSnapshot(m.To, raft.SnapshotFinish) |
| } |
| } |
| if err != nil { |
| p.tr.config.ReportUnreachable(m.To) |
| return err |
| } |
| return nil |
| } |
| |
| func healthCheckConn(ctx context.Context, cc *grpc.ClientConn) error { |
| resp, err := api.NewHealthClient(cc).Check(ctx, &api.HealthCheckRequest{Service: "Raft"}) |
| if err != nil { |
| return errors.Wrap(err, "failed to check health") |
| } |
| if resp.Status != api.HealthCheckResponse_SERVING { |
| return errors.Errorf("health check returned status %s", resp.Status) |
| } |
| return nil |
| } |
| |
| func (p *peer) healthCheck(ctx context.Context) error { |
| ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) |
| defer cancel() |
| return healthCheckConn(ctx, p.conn()) |
| } |
| |
| func (p *peer) setActive() { |
| p.mu.Lock() |
| if !p.active { |
| p.active = true |
| p.becameActive = time.Now() |
| } |
| p.mu.Unlock() |
| } |
| |
| func (p *peer) setInactive() { |
| p.mu.Lock() |
| p.active = false |
| p.becameActive = time.Time{} |
| p.mu.Unlock() |
| } |
| |
| func (p *peer) activeTime() time.Time { |
| p.mu.Lock() |
| defer p.mu.Unlock() |
| return p.becameActive |
| } |
| |
| func (p *peer) drain() error { |
| ctx, cancel := context.WithTimeout(context.Background(), 16*time.Second) |
| defer cancel() |
| for { |
| select { |
| case m, ok := <-p.msgc: |
| if !ok { |
| // all messages proceeded |
| return nil |
| } |
| if err := p.sendProcessMessage(ctx, m); err != nil { |
| return errors.Wrap(err, "send drain message") |
| } |
| case <-ctx.Done(): |
| return ctx.Err() |
| } |
| } |
| } |
| |
| func (p *peer) handleAddressChange(ctx context.Context) error { |
| p.mu.Lock() |
| newAddr := p.newAddr |
| p.newAddr = "" |
| p.mu.Unlock() |
| if newAddr == "" { |
| return nil |
| } |
| cc, err := p.tr.dial(newAddr) |
| if err != nil { |
| return err |
| } |
| ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout) |
| defer cancel() |
| if err := healthCheckConn(ctx, cc); err != nil { |
| cc.Close() |
| return err |
| } |
| // there is possibility of race if host changing address too fast, but |
| // it's unlikely and eventually thing should be settled |
| p.mu.Lock() |
| p.cc.Close() |
| p.cc = cc |
| p.addr = newAddr |
| p.tr.config.UpdateNode(p.id, p.addr) |
| p.mu.Unlock() |
| return nil |
| } |
| |
| func (p *peer) run(ctx context.Context) { |
| defer func() { |
| p.mu.Lock() |
| p.active = false |
| p.becameActive = time.Time{} |
| // at this point we can be sure that nobody will write to msgc |
| if p.msgc != nil { |
| close(p.msgc) |
| } |
| p.mu.Unlock() |
| if err := p.drain(); err != nil { |
| log.G(ctx).WithError(err).Error("failed to drain message queue") |
| } |
| close(p.done) |
| }() |
| if err := p.healthCheck(ctx); err == nil { |
| p.setActive() |
| } |
| for { |
| select { |
| case <-ctx.Done(): |
| return |
| default: |
| } |
| |
| select { |
| case m := <-p.msgc: |
| // we do not propagate context here, because this operation should be finished |
| // or timed out for correct raft work. |
| err := p.sendProcessMessage(context.Background(), m) |
| if err != nil { |
| log.G(ctx).WithError(err).Debugf("failed to send message %s", m.Type) |
| p.setInactive() |
| if err := p.handleAddressChange(ctx); err != nil { |
| log.G(ctx).WithError(err).Error("failed to change address after failure") |
| } |
| continue |
| } |
| p.setActive() |
| case <-ctx.Done(): |
| return |
| } |
| } |
| } |
| |
| func (p *peer) stop() { |
| p.cancel() |
| <-p.done |
| } |