// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <fuchsia/net/c/fidl.h>
#include <lib/zx/socket.h>
#include <lib/zxs/protocol.h>
#include <lib/zxs/zxs.h>
#include <stdlib.h>
#include <string.h>
#include <zircon/assert.h>
#include <zircon/device/ioctl.h>
#include <zircon/syscalls.h>

zx_status_t zxs_socket(zx_handle_t socket, zxs_socket_t* out_socket) {
    zx_info_socket_t info = {};
    zx_status_t status = zx_object_get_info(socket, ZX_INFO_SOCKET, &info,
                                            sizeof(info), NULL, NULL);
    if (status != ZX_OK) {
        return status;
    }

    out_socket->socket = socket;
    out_socket->flags = 0;
    if (info.options & ZX_SOCKET_DATAGRAM) {
        out_socket->flags |= ZXS_FLAG_DATAGRAM;
    }

    return ZX_OK;
}

zx_status_t zxs_close(const zxs_socket_t* socket) {
    int16_t out_code;
    zx_status_t status = fuchsia_net_SocketControlClose(
        socket->socket, &out_code);
    if (status != ZX_OK) {
        return status;
    }
    zx_handle_close(socket->socket);
    if (out_code) {
        // TODO(tamird): we can't use errno_to_fdio_status because fdio
        //  depends on zxs.
        return ZX_ERR_INTERNAL;
    }
    return ZX_OK;
}

static zx_status_t zxs_write(const zxs_socket_t* socket, const void* buffer,
                             size_t capacity, size_t* out_actual) {
    return zx_socket_write(socket->socket, 0, buffer, capacity, out_actual);
}

static zx_status_t zxs_read(const zxs_socket_t* socket, void* buffer,
                            size_t capacity, size_t* out_actual) {
    zx_status_t status = zx_socket_read(socket->socket, 0, buffer, capacity,
                                        out_actual);
    if (status == ZX_ERR_PEER_CLOSED || status == ZX_ERR_BAD_STATE) {
        *out_actual = 0u;
        return ZX_OK;
    }
    return status;
}

static zx_status_t zxs_sendmsg_stream(const zxs_socket_t* socket,
                                      const struct msghdr* msg,
                                      size_t* out_actual) {
    size_t total = 0u;
    for (int i = 0; i < msg->msg_iovlen; i++) {
        struct iovec* iov = &msg->msg_iov[i];
        if (iov->iov_len <= 0) {
            return ZX_ERR_INVALID_ARGS;
        }
        size_t actual = 0u;
        zx_status_t status = zxs_write(socket, iov->iov_base, iov->iov_len,
                                       &actual);
        if (status != ZX_OK) {
            if (total > 0) {
                break;
            }
            return status;
        }
        total += actual;
        if (actual != iov->iov_len) {
            break;
        }
    }
    *out_actual = total;
    return ZX_OK;
}

static zx_status_t zxs_sendmsg_dgram(const zxs_socket_t* socket,
                                     const struct msghdr* msg,
                                     size_t* out_actual) {
    size_t total = 0u;
    for (int i = 0; i < msg->msg_iovlen; i++) {
        struct iovec* iov = &msg->msg_iov[i];
        if (iov->iov_len <= 0) {
            return ZX_ERR_INVALID_ARGS;
        }
        total += iov->iov_len;
    }
    size_t encoded_size = total + FDIO_SOCKET_MSG_HEADER_SIZE;

    // TODO: avoid malloc m
    fdio_socket_msg_t* m = static_cast<fdio_socket_msg_t*>(malloc(encoded_size));
    if (msg->msg_name != nullptr) {
        // TODO(abarth): Validate msg->msg_namelen against sizeof(m->addr).
        memcpy(&m->addr, msg->msg_name, msg->msg_namelen);
    }
    m->addrlen = msg->msg_namelen;
    m->flags = 0;
    char* data = m->data;
    for (int i = 0; i < msg->msg_iovlen; i++) {
        struct iovec* iov = &msg->msg_iov[i];
        memcpy(data, iov->iov_base, iov->iov_len);
        data += iov->iov_len;
    }
    size_t actual = 0u;
    zx_status_t status = zxs_write(socket, m, encoded_size, &actual);
    free(m);
    if (status == ZX_OK) {
        *out_actual = total;
    }
    return status;
}

static zx_status_t zxs_recvmsg_stream(const zxs_socket_t* socket,
                                      struct msghdr* msg,
                                      size_t* out_actual) {
    size_t total = 0u;
    for (int i = 0; i < msg->msg_iovlen; i++) {
        struct iovec* iov = &msg->msg_iov[i];
        size_t actual = 0u;
        zx_status_t status = zxs_read(socket, iov->iov_base, iov->iov_len,
                                      &actual);
        if (status != ZX_OK) {
            if (total > 0) {
                break;
            }
            return status;
        }
        total += actual;
        if (actual != iov->iov_len) {
            break;
        }
    }
    *out_actual = total;
    return ZX_OK;
}

