[wlan][mesh] Join/Leave state machine for wlan-sme

WLAN-942

TEST=Filed WLAN-944 to track unit tests. Manually verified that
joining/leaving the mesh repeatedly works.

Change-Id: I1deb9413911b87f6d1e0724e5654ddc00a88598e
diff --git a/bin/wlan/wlanstack/src/station/mesh.rs b/bin/wlan/wlanstack/src/station/mesh.rs
index 643fa4c..96cece2 100644
--- a/bin/wlan/wlanstack/src/station/mesh.rs
+++ b/bin/wlan/wlanstack/src/station/mesh.rs
@@ -34,6 +34,7 @@
 
 impl mesh_sme::Tokens for Tokens {
     type JoinToken = oneshot::Sender<mesh_sme::JoinMeshResult>;
+    type LeaveToken = oneshot::Sender<mesh_sme::LeaveMeshResult>;
 }
 
 pub type Endpoint = fidl::endpoints::ServerEnd<fidl_sme::MeshSmeMarker>;
@@ -89,6 +90,7 @@
 fn handle_user_event(e: UserEvent<Tokens>) {
     match e {
         UserEvent::JoinMeshFinished { token, result } => token.send(result).unwrap_or_else(|_| ()),
+        UserEvent::LeaveMeshFinished { token, result } => token.send(result).unwrap_or_else(|_| ()),
     }
 }
 
@@ -117,6 +119,10 @@
             let code = await!(join_mesh(sme, config));
             responder.send(code)
         },
+        fidl_sme::MeshSmeRequest::Leave { responder } => {
+            let code = await!(leave_mesh(sme));
+            responder.send(code)
+        },
     }
 }
 
@@ -138,8 +144,26 @@
 fn convert_join_mesh_result(result: mesh_sme::JoinMeshResult) -> fidl_sme::JoinMeshResultCode {
     match result {
         mesh_sme::JoinMeshResult::Success => fidl_sme::JoinMeshResultCode::Success,
+        mesh_sme::JoinMeshResult::Canceled => fidl_sme::JoinMeshResultCode::Canceled,
         mesh_sme::JoinMeshResult::InternalError => fidl_sme::JoinMeshResultCode::InternalError,
         mesh_sme::JoinMeshResult::InvalidArguments => fidl_sme::JoinMeshResultCode::InvalidArguments,
         mesh_sme::JoinMeshResult::DfsUnsupported => fidl_sme::JoinMeshResultCode::DfsUnsupported,
     }
 }
+
+async fn leave_mesh(sme: Arc<Mutex<Sme>>) -> fidl_sme::LeaveMeshResultCode {
+    let (sender, receiver) = oneshot::channel();
+    sme.lock().unwrap().on_leave_command(sender);
+    let r = await!(receiver).unwrap_or_else(|_| {
+        error!("Responder for Leave Mesh command was dropped without sending a response");
+        mesh_sme::LeaveMeshResult::InternalError
+    });
+    convert_leave_mesh_result(r)
+}
+
+fn convert_leave_mesh_result(result: mesh_sme::LeaveMeshResult) -> fidl_sme::LeaveMeshResultCode {
+    match result {
+        mesh_sme::LeaveMeshResult::Success => fidl_sme::LeaveMeshResultCode::Success,
+        mesh_sme::LeaveMeshResult::InternalError => fidl_sme::LeaveMeshResultCode::InternalError,
+    }
+}
diff --git a/bin/wlan/wlantool/src/main.rs b/bin/wlan/wlantool/src/main.rs
index b9108ae..603839e 100644
--- a/bin/wlan/wlantool/src/main.rs
+++ b/bin/wlan/wlantool/src/main.rs
@@ -250,6 +250,11 @@
                 _ => { println!("{:?}", r ); },
             }
         },
+        opts::MeshCmd::Leave { iface_id } => {
+            let sme = await!(get_mesh_sme(wlan_svc, iface_id))?;
+            let r = await!(sme.leave());
+            println!("{:?}", r);
+        }
     }
     Ok(())
 }
diff --git a/bin/wlan/wlantool/src/opts.rs b/bin/wlan/wlantool/src/opts.rs
index 79ae07a..d935d23 100644
--- a/bin/wlan/wlantool/src/opts.rs
+++ b/bin/wlan/wlantool/src/opts.rs
@@ -252,4 +252,9 @@
         // TODO(porce): Expand to support PHY and CBW
         channel: u8,
     },
+    #[structopt(name = "leave")]
+    Leave {
+        #[structopt(raw(required = "true"))]
+        iface_id: u16,
+    }
 }
diff --git a/lib/rust/wlan-sme/src/mesh/mod.rs b/lib/rust/wlan-sme/src/mesh/mod.rs
index 72a5599..0ac253b 100644
--- a/lib/rust/wlan-sme/src/mesh/mod.rs
+++ b/lib/rust/wlan-sme/src/mesh/mod.rs
@@ -7,6 +7,7 @@
     fidl_fuchsia_wlan_mlme::{self as fidl_mlme, MlmeEvent},
     futures::channel::mpsc,
     log::{error},
