[wlan][sme] Add timeouts for establishing RSNA

In the middle of establishing RSNA, if a key frame has been sent out and
no response is received from the authenticator for some time, retry
until max attempt. In addition, add a timeout for the overall handshake.

WLAN-627 #done
WLAN-920 #done

Test:
- Add logging and comment out the lines where we send EAPOL request to
  MLME, verify timeouts work as expected.
- fx run-test wlan-sme-tests

Change-Id: I2a2f1fd56c6be21d345eb2d1c299209d7d7e8f65
diff --git a/bin/wlan/wlanstack/src/station/client.rs b/bin/wlan/wlanstack/src/station/client.rs
index b8b865f..d9b2c25 100644
--- a/bin/wlan/wlanstack/src/station/client.rs
+++ b/bin/wlan/wlanstack/src/station/client.rs
@@ -6,7 +6,7 @@
 use fidl::{endpoints::RequestStream, endpoints::ServerEnd};
 use fidl_fuchsia_wlan_mlme::{self as fidl_mlme, MlmeEventStream, MlmeProxy};
 use fidl_fuchsia_wlan_sme::{self as fidl_sme, ClientSmeRequest};
-use futures::{Poll, prelude::*, select, stream::{self, FuturesUnordered}};
+use futures::{prelude::*, select, stream::FuturesUnordered};
 use futures::channel::mpsc;
 use log::{error, info};
 use pin_utils::pin_mut;
@@ -16,7 +16,6 @@
 use wlan_sme::client::{BssInfo, ConnectionAttemptId, ConnectResult,
                        ConnectPhyParams, DiscoveryError,
                        EssDiscoveryResult, EssInfo, InfoEvent, ScanTxnId};
-use wlan_sme::timer::TimeEntry;
 use fuchsia_zircon as zx;
 
 use fuchsia_cobalt::CobaltSender;
@@ -53,9 +52,8 @@
     -> Result<(), failure::Error>
     where S: Stream<Item = StatsRequest> + Unpin
 {
-    let (sme, mlme_stream, user_stream, info_stream) = Sme::new(device_info);
+    let (sme, mlme_stream, user_stream, info_stream, time_stream) = Sme::new(device_info);
     let sme = Arc::new(Mutex::new(sme));
-    let time_stream = stream::poll_fn::<TimeEntry<()>, _>(|_| Poll::Pending);
     let mlme_sme = super::serve_mlme_sme(
         proxy, event_stream, Arc::clone(&sme), mlme_stream, stats_requests, time_stream);
     let sme_fidl = serve_fidl(sme, new_fidl_clients, user_stream, info_stream, cobalt_sender);
diff --git a/bin/wlan/wlanstack/src/telemetry.rs b/bin/wlan/wlanstack/src/telemetry.rs
index 110ba38..6418abb 100644
--- a/bin/wlan/wlanstack/src/telemetry.rs
+++ b/bin/wlan/wlanstack/src/telemetry.rs
@@ -380,6 +380,7 @@
     ScanNotSupportedId = 4000,
     ScanInvalidArgsId = 4001,
     ScanInternalErrorId = 4002,
+    RsnaTimeout = 5000,
 }
 
 fn convert_connect_failure(result: &ConnectFailure) -> Option<ConnectionResultLabel> {
@@ -420,6 +421,7 @@
             RejectedEmergencyServicesNotSupported => AssocRejectedEmergencyServicesNotSupportedId,
             RefusedTemporarily => AssocRefusedTemporarilyId,
         },
+        ConnectFailure::RsnaTimeout => RsnaTimeout,
     };
 
     Some(result)
