blob: 19e972380c984d03e8e8fab7c628f59162930eae [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright 2022 The Fuchsia Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import time
from dataclasses import dataclass
from pathlib import Path
from statistics import pstdev
from typing import Any, Dict, List, Optional, Tuple
from mobly import asserts, test_runner
from antlion import utils
from antlion.controllers.access_point import setup_ap
from antlion.controllers.ap_lib import hostapd_constants
from antlion.controllers.ap_lib.hostapd_security import Security, SecurityMode
from antlion.controllers.ap_lib.regulatory_channels import COUNTRY_CHANNELS
from antlion.controllers.iperf_client import IPerfClientOverAdb, IPerfClientOverSsh
from antlion.controllers.iperf_server import IPerfResult
from antlion.test_utils.abstract_devices.wlan_device import create_wlan_device
from antlion.test_utils.wifi import base_test
DEFAULT_MIN_THROUGHPUT = 0.0
DEFAULT_MAX_STD_DEV = 1.0
DEFAULT_IPERF_TIMEOUT = 30
DEFAULT_TIME_TO_WAIT_FOR_IP_ADDR = 30
GRAPH_CIRCLE_SIZE = 10
MAX_2_4_CHANNEL = 14
TIME_TO_SLEEP_BETWEEN_RETRIES = 1
WEP_HEX_STRING_LENGTH = 10
MEGABITS_PER_SECOND = "Mbps"
@dataclass
class TestParams:
country_code: str
"""Country code for the DUT to set before running the test."""
security_mode: Optional[SecurityMode]
"""Security type of the network to create. None represents an open network."""
channel: int
"""Channel for the AP to broadcast on"""
channel_bandwidth: int
"""Channel bandwidth in MHz for the AP to broadcast with"""
expect_min_rx_throughput_mbps: float = DEFAULT_MIN_THROUGHPUT
"""Expected minimum receive throughput in Mb/s"""
expect_min_tx_throughput_mbps: float = DEFAULT_MIN_THROUGHPUT
"""Expected minimum transmit throughput in Mb/s"""
# TODO: Use this value
expect_max_std_dev: float = DEFAULT_MAX_STD_DEV
"""Expected maximum standard deviation of throughput in Mb/s"""
@dataclass
class ThroughputKey:
country_code: str
security_mode: Optional[SecurityMode]
channel_bandwidth: int
@staticmethod
def from_test(test: TestParams) -> "ThroughputKey":
return ThroughputKey(
country_code=test.country_code,
security_mode=test.security_mode,
channel_bandwidth=test.channel_bandwidth,
)
@dataclass
class ThroughputValue:
channel: int
tx_throughput_mbps: Optional[float]
rx_throughput_mbps: Optional[float]
ChannelThroughputMap = Dict[ThroughputKey, List[ThroughputValue]]
class ChannelSweepTest(base_test.WifiBaseTest):
"""Tests channel performance.
Testbed Requirement:
* 1 x Fuchsia device (dut)
* 1 x access point
* 1 x Linux Machine used as IPerfServer
Note: Performance tests should be done in isolated testbed.
"""
def __init__(self, configs: List[Any]) -> None:
super().__init__(configs)
self.channel_throughput: ChannelThroughputMap = {}
def pre_run(self):
tests: List[Tuple[TestParams]] = []
def generate_test_name(test: TestParams):
security = test.security_mode if test.security_mode else "open"
return f"test_{test.country_code}_{security}_channel_{test.channel}_{test.channel_bandwidth}mhz"
def test_params(test_name):
return self.user_params.get("channel_sweep_test_params", {}).get(
test_name, {}
)
for country_channels in [COUNTRY_CHANNELS["United States of America"]]:
for security_mode in [
None,
SecurityMode.WEP,
SecurityMode.WPA,
SecurityMode.WPA2,
SecurityMode.WPA_WPA2,
SecurityMode.WPA3,
]:
for channel, bandwidths in country_channels.allowed_channels.items():
for bandwidth in bandwidths:
test = TestParams(
country_code=country_channels.country_code,
security_mode=security_mode,
channel=channel,
channel_bandwidth=bandwidth,
)
name = generate_test_name(test)
test.expect_min_rx_throughput_mbps = test_params(name).get(
"min_rx_throughput", DEFAULT_MIN_THROUGHPUT
)
test.expect_min_tx_throughput_mbps = test_params(name).get(
"min_tx_throughput", DEFAULT_MIN_THROUGHPUT
)
test.expect_max_std_dev = test_params(name).get(
"max_std_dev", DEFAULT_MAX_STD_DEV
)
tests.append((test,))
self.generate_tests(self.run_channel_performance, generate_test_name, tests)
def get_existing_test_names(self) -> List[str]:
test_names: List[str] = super().get_existing_test_names()
# Verify standard deviation last since it depends on the throughput results from
# all other tests.
test_names.sort(key=lambda n: n == "test_standard_deviation")
return test_names
def setup_class(self):
super().setup_class()
self.log = logging.getLogger()
self.time_to_wait_for_ip_addr = self.user_params.get(
"channel_sweep_test_params", {}
).get("time_to_wait_for_ip_addr", DEFAULT_TIME_TO_WAIT_FOR_IP_ADDR)
device_type = self.user_params.get("dut", "fuchsia_devices")
if device_type == "fuchsia_devices":
self.fuchsia_device = self.fuchsia_devices[0]
self.dut = create_wlan_device(self.fuchsia_device)
elif device_type == "android_devices":
self.dut = create_wlan_device(self.android_devices[0])
else:
raise ValueError(
f'Invalid "dut" type specified in config: "{device_type}".'
'Expected "fuchsia_devices" or "android_devices".'
)
self.android_devices = getattr(self, "android_devices", [])
self.access_point = self.access_points[0]
self.access_point.stop_all_aps()
self.iperf_server = self.iperf_servers[0]
self.iperf_server.start()
if hasattr(self, "iperf_clients") and self.iperf_clients:
self.iperf_client = self.iperf_clients[0]
else:
self.iperf_client = self.dut.create_iperf_client()
def teardown_class(self):
self.write_graph()
super().teardown_class()
def setup_test(self):
super().setup_test()
# TODO(fxb/46417): Uncomment when wlanClearCountry is implemented up any
# country code changes.
# for fd in self.fuchsia_devices:
# phy_ids_response = fd.wlan_lib.wlanPhyIdList()
# if phy_ids_response.get('error'):
# raise ConnectionError(
# 'Failed to retrieve phy ids from FuchsiaDevice (%s). '
# 'Error: %s' % (fd.ip, phy_ids_response['error']))
# for id in phy_ids_response['result']:
# clear_country_response = fd.wlan_lib.wlanClearCountry(id)
# if clear_country_response.get('error'):
# raise EnvironmentError(
# 'Failed to reset country code on FuchsiaDevice (%s). '
# 'Error: %s' % (fd.ip, clear_country_response['error'])
# )
self.access_point.stop_all_aps()
for ad in self.android_devices:
ad.droid.wakeLockAcquireBright()
ad.droid.wakeUpNow()
self.dut.wifi_toggle_state(True)
self.dut.disconnect()
def teardown_test(self):
for ad in self.android_devices:
ad.droid.wakeLockRelease()
ad.droid.goToSleepNow()
self.dut.turn_location_off_and_scan_toggle_off()
self.dut.disconnect()
self.download_ap_logs()
self.access_point.stop_all_aps()
super().teardown_test()
def setup_ap(
self,
channel: int,
channel_bandwidth: int,
security_profile: Optional[Security] = None,
) -> str:
"""Start network on AP with basic configuration.
Args:
channel: channel to use for network
channel_bandwidth: channel bandwidth in mhz to use for network,
security_profile: security type to use or None if open
Returns:
SSID of the newly created and running network
Raises:
ConnectionError if network is not started successfully.
"""
ssid = utils.rand_ascii_str(hostapd_constants.AP_SSID_LENGTH_2G)
try:
setup_ap(
access_point=self.access_point,
profile_name="whirlwind",
channel=channel,
security=security_profile,
force_wmm=True,
ssid=ssid,
vht_bandwidth=channel_bandwidth,
setup_bridge=True,
)
self.log.info(
f"Network (ssid: {ssid}) up on channel {channel} "
f"w/ channel bandwidth {channel_bandwidth} MHz"
)
return ssid
except Exception as err:
raise ConnectionError(
f"Failed to setup ap on channel: {channel}, "
f"channel bandwidth: {channel_bandwidth} MHz. "
) from err
def get_and_verify_iperf_address(self, channel, device, interface=None):
"""Get ip address from a devices interface and verify it belongs to
expected subnet based on APs DHCP config.
Args:
channel: int, channel network is running on, to determine subnet
device: device to get ip address for
interface (default: None): interface on device to get ip address.
If None, uses device.test_interface.
Returns:
String, ip address of device on given interface (or test_interface)
Raises:
ConnectionError, if device does not have a valid ip address after
all retries.
"""
if channel <= MAX_2_4_CHANNEL:
subnet = self.access_point._AP_2G_SUBNET_STR
else:
subnet = self.access_point._AP_5G_SUBNET_STR
end_time = time.time() + self.time_to_wait_for_ip_addr
while time.time() < end_time:
if interface:
device_addresses = device.get_interface_ip_addresses(interface)
else:
device_addresses = device.get_interface_ip_addresses(
device.test_interface
)
if device_addresses["ipv4_private"]:
for ip_addr in device_addresses["ipv4_private"]:
if utils.ip_in_subnet(ip_addr, subnet):
return ip_addr
else:
self.log.debug(
f"Device has an ip address ({ip_addr}), but it is not in subnet {subnet}"
)
else:
self.log.debug("Device does not have a valid ip address. Retrying.")
time.sleep(TIME_TO_SLEEP_BETWEEN_RETRIES)
raise ConnectionError("Device failed to get an ip address.")
def get_iperf_throughput(
self,
iperf_server_address: str,
iperf_client_address: str,
reverse: bool = False,
) -> float:
"""Run iperf between client and server and get the throughput.
Args:
iperf_server_address: IP address of running iperf server
iperf_client_address: IP address of iperf client (dut)
reverse: If True, run traffic in reverse direction, from server to client.
Returns:
iperf throughput or 0 if iperf fails
"""
if reverse:
self.log.info(
f"Running IPerf traffic from server ({iperf_server_address}) to "
f"dut ({iperf_client_address})."
)
iperf_results_file = self.iperf_client.start(
iperf_server_address,
"-i 1 -t 10 -R -J",
"channel_sweep_rx",
timeout=DEFAULT_IPERF_TIMEOUT,
)
else:
self.log.info(
f"Running IPerf traffic from dut ({iperf_client_address}) to "
f"server ({iperf_server_address})."
)
iperf_results_file = self.iperf_client.start(
iperf_server_address,
"-i 1 -t 10 -J",
"channel_sweep_tx",
timeout=DEFAULT_IPERF_TIMEOUT,
)
if iperf_results_file:
iperf_results = IPerfResult(
iperf_results_file, reporting_speed_units=MEGABITS_PER_SECOND
)
return iperf_results.avg_send_rate or 0.0
return 0.0
def log_to_file_and_throughput_data(
self,
test: TestParams,
tx_throughput: Optional[float],
rx_throughput: Optional[float],
):
"""Write performance info to csv file and to throughput data.
Args:
channel: int, channel that test was run on
channel_bandwidth: int, channel bandwidth the test used
tx_throughput: float, throughput value from dut to iperf server
rx_throughput: float, throughput value from iperf server to dut
"""
test_name = self.current_test_info.name
log_file = Path(os.path.join(self.log_path, "throughput.csv"))
self.log.info(f"Writing IPerf results for {test_name} to {log_file}")
if not log_file.is_file():
with open(log_file, "x") as csv_file:
csv_file.write(
"country code,security,channel,channel bandwidth,tx throughput,rx throughput\n"
)
with open(log_file, "a") as csv_file:
csv_file.write(
f"{test.country_code},{test.security_mode},{test.channel},{test.channel_bandwidth},{tx_throughput},{rx_throughput}\n"
)
key = ThroughputKey.from_test(test)
if key not in self.channel_throughput:
self.channel_throughput[key] = []
self.channel_throughput[key].append(
ThroughputValue(
channel=test.channel,
tx_throughput_mbps=tx_throughput,
rx_throughput_mbps=rx_throughput,
)
)
def write_graph(self):
"""Create graph html files from throughput data, plotting channel vs
tx_throughput and channel vs rx_throughput.
"""
# If performance measurement is skipped
if not self.iperf_server:
return
try:
from bokeh.plotting import ( # type: ignore
ColumnDataSource,
figure,
output_file,
save,
)
except ImportError:
self.log.warn(
"bokeh is not installed: skipping creation of graphs. "
"Note CSV files are still available. If graphs are "
'desired, install antlion with the "bokeh" feature.'
)
return
for key in self.channel_throughput.keys():
output_file_name = os.path.join(
self.log_path,
f"channel_throughput_{key.country_code}_{key.security_mode}_{key.channel_bandwidth}mhz.html",
)
output_file(output_file_name)
channels = []
tx_throughputs = []
rx_throughputs = []
for throughput in sorted(
self.channel_throughput[key], key=lambda t: t.channel
):
channels.append(str(throughput.channel))
tx_throughputs.append(throughput.tx_throughput_mbps)
rx_throughputs.append(throughput.rx_throughput_mbps)
channel_vs_throughput_data = ColumnDataSource(
data=dict(
channels=channels,
tx_throughput=tx_throughputs,
rx_throughput=rx_throughputs,
)
)
TOOLTIPS = [
("Channel", "@channels"),
("TX_Throughput", "@tx_throughput"),
("RX_Throughput", "@rx_throughput"),
]
channel_vs_throughput_graph = figure(
title="Channels vs. Throughput",
x_axis_label="Channels",
x_range=channels,
y_axis_label="Throughput",
tooltips=TOOLTIPS,
)
channel_vs_throughput_graph.sizing_mode = "stretch_both"
channel_vs_throughput_graph.title.align = "center"
channel_vs_throughput_graph.line(
"channels",
"tx_throughput",
source=channel_vs_throughput_data,
line_width=2,
line_color="blue",
legend_label="TX_Throughput",
)
channel_vs_throughput_graph.circle(
"channels",
"tx_throughput",
source=channel_vs_throughput_data,
size=GRAPH_CIRCLE_SIZE,
color="blue",
)
channel_vs_throughput_graph.line(
"channels",
"rx_throughput",
source=channel_vs_throughput_data,
line_width=2,
line_color="red",
legend_label="RX_Throughput",
)
channel_vs_throughput_graph.circle(
"channels",
"rx_throughput",
source=channel_vs_throughput_data,
size=GRAPH_CIRCLE_SIZE,
color="red",
)
channel_vs_throughput_graph.legend.location = "top_left"
graph_file = save([channel_vs_throughput_graph])
self.log.info(f"Saved graph to {graph_file}")
def test_standard_deviation(self):
"""Verify throughputs don't deviate too much across channels.
Assert the throughput standard deviation across all channels of the same
country, security, and bandwidth does not exceed the maximum specified in the
user param config. If no maximum is set, a default of 1.0 standard deviations
will be used (34.1% from the mean).
Raises:
TestFailure, if standard deviation of throughput exceeds max_std_dev
"""
# If performance measurement is skipped
if not self.iperf_server:
return
max_std_dev = self.user_params.get("channel_sweep_test_params", {}).get(
"max_std_dev", DEFAULT_MAX_STD_DEV
)
self.log.info(
"Verifying standard deviation across channels does not exceed max standard "
f"deviation of {max_std_dev} Mb/s"
)
errors: List[str] = []
for test, throughputs in self.channel_throughput.items():
tx_values = []
rx_values = []
for throughput in throughputs:
if throughput.tx_throughput_mbps is not None:
tx_values.append(throughput.tx_throughput_mbps)
if throughput.rx_throughput_mbps is not None:
rx_values.append(throughput.rx_throughput_mbps)
tx_std_dev = pstdev(tx_values)
rx_std_dev = pstdev(rx_values)
if tx_std_dev > max_std_dev:
errors.append(
f"[{test.country_code} {test.security_mode} "
f"{test.channel_bandwidth}mhz] TX throughput standard deviation "
f"{tx_std_dev} Mb/s exceeds expected max of {max_std_dev} Mb/s"
)
if rx_std_dev > max_std_dev:
errors.append(
f"[{test.country_code} {test.security_mode} "
f"{test.channel_bandwidth}mhz] RX throughput standard deviation "
f"{rx_std_dev} Mb/s exceeds expected max of {max_std_dev} Mb/s"
)
if errors:
error_message = "\n - ".join(errors)
asserts.fail(
f"Failed to meet standard deviation expectations:\n - {error_message}"
)
def run_channel_performance(self, test: TestParams):
"""Run a single channel performance test
Log results to csv file and throughput data.
1. Sets up network with test settings
2. Associates DUT
3. Runs traffic between DUT and iperf server (both directions)
4. Logs channel, tx_throughput (Mb/s), and rx_throughput (Mb/s) to
log file and throughput data.
5. Checks throughput values against minimum throughput thresholds.
Raises:
TestFailure, if throughput (either direction) is less than
the directions given minimum throughput threshold.
"""
self.fuchsia_device.wlan_policy_controller.set_country_code(test.country_code)
if test.security_mode:
if test.security_mode == SecurityMode.WEP:
password = utils.rand_hex_str(WEP_HEX_STRING_LENGTH)
else:
password = utils.rand_ascii_str(hostapd_constants.MIN_WPA_PSK_LENGTH)
security_profile = Security(
security_mode=test.security_mode, password=password
)
target_security = test.security_mode.default_target_security()
else:
password = None
security_profile = None
target_security = None
ssid = self.setup_ap(test.channel, test.channel_bandwidth, security_profile)
associated = self.dut.associate(
ssid, target_pwd=password, target_security=target_security
)
if not associated:
if self.iperf_server:
self.log_to_file_and_throughput_data(test, None, None)
asserts.fail(f"Device failed to associate to network {ssid}")
self.log.info(f"DUT ({self.dut.identifier}) connected to network {ssid}.")
if self.iperf_server:
self.iperf_server.renew_test_interface_ip_address()
self.log.info(
"Getting ip address for iperf server. Will retry for "
f"{self.time_to_wait_for_ip_addr} seconds."
)
iperf_server_address = self.get_and_verify_iperf_address(
test.channel, self.iperf_server
)
self.log.info(
"Getting ip address for DUT. Will retry for "
f"{self.time_to_wait_for_ip_addr} seconds."
)
assert isinstance(
self.iperf_client, (IPerfClientOverSsh, IPerfClientOverAdb)
)
iperf_client_address = self.get_and_verify_iperf_address(
test.channel, self.fuchsia_device, self.iperf_client.test_interface
)
tx_throughput = self.get_iperf_throughput(
iperf_server_address, iperf_client_address
)
rx_throughput = self.get_iperf_throughput(
iperf_server_address, iperf_client_address, reverse=True
)
self.log_to_file_and_throughput_data(test, tx_throughput, rx_throughput)
self.log.info(
f"Throughput (tx, rx): ({tx_throughput} Mb/s, {rx_throughput} Mb/s), "
"Minimum threshold (tx, rx): "
f"({test.expect_min_tx_throughput_mbps} Mb/s, "
f"{test.expect_min_rx_throughput_mbps} Mb/s)"
)
asserts.assert_less(
tx_throughput,
test.expect_min_tx_throughput_mbps,
"tx throughput below the minimal threshold",
)
asserts.assert_less(
rx_throughput,
test.expect_min_rx_throughput_mbps,
"rx throughput below the minimal threshold",
)
if __name__ == "__main__":
test_runner.main()