blob: e3d6013269040de0caa7a7d6a56c1db465ae8364 [file] [log] [blame] [edit]
// Copyright 2024 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
#include <arpa/inet.h>
#include <endian.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <poll.h>
#include <string.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <vector>
// Connection represents a client TCP connection.
// It connects to the given addr:port and allows to send/receive
// flatbuffers-encoded messages.
class Connection
{
public:
Connection(const char* addr, const char* port)
: fd_(Connect(addr, port))
{
}
int FD() const
{
return fd_;
}
template <typename Msg>
void Send(const Msg& msg)
{
typedef typename Msg::TableType Raw;
auto off = Raw::Pack(fbb_, &msg);
fbb_.FinishSizePrefixed(off);
auto data = fbb_.GetBufferSpan();
Send(data.data(), data.size());
fbb_.Reset();
}
template <typename Msg>
void Recv(Msg& msg)
{
typedef typename Msg::TableType Raw;
flatbuffers::uoffset_t size;
Recv(&size, sizeof(size));
size = le32toh(size);
recv_buf_.resize(size);
Recv(recv_buf_.data(), size);
auto raw = flatbuffers::GetRoot<Raw>(recv_buf_.data());
raw->UnPackTo(&msg);
}
void Send(const void* data, size_t size)
{
for (size_t sent = 0; sent < size;) {
ssize_t n = write(fd_, static_cast<const char*>(data) + sent, size - sent);
if (n > 0) {
sent += n;
continue;
}
if (errno == EINTR)
continue;
if (errno == EAGAIN) {
sleep_ms(1);
continue;
}
failmsg("failed to send rpc", "fd=%d want=%zu sent=%zu n=%zd", fd_, size, sent, n);
}
}
private:
const int fd_;
std::vector<char> recv_buf_;
flatbuffers::FlatBufferBuilder fbb_;
void Recv(void* data, size_t size)
{
for (size_t recv = 0; recv < size;) {
ssize_t n = read(fd_, static_cast<char*>(data) + recv, size - recv);
if (n > 0) {
recv += n;
continue;
}
if (errno == EINTR)
continue;
if (errno == EAGAIN) {
sleep_ms(1);
continue;
}
failmsg("failed to recv rpc", "fd=%d want=%zu recv=%zu n=%zd", fd_, size, recv, n);
}
}
static int Connect(const char* addr, const char* ports)
{
int port = atoi(ports);
bool localhost = !strcmp(addr, "localhost");
int fd;
if (!strcmp(addr, "stdin"))
return STDIN_FILENO;
if (port == 0)
failmsg("failed to parse manager port", "port=%s", ports);
sockaddr_in saddr4 = {};
saddr4.sin_family = AF_INET;
saddr4.sin_port = htons(port);
if (localhost)
addr = "127.0.0.1";
if (inet_pton(AF_INET, addr, &saddr4.sin_addr)) {
fd = Connect(&saddr4, &saddr4.sin_addr, port);
if (fd != -1 || !localhost)
return fd;
}
sockaddr_in6 saddr6 = {};
saddr6.sin6_family = AF_INET6;
saddr6.sin6_port = htons(port);
if (localhost)
addr = "0:0:0:0:0:0:0:1";
if (inet_pton(AF_INET6, addr, &saddr6.sin6_addr)) {
fd = Connect(&saddr6, &saddr6.sin6_addr, port);
if (fd != -1 || !localhost)
return fd;
}
auto* hostent = gethostbyname(addr);
if (!hostent)
failmsg("failed to resolve manager addr", "addr=%s h_errno=%d", addr, h_errno);
for (char** addr = hostent->h_addr_list; *addr; addr++) {
if (hostent->h_addrtype == AF_INET) {
memcpy(&saddr4.sin_addr, *addr, std::min<size_t>(hostent->h_length, sizeof(saddr4.sin_addr)));
fd = Connect(&saddr4, &saddr4.sin_addr, port);
} else if (hostent->h_addrtype == AF_INET6) {
memcpy(&saddr6.sin6_addr, *addr, std::min<size_t>(hostent->h_length, sizeof(saddr6.sin6_addr)));
fd = Connect(&saddr6, &saddr6.sin6_addr, port);
} else {
failmsg("unknown socket family", "family=%d", hostent->h_addrtype);
}
if (fd != -1)
return fd;
}
failmsg("can't connect to manager", "addr=%s:%s", addr, ports);
}
template <typename addr_t>
static int Connect(addr_t* addr, void* ip, int port)
{
auto* saddr = reinterpret_cast<sockaddr*>(addr);
int fd = socket(saddr->sa_family, SOCK_STREAM, IPPROTO_TCP);
if (fd == -1) {
printf("failed to create socket for address family %d", saddr->sa_family);
return -1;
}
char str[128] = {};
inet_ntop(saddr->sa_family, ip, str, sizeof(str));
int retcode = connect(fd, saddr, sizeof(*addr));
while (retcode == -1 && errno == EINTR)
retcode = ConnectWait(fd);
if (retcode != 0) {
printf("failed to connect to manager at %s:%d: %s\n", str, port, strerror(errno));
close(fd);
return -1;
}
return fd;
}
Connection(const Connection&) = delete;
Connection& operator=(const Connection&) = delete;
static int ConnectWait(int s)
{
struct pollfd pfd[1] = {{.fd = s, .events = POLLOUT}};
int error = 0;
socklen_t len = sizeof(error);
if (poll(pfd, 1, -1) == -1)
return -1;
if (getsockopt(s, SOL_SOCKET, SO_ERROR, &error, &len) == -1)
return -1;
if (error != 0) {
errno = error;
return -1;
}
return 0;
}
};
// Select is a wrapper around select system call.
class Select
{
public:
Select()
{
FD_ZERO(&rdset_);
}
void Arm(int fd)
{
FD_SET(fd, &rdset_);
max_fd_ = std::max(max_fd_, fd);
}
bool Ready(int fd) const
{
return FD_ISSET(fd, &rdset_);
}
void Wait(int ms)
{
timespec timeout = {.tv_sec = ms / 1000, .tv_nsec = (ms % 1000) * 1000 * 1000};
if (pselect(max_fd_ + 1, &rdset_, nullptr, nullptr, &timeout, nullptr) < 0) {
if (errno != EINTR && errno != EAGAIN)
fail("pselect failed");
}
}
static void Prepare(int fd)
{
if (fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK))
fail("fcntl(O_NONBLOCK) failed");
}
private:
fd_set rdset_;
int max_fd_ = -1;
Select(const Select&) = delete;
Select& operator=(const Select&) = delete;
};