blob: a4330c4d2047469bbd70d27de291fdd629dd23f5 [file] [log] [blame]
package memberlist
import (
"bufio"
"bytes"
"encoding/binary"
"fmt"
"hash/crc32"
"io"
"net"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/go-msgpack/codec"
)
// This is the minimum and maximum protocol version that we can
// _understand_. We're allowed to speak at any version within this
// range. This range is inclusive.
const (
ProtocolVersionMin uint8 = 1
// Version 3 added support for TCP pings but we kept the default
// protocol version at 2 to ease transition to this new feature.
// A memberlist speaking version 2 of the protocol will attempt
// to TCP ping another memberlist who understands version 3 or
// greater.
//
// Version 4 added support for nacks as part of indirect probes.
// A memberlist speaking version 2 of the protocol will expect
// nacks from another memberlist who understands version 4 or
// greater, and likewise nacks will be sent to memberlists who
// understand version 4 or greater.
ProtocolVersion2Compatible = 2
ProtocolVersionMax = 5
)
// messageType is an integer ID of a type of message that can be received
// on network channels from other members.
type messageType uint8
// The list of available message types.
const (
pingMsg messageType = iota
indirectPingMsg
ackRespMsg
suspectMsg
aliveMsg
deadMsg
pushPullMsg
compoundMsg
userMsg // User mesg, not handled by us
compressMsg
encryptMsg
nackRespMsg
hasCrcMsg
errMsg
)
// compressionType is used to specify the compression algorithm
type compressionType uint8
const (
lzwAlgo compressionType = iota
)
const (
MetaMaxSize = 512 // Maximum size for node meta data
compoundHeaderOverhead = 2 // Assumed header overhead
compoundOverhead = 2 // Assumed overhead per entry in compoundHeader
userMsgOverhead = 1
blockingWarning = 10 * time.Millisecond // Warn if a UDP packet takes this long to process
maxPushStateBytes = 10 * 1024 * 1024
)
// ping request sent directly to node
type ping struct {
SeqNo uint32
// Node is sent so the target can verify they are
// the intended recipient. This is to protect again an agent
// restart with a new name.
Node string
}
// indirect ping sent to an indirect ndoe
type indirectPingReq struct {
SeqNo uint32
Target []byte
Port uint16
Node string
Nack bool // true if we'd like a nack back
}
// ack response is sent for a ping
type ackResp struct {
SeqNo uint32
Payload []byte
}
// nack response is sent for an indirect ping when the pinger doesn't hear from
// the ping-ee within the configured timeout. This lets the original node know
// that the indirect ping attempt happened but didn't succeed.
type nackResp struct {
SeqNo uint32
}
// err response is sent to relay the error from the remote end
type errResp struct {
Error string
}
// suspect is broadcast when we suspect a node is dead
type suspect struct {
Incarnation uint32
Node string
From string // Include who is suspecting
}
// alive is broadcast when we know a node is alive.
// Overloaded for nodes joining
type alive struct {
Incarnation uint32
Node string
Addr []byte
Port uint16
Meta []byte
// The versions of the protocol/delegate that are being spoken, order:
// pmin, pmax, pcur, dmin, dmax, dcur
Vsn []uint8
}
// dead is broadcast when we confirm a node is dead
// Overloaded for nodes leaving
type dead struct {
Incarnation uint32
Node string
From string // Include who is suspecting
}
// pushPullHeader is used to inform the
// otherside how many states we are transferring
type pushPullHeader struct {
Nodes int
UserStateLen int // Encodes the byte lengh of user state
Join bool // Is this a join request or a anti-entropy run
}
// userMsgHeader is used to encapsulate a userMsg
type userMsgHeader struct {
UserMsgLen int // Encodes the byte lengh of user state
}
// pushNodeState is used for pushPullReq when we are
// transferring out node states
type pushNodeState struct {
Name string
Addr []byte
Port uint16
Meta []byte
Incarnation uint32
State nodeStateType
Vsn []uint8 // Protocol versions
}
// compress is used to wrap an underlying payload
// using a specified compression algorithm
type compress struct {
Algo compressionType
Buf []byte
}
// msgHandoff is used to transfer a message between goroutines
type msgHandoff struct {
msgType messageType
buf []byte
from net.Addr
}
// encryptionVersion returns the encryption version to use
func (m *Memberlist) encryptionVersion() encryptionVersion {
switch m.ProtocolVersion() {
case 1:
return 0
default:
return 1
}
}
// streamListen is a long running goroutine that pulls incoming streams from the
// transport and hands them off for processing.
func (m *Memberlist) streamListen() {
for {
select {
case conn := <-m.transport.StreamCh():
go m.handleConn(conn)
case <-m.shutdownCh:
return
}
}
}
// handleConn handles a single incoming stream connection from the transport.
func (m *Memberlist) handleConn(conn net.Conn) {
m.logger.Printf("[DEBUG] memberlist: Stream connection %s", LogConn(conn))
defer conn.Close()
metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1)
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
msgType, bufConn, dec, err := m.readStream(conn)
if err != nil {
if err != io.EOF {
m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn))
resp := errResp{err.Error()}
out, err := encode(errMsg, &resp)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to encode error response: %s", err)
return
}
err = m.rawSendMsgStream(conn, out.Bytes())
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send error: %s %s", err, LogConn(conn))
return
}
}
return
}
switch msgType {
case userMsg:
if err := m.readUserMsg(bufConn, dec); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to receive user message: %s %s", err, LogConn(conn))
}
case pushPullMsg:
join, remoteNodes, userState, err := m.readRemoteState(bufConn, dec)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to read remote state: %s %s", err, LogConn(conn))
return
}
if err := m.sendLocalState(conn, join); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn))
return
}
if err := m.mergeRemoteState(join, remoteNodes, userState); err != nil {
m.logger.Printf("[ERR] memberlist: Failed push/pull merge: %s %s", err, LogConn(conn))
return
}
case pingMsg:
var p ping
if err := dec.Decode(&p); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode ping: %s %s", err, LogConn(conn))
return
}
if p.Node != "" && p.Node != m.config.Name {
m.logger.Printf("[WARN] memberlist: Got ping for unexpected node %s %s", p.Node, LogConn(conn))
return
}
ack := ackResp{p.SeqNo, nil}
out, err := encode(ackRespMsg, &ack)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to encode ack: %s", err)
return
}
err = m.rawSendMsgStream(conn, out.Bytes())
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogConn(conn))
return
}
default:
m.logger.Printf("[ERR] memberlist: Received invalid msgType (%d) %s", msgType, LogConn(conn))
}
}
// packetListen is a long running goroutine that pulls packets out of the
// transport and hands them off for processing.
func (m *Memberlist) packetListen() {
for {
select {
case packet := <-m.transport.PacketCh():
m.ingestPacket(packet.Buf, packet.From, packet.Timestamp)
case <-m.shutdownCh:
return
}
}
}
func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) {
// Check if encryption is enabled
if m.config.EncryptionEnabled() {
// Decrypt the payload
plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil)
if err != nil {
if !m.config.GossipVerifyIncoming {
// Treat the message as plaintext
plain = buf
} else {
m.logger.Printf("[ERR] memberlist: Decrypt packet failed: %v %s", err, LogAddress(from))
return
}
}
// Continue processing the plaintext buffer
buf = plain
}
// See if there's a checksum included to verify the contents of the message
if len(buf) >= 5 && messageType(buf[0]) == hasCrcMsg {
crc := crc32.ChecksumIEEE(buf[5:])
expected := binary.BigEndian.Uint32(buf[1:5])
if crc != expected {
m.logger.Printf("[WARN] memberlist: Got invalid checksum for UDP packet: %x, %x", crc, expected)
return
}
m.handleCommand(buf[5:], from, timestamp)
} else {
m.handleCommand(buf, from, timestamp)
}
}
func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Time) {
// Decode the message type
msgType := messageType(buf[0])
buf = buf[1:]
// Switch on the msgType
switch msgType {
case compoundMsg:
m.handleCompound(buf, from, timestamp)
case compressMsg:
m.handleCompressed(buf, from, timestamp)
case pingMsg:
m.handlePing(buf, from)
case indirectPingMsg:
m.handleIndirectPing(buf, from)
case ackRespMsg:
m.handleAck(buf, from, timestamp)
case nackRespMsg:
m.handleNack(buf, from)
case suspectMsg:
fallthrough
case aliveMsg:
fallthrough
case deadMsg:
fallthrough
case userMsg:
select {
case m.handoff <- msgHandoff{msgType, buf, from}:
default:
m.logger.Printf("[WARN] memberlist: handler queue full, dropping message (%d) %s", msgType, LogAddress(from))
}
default:
m.logger.Printf("[ERR] memberlist: msg type (%d) not supported %s", msgType, LogAddress(from))
}
}
// packetHandler is a long running goroutine that processes messages received
// over the packet interface, but is decoupled from the listener to avoid
// blocking the listener which may cause ping/ack messages to be delayed.
func (m *Memberlist) packetHandler() {
for {
select {
case msg := <-m.handoff:
msgType := msg.msgType
buf := msg.buf
from := msg.from
switch msgType {
case suspectMsg:
m.handleSuspect(buf, from)
case aliveMsg:
m.handleAlive(buf, from)
case deadMsg:
m.handleDead(buf, from)
case userMsg:
m.handleUser(buf, from)
default:
m.logger.Printf("[ERR] memberlist: Message type (%d) not supported %s (packet handler)", msgType, LogAddress(from))
}
case <-m.shutdownCh:
return
}
}
}
func (m *Memberlist) handleCompound(buf []byte, from net.Addr, timestamp time.Time) {
// Decode the parts
trunc, parts, err := decodeCompoundMessage(buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode compound request: %s %s", err, LogAddress(from))
return
}
// Log any truncation
if trunc > 0 {
m.logger.Printf("[WARN] memberlist: Compound request had %d truncated messages %s", trunc, LogAddress(from))
}
// Handle each message
for _, part := range parts {
m.handleCommand(part, from, timestamp)
}
}
func (m *Memberlist) handlePing(buf []byte, from net.Addr) {
var p ping
if err := decode(buf, &p); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode ping request: %s %s", err, LogAddress(from))
return
}
// If node is provided, verify that it is for us
if p.Node != "" && p.Node != m.config.Name {
m.logger.Printf("[WARN] memberlist: Got ping for unexpected node '%s' %s", p.Node, LogAddress(from))
return
}
var ack ackResp
ack.SeqNo = p.SeqNo
if m.config.Ping != nil {
ack.Payload = m.config.Ping.AckPayload()
}
if err := m.encodeAndSendMsg(from.String(), ackRespMsg, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogAddress(from))
}
}
func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) {
var ind indirectPingReq
if err := decode(buf, &ind); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode indirect ping request: %s %s", err, LogAddress(from))
return
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port.
if m.ProtocolVersion() < 2 || ind.Port == 0 {
ind.Port = uint16(m.config.BindPort)
}
// Send a ping to the correct host.
localSeqNo := m.nextSeqNo()
ping := ping{SeqNo: localSeqNo, Node: ind.Node}
// Setup a response handler to relay the ack
cancelCh := make(chan struct{})
respHandler := func(payload []byte, timestamp time.Time) {
// Try to prevent the nack if we've caught it in time.
close(cancelCh)
// Forward the ack back to the requestor.
ack := ackResp{ind.SeqNo, nil}
if err := m.encodeAndSendMsg(from.String(), ackRespMsg, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogAddress(from))
}
}
m.setAckHandler(localSeqNo, respHandler, m.config.ProbeTimeout)
// Send the ping.
addr := joinHostPort(net.IP(ind.Target).String(), ind.Port)
if err := m.encodeAndSendMsg(addr, pingMsg, &ping); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ping: %s %s", err, LogAddress(from))
}
// Setup a timer to fire off a nack if no ack is seen in time.
if ind.Nack {
go func() {
select {
case <-cancelCh:
return
case <-time.After(m.config.ProbeTimeout):
nack := nackResp{ind.SeqNo}
if err := m.encodeAndSendMsg(from.String(), nackRespMsg, &nack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send nack: %s %s", err, LogAddress(from))
}
}
}()
}
}
func (m *Memberlist) handleAck(buf []byte, from net.Addr, timestamp time.Time) {
var ack ackResp
if err := decode(buf, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode ack response: %s %s", err, LogAddress(from))
return
}
m.invokeAckHandler(ack, timestamp)
}
func (m *Memberlist) handleNack(buf []byte, from net.Addr) {
var nack nackResp
if err := decode(buf, &nack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode nack response: %s %s", err, LogAddress(from))
return
}
m.invokeNackHandler(nack)
}
func (m *Memberlist) handleSuspect(buf []byte, from net.Addr) {
var sus suspect
if err := decode(buf, &sus); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode suspect message: %s %s", err, LogAddress(from))
return
}
m.suspectNode(&sus)
}
func (m *Memberlist) handleAlive(buf []byte, from net.Addr) {
var live alive
if err := decode(buf, &live); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode alive message: %s %s", err, LogAddress(from))
return
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port
if m.ProtocolVersion() < 2 || live.Port == 0 {
live.Port = uint16(m.config.BindPort)
}
m.aliveNode(&live, nil, false)
}
func (m *Memberlist) handleDead(buf []byte, from net.Addr) {
var d dead
if err := decode(buf, &d); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode dead message: %s %s", err, LogAddress(from))
return
}
m.deadNode(&d)
}
// handleUser is used to notify channels of incoming user data
func (m *Memberlist) handleUser(buf []byte, from net.Addr) {
d := m.config.Delegate
if d != nil {
d.NotifyMsg(buf)
}
}
// handleCompressed is used to unpack a compressed message
func (m *Memberlist) handleCompressed(buf []byte, from net.Addr, timestamp time.Time) {
// Try to decode the payload
payload, err := decompressPayload(buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decompress payload: %v %s", err, LogAddress(from))
return
}
// Recursively handle the payload
m.handleCommand(payload, from, timestamp)
}
// encodeAndSendMsg is used to combine the encoding and sending steps
func (m *Memberlist) encodeAndSendMsg(addr string, msgType messageType, msg interface{}) error {
out, err := encode(msgType, msg)
if err != nil {
return err
}
if err := m.sendMsg(addr, out.Bytes()); err != nil {
return err
}
return nil
}
// sendMsg is used to send a message via packet to another host. It will
// opportunistically create a compoundMsg and piggy back other broadcasts.
func (m *Memberlist) sendMsg(addr string, msg []byte) error {
// Check if we can piggy back any messages
bytesAvail := m.config.UDPBufferSize - len(msg) - compoundHeaderOverhead
if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing {
bytesAvail -= encryptOverhead(m.encryptionVersion())
}
extra := m.getBroadcasts(compoundOverhead, bytesAvail)
// Fast path if nothing to piggypack
if len(extra) == 0 {
return m.rawSendMsgPacket(addr, nil, msg)
}
// Join all the messages
msgs := make([][]byte, 0, 1+len(extra))
msgs = append(msgs, msg)
msgs = append(msgs, extra...)
// Create a compound message
compound := makeCompoundMessage(msgs)
// Send the message
return m.rawSendMsgPacket(addr, nil, compound.Bytes())
}
// rawSendMsgPacket is used to send message via packet to another host without
// modification, other than compression or encryption if enabled.
func (m *Memberlist) rawSendMsgPacket(addr string, node *Node, msg []byte) error {
// Check if we have compression enabled
if m.config.EnableCompression {
buf, err := compressPayload(msg)
if err != nil {
m.logger.Printf("[WARN] memberlist: Failed to compress payload: %v", err)
} else {
// Only use compression if it reduced the size
if buf.Len() < len(msg) {
msg = buf.Bytes()
}
}
}
// Try to look up the destination node
if node == nil {
toAddr, _, err := net.SplitHostPort(addr)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to parse address %q: %v", addr, err)
return err
}
m.nodeLock.RLock()
nodeState, ok := m.nodeMap[toAddr]
m.nodeLock.RUnlock()
if ok {
node = &nodeState.Node
}
}
// Add a CRC to the end of the payload if the recipient understands
// ProtocolVersion >= 5
if node != nil && node.PMax >= 5 {
crc := crc32.ChecksumIEEE(msg)
header := make([]byte, 5, 5+len(msg))
header[0] = byte(hasCrcMsg)
binary.BigEndian.PutUint32(header[1:], crc)
msg = append(header, msg...)
}
// Check if we have encryption enabled
if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing {
// Encrypt the payload
var buf bytes.Buffer
primaryKey := m.config.Keyring.GetPrimaryKey()
err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err)
return err
}
msg = buf.Bytes()
}
metrics.IncrCounter([]string{"memberlist", "udp", "sent"}, float32(len(msg)))
_, err := m.transport.WriteTo(msg, addr)
return err
}
// rawSendMsgStream is used to stream a message to another host without
// modification, other than applying compression and encryption if enabled.
func (m *Memberlist) rawSendMsgStream(conn net.Conn, sendBuf []byte) error {
// Check if compresion is enabled
if m.config.EnableCompression {
compBuf, err := compressPayload(sendBuf)
if err != nil {
m.logger.Printf("[ERROR] memberlist: Failed to compress payload: %v", err)
} else {
sendBuf = compBuf.Bytes()
}
}
// Check if encryption is enabled
if m.config.EncryptionEnabled() && m.config.GossipVerifyOutgoing {
crypt, err := m.encryptLocalState(sendBuf)
if err != nil {
m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err)
return err
}
sendBuf = crypt
}
// Write out the entire send buffer
metrics.IncrCounter([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf)))
if n, err := conn.Write(sendBuf); err != nil {
return err
} else if n != len(sendBuf) {
return fmt.Errorf("only %d of %d bytes written", n, len(sendBuf))
}
return nil
}
// sendUserMsg is used to stream a user message to another host.
func (m *Memberlist) sendUserMsg(addr string, sendBuf []byte) error {
conn, err := m.transport.DialTimeout(addr, m.config.TCPTimeout)
if err != nil {
return err
}
defer conn.Close()
bufConn := bytes.NewBuffer(nil)
if err := bufConn.WriteByte(byte(userMsg)); err != nil {
return err
}
header := userMsgHeader{UserMsgLen: len(sendBuf)}
hd := codec.MsgpackHandle{}
enc := codec.NewEncoder(bufConn, &hd)
if err := enc.Encode(&header); err != nil {
return err
}
if _, err := bufConn.Write(sendBuf); err != nil {
return err
}
return m.rawSendMsgStream(conn, bufConn.Bytes())
}
// sendAndReceiveState is used to initiate a push/pull over a stream with a
// remote host.
func (m *Memberlist) sendAndReceiveState(addr string, join bool) ([]pushNodeState, []byte, error) {
// Attempt to connect
conn, err := m.transport.DialTimeout(addr, m.config.TCPTimeout)
if err != nil {
return nil, nil, err
}
defer conn.Close()
m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s", conn.RemoteAddr())
metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1)
// Send our state
if err := m.sendLocalState(conn, join); err != nil {
return nil, nil, err
}
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
msgType, bufConn, dec, err := m.readStream(conn)
if err != nil {
return nil, nil, err
}
if msgType == errMsg {
var resp errResp
if err := dec.Decode(&resp); err != nil {
return nil, nil, err
}
return nil, nil, fmt.Errorf("remote error: %v", resp.Error)
}
// Quit if not push/pull
if msgType != pushPullMsg {
err := fmt.Errorf("received invalid msgType (%d), expected pushPullMsg (%d) %s", msgType, pushPullMsg, LogConn(conn))
return nil, nil, err
}
// Read remote state
_, remoteNodes, userState, err := m.readRemoteState(bufConn, dec)
return remoteNodes, userState, err
}
// sendLocalState is invoked to send our local state over a stream connection.
func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error {
// Setup a deadline
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
// Prepare the local node state
m.nodeLock.RLock()
localNodes := make([]pushNodeState, len(m.nodes))
for idx, n := range m.nodes {
localNodes[idx].Name = n.Name
localNodes[idx].Addr = n.Addr
localNodes[idx].Port = n.Port
localNodes[idx].Incarnation = n.Incarnation
localNodes[idx].State = n.State
localNodes[idx].Meta = n.Meta
localNodes[idx].Vsn = []uint8{
n.PMin, n.PMax, n.PCur,
n.DMin, n.DMax, n.DCur,
}
}
m.nodeLock.RUnlock()
// Get the delegate state
var userData []byte
if m.config.Delegate != nil {
userData = m.config.Delegate.LocalState(join)
}
// Create a bytes buffer writer
bufConn := bytes.NewBuffer(nil)
// Send our node state
header := pushPullHeader{Nodes: len(localNodes), UserStateLen: len(userData), Join: join}
hd := codec.MsgpackHandle{}
enc := codec.NewEncoder(bufConn, &hd)
// Begin state push
if _, err := bufConn.Write([]byte{byte(pushPullMsg)}); err != nil {
return err
}
if err := enc.Encode(&header); err != nil {
return err
}
for i := 0; i < header.Nodes; i++ {
if err := enc.Encode(&localNodes[i]); err != nil {
return err
}
}
// Write the user state as well
if userData != nil {
if _, err := bufConn.Write(userData); err != nil {
return err
}
}
// Get the send buffer
return m.rawSendMsgStream(conn, bufConn.Bytes())
}
// encryptLocalState is used to help encrypt local state before sending
func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) {
var buf bytes.Buffer
// Write the encryptMsg byte
buf.WriteByte(byte(encryptMsg))
// Write the size of the message
sizeBuf := make([]byte, 4)
encVsn := m.encryptionVersion()
encLen := encryptedLength(encVsn, len(sendBuf))
binary.BigEndian.PutUint32(sizeBuf, uint32(encLen))
buf.Write(sizeBuf)
// Write the encrypted cipher text to the buffer
key := m.config.Keyring.GetPrimaryKey()
err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decryptRemoteState is used to help decrypt the remote state
func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) {
// Read in enough to determine message length
cipherText := bytes.NewBuffer(nil)
cipherText.WriteByte(byte(encryptMsg))
_, err := io.CopyN(cipherText, bufConn, 4)
if err != nil {
return nil, err
}
// Ensure we aren't asked to download too much. This is to guard against
// an attack vector where a huge amount of state is sent
moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5])
if moreBytes > maxPushStateBytes {
return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes)
}
// Read in the rest of the payload
_, err = io.CopyN(cipherText, bufConn, int64(moreBytes))
if err != nil {
return nil, err
}
// Decrypt the cipherText
dataBytes := cipherText.Bytes()[:5]
cipherBytes := cipherText.Bytes()[5:]
// Decrypt the payload
keys := m.config.Keyring.GetKeys()
return decryptPayload(keys, cipherBytes, dataBytes)
}
// readStream is used to read from a stream connection, decrypting and
// decompressing the stream if necessary.
func (m *Memberlist) readStream(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) {
// Created a buffered reader
var bufConn io.Reader = bufio.NewReader(conn)
// Read the message type
buf := [1]byte{0}
if _, err := bufConn.Read(buf[:]); err != nil {
return 0, nil, nil, err
}
msgType := messageType(buf[0])
// Check if the message is encrypted
if msgType == encryptMsg {
if !m.config.EncryptionEnabled() {
return 0, nil, nil,
fmt.Errorf("Remote state is encrypted and encryption is not configured")
}
plain, err := m.decryptRemoteState(bufConn)
if err != nil {
return 0, nil, nil, err
}
// Reset message type and bufConn
msgType = messageType(plain[0])
bufConn = bytes.NewReader(plain[1:])
} else if m.config.EncryptionEnabled() && m.config.GossipVerifyIncoming {
return 0, nil, nil,
fmt.Errorf("Encryption is configured but remote state is not encrypted")
}
// Get the msgPack decoders
hd := codec.MsgpackHandle{}
dec := codec.NewDecoder(bufConn, &hd)
// Check if we have a compressed message
if msgType == compressMsg {
var c compress
if err := dec.Decode(&c); err != nil {
return 0, nil, nil, err
}
decomp, err := decompressBuffer(&c)
if err != nil {
return 0, nil, nil, err
}
// Reset the message type
msgType = messageType(decomp[0])
// Create a new bufConn
bufConn = bytes.NewReader(decomp[1:])
// Create a new decoder
dec = codec.NewDecoder(bufConn, &hd)
}
return msgType, bufConn, dec, nil
}
// readRemoteState is used to read the remote state from a connection
func (m *Memberlist) readRemoteState(bufConn io.Reader, dec *codec.Decoder) (bool, []pushNodeState, []byte, error) {
// Read the push/pull header
var header pushPullHeader
if err := dec.Decode(&header); err != nil {
return false, nil, nil, err
}
// Allocate space for the transfer
remoteNodes := make([]pushNodeState, header.Nodes)
// Try to decode all the states
for i := 0; i < header.Nodes; i++ {
if err := dec.Decode(&remoteNodes[i]); err != nil {
return false, nil, nil, err
}
}
// Read the remote user state into a buffer
var userBuf []byte
if header.UserStateLen > 0 {
userBuf = make([]byte, header.UserStateLen)
bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserStateLen)
if err == nil && bytes != header.UserStateLen {
err = fmt.Errorf(
"Failed to read full user state (%d / %d)",
bytes, header.UserStateLen)
}
if err != nil {
return false, nil, nil, err
}
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port
for idx := range remoteNodes {
if m.ProtocolVersion() < 2 || remoteNodes[idx].Port == 0 {
remoteNodes[idx].Port = uint16(m.config.BindPort)
}
}
return header.Join, remoteNodes, userBuf, nil
}
// mergeRemoteState is used to merge the remote state with our local state
func (m *Memberlist) mergeRemoteState(join bool, remoteNodes []pushNodeState, userBuf []byte) error {
if err := m.verifyProtocol(remoteNodes); err != nil {
return err
}
// Invoke the merge delegate if any
if join && m.config.Merge != nil {
nodes := make([]*Node, len(remoteNodes))
for idx, n := range remoteNodes {
nodes[idx] = &Node{
Name: n.Name,
Addr: n.Addr,
Port: n.Port,
Meta: n.Meta,
PMin: n.Vsn[0],
PMax: n.Vsn[1],
PCur: n.Vsn[2],
DMin: n.Vsn[3],
DMax: n.Vsn[4],
DCur: n.Vsn[5],
}
}
if err := m.config.Merge.NotifyMerge(nodes); err != nil {
return err
}
}
// Merge the membership state
m.mergeState(remoteNodes)
// Invoke the delegate for user state
if userBuf != nil && m.config.Delegate != nil {
m.config.Delegate.MergeRemoteState(userBuf, join)
}
return nil
}
// readUserMsg is used to decode a userMsg from a stream.
func (m *Memberlist) readUserMsg(bufConn io.Reader, dec *codec.Decoder) error {
// Read the user message header
var header userMsgHeader
if err := dec.Decode(&header); err != nil {
return err
}
// Read the user message into a buffer
var userBuf []byte
if header.UserMsgLen > 0 {
userBuf = make([]byte, header.UserMsgLen)
bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserMsgLen)
if err == nil && bytes != header.UserMsgLen {
err = fmt.Errorf(
"Failed to read full user message (%d / %d)",
bytes, header.UserMsgLen)
}
if err != nil {
return err
}
d := m.config.Delegate
if d != nil {
d.NotifyMsg(userBuf)
}
}
return nil
}
// sendPingAndWaitForAck makes a stream connection to the given address, sends
// a ping, and waits for an ack. All of this is done as a series of blocking
// operations, given the deadline. The bool return parameter is true if we
// we able to round trip a ping to the other node.
func (m *Memberlist) sendPingAndWaitForAck(addr string, ping ping, deadline time.Time) (bool, error) {
conn, err := m.transport.DialTimeout(addr, deadline.Sub(time.Now()))
if err != nil {
// If the node is actually dead we expect this to fail, so we
// shouldn't spam the logs with it. After this point, errors
// with the connection are real, unexpected errors and should
// get propagated up.
return false, nil
}
defer conn.Close()
conn.SetDeadline(deadline)
out, err := encode(pingMsg, &ping)
if err != nil {
return false, err
}
if err = m.rawSendMsgStream(conn, out.Bytes()); err != nil {
return false, err
}
msgType, _, dec, err := m.readStream(conn)
if err != nil {
return false, err
}
if msgType != ackRespMsg {
return false, fmt.Errorf("Unexpected msgType (%d) from ping %s", msgType, LogConn(conn))
}
var ack ackResp
if err = dec.Decode(&ack); err != nil {
return false, err
}
if ack.SeqNo != ping.SeqNo {
return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d)", ack.SeqNo, ping.SeqNo, LogConn(conn))
}
return true, nil
}