blob: ec4c7086208f8188b9a65331b7754eac5ac78979 [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#pragma once
#include <fbl/ref_counted.h>
#include <fbl/ref_ptr.h>
#include <random>
#include <unordered_map>
#include <unordered_set>
#include "src/connectivity/overnet/lib/environment/timer.h"
#include "src/connectivity/overnet/lib/labels/node_id.h"
#include "src/connectivity/overnet/lib/links/packet_link.h"
#include "src/connectivity/overnet/lib/vocabulary/slice.h"
namespace overnet {
template <class Address, uint32_t kMSS, class HashAddress = std::hash<Address>,
class EqAddress = std::equal_to<Address>>
class PacketNub {
enum class PacketOp : uint8_t {
Connected = 0,
CallMeMaybe = 1,
Hello = 2,
HelloAck = 3,
GoAway = 4,
};
friend std::ostream& operator<<(std::ostream& out, PacketOp op) {
switch (op) {
case PacketOp::Connected: {
return out << "Connected";
}
case PacketOp::CallMeMaybe: {
return out << "CallMeMaybe";
}
case PacketOp::Hello: {
return out << "Hello";
}
case PacketOp::HelloAck: {
return out << "HelloAck";
}
case PacketOp::GoAway: {
return out << "GoAway";
}
}
return out << "UnknownPacketOp(" << static_cast<int>(op) << ")";
}
enum class LinkState : uint8_t {
Initial,
Announcing,
SayingHello,
AckingHello,
SemiConnected,
Connected,
};
friend std::ostream& operator<<(std::ostream& out, LinkState state) {
switch (state) {
case LinkState::Initial: {
return out << "Initial";
}
case LinkState::Announcing: {
return out << "Announcing";
}
case LinkState::SayingHello: {
return out << "SayingHello";
}
case LinkState::AckingHello: {
return out << "AckingHello";
}
case LinkState::SemiConnected: {
return out << "SemiConnected";
}
case LinkState::Connected: {
return out << "Connected";
}
}
return out << "UnknownState(" << static_cast<int>(state) << ")";
}
struct LinkData;
using LinkDataPtr = fbl::RefPtr<LinkData>;
class NubLink final : public PacketLink {
public:
NubLink(PacketNub* nub, LinkDataPtr link, NodeId peer, uint64_t label)
: PacketLink(nub->GetRouter(), peer, kMSS, label),
nub_(nub),
link_(link) {}
~NubLink() { Delist(); }
void Emit(Slice packet) override {
if (nub_ == nullptr) {
return;
}
nub_->SendToLink(link_, std::move(packet));
}
void Tombstone() override {
Delist();
PacketLink::Tombstone();
}
private:
void Delist() {
if (nub_ == nullptr) {
return;
}
nub_->RemoveLink(link_);
nub_ = nullptr;
link_->link = nullptr;
}
PacketNub* nub_;
const LinkDataPtr link_;
};
struct LinkData : public fbl::RefCounted<LinkData> {
~LinkData() { assert(link == nullptr); }
LinkData(std::vector<Address> addresses)
: addresses(std::move(addresses)) {}
std::vector<Address> addresses;
Optional<Address> preferred_address;
LinkState state = LinkState::Initial;
Optional<NodeId> node_id;
NubLink* link = nullptr;
int ticks = -1;
Optional<Timeout> next_timeout;
Optional<int> SetState(LinkState st) {
OVERNET_TRACE(DEBUG) << "SetState for " << AddrVecStr(addresses) << " to "
<< st;
next_timeout.Reset();
if (state != st) {
state = st;
ticks = 0;
} else {
ticks++;
if (ticks >= 5) {
// Don't time out in semi-connected
if (st == LinkState::SemiConnected) {
ticks = 5;
} else {
return Nothing;
}
}
}
return ticks;
}
Optional<int> SetStateAndMaybeNode(LinkState st, Optional<NodeId> node) {
if (node) {
if (node_id) {
if (*node_id != *node) {
OVERNET_TRACE(DEBUG) << "Node id changed to " << *node;
return Nothing;
}
} else {
node_id = *node;
}
} else {
assert(node_id);
}
return SetState(st);
}
};
public:
static constexpr inline auto kModule = Module::NUB;
static constexpr size_t kCallMeMaybeSize = 256;
static constexpr size_t kHelloSize = 256;
static constexpr uint64_t kAnnounceResendMillis = 1000;
PacketNub(Timer* timer, NodeId node) : timer_(timer), local_node_(node) {}
virtual ~PacketNub() {}
virtual void SendTo(Address dest, Slice slice) = 0;
virtual Router* GetRouter() = 0;
virtual void Publish(LinkPtr<> link) = 0;
virtual void Process(TimeStamp received, Address src, Slice slice) {
ScopedModule<PacketNub> in_nub(this);
// Extract node id and op from slice... this code must be identical with
// PacketLink.
const uint8_t* const begin = slice.begin();
const uint8_t* p = begin;
const uint8_t* const end = slice.end();
OVERNET_TRACE(DEBUG) << "INCOMING: from:" << src << " " << received << " "
<< slice;
if (p == end) {
OVERNET_TRACE(INFO) << "Short packet received (no op code)";
return;
}
const PacketOp op = static_cast<PacketOp>(*p++);
auto op_state = [](PacketOp op, LinkState state) constexpr {
return (static_cast<uint16_t>(op) << 8) | static_cast<uint16_t>(state);
};
while (true) {
LinkDataPtr link = LinkForIncomingPacket(src);
if (!link->preferred_address) {
link->preferred_address = src;
}
uint64_t node_id;
OVERNET_TRACE(DEBUG) << " op=" << op << " state=" << link->state;
switch (op_state(op, link->state)) {
case op_state(PacketOp::GoAway, LinkState::AckingHello):
case op_state(PacketOp::GoAway, LinkState::SemiConnected):
case op_state(PacketOp::GoAway, LinkState::Connected):
OVERNET_TRACE(TRACE) << "Forget " << src << " due to goaway";
if (link->link != nullptr) {
link->link->Tombstone();
} else {
links_.erase(src);
}
break;
case op_state(PacketOp::Connected, LinkState::AckingHello):
if (!link->SetState(LinkState::Connected)) {
OVERNET_TRACE(DEBUG)
<< "Forget " << src << " couldn't set connected";
links_.erase(src);
return;
}
link->next_timeout.Reset();
BecomePublished(link);
continue;
case op_state(PacketOp::Connected, LinkState::SemiConnected):
if (!link->SetState(LinkState::Connected)) {
OVERNET_TRACE(DEBUG)
<< "Forget " << src << " couldn't set connected";
RemoveLink(link);
return;
}
link->next_timeout.Reset();
continue;
case op_state(PacketOp::Connected, LinkState::Connected):
if (p == end && link->node_id && *link->node_id < local_node_) {
// Empty connected packets get reflected to fully advance state
// machine in the case of connecting when applications are idle.
SendToLink(link, std::move(slice));
} else {
link->link->Process(received, std::move(slice));
}
return;
case op_state(PacketOp::CallMeMaybe, LinkState::SayingHello):
if (link->next_timeout.has_value()) {
OVERNET_TRACE(INFO)
<< "Received packet op " << op << " from " << src
<< " on link state " << link->state << ": ignoring.";
return;
}
[[fallthrough]];
case op_state(PacketOp::CallMeMaybe, LinkState::Initial):
if (end - begin != kCallMeMaybeSize) {
OVERNET_TRACE(INFO) << "Received a mis-sized CallMeMaybe packet";
} else if (!ParseLE64(&p, end, &node_id)) {
OVERNET_TRACE(INFO)
<< "Failed to parse node id from CallMeMaybe packet";
} else if (!AllZeros(p, end)) {
OVERNET_TRACE(INFO) << "CallMeMaybe padding should be all zeros";
} else if (NodeId(node_id) == local_node_) {
OVERNET_TRACE(INFO) << "CallMeMaybe received for local node";
} else if (NodeId(node_id) < local_node_) {
OVERNET_TRACE(INFO)
<< "CallMeMaybe received from a smaller numbered node id";
} else {
StartHello(link, NodeId(node_id), ResendBehavior::kNever);
}
return;
case op_state(PacketOp::Hello, LinkState::Initial):
case op_state(PacketOp::Hello, LinkState::AckingHello):
case op_state(PacketOp::Hello, LinkState::Announcing):
if (end - begin != kHelloSize) {
OVERNET_TRACE(INFO) << "Received a mis-sized Hello packet";
} else if (!ParseLE64(&p, end, &node_id)) {
OVERNET_TRACE(INFO) << "Failed to parse node id from Hello packet";
} else if (!AllZeros(p, end)) {
OVERNET_TRACE(INFO) << "Hello padding should be all zeros";
} else if (local_node_ < NodeId(node_id)) {
OVERNET_TRACE(INFO)
<< "Hello received from a larger numbered node id";
} else if (NodeId(node_id) == local_node_) {
OVERNET_TRACE(INFO) << "Hello received for local node";
} else {
StartHelloAck(link, NodeId(node_id));
}
return;
case op_state(PacketOp::HelloAck, LinkState::SayingHello):
if (end != p) {
OVERNET_TRACE(INFO) << "Received a mis-sized HelloAck packet";
} else {
// Must BecomePublished *AFTER* the state change, else the link
// could be dropped immediately during publishing forcing an
// undetected state change that gets wiped out.
StartSemiConnected(link);
BecomePublished(link);
}
return;
case op_state(PacketOp::Connected, LinkState::Initial):
case op_state(PacketOp::HelloAck, LinkState::Initial):
case op_state(PacketOp::Connected, LinkState::Announcing):
case op_state(PacketOp::HelloAck, LinkState::Announcing):
OVERNET_TRACE(INFO)
<< "Received packet op " << op << " from " << src
<< " on link state " << link->state << ": sending goaway.";
SendToLink(link, Slice::FromContainer(
{static_cast<uint8_t>(PacketOp::GoAway)}));
return;
default:
OVERNET_TRACE(INFO)
<< "Received packet op " << op << " from " << src
<< " on link state " << link->state << ": ignoring.";
return;
}
}
}
static std::string AddrVecStr(const std::vector<Address>& addrs) {
std::ostringstream out;
bool first = true;
out << "{";
for (auto addr : addrs) {
if (!first) {
out << ", ";
}
first = false;
out << addr;
}
out << "}";
return out.str();
}
void Initiate(std::vector<Address> peer_addresses, NodeId node) {
ScopedModule<PacketNub> in_nub(this);
LinkDataPtr link;
for (auto peer : peer_addresses) {
if (auto it = links_.find(peer); it != links_.end()) {
if (!link || link->state < it->second->state) {
link = it->second;
}
}
}
if (!link) {
link = CreateLink(peer_addresses);
} else {
for (auto old_addr : link->addresses) {
links_.erase(old_addr);
}
for (auto addr : peer_addresses) {
links_.emplace(addr, link);
}
}
OVERNET_TRACE(INFO) << "Initiate peer=" << AddrVecStr(peer_addresses)
<< " node=" << node << " state=" << link->state;
assert(node != local_node_);
if (link->state == LinkState::Initial ||
(link->state == LinkState::SayingHello &&
!link->next_timeout.has_value() && local_node_ < node)) {
if (node < local_node_) {
// To avoid duplicating links, we insist that lower indexed nodes
// initiate the connection.
StartAnnouncing(link, node);
} else {
StartHello(link, node, ResendBehavior::kResendable);
}
}
}
bool HasConnectionTo(Address peer) const {
auto it = links_.find(peer);
return it != links_.end() && it->second->link != nullptr;
}
private:
static constexpr bool AllZeros(const uint8_t* begin, const uint8_t* end) {
for (const uint8_t* p = begin; p != end; ++p) {
if (*p != 0)
return false;
}
return true;
}
TimeStamp BackoffForTicks(uint64_t initial_millis, int ticks) {
assert(initial_millis);
uint64_t millis = initial_millis;
for (int i = 0; i <= ticks; i++) {
millis = 11 * millis / 10;
}
assert(millis != initial_millis);
return timer_->Now() +
TimeDelta::FromMilliseconds(std::uniform_int_distribution<uint64_t>(
initial_millis, millis)(rng_));
}
LinkDataPtr LinkForIncomingPacket(Address address) {
if (auto it = links_.find(address); it != links_.end()) {
return it->second;
}
return CreateLink({address});
}
LinkDataPtr CreateLink(std::vector<Address> addrs) {
auto link = fbl::MakeRefCounted<LinkData>(std::move(addrs));
for (auto addr : link->addresses) {
links_.emplace(addr, link);
}
return link;
}
enum class ResendBehavior {
kNever,
kResendable,
};
friend std::ostream& operator<<(std::ostream& out, ResendBehavior resend) {
switch (resend) {
case ResendBehavior::kNever:
return out << "Never";
case ResendBehavior::kResendable:
return out << "Resendable";
}
abort();
}
void SendToLink(LinkDataPtr link, Slice slice) {
if (link->preferred_address) {
OVERNET_TRACE(DEBUG) << "SendTo addr=" << *link->preferred_address
<< " slice=" << slice;
SendTo(*link->preferred_address, std::move(slice));
} else {
for (auto addr : link->addresses) {
OVERNET_TRACE(DEBUG) << "SendTo addr=" << addr << " slice=" << slice;
SendTo(addr, slice);
}
}
}
void RemoveLink(LinkDataPtr link) {
for (auto addr : link->addresses) {
if (auto it = links_.find(addr);
it != links_.end() && it->second == link) {
links_.erase(it);
}
}
}
template <class F>
void StartSimpleState(LinkDataPtr link, Optional<NodeId> node,
LinkState state, ResendBehavior resend,
size_t packet_size, F packet_writer) {
OVERNET_TRACE(DEBUG) << "StartState: addrs=" << AddrVecStr(link->addresses)
<< " node=" << node << " linkstate=" << state
<< " resend=" << resend
<< " packet_size=" << packet_size;
const Optional<int> ticks_or_nothing =
link->SetStateAndMaybeNode(state, node);
if (!ticks_or_nothing) {
OVERNET_TRACE(TRACE) << "Forget " << AddrVecStr(link->addresses)
<< " due to age";
RemoveLink(link);
return;
}
const int ticks = *ticks_or_nothing;
SendToLink(link, Slice::WithInitializer(packet_size, packet_writer));
switch (resend) {
case ResendBehavior::kResendable:
link->next_timeout.Reset(
timer_, BackoffForTicks(kAnnounceResendMillis, ticks),
StatusCallback(ALLOCATED_CALLBACK, [=](const Status& status) {
ScopedModule<PacketNub> in_nub(this);
if (status.is_error()) {
return;
}
bool is_current = false;
for (const auto& addr : link->addresses) {
if (auto it = links_.find(addr);
it != links_.end() && it->second == link) {
is_current = true;
break;
}
}
if (!is_current) {
return;
}
StartSimpleState(link, node, state, resend, packet_size,
packet_writer);
}));
break;
case ResendBehavior::kNever:
link->next_timeout.Reset();
break;
}
}
void StartAnnouncing(LinkDataPtr link, NodeId node) {
StartSimpleState(link, node, LinkState::Announcing,
ResendBehavior::kResendable, kCallMeMaybeSize,
[local_node = local_node_](uint8_t* p) {
memset(p, 0, kCallMeMaybeSize);
*p++ = static_cast<uint8_t>(PacketOp::CallMeMaybe);
p = local_node.Write(p);
});
}
void StartHello(LinkDataPtr link, NodeId node, ResendBehavior resend) {
StartSimpleState(link, node, LinkState::SayingHello, resend, kHelloSize,
[local_node = local_node_](uint8_t* p) {
memset(p, 0, kHelloSize);
*p++ = static_cast<uint8_t>(PacketOp::Hello);
p = local_node.Write(p);
});
}
void StartHelloAck(LinkDataPtr link, NodeId node) {
StartSimpleState(
link, node, LinkState::AckingHello, ResendBehavior::kNever, 1,
[](uint8_t* p) { *p = static_cast<uint8_t>(PacketOp::HelloAck); });
}
void StartSemiConnected(LinkDataPtr link) {
StartSimpleState(
link, Nothing, LinkState::SemiConnected, ResendBehavior::kResendable, 1,
[](uint8_t* p) { *p = static_cast<uint8_t>(PacketOp::Connected); });
}
void BecomePublished(LinkDataPtr link) {
assert(link->link == nullptr);
assert(link->node_id);
link->link = new NubLink(this, link, *link->node_id, next_label_++);
Publish(LinkPtr<>(link->link));
}
Timer* const timer_;
const NodeId local_node_;
std::unordered_map<Address, LinkDataPtr, HashAddress, EqAddress> links_;
std::mt19937_64 rng_;
uint64_t next_label_ = 1;
};
} // namespace overnet