blob: 2a3726d61b5846b471dc5f594a98ea193b1d5a0d [file] [log] [blame]
// Copyright 2023 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "persistent_service.h"
#include "async_loop.h"
#include "ipc_utils.h"
#include "util.h"
#define DEBUG 0
#ifdef _WIN32
#include <windows.h>
#else
#include <fcntl.h>
#include <sys/stat.h>
#include <unistd.h>
#endif
#if DEBUG
#define SERVER_LOG(...) \
do { \
fprintf(stderr, "SERVER_LOG: "); \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \
} while (0)
#define CLIENT_LOG(...) \
do { \
fprintf(stderr, "CLIENT_LOG: "); \
fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \
} while (0)
#else
#define CLIENT_LOG(...) (void)0
#define SERVER_LOG(...) (void)0
#endif
namespace {
// Sleep for |delay_ms| milliseconds.
void SleepMilliSeconds(int delay_ms) {
#ifdef _WIN32
if (delay_ms > 0)
::Sleep(static_cast<DWORD>(delay_ms));
#else
usleep(static_cast<useconds_t>(delay_ms) * 1000);
#endif
}
struct ProcessInfo {
#ifdef _WIN32
HANDLE process_handle;
#else
pid_t process_pid;
#endif
};
// Start a new process, with command-line |args|.
// Uses fork()/exec() on Posix, and CreateProcess on Win32.
// Return true on success, or set |*err| and return false on failure.
bool SpawnServerProcess(const PersistentService::Config& config,
ProcessInfo* info, std::string* err) {
if (config.command.empty()) {
*err = "Empty command line!";
return false;
}
#ifdef _WIN32
std::string command_string;
for (const auto& arg : config.command) {
std::string escaped;
GetWin32EscapedString(arg, &escaped);
command_string += escaped;
command_string += ' ';
}
// Remove trailing space if any.
if (!command_string.empty())
command_string.resize(command_string.size() - 1u);
SECURITY_ATTRIBUTES security_attributes = {};
security_attributes.nLength = sizeof(SECURITY_ATTRIBUTES);
security_attributes.bInheritHandle = TRUE;
// Must be inheritable so subprocesses can dup to children.
HANDLE nul =
CreateFileA("NUL", GENERIC_READ | GENERIC_WRITE,
FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
&security_attributes, OPEN_EXISTING, 0, NULL);
if (nul == INVALID_HANDLE_VALUE)
Win32Fatal("couldn't open nul");
HANDLE log = nul;
std::string log_file = config.log_file;
if (log_file.empty()) {
// As a debug helper, use the log file from this environment variable.
const char* env = getenv("DEBUG_PERSISTENT_SERVICE_LOG_FILE");
if (env)
log_file = env;
}
if (!log_file.empty()) {
log =
CreateFileA(log_file.c_str(), STANDARD_RIGHTS_WRITE | FILE_APPEND_DATA,
FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE,
&security_attributes, OPEN_ALWAYS, 0, NULL);
if (log == INVALID_HANDLE_VALUE)
Win32Fatal("couldn't open log file", log_file.c_str());
}
STARTUPINFOA startup_info = {};
startup_info.cb = sizeof(STARTUPINFO);
startup_info.dwFlags = STARTF_USESTDHANDLES;
startup_info.hStdInput = nul;
startup_info.hStdOutput = log;
startup_info.hStdError = log;
PROCESS_INFORMATION process_info = {};
// Ninja handles ctrl-c, except for subprocesses in console pools.
DWORD process_flags = CREATE_NEW_PROCESS_GROUP;
{
// TODO(digit): Create new environment variable block and add/replace
// the value in it, without touching the parent env.
for (const auto& pair : config.env_vars) {
const char* varname = pair.first.c_str();
const char* value = pair.second.c_str();
SetEnvironmentVariable(varname, value);
}
}
// Do not prepend 'cmd /c' on Windows, this breaks command
// lines greater than 8,191 chars.
if (!CreateProcessA(NULL, (char*)command_string.c_str(), NULL, NULL,
/* inherit handles */ TRUE, process_flags, NULL, NULL,
&startup_info, &process_info)) {
DWORD error = GetLastError();
if (error == ERROR_FILE_NOT_FOUND) {
CloseHandle(nul);
if (log != nul)
CloseHandle(log);
// child_ is already NULL;
*err =
"CreateProcess failed: The system cannot find the file "
"specified: [" +
command_string + "]";
return false;
} else {
fprintf(stderr, "\nCreateProcess failed. Command attempted:\n\"%s\"\n",
command_string.c_str());
const char* hint = NULL;
// ERROR_INVALID_PARAMETER means the command line was formatted
// incorrectly. This can be caused by a command line being too long or
// leading whitespace in the command. Give extra context for this case.
if (error == ERROR_INVALID_PARAMETER) {
hint = "is the command line too long?";
}
Win32Fatal("CreateProcess", hint);
}
}
CloseHandle(nul);
if (log != nul)
CloseHandle(log);
CloseHandle(process_info.hThread);
return true;
#else // !_WIN32
// Build arguments array for future exec() call.
std::vector<char*> exec_args;
exec_args.reserve(config.command.size() + 1);
for (const auto& arg : config.command) {
exec_args.push_back(const_cast<char*>(arg.data()));
CLIENT_LOG("CMD: [%s]", arg.c_str());
}
exec_args.push_back(nullptr);
pid_t process = fork();
if (process < 0) {
*err = "fork failed()!";
return false;
}
if (process == 0) {
// Create new session to not receive signals from parent process group.
if (setsid() < 0) {
fprintf(stderr, "ERROR: setsid() failed: %s\n", strerror(errno));
exit(1);
}
// Change current working directory.
if (!config.working_dir.empty()) {
const char* work_dir = config.working_dir.c_str();
if (chdir(work_dir) < 0)
ErrnoFatal("chdir", work_dir);
}
// Redirect stdin to /dev/null and stdout/stderr to a log file if
// PERSISTENT_SERVER_LOG_FILE is set in the environment, or to
// /dev/null otherwise.
int null_fd = open("/dev/null", O_RDWR);
if (null_fd < 0) {
fprintf(stderr, "ERROR: open(/dev/null) failed: %s\n", strerror(errno));
exit(1);
}
int log_fd = null_fd;
std::string log_file = config.log_file;
if (log_file.empty()) {
// As a debug helper, use the log file from this environment variable.
const char* env = getenv("DEBUG_PERSISTENT_SERVICE_LOG_FILE");
if (env)
log_file = env;
}
if (!log_file.empty()) {
log_fd = open(log_file.c_str(), O_WRONLY | O_APPEND | O_CREAT, 0755);
if (log_fd < 0) {
fprintf(stderr, "ERROR: open(%s) failed: %s\n", log_file.c_str(),
strerror(errno));
exit(1);
}
}
fflush(stdout);
fflush(stderr);
dup2(null_fd, 0);
dup2(log_fd, 1);
dup2(log_fd, 2);
// Set extra environment variables.
for (const auto& pair : config.env_vars) {
const char* varname = pair.first.c_str();
const char* value = pair.second.c_str();
setenv(varname, value, 1);
}
fprintf(stderr, "\n\nSTARTING NEW PERSISTENT SERVER: %s\n",
config.command[0].c_str());
execv(config.command[0].c_str(), exec_args.data());
fprintf(stderr, "ERROR: exec() failed: %s\n", strerror(errno));
exit(1);
}
// In parent process, do not do anything.
info->process_pid = process;
return true;
#endif // !_WIN32
}
// Type of commands received from the server.
enum ServerCommandType {
kServerCommandTypeStop,
kServerCommandTypeGetPid,
kServerCommandTypeClientQuery,
};
} // namespace
///////////////////////////////////////////////////////////////////////////
///
/// C L I E N T S I D E
///
PersistentService::Client::Client(const std::string& service_name)
: service_name_(service_name) {}
PersistentService::Client::~Client() = default;
bool PersistentService::Client::HasServer() const {
return IpcServiceHandle::IsBound(service_name_);
}
int PersistentService::Client::GetServerPid() const {
std::string err;
IpcHandle client = RawConnect(&err);
if (!client)
return -1;
uint8_t query_type = kServerCommandTypeGetPid;
int server_pid = -1;
if (!RemoteWrite(query_type, client, &err) ||
!RemoteRead(server_pid, client, &err)) {
return -1;
}
return server_pid;
}
bool PersistentService::Client::StopServer(std::string* err) const {
IpcHandle client = RawConnect(err);
if (!client)
return false;
uint8_t query_type = kServerCommandTypeStop;
if (!client.Write(&query_type, sizeof(query_type), err)) {
*err = StringFormat("Could not stop server: %s", err->c_str());
return false;
}
return true;
}
bool PersistentService::Client::WaitForServerShutdown() {
std::string error;
for (int retry_count = 20; retry_count > 0; --retry_count) {
if (!HasServer())
return true;
SleepMilliSeconds(100);
}
return false;
}
IpcHandle PersistentService::Client::Connect(const Config& config,
std::string* error) {
bool try_again = true;
while (true) {
CLIENT_LOG("Trying to connect to server.");
IpcHandle client = ConnectOrStartServer(config, error);
if (!client)
return {};
CLIENT_LOG("Sending version info to server.");
if (!RemoteWrite(config.version_info, client, error)) {
*error = "Could not send version info: " + *error;
return {};
}
// Receive the |compatible| flag that indicates that the server is
// compatible with the current build plan.
std::string version_check;
if (!RemoteRead(version_check, client, error)) {
*error = "Could not read version check result: " + *error;
return {};
}
if (version_check.empty()) {
// Good, return the handle now.
return client;
}
// The server was not compatible with the current client.
// Assume it exited, and start another one at least once.
CLIENT_LOG("Incompatible server version: %s", version_check.c_str());
if (!try_again) {
// Already tried once, so report failure.
CLIENT_LOG("Failed to connect to or start server!");
return {};
}
try_again = false;
CLIENT_LOG("Waiting for incompatible server shutdown.");
if (!WaitForServerShutdown()) {
*error = "Could not shutdown incompatible server!?!";
return {};
}
}
}
IpcHandle PersistentService::Client::RawConnect(std::string* err) const {
return IpcServiceHandle::ConnectTo(service_name_, err);
}
IpcHandle PersistentService::Client::ConnectOrStartServer(
const Config& config, std::string* err) const {
int retry_count = 5;
int retry_delay_ms = 10;
bool server_started = false;
ProcessInfo info = {};
while (true) {
// Try to connect to the server first.
IpcHandle client = RawConnect(err);
if (client) {
CLIENT_LOG("Got client connection to server!");
uint8_t query_type = kServerCommandTypeClientQuery;
if (!client.Write(&query_type, sizeof(query_type), err)) {
CLIENT_LOG(
"ERROR: Could not write query type, did server disconnect?: %s",
err->c_str());
client.Close();
} else {
CLIENT_LOG("Sent query type");
}
return client;
}
if (!server_started) {
CLIENT_LOG("No initial connection. Spawning server");
// Spawn a server if one wasn't already started.
if (!SpawnServerProcess(config, &info, err)) {
CLIENT_LOG("Could not spawn server: %s", err->c_str());
return {};
}
server_started = true;
} else if (retry_count == 0) {
CLIENT_LOG("Failure to connect to server, exiting retry loop: %s",
err->c_str());
return {};
} else {
--retry_count;
if (retry_delay_ms < 1024)
retry_delay_ms *= 2;
}
CLIENT_LOG("Waiting for %ld milliseconds", (long)retry_delay_ms);
SleepMilliSeconds(retry_delay_ms);
}
}
///////////////////////////////////////////////////////////////////////////
///
/// S E R V E R S I D E
///
PersistentService::Server::Server(const std::string& service_name)
: service_name_(service_name) {}
bool PersistentService::Server::BindService(std::string* err) {
if (service_handle_) {
*err = "Server already started!";
return false;
}
SERVER_LOG("Trying to get service %s", service_name_.c_str());
service_handle_ = IpcServiceHandle::BindTo(service_name_, err);
if (!service_handle_)
SERVER_LOG("Got error %s", err->c_str());
else
SERVER_LOG("Got it!");
return !!service_handle_;
}
void PersistentService::Server::RunServerThenExit(
const VersionCheckHandler& version_check_handler,
const RequestHandler& request_handler) {
std::string error;
if (!service_handle_ && !BindService(&error)) {
Error("Could not acquire service handle: %s", error.c_str());
exit(1);
}
SERVER_LOG("Service handle acquired, starting server loop");
RunServerThenExitInternal(version_check_handler, std::move(request_handler));
}
void PersistentService::Server::RunServerThenExitInternal(
const VersionCheckHandler& version_check_handler,
const RequestHandler& request_handler) {
AsyncLoop& async_loop = AsyncLoop::Get();
std::string error;
#ifndef _WIN32
SigPipeBlocker sig_pipe_blocker;
#endif
/// Convenience class to wait for client connections with a timeout.
/// Usage is:
///
/// 1) Create class, passing references to an IpcServiceHandle
/// and an AsyncLoop instance.
///
/// 2) Call Wait() to wait for a new connection.
///
/// 3) Either destroy the instance, or call Close() before that to
/// ensure the corresponding server handle is closed before
/// calling exit().
///
class ConnectionWaiter {
public:
/// Constructor sets up an internal AsyncHandle to do the wait.
/// Note that this moves the native service handle into this instance, but
/// |service_handle| still needs to be closed manually before calling
/// ::exit(), because on MacOS it does extra work to remove a Unix socket
/// and pid file.
ConnectionWaiter(IpcServiceHandle& service_handle, AsyncLoop& async_loop)
: async_loop_(async_loop), async_(
AsyncHandle::Create(std::move(service_handle), async_loop, callback())) {}
/// Wait for a new client connection for |timeout_ms| milliseconds.
/// Return AsyncLoop exit status, on success, set |client| to receive
/// the new client connection handle.
AsyncLoop::ExitStatus Wait(int64_t timeout_ms, IpcHandle& client) {
completed_ = false;
accepted_ = false;
int64_t now_ms = async_loop_.NowMs();
(void)now_ms; // silence compiler when SERVER_LOG() expansion is empty.
if (timeout_ms < 0) {
SERVER_LOG("Waiting for new client connection (no timeout)");
} else {
SERVER_LOG("Waiting for new client connection (for up to %.1f seconds)",
timeout_ms / 1000.);
}
async_.StartAccept();
auto status = async_loop_.RunUntil([this]() { return completed_; },
timeout_ms);
if (accepted_)
client = async_.TakeAcceptedHandle();
SERVER_LOG("Connection time %s",
StringFormatDurationMs(async_loop_.NowMs() - now_ms).c_str());
return status;
}
/// Close the internal handle explicitly.
void Close() {
async_.Close();
}
private:
AsyncLoop& async_loop_;
AsyncHandle async_;
IpcHandle client_;
bool completed_ = false;
bool accepted_ = false;
// Return the AsyncHandle::Callback to be used by |async| in this instance.
AsyncHandle::Callback callback() {
return [this](AsyncError error, size_t) {
completed_ = true;
accepted_ = (error == 0);
};
}
};
ConnectionWaiter waiter(service_handle_, async_loop);
while (true) {
IpcHandle client;
AsyncLoop::ExitStatus status = waiter.Wait(connection_timeout_ms_, client);
if (status == AsyncLoop::ExitInterrupted)
break;
if (status == AsyncLoop::ExitTimeout) {
SERVER_LOG("Timeout waiting for client connection");
break;
}
if (!client) {
SERVER_LOG("Could not accept new client: %s", error.c_str());
break;
}
SERVER_LOG("Reading query type...");
// Get command, which can be 'kill' or 'query' at the moment.
uint8_t query_type = 0;
if (!client.Read(&query_type, sizeof(query_type), &error)) {
SERVER_LOG("Could not read client query type !?");
break;
}
SERVER_LOG("Got query_type %d\n", query_type);
if (query_type == kServerCommandTypeStop) {
SERVER_LOG("Client asking server to stop!");
break;
}
if (query_type == kServerCommandTypeGetPid) {
SERVER_LOG("Client asking for server pid!");
#ifdef _WIN32
int pid = static_cast<int>(GetCurrentProcessId());
#else
int pid = getpid();
#endif
if (!RemoteWrite(pid, client, &error)) {
SERVER_LOG("Could not send pid %d back to client!: %s", pid,
error.c_str());
break;
}
SERVER_LOG("Sent pid %d to client. Looping", pid);
continue;
}
if (query_type != kServerCommandTypeClientQuery) {
SERVER_LOG("Unknown client query type: %d", query_type);
break;
}
SERVER_LOG("Accepted client connection, checking version info");
std::string version_info;
if (!RemoteRead(version_info, client, &error)) {
SERVER_LOG("Could not client version info: %s", error.c_str());
break;
}
std::string version_check = version_check_handler(version_info);
if (!RemoteWrite(version_check, client, &error)) {
SERVER_LOG("Could not write version check result to client: %s",
error.c_str());
break;
}
if (!version_check.empty()) {
SERVER_LOG("Incompatible client version info: %s", version_check.c_str());
break;
}
if (!request_handler(std::move(client))) {
SERVER_LOG("Request handler returned false, exiting server loop");
break;
}
async_loop.ClearInterrupt();
}
SERVER_LOG("Exiting!");
// Since calling ::exit() directly here prevents the destructors
// of the variables defined in this function from running, close the
// waiter handle explicitly.
waiter.Close();
// Note: while the file descriptor was moved to waiter, this
// instance must be destroyed explicitly to remove the pid file and
// Unix socket filesystem entry on MacOS.
service_handle_.Close();
::exit(0);
}