+    std::mem,
     wlan_common::channel::{Channel, Cbw},
     crate::{
         clone_utils,
@@ -27,6 +28,7 @@
 // trait that enables us to group them into a single generic parameter.
 pub trait Tokens {
     type JoinToken;
+    type LeaveToken;
 }
 
 mod internal {
@@ -36,14 +38,56 @@
 
 pub type UserStream<T> = mpsc::UnboundedReceiver<UserEvent<T>>;
 
+// A list of pending join/leave requests to be maintained in the intermediate
+// 'Joining'/'Leaving' states where we are waiting for a reply from MLME and cannot
+// serve the requests immediately.
+struct PendingRequests<T: Tokens> {
+    leave: Vec<T::LeaveToken>,
+    join: Option<(T::JoinToken, Config)>,
+}
+
+impl<T: Tokens> PendingRequests<T> {
+    pub fn new() -> Self {
+        PendingRequests { leave: Vec::new(), join: None }
+    }
+
+    pub fn enqueue_leave(&mut self, user_sink: &UserSink<T>, token: T::LeaveToken) {
+        self.replace_join_request(user_sink, None);
+        self.leave.push(token);
+    }
+
+    pub fn enqueue_join(&mut self, user_sink: &UserSink<T>, token: T::JoinToken, config: Config) {
+        self.replace_join_request(user_sink, Some((token, config)));
+    }
+
+    pub fn is_empty(&self) -> bool {
+        self.leave.is_empty() && self.join.is_none()
+    }
+
+    fn replace_join_request(
+        &mut self,
+        user_sink: &UserSink<T>,
+        req: Option<(T::JoinToken, Config)>)
+    {
+        if let Some((old_token, _)) = mem::replace(&mut self.join, req) {
+            report_join_finished(user_sink, old_token, JoinMeshResult::Canceled);
+        }
+    }
+}
+
 enum State<T: Tokens> {
     Idle,
     Joining {
         token: T::JoinToken,
         config: Config,
+        pending: PendingRequests<T>,
     },
     Joined {
         config: Config,
+    },
+    Leaving {
+        config: Config,
+        pending: PendingRequests<T>,
     }
 }
 
@@ -65,11 +109,18 @@
 #[derive(Clone, Copy, Debug, PartialEq)]
 pub enum JoinMeshResult {
     Success,
+    Canceled,
     InternalError,
     InvalidArguments,
     DfsUnsupported,
 }
 
+#[derive(Clone, Copy, Debug)]
+pub enum LeaveMeshResult {
+    Success,
+    InternalError,
+}
+
 // A message from the Mesh node to a user or a group of listeners
 #[derive(Debug)]
 pub enum UserEvent<T: Tokens> {
@@ -77,28 +128,78 @@
         token: T::JoinToken,
         result: JoinMeshResult,
     },
+    LeaveMeshFinished {
+        token: T::LeaveToken,
+        result: LeaveMeshResult,
+    }
 }
 
 impl<T: Tokens> MeshSme<T> {
-    pub fn on_join_command(&mut self, token: T::JoinToken, config: Config){
+    pub fn on_join_command(&mut self, token: T::JoinToken, config: Config) {
+        if let Err(result) = validate_config(&config) {
+            report_join_finished(&self.user_sink, token, result);
+            return;
+        }
         self.state = Some(match self.state.take().unwrap() {
             State::Idle => {
-                if let Err(result) = validate_config(&config) {
-                    report_join_finished(&self.user_sink, token, result);
-                    State::Idle
-                } else {
-                    self.mlme_sink.send(MlmeRequest::Start(create_start_request(&config)));
-                    State::Joining { token, config }
-                }
+                self.mlme_sink.send(MlmeRequest::Start(create_start_request(&config)));
+                State::Joining { token, pending: PendingRequests::new(), config }
             },
-            s@ State::Joining { .. } | s@ State::Joined { .. } => {
-                // TODO(gbonik): Leave then re-join
-                error!("cannot join mesh because already joined or joining");
-                report_join_finished(&self.user_sink, token, JoinMeshResult::InternalError);
-                s
+            State::Joining { token: cur_token, config: cur_config, mut pending } => {
+                pending.enqueue_join(&self.user_sink, token, config);
+                State::Joining { token: cur_token, config: cur_config, pending }
+            },
+            State::Joined { config: cur_config } => {
+                self.mlme_sink.send(MlmeRequest::Stop(create_stop_request()));
+                let mut pending = PendingRequests::new();
+                pending.enqueue_join(&self.user_sink, token, config);
+                State::Leaving { config: cur_config, pending }
+            },
+            State::Leaving { config: cur_config, mut pending } => {
+                pending.enqueue_join(&self.user_sink, token, config);
+                State::Leaving { config: cur_config, pending }
             }
         });
     }
+
+    pub fn on_leave_command(&mut self, token: T::LeaveToken) {
+        self.state = Some(match self.state.take().unwrap() {
+            State::Idle => {
+                report_leave_finished(&self.user_sink, token, LeaveMeshResult::Success);
+                State::Idle
+            },
+            State::Joining { token: cur_token, config, mut pending } => {
+                pending.enqueue_leave(&self.user_sink, token);
+                State::Joining { token: cur_token, pending, config }
+            },
+            State::Joined { config } => {
+                self.mlme_sink.send(MlmeRequest::Stop(create_stop_request()));
+                let mut pending = PendingRequests::new();
+                pending.enqueue_leave(&self.user_sink, token);
+                State::Leaving { config, pending }
+            },
+            State::Leaving { config, mut pending } => {
+                pending.enqueue_leave(&self.user_sink, token);
+                State::Leaving { config, pending }
+            }
+        });
+    }
+}
+
+fn on_back_to_idle<T: Tokens>(
+    pending: PendingRequests<T>,
+    user_sink: &UserSink<T>,
+    mlme_sink: &MlmeSink
+) -> State<T> {
+    for token in pending.leave {
+        report_leave_finished(user_sink, token, LeaveMeshResult::Success);
+    }
+    if let Some((token, config)) = pending.join {
+        mlme_sink.send(MlmeRequest::Start(create_start_request(&config)));
+        State::Joining { token, config, pending: PendingRequests::new() }
+    } else {
+        State::Idle
+    }
 }
 
 fn validate_config(config: &Config) -> Result<(), JoinMeshResult> {
@@ -131,25 +232,37 @@
     }
 }
 
+fn create_stop_request() -> fidl_mlme::StopRequest {
+    fidl_mlme::StopRequest { ssid: vec![], }
+}
+
 impl<T: Tokens> super::Station for MeshSme<T> {
     type Event = ();
 
     fn on_mlme_event(&mut self, event: MlmeEvent) {
         self.state = Some(match self.state.take().unwrap() {
             State::Idle => State::Idle,
-            State::Joining { token, config } => match event {
+            State::Joining { token, pending, config } => match event {
                 MlmeEvent::StartConf { resp } => match resp.result_code {
                     fidl_mlme::StartResultCodes::Success => {
                         report_join_finished(&self.user_sink, token, JoinMeshResult::Success);
-                        State::Joined { config }
+                        if pending.is_empty() {
+                            State::Joined { config }
+                        } else {
+                            // If there are any pending join/leave commands that arrived while we
+                            // were waiting for 'Start' to complete, then start leaving immediately,
+                            // and then process the pending commands once the 'Stop' call completes.
+                            self.mlme_sink.send(MlmeRequest::Stop(create_stop_request()));
+                            State::Leaving { config, pending }
+                        }
                     },
                     other => {
                         error!("failed to join mesh: {:?}", other);
                         report_join_finished(&self.user_sink, token, JoinMeshResult::InternalError);
-                        State::Idle
+                        on_back_to_idle(pending, &self.user_sink, &self.mlme_sink)
                     }
                 },
-                _ => State::Joining { token, config },
+                _ => State::Joining { token, pending, config },
             },
             State::Joined { config } => match event {
                 MlmeEvent::IncomingMpOpenAction { action } => {
@@ -187,6 +300,25 @@
                 },
                 _ => State::Joined { config },
             },
+            State::Leaving { config, pending } => match event {
+                MlmeEvent::StopConf { resp } => match resp.result_code {
+                    fidl_mlme::StopResultCodes::Success =>
+                        on_back_to_idle(pending, &self.user_sink, &self.mlme_sink),
+                    other => {
+                        error!("failed to leave mesh: {:?}", other);
+                        for token in pending.leave {
+                            report_leave_finished(
+                                    &self.user_sink, token, LeaveMeshResult::InternalError);
+                        }
+                        if let Some((token, _)) = pending.join {
+                            report_join_finished(
+                                    &self.user_sink, token, JoinMeshResult::InternalError);
+                        }
+                        State::Joined { config }
+                    }
+                },
+                _ => State::Leaving { config, pending }
+            }
         });
     }
 
@@ -223,6 +355,12 @@
     user_sink.send(UserEvent::JoinMeshFinished { token, result });
 }
 
+fn report_leave_finished<T: Tokens>(user_sink: &UserSink<T>, token: T::LeaveToken,
+                                    result: LeaveMeshResult)
+{
+    user_sink.send(UserEvent::LeaveMeshFinished { token, result });
+}
+
 impl<T: Tokens> MeshSme<T> {
     pub fn new(device_info: DeviceInfo) -> (Self, crate::MlmeStream, UserStream<T>) {
         let (mlme_sink, mlme_stream) = mpsc::unbounded();
diff --git a/lib/wlan/fidl/sme.fidl b/lib/wlan/fidl/sme.fidl
index cef530f..ec3c043 100644
--- a/lib/wlan/fidl/sme.fidl
+++ b/lib/wlan/fidl/sme.fidl
@@ -122,6 +122,12 @@
     DFS_UNSUPPORTED = 4;
 };
 
+enum LeaveMeshResultCode {
+    SUCCESS = 0;
+    INTERNAL_ERROR = 1;
+};
+
 interface MeshSme {
     Join(MeshConfig config) -> (JoinMeshResultCode code);
+    Leave() -> (LeaveMeshResultCode code);
 };