blob: 09d3c78fe22ea227c407ee4ede1d231357fa37fd [file] [log] [blame]
// Copyright 2023 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <errno.h>
#include <fcntl.h>
#include <stddef.h>
#include <stdio.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include <unistd.h>
#include "ipc_handle.h"
#include "util.h" // For StringFormat() and GetLastErrorString()
namespace {
// helper macro to loop on EINTR during syscalls.
// Important: Do not use it for close(), use CloseFd() instead.
#define HANDLE_EINTR(x) \
({ \
decltype(x) eintr_wrapper_result; \
do { \
eintr_wrapper_result = (x); \
} while (eintr_wrapper_result == -1 && errno == EINTR); \
eintr_wrapper_result; \
})
// Close file descriptor if needed, preserving errno
// since EINTR can happen during a close(), but there
// is nothing that can be done when it does (since
// it is impossible to tell whether the descriptor
// was already closed or not, the result being kernel
// and system specific).
void CloseFd(int& fd) {
if (fd >= 0) {
int save_errno = errno;
::close(fd);
fd = -1;
errno = save_errno;
}
}
// Convenience function to write the errno message to a string
// and return false.
bool SetErrnoMessage(std::string* error_message) {
*error_message = strerror(errno);
return false;
}
// Return true if |fd| is in non-blocking mode.
bool IsNonBlockingFd(int fd) {
int flags = fcntl(fd, F_GETFL);
return (flags >= 0) && (flags & O_NONBLOCK) != 0;
}
// Set the non-blocking flags of |fd| to |enabled|
void SetNonBlockingFd(int fd, bool enabled) {
int flags = fcntl(fd, F_GETFL);
if (flags < 0)
return;
int new_flags = enabled ? (flags | O_NONBLOCK) : (flags & ~O_NONBLOCK);
if (new_flags != flags)
fcntl(fd, F_SETFL, new_flags);
}
bool IsClosedOnExec(int fd) {
int flags = fcntl(fd, F_GETFD);
return (flags >= 0) && (flags & FD_CLOEXEC) != 0;
}
void SetClosedOnExecFlag(int fd, bool enabled) {
if (fd >= 0) {
int flags = fcntl(fd, F_GETFD);
int new_flags = enabled ? (flags | FD_CLOEXEC) : (flags & ~FD_CLOEXEC);
if (flags != new_flags)
fcntl(fd, F_SETFD, new_flags);
}
}
// Create a new unix-domain socket, potentially in non-blocking mode,
// always with CLOEXEC.
int CreateUnixSocket(bool non_blocking) {
#if defined(SOCK_NONBLOCK) && defined(SOCK_CLOEXEC)
int flags = SOCK_STREAM | SOCK_CLOEXEC | (non_blocking ? SOCK_NONBLOCK : 0);
return socket(AF_UNIX, flags, 0);
#else // !SOCK_NONBLOCK || !SOCK_CLOEXEC
int sock = socket(AF_UNIX, SOCK_STREAM, 0);
if (sock < 0)
return -1;
int flags = fcntl(sock, F_GETFD);
fcntl(sock, F_SETFD, flags | FD_CLOEXEC);
if (non_blocking)
SetNonBlockingFd(sock, true);
return sock;
#endif // !SOCK_NONBLOCK || !SOCK_CLOEXEC
}
// Convenience class to wait for a single file descriptor to become
// either readable or writable. Usage is:
//
// - Create instance.
// - Call Wait() passing a timeout.
//
struct PosixIoConditionWaiter {
// Constructor. Set |writable| to true to wait for write events,
// otherwise for |read| events.
PosixIoConditionWaiter(int fd, bool writable) : fd_(fd), writable_(writable) {
FD_ZERO(&fds_);
}
// Wait for the specific condition. On success return true.
// On failure, set |*error| then return false. In all cases
// |*did_timeout| will be set to true only in case of timeout.
bool Wait(int64_t timeout_ms, bool* did_timeout, std::string* error) {
// Reset fds_ in case Wait() is called multiple times.
FD_SET(fd_, &fds_);
struct timespec* ts = nullptr;
struct timespec timeout_ts;
if (timeout_ms >= 0) {
timeout_ts.tv_sec = static_cast<time_t>(timeout_ms / 1000);
timeout_ts.tv_nsec =
static_cast<int32_t>((timeout_ms % 1000) * 1000000LL);
ts = &timeout_ts;
}
*did_timeout = false;
fd_set* readable = writable_ ? nullptr : &fds_;
fd_set* writable = writable_ ? &fds_ : nullptr;
int ret = pselect(fd_ + 1, readable, writable, nullptr, ts, nullptr);
if (ret < 0)
return SetErrnoMessage(error);
if (!FD_ISSET(fd_, &fds_)) {
*did_timeout = true;
*error = "timed out";
return false;
}
return true;
}
const int fd_;
const bool writable_;
fd_set fds_;
};
// Retrieve non-blocking socket connection status.
int GetSocketConnectStatus(int fd) {
int so_error = 0;
socklen_t so_error_len = sizeof(so_error);
int ret = getsockopt(fd, SOL_SOCKET, SO_ERROR, &so_error, &so_error_len);
if (ret < 0)
return ret;
return so_error;
}
// Set USE_LINUX_NAMESPACE to 1 to use Linux abstract
// unix namespace, which do not require a filesystem
// entry point.
#ifdef __linux__
#define USE_LINUX_NAMESPACE 1
#endif
#if !USE_LINUX_NAMESPACE
// Return runtime directory where to create a Unix socket.
// Only used for non-Linux systems. Callers can assume
// that the directory already exists (otherwise, the system
// is not configured properly, and an error on bind or
// connect operation is expected).
std::string GetRuntimeDirectory() {
std::string result;
// XDG_RUNTIME_DIR might be defined on BSDs and other operating
// systems.
const char* xdg_runtime_dir = getenv("XDG_RUNTIME_DIR");
if (xdg_runtime_dir) {
result = xdg_runtime_dir;
}
if (result.empty()) {
const char* tmp = getenv("TMPDIR");
if (!tmp || !tmp[0])
tmp = "/tmp";
result = tmp;
}
return result;
}
#endif // !USE_LINUX_NAMESPACE
// Return the Unix socket path to be used for |service_name|.
std::string CreateUnixSocketPath(StringPiece service_name) {
// On Linux, use the abstract namespace by creating a string with
// a NUL byte at the front. On other platform, use the runtime
// directory instead.
#if USE_LINUX_NAMESPACE
std::string result(1, '\0');
#else
std::string result = GetRuntimeDirectory() + "/";
#endif
result += "basic_ipc-";
const char* user = getenv("USER");
if (!user || !user[0])
user = "unknown_user";
result += user;
result += "-";
result += service_name.AsString();
return result;
}
// Convenience class to model a Unix socket address.
// Usage is:
// 1) Create instance, passing service name.
// 2) Call valid() to check the instance. If false the
// service name was too long.
// 3) Use address() and size() to pass to sendmsg().
class LocalAddress {
public:
LocalAddress(StringPiece service_name) {
local_ = {};
local_.sun_family = AF_UNIX;
std::string path = CreateUnixSocketPath(service_name);
if (path.size() >= sizeof(local_.sun_path))
return; // Service name is too long.
memcpy(local_.sun_path, path.data(), path.size());
local_.sun_path[path.size()] = '\0';
size_ = offsetof(sockaddr_un, sun_path) + path.size() + 1;
}
bool valid() const { return size_ > 0; }
sockaddr* address() const { return const_cast<sockaddr*>(&generic_); }
size_t size() const { return size_; }
const char* path() const { return local_.sun_path; }
std::string pid_path() const {
return StringFormat("%s.pid", local_.sun_path);
}
private:
size_t size_ = 0;
union {
sockaddr_un local_;
sockaddr generic_;
};
};
} // namespace
// static
constexpr int IpcHandle::kInvalid;
// static
IpcHandle::HandleType IpcHandle::CloneNativeHandle(HandleType handle,
bool inherited) {
int fd = ::dup(handle);
SetClosedOnExecFlag(fd, inherited);
return fd;
}
void IpcHandle::Close() {
CloseFd(handle_);
}
int IpcHandle::ReleaseNativeHandle() {
int result = handle_;
handle_ = -1;
return result;
}
ssize_t IpcHandle::Read(void* buff, size_t buffer_size,
std::string* error_message) const {
auto* buffer = static_cast<char*>(buff);
ssize_t result = 0;
while (buffer_size > 0) {
ssize_t count = read(handle_, buffer, buffer_size);
if (count < 0) {
if (errno == EINTR)
continue;
if (result > 0) {
// Ignore this error to return the current read result.
// This assumes the error will repeat on the next call.
break;
}
*error_message = strerror(errno);
return -1;
} else if (count == 0) {
break;
}
buffer += count;
buffer_size -= static_cast<size_t>(count);
result += count;
}
return result;
}
ssize_t IpcHandle::Write(const void* buff, size_t buffer_size,
std::string* error_message) const {
auto* buffer = static_cast<const char*>(buff);
ssize_t result = 0;
while (buffer_size > 0) {
ssize_t count = write(handle_, buffer, buffer_size);
if (count < 0) {
if (errno == EINTR)
continue;
if (result > 0) {
break;
}
*error_message = strerror(errno);
return -1;
} else if (count == 0) {
break;
}
buffer += count;
buffer_size -= static_cast<size_t>(count);
result += count;
}
return result;
}
bool IpcHandle::SendNativeHandle(HandleType native,
std::string* error_message) const {
char ch = 'x';
iovec iov = { &ch, 1 };
union {
char buf[CMSG_SPACE(sizeof(int))];
cmsghdr align;
} control;
memset(control.buf, 0, sizeof(control.buf));
msghdr header = {};
header.msg_iov = &iov;
header.msg_iovlen = 1;
header.msg_control = control.buf;
header.msg_controllen = sizeof(control.buf);
cmsghdr* control_header = CMSG_FIRSTHDR(&header);
control_header->cmsg_len = CMSG_LEN(sizeof(int));
control_header->cmsg_level = SOL_SOCKET;
control_header->cmsg_type = SCM_RIGHTS;
reinterpret_cast<int*>(CMSG_DATA(control_header))[0] = native;
ssize_t ret = HANDLE_EINTR(sendmsg(handle_, &header, 0));
if (ret == -1)
return SetErrnoMessage(error_message);
return true;
}
bool IpcHandle::ReceiveNativeHandle(IpcHandle* native,
std::string* error_message) const {
char ch = '\0';
iovec iov = { &ch, 1 };
union {
char buf[CMSG_SPACE(sizeof(int))];
cmsghdr align;
} control;
memset(control.buf, 0, sizeof(control.buf));
msghdr header = {};
header.msg_iov = &iov;
header.msg_iovlen = 1;
header.msg_control = control.buf;
header.msg_controllen = sizeof(control.buf);
ssize_t ret = HANDLE_EINTR(recvmsg(handle_, &header, 0));
if (ret == -1)
return SetErrnoMessage(error_message);
cmsghdr* control_header = CMSG_FIRSTHDR(&header);
if (!control_header || control_header->cmsg_len != CMSG_LEN(sizeof(int)) ||
control_header->cmsg_level != SOL_SOCKET ||
control_header->cmsg_type != SCM_RIGHTS) {
*error_message =
std::string("Invalid data when receiving file descriptor!");
return false;
}
*native = IpcHandle(reinterpret_cast<int*>(CMSG_DATA(control_header))[0]);
return true;
}
bool IpcHandle::IsInheritable() const {
return IsClosedOnExec(handle_);
}
void IpcHandle::SetInheritable(bool enabled) {
SetClosedOnExecFlag(handle_, enabled);
}
bool IpcHandle::IsNonBlocking() const {
return IsNonBlockingFd(handle_);
}
void IpcHandle::SetNonBlocking(bool enable) {
SetNonBlockingFd(handle_, enable);
}
// static
IpcHandle::HandleType IpcHandle::NativeForStdio(FILE* file) {
return fileno(file);
}
// static
IpcHandle IpcHandle::CloneFromStdio(FILE* file) {
fflush(file);
return { CloneNativeHandle(fileno(file)) };
}
bool IpcHandle::CloneIntoStdio(FILE* file) {
if (file != stdout && file != stderr && file != stdin) {
errno = EINVAL;
return false;
}
if (handle_ < 0) {
errno = EINVAL;
return false;
}
fflush(file);
int ret = ::dup2(handle_, fileno(file));
if (ret < 0)
return false;
return true;
}
void IpcServiceHandle::Close() {
this->IpcHandle::Close();
if (!socket_path_.empty() && socket_path_[0] != '\0') {
// Remove socket and pid file.
unlink(socket_path_.c_str());
std::string pid_path = socket_path_;
pid_path += ".pid";
unlink(pid_path.c_str());
socket_path_.clear();
}
}
IpcServiceHandle::~IpcServiceHandle() {
Close();
}
#if !USE_LINUX_NAMESPACE
// Try to read the pidfile for |address| if it exists, and return the server
// process id it contains. Return 0 if there is no pid file or if it is
// malformed. Return -1 and set |*err| if the pid file could not be read
// (likely a permission issue).
int ReadPidFile(const std::string& pid_path, std::string* err) {
// Try to open the pid file.
FILE* pid_file = fopen(pid_path.c_str(), "r");
if (!pid_file) {
if (errno != ENOENT) {
*err = StringFormat("Cannot open pid file: %s", strerror(errno));
return -1;
}
// There is no pid file.
return 0;
}
int server_pid = -1;
int ret = fscanf(pid_file, "%d", &server_pid);
(void)fclose(pid_file);
if (ret != 1 || server_pid <= 0) {
// A malformed pid file, consider server not running and
// do not report error.
return 0;
}
return server_pid;
}
#endif // !USE_LINUX_NAMESPACE
// static
IpcServiceHandle IpcServiceHandle::BindTo(StringPiece service_name,
std::string* error_message) {
LocalAddress address(service_name);
if (!address.valid()) {
*error_message = std::string("Service name too long");
return {};
}
#if !USE_LINUX_NAMESPACE
// Try to see if another server is already running. Use a .pid file
// that contains the server's process PID to do that, and check whether
// it is still alive.
std::string pid_path = address.pid_path();
int server_pid = ReadPidFile(pid_path, error_message);
if (server_pid < 0)
return {};
if (server_pid > 0 && kill(server_pid, 0) == 0) {
// The server process is still running.
*error_message = "already in use";
return {};
}
// Create new temporary pid file before doing an atomic filesystem rename.
int cur_pid = getpid();
std::string temp_pid_path =
StringFormat("%s.temp.%d", pid_path.c_str(), cur_pid);
{
bool pid_file_error = false;
FILE* pid_file = fopen(temp_pid_path.c_str(), "w");
if (!pid_file) {
pid_file_error = true;
} else {
if (fprintf(pid_file, "%d", cur_pid) <= 0)
pid_file_error = true;
fclose(pid_file);
}
if (pid_file_error) {
*error_message = "Cannot create temporary pid file: ";
*error_message += strerror(errno);
return {};
}
}
// atomically rename the temporary file.
// Note that EINTR can happen in practice in rename() :-(
if (HANDLE_EINTR(rename(temp_pid_path.c_str(), pid_path.c_str())) < 0) {
*error_message = "Cannot rename pid file: ";
*error_message += strerror(errno);
return {};
}
// Remove stale socket if any.
if (server_pid > 0)
(void)unlink(address.path());
#endif // !USE_LINUX_NAMESPACE
int server_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (server_fd == -1) {
SetErrnoMessage(error_message);
return {};
}
if (bind(server_fd, address.address(), address.size()) < 0 ||
listen(server_fd, 1) < 0) {
CloseFd(server_fd);
SetErrnoMessage(error_message);
return {};
}
IpcServiceHandle result(server_fd, address.path());
return result;
}
IpcHandle IpcServiceHandle::AcceptClient(std::string* error_message) const {
int client = HANDLE_EINTR(accept(handle_, nullptr, nullptr));
if (client < 0) {
SetErrnoMessage(error_message);
return {};
}
return { client };
}
IpcHandle IpcServiceHandle::AcceptClient(int64_t timeout_ms, bool* did_timeout,
std::string* error_message) const {
PosixIoConditionWaiter waiter(handle_, false);
if (!waiter.Wait(timeout_ms, did_timeout, error_message))
return {};
int client = HANDLE_EINTR(accept(handle_, nullptr, nullptr));
if (client < 0) {
SetErrnoMessage(error_message);
return {};
}
return { client };
}
// static
bool IpcServiceHandle::IsBound(StringPiece service_name) {
LocalAddress address(service_name);
#if !USE_LINUX_NAMESPACE
std::string pid_path = address.pid_path();
std::string error;
int server_pid = ReadPidFile(pid_path, &error);
if (server_pid <= 0) {
// No pid file means there is no server running.
return false;
} else if (kill(server_pid, 0) == 0) {
// Server is still running.
return true;
}
#endif // !USE_LINUX_NAMESPACE
int server_fd = CreateUnixSocket(false);
if (server_fd == -1)
return false;
bool result;
if (bind(server_fd, address.address(), address.size()) == 0) {
// fprintf(stderr, "IsBound(): bind() succeeded!\n");
result = false;
} else if (errno == EADDRINUSE) {
// fprintf(stderr, "IsBound(): address in use!\n");
result = true;
} else {
// fprintf(stderr, "IsBound(): bind() returned error: %s\n",
// strerror(errno));
result = false;
}
CloseFd(server_fd);
return result;
}
// static
IpcHandle IpcServiceHandle::ConnectTo(StringPiece service_name,
std::string* error_message) {
LocalAddress address(service_name);
if (!address.valid()) {
*error_message = std::string("Service name too long");
return {};
}
int client_fd = CreateUnixSocket(false);
if (client_fd == -1) {
SetErrnoMessage(error_message);
return {};
}
if (HANDLE_EINTR(connect(client_fd, address.address(), address.size())) < 0) {
SetErrnoMessage(error_message);
CloseFd(client_fd);
return {};
}
return { client_fd };
}
// static
IpcHandle IpcServiceHandle::ConnectTo(StringPiece service_name,
int64_t timeout_ms, bool* did_timeout,
std::string* error_message) {
bool did_connect = false;
*did_timeout = false;
IpcHandle client = AsyncConnectTo(service_name, &did_connect, error_message);
if (!client)
return {};
if (did_connect)
return client;
PosixIoConditionWaiter waiter(client.native_handle(), false);
if (!waiter.Wait(timeout_ms, did_timeout, error_message)) {
return false;
}
int so_error = GetSocketConnectStatus(client.native_handle());
if (so_error != 0) {
SetErrnoMessage(error_message);
return {};
}
client.SetNonBlocking(false);
return client;
}
// static
IpcHandle IpcServiceHandle::AsyncConnectTo(StringPiece service_name,
bool* did_connect,
std::string* error_message) {
LocalAddress address(service_name);
if (!address.valid()) {
*error_message = std::string("Service name too long");
return {};
}
int client_fd = CreateUnixSocket(true);
if (client_fd == -1) {
SetErrnoMessage(error_message);
return {};
}
if (!HANDLE_EINTR(connect(client_fd, address.address(), address.size()))) {
// Connection completed immediately!
*did_connect = true;
return { client_fd };
}
if (errno == EINPROGRESS) {
// Connection could not be completed immediately.
*did_connect = false;
return { client_fd };
}
SetErrnoMessage(error_message);
CloseFd(client_fd);
return {};
}
// static
int IpcHandle::GetNativeAsyncConnectStatus(int fd) {
return GetSocketConnectStatus(fd);
}
// static
bool IpcHandle::CreatePipe(IpcHandle* read, IpcHandle* write,
std::string* error_message) {
int fds[2] = { -1, -1 };
if (pipe(fds) != 0)
return SetErrnoMessage(error_message);
*read = fds[0];
*write = fds[1];
return true;
}
bool IpcHandle::CreateAsyncPipe(IpcHandle* read, IpcHandle* write,
std::string* error_message) {
if (!CreatePipe(read, write, error_message))
return false;
read->SetNonBlocking(true);
write->SetNonBlocking(true);
return true;
}
bool IpcHandle::ReadFull(void* buffer, size_t buffer_size,
std::string* error_message) const {
ssize_t count = Read(buffer, buffer_size, error_message);
if (count < 0)
return false;
if (count != static_cast<ssize_t>(buffer_size)) {
*error_message =
StringFormat("Received %zu bytes, expected %zu", count, buffer_size);
return false;
}
return true;
}
bool IpcHandle::WriteFull(const void* buffer, size_t buffer_size,
std::string* error_message) const {
ssize_t count = Write(buffer, buffer_size, error_message);
if (count < 0)
return false;
if (count != static_cast<ssize_t>(buffer_size)) {
*error_message =
StringFormat("Sent %zu bytes, expected %zu", count, buffer_size);
return false;
}
return true;
}
std::string IpcHandle::display() const {
return StringFormat("fd=%d", handle_);
}