#!/usr/bin/env python3
#
# Copyright 2022 The Fuchsia Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict
from enum import StrEnum

from honeydew import errors
from honeydew.interfaces.device_classes.fuchsia_device import (
    FuchsiaDevice as HdFuchsiaDevice,
)
from honeydew.typing.wlan import (
    BssDescription,
    BssType,
    ChannelBandwidth,
    ClientStatusConnected,
    ClientStatusConnecting,
    ClientStatusIdle,
    ClientStatusResponse,
    CountryCode,
    Protection,
    QueryIfaceResponse,
    WlanChannel,
    WlanMacRole,
)
from mobly import signals

from antlion.controllers.fuchsia_lib.base_lib import BaseLib
from antlion.validation import MapValidator

STATUS_IDLE_KEY = "Idle"
STATUS_CONNECTING_KEY = "Connecting"

# We need to convert the string we receive from the wlan facade to an intEnum
# because serde gives us a string.
string_to_int_enum_map: dict[str, Protection] = {
    "Unknown": Protection.UNKNOWN,
    "Open": Protection.OPEN,
    "Wep": Protection.WEP,
    "Wpa1": Protection.WPA1,
    "Wpa1Wpa2PersonalTkipOnly": Protection.WPA1_WPA2_PERSONAL_TKIP_ONLY,
    "Wpa2PersonalTkipOnly": Protection.WPA2_PERSONAL_TKIP_ONLY,
    "Wpa1Wpa2Personal": Protection.WPA1_WPA2_PERSONAL,
    "Wpa2Personal": Protection.WPA2_PERSONAL,
    "Wpa2Wpa3Personal": Protection.WPA2_WPA3_PERSONAL,
    "Wpa3Personal": Protection.WPA3_PERSONAL,
    "Wpa2Enterprise": Protection.WPA2_ENTERPRISE,
    "Wpa3Enterprise": Protection.WPA3_ENTERPRISE,
}


class WlanFailure(signals.TestFailure):
    """Exception for SL4F commands executed by WLAN lib."""


class Command(StrEnum):
    """Sl4f Server Commands."""

    SCAN_FOR_BSS_INFO = "wlan.scan_for_bss_info"
    CONNECT = "wlan.connect"
    DISCONNECT = "wlan.disconnect"
    STATUS = "wlan.status"
    GET_IFACE_ID_LIST = "wlan.get_iface_id_list"
    GET_PHY_ID_LIST = "wlan.get_phy_id_list"
    CREATE_IFACE = "wlan.create_iface"
    DESTROY_IFACE = "wlan.destroy_iface"
    GET_COUNTRY = "wlan_phy.get_country"
    QUERY_IFACE = "wlan.query_iface"
    SET_REGION = "location_regulatory_region_facade.set_region"


