#include <iostream>
#include <cstring>
#include <vector>
#include <thread>
#include <chrono>
#include <cstring>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <poll.h>
#include <unistd.h>
#include <random>

#include <android-base/file.h>
#include <android-base/logging.h>

#include <android-base/scopeguard.h>
#include <IOUringSocketHandler/IOUringSocketHandler.h>

#include <benchmark/benchmark.h>

// Registered buffers
#define MAX_BUFFERS 256

// Threads sending 4k payload - 1 Million times
const int MESSAGE_SIZE = 4096; // 4KB
const int NUM_MESSAGES_PER_THREAD = 1000000;

// The benchmark is set to run 4 times with
// the following combinations:
//
// a: {1, 4, 8, 16} -> This is the number of sender threads
// b: {0, 1} -> Whether sender is blocking or non-blocking
#define BENCH_OPTIONS                 \
  MeasureProcessCPUTime()             \
      ->Unit(benchmark::kSecond) \
      ->Iterations(1)                \
      ->Repetitions(4)                \
      ->ReportAggregatesOnly(true) \
      ->ArgsProduct({{1, 4, 8, 16}, {0, 1}});

// Function to generate a random string
std::string generateRandomString(size_t length) {
    static const char charset[] =
        "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
    std::random_device rd;
    std::mt19937 gen(rd());
    std::uniform_int_distribution<> dis(0, sizeof(charset) - 2);

    std::string str(length, 0);
    for (size_t i = 0; i < length; ++i) {
        str[i] = charset[dis(gen)];
    }
    return str;
}

static void SetLabel(benchmark::State& state) {
    std::string num_senders = std::to_string(state.range(0));
    std::string type = state.range(1) == 0 ? "synchronous" : "asynchrounous";
    state.SetLabel(num_senders + "-SendThreads" + "/" + type);
}

// Function for sending thread
void sendThread(int sock_send, int sync_sender) {
    std::string message = generateRandomString(MESSAGE_SIZE);

    struct ucred cred;
    memset(&cred, 0, sizeof(cred));
    cred.pid = getpid();
    cred.uid = getuid();
    cred.gid = getgid();

    struct iovec iov_send;
    iov_send.iov_base = const_cast<char*>(message.data());
    iov_send.iov_len = MESSAGE_SIZE;

    struct msghdr msg_send;
    memset(&msg_send, 0, sizeof(msg_send));
    msg_send.msg_iov = &iov_send;
    msg_send.msg_iovlen = 1;

    char control_buffer_send[CMSG_SPACE(sizeof(cred))];
    memset(control_buffer_send, 0, sizeof(control_buffer_send));
    msg_send.msg_control = control_buffer_send;
    msg_send.msg_controllen = sizeof(control_buffer_send);

    struct cmsghdr* cmsg_send = CMSG_FIRSTHDR(&msg_send);
    cmsg_send->cmsg_level = SOL_SOCKET;
    cmsg_send->cmsg_type = SCM_CREDENTIALS;
    cmsg_send->cmsg_len = CMSG_LEN(sizeof(cred));
    memcpy(CMSG_DATA(cmsg_send), &cred, sizeof(cred));

    int flags = 0;
    if (!sync_sender) {
        flags = MSG_DONTWAIT;
    }
    for (int i = 0; i < NUM_MESSAGES_PER_THREAD; ++i) {
        ssize_t sent_bytes;
        while (true) {
            sent_bytes = sendmsg(sock_send, &msg_send, flags);
            if (sent_bytes >= 0) {
                break; // Success
            }
            if (errno == EAGAIN || errno == EWOULDBLOCK) {
                // Try again
                continue;
            } else {
                perror("sendmsg failed");
                return;
            }
        }
    }
    LOG(DEBUG) << "sendThread exiting";
}

