[overnet] Compress acks better

Test: Ran overnet unittests

Change-Id: I51c7b4e43c8f7d181ab5f7b4cb864bcab612f78c
diff --git a/garnet/lib/overnet/BUILD.gn b/garnet/lib/overnet/BUILD.gn
index 7296971..8cebda7 100644
--- a/garnet/lib/overnet/BUILD.gn
+++ b/garnet/lib/overnet/BUILD.gn
@@ -88,6 +88,7 @@
     "datagram_stream:receive_mode_fuzzer",
     "links:packet_nub_fuzzer",
     "packet_protocol:packet_protocol_fuzzer",
+    "protocol:ack_frame_fuzzer",
     "protocol:overnet_encoding_fuzzer",
     "protocol:overnet_decoding_fuzzer",
     "protocol:routable_message_fuzzer",
diff --git a/garnet/lib/overnet/packet_protocol/packet_protocol.cc b/garnet/lib/overnet/packet_protocol/packet_protocol.cc
index 6a8ad68..2eb74b0 100644
--- a/garnet/lib/overnet/packet_protocol/packet_protocol.cc
+++ b/garnet/lib/overnet/packet_protocol/packet_protocol.cc
@@ -628,14 +628,13 @@
                              TimeDelta::FromMicroseconds(ack.ack_delay_us()));
 
   // Fail any nacked packets.
-  // Nacks are received in descending order of sequence number. We iterate the
-  // callbacks here in reverse order then so that the OLDEST nacked message is
-  // the most likely to be sent first. This has the important consequence that
-  // if the packet was a fragment of a large message that was rejected due to
-  // buffering, the earlier pieces (that are more likely to fit) are
-  // retransmitted first.
-  for (auto it = ack.nack_seqs().rbegin(); it != ack.nack_seqs().rend(); ++it) {
-    ack_processor.Nack(*it, Status::Unavailable());
+  // Iteration is from oldest packet to newest, such that the OLDEST nacked
+  // message is the most likely to be sent first. This has the important
+  // consequence that if the packet was a fragment of a large message that was
+  // rejected due to buffering, the earlier pieces (that are more likely to fit)
+  // are retransmitted first.
+  for (auto nack_seq : ack.nack_seqs()) {
+    ack_processor.Nack(nack_seq, Status::Unavailable());
   }
 
   // Clear out outstanding packet references, propagating acks.
diff --git a/garnet/lib/overnet/protocol/BUILD.gn b/garnet/lib/overnet/protocol/BUILD.gn
index a2a08dc..42535e4 100644
--- a/garnet/lib/overnet/protocol/BUILD.gn
+++ b/garnet/lib/overnet/protocol/BUILD.gn
@@ -50,6 +50,16 @@
   ]
 }
 
