[wlan][mesh] Implement StopRequest in MeshMlme

WLAN-942

TEST=Unit test included.

Change-Id: I4afe9feaaac6bf77c59fcfce26f3d27cbc1c6b9e
diff --git a/lib/wlan/mlme/include/wlan/mlme/mesh/mesh_mlme.h b/lib/wlan/mlme/include/wlan/mlme/mesh/mesh_mlme.h
index 96ce614..9b83026 100644
--- a/lib/wlan/mlme/include/wlan/mlme/mesh/mesh_mlme.h
+++ b/lib/wlan/mlme/include/wlan/mlme/mesh/mesh_mlme.h
@@ -33,6 +33,7 @@
 
     ::fuchsia::wlan::mlme::StartResultCodes Start(
         const MlmeMsg<::fuchsia::wlan::mlme::StartRequest>& req);
+    ::fuchsia::wlan::mlme::StopResultCodes Stop();
     void SendPeeringOpen(const MlmeMsg<::fuchsia::wlan::mlme::MeshPeeringOpenAction>& req);
     void SendPeeringConfirm(const MlmeMsg<::fuchsia::wlan::mlme::MeshPeeringConfirmAction>& req);
     void ConfigurePeering(const MlmeMsg<::fuchsia::wlan::mlme::MeshPeeringParams>& params);
@@ -62,14 +63,18 @@
 
     MacHeaderWriter CreateMacHeaderWriter();
 
+    struct MeshState {
+        HwmpState hwmp;
+        PathTable path_table;
+        DeDuplicator deduplicator;
+
+        explicit MeshState(fbl::unique_ptr<Timer> timer);
+    };
+
     DeviceInterface* const device_;
-    bool joined_ = false;
     Sequence seq_;
     uint32_t mesh_seq_ = 0;
-    std::unique_ptr<HwmpState> hwmp_;
-    PathTable path_table_;
-
-    DeDuplicator deduplicator_;
+    std::optional<MeshState> state_;
 };
 
 }  // namespace wlan
diff --git a/lib/wlan/mlme/include/wlan/mlme/service.h b/lib/wlan/mlme/include/wlan/mlme/service.h
index b348ec7..f390d0e 100644
--- a/lib/wlan/mlme/include/wlan/mlme/service.h
+++ b/lib/wlan/mlme/include/wlan/mlme/service.h
@@ -163,6 +163,7 @@
                                 const common::MacAddr& src, const common::MacAddr& dst);
 
 zx_status_t SendStartConfirm(DeviceInterface* device, ::fuchsia::wlan::mlme::StartResultCodes code);
+zx_status_t SendStopConfirm(DeviceInterface* device, ::fuchsia::wlan::mlme::StopResultCodes code);
 
 }  // namespace service
 
diff --git a/lib/wlan/mlme/mesh/mesh_mlme.cpp b/lib/wlan/mlme/mesh/mesh_mlme.cpp
index f416ba8..7b4a73a 100644
--- a/lib/wlan/mlme/mesh/mesh_mlme.cpp
+++ b/lib/wlan/mlme/mesh/mesh_mlme.cpp
@@ -74,20 +74,12 @@
     return BuildBeacon(c, buffer, tim_ele_offset);
 }
 
-MeshMlme::MeshMlme(DeviceInterface* device)
-    : device_(device), deduplicator_(kMaxReceivedFrameCacheSize) {}
+MeshMlme::MeshState::MeshState(fbl::unique_ptr<Timer> timer)
+    : hwmp(std::move(timer)), deduplicator(kMaxReceivedFrameCacheSize) {}
+
+MeshMlme::MeshMlme(DeviceInterface* device) : device_(device) {}
 
 zx_status_t MeshMlme::Init() {
-    fbl::unique_ptr<Timer> timer;
-    ObjectId timer_id;
-    timer_id.set_subtype(to_enum_type(ObjectSubtype::kTimer));
-    timer_id.set_target(to_enum_type(ObjectTarget::kHwmp));
-    zx_status_t status = device_->GetTimer(ToPortKey(PortKeyType::kMlme, timer_id.val()), &timer);
-    if (status != ZX_OK) {
-        errorf("[mesh-mlme] Failed to create the HWMP timer: %s\n", zx_status_get_string(status));
-        return status;
-    }
-    hwmp_ = std::make_unique<HwmpState>(std::move(timer));
     return ZX_OK;
 }
 
