[overnet] Compress acks better
Test: Ran overnet unittests
Change-Id: I51c7b4e43c8f7d181ab5f7b4cb864bcab612f78c
diff --git a/garnet/lib/overnet/BUILD.gn b/garnet/lib/overnet/BUILD.gn
index 7296971..8cebda7 100644
--- a/garnet/lib/overnet/BUILD.gn
+++ b/garnet/lib/overnet/BUILD.gn
@@ -88,6 +88,7 @@
"datagram_stream:receive_mode_fuzzer",
"links:packet_nub_fuzzer",
"packet_protocol:packet_protocol_fuzzer",
+ "protocol:ack_frame_fuzzer",
"protocol:overnet_encoding_fuzzer",
"protocol:overnet_decoding_fuzzer",
"protocol:routable_message_fuzzer",
diff --git a/garnet/lib/overnet/packet_protocol/packet_protocol.cc b/garnet/lib/overnet/packet_protocol/packet_protocol.cc
index 6a8ad68..2eb74b0 100644
--- a/garnet/lib/overnet/packet_protocol/packet_protocol.cc
+++ b/garnet/lib/overnet/packet_protocol/packet_protocol.cc
@@ -628,14 +628,13 @@
TimeDelta::FromMicroseconds(ack.ack_delay_us()));
// Fail any nacked packets.
- // Nacks are received in descending order of sequence number. We iterate the
- // callbacks here in reverse order then so that the OLDEST nacked message is
- // the most likely to be sent first. This has the important consequence that
- // if the packet was a fragment of a large message that was rejected due to
- // buffering, the earlier pieces (that are more likely to fit) are
- // retransmitted first.
- for (auto it = ack.nack_seqs().rbegin(); it != ack.nack_seqs().rend(); ++it) {
- ack_processor.Nack(*it, Status::Unavailable());
+ // Iteration is from oldest packet to newest, such that the OLDEST nacked
+ // message is the most likely to be sent first. This has the important
+ // consequence that if the packet was a fragment of a large message that was
+ // rejected due to buffering, the earlier pieces (that are more likely to fit)
+ // are retransmitted first.
+ for (auto nack_seq : ack.nack_seqs()) {
+ ack_processor.Nack(nack_seq, Status::Unavailable());
}
// Clear out outstanding packet references, propagating acks.
diff --git a/garnet/lib/overnet/protocol/BUILD.gn b/garnet/lib/overnet/protocol/BUILD.gn
index a2a08dc..42535e4 100644
--- a/garnet/lib/overnet/protocol/BUILD.gn
+++ b/garnet/lib/overnet/protocol/BUILD.gn
@@ -50,6 +50,16 @@
]
}
+fuzz_target("ack_frame_fuzzer") {
+ testonly = true
+ sources = [
+ "ack_frame_fuzzer.cc",
+ ]
+ deps = [
+ ":ack_frame",
+ ]
+}
+
# coding
source_set("coding") {
sources = [
diff --git a/garnet/lib/overnet/protocol/ack_frame.cc b/garnet/lib/overnet/protocol/ack_frame.cc
index 45b2e6c..3b23fce 100644
--- a/garnet/lib/overnet/protocol/ack_frame.cc
+++ b/garnet/lib/overnet/protocol/ack_frame.cc
@@ -9,44 +9,25 @@
namespace overnet {
AckFrame::Writer::Writer(const AckFrame* ack_frame)
- : ack_frame_(ack_frame),
- ack_to_seq_length_(varint::WireSizeFor(ack_frame_->ack_to_seq_)),
- delay_and_flags_length_(
- varint::WireSizeFor(ack_frame_->DelayAndFlags())) {
- wire_length_ = ack_to_seq_length_ + delay_and_flags_length_;
- nack_length_.reserve(ack_frame_->nack_seqs_.size());
- uint64_t base = ack_frame_->ack_to_seq_;
- for (auto n : ack_frame_->nack_seqs_) {
- auto enc = base - n;
- auto l = varint::WireSizeFor(enc);
- wire_length_ += l;
- nack_length_.push_back(l);
- base = n;
- }
- assert(ack_frame->WrittenLength() == wire_length_);
-}
+ : ack_frame_(ack_frame), wire_length_(ack_frame_->WrittenLength()) {}
uint64_t AckFrame::WrittenLength() const {
uint64_t wire_length =
varint::WireSizeFor(ack_to_seq_) + varint::WireSizeFor(DelayAndFlags());
- uint64_t base = ack_to_seq_;
- for (auto n : nack_seqs_) {
- wire_length += varint::WireSizeFor(base - n);
- base = n;
+ for (const auto block : blocks_) {
+ wire_length += varint::WireSizeFor(block.acks);
+ wire_length += varint::WireSizeFor(block.nacks);
}
return wire_length;
}
uint8_t* AckFrame::Writer::Write(uint8_t* out) const {
uint8_t* p = out;
- p = varint::Write(ack_frame_->ack_to_seq_, ack_to_seq_length_, p);
- p = varint::Write(ack_frame_->DelayAndFlags(), delay_and_flags_length_, p);
- uint64_t base = ack_frame_->ack_to_seq_;
- for (size_t i = 0; i < nack_length_.size(); i++) {
- auto n = ack_frame_->nack_seqs_[i];
- auto enc = base - n;
- p = varint::Write(enc, nack_length_[i], p);
- base = n;
+ p = varint::Write(ack_frame_->ack_to_seq_, p);
+ p = varint::Write(ack_frame_->DelayAndFlags(), p);
+ for (const auto block : ack_frame_->blocks_) {
+ p = varint::Write(block.acks, p);
+ p = varint::Write(block.nacks, p);
}
assert(p == out + wire_length_);
return p;
@@ -82,18 +63,31 @@
frame.partial_ = is_partial;
uint64_t base = ack_to_seq;
while (bytes != end) {
- uint64_t offset;
- if (!varint::Read(&bytes, end, &offset)) {
+ uint64_t acks, nacks;
+ if (!varint::Read(&bytes, end, &acks)) {
return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
- "Failed to read nack offset from ack frame");
+ "Failed to read ack count from ack frame");
}
- if (offset >= base) {
+ if (!varint::Read(&bytes, end, &nacks)) {
return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
- "Failed to read nack");
+ "Failed to read nack count from ack frame");
}
- const uint64_t seq = base - offset;
- frame.AddNack(seq);
- base = seq;
+ if (acks >= base) {
+ return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
+ "Failed to read nack (too many acks)");
+ }
+ if (nacks > base - acks) {
+ return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
+ "Failed to read nack (too many nacks)");
+ }
+ if (nacks == 0) {
+ return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
+ "Nack count cannot be zero");
+ }
+ base -= acks;
+ base -= nacks;
+ frame.blocks_.push_back(Block{acks, nacks});
+ frame.last_nack_ = base + 1;
}
return StatusOr<AckFrame>(std::move(frame));
}
diff --git a/garnet/lib/overnet/protocol/ack_frame.h b/garnet/lib/overnet/protocol/ack_frame.h
index 4f5e5a8..0a7738a 100644
--- a/garnet/lib/overnet/protocol/ack_frame.h
+++ b/garnet/lib/overnet/protocol/ack_frame.h
@@ -8,12 +8,22 @@
#include <tuple>
#include <vector>
#include "garnet/lib/overnet/environment/trace.h"
+#include "garnet/lib/overnet/protocol/varint.h"
#include "garnet/lib/overnet/vocabulary/slice.h"
#include "garnet/lib/overnet/vocabulary/status.h"
namespace overnet {
class AckFrame {
+ struct Block {
+ uint64_t acks;
+ uint64_t nacks;
+ bool operator==(const Block& other) const {
+ return acks == other.acks && nacks == other.nacks;
+ }
+ };
+ using Brit = std::vector<Block>::const_reverse_iterator;
+
public:
class Writer {
public:
@@ -24,23 +34,21 @@
private:
const AckFrame* const ack_frame_;
- const uint8_t ack_to_seq_length_;
- const uint8_t delay_and_flags_length_;
- std::vector<uint8_t> nack_length_;
- size_t wire_length_;
+ const size_t wire_length_;
};
AckFrame(uint64_t ack_to_seq, uint64_t ack_delay_us)
- : ack_to_seq_(ack_to_seq), ack_delay_us_(ack_delay_us) {
+ : partial_(false), ack_to_seq_(ack_to_seq), ack_delay_us_(ack_delay_us) {
assert(ack_to_seq_ > 0);
}
AckFrame(uint64_t ack_to_seq, uint64_t ack_delay_us,
std::initializer_list<uint64_t> nack_seqs)
- : ack_to_seq_(ack_to_seq), ack_delay_us_(ack_delay_us) {
+ : partial_(false), ack_to_seq_(ack_to_seq), ack_delay_us_(ack_delay_us) {
assert(ack_to_seq_ > 0);
- for (auto n : nack_seqs)
+ for (auto n : nack_seqs) {
AddNack(n);
+ }
}
AckFrame(const AckFrame&) = delete;
@@ -50,57 +58,142 @@
: partial_(other.partial_),
ack_to_seq_(other.ack_to_seq_),
ack_delay_us_(other.ack_delay_us_),
- nack_seqs_(std::move(other.nack_seqs_)) {}
+ blocks_(std::move(other.blocks_)),
+ last_nack_(other.last_nack_) {}
AckFrame& operator=(AckFrame&& other) {
partial_ = other.partial_;
ack_to_seq_ = other.ack_to_seq_;
ack_delay_us_ = other.ack_delay_us_;
- nack_seqs_ = std::move(other.nack_seqs_);
+ blocks_ = std::move(other.blocks_);
+ last_nack_ = other.last_nack_;
return *this;
}
void AddNack(uint64_t seq) {
assert(ack_to_seq_ > 0);
assert(seq <= ack_to_seq_);
- if (!nack_seqs_.empty()) {
- assert(seq < nack_seqs_.back());
+ assert(seq > 0);
+ if (!blocks_.empty()) {
+ assert(seq < last_nack_);
+ if (seq == last_nack_ - 1) {
+ blocks_.back().nacks++;
+ } else {
+ blocks_.emplace_back(Block{last_nack_ - seq - 1, 1});
+ }
+ } else {
+ blocks_.emplace_back(Block{ack_to_seq_ - seq, 1});
}
- nack_seqs_.push_back(seq);
+ last_nack_ = seq;
}
static StatusOr<AckFrame> Parse(Slice slice);
friend bool operator==(const AckFrame& a, const AckFrame& b) {
- return std::tie(a.ack_to_seq_, a.ack_delay_us_, a.nack_seqs_) ==
- std::tie(b.ack_to_seq_, b.ack_delay_us_, b.nack_seqs_);
+ if (std::tie(a.ack_to_seq_, a.ack_delay_us_, a.blocks_, a.partial_) !=
+ std::tie(b.ack_to_seq_, b.ack_delay_us_, b.blocks_, b.partial_)) {
+ return false;
+ }
+ if (!a.blocks_.empty()) {
+ return a.last_nack_ == b.last_nack_;
+ }
+ return true;
}
uint64_t ack_to_seq() const { return ack_to_seq_; }
uint64_t ack_delay_us() const { return ack_delay_us_; }
bool partial() const { return partial_; }
- const std::vector<uint64_t>& nack_seqs() const { return nack_seqs_; }
+
+ class NackSeqs {
+ public:
+ NackSeqs(const AckFrame* ack_frame) : ack_frame_(ack_frame) {}
+
+ class Iterator {
+ public:
+ Iterator(Brit brit, uint64_t base) : brit_(brit), base_(base) {}
+
+ bool operator!=(const Iterator& other) const {
+ return brit_ != other.brit_ || base_ != other.base_ ||
+ nack_ != other.nack_;
+ }
+
+ void operator++() {
+ nack_++;
+ if (nack_ == brit_->nacks) {
+ base_ += brit_->nacks + brit_->acks;
+ ++brit_;
+ nack_ = 0;
+ }
+ }
+
+ uint64_t operator*() const { return base_ + nack_; }
+
+ private:
+ Brit brit_;
+ uint64_t base_;
+ uint64_t nack_ = 0;
+ };
+
+ std::vector<uint64_t> AsVector() const {
+ std::vector<uint64_t> out;
+ for (auto n : *this) {
+ out.push_back(n);
+ }
+ return out;
+ }
+
+ Iterator begin() const {
+ if (ack_frame_->blocks_.empty()) {
+ return end();
+ }
+ return Iterator(ack_frame_->blocks_.rbegin(), ack_frame_->last_nack_);
+ }
+ Iterator end() const {
+ return Iterator(ack_frame_->blocks_.rend(), ack_frame_->ack_to_seq_ + 1);
+ }
+
+ private:
+ const AckFrame* ack_frame_;
+ };
+ NackSeqs nack_seqs() const { return NackSeqs(this); }
// Move ack_to_seq back in time such that the total ack frame will fit
// within mss. DelayFn is a function uint64_t -> TimeDelta that returns the
// ack delay (in microseconds) for a given sequence number.
template <class DelayFn>
void AdjustForMSS(uint32_t mss, DelayFn delay_fn) {
- while (!nack_seqs_.empty() && WrittenLength() > mss) {
+ while (!blocks_.empty() && WrittenLength() > mss) {
partial_ = true;
- if (ack_to_seq_ != nack_seqs_[0]) {
+ auto& block0_acks = blocks_[0].acks;
+ auto& block0_nacks = blocks_[0].nacks;
+ if (block0_acks > 0) {
+ auto new_acks = varint::SmallerRecordedNumber(block0_acks);
OVERNET_TRACE(DEBUG) << "Trim too long ack (" << WrittenLength()
<< " > " << mss << " by moving ack " << ack_to_seq_
- << " to first nack " << nack_seqs_[0];
- ack_to_seq_ = nack_seqs_[0];
+ << " to shorter first ack block length "
+ << (ack_to_seq_ - block0_acks + new_acks);
+ ack_to_seq_ -= (block0_acks - new_acks);
+ block0_acks = new_acks;
} else {
- OVERNET_TRACE(DEBUG)
- << "Trim too long ack (" << WrittenLength() << " > " << mss
- << " by trimming first nack " << nack_seqs_[0];
- nack_seqs_.erase(nack_seqs_.begin());
- ack_to_seq_--;
+ assert(block0_nacks > 0);
+ auto new_nacks = varint::SmallerRecordedNumber(block0_nacks);
+ if (new_nacks == 0) {
+ OVERNET_TRACE(DEBUG)
+ << "Trim too long ack (" << WrittenLength() << " > " << mss
+ << " by eliminating first block and moving first ack to "
+ << (ack_to_seq_ - block0_nacks + new_nacks);
+ ack_to_seq_ -= (block0_nacks - new_nacks);
+ blocks_.erase(blocks_.begin());
+ } else {
+ OVERNET_TRACE(DEBUG)
+ << "Trim too long ack (" << WrittenLength() << " > " << mss
+ << " by moving ack " << ack_to_seq_
+ << " to shorter first nack block length "
+ << (ack_to_seq_ - block0_nacks + new_nacks);
+ ack_to_seq_ -= (block0_nacks - new_nacks);
+ block0_nacks = new_nacks;
+ }
}
- ack_delay_us_ = delay_fn(ack_to_seq_);
}
}
@@ -111,16 +204,17 @@
// Flag indicating that this ack is only a partial acknowledgement, and
// there's more to come.
bool partial_ = false;
+
// All messages with sequence number prior to ack_to_seq_ are implicitly
// acknowledged.
uint64_t ack_to_seq_;
// How long between receiving ack_delay_seq_ and generating this data
// structure.
uint64_t ack_delay_us_;
- // All messages contained in nack_seqs_ need to be resent.
- // NOTE: it's assumed that nack_seqs_ is in-order descending and all value are
- // less than or equal to ack_to_seq_.
- std::vector<uint64_t> nack_seqs_;
+ // From ack_to_seq working back we record blocks. A block contains some number
+ // of acks followed by some number of nacks.
+ std::vector<Block> blocks_;
+ uint64_t last_nack_;
};
std::ostream& operator<<(std::ostream& out, const AckFrame& ack_frame);
diff --git a/garnet/lib/overnet/protocol/ack_frame_fuzzer.cc b/garnet/lib/overnet/protocol/ack_frame_fuzzer.cc
new file mode 100644
index 0000000..a05849b
--- /dev/null
+++ b/garnet/lib/overnet/protocol/ack_frame_fuzzer.cc
@@ -0,0 +1,27 @@
+// Copyright 2019 The Fuchsia Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "garnet/lib/overnet/protocol/ack_frame.h"
+
+using namespace overnet;
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+ auto input = Slice::FromCopiedBuffer(data, size);
+ // Parse data as an ack frame.
+ auto status = AckFrame::Parse(input);
+ if (status.is_ok()) {
+ if (status->ack_to_seq() < 100000) {
+ for (auto n : status->nack_seqs()) {
+ [](auto) {}(n);
+ }
+ }
+ // If parsed ok: rewrite.
+ Slice written = Slice::FromWriters(AckFrame::Writer(status.get()));
+ // Should get the equivalent set of bytes.
+ auto status2 = AckFrame::Parse(written);
+ assert(status2.is_ok());
+ assert(*status == *status2);
+ }
+ return 0; // Non-zero return values are reserved for future use.
+}
diff --git a/garnet/lib/overnet/protocol/ack_frame_test.cc b/garnet/lib/overnet/protocol/ack_frame_test.cc
index 76e4a7f..d670e3b 100644
--- a/garnet/lib/overnet/protocol/ack_frame_test.cc
+++ b/garnet/lib/overnet/protocol/ack_frame_test.cc
@@ -21,7 +21,7 @@
auto v = Encode(h);
EXPECT_EQ(expect, v);
auto p = AckFrame::Parse(Slice::FromCopiedBuffer(v.data(), v.size()));
- EXPECT_TRUE(p.is_ok());
+ EXPECT_TRUE(p.is_ok()) << p.AsStatus();
EXPECT_EQ(h, *p.get());
}
@@ -33,7 +33,8 @@
TEST(AckFrame, OneNack) {
AckFrame h(5, 10);
h.AddNack(2);
- RoundTrip(h, {5, 20, 3});
+ RoundTrip(h, {5, 20, 3, 1});
+ EXPECT_EQ(h.nack_seqs().AsVector(), std::vector<uint64_t>({2}));
}
TEST(AckFrame, ThreeNacks) {
@@ -41,7 +42,52 @@
h.AddNack(4);
h.AddNack(3);
h.AddNack(2);
- RoundTrip(h, {5, 84, 1, 1, 1});
+ RoundTrip(h, {5, 84, 1, 3});
+ EXPECT_EQ(h.nack_seqs().AsVector(), std::vector<uint64_t>({2, 3, 4}));
+}
+
+TEST(AckFrame, TwoBlocks) {
+ AckFrame h(20, 42);
+ h.AddNack(15);
+ h.AddNack(14);
+ h.AddNack(13);
+ h.AddNack(5);
+ h.AddNack(4);
+ h.AddNack(3);
+ h.AddNack(2);
+ h.AddNack(1);
+ RoundTrip(h, {20, 84, 5, 3, 7, 5});
+ EXPECT_EQ(h.nack_seqs().AsVector(),
+ std::vector<uint64_t>({1, 2, 3, 4, 5, 13, 14, 15}));
+}
+
+TEST(AckFrame, FuzzedExamples) {
+ auto test = [](std::initializer_list<uint8_t> bytes) {
+ auto input = Slice::FromContainer(bytes);
+ std::cerr << "Test: " << input << "\n";
+ if (auto p = AckFrame::Parse(input); p.is_ok()) {
+ {
+ auto ns = p->nack_seqs();
+ const auto begin = ns.begin();
+ const auto end = ns.end();
+ for (auto it = begin; it != end; ++it) {
+ [](auto) {}(*it);
+ }
+ }
+ std::cerr << "Parsed: " << *p << "\n";
+ auto written = Slice::FromWriters(AckFrame::Writer(p.get()));
+ auto p2 = AckFrame::Parse(written);
+ EXPECT_TRUE(p2.is_ok());
+ EXPECT_EQ(*p, *p2);
+ } else {
+ std::cerr << "Parse error: " << p.AsStatus() << "\n";
+ }
+ };
+ test({0x0a, 0x0a, 0x00, 0x00});
+ test({0xc1, 0xe0, 0x00, 0x2d});
+ test({0x80, 0xcd, 0xcd, 0xcd, 0xcd, 0x2b, 0x00, 0x2f, 0xcd, 0xcd, 0xf9, 0xe4,
+ 0x00, 0x51});
+ test({0x65, 0x01, 0x01, 0x02});
}
} // namespace ack_frame_test
diff --git a/garnet/lib/overnet/protocol/varint.cc b/garnet/lib/overnet/protocol/varint.cc
index b95af61..8709499 100644
--- a/garnet/lib/overnet/protocol/varint.cc
+++ b/garnet/lib/overnet/protocol/varint.cc
@@ -30,6 +30,28 @@
return 10;
}
+uint64_t SmallerRecordedNumber(uint64_t x) {
+ if (x < (1ull << 7))
+ return 0;
+ if (x < (1ull << 14))
+ return (1ull << 7) - 1;
+ if (x < (1ull << 21))
+ return (1ull << 14) - 1;
+ if (x < (1ull << 28))
+ return (1ull << 21) - 1;
+ if (x < (1ull << 35))
+ return (1ull << 28) - 1;
+ if (x < (1ull << 42))
+ return (1ull << 35) - 1;
+ if (x < (1ull << 49))
+ return (1ull << 42) - 1;
+ if (x < (1ull << 56))
+ return (1ull << 49) - 1;
+ if (x < (1ull << 63))
+ return (1ull << 56) - 1;
+ return (1ull << 63) - 1;
+}
+
uint64_t MaximumLengthWithPrefix(uint64_t x) {
assert(x > 0);
uint64_t r = x - WireSizeFor(x);
diff --git a/garnet/lib/overnet/protocol/varint.h b/garnet/lib/overnet/protocol/varint.h
index df2cb80..3bdf52a 100644
--- a/garnet/lib/overnet/protocol/varint.h
+++ b/garnet/lib/overnet/protocol/varint.h
@@ -17,6 +17,11 @@
// convenience
uint8_t* Write(uint64_t x, uint8_t wire_length, uint8_t* dst);
+// Variant of Write that does not need the pre-computed length
+inline uint8_t* Write(uint64_t x, uint8_t* dst) {
+ return Write(x, WireSizeFor(x), dst);
+}
+
namespace impl {
bool ReadFallback(const uint8_t** bytes, const uint8_t* end, uint64_t* result);
}
@@ -35,6 +40,11 @@
// such that the total length does not exceed fit_to?
uint64_t MaximumLengthWithPrefix(uint64_t fit_to);
+// Returns the largest number n, n < x, such that the number of bytes to record
+// n is less than the number of bytes to record x (or zero if this is not
+// possible)
+uint64_t SmallerRecordedNumber(uint64_t x);
+
} // namespace varint
} // namespace overnet