+fuzz_target("ack_frame_fuzzer") {
+  testonly = true
+  sources = [
+    "ack_frame_fuzzer.cc",
+  ]
+  deps = [
+    ":ack_frame",
+  ]
+}
+
 # coding
 source_set("coding") {
     sources = [
diff --git a/garnet/lib/overnet/protocol/ack_frame.cc b/garnet/lib/overnet/protocol/ack_frame.cc
index 45b2e6c..3b23fce 100644
--- a/garnet/lib/overnet/protocol/ack_frame.cc
+++ b/garnet/lib/overnet/protocol/ack_frame.cc
@@ -9,44 +9,25 @@
 namespace overnet {
 
 AckFrame::Writer::Writer(const AckFrame* ack_frame)
-    : ack_frame_(ack_frame),
-      ack_to_seq_length_(varint::WireSizeFor(ack_frame_->ack_to_seq_)),
-      delay_and_flags_length_(
-          varint::WireSizeFor(ack_frame_->DelayAndFlags())) {
-  wire_length_ = ack_to_seq_length_ + delay_and_flags_length_;
-  nack_length_.reserve(ack_frame_->nack_seqs_.size());
-  uint64_t base = ack_frame_->ack_to_seq_;
-  for (auto n : ack_frame_->nack_seqs_) {
-    auto enc = base - n;
-    auto l = varint::WireSizeFor(enc);
-    wire_length_ += l;
-    nack_length_.push_back(l);
-    base = n;
-  }
-  assert(ack_frame->WrittenLength() == wire_length_);
-}
+    : ack_frame_(ack_frame), wire_length_(ack_frame_->WrittenLength()) {}
 
 uint64_t AckFrame::WrittenLength() const {
   uint64_t wire_length =
       varint::WireSizeFor(ack_to_seq_) + varint::WireSizeFor(DelayAndFlags());
-  uint64_t base = ack_to_seq_;
-  for (auto n : nack_seqs_) {
-    wire_length += varint::WireSizeFor(base - n);
-    base = n;
+  for (const auto block : blocks_) {
+    wire_length += varint::WireSizeFor(block.acks);
+    wire_length += varint::WireSizeFor(block.nacks);
   }
   return wire_length;
 }
 
 uint8_t* AckFrame::Writer::Write(uint8_t* out) const {
   uint8_t* p = out;
-  p = varint::Write(ack_frame_->ack_to_seq_, ack_to_seq_length_, p);
-  p = varint::Write(ack_frame_->DelayAndFlags(), delay_and_flags_length_, p);
-  uint64_t base = ack_frame_->ack_to_seq_;
-  for (size_t i = 0; i < nack_length_.size(); i++) {
-    auto n = ack_frame_->nack_seqs_[i];
-    auto enc = base - n;
-    p = varint::Write(enc, nack_length_[i], p);
-    base = n;
+  p = varint::Write(ack_frame_->ack_to_seq_, p);
+  p = varint::Write(ack_frame_->DelayAndFlags(), p);
+  for (const auto block : ack_frame_->blocks_) {
+    p = varint::Write(block.acks, p);
+    p = varint::Write(block.nacks, p);
   }
   assert(p == out + wire_length_);
   return p;
@@ -82,18 +63,31 @@
   frame.partial_ = is_partial;
   uint64_t base = ack_to_seq;
   while (bytes != end) {
-    uint64_t offset;
-    if (!varint::Read(&bytes, end, &offset)) {
+    uint64_t acks, nacks;
+    if (!varint::Read(&bytes, end, &acks)) {
       return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
-                                "Failed to read nack offset from ack frame");
+                                "Failed to read ack count from ack frame");
     }
-    if (offset >= base) {
+    if (!varint::Read(&bytes, end, &nacks)) {
       return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
-                                "Failed to read nack");
+                                "Failed to read nack count from ack frame");
     }
-    const uint64_t seq = base - offset;
-    frame.AddNack(seq);
-    base = seq;
+    if (acks >= base) {
+      return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
+                                "Failed to read nack (too many acks)");
+    }
+    if (nacks > base - acks) {
+      return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
+                                "Failed to read nack (too many nacks)");
+    }
+    if (nacks == 0) {
+      return StatusOr<AckFrame>(StatusCode::INVALID_ARGUMENT,
+                                "Nack count cannot be zero");
+    }
+    base -= acks;
+    base -= nacks;
+    frame.blocks_.push_back(Block{acks, nacks});
+    frame.last_nack_ = base + 1;
   }
   return StatusOr<AckFrame>(std::move(frame));
 }
diff --git a/garnet/lib/overnet/protocol/ack_frame.h b/garnet/lib/overnet/protocol/ack_frame.h
index 4f5e5a8..0a7738a 100644
--- a/garnet/lib/overnet/protocol/ack_frame.h
+++ b/garnet/lib/overnet/protocol/ack_frame.h
@@ -8,12 +8,22 @@
 #include <tuple>
 #include <vector>
 #include "garnet/lib/overnet/environment/trace.h"
+#include "garnet/lib/overnet/protocol/varint.h"
 #include "garnet/lib/overnet/vocabulary/slice.h"
 #include "garnet/lib/overnet/vocabulary/status.h"
 
 namespace overnet {
 
 class AckFrame {
+  struct Block {
+    uint64_t acks;
+    uint64_t nacks;
+    bool operator==(const Block& other) const {
+      return acks == other.acks && nacks == other.nacks;
+    }
+  };
+  using Brit = std::vector<Block>::const_reverse_iterator;
+
  public:
   class Writer {
    public:
@@ -24,23 +34,21 @@
 
    private:
     const AckFrame* const ack_frame_;
-    const uint8_t ack_to_seq_length_;
-    const uint8_t delay_and_flags_length_;
-    std::vector<uint8_t> nack_length_;
-    size_t wire_length_;
+    const size_t wire_length_;
   };
 
   AckFrame(uint64_t ack_to_seq, uint64_t ack_delay_us)
-      : ack_to_seq_(ack_to_seq), ack_delay_us_(ack_delay_us) {
+      : partial_(false), ack_to_seq_(ack_to_seq), ack_delay_us_(ack_delay_us) {
     assert(ack_to_seq_ > 0);
   }
 
   AckFrame(uint64_t ack_to_seq, uint64_t ack_delay_us,
            std::initializer_list<uint64_t> nack_seqs)
-      : ack_to_seq_(ack_to_seq), ack_delay_us_(ack_delay_us) {
+      : partial_(false), ack_to_seq_(ack_to_seq), ack_delay_us_(ack_delay_us) {
     assert(ack_to_seq_ > 0);
-    for (auto n : nack_seqs)
+    for (auto n : nack_seqs) {
       AddNack(n);
+    }
   }
 
   AckFrame(const AckFrame&) = delete;
@@ -50,57 +58,142 @@
       : partial_(other.partial_),
         ack_to_seq_(other.ack_to_seq_),
         ack_delay_us_(other.ack_delay_us_),
-        nack_seqs_(std::move(other.nack_seqs_)) {}
+        blocks_(std::move(other.blocks_)),
+        last_nack_(other.last_nack_) {}
 
   AckFrame& operator=(AckFrame&& other) {
     partial_ = other.partial_;
     ack_to_seq_ = other.ack_to_seq_;
     ack_delay_us_ = other.ack_delay_us_;
-    nack_seqs_ = std::move(other.nack_seqs_);
+    blocks_ = std::move(other.blocks_);
+    last_nack_ = other.last_nack_;
     return *this;
   }
 
   void AddNack(uint64_t seq) {
     assert(ack_to_seq_ > 0);
     assert(seq <= ack_to_seq_);
-    if (!nack_seqs_.empty()) {
-      assert(seq < nack_seqs_.back());
+    assert(seq > 0);
+    if (!blocks_.empty()) {
+      assert(seq < last_nack_);
+      if (seq == last_nack_ - 1) {
+        blocks_.back().nacks++;
+      } else {
+        blocks_.emplace_back(Block{last_nack_ - seq - 1, 1});
+      }
+    } else {
+      blocks_.emplace_back(Block{ack_to_seq_ - seq, 1});
     }
-    nack_seqs_.push_back(seq);
+    last_nack_ = seq;
   }
 
   static StatusOr<AckFrame> Parse(Slice slice);
 
   friend bool operator==(const AckFrame& a, const AckFrame& b) {
-    return std::tie(a.ack_to_seq_, a.ack_delay_us_, a.nack_seqs_) ==
-           std::tie(b.ack_to_seq_, b.ack_delay_us_, b.nack_seqs_);
+    if (std::tie(a.ack_to_seq_, a.ack_delay_us_, a.blocks_, a.partial_) !=
+        std::tie(b.ack_to_seq_, b.ack_delay_us_, b.blocks_, b.partial_)) {
+      return false;
+    }
+    if (!a.blocks_.empty()) {
+      return a.last_nack_ == b.last_nack_;
+    }
+    return true;
   }
 
   uint64_t ack_to_seq() const { return ack_to_seq_; }
   uint64_t ack_delay_us() const { return ack_delay_us_; }
   bool partial() const { return partial_; }
-  const std::vector<uint64_t>& nack_seqs() const { return nack_seqs_; }
+
+  class NackSeqs {
+   public:
+    NackSeqs(const AckFrame* ack_frame) : ack_frame_(ack_frame) {}
+
+    class Iterator {
+     public:
+      Iterator(Brit brit, uint64_t base) : brit_(brit), base_(base) {}
+
+      bool operator!=(const Iterator& other) const {
+        return brit_ != other.brit_ || base_ != other.base_ ||
+               nack_ != other.nack_;
+      }
+
+      void operator++() {
+        nack_++;
+        if (nack_ == brit_->nacks) {
+          base_ += brit_->nacks + brit_->acks;
+          ++brit_;
+          nack_ = 0;
+        }
+      }
+
+      uint64_t operator*() const { return base_ + nack_; }
+
+     private:
+      Brit brit_;
+      uint64_t base_;
+      uint64_t nack_ = 0;
+    };
+
+    std::vector<uint64_t> AsVector() const {
+      std::vector<uint64_t> out;
+      for (auto n : *this) {
+        out.push_back(n);
+      }
+      return out;
+    }
+
+    Iterator begin() const {
+      if (ack_frame_->blocks_.empty()) {
+        return end();
+      }
+      return Iterator(ack_frame_->blocks_.rbegin(), ack_frame_->last_nack_);
+    }
+    Iterator end() const {
+      return Iterator(ack_frame_->blocks_.rend(), ack_frame_->ack_to_seq_ + 1);
+    }
+
+   private:
+    const AckFrame* ack_frame_;
+  };
+  NackSeqs nack_seqs() const { return NackSeqs(this); }
 
   // Move ack_to_seq back in time such that the total ack frame will fit
   // within mss. DelayFn is a function uint64_t -> TimeDelta that returns the
   // ack delay (in microseconds) for a given sequence number.
   template <class DelayFn>
   void AdjustForMSS(uint32_t mss, DelayFn delay_fn) {
-    while (!nack_seqs_.empty() && WrittenLength() > mss) {
+    while (!blocks_.empty() && WrittenLength() > mss) {
       partial_ = true;
-      if (ack_to_seq_ != nack_seqs_[0]) {
+      auto& block0_acks = blocks_[0].acks;
+      auto& block0_nacks = blocks_[0].nacks;
+      if (block0_acks > 0) {
+        auto new_acks = varint::SmallerRecordedNumber(block0_acks);
         OVERNET_TRACE(DEBUG) << "Trim too long ack (" << WrittenLength()
                              << " > " << mss << " by moving ack " << ack_to_seq_
-                             << " to first nack " << nack_seqs_[0];
-        ack_to_seq_ = nack_seqs_[0];
+                             << " to shorter first ack block length "
+                             << (ack_to_seq_ - block0_acks + new_acks);
+        ack_to_seq_ -= (block0_acks - new_acks);
+        block0_acks = new_acks;
       } else {
-        OVERNET_TRACE(DEBUG)
-            << "Trim too long ack (" << WrittenLength() << " > " << mss
-            << " by trimming first nack " << nack_seqs_[0];
-        nack_seqs_.erase(nack_seqs_.begin());
-        ack_to_seq_--;
+        assert(block0_nacks > 0);
+        auto new_nacks = varint::SmallerRecordedNumber(block0_nacks);
+        if (new_nacks == 0) {
+          OVERNET_TRACE(DEBUG)
+              << "Trim too long ack (" << WrittenLength() << " > " << mss
+              << " by eliminating first block and moving first ack to "
+              << (ack_to_seq_ - block0_nacks + new_nacks);
+          ack_to_seq_ -= (block0_nacks - new_nacks);
+          blocks_.erase(blocks_.begin());
+        } else {
+          OVERNET_TRACE(DEBUG)
+              << "Trim too long ack (" << WrittenLength() << " > " << mss
+              << " by moving ack " << ack_to_seq_
+              << " to shorter first nack block length "
+              << (ack_to_seq_ - block0_nacks + new_nacks);
+          ack_to_seq_ -= (block0_nacks - new_nacks);
+          block0_nacks = new_nacks;
+        }
       }
-      ack_delay_us_ = delay_fn(ack_to_seq_);
     }
   }
 
@@ -111,16 +204,17 @@
   // Flag indicating that this ack is only a partial acknowledgement, and
   // there's more to come.
   bool partial_ = false;
+
   // All messages with sequence number prior to ack_to_seq_ are implicitly
   // acknowledged.
   uint64_t ack_to_seq_;
   // How long between receiving ack_delay_seq_ and generating this data
   // structure.
   uint64_t ack_delay_us_;
-  // All messages contained in nack_seqs_ need to be resent.
-  // NOTE: it's assumed that nack_seqs_ is in-order descending and all value are
-  // less than or equal to ack_to_seq_.
-  std::vector<uint64_t> nack_seqs_;
+  // From ack_to_seq working back we record blocks. A block contains some number
+  // of acks followed by some number of nacks.
+  std::vector<Block> blocks_;
+  uint64_t last_nack_;
 };
 
 std::ostream& operator<<(std::ostream& out, const AckFrame& ack_frame);
diff --git a/garnet/lib/overnet/protocol/ack_frame_fuzzer.cc b/garnet/lib/overnet/protocol/ack_frame_fuzzer.cc
new file mode 100644
index 0000000..a05849b
--- /dev/null
+++ b/garnet/lib/overnet/protocol/ack_frame_fuzzer.cc
@@ -0,0 +1,27 @@
+// Copyright 2019 The Fuchsia Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#include "garnet/lib/overnet/protocol/ack_frame.h"
+
+using namespace overnet;
+
+extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
+  auto input = Slice::FromCopiedBuffer(data, size);
+  // Parse data as an ack frame.
+  auto status = AckFrame::Parse(input);
+  if (status.is_ok()) {
+    if (status->ack_to_seq() < 100000) {
+      for (auto n : status->nack_seqs()) {
+        [](auto) {}(n);
+      }
+    }
+    // If parsed ok: rewrite.
+    Slice written = Slice::FromWriters(AckFrame::Writer(status.get()));
+    // Should get the equivalent set of bytes.
+    auto status2 = AckFrame::Parse(written);
+    assert(status2.is_ok());
+    assert(*status == *status2);
+  }
+  return 0;  // Non-zero return values are reserved for future use.
+}
diff --git a/garnet/lib/overnet/protocol/ack_frame_test.cc b/garnet/lib/overnet/protocol/ack_frame_test.cc
index 76e4a7f..d670e3b 100644
--- a/garnet/lib/overnet/protocol/ack_frame_test.cc
+++ b/garnet/lib/overnet/protocol/ack_frame_test.cc
@@ -21,7 +21,7 @@
   auto v = Encode(h);
   EXPECT_EQ(expect, v);
   auto p = AckFrame::Parse(Slice::FromCopiedBuffer(v.data(), v.size()));
-  EXPECT_TRUE(p.is_ok());
+  EXPECT_TRUE(p.is_ok()) << p.AsStatus();
   EXPECT_EQ(h, *p.get());
 }
 
@@ -33,7 +33,8 @@
 TEST(AckFrame, OneNack) {
   AckFrame h(5, 10);
   h.AddNack(2);
-  RoundTrip(h, {5, 20, 3});
+  RoundTrip(h, {5, 20, 3, 1});
+  EXPECT_EQ(h.nack_seqs().AsVector(), std::vector<uint64_t>({2}));
 }
 
 TEST(AckFrame, ThreeNacks) {
@@ -41,7 +42,52 @@
   h.AddNack(4);
   h.AddNack(3);
   h.AddNack(2);
-  RoundTrip(h, {5, 84, 1, 1, 1});
+  RoundTrip(h, {5, 84, 1, 3});
+  EXPECT_EQ(h.nack_seqs().AsVector(), std::vector<uint64_t>({2, 3, 4}));
+}
+
+TEST(AckFrame, TwoBlocks) {
+  AckFrame h(20, 42);
+  h.AddNack(15);
+  h.AddNack(14);
+  h.AddNack(13);
+  h.AddNack(5);
+  h.AddNack(4);
+  h.AddNack(3);
+  h.AddNack(2);
+  h.AddNack(1);
+  RoundTrip(h, {20, 84, 5, 3, 7, 5});
+  EXPECT_EQ(h.nack_seqs().AsVector(),
+            std::vector<uint64_t>({1, 2, 3, 4, 5, 13, 14, 15}));
+}
+
+TEST(AckFrame, FuzzedExamples) {
+  auto test = [](std::initializer_list<uint8_t> bytes) {
+    auto input = Slice::FromContainer(bytes);
+    std::cerr << "Test: " << input << "\n";
+    if (auto p = AckFrame::Parse(input); p.is_ok()) {
+      {
+        auto ns = p->nack_seqs();
+        const auto begin = ns.begin();
+        const auto end = ns.end();
+        for (auto it = begin; it != end; ++it) {
+          [](auto) {}(*it);
+        }
+      }
+      std::cerr << "Parsed: " << *p << "\n";
+      auto written = Slice::FromWriters(AckFrame::Writer(p.get()));
+      auto p2 = AckFrame::Parse(written);
+      EXPECT_TRUE(p2.is_ok());
+      EXPECT_EQ(*p, *p2);
+    } else {
+      std::cerr << "Parse error: " << p.AsStatus() << "\n";
+    }
+  };
+  test({0x0a, 0x0a, 0x00, 0x00});
+  test({0xc1, 0xe0, 0x00, 0x2d});
+  test({0x80, 0xcd, 0xcd, 0xcd, 0xcd, 0x2b, 0x00, 0x2f, 0xcd, 0xcd, 0xf9, 0xe4,
+        0x00, 0x51});
+  test({0x65, 0x01, 0x01, 0x02});
 }
 
 }  // namespace ack_frame_test
diff --git a/garnet/lib/overnet/protocol/varint.cc b/garnet/lib/overnet/protocol/varint.cc
index b95af61..8709499 100644
--- a/garnet/lib/overnet/protocol/varint.cc
+++ b/garnet/lib/overnet/protocol/varint.cc
@@ -30,6 +30,28 @@
   return 10;
 }
 
+uint64_t SmallerRecordedNumber(uint64_t x) {
+  if (x < (1ull << 7))
+    return 0;
+  if (x < (1ull << 14))
+    return (1ull << 7) - 1;
+  if (x < (1ull << 21))
+    return (1ull << 14) - 1;
+  if (x < (1ull << 28))
+    return (1ull << 21) - 1;
+  if (x < (1ull << 35))
+    return (1ull << 28) - 1;
+  if (x < (1ull << 42))
+    return (1ull << 35) - 1;
+  if (x < (1ull << 49))
+    return (1ull << 42) - 1;
+  if (x < (1ull << 56))
+    return (1ull << 49) - 1;
+  if (x < (1ull << 63))
+    return (1ull << 56) - 1;
+  return (1ull << 63) - 1;
+}
+
 uint64_t MaximumLengthWithPrefix(uint64_t x) {
   assert(x > 0);
   uint64_t r = x - WireSizeFor(x);
diff --git a/garnet/lib/overnet/protocol/varint.h b/garnet/lib/overnet/protocol/varint.h
index df2cb80..3bdf52a 100644
--- a/garnet/lib/overnet/protocol/varint.h
+++ b/garnet/lib/overnet/protocol/varint.h
@@ -17,6 +17,11 @@
 // convenience
 uint8_t* Write(uint64_t x, uint8_t wire_length, uint8_t* dst);
 
+// Variant of Write that does not need the pre-computed length
+inline uint8_t* Write(uint64_t x, uint8_t* dst) {
+  return Write(x, WireSizeFor(x), dst);
+}
+
 namespace impl {
 bool ReadFallback(const uint8_t** bytes, const uint8_t* end, uint64_t* result);
 }
@@ -35,6 +40,11 @@
 // such that the total length does not exceed fit_to?
 uint64_t MaximumLengthWithPrefix(uint64_t fit_to);
 
+// Returns the largest number n, n < x, such that the number of bytes to record
+// n is less than the number of bytes to record x (or zero if this is not
+// possible)
+uint64_t SmallerRecordedNumber(uint64_t x);
+
 }  // namespace varint
 
 }  // namespace overnet