diff --git a/lib/rust/wlan-sme/src/client/event.rs b/lib/rust/wlan-sme/src/client/event.rs
new file mode 100644
index 0000000..13eea39
--- /dev/null
+++ b/lib/rust/wlan-sme/src/client/event.rs
@@ -0,0 +1,33 @@
+// Copyright 2018 The Fuchsia Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use eapol;
+use fuchsia_zircon::{self as zx, prelude::DurationNum};
+
+use crate::MacAddr;
+use crate::timer::TimeoutDuration;
+
+pub const ESTABLISHING_RSNA_TIMEOUT_SECONDS: i64 = 3;
+pub const KEY_FRAME_EXCHANGE_TIMEOUT_MILLIS: i64 = 200;
+pub const KEY_FRAME_EXCHANGE_MAX_ATTEMPTS: u32 = 3;
+
+#[derive(Debug, Clone)]
+pub enum Event {
+    EstablishingRsnaTimeout,
+    KeyFrameExchangeTimeout {
+        bssid: MacAddr,
+        sta_addr: MacAddr,
+        frame: eapol::KeyFrame,
+        attempt: u32
+    },
+}
+
+impl TimeoutDuration for Event {
+    fn timeout_duration(&self) -> zx::Duration {
+        match self {
+            Event::EstablishingRsnaTimeout => ESTABLISHING_RSNA_TIMEOUT_SECONDS.seconds(),
+            Event::KeyFrameExchangeTimeout { .. } => KEY_FRAME_EXCHANGE_TIMEOUT_MILLIS.millis(),
+        }
+    }
+}
\ No newline at end of file
diff --git a/lib/rust/wlan-sme/src/client/mod.rs b/lib/rust/wlan-sme/src/client/mod.rs
index eb6a6dc..76b354d 100644
--- a/lib/rust/wlan-sme/src/client/mod.rs
+++ b/lib/rust/wlan-sme/src/client/mod.rs
@@ -3,6 +3,7 @@
 // found in the LICENSE file.
 
 mod bss;
+mod event;
 mod rsn;
 mod scan;
 mod state;
@@ -18,14 +19,15 @@
 
 use super::{DeviceInfo, InfoStream, MlmeRequest, MlmeStream, Ssid};
 
+use self::bss::{get_best_bss, get_channel_map, get_standard_map, group_networks};
+use self::event::Event;
 use self::scan::{DiscoveryScan, JoinScan, JoinScanFailure, ScanResult, ScanScheduler};
 use self::rsn::get_rsna;
 use self::state::{ConnectCommand, State};
 
-use crate::client::bss::{get_best_bss, get_channel_map, get_standard_map, group_networks};
 use crate::clone_utils::clone_bss_desc;
 use crate::sink::{InfoSink, MlmeSink};
-use crate::timer::TimedEvent;
+use crate::timer::{self, TimedEvent};
 
 pub use self::bss::{BssInfo, EssInfo};
 pub use self::scan::{DiscoveryError};
@@ -46,8 +48,9 @@
     use std::sync::Arc;
 
     use crate::DeviceInfo;
-    use crate::client::{ConnectionAttemptId, Tokens};
+    use crate::client::{ConnectionAttemptId, event::Event, Tokens};
     use crate::sink::{InfoSink, MlmeSink, UnboundedSink};
+    use crate::timer::Timer;
 
     pub type UserSink<T> = UnboundedSink<super::UserEvent<T>>;
     pub struct Context<T: Tokens> {
@@ -55,6 +58,7 @@
         pub mlme_sink: MlmeSink,
         pub user_sink: UserSink<T>,
         pub info_sink: InfoSink,
+        pub(crate) timer: Timer<Event>,
         pub att_id: ConnectionAttemptId,
     }
 }
@@ -62,6 +66,7 @@
 use self::internal::*;
 
 pub type UserStream<T> = mpsc::UnboundedReceiver<UserEvent<T>>;
+pub type TimeStream = timer::TimeStream<Event>;
 
 #[derive(Debug, PartialEq)]
 pub struct ConnectPhyParams {
@@ -104,6 +109,7 @@
     JoinFailure(fidl_mlme::JoinResultCodes),
     AuthenticationFailure(fidl_mlme::AuthenticateResultCodes),
     AssociationFailure(fidl_mlme::AssociateResultCodes),
+    RsnaTimeout,
 }
 
 pub type EssDiscoveryResult = Result<Vec<EssInfo>, DiscoveryError>;
