blob: 6ff012fa8112a748f18b8f49df3419777b438650 [file] [log] [blame]
/*
* Copyright (C) 2025 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <chrono>
#include <iostream>
#include <cstring>
#include <vector>
#include <thread>
#include <chrono>
#include <cstring>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <poll.h>
#include <unistd.h>
#include <random>
#include <future>
#include <thread>
#include <android-base/file.h>
#include <android-base/logging.h>
#include <android-base/scopeguard.h>
#include <IOUringSocketHandler/IOUringSocketHandler.h>
#include <gtest/gtest.h>
// Test all combinations of queue_depth and messages
struct TestParam {
int queue_depth;
int numMessages;
};
class IOUringSocketHandlerTest : public ::testing::TestWithParam<TestParam> {
public:
bool IsIouringSupported() {
return IOUringSocketHandler::IsIouringSupported();
}
void SendMsg(int sock_send, const bool non_block);
protected:
void SetUp() override {
}
void TearDown() override {
close(sock_recv_);
unlink(socket_path_.c_str());
}
void ReceiveThreaduring(int sock_recv);
bool CreateServerSocket();
std::unique_ptr<IOUringSocketHandler> handler_;
void InitializeHandler(int socket_fd = 1);
// Default queue depth
int queue_depth_ = 1;
int sock_recv_;
std::string socket_path_;
const int kMessageSize = 4096;
std::vector<std::string> sent_messages; // Store sent messages for comparison
};
bool IOUringSocketHandlerTest::CreateServerSocket() {
int sock_recv = socket(AF_UNIX, SOCK_DGRAM, 0);
if (sock_recv < 0) {
PLOG(ERROR) << "socket failed";
return false;
}
std::string tmp_path = android::base::GetExecutableDirectory();
std::string socket_path = tmp_path + "/temp.sock";
struct sockaddr_un addr_recv;
memset(&addr_recv, 0, sizeof(addr_recv));
addr_recv.sun_family = AF_UNIX;
strcpy(addr_recv.sun_path, socket_path.c_str());
unlink(socket_path.c_str()); // Remove existing socket file if any
if (bind(sock_recv, (struct sockaddr*)&addr_recv, sizeof(addr_recv)) < 0) {
PLOG(ERROR) << "bind failed";
close(sock_recv);
return false;
}
int on = 1;
if (setsockopt(sock_recv, SOL_SOCKET, SO_PASSCRED, &on, sizeof(on))) {
return false;
}
sock_recv_ = sock_recv;
socket_path_ = socket_path;
return true;
}
// Function to generate a random string
static std::string generateRandomString(size_t length) {
static const char charset[] =
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dis(0, sizeof(charset) - 2);
std::string str(length, 0);
for (size_t i = 0; i < length; ++i) {
str[i] = charset[dis(gen)];
}
return str;
}
void IOUringSocketHandlerTest::InitializeHandler(int socket_fd) {
handler_ = std::make_unique<IOUringSocketHandler>(socket_fd);
}
TEST_F(IOUringSocketHandlerTest, SetupIoUring) {
if (!IsIouringSupported()) {
GTEST_SKIP() << "io_uring not supported. Skipping Test.";
}
InitializeHandler();
EXPECT_TRUE(handler_->SetupIoUring(queue_depth_));
}
TEST_F(IOUringSocketHandlerTest, AllocateAndRegisterBuffers) {
if (!IsIouringSupported()) {
GTEST_SKIP() << "io_uring not supported. Skipping Test.";
}
InitializeHandler();
EXPECT_TRUE(handler_->SetupIoUring(queue_depth_));
EXPECT_TRUE(handler_->AllocateAndRegisterBuffers(8, 4096));
}
TEST_F(IOUringSocketHandlerTest, MultipleAllocateAndRegisterBuffers) {
if (!IsIouringSupported()) {
GTEST_SKIP() << "io_uring not supported. Skipping Test.";
}
InitializeHandler();
EXPECT_TRUE(handler_->SetupIoUring(queue_depth_));
EXPECT_TRUE(handler_->AllocateAndRegisterBuffers(4, 4096));
handler_->DeRegisterBuffers();
EXPECT_TRUE(handler_->AllocateAndRegisterBuffers(2, 1024*1024L));
handler_->DeRegisterBuffers();
EXPECT_TRUE(handler_->AllocateAndRegisterBuffers(32, 1024));
handler_->DeRegisterBuffers();
// num_buffers should be power of 2
EXPECT_FALSE(handler_->AllocateAndRegisterBuffers(5, 4096));
}
void IOUringSocketHandlerTest::SendMsg(int sock_send, const bool non_block) {
const TestParam params = GetParam();
for (int i = 0; i < params.numMessages; ++i) {
std::string message = generateRandomString(kMessageSize);
sent_messages.push_back(message);
struct ucred cred;
memset(&cred, 0, sizeof(cred));
cred.pid = getpid();
cred.uid = getuid();
cred.gid = getgid();
struct iovec iov_send;
iov_send.iov_base = const_cast<char*>(message.data());
iov_send.iov_len = kMessageSize;
struct msghdr msg_send;
memset(&msg_send, 0, sizeof(msg_send));
msg_send.msg_iov = &iov_send;
msg_send.msg_iovlen = 1;
char control_buffer_send[CMSG_SPACE(sizeof(cred))];
memset(control_buffer_send, 0, sizeof(control_buffer_send));
msg_send.msg_control = control_buffer_send;
msg_send.msg_controllen = sizeof(control_buffer_send);
struct cmsghdr* cmsg_send = CMSG_FIRSTHDR(&msg_send);
cmsg_send->cmsg_level = SOL_SOCKET;
cmsg_send->cmsg_type = SCM_CREDENTIALS;
cmsg_send->cmsg_len = CMSG_LEN(sizeof(cred));
memcpy(CMSG_DATA(cmsg_send), &cred, sizeof(cred));
int flags = 0;
if (non_block) {
flags = MSG_DONTWAIT;
}
ssize_t sent_bytes;
while (true) {
sent_bytes = sendmsg(sock_send, &msg_send, flags);
if (sent_bytes >= 0) {
break; // Success
}
if (errno == EAGAIN || errno == EWOULDBLOCK) {
// Try again
continue;
} else {
perror("sendmsg failed");
return;
}
}
}
}
void IOUringSocketHandlerTest::ReceiveThreaduring(int sock_recv) {
std::unique_ptr<IOUringSocketHandler> uring_listener;
uring_listener = std::make_unique<IOUringSocketHandler>(sock_recv);
const TestParam params = GetParam();
ASSERT_TRUE(uring_listener->SetupIoUring(params.queue_depth));
uring_listener->AllocateAndRegisterBuffers(
params.queue_depth, kMessageSize);
ASSERT_TRUE(uring_listener->EnqueueMultishotRecvmsg());
long long received_messages = 0;
int index = 0;
while (received_messages < params.numMessages) {
struct ucred* cred = nullptr;
void* this_recv = nullptr;
size_t len = 0;
uring_listener->ReceiveData(&this_recv, len, &cred);
// Release the buffer from here onwards
{
auto scope_guard =
android::base::make_scope_guard([&uring_listener]() -> void {
uring_listener->ReleaseBuffer(); });
if (len <= 0) {
continue;
}
received_messages++;
char* char_ptr = static_cast<char*>(this_recv);
std::string payload_string(char_ptr, len);
std::string orig_string = sent_messages[index];
// Compare payload data
EXPECT_EQ(payload_string, orig_string);
// Verify credentials
EXPECT_EQ(cred->uid, getuid());
EXPECT_EQ(cred->gid, getgid());
EXPECT_EQ(cred->pid, getpid());
index += 1;
}
}
}
TEST_P(IOUringSocketHandlerTest, RecvmsgDataIntegrity) {
if (!IsIouringSupported()) {
GTEST_SKIP() << "io_uring not supported. Skipping Test.";
}
ASSERT_TRUE(CreateServerSocket());
int sock_send = socket(AF_UNIX, SOCK_DGRAM, 0);
ASSERT_GT(sock_send, 0);
struct sockaddr_un addr_send;
memset(&addr_send, 0, sizeof(addr_send));
addr_send.sun_family = AF_UNIX;
strcpy(addr_send.sun_path, socket_path_.c_str()); // Connect to the receiver
ASSERT_EQ(connect(sock_send, (struct sockaddr*)&addr_send, sizeof(addr_send)), 0);
std::vector<std::thread> send_threads;
send_threads.emplace_back([this, sock_send](){ SendMsg(sock_send, false); });
ReceiveThreaduring(sock_recv_);
for (auto& thread : send_threads) {
thread.join();
}
close(sock_send);
close(sock_recv_);
unlink(socket_path_.c_str());
}
TEST_P(IOUringSocketHandlerTest, RecvmsgDataIntegrityNonBlockingSend) {
if (!IsIouringSupported()) {
GTEST_SKIP() << "io_uring not supported. Skipping Test.";
}
ASSERT_TRUE(CreateServerSocket());
int sock_send = socket(AF_UNIX, SOCK_DGRAM, 0);
ASSERT_GT(sock_send, 0);
int flags = fcntl(sock_send, F_GETFL, 0);
// Set O_NONBLOCK
ASSERT_NE(fcntl(sock_send, F_SETFL, flags | O_NONBLOCK), -1);
struct sockaddr_un addr_send;
memset(&addr_send, 0, sizeof(addr_send));
addr_send.sun_family = AF_UNIX;
strcpy(addr_send.sun_path, socket_path_.c_str()); // Connect to the receiver
ASSERT_EQ(connect(sock_send, (struct sockaddr*)&addr_send, sizeof(addr_send)), 0);
std::vector<std::thread> send_threads;
send_threads.emplace_back([this, sock_send](){ SendMsg(sock_send, true); });
ReceiveThreaduring(sock_recv_);
for (auto& thread : send_threads) {
thread.join();
}
close(sock_send);
close(sock_recv_);
unlink(socket_path_.c_str());
}
std::vector<TestParam> GetConfigs() {
std::vector<TestParam> testParams;
std::vector<int> queue_depth = {1, 8, 16, 32, 64, 128, 256, 512};
std::vector<int> num_messages = {1, 100, 250, 500, 1000, 1500, 2000, 5000};
// This will test 64 combinations
for (auto q_depth : queue_depth) {
for (auto n_messages : num_messages) {
TestParam param;
param.queue_depth = q_depth;
param.numMessages = n_messages;
testParams.push_back(std::move(param));
}
}
return testParams;
}
INSTANTIATE_TEST_SUITE_P(Io, IOUringSocketHandlerTest,
::testing::ValuesIn(GetConfigs()));
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}