blob: 4e99b7118bef97266c9eb74c4050e135da644649 [file] [log] [blame]
#!/usr/bin/env python3
#
# Copyright 2023 The Fuchsia Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import subprocess
import time
from dataclasses import dataclass
from typing import Any, BinaryIO, List, Mapping, Optional, Union
from antlion import logger, signals
from antlion.net import wait_for_port
DEFAULT_SSH_PORT: int = 22
DEFAULT_SSH_TIMEOUT_SEC: int = 60
DEFAULT_SSH_CONNECT_TIMEOUT_SEC: int = 90
DEFAULT_SSH_SERVER_ALIVE_INTERVAL: int = 30
# The default package repository for all components.
class SSHResult:
"""Result of an SSH command."""
def __init__(
self, process: Union[subprocess.CompletedProcess, subprocess.CalledProcessError]
) -> None:
self._raw_stdout = process.stdout
self._stderr = process.stderr.decode("utf-8", errors="replace")
self._exit_status: int = process.returncode
def __str__(self):
if self.exit_status == 0:
return self.stdout
return f'status {self.exit_status}, stdout: "{self.stdout}", stderr: "{self.stderr}"'
@property
def stdout(self) -> str:
if not hasattr(self, "_stdout"):
self._stdout = self._raw_stdout.decode("utf-8", errors="replace")
return self._stdout
@property
def stderr(self) -> str:
return self._stderr
@property
def exit_status(self) -> int:
return self._exit_status
@property
def raw_stdout(self) -> bytes:
return self._raw_stdout
class SSHError(signals.TestError):
"""A SSH command returned with a non-zero status code."""
def __init__(self, command: str, result: SSHResult):
super().__init__(f'SSH command "{command}" unexpectedly returned {result}')
self.result = result
class SSHTimeout(signals.TestError):
"""A SSH command timed out."""
def __init__(self, err: subprocess.TimeoutExpired):
super().__init__(
f'SSH command "{err.cmd}" timed out after {err.timeout}s, '
f"stdout={err.stdout!r}, stderr={err.stderr!r}"
)
class SSHTransportError(signals.TestError):
"""Failure to send an SSH command."""
@dataclass
class SSHConfig:
"""SSH client config."""
# SSH flags. See ssh(1) for full details.
user: str
host_name: str
identity_file: str
ssh_binary: str = "ssh"
config_file: str = "/dev/null"
port: int = 22
# SSH options. See ssh_config(5) for full details.
connect_timeout: int = DEFAULT_SSH_CONNECT_TIMEOUT_SEC
server_alive_interval: int = DEFAULT_SSH_SERVER_ALIVE_INTERVAL
strict_host_key_checking: bool = False
user_known_hosts_file: str = "/dev/null"
log_level: str = "ERROR"
def full_command(self, command: str, force_tty: bool = False) -> List[str]:
"""Generate the complete command to execute command over SSH.
Args:
command: The command to run over SSH
force_tty: Force pseudo-terminal allocation. This can be used to
execute arbitrary screen-based programs on a remote machine,
which can be very useful, e.g. when implementing menu services.
Returns:
Arguments composing the complete call to SSH.
"""
optional_flags = []
if force_tty:
# Multiple -t options force tty allocation, even if ssh has no local
# tty. This is necessary for launching ssh with subprocess without
# shell=True.
optional_flags.append("-tt")
return (
[
self.ssh_binary,
# SSH flags
"-i",
self.identity_file,
"-F",
self.config_file,
"-p",
str(self.port),
# SSH configuration options
"-o",
f"ConnectTimeout={self.connect_timeout}",
"-o",
f"ServerAliveInterval={self.server_alive_interval}",
"-o",
f'StrictHostKeyChecking={"yes" if self.strict_host_key_checking else "no"}',
"-o",
f"UserKnownHostsFile={self.user_known_hosts_file}",
"-o",
f"LogLevel={self.log_level}",
]
+ optional_flags
+ [f"{self.user}@{self.host_name}"]
+ command.split()
)
@staticmethod
def from_config(config: Mapping[str, Any]) -> "SSHConfig":
ssh_binary_path = config.get("ssh_binary_path", None)
if ssh_binary_path is None:
ssh_binary_path = shutil.which("ssh")
if type(ssh_binary_path) != str:
raise ValueError(f"ssh_binary_path must be a string, got {ssh_binary_path}")
user = config.get("user", None)
if type(user) != str:
raise ValueError(f"user must be a string, got {user}")
host = config.get("host", None)
if type(host) != str:
raise ValueError(f"host must be a string, got {host}")
port = config.get("port", 22)
if type(port) != int:
raise ValueError(f"port must be an integer, got {port}")
identity_file = config.get("identity_file", None)
if type(identity_file) != str:
raise ValueError(f"identity_file must be a string, got {identity_file}")
ssh_config = config.get("ssh_config", "/dev/null")
if type(ssh_config) != str:
raise ValueError(f"ssh_config must be a string, got {ssh_config}")
connect_timeout = config.get("connect_timeout", 30)
if type(connect_timeout) != int:
raise ValueError(
f"connect_timeout must be an integer, got {connect_timeout}"
)
return SSHConfig(
user=user,
host_name=host,
identity_file=identity_file,
ssh_binary=ssh_binary_path,
config_file=ssh_config,
port=port,
connect_timeout=connect_timeout,
)
class SSHProvider:
"""Device-specific provider for SSH clients."""
def __init__(self, config: SSHConfig) -> None:
"""
Args:
config: SSH client config
"""
logger_tag = f"ssh | {config.host_name}"
if config.port != DEFAULT_SSH_PORT:
logger_tag += f":{config.port}"
# Check if the private key exists
self.log = logger.create_tagged_trace_logger(logger_tag)
self.config = config
try:
self.wait_until_reachable()
self.log.info("sshd is reachable")
except Exception as e:
raise TimeoutError("sshd is unreachable") from e
def wait_until_reachable(self) -> None:
"""Wait for the device to become reachable via SSH.
Raises:
TimeoutError: connect_timeout has expired without a successful SSH
connection to the device
SSHTransportError: SSH is available on the device but
connect_timeout has expired and SSH fails to run
SSHTimeout: SSH is available on the device but connect_timeout has
expired and SSH takes too long to run a command
"""
timeout_sec = self.config.connect_timeout
timeout = time.time() + timeout_sec
wait_for_port(self.config.host_name, self.config.port, timeout_sec=timeout_sec)
while True:
try:
self._run("echo", timeout_sec, False, None)
return
except SSHTransportError as e:
# Repeat if necessary; _run() can exit prematurely by receiving
# SSH transport errors. These errors can be caused by sshd not
# being fully initialized yet.
if time.time() < timeout:
continue
else:
raise e
def wait_until_unreachable(
self, interval_sec: int = 1, timeout_sec: int = DEFAULT_SSH_CONNECT_TIMEOUT_SEC
) -> None:
"""Wait for the device to become unreachable via SSH.
Args:
interval_sec: Seconds to wait between unreachability attempts
timeout_sec: Seconds to wait until raising TimeoutError
Raises:
TimeoutError: when timeout_sec has expired without an unsuccessful
SSH connection to the device
"""
timeout = time.time() + timeout_sec
while True:
try:
wait_for_port(
self.config.host_name, self.config.port, timeout_sec=interval_sec
)
except TimeoutError:
return
if time.time() < timeout:
raise TimeoutError(
f"Connection to {self.config.host_name} is still reachable "
f"after {timeout_sec}s"
)
def run(
self,
command: str,
timeout_sec: int = DEFAULT_SSH_TIMEOUT_SEC,
connect_retries: int = 3,
force_tty: bool = False,
) -> SSHResult:
"""Run a command on the device then exit.
Args:
command: String to send to the device.
timeout_sec: Seconds to wait for the command to complete.
connect_retries: Amount of times to retry connect on fail.
force_tty: Force pseudo-terminal allocation.
Raises:
SSHError: if the SSH command returns a non-zero status code
SSHTransportError: if SSH fails to run the command
SSHTimeout: if there is no response within timeout_sec
Returns:
SSHResults from the executed command.
"""
return self._run_with_retry(
command, timeout_sec, connect_retries, force_tty, stdin=None
)
def _run_with_retry(
self,
command: str,
timeout_sec: int,
connect_retries: int,
force_tty: bool,
stdin: Optional[BinaryIO],
) -> SSHResult:
err: Exception = ValueError("connect_retries cannot be 0")
for i in range(0, connect_retries):
try:
return self._run(command, timeout_sec, force_tty, stdin)
except SSHTransportError as e:
err = e
self.log.warn(f"Connect failed: {e}")
raise err
def _run(
self, command: str, timeout_sec: int, force_tty: bool, stdin: Optional[BinaryIO]
) -> SSHResult:
full_command = self.config.full_command(command, force_tty)
self.log.debug(
f'Running "{command}" (full command: "{" ".join(full_command)}")'
)
try:
process = subprocess.run(
full_command,
capture_output=True,
timeout=timeout_sec,
check=True,
stdin=stdin,
)
except subprocess.CalledProcessError as e:
if e.returncode == 255:
stderr = e.stderr.decode("utf-8", errors="replace")
if (
"Name or service not known" in stderr
or "Host does not exist" in stderr
):
raise SSHTransportError(
f"Hostname {self.config.host_name} cannot be resolved to an address"
) from e
if "Connection timed out" in stderr:
raise SSHTransportError(
f"Failed to establish a connection to {self.config.host_name} within {timeout_sec}s"
) from e
if "Connection refused" in stderr:
raise SSHTransportError(
f"Connection refused by {self.config.host_name}"
) from e
raise SSHError(command, SSHResult(e)) from e
except subprocess.TimeoutExpired as e:
raise SSHTimeout(e) from e
return SSHResult(process)
def upload_file(
self,
local_path: str,
remote_path: str,
timeout_sec: int = DEFAULT_SSH_TIMEOUT_SEC,
connect_retries: int = 3,
) -> None:
"""Upload a file to the device.
Args:
local_path: Path to the file to upload
remote_path: Path on the remote device to place the uploaded file.
timeout_sec: Seconds to wait for the command to complete.
connect_retries: Amount of times to retry connect on fail.
Raises:
SSHError: if the SSH upload returns a non-zero status code
SSHTransportError: if SSH fails to run the upload command
SSHTimeout: if there is no response within timeout_sec
"""
file = open(local_path, "rb")
self._run_with_retry(
f"cat > {remote_path}",
timeout_sec,
connect_retries,
force_tty=False,
stdin=file,
)
def download_file(
self,
remote_path: str,
local_path: str,
timeout_sec: int = DEFAULT_SSH_TIMEOUT_SEC,
connect_retries: int = 3,
) -> None:
"""Upload a file to the device.
Args:
remote_path: Path on the remote device to download.
local_path: Path on the host to the place the downloaded file.
timeout_sec: Seconds to wait for the command to complete.
connect_retries: Amount of times to retry connect on fail.
Raises:
SSHError: if the SSH command returns a non-zero status code
SSHTransportError: if SSH fails to run the command
SSHTimeout: if there is no response within timeout_sec
"""
file = open(local_path, "rb")
self._run_with_retry(
f"cat > {remote_path}",
timeout_sec,
connect_retries,
force_tty=False,
stdin=file,
)