blob: b92dccb101258531d1acb3b905df2b4171a86e62 [file] [log] [blame]
package memberlist
import (
"bufio"
"bytes"
"encoding/binary"
"fmt"
"io"
"net"
"time"
"github.com/armon/go-metrics"
"github.com/hashicorp/go-msgpack/codec"
)
// This is the minimum and maximum protocol version that we can
// _understand_. We're allowed to speak at any version within this
// range. This range is inclusive.
const (
ProtocolVersionMin uint8 = 1
// Version 3 added support for TCP pings but we kept the default
// protocol version at 2 to ease transition to this new feature.
// A memberlist speaking version 2 of the protocol will attempt
// to TCP ping another memberlist who understands version 3 or
// greater.
ProtocolVersion2Compatible = 2
ProtocolVersionMax = 3
)
// messageType is an integer ID of a type of message that can be received
// on network channels from other members.
type messageType uint8
// The list of available message types.
const (
pingMsg messageType = iota
indirectPingMsg
ackRespMsg
suspectMsg
aliveMsg
deadMsg
pushPullMsg
compoundMsg
userMsg // User mesg, not handled by us
compressMsg
encryptMsg
)
// compressionType is used to specify the compression algorithm
type compressionType uint8
const (
lzwAlgo compressionType = iota
)
const (
MetaMaxSize = 512 // Maximum size for node meta data
compoundHeaderOverhead = 2 // Assumed header overhead
compoundOverhead = 2 // Assumed overhead per entry in compoundHeader
udpBufSize = 65536
udpRecvBuf = 2 * 1024 * 1024
udpSendBuf = 1400
userMsgOverhead = 1
blockingWarning = 10 * time.Millisecond // Warn if a UDP packet takes this long to process
maxPushStateBytes = 10 * 1024 * 1024
)
// ping request sent directly to node
type ping struct {
SeqNo uint32
// Node is sent so the target can verify they are
// the intended recipient. This is to protect again an agent
// restart with a new name.
Node string
}
// indirect ping sent to an indirect ndoe
type indirectPingReq struct {
SeqNo uint32
Target []byte
Port uint16
Node string
}
// ack response is sent for a ping
type ackResp struct {
SeqNo uint32
Payload []byte
}
// suspect is broadcast when we suspect a node is dead
type suspect struct {
Incarnation uint32
Node string
From string // Include who is suspecting
}
// alive is broadcast when we know a node is alive.
// Overloaded for nodes joining
type alive struct {
Incarnation uint32
Node string
Addr []byte
Port uint16
Meta []byte
// The versions of the protocol/delegate that are being spoken, order:
// pmin, pmax, pcur, dmin, dmax, dcur
Vsn []uint8
}
// dead is broadcast when we confirm a node is dead
// Overloaded for nodes leaving
type dead struct {
Incarnation uint32
Node string
From string // Include who is suspecting
}
// pushPullHeader is used to inform the
// otherside how many states we are transfering
type pushPullHeader struct {
Nodes int
UserStateLen int // Encodes the byte lengh of user state
Join bool // Is this a join request or a anti-entropy run
}
// userMsgHeader is used to encapsulate a userMsg
type userMsgHeader struct {
UserMsgLen int // Encodes the byte lengh of user state
}
// pushNodeState is used for pushPullReq when we are
// transfering out node states
type pushNodeState struct {
Name string
Addr []byte
Port uint16
Meta []byte
Incarnation uint32
State nodeStateType
Vsn []uint8 // Protocol versions
}
// compress is used to wrap an underlying payload
// using a specified compression algorithm
type compress struct {
Algo compressionType
Buf []byte
}
// msgHandoff is used to transfer a message between goroutines
type msgHandoff struct {
msgType messageType
buf []byte
from net.Addr
}
// encryptionVersion returns the encryption version to use
func (m *Memberlist) encryptionVersion() encryptionVersion {
switch m.ProtocolVersion() {
case 1:
return 0
default:
return 1
}
}
// setUDPRecvBuf is used to resize the UDP receive window. The function
// attempts to set the read buffer to `udpRecvBuf` but backs off until
// the read buffer can be set.
func setUDPRecvBuf(c *net.UDPConn) {
size := udpRecvBuf
for {
if err := c.SetReadBuffer(size); err == nil {
break
}
size = size / 2
}
}
// tcpListen listens for and handles incoming connections
func (m *Memberlist) tcpListen() {
for {
conn, err := m.tcpListener.AcceptTCP()
if err != nil {
if m.shutdown {
break
}
m.logger.Printf("[ERR] memberlist: Error accepting TCP connection: %s", err)
continue
}
go m.handleConn(conn)
}
}
// handleConn handles a single incoming TCP connection
func (m *Memberlist) handleConn(conn *net.TCPConn) {
m.logger.Printf("[DEBUG] memberlist: TCP connection %s", LogConn(conn))
defer conn.Close()
metrics.IncrCounter([]string{"memberlist", "tcp", "accept"}, 1)
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
msgType, bufConn, dec, err := m.readTCP(conn)
if err != nil {
m.logger.Printf("[ERR] memberlist: failed to receive: %s %s", err, LogConn(conn))
return
}
switch msgType {
case userMsg:
if err := m.readUserMsg(bufConn, dec); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to receive user message: %s %s", err, LogConn(conn))
}
case pushPullMsg:
join, remoteNodes, userState, err := m.readRemoteState(bufConn, dec)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to read remote state: %s %s", err, LogConn(conn))
return
}
if err := m.sendLocalState(conn, join); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to push local state: %s %s", err, LogConn(conn))
return
}
if err := m.mergeRemoteState(join, remoteNodes, userState); err != nil {
m.logger.Printf("[ERR] memberlist: Failed push/pull merge: %s %s", err, LogConn(conn))
return
}
case pingMsg:
var p ping
if err := dec.Decode(&p); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode TCP ping: %s %s", err, LogConn(conn))
return
}
if p.Node != "" && p.Node != m.config.Name {
m.logger.Printf("[WARN] memberlist: Got ping for unexpected node %s %s", p.Node, LogConn(conn))
return
}
ack := ackResp{p.SeqNo, nil}
out, err := encode(ackRespMsg, &ack)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to encode TCP ack: %s", err)
return
}
err = m.rawSendMsgTCP(conn, out.Bytes())
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send TCP ack: %s %s", err, LogConn(conn))
return
}
default:
m.logger.Printf("[ERR] memberlist: Received invalid msgType (%d) %s", msgType, LogConn(conn))
}
}
// udpListen listens for and handles incoming UDP packets
func (m *Memberlist) udpListen() {
var n int
var addr net.Addr
var err error
var lastPacket time.Time
for {
// Do a check for potentially blocking operations
if !lastPacket.IsZero() && time.Now().Sub(lastPacket) > blockingWarning {
diff := time.Now().Sub(lastPacket)
m.logger.Printf(
"[DEBUG] memberlist: Potential blocking operation. Last command took %v",
diff)
}
// Create a new buffer
// TODO: Use Sync.Pool eventually
buf := make([]byte, udpBufSize)
// Read a packet
n, addr, err = m.udpListener.ReadFrom(buf)
if err != nil {
if m.shutdown {
break
}
m.logger.Printf("[ERR] memberlist: Error reading UDP packet: %s", err)
continue
}
// Capture the reception time of the packet as close to the
// system calls as possible.
lastPacket = time.Now()
// Check the length
if n < 1 {
m.logger.Printf("[ERR] memberlist: UDP packet too short (%d bytes) %s",
len(buf), LogAddress(addr))
continue
}
// Ingest this packet
metrics.IncrCounter([]string{"memberlist", "udp", "received"}, float32(n))
m.ingestPacket(buf[:n], addr, lastPacket)
}
}
func (m *Memberlist) ingestPacket(buf []byte, from net.Addr, timestamp time.Time) {
// Check if encryption is enabled
if m.config.EncryptionEnabled() {
// Decrypt the payload
plain, err := decryptPayload(m.config.Keyring.GetKeys(), buf, nil)
if err != nil {
m.logger.Printf("[ERR] memberlist: Decrypt packet failed: %v %s", err, LogAddress(from))
return
}
// Continue processing the plaintext buffer
buf = plain
}
// Handle the command
m.handleCommand(buf, from, timestamp)
}
func (m *Memberlist) handleCommand(buf []byte, from net.Addr, timestamp time.Time) {
// Decode the message type
msgType := messageType(buf[0])
buf = buf[1:]
// Switch on the msgType
switch msgType {
case compoundMsg:
m.handleCompound(buf, from, timestamp)
case compressMsg:
m.handleCompressed(buf, from, timestamp)
case pingMsg:
m.handlePing(buf, from)
case indirectPingMsg:
m.handleIndirectPing(buf, from)
case ackRespMsg:
m.handleAck(buf, from, timestamp)
case suspectMsg:
fallthrough
case aliveMsg:
fallthrough
case deadMsg:
fallthrough
case userMsg:
select {
case m.handoff <- msgHandoff{msgType, buf, from}:
default:
m.logger.Printf("[WARN] memberlist: UDP handler queue full, dropping message (%d) %s", msgType, LogAddress(from))
}
default:
m.logger.Printf("[ERR] memberlist: UDP msg type (%d) not supported %s", msgType, LogAddress(from))
}
}
// udpHandler processes messages received over UDP, but is decoupled
// from the listener to avoid blocking the listener which may cause
// ping/ack messages to be delayed.
func (m *Memberlist) udpHandler() {
for {
select {
case msg := <-m.handoff:
msgType := msg.msgType
buf := msg.buf
from := msg.from
switch msgType {
case suspectMsg:
m.handleSuspect(buf, from)
case aliveMsg:
m.handleAlive(buf, from)
case deadMsg:
m.handleDead(buf, from)
case userMsg:
m.handleUser(buf, from)
default:
m.logger.Printf("[ERR] memberlist: UDP msg type (%d) not supported %s (handler)", msgType, LogAddress(from))
}
case <-m.shutdownCh:
return
}
}
}
func (m *Memberlist) handleCompound(buf []byte, from net.Addr, timestamp time.Time) {
// Decode the parts
trunc, parts, err := decodeCompoundMessage(buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode compound request: %s %s", err, LogAddress(from))
return
}
// Log any truncation
if trunc > 0 {
m.logger.Printf("[WARN] memberlist: Compound request had %d truncated messages %s", trunc, LogAddress(from))
}
// Handle each message
for _, part := range parts {
m.handleCommand(part, from, timestamp)
}
}
func (m *Memberlist) handlePing(buf []byte, from net.Addr) {
var p ping
if err := decode(buf, &p); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode ping request: %s %s", err, LogAddress(from))
return
}
// If node is provided, verify that it is for us
if p.Node != "" && p.Node != m.config.Name {
m.logger.Printf("[WARN] memberlist: Got ping for unexpected node '%s' %s", p.Node, LogAddress(from))
return
}
var ack ackResp
ack.SeqNo = p.SeqNo
if m.config.Ping != nil {
ack.Payload = m.config.Ping.AckPayload()
}
if err := m.encodeAndSendMsg(from, ackRespMsg, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ack: %s %s", err, LogAddress(from))
}
}
func (m *Memberlist) handleIndirectPing(buf []byte, from net.Addr) {
var ind indirectPingReq
if err := decode(buf, &ind); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode indirect ping request: %s %s", err, LogAddress(from))
return
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port
if m.ProtocolVersion() < 2 || ind.Port == 0 {
ind.Port = uint16(m.config.BindPort)
}
// Send a ping to the correct host
localSeqNo := m.nextSeqNo()
ping := ping{SeqNo: localSeqNo, Node: ind.Node}
destAddr := &net.UDPAddr{IP: ind.Target, Port: int(ind.Port)}
// Setup a response handler to relay the ack
respHandler := func(payload []byte, timestamp time.Time) {
ack := ackResp{ind.SeqNo, nil}
if err := m.encodeAndSendMsg(from, ackRespMsg, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to forward ack: %s %s", err, LogAddress(from))
}
}
m.setAckHandler(localSeqNo, respHandler, m.config.ProbeTimeout)
// Send the ping
if err := m.encodeAndSendMsg(destAddr, pingMsg, &ping); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to send ping: %s %s", err, LogAddress(from))
}
}
func (m *Memberlist) handleAck(buf []byte, from net.Addr, timestamp time.Time) {
var ack ackResp
if err := decode(buf, &ack); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode ack response: %s %s", err, LogAddress(from))
return
}
m.invokeAckHandler(ack, timestamp)
}
func (m *Memberlist) handleSuspect(buf []byte, from net.Addr) {
var sus suspect
if err := decode(buf, &sus); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode suspect message: %s %s", err, LogAddress(from))
return
}
m.suspectNode(&sus)
}
func (m *Memberlist) handleAlive(buf []byte, from net.Addr) {
var live alive
if err := decode(buf, &live); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode alive message: %s %s", err, LogAddress(from))
return
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port
if m.ProtocolVersion() < 2 || live.Port == 0 {
live.Port = uint16(m.config.BindPort)
}
m.aliveNode(&live, nil, false)
}
func (m *Memberlist) handleDead(buf []byte, from net.Addr) {
var d dead
if err := decode(buf, &d); err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decode dead message: %s %s", err, LogAddress(from))
return
}
m.deadNode(&d)
}
// handleUser is used to notify channels of incoming user data
func (m *Memberlist) handleUser(buf []byte, from net.Addr) {
d := m.config.Delegate
if d != nil {
d.NotifyMsg(buf)
}
}
// handleCompressed is used to unpack a compressed message
func (m *Memberlist) handleCompressed(buf []byte, from net.Addr, timestamp time.Time) {
// Try to decode the payload
payload, err := decompressPayload(buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Failed to decompress payload: %v %s", err, LogAddress(from))
return
}
// Recursively handle the payload
m.handleCommand(payload, from, timestamp)
}
// encodeAndSendMsg is used to combine the encoding and sending steps
func (m *Memberlist) encodeAndSendMsg(to net.Addr, msgType messageType, msg interface{}) error {
out, err := encode(msgType, msg)
if err != nil {
return err
}
if err := m.sendMsg(to, out.Bytes()); err != nil {
return err
}
return nil
}
// sendMsg is used to send a UDP message to another host. It will opportunistically
// create a compoundMsg and piggy back other broadcasts
func (m *Memberlist) sendMsg(to net.Addr, msg []byte) error {
// Check if we can piggy back any messages
bytesAvail := udpSendBuf - len(msg) - compoundHeaderOverhead
if m.config.EncryptionEnabled() {
bytesAvail -= encryptOverhead(m.encryptionVersion())
}
extra := m.getBroadcasts(compoundOverhead, bytesAvail)
// Fast path if nothing to piggypack
if len(extra) == 0 {
return m.rawSendMsgUDP(to, msg)
}
// Join all the messages
msgs := make([][]byte, 0, 1+len(extra))
msgs = append(msgs, msg)
msgs = append(msgs, extra...)
// Create a compound message
compound := makeCompoundMessage(msgs)
// Send the message
return m.rawSendMsgUDP(to, compound.Bytes())
}
// rawSendMsgUDP is used to send a UDP message to another host without modification
func (m *Memberlist) rawSendMsgUDP(to net.Addr, msg []byte) error {
// Check if we have compression enabled
if m.config.EnableCompression {
buf, err := compressPayload(msg)
if err != nil {
m.logger.Printf("[WARN] memberlist: Failed to compress payload: %v", err)
} else {
// Only use compression if it reduced the size
if buf.Len() < len(msg) {
msg = buf.Bytes()
}
}
}
// Check if we have encryption enabled
if m.config.EncryptionEnabled() {
// Encrypt the payload
var buf bytes.Buffer
primaryKey := m.config.Keyring.GetPrimaryKey()
err := encryptPayload(m.encryptionVersion(), primaryKey, msg, nil, &buf)
if err != nil {
m.logger.Printf("[ERR] memberlist: Encryption of message failed: %v", err)
return err
}
msg = buf.Bytes()
}
metrics.IncrCounter([]string{"memberlist", "udp", "sent"}, float32(len(msg)))
_, err := m.udpListener.WriteTo(msg, to)
return err
}
// rawSendMsgTCP is used to send a TCP message to another host without modification
func (m *Memberlist) rawSendMsgTCP(conn net.Conn, sendBuf []byte) error {
// Check if compresion is enabled
if m.config.EnableCompression {
compBuf, err := compressPayload(sendBuf)
if err != nil {
m.logger.Printf("[ERROR] memberlist: Failed to compress payload: %v", err)
} else {
sendBuf = compBuf.Bytes()
}
}
// Check if encryption is enabled
if m.config.EncryptionEnabled() {
crypt, err := m.encryptLocalState(sendBuf)
if err != nil {
m.logger.Printf("[ERROR] memberlist: Failed to encrypt local state: %v", err)
return err
}
sendBuf = crypt
}
// Write out the entire send buffer
metrics.IncrCounter([]string{"memberlist", "tcp", "sent"}, float32(len(sendBuf)))
if n, err := conn.Write(sendBuf); err != nil {
return err
} else if n != len(sendBuf) {
return fmt.Errorf("only %d of %d bytes written", n, len(sendBuf))
}
return nil
}
// sendTCPUserMsg is used to send a TCP userMsg to another host
func (m *Memberlist) sendTCPUserMsg(to net.Addr, sendBuf []byte) error {
dialer := net.Dialer{Timeout: m.config.TCPTimeout}
conn, err := dialer.Dial("tcp", to.String())
if err != nil {
return err
}
defer conn.Close()
bufConn := bytes.NewBuffer(nil)
if err := bufConn.WriteByte(byte(userMsg)); err != nil {
return err
}
// Send our node state
header := userMsgHeader{UserMsgLen: len(sendBuf)}
hd := codec.MsgpackHandle{}
enc := codec.NewEncoder(bufConn, &hd)
if err := enc.Encode(&header); err != nil {
return err
}
if _, err := bufConn.Write(sendBuf); err != nil {
return err
}
return m.rawSendMsgTCP(conn, bufConn.Bytes())
}
// sendAndReceiveState is used to initiate a push/pull over TCP with a remote node
func (m *Memberlist) sendAndReceiveState(addr []byte, port uint16, join bool) ([]pushNodeState, []byte, error) {
// Attempt to connect
dialer := net.Dialer{Timeout: m.config.TCPTimeout}
dest := net.TCPAddr{IP: addr, Port: int(port)}
conn, err := dialer.Dial("tcp", dest.String())
if err != nil {
return nil, nil, err
}
defer conn.Close()
m.logger.Printf("[DEBUG] memberlist: Initiating push/pull sync with: %s", conn.RemoteAddr())
metrics.IncrCounter([]string{"memberlist", "tcp", "connect"}, 1)
// Send our state
if err := m.sendLocalState(conn, join); err != nil {
return nil, nil, err
}
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
msgType, bufConn, dec, err := m.readTCP(conn)
if err != nil {
return nil, nil, err
}
// Quit if not push/pull
if msgType != pushPullMsg {
err := fmt.Errorf("received invalid msgType (%d), expected pushPullMsg (%d) %s", msgType, pushPullMsg, LogConn(conn))
return nil, nil, err
}
// Read remote state
_, remoteNodes, userState, err := m.readRemoteState(bufConn, dec)
return remoteNodes, userState, err
}
// sendLocalState is invoked to send our local state over a tcp connection
func (m *Memberlist) sendLocalState(conn net.Conn, join bool) error {
// Setup a deadline
conn.SetDeadline(time.Now().Add(m.config.TCPTimeout))
// Prepare the local node state
m.nodeLock.RLock()
localNodes := make([]pushNodeState, len(m.nodes))
for idx, n := range m.nodes {
localNodes[idx].Name = n.Name
localNodes[idx].Addr = n.Addr
localNodes[idx].Port = n.Port
localNodes[idx].Incarnation = n.Incarnation
localNodes[idx].State = n.State
localNodes[idx].Meta = n.Meta
localNodes[idx].Vsn = []uint8{
n.PMin, n.PMax, n.PCur,
n.DMin, n.DMax, n.DCur,
}
}
m.nodeLock.RUnlock()
// Get the delegate state
var userData []byte
if m.config.Delegate != nil {
userData = m.config.Delegate.LocalState(join)
}
// Create a bytes buffer writer
bufConn := bytes.NewBuffer(nil)
// Send our node state
header := pushPullHeader{Nodes: len(localNodes), UserStateLen: len(userData), Join: join}
hd := codec.MsgpackHandle{}
enc := codec.NewEncoder(bufConn, &hd)
// Begin state push
if _, err := bufConn.Write([]byte{byte(pushPullMsg)}); err != nil {
return err
}
if err := enc.Encode(&header); err != nil {
return err
}
for i := 0; i < header.Nodes; i++ {
if err := enc.Encode(&localNodes[i]); err != nil {
return err
}
}
// Write the user state as well
if userData != nil {
if _, err := bufConn.Write(userData); err != nil {
return err
}
}
// Get the send buffer
return m.rawSendMsgTCP(conn, bufConn.Bytes())
}
// encryptLocalState is used to help encrypt local state before sending
func (m *Memberlist) encryptLocalState(sendBuf []byte) ([]byte, error) {
var buf bytes.Buffer
// Write the encryptMsg byte
buf.WriteByte(byte(encryptMsg))
// Write the size of the message
sizeBuf := make([]byte, 4)
encVsn := m.encryptionVersion()
encLen := encryptedLength(encVsn, len(sendBuf))
binary.BigEndian.PutUint32(sizeBuf, uint32(encLen))
buf.Write(sizeBuf)
// Write the encrypted cipher text to the buffer
key := m.config.Keyring.GetPrimaryKey()
err := encryptPayload(encVsn, key, sendBuf, buf.Bytes()[:5], &buf)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// decryptRemoteState is used to help decrypt the remote state
func (m *Memberlist) decryptRemoteState(bufConn io.Reader) ([]byte, error) {
// Read in enough to determine message length
cipherText := bytes.NewBuffer(nil)
cipherText.WriteByte(byte(encryptMsg))
_, err := io.CopyN(cipherText, bufConn, 4)
if err != nil {
return nil, err
}
// Ensure we aren't asked to download too much. This is to guard against
// an attack vector where a huge amount of state is sent
moreBytes := binary.BigEndian.Uint32(cipherText.Bytes()[1:5])
if moreBytes > maxPushStateBytes {
return nil, fmt.Errorf("Remote node state is larger than limit (%d)", moreBytes)
}
// Read in the rest of the payload
_, err = io.CopyN(cipherText, bufConn, int64(moreBytes))
if err != nil {
return nil, err
}
// Decrypt the cipherText
dataBytes := cipherText.Bytes()[:5]
cipherBytes := cipherText.Bytes()[5:]
// Decrypt the payload
keys := m.config.Keyring.GetKeys()
return decryptPayload(keys, cipherBytes, dataBytes)
}
// readTCP is used to read the start of a TCP stream.
// it decrypts and decompresses the stream if necessary
func (m *Memberlist) readTCP(conn net.Conn) (messageType, io.Reader, *codec.Decoder, error) {
// Created a buffered reader
var bufConn io.Reader = bufio.NewReader(conn)
// Read the message type
buf := [1]byte{0}
if _, err := bufConn.Read(buf[:]); err != nil {
return 0, nil, nil, err
}
msgType := messageType(buf[0])
// Check if the message is encrypted
if msgType == encryptMsg {
if !m.config.EncryptionEnabled() {
return 0, nil, nil,
fmt.Errorf("Remote state is encrypted and encryption is not configured")
}
plain, err := m.decryptRemoteState(bufConn)
if err != nil {
return 0, nil, nil, err
}
// Reset message type and bufConn
msgType = messageType(plain[0])
bufConn = bytes.NewReader(plain[1:])
} else if m.config.EncryptionEnabled() {
return 0, nil, nil,
fmt.Errorf("Encryption is configured but remote state is not encrypted")
}
// Get the msgPack decoders
hd := codec.MsgpackHandle{}
dec := codec.NewDecoder(bufConn, &hd)
// Check if we have a compressed message
if msgType == compressMsg {
var c compress
if err := dec.Decode(&c); err != nil {
return 0, nil, nil, err
}
decomp, err := decompressBuffer(&c)
if err != nil {
return 0, nil, nil, err
}
// Reset the message type
msgType = messageType(decomp[0])
// Create a new bufConn
bufConn = bytes.NewReader(decomp[1:])
// Create a new decoder
dec = codec.NewDecoder(bufConn, &hd)
}
return msgType, bufConn, dec, nil
}
// readRemoteState is used to read the remote state from a connection
func (m *Memberlist) readRemoteState(bufConn io.Reader, dec *codec.Decoder) (bool, []pushNodeState, []byte, error) {
// Read the push/pull header
var header pushPullHeader
if err := dec.Decode(&header); err != nil {
return false, nil, nil, err
}
// Allocate space for the transfer
remoteNodes := make([]pushNodeState, header.Nodes)
// Try to decode all the states
for i := 0; i < header.Nodes; i++ {
if err := dec.Decode(&remoteNodes[i]); err != nil {
return false, nil, nil, err
}
}
// Read the remote user state into a buffer
var userBuf []byte
if header.UserStateLen > 0 {
userBuf = make([]byte, header.UserStateLen)
bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserStateLen)
if err == nil && bytes != header.UserStateLen {
err = fmt.Errorf(
"Failed to read full user state (%d / %d)",
bytes, header.UserStateLen)
}
if err != nil {
return false, nil, nil, err
}
}
// For proto versions < 2, there is no port provided. Mask old
// behavior by using the configured port
for idx := range remoteNodes {
if m.ProtocolVersion() < 2 || remoteNodes[idx].Port == 0 {
remoteNodes[idx].Port = uint16(m.config.BindPort)
}
}
return header.Join, remoteNodes, userBuf, nil
}
// mergeRemoteState is used to merge the remote state with our local state
func (m *Memberlist) mergeRemoteState(join bool, remoteNodes []pushNodeState, userBuf []byte) error {
if err := m.verifyProtocol(remoteNodes); err != nil {
return err
}
// Invoke the merge delegate if any
if join && m.config.Merge != nil {
nodes := make([]*Node, len(remoteNodes))
for idx, n := range remoteNodes {
nodes[idx] = &Node{
Name: n.Name,
Addr: n.Addr,
Port: n.Port,
Meta: n.Meta,
PMin: n.Vsn[0],
PMax: n.Vsn[1],
PCur: n.Vsn[2],
DMin: n.Vsn[3],
DMax: n.Vsn[4],
DCur: n.Vsn[5],
}
}
if err := m.config.Merge.NotifyMerge(nodes); err != nil {
return err
}
}
// Merge the membership state
m.mergeState(remoteNodes)
// Invoke the delegate for user state
if userBuf != nil && m.config.Delegate != nil {
m.config.Delegate.MergeRemoteState(userBuf, join)
}
return nil
}
// readUserMsg is used to decode a userMsg from a TCP stream
func (m *Memberlist) readUserMsg(bufConn io.Reader, dec *codec.Decoder) error {
// Read the user message header
var header userMsgHeader
if err := dec.Decode(&header); err != nil {
return err
}
// Read the user message into a buffer
var userBuf []byte
if header.UserMsgLen > 0 {
userBuf = make([]byte, header.UserMsgLen)
bytes, err := io.ReadAtLeast(bufConn, userBuf, header.UserMsgLen)
if err == nil && bytes != header.UserMsgLen {
err = fmt.Errorf(
"Failed to read full user message (%d / %d)",
bytes, header.UserMsgLen)
}
if err != nil {
return err
}
d := m.config.Delegate
if d != nil {
d.NotifyMsg(userBuf)
}
}
return nil
}
// sendPingAndWaitForAck makes a TCP connection to the given address, sends
// a ping, and waits for an ack. All of this is done as a series of blocking
// operations, given the deadline. The bool return parameter is true if we
// we able to round trip a ping to the other node.
func (m *Memberlist) sendPingAndWaitForAck(destAddr net.Addr, ping ping, deadline time.Time) (bool, error) {
dialer := net.Dialer{Deadline: deadline}
conn, err := dialer.Dial("tcp", destAddr.String())
if err != nil {
// If the node is actually dead we expect this to fail, so we
// shouldn't spam the logs with it. After this point, errors
// with the connection are real, unexpected errors and should
// get propagated up.
return false, nil
}
defer conn.Close()
conn.SetDeadline(deadline)
out, err := encode(pingMsg, &ping)
if err != nil {
return false, err
}
if err = m.rawSendMsgTCP(conn, out.Bytes()); err != nil {
return false, err
}
msgType, _, dec, err := m.readTCP(conn)
if err != nil {
return false, err
}
if msgType != ackRespMsg {
return false, fmt.Errorf("Unexpected msgType (%d) from TCP ping %s", msgType, LogConn(conn))
}
var ack ackResp
if err = dec.Decode(&ack); err != nil {
return false, err
}
if ack.SeqNo != ping.SeqNo {
return false, fmt.Errorf("Sequence number from ack (%d) doesn't match ping (%d) from TCP ping %s", ack.SeqNo, ping.SeqNo, LogConn(conn))
}
return true, nil
}