blob: 2a3e1d848ead9a144b3037b8f63ca5854e80eef4 [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
mod bss;
mod event;
mod rsn;
mod scan;
mod state;
#[cfg(test)]
pub mod test_utils;
use fidl_fuchsia_wlan_common as fidl_common;
use fidl_fuchsia_wlan_mlme::{self as fidl_mlme, MlmeEvent, ScanRequest};
use fidl_fuchsia_wlan_sme as fidl_sme;
use futures::channel::{mpsc, oneshot};
use log::error;
use std::collections::HashMap;
use std::sync::Arc;
use wlan_common::RadioConfig;
use super::{DeviceInfo, InfoStream, MlmeRequest, MlmeStream, Ssid};
use self::bss::{get_best_bss, get_channel_map, get_standard_map, group_networks};
use self::event::Event;
use self::rsn::get_rsna;
use self::scan::{DiscoveryScan, JoinScan, JoinScanFailure, ScanResult, ScanScheduler};
use self::state::{ConnectCommand, State};
use crate::clone_utils::clone_bss_desc;
use crate::responder::Responder;
use crate::sink::{InfoSink, MlmeSink};
use crate::timer::{self, TimedEvent};
pub use self::bss::{BssInfo, EssInfo};
pub use self::scan::DiscoveryError;
// This is necessary to trick the private-in-public checker.
// A private module is not allowed to include private types in its interface,
// even though the module itself is private and will never be exported.
// As a workaround, we add another private module with public types.
mod internal {
use std::sync::Arc;
use crate::client::{event::Event, ConnectionAttemptId};
use crate::sink::{InfoSink, MlmeSink};
use crate::timer::Timer;
use crate::DeviceInfo;
pub struct Context {
pub device_info: Arc<DeviceInfo>,
pub mlme_sink: MlmeSink,
pub info_sink: InfoSink,
pub(crate) timer: Timer<Event>,
pub att_id: ConnectionAttemptId,
}
}
use self::internal::*;
pub type TimeStream = timer::TimeStream<Event>;
pub struct ConnectConfig {
responder: Responder<ConnectResult>,
password: Vec<u8>,
radio_cfg: RadioConfig,
}
// An automatically increasing sequence number that uniquely identifies a logical
// connection attempt. For example, a new connection attempt can be triggered
// by a DisassociateInd message from the MLME.
pub type ConnectionAttemptId = u64;
pub type ScanTxnId = u64;
pub struct ClientSme {
state: Option<State>,
scan_sched: ScanScheduler<Responder<EssDiscoveryResult>, ConnectConfig>,
context: Context,
}
#[derive(Clone, Debug, PartialEq)]
pub enum ConnectResult {
Success,
Canceled,
Failed,
BadCredentials,
}
#[derive(Debug, PartialEq)]
pub enum ConnectFailure {
NoMatchingBssFound,
ScanFailure(fidl_mlme::ScanResultCodes),
JoinFailure(fidl_mlme::JoinResultCodes),
AuthenticationFailure(fidl_mlme::AuthenticateResultCodes),
AssociationFailure(fidl_mlme::AssociateResultCodes),
RsnaTimeout,
}
pub type EssDiscoveryResult = Result<Vec<EssInfo>, DiscoveryError>;
#[derive(Debug, PartialEq)]
pub enum InfoEvent {
ConnectStarted,
ConnectFinished {
result: ConnectResult,
failure: Option<ConnectFailure>,
},
MlmeScanStart {
txn_id: ScanTxnId,
},
MlmeScanEnd {
txn_id: ScanTxnId,
},
ScanDiscoveryFinished {
bss_count: usize,
ess_count: usize,
num_bss_by_standard: HashMap<Standard, usize>,
num_bss_by_channel: HashMap<u8, usize>,
},
AssociationStarted {
att_id: ConnectionAttemptId,
},
AssociationSuccess {
att_id: ConnectionAttemptId,
},
RsnaStarted {
att_id: ConnectionAttemptId,
},
RsnaEstablished {
att_id: ConnectionAttemptId,
},
}
#[derive(Clone, Debug, PartialEq)]
pub struct Status {
pub connected_to: Option<BssInfo>,
pub connecting_to: Option<Ssid>,
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum Standard {
B,
G,
A,
N,
Ac,
}
impl ClientSme {
pub fn new(info: DeviceInfo) -> (Self, MlmeStream, InfoStream, TimeStream) {
let device_info = Arc::new(info);
let (mlme_sink, mlme_stream) = mpsc::unbounded();
let (info_sink, info_stream) = mpsc::unbounded();
let (timer, time_stream) = timer::create_timer();
(
ClientSme {
state: Some(State::Idle),
scan_sched: ScanScheduler::new(Arc::clone(&device_info)),
context: Context {
mlme_sink: MlmeSink::new(mlme_sink),
info_sink: InfoSink::new(info_sink),
device_info,
timer,
att_id: 0,
},
},
mlme_stream,
info_stream,
time_stream,
)
}
pub fn on_connect_command(
&mut self,
req: fidl_sme::ConnectRequest,
) -> oneshot::Receiver<ConnectResult> {
let (responder, receiver) = Responder::new();
self.context.info_sink.send(InfoEvent::ConnectStarted);
let (canceled_token, req) = self.scan_sched.enqueue_scan_to_join(JoinScan {
ssid: req.ssid,
token: ConnectConfig {
responder,
password: req.password,
radio_cfg: RadioConfig::from_fidl(req.radio_cfg),
},
scan_type: req.scan_type,
});
// If the new scan replaced an existing pending JoinScan, notify the existing transaction
if let Some(token) = canceled_token {
report_connect_finished(
Some(token.responder),
&self.context,
ConnectResult::Canceled,
None,
);
}
self.send_scan_request(req);
receiver
}
pub fn on_disconnect_command(&mut self) {
self.state = self.state.take().map(|state| state.disconnect(&self.context));
}
pub fn on_scan_command(
&mut self,
scan_type: fidl_common::ScanType,
) -> oneshot::Receiver<EssDiscoveryResult> {
let (responder, receiver) = Responder::new();
let scan = DiscoveryScan::new(responder, scan_type);
let req = self.scan_sched.enqueue_scan_to_discover(scan);
self.send_scan_request(req);
receiver
}
pub fn status(&self) -> Status {
let status = self.state.as_ref().expect("expected state to be always present").status();
if status.connecting_to.is_some() {
status
} else {
// If the association machine is not connecting to a network, but the scanner
// has a queued 'JoinScan', include the SSID we are trying to connect to
Status {
connecting_to: self.scan_sched.get_join_scan().map(|s| s.ssid.clone()),
..status
}
}
}
fn send_scan_request(&mut self, req: Option<ScanRequest>) {
if let Some(req) = req {
self.context.info_sink.send(InfoEvent::MlmeScanStart { txn_id: req.txn_id });
self.context.mlme_sink.send(MlmeRequest::Scan(req));
}
}
}
impl super::Station for ClientSme {
type Event = Event;
fn on_mlme_event(&mut self, event: MlmeEvent) {
self.state = self.state.take().map(|state| match event {
MlmeEvent::OnScanResult { result } => {
self.scan_sched.on_mlme_scan_result(result);
state
}
MlmeEvent::OnScanEnd { end } => {
self.context.info_sink.send(InfoEvent::MlmeScanEnd { txn_id: end.txn_id });
let (result, request) = self.scan_sched.on_mlme_scan_end(end);
self.send_scan_request(request);
match result {
ScanResult::None => state,
ScanResult::JoinScanFinished { token, result: Ok(bss_list) } => {
match get_best_bss(&bss_list) {
Some(best_bss) => {
match get_rsna(
&self.context.device_info,
&token.password,
&best_bss,
) {
Ok(rsna) => {
let cmd = ConnectCommand {
bss: Box::new(clone_bss_desc(&best_bss)),
responder: Some(token.responder),
rsna,
radio_cfg: token.radio_cfg,
};
state.connect(cmd, &mut self.context)
}
Err(err) => {
error!("cannot join BSS {:02X?} {}", best_bss.bssid, err);
report_connect_finished(
Some(token.responder),
&self.context,
ConnectResult::Failed,
None,
);
state
}
}
}
None => {
error!("no matching BSS found");
report_connect_finished(
Some(token.responder),
&self.context,
ConnectResult::Failed,
Some(ConnectFailure::NoMatchingBssFound),
);
state
}
}
}
ScanResult::JoinScanFinished { token, result: Err(e) } => {
error!("cannot join network because scan failed: {:?}", e);
let (result, failure) = match e {
JoinScanFailure::Canceled => (ConnectResult::Canceled, None),
JoinScanFailure::ScanFailed(code) => {
(ConnectResult::Failed, Some(ConnectFailure::ScanFailure(code)))
}
};
report_connect_finished(
Some(token.responder),
&self.context,
result,
failure,
);
state
}
ScanResult::DiscoveryFinished { tokens, result } => {
let result = match result {
Ok(bss_list) => {
let bss_count = bss_list.len();
let ess_list = group_networks(&bss_list);
let ess_count = ess_list.len();
let num_bss_by_standard = get_standard_map(&bss_list);
let num_bss_by_channel = get_channel_map(&bss_list);
self.context.info_sink.send(InfoEvent::ScanDiscoveryFinished {
bss_count,
ess_count,
num_bss_by_standard,
num_bss_by_channel,
});
Ok(ess_list)
}
Err(e) => Err(e),
};
for responder in tokens {
responder.respond(result.clone());
}
state
}
}
}
other => state.on_mlme_event(other, &mut self.context),
});
}
fn on_timeout(&mut self, timed_event: TimedEvent<Event>) {
self.state = self.state.take().map(|state| match timed_event.event {
event @ Event::EstablishingRsnaTimeout
| event @ Event::KeyFrameExchangeTimeout { .. } => {
state.handle_timeout(timed_event.id, event, &mut self.context)
}
});
}
}
fn report_connect_finished(
responder: Option<Responder<ConnectResult>>,
context: &Context,
result: ConnectResult,
failure: Option<ConnectFailure>,
) {
if let Some(responder) = responder {
responder.respond(result.clone());
}
context.info_sink.send(InfoEvent::ConnectFinished { result, failure });
}
#[cfg(test)]
mod tests {
use super::*;
use fidl_fuchsia_wlan_mlme as fidl_mlme;
use std::error::Error;
use wlan_common::RadioConfig;
use super::test_utils::{
expect_info_event, fake_protected_bss_description, fake_unprotected_bss_description,
};
use crate::Station;
const CLIENT_ADDR: [u8; 6] = [0x7A, 0xE7, 0x76, 0xD9, 0xF2, 0x67];
#[test]
fn status_connecting_to() {
let (mut sme, _mlme_stream, _info_stream, _time_stream) = create_sme();
assert_eq!(Status { connected_to: None, connecting_to: None }, sme.status());
// Issue a connect command and expect the status to change appropriately.
// We also check that the association machine state is still disconnected
// to make sure that the status comes from the scanner.
let _recv = sme.on_connect_command(connect_req(b"foo".to_vec(), vec![]));
assert_eq!(None, sme.state.as_ref().unwrap().status().connecting_to);
assert_eq!(
Status { connected_to: None, connecting_to: Some(b"foo".to_vec()) },
sme.status()
);
// Push a fake scan result into SME. We should still be connecting to "foo",
// but the status should now come from the state machine and not from the scanner.
sme.on_mlme_event(MlmeEvent::OnScanResult {
result: fidl_mlme::ScanResult {
txn_id: 1,
bss: fake_unprotected_bss_description(b"foo".to_vec()),
},
});
sme.on_mlme_event(MlmeEvent::OnScanEnd {
end: fidl_mlme::ScanEnd { txn_id: 1, code: fidl_mlme::ScanResultCodes::Success },
});
assert_eq!(Some(b"foo".to_vec()), sme.state.as_ref().unwrap().status().connecting_to);
assert_eq!(
Status { connected_to: None, connecting_to: Some(b"foo".to_vec()) },
sme.status()
);
// Even if we scheduled a scan to connect to another network "bar", we should
// still report that we are connecting to "foo".
let _recv2 = sme.on_connect_command(connect_req(b"bar".to_vec(), vec![]));
assert_eq!(
Status { connected_to: None, connecting_to: Some(b"foo".to_vec()) },
sme.status()
);
// Simulate that joining "foo" failed. We should now be connecting to "bar".
sme.on_mlme_event(MlmeEvent::JoinConf {
resp: fidl_mlme::JoinConfirm {
result_code: fidl_mlme::JoinResultCodes::JoinFailureTimeout,
},
});
assert_eq!(
Status { connected_to: None, connecting_to: Some(b"bar".to_vec()) },
sme.status()
);
}
#[test]
fn connecting_password_supplied_for_protected_network() {
let (mut sme, mut mlme_stream, _info_stream, _time_stream) = create_sme();
assert_eq!(Status { connected_to: None, connecting_to: None }, sme.status());
// Issue a connect command and verify that connecting_to status is changed for upper
// layer (but not underlying state machine) and a scan request is sent to MLME.
let req = connect_req(b"foo".to_vec(), b"somepass".to_vec());
let _recv = sme.on_connect_command(req);
assert_eq!(None, sme.state.as_ref().unwrap().status().connecting_to);
assert_eq!(
Status { connected_to: None, connecting_to: Some(b"foo".to_vec()) },
sme.status()
);
if let Ok(Some(MlmeRequest::Scan(..))) = mlme_stream.try_next() {
// expected path; nothing to do
} else {
panic!("expect scan request to MLME");
}
// Simulate scan end and verify that underlying state machine's status is changed,
// and a join request is sent to MLME.
sme.on_mlme_event(MlmeEvent::OnScanResult {
result: fidl_mlme::ScanResult {
txn_id: 1,
bss: fake_protected_bss_description(b"foo".to_vec()),
},
});
sme.on_mlme_event(MlmeEvent::OnScanEnd {
end: fidl_mlme::ScanEnd { txn_id: 1, code: fidl_mlme::ScanResultCodes::Success },
});
assert_eq!(Some(b"foo".to_vec()), sme.state.as_ref().unwrap().status().connecting_to);
assert_eq!(
Status { connected_to: None, connecting_to: Some(b"foo".to_vec()) },
sme.status()
);
if let Ok(Some(MlmeRequest::Join(..))) = mlme_stream.try_next() {
// expected path; nothing to do
} else {
panic!("expect join request to MLME");
}
}
#[test]
fn connecting_password_supplied_for_unprotected_network() {
let (mut sme, mut mlme_stream, _info_stream, _time_stream) = create_sme();
assert_eq!(Status { connected_to: None, connecting_to: None }, sme.status());
let req = connect_req(b"foo".to_vec(), b"somepass".to_vec());
let mut connect_fut = sme.on_connect_command(req);
assert_eq!(None, sme.state.as_ref().unwrap().status().connecting_to);
assert_eq!(
Status { connected_to: None, connecting_to: Some(b"foo".to_vec()) },
sme.status()
);
// Push a fake scan result into SME. We should not attempt to connect
// because a password was supplied for unprotected network. So both the
// SME client and underlying state machine should report not connecting
// anymore.
sme.on_mlme_event(MlmeEvent::OnScanResult {
result: fidl_mlme::ScanResult {
txn_id: 1,
bss: fake_unprotected_bss_description(b"foo".to_vec()),
},
});
sme.on_mlme_event(MlmeEvent::OnScanEnd {
end: fidl_mlme::ScanEnd { txn_id: 1, code: fidl_mlme::ScanResultCodes::Success },
});
assert_eq!(None, sme.state.as_ref().unwrap().status().connecting_to);
assert_eq!(Status { connected_to: None, connecting_to: None }, sme.status());
// No join request should be sent to MLME
loop {
match mlme_stream.try_next() {
Ok(event) => match event {
Some(MlmeRequest::Join(..)) => panic!("unexpected join request to MLME"),
None => break,
_ => (),
},
Err(e) => {
assert_eq!(e.description(), "receiver channel is empty");
break;
}
}
}
// User should get a message that connection failed
assert_eq!(Ok(Some(ConnectResult::Failed)), connect_fut.try_recv());
}
#[test]
fn connecting_no_password_supplied_for_protected_network() {
let (mut sme, mut mlme_stream, _info_stream, _time_stream) = create_sme();
assert_eq!(Status { connected_to: None, connecting_to: None }, sme.status());
let req = connect_req(b"foo".to_vec(), vec![]);
let mut connect_fut = sme.on_connect_command(req);
assert_eq!(None, sme.state.as_ref().unwrap().status().connecting_to);
assert_eq!(
Status { connected_to: None, connecting_to: Some(b"foo".to_vec()) },
sme.status()
);
// Push a fake scan result into SME. We should not attempt to connect
// because no password was supplied for a protected network.
sme.on_mlme_event(MlmeEvent::OnScanResult {
result: fidl_mlme::ScanResult {
txn_id: 1,
bss: fake_protected_bss_description(b"foo".to_vec()),
},
});
sme.on_mlme_event(MlmeEvent::OnScanEnd {
end: fidl_mlme::ScanEnd { txn_id: 1, code: fidl_mlme::ScanResultCodes::Success },
});
assert_eq!(None, sme.state.as_ref().unwrap().status().connecting_to);
assert_eq!(Status { connected_to: None, connecting_to: None }, sme.status());
// No join request should be sent to MLME
loop {
match mlme_stream.try_next() {
Ok(event) => match event {
Some(MlmeRequest::Join(..)) => panic!("unexpected join request sent to MLME"),
None => break,
_ => (),
},
Err(e) => {
assert_eq!(e.description(), "receiver channel is empty");
break;
}
}
}
// User should get a message that connection failed
assert_eq!(Ok(Some(ConnectResult::Failed)), connect_fut.try_recv());
}
#[test]
fn connecting_generates_info_events() {
let (mut sme, _mlme_stream, mut info_stream, _time_stream) = create_sme();
let _recv = sme.on_connect_command(connect_req(b"foo".to_vec(), vec![]));
expect_info_event(&mut info_stream, InfoEvent::ConnectStarted);
expect_info_event(&mut info_stream, InfoEvent::MlmeScanStart { txn_id: 1 });
sme.on_mlme_event(MlmeEvent::OnScanResult {
result: fidl_mlme::ScanResult {
txn_id: 1,
bss: fake_unprotected_bss_description(b"foo".to_vec()),
},
});
sme.on_mlme_event(MlmeEvent::OnScanEnd {
end: fidl_mlme::ScanEnd { txn_id: 1, code: fidl_mlme::ScanResultCodes::Success },
});
expect_info_event(&mut info_stream, InfoEvent::MlmeScanEnd { txn_id: 1 });
expect_info_event(&mut info_stream, InfoEvent::AssociationStarted { att_id: 1 });
}
fn connect_req(ssid: Ssid, password: Vec<u8>) -> fidl_sme::ConnectRequest {
fidl_sme::ConnectRequest {
ssid,
password,
radio_cfg: RadioConfig::default().to_fidl(),
scan_type: fidl_common::ScanType::Passive,
}
}
fn create_sme() -> (ClientSme, MlmeStream, InfoStream, TimeStream) {
ClientSme::new(DeviceInfo { addr: CLIENT_ADDR, bands: vec![] })
}
}