blob: 644e295214e7089190c352b5c42c90554d7512e1 [file] [log] [blame]
package transport
import (
"fmt"
"sync"
"time"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"github.com/coreos/etcd/raft"
"github.com/coreos/etcd/raft/raftpb"
"github.com/docker/swarmkit/api"
"github.com/docker/swarmkit/log"
"github.com/docker/swarmkit/manager/state/raft/membership"
"github.com/pkg/errors"
"google.golang.org/grpc/status"
)
const (
// GRPCMaxMsgSize is the max allowed gRPC message size for raft messages.
GRPCMaxMsgSize = 4 << 20
)
type peer struct {
id uint64
tr *Transport
msgc chan raftpb.Message
ctx context.Context
cancel context.CancelFunc
done chan struct{}
mu sync.Mutex
cc *grpc.ClientConn
addr string
newAddr string
active bool
becameActive time.Time
}
func newPeer(id uint64, addr string, tr *Transport) (*peer, error) {
cc, err := tr.dial(addr)
if err != nil {
return nil, errors.Wrapf(err, "failed to create conn for %x with addr %s", id, addr)
}
ctx, cancel := context.WithCancel(tr.ctx)
ctx = log.WithField(ctx, "peer_id", fmt.Sprintf("%x", id))
p := &peer{
id: id,
addr: addr,
cc: cc,
tr: tr,
ctx: ctx,
cancel: cancel,
msgc: make(chan raftpb.Message, 4096),
done: make(chan struct{}),
}
go p.run(ctx)
return p, nil
}
func (p *peer) send(m raftpb.Message) (err error) {
p.mu.Lock()
defer func() {
if err != nil {
p.active = false
p.becameActive = time.Time{}
}
p.mu.Unlock()
}()
select {
case <-p.ctx.Done():
return p.ctx.Err()
default:
}
select {
case p.msgc <- m:
case <-p.ctx.Done():
return p.ctx.Err()
default:
p.tr.config.ReportUnreachable(p.id)
return errors.Errorf("peer is unreachable")
}
return nil
}
func (p *peer) update(addr string) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.addr == addr {
return nil
}
cc, err := p.tr.dial(addr)
if err != nil {
return err
}
p.cc.Close()
p.cc = cc
p.addr = addr
return nil
}
func (p *peer) updateAddr(addr string) error {
p.mu.Lock()
defer p.mu.Unlock()
if p.addr == addr {
return nil
}
log.G(p.ctx).Debugf("peer %x updated to address %s, it will be used if old failed", p.id, addr)
p.newAddr = addr
return nil
}
func (p *peer) conn() *grpc.ClientConn {
p.mu.Lock()
defer p.mu.Unlock()
return p.cc
}
func (p *peer) address() string {
p.mu.Lock()
defer p.mu.Unlock()
return p.addr
}
func (p *peer) resolveAddr(ctx context.Context, id uint64) (string, error) {
resp, err := api.NewRaftClient(p.conn()).ResolveAddress(ctx, &api.ResolveAddressRequest{RaftID: id})
if err != nil {
return "", errors.Wrap(err, "failed to resolve address")
}
return resp.Addr, nil
}
// Returns the raft message struct size (not including the payload size) for the given raftpb.Message.
// The payload is typically the snapshot or append entries.
func raftMessageStructSize(m *raftpb.Message) int {
return (&api.ProcessRaftMessageRequest{Message: m}).Size() - len(m.Snapshot.Data)
}
// Returns the max allowable payload based on MaxRaftMsgSize and
// the struct size for the given raftpb.Message.
func raftMessagePayloadSize(m *raftpb.Message) int {
return GRPCMaxMsgSize - raftMessageStructSize(m)
}
// Split a large raft message into smaller messages.
// Currently this means splitting the []Snapshot.Data into chunks whose size
// is dictacted by MaxRaftMsgSize.
func splitSnapshotData(ctx context.Context, m *raftpb.Message) []api.StreamRaftMessageRequest {
var messages []api.StreamRaftMessageRequest
if m.Type != raftpb.MsgSnap {
return messages
}
// get the size of the data to be split.
size := len(m.Snapshot.Data)
// Get the max payload size.
payloadSize := raftMessagePayloadSize(m)
// split the snapshot into smaller messages.
for snapDataIndex := 0; snapDataIndex < size; {
chunkSize := size - snapDataIndex
if chunkSize > payloadSize {
chunkSize = payloadSize
}
raftMsg := *m
// sub-slice for this snapshot chunk.
raftMsg.Snapshot.Data = m.Snapshot.Data[snapDataIndex : snapDataIndex+chunkSize]
snapDataIndex += chunkSize
// add message to the list of messages to be sent.
msg := api.StreamRaftMessageRequest{Message: &raftMsg}
messages = append(messages, msg)
}
return messages
}
// Function to check if this message needs to be split to be streamed
// (because it is larger than GRPCMaxMsgSize).
// Returns true if the message type is MsgSnap
// and size larger than MaxRaftMsgSize.
func needsSplitting(m *raftpb.Message) bool {
raftMsg := api.ProcessRaftMessageRequest{Message: m}
return m.Type == raftpb.MsgSnap && raftMsg.Size() > GRPCMaxMsgSize
}
func (p *peer) sendProcessMessage(ctx context.Context, m raftpb.Message) error {
ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
defer cancel()
var err error
var stream api.Raft_StreamRaftMessageClient
stream, err = api.NewRaftClient(p.conn()).StreamRaftMessage(ctx)
if err == nil {
// Split the message if needed.
// Currently only supported for MsgSnap.
var msgs []api.StreamRaftMessageRequest
if needsSplitting(&m) {
msgs = splitSnapshotData(ctx, &m)
} else {
raftMsg := api.StreamRaftMessageRequest{Message: &m}
msgs = append(msgs, raftMsg)
}
// Stream
for _, msg := range msgs {
err = stream.Send(&msg)
if err != nil {
log.G(ctx).WithError(err).Error("error streaming message to peer")
stream.CloseAndRecv()
break
}
}
// Finished sending all the messages.
// Close and receive response.
if err == nil {
_, err = stream.CloseAndRecv()
if err != nil {
log.G(ctx).WithError(err).Error("error receiving response")
}
}
} else {
log.G(ctx).WithError(err).Error("error sending message to peer")
}
// Try doing a regular rpc if the receiver doesn't support streaming.
s, _ := status.FromError(err)
if s.Code() == codes.Unimplemented {
log.G(ctx).Info("sending message to raft peer using ProcessRaftMessage()")
_, err = api.NewRaftClient(p.conn()).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m})
}
// Handle errors.
s, _ = status.FromError(err)
if s.Code() == codes.NotFound && s.Message() == membership.ErrMemberRemoved.Error() {
p.tr.config.NodeRemoved()
}
if m.Type == raftpb.MsgSnap {
if err != nil {
p.tr.config.ReportSnapshot(m.To, raft.SnapshotFailure)
} else {
p.tr.config.ReportSnapshot(m.To, raft.SnapshotFinish)
}
}
if err != nil {
p.tr.config.ReportUnreachable(m.To)
return err
}
return nil
}
func healthCheckConn(ctx context.Context, cc *grpc.ClientConn) error {
resp, err := api.NewHealthClient(cc).Check(ctx, &api.HealthCheckRequest{Service: "Raft"})
if err != nil {
return errors.Wrap(err, "failed to check health")
}
if resp.Status != api.HealthCheckResponse_SERVING {
return errors.Errorf("health check returned status %s", resp.Status)
}
return nil
}
func (p *peer) healthCheck(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
defer cancel()
return healthCheckConn(ctx, p.conn())
}
func (p *peer) setActive() {
p.mu.Lock()
if !p.active {
p.active = true
p.becameActive = time.Now()
}
p.mu.Unlock()
}
func (p *peer) setInactive() {
p.mu.Lock()
p.active = false
p.becameActive = time.Time{}
p.mu.Unlock()
}
func (p *peer) activeTime() time.Time {
p.mu.Lock()
defer p.mu.Unlock()
return p.becameActive
}
func (p *peer) drain() error {
ctx, cancel := context.WithTimeout(context.Background(), 16*time.Second)
defer cancel()
for {
select {
case m, ok := <-p.msgc:
if !ok {
// all messages proceeded
return nil
}
if err := p.sendProcessMessage(ctx, m); err != nil {
return errors.Wrap(err, "send drain message")
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func (p *peer) handleAddressChange(ctx context.Context) error {
p.mu.Lock()
newAddr := p.newAddr
p.newAddr = ""
p.mu.Unlock()
if newAddr == "" {
return nil
}
cc, err := p.tr.dial(newAddr)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
defer cancel()
if err := healthCheckConn(ctx, cc); err != nil {
cc.Close()
return err
}
// there is possibility of race if host changing address too fast, but
// it's unlikely and eventually thing should be settled
p.mu.Lock()
p.cc.Close()
p.cc = cc
p.addr = newAddr
p.tr.config.UpdateNode(p.id, p.addr)
p.mu.Unlock()
return nil
}
func (p *peer) run(ctx context.Context) {
defer func() {
p.mu.Lock()
p.active = false
p.becameActive = time.Time{}
// at this point we can be sure that nobody will write to msgc
if p.msgc != nil {
close(p.msgc)
}
p.mu.Unlock()
if err := p.drain(); err != nil {
log.G(ctx).WithError(err).Error("failed to drain message queue")
}
close(p.done)
}()
if err := p.healthCheck(ctx); err == nil {
p.setActive()
}
for {
select {
case <-ctx.Done():
return
default:
}
select {
case m := <-p.msgc:
// we do not propagate context here, because this operation should be finished
// or timed out for correct raft work.
err := p.sendProcessMessage(context.Background(), m)
if err != nil {
log.G(ctx).WithError(err).Debugf("failed to send message %s", m.Type)
p.setInactive()
if err := p.handleAddressChange(ctx); err != nil {
log.G(ctx).WithError(err).Error("failed to change address after failure")
}
continue
}
p.setActive()
case <-ctx.Done():
return
}
}
}
func (p *peer) stop() {
p.cancel()
<-p.done
}