blob: 845c7095e1f2bd97fc36d3b1a53e7719719f2fe5 [file] [log] [blame]
# Copyright 2024 The Fuchsia Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""
Tests for connecting to an access point.
"""
import logging
logger = logging.getLogger(__name__)
import asyncio
import struct
from dataclasses import dataclass
from typing import Any, List
import fidl_fuchsia_wlan_wlanix as fidl_wlanix
from antlion import utils
from antlion.controllers.access_point import setup_ap
from antlion.controllers.ap_lib.hostapd_constants import (
AP_DEFAULT_CHANNEL_2G,
AP_SSID_LENGTH_2G,
)
from antlion.controllers.ap_lib.hostapd_security import Security, SecurityMode
from antlion.controllers.ap_lib.hostapd_utils import generate_random_password
from fuchsia_controller_py import Channel
from fuchsia_controller_py.wrappers import AsyncAdapter, asyncmethod
from mobly import base_test, signals, test_runner
from mobly.asserts import assert_equal, assert_true, fail
from wlanix_testing import base_test
class ConnectToApTest(AsyncAdapter, base_test.ConnectionBaseTestClass):
def pre_run(self) -> None:
self.generate_tests(
test_logic=self._test_logic,
name_func=self.name_func,
arg_sets=[
(Security(security_mode=SecurityMode.OPEN),),
(
Security(
security_mode=SecurityMode.WPA2,
password=generate_random_password(),
),
),
(
Security(
security_mode=SecurityMode.WPA3,
password=generate_random_password(),
),
),
],
)
def name_func(self, security: Security) -> str:
return f"test_successfully_connect_to_ap_{security.security_mode}"
@asyncmethod
async def _test_logic(self, security: Security) -> None:
ssid = utils.rand_ascii_str(AP_SSID_LENGTH_2G)
password = getattr(security, "password", None)
setup_ap(
access_point=self.access_point(),
profile_name="whirlwind",
channel=AP_DEFAULT_CHANNEL_2G,
ssid=ssid,
security=security,
)
logger.info("Querying for IfaceIndex...")
get_interface_message = fidl_wlanix.Nl80211Message(
message_type=fidl_wlanix.Nl80211MessageType.MESSAGE,
# fmt: off
payload=[
# Generic Netlink Header
0x05, # Command: GetInterface
0x01, # Version
0x00, 0x00 # Reserved
],
# fmt: on
)
response_list = (
(await self.nl80211_proxy.message(message=get_interface_message))
.unwrap()
.responses
)
iface_index = await read_iface_index_or_fail(response_list)
logger.info("Using IfaceIndex %d for connection test", iface_index)
logger.info("Triggering a scan on IfaceIndex %d", iface_index)
with Nl80211MulticastServer() as ctx:
scan_queue = ctx.message_queue
scan_callback_channel = ctx.callback_channel
self.nl80211_proxy.get_multicast(
group="scan", multicast=scan_callback_channel.take()
)
trigger_scan_message = fidl_wlanix.Nl80211Message(
message_type=fidl_wlanix.Nl80211MessageType.MESSAGE,
# fmt: off
payload=[
# Generic Netlink Header
0x21, # Command: TriggerScan
0x01, # Version
0x00, 0x00, # Reserved
0x08, 0x00, # Length
0x03, 0x00, # Type: IfaceIndex (little-endian)
*list(struct.pack("<I", iface_index)),
],
# fmt: on
)
response_list = (
(await self.nl80211_proxy.message(message=trigger_scan_message))
.unwrap()
.responses
)
assert_equal(
len(response_list),
1,
"Response from TriggerScan should contain a single ACK message.",
)
assert_equal(
response_list[0].message_type,
3,
"Response should have been an ACK.",
)
# Wait for a multicast message to indicate the scan has completed.
try:
scan_message = await asyncio.wait_for(
scan_queue.get(), timeout=20
)
logger.info("Recieved nl80211 scan result signal")
assert (
scan_message.payload is not None
), "Received scan result indication without payload"
assert_equal(
scan_message.payload[0],
34, # Command: NewScanResults
"Received unexpected scan result",
)
except TimeoutError:
raise signals.TestFailure(
"Did not receive a scan result within 20 seconds"
)
with SupplicantStaIfaceCallbackServer() as ctx:
state_change_queue = ctx.state_change_queue
callback_channel = ctx.callback_channel
self.supplicant_sta_iface_proxy.register_callback(
callback=callback_channel.take()
)
proxy, server = Channel.create()
self.supplicant_sta_iface_proxy.add_network(network=server.take())
supplicant_sta_network_proxy = (
fidl_wlanix.SupplicantStaNetworkClient(proxy)
)
supplicant_sta_network_proxy.set_ssid(
ssid=list(ssid.encode("ascii"))
)
if password:
if security.security_mode == SecurityMode.WPA3:
supplicant_sta_network_proxy.set_sae_password(
password=list(password.encode("ascii"))
)
else:
supplicant_sta_network_proxy.set_psk_passphrase(
passphrase=list(password.encode("ascii"))
)
try:
(await supplicant_sta_network_proxy.select()).unwrap()
except AssertionError as e:
raise signals.TestFailure(
f'Failed to connect to "{ssid}" with {security}'
) from e
logger.info(f'Successfully connected to "{ssid}"!')
state_change = await state_change_queue.get()
assert isinstance(
state_change,
fidl_wlanix.SupplicantStaIfaceCallbackOnStateChangedRequest,
), f"Expected OnStateChanged. Got {state_change!r}"
assert_equal(
state_change.new_state,
fidl_wlanix.StaIfaceCallbackState.COMPLETED,
)
assert_true(
state_change_queue.empty(),
"Unexpectedly received additional callback messages.",
)
@dataclass
class SupplicantStaIfaceCallbackContext:
state_change_queue: asyncio.Queue[
fidl_wlanix.SupplicantStaIfaceCallbackOnStateChangedRequest
| fidl_wlanix.SupplicantStaIfaceCallbackOnDisconnectedRequest
| fidl_wlanix.SupplicantStaIfaceCallbackOnAssociationRejectedRequest
]
callback_channel: Channel
class SupplicantStaIfaceCallbackServer(
fidl_wlanix.SupplicantStaIfaceCallbackServer
):
def __init__(
self,
verbose: bool = True,
) -> None:
self.verbose = verbose
self.state_change_queue: asyncio.Queue[
fidl_wlanix.SupplicantStaIfaceCallbackOnStateChangedRequest
| fidl_wlanix.SupplicantStaIfaceCallbackOnDisconnectedRequest
| fidl_wlanix.SupplicantStaIfaceCallbackOnAssociationRejectedRequest
] = asyncio.Queue()
def on_state_changed(
self,
request: fidl_wlanix.SupplicantStaIfaceCallbackOnStateChangedRequest,
) -> None:
if self.verbose:
logger.info("State changed: %s", request)
self.state_change_queue.put_nowait(request)
def on_disconnected(
self,
request: fidl_wlanix.SupplicantStaIfaceCallbackOnDisconnectedRequest,
) -> None:
if self.verbose:
logger.info("Disconnected: %s", request)
self.state_change_queue.put_nowait(request)
def on_association_rejected(
self,
request: fidl_wlanix.SupplicantStaIfaceCallbackOnAssociationRejectedRequest,
) -> None:
if self.verbose:
logger.info("Association rejected: %s", request)
self.state_change_queue.put_nowait(request)
def __enter__(self) -> SupplicantStaIfaceCallbackContext:
client, server = Channel.create()
super().__init__(channel=server)
self.server_task = asyncio.get_running_loop().create_task(self.serve())
return SupplicantStaIfaceCallbackContext(
state_change_queue=self.state_change_queue,
callback_channel=client,
)
def __exit__(self, *args: Any, **kwargs: Any) -> None:
if self.server_task:
self.server_task.cancel()
@dataclass
class Nl80211MulticastServerContext:
message_queue: asyncio.Queue[fidl_wlanix.Nl80211Message]
callback_channel: Channel
class Nl80211MulticastServer(fidl_wlanix.Nl80211MulticastServer):
def __init__(self) -> None:
self.message_queue: asyncio.Queue[
fidl_wlanix.Nl80211Message
] = asyncio.Queue()
def message(
self,
request: fidl_wlanix.Nl80211MulticastMessageRequest,
) -> None:
if request.message is not None:
self.message_queue.put_nowait(request.message)
def __enter__(self) -> Nl80211MulticastServerContext:
client, server = Channel.create()
super().__init__(channel=server)
self.server_task = asyncio.get_running_loop().create_task(self.serve())
return Nl80211MulticastServerContext(
message_queue=self.message_queue,
callback_channel=client,
)
def __exit__(self, *args: Any, **kwargs: Any) -> None:
if self.server_task:
self.server_task.cancel()
async def read_iface_index_or_fail(
response_list: List[fidl_wlanix.Nl80211Message],
) -> int:
last_response_index = len(response_list) - 1
for response_index, response in enumerate(response_list):
if response.message_type == fidl_wlanix.Nl80211MessageType.DONE:
assert_equal(
response_index,
last_response_index,
"Nl80211 DONE message before end of response",
)
break
if response.message_type in [
fidl_wlanix.Nl80211MessageType.ERROR,
fidl_wlanix.Nl80211MessageType.OVERRUN,
]:
fail(
"Received an error Nl80211 message type: %s",
response.message_type,
)
if response.message_type in [
fidl_wlanix.Nl80211MessageType.NO_OP,
fidl_wlanix.Nl80211MessageType.ACK,
]:
fail(
"Received an unexpected Nl80211 message: %s",
response.message_type,
)
assert_equal(
response.message_type,
fidl_wlanix.Nl80211MessageType.MESSAGE,
"After filtering all other message types, a type other than MESSAGE was received.",
)
assert response.payload is not None, "MESSAGE must contain a payload"
formatted_response_payload = [
format(b, "#04x") for b in response.payload
]
assert_equal(
response.payload[0],
7,
f"Payload is not a NewInterface message: {formatted_response_payload}",
)
assert_equal(
response.payload[6],
3,
f"First attribute is not an IfaceIndex: {formatted_response_payload}",
)
return struct.unpack("<I", bytes(response.payload[8:12]))[0]
raise RuntimeError(
f"Did not find an iface index in the response list: {response_list}"
)
if __name__ == "__main__":
test_runner.main()