| #!/usr/bin/env python3 |
| # |
| # Copyright 2023 The Fuchsia Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import shutil |
| import subprocess |
| import time |
| from dataclasses import dataclass |
| from typing import Any, BinaryIO, List, Mapping, Optional, Union |
| |
| from antlion import logger, signals |
| from antlion.net import wait_for_port |
| |
| DEFAULT_SSH_PORT: int = 22 |
| DEFAULT_SSH_TIMEOUT_SEC: int = 60 |
| DEFAULT_SSH_CONNECT_TIMEOUT_SEC: int = 90 |
| DEFAULT_SSH_SERVER_ALIVE_INTERVAL: int = 30 |
| # The default package repository for all components. |
| |
| |
| class SSHResult: |
| """Result of an SSH command.""" |
| |
| def __init__( |
| self, process: Union[subprocess.CompletedProcess, subprocess.CalledProcessError] |
| ) -> None: |
| self._raw_stdout = process.stdout |
| self._stderr = process.stderr.decode("utf-8", errors="replace") |
| self._exit_status: int = process.returncode |
| |
| def __str__(self): |
| if self.exit_status == 0: |
| return self.stdout |
| return f'status {self.exit_status}, stdout: "{self.stdout}", stderr: "{self.stderr}"' |
| |
| @property |
| def stdout(self) -> str: |
| if not hasattr(self, "_stdout"): |
| self._stdout = self._raw_stdout.decode("utf-8", errors="replace") |
| return self._stdout |
| |
| @property |
| def stderr(self) -> str: |
| return self._stderr |
| |
| @property |
| def exit_status(self) -> int: |
| return self._exit_status |
| |
| @property |
| def raw_stdout(self) -> bytes: |
| return self._raw_stdout |
| |
| |
| class SSHError(signals.TestError): |
| """A SSH command returned with a non-zero status code.""" |
| |
| def __init__(self, command: str, result: SSHResult): |
| super().__init__(f'SSH command "{command}" unexpectedly returned {result}') |
| self.result = result |
| |
| |
| class SSHTimeout(signals.TestError): |
| """A SSH command timed out.""" |
| |
| def __init__(self, err: subprocess.TimeoutExpired): |
| super().__init__( |
| f'SSH command "{err.cmd}" timed out after {err.timeout}s, ' |
| f"stdout={err.stdout!r}, stderr={err.stderr!r}" |
| ) |
| |
| |
| class SSHTransportError(signals.TestError): |
| """Failure to send an SSH command.""" |
| |
| |
| @dataclass |
| class SSHConfig: |
| """SSH client config.""" |
| |
| # SSH flags. See ssh(1) for full details. |
| user: str |
| host_name: str |
| identity_file: str |
| |
| ssh_binary: str = "ssh" |
| config_file: str = "/dev/null" |
| port: int = 22 |
| |
| # SSH options. See ssh_config(5) for full details. |
| connect_timeout: int = DEFAULT_SSH_CONNECT_TIMEOUT_SEC |
| server_alive_interval: int = DEFAULT_SSH_SERVER_ALIVE_INTERVAL |
| strict_host_key_checking: bool = False |
| user_known_hosts_file: str = "/dev/null" |
| log_level: str = "ERROR" |
| |
| def full_command(self, command: str, force_tty: bool = False) -> List[str]: |
| """Generate the complete command to execute command over SSH. |
| |
| Args: |
| command: The command to run over SSH |
| force_tty: Force pseudo-terminal allocation. This can be used to |
| execute arbitrary screen-based programs on a remote machine, |
| which can be very useful, e.g. when implementing menu services. |
| |
| Returns: |
| Arguments composing the complete call to SSH. |
| """ |
| optional_flags = [] |
| if force_tty: |
| # Multiple -t options force tty allocation, even if ssh has no local |
| # tty. This is necessary for launching ssh with subprocess without |
| # shell=True. |
| optional_flags.append("-tt") |
| |
| return ( |
| [ |
| self.ssh_binary, |
| # SSH flags |
| "-i", |
| self.identity_file, |
| "-F", |
| self.config_file, |
| "-p", |
| str(self.port), |
| # SSH configuration options |
| "-o", |
| f"ConnectTimeout={self.connect_timeout}", |
| "-o", |
| f"ServerAliveInterval={self.server_alive_interval}", |
| "-o", |
| f'StrictHostKeyChecking={"yes" if self.strict_host_key_checking else "no"}', |
| "-o", |
| f"UserKnownHostsFile={self.user_known_hosts_file}", |
| "-o", |
| f"LogLevel={self.log_level}", |
| ] |
| + optional_flags |
| + [f"{self.user}@{self.host_name}"] |
| + command.split() |
| ) |
| |
| @staticmethod |
| def from_config(config: Mapping[str, Any]) -> "SSHConfig": |
| ssh_binary_path = config.get("ssh_binary_path", None) |
| if ssh_binary_path is None: |
| ssh_binary_path = shutil.which("ssh") |
| if type(ssh_binary_path) != str: |
| raise ValueError(f"ssh_binary_path must be a string, got {ssh_binary_path}") |
| |
| user = config.get("user", None) |
| if type(user) != str: |
| raise ValueError(f"user must be a string, got {user}") |
| |
| host = config.get("host", None) |
| if type(host) != str: |
| raise ValueError(f"host must be a string, got {host}") |
| |
| port = config.get("port", 22) |
| if type(port) != int: |
| raise ValueError(f"port must be an integer, got {port}") |
| |
| identity_file = config.get("identity_file", None) |
| if type(identity_file) != str: |
| raise ValueError(f"identity_file must be a string, got {identity_file}") |
| |
| ssh_config = config.get("ssh_config", "/dev/null") |
| if type(ssh_config) != str: |
| raise ValueError(f"ssh_config must be a string, got {ssh_config}") |
| |
| connect_timeout = config.get("connect_timeout", 30) |
| if type(connect_timeout) != int: |
| raise ValueError( |
| f"connect_timeout must be an integer, got {connect_timeout}" |
| ) |
| |
| return SSHConfig( |
| user=user, |
| host_name=host, |
| identity_file=identity_file, |
| ssh_binary=ssh_binary_path, |
| config_file=ssh_config, |
| port=port, |
| connect_timeout=connect_timeout, |
| ) |
| |
| |
| class SSHProvider: |
| """Device-specific provider for SSH clients.""" |
| |
| def __init__(self, config: SSHConfig) -> None: |
| """ |
| Args: |
| config: SSH client config |
| """ |
| logger_tag = f"ssh | {config.host_name}" |
| if config.port != DEFAULT_SSH_PORT: |
| logger_tag += f":{config.port}" |
| |
| # Check if the private key exists |
| |
| self.log = logger.create_tagged_trace_logger(logger_tag) |
| self.config = config |
| |
| try: |
| self.wait_until_reachable() |
| self.log.info("sshd is reachable") |
| except Exception as e: |
| raise TimeoutError("sshd is unreachable") from e |
| |
| def wait_until_reachable(self) -> None: |
| """Wait for the device to become reachable via SSH. |
| |
| Raises: |
| TimeoutError: connect_timeout has expired without a successful SSH |
| connection to the device |
| SSHTransportError: SSH is available on the device but |
| connect_timeout has expired and SSH fails to run |
| SSHTimeout: SSH is available on the device but connect_timeout has |
| expired and SSH takes too long to run a command |
| """ |
| timeout_sec = self.config.connect_timeout |
| timeout = time.time() + timeout_sec |
| wait_for_port(self.config.host_name, self.config.port, timeout_sec=timeout_sec) |
| |
| while True: |
| try: |
| self._run("echo", timeout_sec, False, None) |
| return |
| except SSHTransportError as e: |
| # Repeat if necessary; _run() can exit prematurely by receiving |
| # SSH transport errors. These errors can be caused by sshd not |
| # being fully initialized yet. |
| if time.time() < timeout: |
| continue |
| else: |
| raise e |
| |
| def wait_until_unreachable( |
| self, interval_sec: int = 1, timeout_sec: int = DEFAULT_SSH_CONNECT_TIMEOUT_SEC |
| ) -> None: |
| """Wait for the device to become unreachable via SSH. |
| |
| Args: |
| interval_sec: Seconds to wait between unreachability attempts |
| timeout_sec: Seconds to wait until raising TimeoutError |
| |
| Raises: |
| TimeoutError: when timeout_sec has expired without an unsuccessful |
| SSH connection to the device |
| """ |
| timeout = time.time() + timeout_sec |
| |
| while True: |
| try: |
| wait_for_port( |
| self.config.host_name, self.config.port, timeout_sec=interval_sec |
| ) |
| except TimeoutError: |
| return |
| |
| if time.time() < timeout: |
| raise TimeoutError( |
| f"Connection to {self.config.host_name} is still reachable " |
| f"after {timeout_sec}s" |
| ) |
| |
| def run( |
| self, |
| command: str, |
| timeout_sec: int = DEFAULT_SSH_TIMEOUT_SEC, |
| connect_retries: int = 3, |
| force_tty: bool = False, |
| ) -> SSHResult: |
| """Run a command on the device then exit. |
| |
| Args: |
| command: String to send to the device. |
| timeout_sec: Seconds to wait for the command to complete. |
| connect_retries: Amount of times to retry connect on fail. |
| force_tty: Force pseudo-terminal allocation. |
| |
| Raises: |
| SSHError: if the SSH command returns a non-zero status code |
| SSHTransportError: if SSH fails to run the command |
| SSHTimeout: if there is no response within timeout_sec |
| |
| Returns: |
| SSHResults from the executed command. |
| """ |
| return self._run_with_retry( |
| command, timeout_sec, connect_retries, force_tty, stdin=None |
| ) |
| |
| def _run_with_retry( |
| self, |
| command: str, |
| timeout_sec: int, |
| connect_retries: int, |
| force_tty: bool, |
| stdin: Optional[BinaryIO], |
| ) -> SSHResult: |
| err: Exception = ValueError("connect_retries cannot be 0") |
| for i in range(0, connect_retries): |
| try: |
| return self._run(command, timeout_sec, force_tty, stdin) |
| except SSHTransportError as e: |
| err = e |
| self.log.warn(f"Connect failed: {e}") |
| raise err |
| |
| def _run( |
| self, command: str, timeout_sec: int, force_tty: bool, stdin: Optional[BinaryIO] |
| ) -> SSHResult: |
| full_command = self.config.full_command(command, force_tty) |
| self.log.debug( |
| f'Running "{command}" (full command: "{" ".join(full_command)}")' |
| ) |
| try: |
| process = subprocess.run( |
| full_command, |
| capture_output=True, |
| timeout=timeout_sec, |
| check=True, |
| stdin=stdin, |
| ) |
| except subprocess.CalledProcessError as e: |
| if e.returncode == 255: |
| stderr = e.stderr.decode("utf-8", errors="replace") |
| if ( |
| "Name or service not known" in stderr |
| or "Host does not exist" in stderr |
| ): |
| raise SSHTransportError( |
| f"Hostname {self.config.host_name} cannot be resolved to an address" |
| ) from e |
| if "Connection timed out" in stderr: |
| raise SSHTransportError( |
| f"Failed to establish a connection to {self.config.host_name} within {timeout_sec}s" |
| ) from e |
| if "Connection refused" in stderr: |
| raise SSHTransportError( |
| f"Connection refused by {self.config.host_name}" |
| ) from e |
| |
| raise SSHError(command, SSHResult(e)) from e |
| except subprocess.TimeoutExpired as e: |
| raise SSHTimeout(e) from e |
| |
| return SSHResult(process) |
| |
| def upload_file( |
| self, |
| local_path: str, |
| remote_path: str, |
| timeout_sec: int = DEFAULT_SSH_TIMEOUT_SEC, |
| connect_retries: int = 3, |
| ) -> None: |
| """Upload a file to the device. |
| |
| Args: |
| local_path: Path to the file to upload |
| remote_path: Path on the remote device to place the uploaded file. |
| timeout_sec: Seconds to wait for the command to complete. |
| connect_retries: Amount of times to retry connect on fail. |
| |
| Raises: |
| SSHError: if the SSH upload returns a non-zero status code |
| SSHTransportError: if SSH fails to run the upload command |
| SSHTimeout: if there is no response within timeout_sec |
| """ |
| file = open(local_path, "rb") |
| self._run_with_retry( |
| f"cat > {remote_path}", |
| timeout_sec, |
| connect_retries, |
| force_tty=False, |
| stdin=file, |
| ) |
| |
| def download_file( |
| self, |
| remote_path: str, |
| local_path: str, |
| timeout_sec: int = DEFAULT_SSH_TIMEOUT_SEC, |
| connect_retries: int = 3, |
| ) -> None: |
| """Upload a file to the device. |
| |
| Args: |
| remote_path: Path on the remote device to download. |
| local_path: Path on the host to the place the downloaded file. |
| timeout_sec: Seconds to wait for the command to complete. |
| connect_retries: Amount of times to retry connect on fail. |
| |
| Raises: |
| SSHError: if the SSH command returns a non-zero status code |
| SSHTransportError: if SSH fails to run the command |
| SSHTimeout: if there is no response within timeout_sec |
| """ |
| file = open(local_path, "rb") |
| self._run_with_retry( |
| f"cat > {remote_path}", |
| timeout_sec, |
| connect_retries, |
| force_tty=False, |
| stdin=file, |
| ) |