| # Copyright 2024 The Fuchsia Authors. All rights reserved. |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| """ |
| Tests for connecting to an access point. |
| """ |
| |
| import logging |
| |
| logger = logging.getLogger(__name__) |
| |
| import asyncio |
| import struct |
| from dataclasses import dataclass |
| from typing import Any, List |
| |
| import fidl_fuchsia_wlan_wlanix as fidl_wlanix |
| from antlion import utils |
| from antlion.controllers.access_point import setup_ap |
| from antlion.controllers.ap_lib.hostapd_constants import ( |
| AP_DEFAULT_CHANNEL_2G, |
| AP_SSID_LENGTH_2G, |
| ) |
| from antlion.controllers.ap_lib.hostapd_security import Security, SecurityMode |
| from antlion.controllers.ap_lib.hostapd_utils import generate_random_password |
| from fuchsia_controller_py import Channel |
| from fuchsia_controller_py.wrappers import AsyncAdapter, asyncmethod |
| from mobly import base_test, signals, test_runner |
| from mobly.asserts import assert_equal, assert_true, fail |
| from wlanix_testing import base_test |
| |
| |
| class ConnectToApTest(AsyncAdapter, base_test.ConnectionBaseTestClass): |
| def pre_run(self) -> None: |
| self.generate_tests( |
| test_logic=self._test_logic, |
| name_func=self.name_func, |
| arg_sets=[ |
| (Security(security_mode=SecurityMode.OPEN),), |
| ( |
| Security( |
| security_mode=SecurityMode.WPA2, |
| password=generate_random_password(), |
| ), |
| ), |
| ( |
| Security( |
| security_mode=SecurityMode.WPA3, |
| password=generate_random_password(), |
| ), |
| ), |
| ], |
| ) |
| |
| def name_func(self, security: Security) -> str: |
| return f"test_successfully_connect_to_ap_{security.security_mode}" |
| |
| @asyncmethod |
| async def _test_logic(self, security: Security) -> None: |
| ssid = utils.rand_ascii_str(AP_SSID_LENGTH_2G) |
| password = getattr(security, "password", None) |
| |
| setup_ap( |
| access_point=self.access_point(), |
| profile_name="whirlwind", |
| channel=AP_DEFAULT_CHANNEL_2G, |
| ssid=ssid, |
| security=security, |
| ) |
| |
| logger.info("Querying for IfaceIndex...") |
| get_interface_message = fidl_wlanix.Nl80211Message( |
| message_type=fidl_wlanix.Nl80211MessageType.MESSAGE, |
| # fmt: off |
| payload=[ |
| # Generic Netlink Header |
| 0x05, # Command: GetInterface |
| 0x01, # Version |
| 0x00, 0x00 # Reserved |
| ], |
| # fmt: on |
| ) |
| response_list = ( |
| (await self.nl80211_proxy.message(message=get_interface_message)) |
| .unwrap() |
| .responses |
| ) |
| iface_index = await read_iface_index_or_fail(response_list) |
| logger.info("Using IfaceIndex %d for connection test", iface_index) |
| |
| logger.info("Triggering a scan on IfaceIndex %d", iface_index) |
| with Nl80211MulticastServer() as ctx: |
| scan_queue = ctx.message_queue |
| scan_callback_channel = ctx.callback_channel |
| |
| self.nl80211_proxy.get_multicast( |
| group="scan", multicast=scan_callback_channel.take() |
| ) |
| |
| trigger_scan_message = fidl_wlanix.Nl80211Message( |
| message_type=fidl_wlanix.Nl80211MessageType.MESSAGE, |
| # fmt: off |
| payload=[ |
| # Generic Netlink Header |
| 0x21, # Command: TriggerScan |
| 0x01, # Version |
| 0x00, 0x00, # Reserved |
| 0x08, 0x00, # Length |
| 0x03, 0x00, # Type: IfaceIndex (little-endian) |
| *list(struct.pack("<I", iface_index)), |
| ], |
| # fmt: on |
| ) |
| response_list = ( |
| (await self.nl80211_proxy.message(message=trigger_scan_message)) |
| .unwrap() |
| .responses |
| ) |
| assert_equal( |
| len(response_list), |
| 1, |
| "Response from TriggerScan should contain a single ACK message.", |
| ) |
| assert_equal( |
| response_list[0].message_type, |
| 3, |
| "Response should have been an ACK.", |
| ) |
| |
| # Wait for a multicast message to indicate the scan has completed. |
| try: |
| scan_message = await asyncio.wait_for( |
| scan_queue.get(), timeout=20 |
| ) |
| logger.info("Recieved nl80211 scan result signal") |
| assert ( |
| scan_message.payload is not None |
| ), "Received scan result indication without payload" |
| assert_equal( |
| scan_message.payload[0], |
| 34, # Command: NewScanResults |
| "Received unexpected scan result", |
| ) |
| except TimeoutError: |
| raise signals.TestFailure( |
| "Did not receive a scan result within 20 seconds" |
| ) |
| |
| with SupplicantStaIfaceCallbackServer() as ctx: |
| state_change_queue = ctx.state_change_queue |
| callback_channel = ctx.callback_channel |
| |
| self.supplicant_sta_iface_proxy.register_callback( |
| callback=callback_channel.take() |
| ) |
| |
| proxy, server = Channel.create() |
| self.supplicant_sta_iface_proxy.add_network(network=server.take()) |
| supplicant_sta_network_proxy = ( |
| fidl_wlanix.SupplicantStaNetworkClient(proxy) |
| ) |
| |
| supplicant_sta_network_proxy.set_ssid( |
| ssid=list(ssid.encode("ascii")) |
| ) |
| if password: |
| if security.security_mode == SecurityMode.WPA3: |
| supplicant_sta_network_proxy.set_sae_password( |
| password=list(password.encode("ascii")) |
| ) |
| else: |
| supplicant_sta_network_proxy.set_psk_passphrase( |
| passphrase=list(password.encode("ascii")) |
| ) |
| |
| try: |
| (await supplicant_sta_network_proxy.select()).unwrap() |
| except AssertionError as e: |
| raise signals.TestFailure( |
| f'Failed to connect to "{ssid}" with {security}' |
| ) from e |
| logger.info(f'Successfully connected to "{ssid}"!') |
| |
| state_change = await state_change_queue.get() |
| assert isinstance( |
| state_change, |
| fidl_wlanix.SupplicantStaIfaceCallbackOnStateChangedRequest, |
| ), f"Expected OnStateChanged. Got {state_change!r}" |
| assert_equal( |
| state_change.new_state, |
| fidl_wlanix.StaIfaceCallbackState.COMPLETED, |
| ) |
| assert_true( |
| state_change_queue.empty(), |
| "Unexpectedly received additional callback messages.", |
| ) |
| |
| |
| @dataclass |
| class SupplicantStaIfaceCallbackContext: |
| state_change_queue: asyncio.Queue[ |
| fidl_wlanix.SupplicantStaIfaceCallbackOnStateChangedRequest |
| | fidl_wlanix.SupplicantStaIfaceCallbackOnDisconnectedRequest |
| | fidl_wlanix.SupplicantStaIfaceCallbackOnAssociationRejectedRequest |
| ] |
| callback_channel: Channel |
| |
| |
| class SupplicantStaIfaceCallbackServer( |
| fidl_wlanix.SupplicantStaIfaceCallbackServer |
| ): |
| def __init__( |
| self, |
| verbose: bool = True, |
| ) -> None: |
| self.verbose = verbose |
| self.state_change_queue: asyncio.Queue[ |
| fidl_wlanix.SupplicantStaIfaceCallbackOnStateChangedRequest |
| | fidl_wlanix.SupplicantStaIfaceCallbackOnDisconnectedRequest |
| | fidl_wlanix.SupplicantStaIfaceCallbackOnAssociationRejectedRequest |
| ] = asyncio.Queue() |
| |
| def on_state_changed( |
| self, |
| request: fidl_wlanix.SupplicantStaIfaceCallbackOnStateChangedRequest, |
| ) -> None: |
| if self.verbose: |
| logger.info("State changed: %s", request) |
| self.state_change_queue.put_nowait(request) |
| |
| def on_disconnected( |
| self, |
| request: fidl_wlanix.SupplicantStaIfaceCallbackOnDisconnectedRequest, |
| ) -> None: |
| if self.verbose: |
| logger.info("Disconnected: %s", request) |
| self.state_change_queue.put_nowait(request) |
| |
| def on_association_rejected( |
| self, |
| request: fidl_wlanix.SupplicantStaIfaceCallbackOnAssociationRejectedRequest, |
| ) -> None: |
| if self.verbose: |
| logger.info("Association rejected: %s", request) |
| self.state_change_queue.put_nowait(request) |
| |
| def __enter__(self) -> SupplicantStaIfaceCallbackContext: |
| client, server = Channel.create() |
| super().__init__(channel=server) |
| self.server_task = asyncio.get_running_loop().create_task(self.serve()) |
| return SupplicantStaIfaceCallbackContext( |
| state_change_queue=self.state_change_queue, |
| callback_channel=client, |
| ) |
| |
| def __exit__(self, *args: Any, **kwargs: Any) -> None: |
| if self.server_task: |
| self.server_task.cancel() |
| |
| |
| @dataclass |
| class Nl80211MulticastServerContext: |
| message_queue: asyncio.Queue[fidl_wlanix.Nl80211Message] |
| callback_channel: Channel |
| |
| |
| class Nl80211MulticastServer(fidl_wlanix.Nl80211MulticastServer): |
| def __init__(self) -> None: |
| self.message_queue: asyncio.Queue[ |
| fidl_wlanix.Nl80211Message |
| ] = asyncio.Queue() |
| |
| def message( |
| self, |
| request: fidl_wlanix.Nl80211MulticastMessageRequest, |
| ) -> None: |
| if request.message is not None: |
| self.message_queue.put_nowait(request.message) |
| |
| def __enter__(self) -> Nl80211MulticastServerContext: |
| client, server = Channel.create() |
| super().__init__(channel=server) |
| self.server_task = asyncio.get_running_loop().create_task(self.serve()) |
| return Nl80211MulticastServerContext( |
| message_queue=self.message_queue, |
| callback_channel=client, |
| ) |
| |
| def __exit__(self, *args: Any, **kwargs: Any) -> None: |
| if self.server_task: |
| self.server_task.cancel() |
| |
| |
| async def read_iface_index_or_fail( |
| response_list: List[fidl_wlanix.Nl80211Message], |
| ) -> int: |
| last_response_index = len(response_list) - 1 |
| for response_index, response in enumerate(response_list): |
| if response.message_type == fidl_wlanix.Nl80211MessageType.DONE: |
| assert_equal( |
| response_index, |
| last_response_index, |
| "Nl80211 DONE message before end of response", |
| ) |
| break |
| |
| if response.message_type in [ |
| fidl_wlanix.Nl80211MessageType.ERROR, |
| fidl_wlanix.Nl80211MessageType.OVERRUN, |
| ]: |
| fail( |
| "Received an error Nl80211 message type: %s", |
| response.message_type, |
| ) |
| |
| if response.message_type in [ |
| fidl_wlanix.Nl80211MessageType.NO_OP, |
| fidl_wlanix.Nl80211MessageType.ACK, |
| ]: |
| fail( |
| "Received an unexpected Nl80211 message: %s", |
| response.message_type, |
| ) |
| |
| assert_equal( |
| response.message_type, |
| fidl_wlanix.Nl80211MessageType.MESSAGE, |
| "After filtering all other message types, a type other than MESSAGE was received.", |
| ) |
| assert response.payload is not None, "MESSAGE must contain a payload" |
| |
| formatted_response_payload = [ |
| format(b, "#04x") for b in response.payload |
| ] |
| assert_equal( |
| response.payload[0], |
| 7, |
| f"Payload is not a NewInterface message: {formatted_response_payload}", |
| ) |
| assert_equal( |
| response.payload[6], |
| 3, |
| f"First attribute is not an IfaceIndex: {formatted_response_payload}", |
| ) |
| |
| return struct.unpack("<I", bytes(response.payload[8:12]))[0] |
| |
| raise RuntimeError( |
| f"Did not find an iface index in the response list: {response_list}" |
| ) |
| |
| |
| if __name__ == "__main__": |
| test_runner.main() |