@@ -170,11 +176,12 @@
 }
 
 impl<T: Tokens> ClientSme<T> {
-    pub fn new(info: DeviceInfo) -> (Self, MlmeStream, UserStream<T>, InfoStream) {
+    pub fn new(info: DeviceInfo) -> (Self, MlmeStream, UserStream<T>, InfoStream, TimeStream) {
         let device_info = Arc::new(info);
         let (mlme_sink, mlme_stream) = mpsc::unbounded();
         let (user_sink, user_stream) = mpsc::unbounded();
         let (info_sink, info_stream) = mpsc::unbounded();
+        let (timer, time_stream) = timer::create_timer();
         (
             ClientSme {
                 state: Some(State::Idle),
@@ -184,12 +191,14 @@
                     user_sink: UserSink::new(user_sink),
                     info_sink: InfoSink::new(info_sink),
                     device_info,
+                    timer,
                     att_id: 0,
                 },
             },
             mlme_stream,
             user_stream,
             info_stream,
+            time_stream,
         )
     }
 
@@ -246,7 +255,7 @@
 }
 
 impl<T: Tokens> super::Station for ClientSme<T> {
-    type Event = ();
+    type Event = Event;
 
     fn on_mlme_event(&mut self, event: MlmeEvent) {
         self.state = self.state.take().map(|state| match event {
@@ -341,8 +350,13 @@
         });
     }
 
-    fn on_timeout(&mut self, _timed_event: TimedEvent<()>) {
-        unimplemented!();
+    fn on_timeout(&mut self, timed_event: TimedEvent<Event>) {
+        self.state = self.state.take().map(|state| match timed_event.event {
+            event @ Event::EstablishingRsnaTimeout
+            | event @ Event::KeyFrameExchangeTimeout { .. } => {
+                state.handle_timeout(timed_event.id, event, &mut self.context)
+            },
+        });
     }
 }
 
@@ -380,7 +394,7 @@
 
     #[test]
     fn status_connecting_to() {
-        let (mut sme, _mlme_stream, _user_stream, _info_stream) = create_sme();
+        let (mut sme, _mlme_stream, _user_stream, _info_stream, _time_stream) = create_sme();
         assert_eq!(Status{ connected_to: None, connecting_to: None },
                    sme.status());
 
@@ -432,7 +446,7 @@
 
     #[test]
     fn connecting_password_supplied_for_protected_network() {
-        let (mut sme, mut mlme_stream, _user_stream, _info_stream) = create_sme();
+        let (mut sme, mut mlme_stream, _user_stream, _info_stream, _time_stream) = create_sme();
         assert_eq!(Status{ connected_to: None, connecting_to: None },
                    sme.status());
 
@@ -479,7 +493,7 @@
 
     #[test]
     fn connecting_password_supplied_for_unprotected_network() {
-        let (mut sme, mut mlme_stream, mut user_stream, _info_stream) = create_sme();
+        let (mut sme, mut mlme_stream, mut user_stream, _info_stream, _time_stream) = create_sme();
         assert_eq!(Status{ connected_to: None, connecting_to: None },
                    sme.status());
 
@@ -537,7 +551,7 @@
 
     #[test]
     fn connecting_no_password_supplied_for_protected_network() {
-        let (mut sme, mut mlme_stream, mut user_stream, _info_stream) = create_sme();
+        let (mut sme, mut mlme_stream, mut user_stream, _info_stream, _time_stream) = create_sme();
         assert_eq!(Status{ connected_to: None, connecting_to: None },
                    sme.status());
 
@@ -593,7 +607,7 @@
 
     #[test]
     fn connecting_generates_info_events() {
-        let (mut sme, _mlme_stream, _user_stream, mut info_stream) = create_sme();
+        let (mut sme, _mlme_stream, _user_stream, mut info_stream, _time_stream) = create_sme();
 
         sme.on_connect_command(b"foo".to_vec(), vec![], 10,
                                ConnectPhyParams { phy: None, cbw: None });
@@ -622,7 +636,8 @@
         type ConnectToken = i32;
     }
 
-    fn create_sme() -> (ClientSme<FakeTokens>, MlmeStream, UserStream<FakeTokens>, InfoStream) {
+    fn create_sme() -> (ClientSme<FakeTokens>, MlmeStream, UserStream<FakeTokens>, InfoStream,
+                        TimeStream) {
         ClientSme::new(DeviceInfo {
             addr: CLIENT_ADDR,
             bands: vec![],
diff --git a/lib/rust/wlan-sme/src/client/state.rs b/lib/rust/wlan-sme/src/client/state.rs
index 5e0293d..1bdb7c8 100644
--- a/lib/rust/wlan-sme/src/client/state.rs
+++ b/lib/rust/wlan-sme/src/client/state.rs
@@ -8,21 +8,31 @@
 use wlan_rsn::rsna::{self, SecAssocUpdate, SecAssocStatus};
 
 use super::bss::convert_bss_description;
-use crate::phy_selection::{derive_phy_cbw};
 use super::{ConnectFailure, ConnectPhyParams, ConnectResult, InfoEvent, Status, Tokens};
 use super::rsn::Rsna;
 
 use crate::MlmeRequest;
-use crate::client::{Context, report_connect_finished};
+use crate::client::{Context, event::{self, Event}, report_connect_finished};
 use crate::clone_utils::clone_bss_desc;
+use crate::phy_selection::{derive_phy_cbw};
 use crate::sink::MlmeSink;
+use crate::timer::EventId;
 
 const DEFAULT_JOIN_FAILURE_TIMEOUT: u32 = 20; // beacon intervals
 const DEFAULT_AUTH_FAILURE_TIMEOUT: u32 = 20; // beacon intervals
 
 #[derive(Debug, PartialEq)]
 pub enum LinkState<T: Tokens> {
-    EstablishingRsna(Option<T::ConnectToken>, Rsna),
+    EstablishingRsna {
+        token: Option<T::ConnectToken>,
+        rsna: Rsna,
+        // Timeout for the total duration RSNA may take to complete.
+        rsna_timeout: Option<EventId>,
+        // Timeout waiting to receive a key frame from the Authenticator. This timeout is None at
+        // the beginning of the RSNA when no frame has been exchanged yet, or at the end of the
+        // RSNA when all the key frames have finished exchanging.
+        resp_timeout: Option<EventId>,
+    },
     LinkUp(Option<Rsna>)
 }
 
@@ -39,6 +49,9 @@
     Established,
     Failed(ConnectResult),
     Unchanged,
+    Progressed {
+        new_resp_timeout: Option<EventId>,
+    },
 }
 
 #[derive(Debug, PartialEq)]
@@ -122,10 +135,17 @@
                                     context.info_sink.send(
                                         InfoEvent::RsnaStarted { att_id: context.att_id });
 
+                                    let rsna_timeout = Some(context.timer.schedule(
+                                        Event::EstablishingRsnaTimeout));
                                     State::Associated {
                                         bss: cmd.bss,
                                         last_rssi: None,
-                                        link_state: LinkState::EstablishingRsna(cmd.token, rsna),
+                                        link_state: LinkState::EstablishingRsna {
+                                            token: cmd.token,
+                                            rsna,
+                                            rsna_timeout,
+                                            resp_timeout: None,
+                                        },
                                         params: cmd.params,
                                     }
                                 }
@@ -156,7 +176,7 @@
                 MlmeEvent::DisassociateInd{ .. } => {
                     let (token, mut rsna) = match link_state {
                         LinkState::LinkUp(rsna) => (None, rsna),
-                        LinkState::EstablishingRsna(token, rsna) => (token, Some(rsna)),
+                        LinkState::EstablishingRsna{ token, rsna, .. } => (token, Some(rsna)),
                     };
                     // Client is disassociating. The ESS-SA must be kept alive but reset.
                     if let Some(rsna) = &mut rsna {
@@ -173,7 +193,7 @@
                     to_associating_state(cmd, &context.mlme_sink)
                 },
                 MlmeEvent::DeauthenticateInd{ ind } => {
-                    if let LinkState::EstablishingRsna(token, _) = link_state {
+                    if let LinkState::EstablishingRsna{ token, .. } = link_state {
                         let connect_result = deauth_code_to_connect_result(ind.reason_code);
                         report_connect_finished(token, &context, connect_result, None);
                     }
@@ -188,8 +208,9 @@
                     }
                 },
                 MlmeEvent::EapolInd{ ref ind } if bss.rsn.is_some() => match link_state {
-                    LinkState::EstablishingRsna(token, mut rsna) => {
-                        match process_eapol_ind(&context.mlme_sink, &mut rsna, &ind) {
+                    LinkState::EstablishingRsna{ token, mut rsna, rsna_timeout,
+                                                 mut resp_timeout } => {
+                        match process_eapol_ind(context, &mut rsna, &ind) {
                             RsnaStatus::Established => {
                                 context.mlme_sink.send(MlmeRequest::SetCtrlPort(
                                     fidl_mlme::SetControlledPortRequest {
@@ -210,13 +231,23 @@
                                 State::Idle
                             },
                             RsnaStatus::Unchanged => {
-                                let link_state = LinkState::EstablishingRsna(token, rsna);
+                                let link_state = LinkState::EstablishingRsna {
+                                    token, rsna, rsna_timeout, resp_timeout };
                                 State::Associated { bss, last_rssi, link_state, params, }
                             },
+                            RsnaStatus::Progressed { new_resp_timeout } => {
+                                cancel(&mut resp_timeout);
+                                if let Some(id) = new_resp_timeout {
+                                    resp_timeout.replace(id);
+                                }
+                                let link_state = LinkState::EstablishingRsna {
+                                    token, rsna, rsna_timeout, resp_timeout};
+                                State::Associated { bss, last_rssi, link_state, params, }
+                            }
                         }
                     },
                     LinkState::LinkUp(Some(mut rsna)) => {
-                        match process_eapol_ind(&context.mlme_sink, &mut rsna, &ind) {
+                        match process_eapol_ind(context, &mut rsna, &ind) {
                             RsnaStatus::Unchanged => {},
                             // Once re-keying is supported, the RSNA can fail in LinkUp as well
                             // and cause deauthentication.
@@ -232,6 +263,58 @@
         }
     }
 
+    pub fn handle_timeout(self, event_id: EventId, event: Event, context: &mut Context<T>) -> Self {
+        match self {
+            State::Associated { bss, last_rssi, link_state, params } => match link_state {
+                LinkState::EstablishingRsna { token, rsna, mut rsna_timeout, mut resp_timeout } => {
+                    match event {
+                        Event::EstablishingRsnaTimeout if triggered(&rsna_timeout,
+                                                                    event_id) => {
+                            error!("timeout establishing RSNA; deauthenticating");
+                            cancel(&mut rsna_timeout);
+                            report_connect_finished(token, &context, ConnectResult::Failed,
+                                                    Some(ConnectFailure::RsnaTimeout));
+                            send_deauthenticate_request(bss, &context.mlme_sink);
+                            State::Idle
+                        },
+                        Event::KeyFrameExchangeTimeout { bssid, sta_addr, frame, attempt } => {
+                            if !triggered(&resp_timeout, event_id) {
+                                let link_state = LinkState::EstablishingRsna {
+                                    token, rsna, rsna_timeout, resp_timeout, };
+                                return State::Associated { bss, last_rssi, link_state, params }
+                            }
+
+                            if attempt < event::KEY_FRAME_EXCHANGE_MAX_ATTEMPTS {
+                                warn!("timeout waiting for key frame for attempt {}; retrying",
+                                      attempt);
+                                let id = send_eapol_frame(context, bssid, sta_addr, frame,
+                                                          attempt + 1);
+                                resp_timeout.replace(id);
+                                let link_state = LinkState::EstablishingRsna {
+                                    token, rsna, rsna_timeout, resp_timeout, };
+                                State::Associated { bss, last_rssi, link_state, params }
+                            } else {
+                                error!("timeout waiting for key frame for last attempt; deauth");
+                                cancel(&mut resp_timeout);
+                                report_connect_finished(token, &context, ConnectResult::Failed,
+                                                        Some(ConnectFailure::RsnaTimeout));
+                                send_deauthenticate_request(bss, &context.mlme_sink);
+                                State::Idle
+                            }
+                        },
+                        _ => {
+                            let link_state = LinkState::EstablishingRsna {
+                                token, rsna, rsna_timeout, resp_timeout };
+                            State::Associated { bss, last_rssi, link_state, params }
+                        },
+                    }
+                },
+                _ => State::Associated { bss, last_rssi, link_state, params },
+            },
+            _ => self,
+        }
+    }
+
     pub fn connect(self, cmd: ConnectCommand<T::ConnectToken>, context: &mut Context<T>) -> Self {
         self.disconnect_internal(context);
 
@@ -291,9 +374,11 @@
                     connecting_to: Some(cmd.bss.ssid.clone()),
                 }
             },
-            State::Associated { bss, link_state: LinkState::EstablishingRsna(..), .. } => Status {
-                connected_to: None,
-                connecting_to: Some(bss.ssid.clone()),
+            State::Associated { bss, link_state: LinkState::EstablishingRsna { .. }, .. } => {
+                Status {
+                    connected_to: None,
+                    connecting_to: Some(bss.ssid.clone()),
+                }
             },
             State::Associated { bss, link_state: LinkState::LinkUp(..), .. } => Status {
                 connected_to: Some(convert_bss_description(bss)),
@@ -303,6 +388,15 @@
     }
 }
 
+fn triggered(id: &Option<EventId>, received_id: EventId) -> bool {
+    id.map_or(false, |id| id == received_id)
+}
+
+fn cancel(event_id: &mut Option<EventId>) {
+    let _ = event_id.take();
+}
+
+
 fn deauth_code_to_connect_result(reason_code: fidl_mlme::ReasonCode) -> ConnectResult {
     match reason_code {
         fidl_mlme::ReasonCode::InvalidAuthentication
@@ -311,8 +405,9 @@
     }
 }
 
-fn process_eapol_ind(mlme_sink: &MlmeSink, rsna: &mut Rsna, ind: &fidl_mlme::EapolIndication)
-    -> RsnaStatus
+fn process_eapol_ind<T: Tokens>(context: &mut Context<T>, rsna: &mut Rsna,
+                                ind: &fidl_mlme::EapolIndication)
+                                -> RsnaStatus
 {
     let mic_size = rsna.negotiated_rsne.mic_size;
     let eapol_pdu = &ind.data[..];
@@ -325,24 +420,29 @@
     };
 
     let mut update_sink = rsna::UpdateSink::default();
-    if let Err(e) = rsna.supplicant.on_eapol_frame(&mut update_sink, &eapol_frame) {
-        error!("error processing EAPOL key frame: {}", e);
-        return RsnaStatus::Unchanged;
+    match rsna.supplicant.on_eapol_frame(&mut update_sink, &eapol_frame) {
+        Err(e) => {
+            error!("error processing EAPOL key frame: {}", e);
+            return RsnaStatus::Unchanged;
+        }
+        Ok(_) if update_sink.is_empty() => return RsnaStatus::Unchanged,
+        _ => (),
     }
 
     let bssid = ind.src_addr;
     let sta_addr = ind.dst_addr;
+    let mut new_resp_timeout = None;
     for update in update_sink {
         match update {
             // ESS Security Association requests to send an EAPOL frame.
             // Forward EAPOL frame to MLME.
             SecAssocUpdate::TxEapolKeyFrame(frame) => {
-                send_eapol_frame(mlme_sink, bssid, sta_addr, frame)
+                new_resp_timeout.replace(send_eapol_frame(context, bssid, sta_addr, frame, 1));
             },
             // ESS Security Association derived a new key.
             // Configure key in MLME.
             SecAssocUpdate::Key(key) => {
-                send_keys(mlme_sink, bssid, key)
+                send_keys(&context.mlme_sink, bssid, key)
             },
             // Received a status update.
             // TODO(hahnr): Rework this part.
@@ -363,20 +463,30 @@
         }
     }
 
-    RsnaStatus::Unchanged
+    RsnaStatus::Progressed { new_resp_timeout }
 }
 
-fn send_eapol_frame(mlme_sink: &MlmeSink, bssid: [u8; 6], sta_addr: [u8; 6], frame: eapol::KeyFrame)
+fn send_eapol_frame<T: Tokens>(context: &mut Context<T>, bssid: [u8; 6], sta_addr: [u8; 6],
+                               frame: eapol::KeyFrame, attempt: u32)
+                               -> EventId
 {
+    let resp_timeout_id = context.timer.schedule(Event::KeyFrameExchangeTimeout {
+        bssid,
+        sta_addr,
+        frame: frame.clone(),
+        attempt
+    });
+
     let mut buf = Vec::with_capacity(frame.len());
     frame.as_bytes(false, &mut buf);
-    mlme_sink.send(MlmeRequest::Eapol(
+    context.mlme_sink.send(MlmeRequest::Eapol(
         fidl_mlme::EapolRequest {
             src_addr: sta_addr,
             dst_addr: bssid,
             data: buf,
         }
     ));
+    resp_timeout_id
 }
 
 fn send_keys(mlme_sink: &MlmeSink, bssid: [u8; 6], key: Key)
@@ -477,8 +587,8 @@
         MockSupplicantController,
 
     };
-    use crate::client::{InfoSink, UserEvent, UserStream, UserSink};
-    use crate::{DeviceInfo, InfoStream, MlmeStream, Ssid, test_utils};
+    use crate::client::{InfoSink, TimeStream, UserEvent, UserStream, UserSink};
+    use crate::{DeviceInfo, InfoStream, MlmeStream, Ssid, test_utils, timer};
 
     #[derive(Debug, PartialEq)]
     struct FakeTokens;
@@ -671,12 +781,8 @@
         let state = State::Associating::<FakeTokens> { cmd: connect_command_one() };
 
         // (mlme->sme) Send an unsuccessful AssociateConf
-        let assoc_conf = MlmeEvent::AssociateConf {
-            resp: fidl_mlme::AssociateConfirm {
-                result_code: fidl_mlme::AssociateResultCodes::RefusedReasonUnspecified,
-                association_id: 0,
-            }
-        };
+        let assoc_conf = create_assoc_conf(
+            fidl_mlme::AssociateResultCodes::RefusedReasonUnspecified);
         let state = state.on_mlme_event(assoc_conf, &mut h.context);
         assert_eq!(idle_state(), state);
 
@@ -779,7 +885,7 @@
 
         match state {
             State::Associated { link_state, .. } => match link_state {
-                LinkState::EstablishingRsna(..) => (), // expected path
+                LinkState::EstablishingRsna { .. } => (), // expected path
                 _ => panic!("expect link state to still be establishing RSNA"),
             },
             _ => panic!("expect state to still be associated"),
@@ -806,7 +912,7 @@
 
         match state {
             State::Associated { link_state, .. } => match link_state {
-                LinkState::EstablishingRsna(..) => (), // expected path
+                LinkState::EstablishingRsna { .. } => (), // expected path
                 _ => panic!("expect link state to still be establishing RSNA"),
             },
             _ => panic!("expect state to still be associated"),
@@ -842,6 +948,64 @@
     }
 
     #[test]
+    fn overall_timeout_while_establishing_rsna() {
+        let mut h = TestHelper::new();
+        let (supplicant, _suppl_mock) = mock_supplicant();
+        let command = connect_command_rsna(supplicant);
+        let bssid = command.bss.bssid.clone();
+        let token = command.token.unwrap();
+
+        // Start in an "Associating" state
+        let state = State::Associating::<FakeTokens> { cmd: command };
+        let assoc_conf = create_assoc_conf(fidl_mlme::AssociateResultCodes::Success);
+        let state = state.on_mlme_event(assoc_conf, &mut h.context);
+
+        let (_, timed_event) = h.time_stream.try_next().unwrap().expect("expect timed event");
+        match timed_event.event {
+            Event::EstablishingRsnaTimeout => (), // expected path
+            _ => panic!("expect EstablishingRsnaTimeout timeout event"),
+        }
+
+        expect_stream_empty(&mut h.mlme_stream, "unexpected event in mlme stream");
+
+        let _state = state.handle_timeout(timed_event.id, timed_event.event, &mut h.context);
+
+        expect_deauth_req(&mut h.mlme_stream, bssid, fidl_mlme::ReasonCode::StaLeaving);
+        expect_connect_result(&mut h.user_stream, token, ConnectResult::Failed);
+    }
+
+    #[test]
+    fn key_frame_exchange_timeout_while_establishing_rsna() {
+        let mut h = TestHelper::new();
+        let (supplicant, suppl_mock) = mock_supplicant();
+        let command = connect_command_rsna(supplicant);
+        let bssid = command.bss.bssid.clone();
+        let token = command.token.unwrap();
+        let state = establishing_rsna_state(command);
+
+        // (mlme->sme) Send an EapolInd, mock supplication with key frame
+        let update = SecAssocUpdate::TxEapolKeyFrame(test_utils::eapol_key_frame());
+        let mut state = on_eapol_ind(state, &mut h, bssid, &suppl_mock, vec![update]);
+
+        for i in 1..=3 {
+            println!("send eapol attempt: {}", i);
+            expect_eapol_req(&mut h.mlme_stream, bssid);
+            expect_stream_empty(&mut h.mlme_stream, "unexpected event in mlme stream");
+
+            let (_, timed_event) = h.time_stream.try_next().unwrap().expect("expect timed event");
+            match timed_event.event {
+                Event::KeyFrameExchangeTimeout { attempt, .. } => assert_eq!(attempt, i),
+                _ => panic!("expect EstablishingRsnaTimeout timeout event"),
+            }
+            state = state.handle_timeout(timed_event.id, timed_event.event, &mut h.context);
+        }
+
+        expect_deauth_req(&mut h.mlme_stream, bssid, fidl_mlme::ReasonCode::StaLeaving);
+        expect_connect_result(&mut h.user_stream, token, ConnectResult::Failed);
+    }
+
+
+    #[test]
     fn connect_while_link_up() {
         let mut h = TestHelper::new();
         let state = link_up_state(connect_command_one().bss);
@@ -942,6 +1106,7 @@
         mlme_stream: MlmeStream,
         user_stream: UserStream<FakeTokens>,
         info_stream: InfoStream,
+        time_stream: TimeStream,
         context: Context<FakeTokens>,
     }
 
@@ -950,14 +1115,16 @@
             let (mlme_sink, mlme_stream) = mpsc::unbounded();
             let (user_sink, user_stream) = mpsc::unbounded();
             let (info_sink, info_stream) = mpsc::unbounded();
+            let (timer, time_stream) = timer::create_timer();
             let context = Context {
                 device_info: Arc::new(fake_device_info()),
                 mlme_sink: MlmeSink::new(mlme_sink),
                 user_sink: UserSink::new(user_sink),
                 info_sink: InfoSink::new(info_sink),
+                timer,
                 att_id: 0,
             };
-            TestHelper { mlme_stream, user_stream, info_stream, context }
+            TestHelper { mlme_stream, user_stream, info_stream, time_stream, context }
         }
     }
 
@@ -1196,7 +1363,12 @@
         State::Associated {
             bss: cmd.bss,
             last_rssi: None,
-            link_state: LinkState::EstablishingRsna(cmd.token, rsna),
+            link_state: LinkState::EstablishingRsna {
+                token: cmd.token,
+                rsna,
+                rsna_timeout: None,
+                resp_timeout: None,
+            },
             params: ConnectPhyParams { phy: None, cbw: None },
         }
     }