// Receive using io_uring
bool receiveThreaduring(int sock_recv, int num_threads, long long& total_bytes_received,
                        double& average_latency) {
    std::unique_ptr<IOUringSocketHandler> async_listener_;
    async_listener_ = std::make_unique<IOUringSocketHandler>(sock_recv);
    if (!async_listener_->SetupIoUring(MAX_BUFFERS)) {
        LOG(ERROR) << "SetupIoUring failed";
        return false;
    }
    async_listener_->AllocateAndRegisterBuffers(
        MAX_BUFFERS, MESSAGE_SIZE);

    if (!async_listener_->EnqueueMultishotRecvmsg()) {
        LOG(ERROR) << "EnqueueMultishotRecvmsg failed";
        return false;
    }

    long long received_messages = 0;
    auto start_time = std::chrono::high_resolution_clock::now();
    long long total_latency = 0;

    while (received_messages < num_threads * NUM_MESSAGES_PER_THREAD) {
        struct ucred* cred = nullptr;
        void* this_recv = nullptr;
        size_t len = 0;
        auto receive_time = std::chrono::high_resolution_clock::now();
        async_listener_->ReceiveData(&this_recv, len, &cred);
        // Release the buffer from here onwards
        {
            auto scope_guard =
                android::base::make_scope_guard([&async_listener_]() -> void {
                  async_listener_->ReleaseBuffer(); });
            auto end_receive_time = std::chrono::high_resolution_clock::now();
            total_latency += std::chrono::duration_cast<std::chrono::microseconds>(
                              end_receive_time - receive_time).count();

            if (len <= 0) {
                LOG(DEBUG) << "Received zero length for: " << received_messages;
                continue;
            }
            received_messages++;
            total_bytes_received += len;
        }
    }

    auto end_time = std::chrono::high_resolution_clock::now();
    average_latency = static_cast<double>(total_latency) / received_messages;
    return true;
}

// Function for receiving thread using recvmsg()
void receiveThread(int sock_recv, int num_threads, long long& total_bytes_received,
                   double& average_latency) {
    char recv_buffer[MESSAGE_SIZE];
    struct ucred cred;

    struct iovec iov_recv;
    iov_recv.iov_base = recv_buffer;
    iov_recv.iov_len = MESSAGE_SIZE;

    struct msghdr msg_recv;
    memset(&msg_recv, 0, sizeof(msg_recv));
    msg_recv.msg_iov = &iov_recv;
    msg_recv.msg_iovlen = 1;

    char control_buffer_recv[CMSG_SPACE(sizeof(cred))];
    memset(control_buffer_recv, 0, sizeof(control_buffer_recv));
    msg_recv.msg_control = control_buffer_recv;
    msg_recv.msg_controllen = sizeof(control_buffer_recv);

    struct pollfd pfd;
    pfd.fd = sock_recv;
    pfd.events = POLLIN;

    long long received_messages = 0;
    auto start_time = std::chrono::high_resolution_clock::now();
    long long total_latency = 0;

    while (received_messages < num_threads * NUM_MESSAGES_PER_THREAD) {
        auto receive_time = std::chrono::high_resolution_clock::now();
        if (poll(&pfd, 1, -1) > 0) {
            ssize_t received_bytes = recvmsg(sock_recv, &msg_recv, 0);
            if (received_bytes < 0) {
                perror("recvmsg failed");
                break;
            }

            auto end_receive_time = std::chrono::high_resolution_clock::now();
            total_latency += std::chrono::duration_cast<std::chrono::microseconds>(
                              end_receive_time - receive_time).count();

            received_messages++;
            total_bytes_received += received_bytes;
        }
    }

    auto end_time = std::chrono::high_resolution_clock::now();
    average_latency = static_cast<double>(total_latency) / received_messages;
}

static int CreateServerSocket(std::string& path) {
    int sock_recv = socket(AF_UNIX, SOCK_DGRAM, 0);
    if (sock_recv < 0) {
        PLOG(ERROR) << "socket failed";
        return -1;
    }

    std::string tmp_path = android::base::GetExecutableDirectory();
    std::string socket_path = tmp_path + "/temp.sock";
    struct sockaddr_un addr_recv;
    memset(&addr_recv, 0, sizeof(addr_recv));
    addr_recv.sun_family = AF_UNIX;
    strcpy(addr_recv.sun_path, socket_path.c_str());

    unlink(socket_path.c_str()); // Remove existing socket file if any

    if (bind(sock_recv, (struct sockaddr*)&addr_recv, sizeof(addr_recv)) < 0) {
        PLOG(ERROR) << "bind failed";
        close(sock_recv);
        return -1;
    }

    path = socket_path;
    return sock_recv;
}

