blob: b2770456d69f4806c9ac76a331ccb46b45b83a2d [file] [log] [blame]
# Copyright 2022 The Fuchsia Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create a SshSettings from a dictionary from an ACTS config
Args:
config dict instance from an ACTS config
Returns:
An instance of SshSettings or None
"""
from typing import Mapping, Union
class SshSettings(object):
"""Contains settings for ssh.
Container for ssh connection settings.
Attributes:
username: The name of the user to log in as.
hostname: The name of the host to connect to.
executable: The ssh executable to use.
port: The port to connect through (usually 22).
host_file: The known host file to use.
connect_timeout: How long to wait on a connection before giving a
timeout.
alive_interval: How long between ssh heartbeat signals to keep the
connection alive.
"""
def __init__(
self,
hostname: str,
username: str,
identity_file: str,
port: int = 22,
host_file: str = "/dev/null",
connect_timeout: int = 30,
alive_interval: int = 300,
executable: str = "/usr/bin/ssh",
ssh_config: str | None = None,
):
self.username = username
self.hostname = hostname
self.executable = executable
self.port = port
self.host_file = host_file
self.connect_timeout = connect_timeout
self.alive_interval = alive_interval
self.identity_file = identity_file
self.ssh_config = ssh_config
def construct_ssh_options(self) -> dict[str, Union[str, int, bool]]:
"""Construct the ssh options.
Constructs a dictionary of option that should be used with the ssh
command.
Returns:
A dictionary of option name to value.
"""
current_options: dict[str, Union[str, int, bool]] = {}
current_options["StrictHostKeyChecking"] = False
current_options["UserKnownHostsFile"] = self.host_file
current_options["ConnectTimeout"] = self.connect_timeout
current_options["ServerAliveInterval"] = self.alive_interval
return current_options
def construct_ssh_flags(self) -> dict[str, Union[None, str, int]]:
"""Construct the ssh flags.
Constructs what flags should be used in the ssh connection.
Returns:
A dictionary of flag name to value. If value is none then it is
treated as a binary flag.
"""
current_flags: dict[str, Union[None, str, int]] = {}
current_flags["-a"] = None
current_flags["-x"] = None
current_flags["-p"] = self.port
if self.identity_file:
current_flags["-i"] = self.identity_file
if self.ssh_config:
current_flags["-F"] = self.ssh_config
return current_flags
def from_config(config: Mapping[str, Union[str, int]]) -> SshSettings:
ssh_binary_path = config.get("ssh_binary_path", "/usr/bin/ssh")
if type(ssh_binary_path) != str:
raise ValueError(f"ssh_binary_path must be a string, got {ssh_binary_path}")
user = config.get("user", None)
if type(user) != str:
raise ValueError(f"user must be a string, got {user}")
host = config.get("host", None)
if type(host) != str:
raise ValueError(f"host must be a string, got {host}")
port = config.get("port", 22)
if type(port) != int:
raise ValueError(f"port must be an integer, got {port}")
identity_file = config.get("identity_file", None)
if identity_file is None or type(identity_file) != str:
raise ValueError(f"identity_file must be a string, got {identity_file}")
ssh_config = config.get("ssh_config", None)
if ssh_config is not None and type(ssh_config) != str:
raise ValueError(f"ssh_config must be a string, got {ssh_config}")
connect_timeout = config.get("connect_timeout", 30)
if type(connect_timeout) != int:
raise ValueError(f"connect_timeout must be an integer, got {connect_timeout}")
return SshSettings(
host,
user,
identity_file,
port=port,
ssh_config=ssh_config,
connect_timeout=connect_timeout,
executable=ssh_binary_path,
)