static zx_status_t zxs_recvmsg_dgram(const zxs_socket_t* socket,
                                     struct msghdr* msg,
                                     size_t* out_actual) {
    // Read 1 extra byte to detect if the buffer is too small to fit the whole
    // packet, so we can set MSG_TRUNC flag if necessary.
    size_t encoded_size = FDIO_SOCKET_MSG_HEADER_SIZE + 1;
    for (int i = 0; i < msg->msg_iovlen; i++) {
        struct iovec* iov = &msg->msg_iov[i];
        if (iov->iov_len <= 0) {
            return ZX_ERR_INVALID_ARGS;
        }
        encoded_size += iov->iov_len;
    }

    // TODO: avoid malloc
    fdio_socket_msg_t* m = static_cast<fdio_socket_msg_t*>(malloc(encoded_size));
    size_t actual = 0u;
    zx_status_t status = zxs_read(socket, m, encoded_size, &actual);
    if (status != ZX_OK) {
        free(m);
        return status;
    }
    if (actual < FDIO_SOCKET_MSG_HEADER_SIZE) {
        free(m);
        return ZX_ERR_INTERNAL;
    }
    actual -= FDIO_SOCKET_MSG_HEADER_SIZE;
    if (msg->msg_name != nullptr) {
        int bytes_to_copy = (msg->msg_namelen < m->addrlen) ? msg->msg_namelen : m->addrlen;
        memcpy(msg->msg_name, &m->addr, bytes_to_copy);
    }
    msg->msg_namelen = m->addrlen;
    msg->msg_flags = m->flags;
    char* data = m->data;
    size_t remaining = actual;
    for (int i = 0; i < msg->msg_iovlen; i++) {
        struct iovec* iov = &msg->msg_iov[i];
        if (remaining == 0) {
            iov->iov_len = 0;
        } else {
            if (remaining < iov->iov_len)
                iov->iov_len = remaining;
            memcpy(iov->iov_base, data, iov->iov_len);
            data += iov->iov_len;
            remaining -= iov->iov_len;
        }
    }

    if (remaining > 0) {
        msg->msg_flags |= MSG_TRUNC;
        actual -= remaining;
    }

    free(m);
    *out_actual = actual;
    return ZX_OK;
}

zx_status_t zxs_send(const zxs_socket_t* socket, const void* buffer,
                     size_t capacity, size_t* out_actual) {
    if (socket->flags & ZXS_FLAG_DATAGRAM) {
        struct iovec iov;
        iov.iov_base = const_cast<void*>(buffer);
        iov.iov_len = capacity;

        struct msghdr msg;
        msg.msg_name = nullptr;
        msg.msg_namelen = 0;
        msg.msg_iov = &iov;
        msg.msg_iovlen = 1;
        msg.msg_control = nullptr;
        msg.msg_controllen = 0;
        msg.msg_flags = 0;

        return zxs_sendmsg_dgram(socket, &msg, out_actual);
    } else {
        return zxs_write(socket, buffer, capacity, out_actual);
    }
}

zx_status_t zxs_recv(const zxs_socket_t* socket, void* buffer,
                     size_t capacity, size_t* out_actual) {
    if (socket->flags & ZXS_FLAG_DATAGRAM) {
        struct iovec iov;
        iov.iov_base = buffer;
        iov.iov_len = capacity;

        struct msghdr msg;
        msg.msg_name = nullptr;
        msg.msg_namelen = 0;
        msg.msg_iov = &iov;
        msg.msg_iovlen = 1;
        msg.msg_control = nullptr;
        msg.msg_controllen = 0;
        msg.msg_flags = 0;

        return zxs_recvmsg_dgram(socket, &msg, out_actual);
    } else {
        return zxs_read(socket, buffer, capacity, out_actual);
    }
}

zx_status_t zxs_sendto(const zxs_socket_t* socket, const struct sockaddr* addr,
                       size_t addr_length, const void* buffer, size_t capacity,
                       size_t* out_actual) {
    struct iovec iov;
    iov.iov_base = const_cast<void*>(buffer);
    iov.iov_len = capacity;

    struct msghdr msg;
    msg.msg_name = const_cast<struct sockaddr*>(addr);
    msg.msg_namelen = static_cast<socklen_t>(addr_length);
    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;
    msg.msg_control = nullptr;
    msg.msg_controllen = 0;
    msg.msg_flags = 0; // this field is ignored

    return zxs_sendmsg(socket, &msg, out_actual);
}

zx_status_t zxs_recvfrom(const zxs_socket_t* socket, struct sockaddr* addr,
                         size_t addr_capacity, size_t* out_addr_actual,
                         void* buffer, size_t capacity, size_t* out_actual) {
    struct iovec iov;
    iov.iov_base = buffer;
    iov.iov_len = capacity;

    struct msghdr msg;
    msg.msg_name = addr;
    msg.msg_namelen = static_cast<socklen_t>(addr_capacity);
    msg.msg_iov = &iov;
    msg.msg_iovlen = 1;
    msg.msg_control = nullptr;
    msg.msg_controllen = 0;
    msg.msg_flags = 0;

    zx_status_t status = zxs_recvmsg(socket, &msg, out_actual);
    *out_addr_actual = msg.msg_namelen;
    return status;
}

zx_status_t zxs_sendmsg(const zxs_socket_t* socket, const struct msghdr* msg,
                        size_t* out_actual) {
    if (socket->flags & ZXS_FLAG_DATAGRAM) {
        return zxs_sendmsg_dgram(socket, msg, out_actual);
    } else {
        return zxs_sendmsg_stream(socket, msg, out_actual);
    }
}
zx_status_t zxs_recvmsg(const zxs_socket_t* socket, struct msghdr* msg,
                        size_t* out_actual) {
    if (socket->flags & ZXS_FLAG_DATAGRAM) {
        return zxs_recvmsg_dgram(socket, msg, out_actual);
    } else {
        return zxs_recvmsg_stream(socket, msg, out_actual);
    }
}