static void SocketBenchMark(benchmark::State& state, const bool io_uring) {
  state.PauseTiming();
  while (state.KeepRunning()) {
    std::string socket_path;
    int sock_recv = CreateServerSocket(socket_path);
    if (sock_recv < 0) {
        LOG(ERROR) << "CreateServerSocket failed";
        return;
    }

    const size_t num_sender_threads = state.range(0);
    const size_t sync_sender = state.range(1);
    std::vector<int> sender_sockets(num_sender_threads);
    // Sender socket setup (for each thread)
    for (int i = 0; i < num_sender_threads; ++i) {
        int sock_send = socket(AF_UNIX, SOCK_DGRAM, 0);
        if (sock_send < 0) {
            perror("socket failed");
            return;
        }

        if (!sync_sender) {
          // Set non-blocking for the sender socket
          int flags = fcntl(sock_send, F_GETFL, 0);
          if (flags == -1) {
            perror("fcntl F_GETFL failed");
            close(sock_send);
            for (int j = 0; j < i; ++j) { // Close previously opened sockets
              close(sender_sockets[j]);
            }
            close(sock_recv);
            return;
          }
          if (fcntl(sock_send, F_SETFL, flags | O_NONBLOCK) == -1) {
            perror("fcntl F_SETFL failed");
            close(sock_send);
            for (int j = 0; j < i; ++j) { // Close previously opened sockets
              close(sender_sockets[j]);
            }
            close(sock_recv);
            return;
          }
        }

        struct sockaddr_un addr_send;
        memset(&addr_send, 0, sizeof(addr_send));
        addr_send.sun_family = AF_UNIX;
        strcpy(addr_send.sun_path, socket_path.c_str()); // Connect to the receiver

        if (connect(sock_send, (struct sockaddr*)&addr_send, sizeof(addr_send)) < 0) {
            perror("connect failed");
            close(sock_send);
            for (int j = 0; j < i; ++j) { // Close previously opened sockets
                close(sender_sockets[j]);
            }
            close(sock_recv);
            return;
        }

        sender_sockets[i] = sock_send;
    }

    std::vector<std::thread> send_threads;
    for (int i = 0; i < num_sender_threads; ++i) {
        send_threads.emplace_back(sendThread, sender_sockets[i], sync_sender);
    }

    long long total_bytes_received = 0;
    double average_latency = 0.0;

    // Reset counters for each benchmark iteration
    total_bytes_received = 0;
    average_latency = 0.0;

    state.ResumeTiming();
    if (io_uring) {
        receiveThreaduring(sock_recv, num_sender_threads,
                           std::ref(total_bytes_received), std::ref(average_latency));
    } else {
        receiveThread(sock_recv, num_sender_threads,
                      std::ref(total_bytes_received), std::ref(average_latency));
    }
    state.PauseTiming();

    for (auto& thread : send_threads) {
        thread.join();
    }

    state.counters["Total_Data"] = total_bytes_received;
    state.counters["Latency(usec)"] = average_latency;
    state.SetBytesProcessed(total_bytes_received);
    state.SetItemsProcessed(num_sender_threads * NUM_MESSAGES_PER_THREAD);

    // Cleanup
    close(sock_recv);
    unlink(socket_path.c_str()); // Remove the socket file

    for (int sock : sender_sockets) {
        close(sock);
    }
  }
  SetLabel(state);
}

static void BM_ReceiveIOUring(benchmark::State& state) {
    SocketBenchMark(state, true);
}
BENCHMARK(BM_ReceiveIOUring)->BENCH_OPTIONS

static void BM_ReceiveSync(benchmark::State& state) {
    SocketBenchMark(state, false);
}
BENCHMARK(BM_ReceiveSync)->BENCH_OPTIONS

int main(int argc, char** argv) {
    android::base::InitLogging(argv, &android::base::StderrLogger);
    benchmark::Initialize(&argc, argv);
    benchmark::RunSpecifiedBenchmarks();
    return 0;
}
