blob: ca94d73dfad868226ab94adbdd7aa9dfd8e5469e [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright 2023 The Fuchsia Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import subprocess
import time
from dataclasses import dataclass
from typing import Any, BinaryIO, Mapping
from antlion import logger, signals
from antlion.net import wait_for_port
DEFAULT_SSH_PORT: int = 22
DEFAULT_SSH_TIMEOUT_SEC: int = 60
DEFAULT_SSH_CONNECT_TIMEOUT_SEC: int = 90
DEFAULT_SSH_SERVER_ALIVE_INTERVAL: int = 30
# The default package repository for all components.
class SSHResult:
"""Result of an SSH command."""
def __init__(
self,
process: subprocess.CompletedProcess[bytes]
| subprocess.CompletedProcess[str]
| subprocess.CalledProcessError,
) -> None:
if isinstance(process.stdout, bytes):
self._stdout_bytes = process.stdout
elif isinstance(process.stdout, str):
self._stdout = process.stdout
else:
raise TypeError(
"Expected process.stdout to be either bytes or str, "
f"got {type(process.stdout)}"
)
if isinstance(process.stderr, bytes):
self._stderr_bytes = process.stderr
elif isinstance(process.stderr, str):
self._stderr = process.stderr
else:
raise TypeError(
"Expected process.stderr to be either bytes or str, "
f"got {type(process.stderr)}"
)
self._exit_status = process.returncode
def __str__(self):
if self.exit_status == 0:
return self.stdout
return f'status {self.exit_status}, stdout: "{self.stdout}", stderr: "{self.stderr}"'
@property
def stdout(self) -> str:
if not hasattr(self, "_stdout"):
self._stdout = self._stdout_bytes.decode("utf-8", errors="replace")
return self._stdout
@property
def stdout_bytes(self) -> bytes:
if not hasattr(self, "_stdout_bytes"):
self._stdout_bytes = self._stdout.encode()
return self._stdout_bytes
@property
def stderr(self) -> str:
if not hasattr(self, "_stderr"):
self._stderr = self._stderr_bytes.decode("utf-8", errors="replace")
return self._stderr
@property
def exit_status(self) -> int:
return self._exit_status
class SSHError(signals.TestError):
"""A SSH command returned with a non-zero status code."""
def __init__(self, command: str, result: SSHResult):
super().__init__(f'SSH command "{command}" unexpectedly returned {result}')
self.result = result
class SSHTimeout(signals.TestError):
"""A SSH command timed out."""
def __init__(self, err: subprocess.TimeoutExpired):
super().__init__(
f'SSH command "{err.cmd}" timed out after {err.timeout}s, '
f"stdout={err.stdout!r}, stderr={err.stderr!r}"
)
class SSHTransportError(signals.TestError):
"""Failure to send an SSH command."""
@dataclass
class SSHConfig:
"""SSH client config."""
# SSH flags. See ssh(1) for full details.
user: str
host_name: str
identity_file: str
ssh_binary: str = "ssh"
config_file: str = "/dev/null"
port: int = 22
# SSH options. See ssh_config(5) for full details.
connect_timeout: int = DEFAULT_SSH_CONNECT_TIMEOUT_SEC
server_alive_interval: int = DEFAULT_SSH_SERVER_ALIVE_INTERVAL
strict_host_key_checking: bool = False
user_known_hosts_file: str = "/dev/null"
log_level: str = "ERROR"
def full_command(self, command: str, force_tty: bool = False) -> list[str]:
"""Generate the complete command to execute command over SSH.
Args:
command: The command to run over SSH
force_tty: Force pseudo-terminal allocation. This can be used to
execute arbitrary screen-based programs on a remote machine,
which can be very useful, e.g. when implementing menu services.
Returns:
Arguments composing the complete call to SSH.
"""
optional_flags = []
if force_tty:
# Multiple -t options force tty allocation, even if ssh has no local
# tty. This is necessary for launching ssh with subprocess without
# shell=True.
optional_flags.append("-tt")
return (
[
self.ssh_binary,
# SSH flags
"-i",
self.identity_file,
"-F",
self.config_file,
"-p",
str(self.port),
# SSH configuration options
"-o",
f"ConnectTimeout={self.connect_timeout}",
"-o",
f"ServerAliveInterval={self.server_alive_interval}",
"-o",
f'StrictHostKeyChecking={"yes" if self.strict_host_key_checking else "no"}',
"-o",
f"UserKnownHostsFile={self.user_known_hosts_file}",
"-o",
f"LogLevel={self.log_level}",
]
+ optional_flags
+ [f"{self.user}@{self.host_name}"]
+ command.split()
)
@staticmethod
def from_config(config: Mapping[str, Any]) -> "SSHConfig":
ssh_binary_path = config.get("ssh_binary_path", None)
if ssh_binary_path is None:
ssh_binary_path = shutil.which("ssh")
if type(ssh_binary_path) != str:
raise ValueError(f"ssh_binary_path must be a string, got {ssh_binary_path}")
user = config.get("user", None)
if type(user) != str:
raise ValueError(f"user must be a string, got {user}")
host = config.get("host", None)
if type(host) != str:
raise ValueError(f"host must be a string, got {host}")
port = config.get("port", 22)
if type(port) != int:
raise ValueError(f"port must be an integer, got {port}")
identity_file = config.get("identity_file", None)
if type(identity_file) != str:
raise ValueError(f"identity_file must be a string, got {identity_file}")
ssh_config = config.get("ssh_config", "/dev/null")
if type(ssh_config) != str:
raise ValueError(f"ssh_config must be a string, got {ssh_config}")
connect_timeout = config.get("connect_timeout", 30)
if type(connect_timeout) != int:
raise ValueError(
f"connect_timeout must be an integer, got {connect_timeout}"
)
return SSHConfig(
user=user,
host_name=host,
identity_file=identity_file,
ssh_binary=ssh_binary_path,
config_file=ssh_config,
port=port,
connect_timeout=connect_timeout,
)
class SSHProvider:
"""Device-specific provider for SSH clients."""
def __init__(self, config: SSHConfig) -> None:
"""
Args:
config: SSH client config
"""
logger_tag = f"ssh | {config.host_name}"
if config.port != DEFAULT_SSH_PORT:
logger_tag += f":{config.port}"
# Check if the private key exists
self.log = logger.create_tagged_trace_logger(logger_tag)
self.config = config
try:
self.wait_until_reachable()
self.log.info("sshd is reachable")
except Exception as e:
raise TimeoutError("sshd is unreachable") from e
def wait_until_reachable(self) -> None:
"""Wait for the device to become reachable via SSH.
Raises:
TimeoutError: connect_timeout has expired without a successful SSH
connection to the device
SSHTransportError: SSH is available on the device but
connect_timeout has expired and SSH fails to run
SSHTimeout: SSH is available on the device but connect_timeout has
expired and SSH takes too long to run a command
"""
timeout_sec = self.config.connect_timeout
timeout = time.time() + timeout_sec
wait_for_port(self.config.host_name, self.config.port, timeout_sec=timeout_sec)
while True:
try:
self._run("echo", timeout_sec, False, None)
return
except SSHTransportError as e:
# Repeat if necessary; _run() can exit prematurely by receiving
# SSH transport errors. These errors can be caused by sshd not
# being fully initialized yet.
if time.time() < timeout:
continue
else:
raise e
def wait_until_unreachable(
self, interval_sec: int = 1, timeout_sec: int = DEFAULT_SSH_CONNECT_TIMEOUT_SEC
) -> None:
"""Wait for the device to become unreachable via SSH.
Args:
interval_sec: Seconds to wait between unreachability attempts
timeout_sec: Seconds to wait until raising TimeoutError
Raises:
TimeoutError: when timeout_sec has expired without an unsuccessful
SSH connection to the device
"""
timeout = time.time() + timeout_sec
while True:
try:
wait_for_port(
self.config.host_name, self.config.port, timeout_sec=interval_sec
)
except TimeoutError:
return
if time.time() < timeout:
raise TimeoutError(
f"Connection to {self.config.host_name} is still reachable "
f"after {timeout_sec}s"
)
def run(
self,
command: str,
timeout_sec: int = DEFAULT_SSH_TIMEOUT_SEC,
connect_retries: int = 3,
force_tty: bool = False,
) -> SSHResult:
"""Run a command on the device then exit.
Args:
command: String to send to the device.
timeout_sec: Seconds to wait for the command to complete.
connect_retries: Amount of times to retry connect on fail.
force_tty: Force pseudo-terminal allocation.
Raises:
SSHError: if the SSH command returns a non-zero status code
SSHTransportError: if SSH fails to run the command
SSHTimeout: if there is no response within timeout_sec
Returns:
SSHResults from the executed command.
"""
return self._run_with_retry(
command, timeout_sec, connect_retries, force_tty, stdin=None
)
def _run_with_retry(
self,
command: str,
timeout_sec: int,
connect_retries: int,
force_tty: bool,
stdin: BinaryIO | None,
) -> SSHResult:
err: Exception = ValueError("connect_retries cannot be 0")
for i in range(0, connect_retries):
try:
return self._run(command, timeout_sec, force_tty, stdin)
except SSHTransportError as e:
err = e
self.log.warn(f"Connect failed: {e}")
raise err
def _run(
self, command: str, timeout_sec: int, force_tty: bool, stdin: BinaryIO | None
) -> SSHResult:
full_command = self.config.full_command(command, force_tty)
self.log.debug(
f'Running "{command}" (full command: "{" ".join(full_command)}")'
)
try:
process = subprocess.run(
full_command,
capture_output=True,
timeout=timeout_sec,
check=True,
stdin=stdin,
)
except subprocess.CalledProcessError as e:
if e.returncode == 255:
stderr = e.stderr.decode("utf-8", errors="replace")
if (
"Name or service not known" in stderr
or "Host does not exist" in stderr
):
raise SSHTransportError(
f"Hostname {self.config.host_name} cannot be resolved to an address"
) from e
if "Connection timed out" in stderr:
raise SSHTransportError(
f"Failed to establish a connection to {self.config.host_name} within {timeout_sec}s"
) from e
if "Connection refused" in stderr:
raise SSHTransportError(
f"Connection refused by {self.config.host_name}"
) from e
raise SSHError(command, SSHResult(e)) from e
except subprocess.TimeoutExpired as e:
raise SSHTimeout(e) from e
return SSHResult(process)
def upload_file(
self,
local_path: str,
remote_path: str,
timeout_sec: int = DEFAULT_SSH_TIMEOUT_SEC,
connect_retries: int = 3,
) -> None:
"""Upload a file to the device.
Args:
local_path: Path to the file to upload
remote_path: Path on the remote device to place the uploaded file.
timeout_sec: Seconds to wait for the command to complete.
connect_retries: Amount of times to retry connect on fail.
Raises:
SSHError: if the SSH upload returns a non-zero status code
SSHTransportError: if SSH fails to run the upload command
SSHTimeout: if there is no response within timeout_sec
"""
with open(local_path, "rb") as file:
self._run_with_retry(
f"cat > {remote_path}",
timeout_sec,
connect_retries,
force_tty=False,
stdin=file,
)
def download_file(
self,
remote_path: str,
local_path: str,
timeout_sec: int = DEFAULT_SSH_TIMEOUT_SEC,
connect_retries: int = 3,
) -> None:
"""Upload a file to the device.
Args:
remote_path: Path on the remote device to download.
local_path: Path on the host to the place the downloaded file.
timeout_sec: Seconds to wait for the command to complete.
connect_retries: Amount of times to retry connect on fail.
Raises:
SSHError: if the SSH command returns a non-zero status code
SSHTransportError: if SSH fails to run the command
SSHTimeout: if there is no response within timeout_sec
"""
with open(local_path, "rb") as file:
self._run_with_retry(
f"cat > {remote_path}",
timeout_sec,
connect_retries,
force_tty=False,
stdin=file,
)