class FuchsiaWlanLib(BaseLib):
    def __init__(self, addr: str, honeydew_fd: HdFuchsiaDevice | None = None) -> None:
        super().__init__(addr, "wlan")
        self.honeydew_fd = honeydew_fd

    def _check_response_error(
        self, cmd: Command, response_json: dict[str, object]
    ) -> object | None:
        """Helper method to process errors from SL4F calls.

        Args:
            cmd: SL4F command sent.
            response_json: Response from SL4F server.

        Returns:
            Response json or None if error.

        Raises:
            WlanFailure if the response_json has something in the 'error' field.
        """
        resp = MapValidator(response_json)
        error = resp.get(str, "error", None)
        if error:
            # We sometimes expect to catch WlanFailure so we include a log here for
            # when we do retries.
            self.log.debug(f"SL4F call: {cmd} failed with Error: '{error}'.")
            raise WlanFailure(f"SL4F call: {cmd} failed with Error: '{error}'.")
        else:
            return response_json.get("result")

    def scan_for_bss_info(self) -> dict[str, list[BssDescription]]:
        """Scans and returns BSS info

        Returns:
            A dict mapping each seen SSID to a list of BSS Description IE
            blocks, one for each BSS observed in the network

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        if self.honeydew_fd:
            try:
                return self.honeydew_fd.wlan.scan_for_bss_info()
            except errors.Sl4fError as e:
                raise WlanFailure(
                    f"SL4F call {Command.SCAN_FOR_BSS_INFO} failed."
                ) from e
        else:
            resp = self.send_command(Command.SCAN_FOR_BSS_INFO)
            result = self._check_response_error(Command.SCAN_FOR_BSS_INFO, resp)

            if not isinstance(result, dict):
                raise TypeError(f'Expected "result" to be dict, got {type(result)}')

            ssid_bss_desc_map: dict[str, list[BssDescription]] = {}
            for ssid_key, bss_list in result.items():
                if not isinstance(bss_list, list):
                    raise TypeError(
                        f'Expected "bss_list" to be list, got {type(bss_list)}'
                    )

                # Create BssDescription type out of return values
                bss_descriptions: list[BssDescription] = []
                for bss in bss_list:
                    bss_map = MapValidator(bss)
                    bssid = bss_map.list("bssid").all(int)
                    ies = bss_map.list("ies").all(int)
                    channel_map = MapValidator(bss_map.get(dict, "channel"))

                    wlan_channel = WlanChannel(
                        primary=channel_map.get(int, "primary"),
                        cbw=ChannelBandwidth(channel_map.get(str, "cbw")),
                        secondary80=channel_map.get(int, "secondary80"),
                    )

                    bss_block = BssDescription(
                        bssid=bssid,
                        bss_type=BssType(bss_map.get(str, "bss_type")),
                        beacon_period=bss_map.get(int, "beacon_period"),
                        capability_info=bss_map.get(int, "capability_info"),
                        ies=ies,
                        channel=wlan_channel,
                        rssi_dbm=bss_map.get(int, "rssi_dbm"),
                        snr_db=bss_map.get(int, "snr_db"),
                    )
                    bss_descriptions.append(bss_block)

                ssid_bss_desc_map[ssid_key] = bss_descriptions

            return ssid_bss_desc_map

    def connect(
        self, target_ssid: str, target_pwd: str | None, target_bss_desc: BssDescription
    ) -> bool:
        """Triggers a network connection
        Args:
            target_ssid: The network to connect to.
            target_pwd: The password for the network.
            target_bss_desc: The basic service set for target network.

        Returns:
            boolean indicating if the connection was successful

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        method_params = {
            "target_ssid": target_ssid,
            "target_pwd": target_pwd,
            "target_bss_desc": asdict(target_bss_desc),
        }
        if self.honeydew_fd:
            try:
                return self.honeydew_fd.wlan.connect(
                    target_ssid, target_pwd, target_bss_desc
                )
            except errors.Sl4fError as e:
                raise WlanFailure(f"SL4F call {Command.CONNECT} failed.") from e
        else:
            resp = self.send_command(Command.CONNECT, method_params)
            result = self._check_response_error(Command.CONNECT, resp)

            if not isinstance(result, bool):
                raise TypeError(f'Expected "result" to be bool, got {type(result)}')

            return result

    def disconnect(self) -> None:
        """Disconnect any current wifi connections

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        if self.honeydew_fd:
            try:
                self.honeydew_fd.wlan.disconnect()
            except errors.Sl4fError as e:
                raise WlanFailure(f"SL4F call {Command.DISCONNECT} failed.") from e
        else:
            resp = self.send_command(Command.DISCONNECT)
            self._check_response_error(Command.DISCONNECT, resp)

    def create_iface(
        self, phy_id: int, role: WlanMacRole, sta_addr: str | None = None
    ) -> int:
        """Create a new WLAN interface.

        Args:
            phy_id: The interface id.
            role: The role of new interface.
            sta_addr: MAC address for softAP interface only.

        Returns:
            Iface id of newly created interface.

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        method_params = {
            "phy_id": phy_id,
            "role": role,
            "sta_addr": sta_addr,
        }
        if self.honeydew_fd:
            try:
                return self.honeydew_fd.wlan.create_iface(phy_id, role, sta_addr)
            except errors.Sl4fError as e:
                raise WlanFailure(f"SL4F call {Command.CREATE_IFACE} failed.") from e
        else:
            resp = self.send_command(Command.CREATE_IFACE, method_params)
            result = self._check_response_error(Command.CREATE_IFACE, resp)

            if not isinstance(result, int):
                raise TypeError(f'Expected "result" to be int, got {type(result)}')

            return result

    def destroy_iface(self, iface_id: int) -> None:
        """Destroy WLAN interface by ID.

        Args:
            iface_id: The interface to destroy.

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        method_params = {"identifier": iface_id}
        if self.honeydew_fd:
            try:
                self.honeydew_fd.wlan.destroy_iface(iface_id)
            except errors.Sl4fError as e:
                raise WlanFailure(f"SL4F call {Command.DESTROY_IFACE} failed.") from e
        else:
            resp = self.send_command(Command.DESTROY_IFACE, method_params)
            self._check_response_error(Command.DESTROY_IFACE, resp)

    def get_iface_id_list(self) -> list[int]:
        """Get list of wlan iface IDs on device.

        Returns:
            A list of wlan iface IDs that are present on the device.

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        if self.honeydew_fd:
            try:
                return self.honeydew_fd.wlan.get_iface_id_list()
            except errors.Sl4fError as e:
                raise WlanFailure(
                    f"SL4F call {Command.GET_IFACE_ID_LIST} failed."
                ) from e
        else:
            resp = self.send_command(Command.GET_IFACE_ID_LIST)
            result = self._check_response_error(Command.GET_IFACE_ID_LIST, resp)

            if not isinstance(result, list):
                raise TypeError(f'Expected "result" to be list, got {type(result)}')

            return result

    def get_phy_id_list(self) -> list[int]:
        """Get list of phy ids on device.

        Returns:
            A list of phy ids that is present on the device.

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        if self.honeydew_fd:
            try:
                return self.honeydew_fd.wlan.get_phy_id_list()
            except errors.Sl4fError as e:
                raise WlanFailure(f"SL4F call {Command.GET_PHY_ID_LIST} failed.") from e
        else:
            resp = self.send_command(Command.GET_PHY_ID_LIST)
            result = self._check_response_error(Command.GET_PHY_ID_LIST, resp)

            if not isinstance(result, list):
                raise TypeError(f'Expected "result" to be list, got {type(result)}')

            return result

    def status(self) -> ClientStatusResponse:
        """Request connection status

        Returns:
            ClientStatusResponse state summary and
            status of various networks connections.

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        if self.honeydew_fd:
            try:
                return self.honeydew_fd.wlan.status()
            except errors.Sl4fError as e:
                raise WlanFailure(f"SL4F call {Command.STATUS} failed.") from e
        else:
            resp = self.send_command(Command.STATUS)
            result = self._check_response_error(Command.STATUS, resp)

            if not isinstance(result, dict):
                raise TypeError(f'Expected "result" to be dict, got {type(result)}')

            result_map = MapValidator(result)
            # Only one of these keys in result should be present.
            if STATUS_IDLE_KEY in result:
                return ClientStatusIdle()
            elif STATUS_CONNECTING_KEY in result:
                ssid = result.get("Connecting")
                if not isinstance(ssid, list):
                    raise TypeError(
                        f'Expected "connecting" to be list, got "{type(ssid)}"'
                    )
                return ClientStatusConnecting(ssid=ssid)
            else:
                connected_map = MapValidator(result_map.get(dict, "Connected"))
                channel_map = MapValidator(connected_map.get(dict, "channel"))
                bssid = connected_map.list("bssid").all(int)
                ssid = connected_map.list("ssid").all(int)
                protection = connected_map.get(str, "protection")

                channel = WlanChannel(
                    primary=channel_map.get(int, "primary"),
                    cbw=ChannelBandwidth(channel_map.get(str, "cbw")),
                    secondary80=channel_map.get(int, "secondary80"),
                )

                return ClientStatusConnected(
                    bssid=bssid,
                    ssid=ssid,
                    rssi_dbm=connected_map.get(int, "rssi_dbm"),
                    snr_db=connected_map.get(int, "snr_db"),
                    channel=channel,
                    protection=Protection(string_to_int_enum_map.get(protection, 0)),
                )

    def get_country(self, phy_id: int) -> CountryCode:
        """Reads the currently configured country for `phy_id`.

        Args:
            phy_id: unsigned 16-bit integer.

        Returns:
            The currently configured country code from phy_id.

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        method_params = {"phy_id": phy_id}
        if self.honeydew_fd:
            try:
                return self.honeydew_fd.wlan.get_country(phy_id)
            except errors.Sl4fError as e:
                raise WlanFailure(f"SL4F call {Command.GET_COUNTRY} failed.") from e
        else:
            resp = self.send_command(Command.GET_COUNTRY, method_params)
            result = self._check_response_error(Command.GET_COUNTRY, resp)

            if not isinstance(result, list):
                raise TypeError(f'Expected "result" to be list, got {type(result)}')

            set_code = "".join([chr(ascii_char) for ascii_char in result])

            return CountryCode(set_code)

    def query_iface(self, iface_id: int) -> QueryIfaceResponse:
        """Retrieves interface info for given wlan iface id.

        Args:
            iface_id: The iface_id to query

        Returns:
            QueryIfaceResults from the SL4F server

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        method_params = {"iface_id": iface_id}
        if self.honeydew_fd:
            try:
                return self.honeydew_fd.wlan.query_iface(iface_id)
            except errors.Sl4fError as e:
                raise WlanFailure(f"SL4F call {Command.QUERY_IFACE} failed.") from e
        else:
            resp = self.send_command(Command.QUERY_IFACE, method_params)
            result = self._check_response_error(Command.QUERY_IFACE, resp)

            if not isinstance(result, dict):
                raise TypeError(f'Expected "network" to be dict, got {type(result)}')

            iface_results = MapValidator(result)
            sta_addr = iface_results.list("sta_addr")

            return QueryIfaceResponse(
                role=WlanMacRole(iface_results.get(str, "role")),
                id=iface_results.get(int, "id"),
                phy_id=iface_results.get(int, "phy_id"),
                phy_assigned_id=iface_results.get(int, "phy_assigned_id"),
                sta_addr=sta_addr.all(int),
            )

    def set_region(self, region_code: CountryCode) -> None:
        """Set regulatory region.

        Args:
            region_code: CountryCode which is a 2-byte ASCII string.

        Raises:
            WlanFailure: Sl4f run command failed.
        """
        method_params = {"region": region_code.value}
        if self.honeydew_fd:
            try:
                self.honeydew_fd.wlan.set_region(region_code)
            except errors.Sl4fError as e:
                raise WlanFailure(f"SL4F call {Command.SET_REGION} failed.") from e
        else:
            resp = self.send_command(Command.SET_REGION, method_params)
            self._check_response_error(Command.SET_REGION, resp)