@@ -95,6 +87,9 @@
     if (auto start_req = msg.As<wlan_mlme::StartRequest>()) {
         auto code = Start(*start_req);
         return service::SendStartConfirm(device_, code);
+    } else if (auto stop_req = msg.As<wlan_mlme::StopRequest>()) {
+        auto code = Stop();
+        return service::SendStopConfirm(device_, code);
     } else if (auto mp_open = msg.As<wlan_mlme::MeshPeeringOpenAction>()) {
         SendPeeringOpen(*mp_open);
         return ZX_OK;
@@ -110,10 +105,20 @@
 }
 
 wlan_mlme::StartResultCodes MeshMlme::Start(const MlmeMsg<wlan_mlme::StartRequest>& req) {
-    if (joined_) { return wlan_mlme::StartResultCodes::BSS_ALREADY_STARTED_OR_JOINED; }
+    if (state_) { return wlan_mlme::StartResultCodes::BSS_ALREADY_STARTED_OR_JOINED; }
+
+    fbl::unique_ptr<Timer> timer;
+    ObjectId timer_id;
+    timer_id.set_subtype(to_enum_type(ObjectSubtype::kTimer));
+    timer_id.set_target(to_enum_type(ObjectTarget::kHwmp));
+    zx_status_t status = device_->GetTimer(ToPortKey(PortKeyType::kMlme, timer_id.val()), &timer);
+    if (status != ZX_OK) {
+        errorf("[mesh-mlme] Failed to create the HWMP timer: %s\n", zx_status_get_string(status));
+        return wlan_mlme::StartResultCodes::INTERNAL_ERROR;
+    }
 
     wlan_channel_t channel = GetChannel(req.body()->channel);
-    zx_status_t status = device_->SetChannel(channel);
+    status = device_->SetChannel(channel);
     if (status != ZX_OK) {
         errorf("[mesh-mlme] failed to set channel to %s: %s\n", common::ChanStr(channel).c_str(),
                zx_status_get_string(status));
@@ -139,10 +144,26 @@
     }
 
     device_->SetStatus(ETHMAC_STATUS_ONLINE);
-    joined_ = true;
+    state_.emplace(std::move(timer));
     return wlan_mlme::StartResultCodes::SUCCESS;
 }
 
