| // Copyright 2023 Google Inc. All Rights Reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "persistent_service.h" |
| |
| #include "async_loop.h" |
| #include "ipc_utils.h" |
| #include "util.h" |
| |
| #define DEBUG 0 |
| |
| #ifdef _WIN32 |
| #include <windows.h> |
| #else |
| #include <fcntl.h> |
| #include <sys/stat.h> |
| #include <unistd.h> |
| #endif |
| |
| #if DEBUG |
| #define SERVER_LOG(...) \ |
| do { \ |
| fprintf(stderr, "SERVER_LOG: "); \ |
| fprintf(stderr, __VA_ARGS__); \ |
| fprintf(stderr, "\n"); \ |
| } while (0) |
| #define CLIENT_LOG(...) \ |
| do { \ |
| fprintf(stderr, "CLIENT_LOG: "); \ |
| fprintf(stderr, __VA_ARGS__); \ |
| fprintf(stderr, "\n"); \ |
| } while (0) |
| #else |
| #define CLIENT_LOG(...) (void)0 |
| #define SERVER_LOG(...) (void)0 |
| #endif |
| |
| namespace { |
| |
| // Sleep for |delay_ms| milliseconds. |
| void SleepMilliSeconds(int delay_ms) { |
| #ifdef _WIN32 |
| if (delay_ms > 0) |
| ::Sleep(static_cast<DWORD>(delay_ms)); |
| #else |
| usleep(static_cast<useconds_t>(delay_ms) * 1000); |
| #endif |
| } |
| |
| struct ProcessInfo { |
| #ifdef _WIN32 |
| HANDLE process_handle; |
| #else |
| pid_t process_pid; |
| #endif |
| }; |
| |
| // Start a new process, with command-line |args|. |
| // Uses fork()/exec() on Posix, and CreateProcess on Win32. |
| // Return true on success, or set |*err| and return false on failure. |
| bool SpawnServerProcess(const PersistentService::Config& config, |
| ProcessInfo* info, std::string* err) { |
| if (config.command.empty()) { |
| *err = "Empty command line!"; |
| return false; |
| } |
| |
| #ifdef _WIN32 |
| std::string command_string; |
| for (const auto& arg : config.command) { |
| std::string escaped; |
| GetWin32EscapedString(arg, &escaped); |
| command_string += escaped; |
| command_string += ' '; |
| } |
| // Remove trailing space if any. |
| if (!command_string.empty()) |
| command_string.resize(command_string.size() - 1u); |
| |
| SECURITY_ATTRIBUTES security_attributes = {}; |
| security_attributes.nLength = sizeof(SECURITY_ATTRIBUTES); |
| security_attributes.bInheritHandle = TRUE; |
| // Must be inheritable so subprocesses can dup to children. |
| HANDLE nul = |
| CreateFileA("NUL", GENERIC_READ | GENERIC_WRITE, |
| FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, |
| &security_attributes, OPEN_EXISTING, 0, NULL); |
| if (nul == INVALID_HANDLE_VALUE) |
| Win32Fatal("couldn't open nul"); |
| |
| HANDLE log = nul; |
| std::string log_file = config.log_file; |
| if (log_file.empty()) { |
| // As a debug helper, use the log file from this environment variable. |
| const char* env = getenv("DEBUG_PERSISTENT_SERVICE_LOG_FILE"); |
| if (env) |
| log_file = env; |
| } |
| if (!log_file.empty()) { |
| log = |
| CreateFileA(log_file.c_str(), STANDARD_RIGHTS_WRITE | FILE_APPEND_DATA, |
| FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, |
| &security_attributes, OPEN_ALWAYS, 0, NULL); |
| if (log == INVALID_HANDLE_VALUE) |
| Win32Fatal("couldn't open log file", log_file.c_str()); |
| } |
| |
| STARTUPINFOA startup_info = {}; |
| startup_info.cb = sizeof(STARTUPINFO); |
| startup_info.dwFlags = STARTF_USESTDHANDLES; |
| startup_info.hStdInput = nul; |
| startup_info.hStdOutput = log; |
| startup_info.hStdError = log; |
| PROCESS_INFORMATION process_info = {}; |
| |
| // Ninja handles ctrl-c, except for subprocesses in console pools. |
| DWORD process_flags = CREATE_NEW_PROCESS_GROUP; |
| |
| { |
| // TODO(digit): Create new environment variable block and add/replace |
| // the value in it, without touching the parent env. |
| for (const auto& pair : config.env_vars) { |
| const char* varname = pair.first.c_str(); |
| const char* value = pair.second.c_str(); |
| SetEnvironmentVariable(varname, value); |
| } |
| } |
| |
| // Do not prepend 'cmd /c' on Windows, this breaks command |
| // lines greater than 8,191 chars. |
| if (!CreateProcessA(NULL, (char*)command_string.c_str(), NULL, NULL, |
| /* inherit handles */ TRUE, process_flags, NULL, NULL, |
| &startup_info, &process_info)) { |
| DWORD error = GetLastError(); |
| if (error == ERROR_FILE_NOT_FOUND) { |
| CloseHandle(nul); |
| if (log != nul) |
| CloseHandle(log); |
| // child_ is already NULL; |
| *err = |
| "CreateProcess failed: The system cannot find the file " |
| "specified: [" + |
| command_string + "]"; |
| return false; |
| } else { |
| fprintf(stderr, "\nCreateProcess failed. Command attempted:\n\"%s\"\n", |
| command_string.c_str()); |
| const char* hint = NULL; |
| // ERROR_INVALID_PARAMETER means the command line was formatted |
| // incorrectly. This can be caused by a command line being too long or |
| // leading whitespace in the command. Give extra context for this case. |
| if (error == ERROR_INVALID_PARAMETER) { |
| hint = "is the command line too long?"; |
| } |
| Win32Fatal("CreateProcess", hint); |
| } |
| } |
| |
| CloseHandle(nul); |
| if (log != nul) |
| CloseHandle(log); |
| |
| CloseHandle(process_info.hThread); |
| return true; |
| #else // !_WIN32 |
| // Build arguments array for future exec() call. |
| std::vector<char*> exec_args; |
| exec_args.reserve(config.command.size() + 1); |
| for (const auto& arg : config.command) { |
| exec_args.push_back(const_cast<char*>(arg.data())); |
| CLIENT_LOG("CMD: [%s]", arg.c_str()); |
| } |
| exec_args.push_back(nullptr); |
| |
| pid_t process = fork(); |
| if (process < 0) { |
| *err = "fork failed()!"; |
| return false; |
| } |
| |
| if (process == 0) { |
| // Create new session to not receive signals from parent process group. |
| if (setsid() < 0) { |
| fprintf(stderr, "ERROR: setsid() failed: %s\n", strerror(errno)); |
| exit(1); |
| } |
| |
| // Change current working directory. |
| if (!config.working_dir.empty()) { |
| const char* work_dir = config.working_dir.c_str(); |
| if (chdir(work_dir) < 0) |
| ErrnoFatal("chdir", work_dir); |
| } |
| |
| // Redirect stdin to /dev/null and stdout/stderr to a log file if |
| // PERSISTENT_SERVER_LOG_FILE is set in the environment, or to |
| // /dev/null otherwise. |
| int null_fd = open("/dev/null", O_RDWR); |
| if (null_fd < 0) { |
| fprintf(stderr, "ERROR: open(/dev/null) failed: %s\n", strerror(errno)); |
| exit(1); |
| } |
| int log_fd = null_fd; |
| std::string log_file = config.log_file; |
| if (log_file.empty()) { |
| // As a debug helper, use the log file from this environment variable. |
| const char* env = getenv("DEBUG_PERSISTENT_SERVICE_LOG_FILE"); |
| if (env) |
| log_file = env; |
| } |
| if (!log_file.empty()) { |
| log_fd = open(log_file.c_str(), O_WRONLY | O_APPEND | O_CREAT, 0755); |
| if (log_fd < 0) { |
| fprintf(stderr, "ERROR: open(%s) failed: %s\n", log_file.c_str(), |
| strerror(errno)); |
| exit(1); |
| } |
| } |
| |
| fflush(stdout); |
| fflush(stderr); |
| |
| dup2(null_fd, 0); |
| dup2(log_fd, 1); |
| dup2(log_fd, 2); |
| |
| // Set extra environment variables. |
| for (const auto& pair : config.env_vars) { |
| const char* varname = pair.first.c_str(); |
| const char* value = pair.second.c_str(); |
| setenv(varname, value, 1); |
| } |
| |
| fprintf(stderr, "\n\nSTARTING NEW PERSISTENT SERVER: %s\n", |
| config.command[0].c_str()); |
| execv(config.command[0].c_str(), exec_args.data()); |
| fprintf(stderr, "ERROR: exec() failed: %s\n", strerror(errno)); |
| exit(1); |
| } |
| // In parent process, do not do anything. |
| info->process_pid = process; |
| return true; |
| #endif // !_WIN32 |
| } |
| |
| // Type of commands received from the server. |
| enum ServerCommandType { |
| kServerCommandTypeStop, |
| kServerCommandTypeGetPid, |
| kServerCommandTypeClientQuery, |
| }; |
| |
| } // namespace |
| |
| /////////////////////////////////////////////////////////////////////////// |
| /// |
| /// C L I E N T S I D E |
| /// |
| |
| PersistentService::Client::Client(const std::string& service_name) |
| : service_name_(service_name) {} |
| |
| PersistentService::Client::~Client() = default; |
| |
| bool PersistentService::Client::HasServer() const { |
| return IpcServiceHandle::IsBound(service_name_); |
| } |
| |
| int PersistentService::Client::GetServerPid() const { |
| std::string err; |
| IpcHandle client = RawConnect(&err); |
| if (!client) |
| return -1; |
| |
| uint8_t query_type = kServerCommandTypeGetPid; |
| int server_pid = -1; |
| if (!RemoteWrite(query_type, client, &err) || |
| !RemoteRead(server_pid, client, &err)) { |
| return -1; |
| } |
| return server_pid; |
| } |
| |
| bool PersistentService::Client::StopServer(std::string* err) const { |
| IpcHandle client = RawConnect(err); |
| if (!client) |
| return false; |
| |
| uint8_t query_type = kServerCommandTypeStop; |
| if (!client.Write(&query_type, sizeof(query_type), err)) { |
| *err = StringFormat("Could not stop server: %s", err->c_str()); |
| return false; |
| } |
| |
| return true; |
| } |
| |
| bool PersistentService::Client::WaitForServerShutdown() { |
| std::string error; |
| for (int retry_count = 20; retry_count > 0; --retry_count) { |
| if (!HasServer()) |
| return true; |
| |
| SleepMilliSeconds(100); |
| } |
| return false; |
| } |
| |
| IpcHandle PersistentService::Client::Connect(const Config& config, |
| std::string* error) { |
| bool try_again = true; |
| |
| while (true) { |
| CLIENT_LOG("Trying to connect to server."); |
| IpcHandle client = ConnectOrStartServer(config, error); |
| if (!client) |
| return {}; |
| |
| CLIENT_LOG("Sending version info to server."); |
| if (!RemoteWrite(config.version_info, client, error)) { |
| *error = "Could not send version info: " + *error; |
| return {}; |
| } |
| |
| // Receive the |compatible| flag that indicates that the server is |
| // compatible with the current build plan. |
| std::string version_check; |
| if (!RemoteRead(version_check, client, error)) { |
| *error = "Could not read version check result: " + *error; |
| return {}; |
| } |
| |
| if (version_check.empty()) { |
| // Good, return the handle now. |
| return client; |
| } |
| |
| // The server was not compatible with the current client. |
| // Assume it exited, and start another one at least once. |
| CLIENT_LOG("Incompatible server version: %s", version_check.c_str()); |
| |
| if (!try_again) { |
| // Already tried once, so report failure. |
| CLIENT_LOG("Failed to connect to or start server!"); |
| return {}; |
| } |
| |
| try_again = false; |
| CLIENT_LOG("Waiting for incompatible server shutdown."); |
| if (!WaitForServerShutdown()) { |
| *error = "Could not shutdown incompatible server!?!"; |
| return {}; |
| } |
| } |
| } |
| |
| IpcHandle PersistentService::Client::RawConnect(std::string* err) const { |
| return IpcServiceHandle::ConnectTo(service_name_, err); |
| } |
| |
| IpcHandle PersistentService::Client::ConnectOrStartServer( |
| const Config& config, std::string* err) const { |
| int retry_count = 5; |
| int retry_delay_ms = 10; |
| bool server_started = false; |
| ProcessInfo info = {}; |
| while (true) { |
| // Try to connect to the server first. |
| IpcHandle client = RawConnect(err); |
| if (client) { |
| CLIENT_LOG("Got client connection to server!"); |
| uint8_t query_type = kServerCommandTypeClientQuery; |
| if (!client.Write(&query_type, sizeof(query_type), err)) { |
| CLIENT_LOG( |
| "ERROR: Could not write query type, did server disconnect?: %s", |
| err->c_str()); |
| client.Close(); |
| } else { |
| CLIENT_LOG("Sent query type"); |
| } |
| return client; |
| } |
| |
| if (!server_started) { |
| CLIENT_LOG("No initial connection. Spawning server"); |
| // Spawn a server if one wasn't already started. |
| if (!SpawnServerProcess(config, &info, err)) { |
| CLIENT_LOG("Could not spawn server: %s", err->c_str()); |
| return {}; |
| } |
| server_started = true; |
| |
| } else if (retry_count == 0) { |
| CLIENT_LOG("Failure to connect to server, exiting retry loop: %s", |
| err->c_str()); |
| return {}; |
| } else { |
| --retry_count; |
| if (retry_delay_ms < 1024) |
| retry_delay_ms *= 2; |
| } |
| |
| CLIENT_LOG("Waiting for %ld milliseconds", (long)retry_delay_ms); |
| SleepMilliSeconds(retry_delay_ms); |
| } |
| } |
| |
| /////////////////////////////////////////////////////////////////////////// |
| /// |
| /// S E R V E R S I D E |
| /// |
| |
| PersistentService::Server::Server(const std::string& service_name) |
| : service_name_(service_name) {} |
| |
| bool PersistentService::Server::BindService(std::string* err) { |
| if (service_handle_) { |
| *err = "Server already started!"; |
| return false; |
| } |
| SERVER_LOG("Trying to get service %s", service_name_.c_str()); |
| service_handle_ = IpcServiceHandle::BindTo(service_name_, err); |
| if (!service_handle_) |
| SERVER_LOG("Got error %s", err->c_str()); |
| else |
| SERVER_LOG("Got it!"); |
| return !!service_handle_; |
| } |
| |
| void PersistentService::Server::RunServerThenExit( |
| const VersionCheckHandler& version_check_handler, |
| const RequestHandler& request_handler) { |
| std::string error; |
| if (!service_handle_ && !BindService(&error)) { |
| Error("Could not acquire service handle: %s", error.c_str()); |
| exit(1); |
| } |
| SERVER_LOG("Service handle acquired, starting server loop"); |
| RunServerThenExitInternal(version_check_handler, std::move(request_handler)); |
| } |
| |
| void PersistentService::Server::RunServerThenExitInternal( |
| const VersionCheckHandler& version_check_handler, |
| const RequestHandler& request_handler) { |
| AsyncLoop& async_loop = AsyncLoop::Get(); |
| std::string error; |
| #ifndef _WIN32 |
| SigPipeBlocker sig_pipe_blocker; |
| #endif |
| |
| /// Convenience class to wait for client connections with a timeout. |
| /// Usage is: |
| /// |
| /// 1) Create class, passing references to an IpcServiceHandle |
| /// and an AsyncLoop instance. |
| /// |
| /// 2) Call Wait() to wait for a new connection. |
| /// |
| /// 3) Either destroy the instance, or call Close() before that to |
| /// ensure the corresponding server handle is closed before |
| /// calling exit(). |
| /// |
| class ConnectionWaiter { |
| public: |
| /// Constructor sets up an internal AsyncHandle to do the wait. |
| /// Note that this moves the native service handle into this instance, but |
| /// |service_handle| still needs to be closed manually before calling |
| /// ::exit(), because on MacOS it does extra work to remove a Unix socket |
| /// and pid file. |
| ConnectionWaiter(IpcServiceHandle& service_handle, AsyncLoop& async_loop) |
| : async_loop_(async_loop), async_( |
| AsyncHandle::Create(std::move(service_handle), async_loop, callback())) {} |
| |
| /// Wait for a new client connection for |timeout_ms| milliseconds. |
| /// Return AsyncLoop exit status, on success, set |client| to receive |
| /// the new client connection handle. |
| AsyncLoop::ExitStatus Wait(int64_t timeout_ms, IpcHandle& client) { |
| completed_ = false; |
| accepted_ = false; |
| |
| int64_t now_ms = async_loop_.NowMs(); |
| (void)now_ms; // silence compiler when SERVER_LOG() expansion is empty. |
| |
| if (timeout_ms < 0) { |
| SERVER_LOG("Waiting for new client connection (no timeout)"); |
| } else { |
| SERVER_LOG("Waiting for new client connection (for up to %.1f seconds)", |
| timeout_ms / 1000.); |
| } |
| |
| async_.StartAccept(); |
| auto status = async_loop_.RunUntil([this]() { return completed_; }, |
| timeout_ms); |
| |
| if (accepted_) |
| client = async_.TakeAcceptedHandle(); |
| |
| SERVER_LOG("Connection time %s", |
| StringFormatDurationMs(async_loop_.NowMs() - now_ms).c_str()); |
| return status; |
| } |
| |
| /// Close the internal handle explicitly. |
| void Close() { |
| async_.Close(); |
| } |
| |
| private: |
| AsyncLoop& async_loop_; |
| AsyncHandle async_; |
| IpcHandle client_; |
| bool completed_ = false; |
| bool accepted_ = false; |
| |
| // Return the AsyncHandle::Callback to be used by |async| in this instance. |
| AsyncHandle::Callback callback() { |
| return [this](AsyncError error, size_t) { |
| completed_ = true; |
| accepted_ = (error == 0); |
| }; |
| } |
| }; |
| |
| ConnectionWaiter waiter(service_handle_, async_loop); |
| while (true) { |
| IpcHandle client; |
| AsyncLoop::ExitStatus status = waiter.Wait(connection_timeout_ms_, client); |
| |
| if (status == AsyncLoop::ExitInterrupted) |
| break; |
| |
| if (status == AsyncLoop::ExitTimeout) { |
| SERVER_LOG("Timeout waiting for client connection"); |
| break; |
| } |
| if (!client) { |
| SERVER_LOG("Could not accept new client: %s", error.c_str()); |
| break; |
| } |
| |
| SERVER_LOG("Reading query type..."); |
| // Get command, which can be 'kill' or 'query' at the moment. |
| uint8_t query_type = 0; |
| if (!client.Read(&query_type, sizeof(query_type), &error)) { |
| SERVER_LOG("Could not read client query type !?"); |
| break; |
| } |
| |
| SERVER_LOG("Got query_type %d\n", query_type); |
| |
| if (query_type == kServerCommandTypeStop) { |
| SERVER_LOG("Client asking server to stop!"); |
| break; |
| } |
| if (query_type == kServerCommandTypeGetPid) { |
| SERVER_LOG("Client asking for server pid!"); |
| #ifdef _WIN32 |
| int pid = static_cast<int>(GetCurrentProcessId()); |
| #else |
| int pid = getpid(); |
| #endif |
| if (!RemoteWrite(pid, client, &error)) { |
| SERVER_LOG("Could not send pid %d back to client!: %s", pid, |
| error.c_str()); |
| break; |
| } |
| SERVER_LOG("Sent pid %d to client. Looping", pid); |
| continue; |
| } |
| if (query_type != kServerCommandTypeClientQuery) { |
| SERVER_LOG("Unknown client query type: %d", query_type); |
| break; |
| } |
| |
| SERVER_LOG("Accepted client connection, checking version info"); |
| std::string version_info; |
| if (!RemoteRead(version_info, client, &error)) { |
| SERVER_LOG("Could not client version info: %s", error.c_str()); |
| break; |
| } |
| |
| std::string version_check = version_check_handler(version_info); |
| if (!RemoteWrite(version_check, client, &error)) { |
| SERVER_LOG("Could not write version check result to client: %s", |
| error.c_str()); |
| break; |
| } |
| |
| if (!version_check.empty()) { |
| SERVER_LOG("Incompatible client version info: %s", version_check.c_str()); |
| break; |
| } |
| |
| if (!request_handler(std::move(client))) { |
| SERVER_LOG("Request handler returned false, exiting server loop"); |
| break; |
| } |
| async_loop.ClearInterrupt(); |
| } |
| SERVER_LOG("Exiting!"); |
| |
| // Since calling ::exit() directly here prevents the destructors |
| // of the variables defined in this function from running, close the |
| // waiter handle explicitly. |
| waiter.Close(); |
| |
| // Note: while the file descriptor was moved to waiter, this |
| // instance must be destroyed explicitly to remove the pid file and |
| // Unix socket filesystem entry on MacOS. |
| service_handle_.Close(); |
| |
| ::exit(0); |
| } |