| # Copyright 2022 The Fuchsia Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import collections |
| import os |
| import re |
| import shutil |
| import tempfile |
| import threading |
| import time |
| import uuid |
| |
| from antlion import logger |
| from antlion.controllers.utils_lib import host_utils |
| from antlion.controllers.utils_lib.ssh import formatter |
| from antlion.libs.proc import job |
| from antlion.runner import CompletedProcess, Runner |
| |
| |
| class Error(Exception): |
| """An error occurred during an ssh operation.""" |
| |
| |
| class CommandError(Exception): |
| """An error occurred with the command. |
| |
| Attributes: |
| result: The results of the ssh command that had the error. |
| """ |
| |
| def __init__(self, result): |
| """ |
| Args: |
| result: The result of the ssh command that created the problem. |
| """ |
| self.result = result |
| |
| def __str__(self): |
| return "cmd: %s\nstdout: %s\nstderr: %s" % ( |
| self.result.command, |
| self.result.stdout, |
| self.result.stderr, |
| ) |
| |
| |
| _Tunnel = collections.namedtuple("_Tunnel", ["local_port", "remote_port", "proc"]) |
| |
| |
| class SshConnection(Runner): |
| """Provides a connection to a remote machine through ssh. |
| |
| Provides the ability to connect to a remote machine and execute a command |
| on it. The connection will try to establish a persistent connection When |
| a command is run. If the persistent connection fails it will attempt |
| to connect normally. |
| """ |
| |
| @property |
| def socket_path(self): |
| """Returns: The os path to the master socket file.""" |
| if self._master_ssh_tempdir is None: |
| raise AttributeError( |
| "socket_path is not available yet; run setup_master_ssh() first" |
| ) |
| return os.path.join(self._master_ssh_tempdir, "socket") |
| |
| def __init__(self, settings): |
| """ |
| Args: |
| settings: The ssh settings to use for this connection. |
| formatter: The object that will handle formatting ssh command |
| for use with the background job. |
| """ |
| self._settings = settings |
| self._formatter = formatter.SshFormatter() |
| self._lock = threading.Lock() |
| self._master_ssh_proc = None |
| self._master_ssh_tempdir: str | None = None |
| self._tunnels = list() |
| |
| def log_line(msg): |
| return f"[SshConnection | {self._settings.hostname}] {msg}" |
| |
| self.log = logger.create_logger(log_line) |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, _, __, ___): |
| self.close() |
| |
| def __del__(self): |
| self.close() |
| |
| def setup_master_ssh(self, timeout_sec: int = 5): |
| """Sets up the master ssh connection. |
| |
| Sets up the initial master ssh connection if it has not already been |
| started. |
| |
| Args: |
| timeout_sec: The time to wait for the master ssh connection to |
| be made. |
| |
| Raises: |
| Error: When setting up the master ssh connection fails. |
| """ |
| with self._lock: |
| if self._master_ssh_proc is not None: |
| socket_path = self.socket_path |
| if ( |
| not os.path.exists(socket_path) |
| or self._master_ssh_proc.poll() is not None |
| ): |
| self.log.debug( |
| "Master ssh connection to %s is down.", self._settings.hostname |
| ) |
| self._cleanup_master_ssh() |
| |
| if self._master_ssh_proc is None: |
| # Create a shared socket in a temp location. |
| self._master_ssh_tempdir = tempfile.mkdtemp(prefix="ssh-master") |
| |
| # Setup flags and options for running the master ssh |
| # -N: Do not execute a remote command. |
| # ControlMaster: Spawn a master connection. |
| # ControlPath: The master connection socket path. |
| extra_flags = {"-N": None} |
| extra_options = { |
| "ControlMaster": True, |
| "ControlPath": self.socket_path, |
| "BatchMode": True, |
| } |
| |
| # Construct the command and start it. |
| master_cmd = self._formatter.format_ssh_local_command( |
| self._settings, extra_flags=extra_flags, extra_options=extra_options |
| ) |
| self.log.info("Starting master ssh connection.") |
| self._master_ssh_proc = job.run_async(master_cmd) |
| |
| end_time = time.time() + timeout_sec |
| |
| while time.time() < end_time: |
| if os.path.exists(self.socket_path): |
| break |
| time.sleep(0.2) |
| else: |
| self._cleanup_master_ssh() |
| raise Error("Master ssh connection timed out.") |
| |
| def run( |
| self, |
| command: str | list[str], |
| timeout_sec: int | None = 60, |
| ignore_status: bool = False, |
| env: dict[str, str] | None = None, |
| io_encoding: str = "utf-8", |
| attempts: int = 2, |
| ) -> CompletedProcess: |
| """Runs a remote command over ssh. |
| |
| Will ssh to a remote host and run a command. This method will |
| block until the remote command is finished. |
| |
| Args: |
| command: The command to execute over ssh. |
| timeout_sec: seconds to wait for command to finish. |
| ignore_status: True to ignore the exit code of the remote |
| subprocess. Note that if you do ignore status codes, |
| you should handle non-zero exit codes explicitly. |
| env: environment variables to setup on the remote host. |
| io_encoding: unicode encoding of command output. |
| attempts: Number of attempts before giving up on command failures. |
| |
| Returns: |
| Results of the ssh command. |
| |
| Raises: |
| job.TimeoutError: When the remote command took to long to execute. |
| job.Error: When the ssh connection failed to be created. |
| CommandError: Ssh worked, but the command had an error executing. |
| """ |
| if attempts < 1: |
| raise TypeError("attempts must be a positive, non-zero integer") |
| if env is None: |
| env = {} |
| |
| try: |
| self.setup_master_ssh(self._settings.connect_timeout) |
| except Error: |
| self.log.warning( |
| "Failed to create master ssh connection, using " |
| "normal ssh connection." |
| ) |
| |
| extra_options: dict[str, str | bool] = {"BatchMode": True} |
| if self._master_ssh_proc: |
| extra_options["ControlPath"] = self.socket_path |
| |
| identifier = str(uuid.uuid4()) |
| full_command = f'echo "CONNECTED: {identifier}"; {command}' |
| |
| terminal_command = self._formatter.format_command( |
| full_command, env, self._settings, extra_options=extra_options |
| ) |
| |
| dns_retry_count = 2 |
| while True: |
| result = job.run( |
| terminal_command, |
| ignore_status=True, |
| timeout_sec=timeout_sec, |
| io_encoding=io_encoding, |
| ) |
| output = result.stdout |
| |
| # Check for a connected message to prevent false negatives. |
| valid_connection = re.search( |
| f"^CONNECTED: {identifier}", output, flags=re.MULTILINE |
| ) |
| if valid_connection: |
| # Remove the first line that contains the connect message. |
| line_index = output.find("\n") + 1 |
| if line_index == 0: |
| line_index = len(output) |
| real_output = output[line_index:].encode(io_encoding) |
| |
| result._raw_stdout = real_output |
| result._stdout_str = None |
| |
| if result.exit_status and not ignore_status: |
| raise job.Error(result) |
| return result |
| |
| error_string = result.stderr |
| |
| had_dns_failure = result.exit_status == 255 and re.search( |
| r"^ssh: .*: Name or service not known", error_string, flags=re.MULTILINE |
| ) |
| if had_dns_failure: |
| dns_retry_count -= 1 |
| if not dns_retry_count: |
| raise Error("DNS failed to find host.", result) |
| self.log.debug("Failed to connect to host, retrying...") |
| else: |
| break |
| |
| had_timeout = re.search( |
| r"^ssh: connect to host .* port .*: " r"Connection timed out\r$", |
| error_string, |
| flags=re.MULTILINE, |
| ) |
| if had_timeout: |
| raise Error("Ssh timed out.", result) |
| |
| permission_denied = "Permission denied" in error_string |
| if permission_denied: |
| raise Error("Permission denied.", result) |
| |
| unknown_host = re.search( |
| r"ssh: Could not resolve hostname .*: " r"Name or service not known", |
| error_string, |
| flags=re.MULTILINE, |
| ) |
| if unknown_host: |
| raise Error("Unknown host.", result) |
| |
| self.log.error(f"An unknown error has occurred. Job result: {result}") |
| ping_output = job.run( |
| f"ping {self._settings.hostname} -c 3 -w 1", ignore_status=True |
| ) |
| self.log.error(f"Ping result: {ping_output}") |
| if attempts > 1: |
| self._cleanup_master_ssh() |
| self.run( |
| command, timeout_sec, ignore_status, env, io_encoding, attempts - 1 |
| ) |
| raise Error("The job failed for unknown reasons.", result) |
| |
| def run_async(self, command, env=None) -> CompletedProcess: |
| """Starts up a background command over ssh. |
| |
| Will ssh to a remote host and startup a command. This method will |
| block until there is confirmation that the remote command has started. |
| |
| Args: |
| command: The command to execute over ssh. Can be either a string |
| or a list. |
| env: A dictionary of environment variables to setup on the remote |
| host. |
| |
| Returns: |
| The result of the command to launch the background job. |
| |
| Raises: |
| CmdTimeoutError: When the remote command took to long to execute. |
| SshTimeoutError: When the connection took to long to established. |
| SshPermissionDeniedError: When permission is not allowed on the |
| remote host. |
| """ |
| return self.run( |
| f"({command}) < /dev/null > /dev/null 2>&1 & echo -n $!", env=env |
| ) |
| |
| def close(self): |
| """Clean up open connections to remote host.""" |
| self._cleanup_master_ssh() |
| while self._tunnels: |
| self.close_ssh_tunnel(self._tunnels[0].local_port) |
| |
| def _cleanup_master_ssh(self): |
| """ |
| Release all resources (process, temporary directory) used by an active |
| master SSH connection. |
| """ |
| # If a master SSH connection is running, kill it. |
| if self._master_ssh_proc is not None: |
| self.log.debug("Nuking master_ssh_job.") |
| self._master_ssh_proc.kill() |
| self._master_ssh_proc.wait() |
| self._master_ssh_proc = None |
| |
| # Remove the temporary directory for the master SSH socket. |
| if self._master_ssh_tempdir is not None: |
| self.log.debug("Cleaning master_ssh_tempdir.") |
| shutil.rmtree(self._master_ssh_tempdir) |
| self._master_ssh_tempdir = None |
| |
| def create_ssh_tunnel(self, port, local_port=None): |
| """Create an ssh tunnel from local_port to port. |
| |
| This securely forwards traffic from local_port on this machine to the |
| remote SSH host at port. |
| |
| Args: |
| port: remote port on the host. |
| local_port: local forwarding port, or None to pick an available |
| port. |
| |
| Returns: |
| the created tunnel process. |
| """ |
| if not local_port: |
| local_port = host_utils.get_available_host_port() |
| else: |
| for tunnel in self._tunnels: |
| if tunnel.remote_port == port: |
| return tunnel.local_port |
| |
| extra_flags = { |
| "-n": None, # Read from /dev/null for stdin |
| "-N": None, # Do not execute a remote command |
| "-q": None, # Suppress warnings and diagnostic commands |
| "-L": f"{local_port}:localhost:{port}", |
| } |
| extra_options = dict() |
| if self._master_ssh_proc: |
| extra_options["ControlPath"] = self.socket_path |
| tunnel_cmd = self._formatter.format_ssh_local_command( |
| self._settings, extra_flags=extra_flags, extra_options=extra_options |
| ) |
| self.log.debug("Full tunnel command: %s", tunnel_cmd) |
| # Exec the ssh process directly so that when we deliver signals, we |
| # deliver them straight to the child process. |
| tunnel_proc = job.run_async(tunnel_cmd) |
| self.log.debug( |
| "Started ssh tunnel, local = %d remote = %d, pid = %d", |
| local_port, |
| port, |
| tunnel_proc.pid, |
| ) |
| self._tunnels.append(_Tunnel(local_port, port, tunnel_proc)) |
| return local_port |
| |
| def close_ssh_tunnel(self, local_port): |
| """Close a previously created ssh tunnel of a TCP port. |
| |
| Args: |
| local_port: int port on localhost previously forwarded to the remote |
| host. |
| |
| Returns: |
| integer port number this port was forwarded to on the remote host or |
| None if no tunnel was found. |
| """ |
| idx = None |
| for i, tunnel in enumerate(self._tunnels): |
| if tunnel.local_port == local_port: |
| idx = i |
| break |
| if idx is not None: |
| tunnel = self._tunnels.pop(idx) |
| tunnel.proc.kill() |
| tunnel.proc.wait() |
| return tunnel.remote_port |
| return None |
| |
| def send_file(self, local_path, remote_path, ignore_status=False): |
| """Send a file from the local host to the remote host. |
| |
| Args: |
| local_path: string path of file to send on local host. |
| remote_path: string path to copy file to on remote host. |
| ignore_status: Whether or not to ignore the command's exit_status. |
| """ |
| # TODO: This may belong somewhere else: b/32572515 |
| user_host = self._formatter.format_host_name(self._settings) |
| job.run( |
| f"scp {local_path} {user_host}:{remote_path}", |
| ignore_status=ignore_status, |
| ) |
| |
| def pull_file(self, local_path, remote_path, ignore_status=False): |
| """Send a file from remote host to local host |
| |
| Args: |
| local_path: string path of file to recv on local host |
| remote_path: string path to copy file from on remote host. |
| ignore_status: Whether or not to ignore the command's exit_status. |
| """ |
| user_host = self._formatter.format_host_name(self._settings) |
| job.run( |
| f"scp {user_host}:{remote_path} {local_path}", |
| ignore_status=ignore_status, |
| ) |
| |
| def find_free_port(self, interface_name="localhost"): |
| """Find a unused port on the remote host. |
| |
| Note that this method is inherently racy, since it is impossible |
| to promise that the remote port will remain free. |
| |
| Args: |
| interface_name: string name of interface to check whether a |
| port is used against. |
| |
| Returns: |
| integer port number on remote interface that was free. |
| """ |
| # TODO: This may belong somewhere else: b/3257251 |
| free_port_cmd = ( |
| 'python -c "import socket; s=socket.socket(); ' |
| "s.bind(('%s', 0)); print(s.getsockname()[1]); s.close()\"" |
| ) % interface_name |
| port = int(self.run(free_port_cmd).stdout) |
| # Yield to the os to ensure the port gets cleaned up. |
| time.sleep(0.001) |
| return port |