+wlan_mlme::StopResultCodes MeshMlme::Stop() {
+    if (!state_) { return wlan_mlme::StopResultCodes::BSS_ALREADY_STOPPED; };
+
+    // TODO(gbonik): call clear_assoc for all peers once we have a list of peers
+
+    zx_status_t status = device_->EnableBeaconing(nullptr);
+    if (status != ZX_OK) {
+        errorf("[mesh-mlme] failed to disable beaconing: %s\n", zx_status_get_string(status));
+        return wlan_mlme::StopResultCodes::INTERNAL_ERROR;
+    }
+
+    device_->SetStatus(0);
+    state_.reset();
+    return wlan_mlme::StopResultCodes::SUCCESS;
+}
+
 void MeshMlme::SendPeeringOpen(const MlmeMsg<wlan_mlme::MeshPeeringOpenAction>& req) {
     auto packet = GetWlanPacket(kMaxMeshMgmtFrameSize);
     if (packet == nullptr) { return; }
@@ -218,6 +239,8 @@
 }
 
 void MeshMlme::HandleEthTx(EthFrame&& frame) {
+    if (!state_) { return; }
+
     auto packet = GetWlanPacket(GetDataFrameBufferSize(frame.body_len()));
     if (packet == nullptr) { return; }
     BufferWriter w(*packet);
@@ -234,10 +257,10 @@
             w.WriteValue(frame.hdr()->src);
         }
     } else {
-        auto proxy_info = path_table_.GetProxyInfo(frame.hdr()->dest);
+        auto proxy_info = state_->path_table.GetProxyInfo(frame.hdr()->dest);
         auto mesh_dest = proxy_info == nullptr ? frame.hdr()->dest : proxy_info->mesh_target;
 
-        auto path = path_table_.GetPath(mesh_dest);
+        auto path = state_->path_table.GetPath(mesh_dest);
         if (path == nullptr) {
             // TODO(gbonik): buffer the frame
             TriggerPathDiscovery(mesh_dest);
@@ -263,6 +286,8 @@
 }
 
 zx_status_t MeshMlme::HandleAnyWlanFrame(fbl::unique_ptr<Packet> pkt) {
+    if (!state_) { return ZX_OK; }
+
     if (auto possible_mgmt_frame = MgmtFrameView<>::CheckType(pkt.get())) {
         if (auto mgmt_frame = possible_mgmt_frame.CheckLength()) {
             return HandleAnyMgmtFrame(mgmt_frame.IntoOwned(std::move(pkt)));
@@ -313,14 +338,17 @@
 }
 
 void MeshMlme::HandleMeshAction(const MgmtFrameHeader& mgmt, BufferReader* r) {
+    ZX_ASSERT(state_);
+
     auto mesh_action_header = r->Read<MeshActionHeader>();
     if (mesh_action_header == nullptr) { return; }
 
     switch (mesh_action_header->mesh_action) {
     case action::kHwmpMeshPathSelection: {
         // TODO(gbonik): pass the actual airtime metric
-        auto packets_to_tx = HandleHwmpAction(r->ReadRemaining(), mgmt.addr2, self_addr(), 100,
-                                              CreateMacHeaderWriter(), hwmp_.get(), &path_table_);
+        auto packets_to_tx =
+            HandleHwmpAction(r->ReadRemaining(), mgmt.addr2, self_addr(), 100,
+                             CreateMacHeaderWriter(), &state_->hwmp, &state_->path_table);
         while (!packets_to_tx.is_empty()) {
             SendMgmtFrame(packets_to_tx.Dequeue());
         }
@@ -340,9 +368,11 @@
 }
 
 void MeshMlme::TriggerPathDiscovery(const common::MacAddr& target) {
+    ZX_ASSERT(state_);
+
     PacketQueue packets_to_tx;
     zx_status_t status = InitiatePathDiscovery(target, self_addr(), CreateMacHeaderWriter(),
-                                               hwmp_.get(), path_table_, &packets_to_tx);
+                                               &state_->hwmp, state_->path_table, &packets_to_tx);
     if (status != ZX_OK) {
         errorf("[mesh-mlme] Failed to initiate path discovery: %s\n", zx_status_get_string(status));
         return;
@@ -408,7 +438,9 @@
     // TODO(gbonik): drop frames from non-peers
 
     // Drop if duplicate
-    if (deduplicator_.DeDuplicate(GetMeshSrcAddr(*header), header->mesh_ctrl->seq)) { return; }
+    if (state_->deduplicator.DeDuplicate(GetMeshSrcAddr(*header), header->mesh_ctrl->seq)) {
+        return;
+    }
 
     if (ShouldDeliverData(header->mac_header)) { DeliverData(*header, *packet, r.ReadBytes()); }
 
@@ -463,7 +495,7 @@
     if (header.mac_header.addr4 != nullptr) {
         // Individually addressed frame: addr3 is the mesh destination
         if (header.mac_header.fixed->addr3 == self_addr()) { return {}; }
-        auto path = path_table_.GetPath(header.mac_header.fixed->addr3);
+        auto path = state_->path_table.GetPath(header.mac_header.fixed->addr3);
         if (path == nullptr) { return {}; }
         return {path->next_hop};
     } else {
@@ -488,11 +520,13 @@
 }
 
 zx_status_t MeshMlme::HandleTimeout(const ObjectId id) {
+    if (!state_) { return ZX_OK; }
+
     switch (id.target()) {
     case to_enum_type(ObjectTarget::kHwmp): {
         PacketQueue packets_to_tx;
-        zx_status_t status = HandleHwmpTimeout(self_addr(), CreateMacHeaderWriter(), hwmp_.get(),
-                                               path_table_, &packets_to_tx);
+        zx_status_t status = HandleHwmpTimeout(self_addr(), CreateMacHeaderWriter(), &state_->hwmp,
+                                               state_->path_table, &packets_to_tx);
         if (status != ZX_OK) {
             errorf("[mesh-mlme] Failed to rearm the HWMP timer: %s\n",
                    zx_status_get_string(status));
diff --git a/lib/wlan/mlme/service.cpp b/lib/wlan/mlme/service.cpp
index 224c4a8..9cdc24a 100644
--- a/lib/wlan/mlme/service.cpp
+++ b/lib/wlan/mlme/service.cpp
@@ -151,5 +151,11 @@
     return SendServiceMsg(device, &msg, fuchsia_wlan_mlme_MLMEStartConfOrdinal);
 }
 
+zx_status_t SendStopConfirm(DeviceInterface* device, wlan_mlme::StopResultCodes code) {
+    wlan_mlme::StopConfirm msg;
+    msg.result_code = code;
+    return SendServiceMsg(device, &msg, fuchsia_wlan_mlme_MLMEStopConfOrdinal);
+}
+
 }  // namespace service
 }  // namespace wlan
diff --git a/lib/wlan/mlme/tests/mesh_mlme_unittest.cpp b/lib/wlan/mlme/tests/mesh_mlme_unittest.cpp
index 1ba4238..c0b2f0a 100644
--- a/lib/wlan/mlme/tests/mesh_mlme_unittest.cpp
+++ b/lib/wlan/mlme/tests/mesh_mlme_unittest.cpp
@@ -14,24 +14,58 @@
 
 namespace wlan {
 
+struct MeshMlmeTest : public ::testing::Test {
+    MeshMlmeTest() : mlme(&device) {
+        device.state->set_address(common::MacAddr("11:11:11:11:11:11"));
+        mlme.Init();
+    }
+
+    wlan_mlme::StartResultCodes JoinMesh() {
+        wlan_mlme::StartRequest join;
+        zx_status_t status =
+            mlme.HandleMlmeMsg(MlmeMsg<wlan_mlme::StartRequest>(std::move(join), 123));
+        EXPECT_EQ(ZX_OK, status);
+
+        auto msgs = device.GetServiceMsgs<wlan_mlme::StartConfirm>();
+        EXPECT_EQ(msgs.size(), 1ULL);
+        return msgs[0].body()->result_code;
+    }
+
+    wlan_mlme::StopResultCodes LeaveMesh() {
+        wlan_mlme::StopRequest leave;
+        zx_status_t status =
+            mlme.HandleMlmeMsg(MlmeMsg<wlan_mlme::StopRequest>(std::move(leave), 123));
+        EXPECT_EQ(ZX_OK, status);
+
+        auto msgs = device.GetServiceMsgs<wlan_mlme::StopConfirm>();
+        EXPECT_EQ(msgs.size(), 1ULL);
+        return msgs[0].body()->result_code;
+    }
+
+    MockDevice device;
+    MeshMlme mlme;
+};
+
 static fbl::unique_ptr<Packet> MakeWlanPacket(Span<const uint8_t> bytes) {
     auto packet = GetWlanPacket(bytes.size());
     memcpy(packet->data(), bytes.data(), bytes.size());
     return packet;
 }
 
-static void JoinMesh(MeshMlme* mlme) {
-    wlan_mlme::StartRequest join;
-    zx_status_t status =
-        mlme->HandleMlmeMsg(MlmeMsg<wlan_mlme::StartRequest>(std::move(join), 123));
-    EXPECT_EQ(ZX_OK, status);
+TEST_F(MeshMlmeTest, JoinLeave) {
+    EXPECT_EQ(LeaveMesh(), wlan_mlme::StopResultCodes::BSS_ALREADY_STOPPED);
+    EXPECT_EQ(JoinMesh(), wlan_mlme::StartResultCodes::SUCCESS);
+    EXPECT_TRUE(device.beaconing_enabled);
+    EXPECT_EQ(JoinMesh(), wlan_mlme::StartResultCodes::BSS_ALREADY_STARTED_OR_JOINED);
+    EXPECT_EQ(LeaveMesh(), wlan_mlme::StopResultCodes::SUCCESS);
+    EXPECT_FALSE(device.beaconing_enabled);
+    EXPECT_EQ(LeaveMesh(), wlan_mlme::StopResultCodes::BSS_ALREADY_STOPPED);
+    EXPECT_EQ(JoinMesh(), wlan_mlme::StartResultCodes::SUCCESS);
+    EXPECT_TRUE(device.beaconing_enabled);
 }
 
-TEST(MeshMlme, HandleMpmOpen) {
-    MockDevice device;
-    MeshMlme mlme(&device);
-    mlme.Init();
-    JoinMesh(&mlme);
+TEST_F(MeshMlmeTest, HandleMpmOpen) {
+    EXPECT_EQ(JoinMesh(), wlan_mlme::StartResultCodes::SUCCESS);
 
     // clang-format off
     const uint8_t frame[] = {
@@ -69,12 +103,8 @@
     }
 }
 
-TEST(MeshMlme, DeliverProxiedData) {
-    MockDevice device;
-    device.state->set_address(common::MacAddr("11:11:11:11:11:11"));
-    MeshMlme mlme(&device);
-    mlme.Init();
-    JoinMesh(&mlme);
+TEST_F(MeshMlmeTest, DeliverProxiedData) {
+    EXPECT_EQ(JoinMesh(), wlan_mlme::StartResultCodes::SUCCESS);
 
     // Simulate receiving a data frame
     zx_status_t status = mlme.HandleFramePacket(test_utils::MakeWlanPacket({
@@ -122,12 +152,51 @@
     EXPECT_RANGES_EQ(expected, eth_frames[0]);
 }
 
-TEST(MeshMlme, HandlePreq) {
-    MockDevice device;
-    device.state->set_address(common::MacAddr("11:11:11:11:11:11"));
-    MeshMlme mlme(&device);
-    mlme.Init();
-    JoinMesh(&mlme);
+TEST_F(MeshMlmeTest, DoNotDeliverWhenNotJoined) {
+    auto packet = [] (uint8_t mesh_seq) {
+        return test_utils::MakeWlanPacket({
+            // clang-format off
+            // Data header
+            0x88, 0x03, // fc: qos data, 4-address, no ht ctl
+            0x00, 0x00, // duration
+            0x11, 0x11, 0x11, 0x11, 0x11, 0x11, // addr1
+            0x22, 0x22, 0x22, 0x22, 0x22, 0x22, // addr2
+            0x11, 0x11, 0x11, 0x11, 0x11, 0x11, // addr3: mesh da = ra
+            0x00, 0x00, // seq ctl
+            0x44, 0x44, 0x44, 0x44, 0x44, 0x44, // addr4
+            0x00, 0x01, // qos ctl: mesh control present
+            // Mesh control
+            0x00, 0x20, // flags, ttl
+            mesh_seq, 0xbb, 0xcc, 0xdd, // seq
+            // LLC header
+            0xaa, 0xaa, 0x03, // dsap ssap ctrl
+            0x00, 0x00, 0x00, // oui
+            0x12, 0x34, // protocol id
+            // Payload
+            0xde, 0xad, 0xbe, 0xef,
+            // clang-format on
+        });
+    };
+
+    // Receive a frame while not joined: expect it to be dropped
+    EXPECT_EQ(mlme.HandleFramePacket(packet(1)), ZX_OK);
+    EXPECT_TRUE(device.GetEthPackets().empty());
+
+    EXPECT_EQ(JoinMesh(), wlan_mlme::StartResultCodes::SUCCESS);
+
+    // Receive a frame while joined: expect it to be delivered
+    EXPECT_EQ(mlme.HandleFramePacket(packet(2)), ZX_OK);
+    EXPECT_EQ(device.GetEthPackets().size(), 1u);
+
+    EXPECT_EQ(LeaveMesh(), wlan_mlme::StopResultCodes::SUCCESS);
+
+    // Again, receive a frame while not joined: expect it to be dropped
+    EXPECT_EQ(mlme.HandleFramePacket(packet(3)), ZX_OK);
+    EXPECT_TRUE(device.GetEthPackets().empty());
+}
+
+TEST_F(MeshMlmeTest, HandlePreq) {
+    EXPECT_EQ(JoinMesh(), wlan_mlme::StartResultCodes::SUCCESS);
 
     zx_status_t status = mlme.HandleFramePacket(test_utils::MakeWlanPacket({
         // clang-format off
@@ -170,12 +239,8 @@
     EXPECT_EQ(packet.data()[26], 131);  // prep element
 }
 
-TEST(MeshMlme, DeliverDuplicateData) {
-    MockDevice device;
-    device.state->set_address(common::MacAddr("11:11:11:11:11:11"));
-    MeshMlme mlme(&device);
-    mlme.Init();
-    JoinMesh(&mlme);
+TEST_F(MeshMlmeTest, DeliverDuplicateData) {
+    EXPECT_EQ(JoinMesh(), wlan_mlme::StartResultCodes::SUCCESS);
 
     auto mesh_packet = [](uint8_t addr, uint8_t seq, uint8_t data) {
         // clang-format off
@@ -272,12 +337,8 @@
     }
 }
 
-TEST(MeshMlme, DataForwarding) {
-    MockDevice device;
-    device.state->set_address(common::MacAddr("11:11:11:11:11:11"));
-    MeshMlme mlme(&device);
-    mlme.Init();
-    JoinMesh(&mlme);
+TEST_F(MeshMlmeTest, DataForwarding) {
+    EXPECT_EQ(JoinMesh(), wlan_mlme::StartResultCodes::SUCCESS);
 
     // Receive a PREP to establish a path to 33:33:33:33:33:33 via 22:22:22:22:22:22
     zx_status_t status = mlme.HandleFramePacket(test_utils::MakeWlanPacket({