blob: 536746e21dd3a2be590a8e0d528c1b5e277dfd87 [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <fuchsia/net/c/fidl.h>
#include <lib/zx/socket.h>
#include <lib/zxs/protocol.h>
#include <lib/zxs/zxs.h>
#include <stdlib.h>
#include <string.h>
#include <zircon/assert.h>
#include <zircon/device/ioctl.h>
#include <zircon/syscalls.h>
zx_status_t zxs_socket(zx_handle_t socket, zxs_socket_t* out_socket) {
zx_info_socket_t info = {};
zx_status_t status = zx_object_get_info(socket, ZX_INFO_SOCKET, &info,
sizeof(info), NULL, NULL);
if (status != ZX_OK) {
return status;
}
out_socket->socket = socket;
out_socket->flags = 0;
if (info.options & ZX_SOCKET_DATAGRAM) {
out_socket->flags |= ZXS_FLAG_DATAGRAM;
}
return ZX_OK;
}
zx_status_t zxs_close(const zxs_socket_t* socket) {
int16_t out_code;
zx_status_t status = fuchsia_net_SocketControlClose(
socket->socket, &out_code);
if (status != ZX_OK) {
return status;
}
zx_handle_close(socket->socket);
if (out_code) {
// TODO(tamird): we can't use errno_to_fdio_status because fdio
// depends on zxs.
return ZX_ERR_INTERNAL;
}
return ZX_OK;
}
static zx_status_t zxs_write(const zxs_socket_t* socket, const void* buffer,
size_t capacity, size_t* out_actual) {
return zx_socket_write(socket->socket, 0, buffer, capacity, out_actual);
}
static zx_status_t zxs_read(const zxs_socket_t* socket, void* buffer,
size_t capacity, size_t* out_actual) {
zx_status_t status = zx_socket_read(socket->socket, 0, buffer, capacity,
out_actual);
if (status == ZX_ERR_PEER_CLOSED || status == ZX_ERR_BAD_STATE) {
*out_actual = 0u;
return ZX_OK;
}
return status;
}
static zx_status_t zxs_sendmsg_stream(const zxs_socket_t* socket,
const struct msghdr* msg,
size_t* out_actual) {
size_t total = 0u;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
if (iov->iov_len <= 0) {
return ZX_ERR_INVALID_ARGS;
}
size_t actual = 0u;
zx_status_t status = zxs_write(socket, iov->iov_base, iov->iov_len,
&actual);
if (status != ZX_OK) {
if (total > 0) {
break;
}
return status;
}
total += actual;
if (actual != iov->iov_len) {
break;
}
}
*out_actual = total;
return ZX_OK;
}
static zx_status_t zxs_sendmsg_dgram(const zxs_socket_t* socket,
const struct msghdr* msg,
size_t* out_actual) {
size_t total = 0u;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
if (iov->iov_len <= 0) {
return ZX_ERR_INVALID_ARGS;
}
total += iov->iov_len;
}
size_t encoded_size = total + FDIO_SOCKET_MSG_HEADER_SIZE;
// TODO: avoid malloc m
fdio_socket_msg_t* m = static_cast<fdio_socket_msg_t*>(malloc(encoded_size));
if (msg->msg_name != nullptr) {
// TODO(abarth): Validate msg->msg_namelen against sizeof(m->addr).
memcpy(&m->addr, msg->msg_name, msg->msg_namelen);
}
m->addrlen = msg->msg_namelen;
m->flags = 0;
char* data = m->data;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
memcpy(data, iov->iov_base, iov->iov_len);
data += iov->iov_len;
}
size_t actual = 0u;
zx_status_t status = zxs_write(socket, m, encoded_size, &actual);
free(m);
if (status == ZX_OK) {
*out_actual = total;
}
return status;
}
static zx_status_t zxs_recvmsg_stream(const zxs_socket_t* socket,
struct msghdr* msg,
size_t* out_actual) {
size_t total = 0u;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
size_t actual = 0u;
zx_status_t status = zxs_read(socket, iov->iov_base, iov->iov_len,
&actual);
if (status != ZX_OK) {
if (total > 0) {
break;
}
return status;
}
total += actual;
if (actual != iov->iov_len) {
break;
}
}
*out_actual = total;
return ZX_OK;
}
static zx_status_t zxs_recvmsg_dgram(const zxs_socket_t* socket,
struct msghdr* msg,
size_t* out_actual) {
// Read 1 extra byte to detect if the buffer is too small to fit the whole
// packet, so we can set MSG_TRUNC flag if necessary.
size_t encoded_size = FDIO_SOCKET_MSG_HEADER_SIZE + 1;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
if (iov->iov_len <= 0) {
return ZX_ERR_INVALID_ARGS;
}
encoded_size += iov->iov_len;
}
// TODO: avoid malloc
fdio_socket_msg_t* m = static_cast<fdio_socket_msg_t*>(malloc(encoded_size));
size_t actual = 0u;
zx_status_t status = zxs_read(socket, m, encoded_size, &actual);
if (status != ZX_OK) {
free(m);
return status;
}
if (actual < FDIO_SOCKET_MSG_HEADER_SIZE) {
free(m);
return ZX_ERR_INTERNAL;
}
actual -= FDIO_SOCKET_MSG_HEADER_SIZE;
if (msg->msg_name != nullptr) {
int bytes_to_copy = (msg->msg_namelen < m->addrlen) ? msg->msg_namelen : m->addrlen;
memcpy(msg->msg_name, &m->addr, bytes_to_copy);
}
msg->msg_namelen = m->addrlen;
msg->msg_flags = m->flags;
char* data = m->data;
size_t remaining = actual;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
if (remaining == 0) {
iov->iov_len = 0;
} else {
if (remaining < iov->iov_len)
iov->iov_len = remaining;
memcpy(iov->iov_base, data, iov->iov_len);
data += iov->iov_len;
remaining -= iov->iov_len;
}
}
if (remaining > 0) {
msg->msg_flags |= MSG_TRUNC;
actual -= remaining;
}
free(m);
*out_actual = actual;
return ZX_OK;
}
zx_status_t zxs_send(const zxs_socket_t* socket, const void* buffer,
size_t capacity, size_t* out_actual) {
if (socket->flags & ZXS_FLAG_DATAGRAM) {
struct iovec iov;
iov.iov_base = const_cast<void*>(buffer);
iov.iov_len = capacity;
struct msghdr msg;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
return zxs_sendmsg_dgram(socket, &msg, out_actual);
} else {
return zxs_write(socket, buffer, capacity, out_actual);
}
}
zx_status_t zxs_recv(const zxs_socket_t* socket, void* buffer,
size_t capacity, size_t* out_actual) {
if (socket->flags & ZXS_FLAG_DATAGRAM) {
struct iovec iov;
iov.iov_base = buffer;
iov.iov_len = capacity;
struct msghdr msg;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
return zxs_recvmsg_dgram(socket, &msg, out_actual);
} else {
return zxs_read(socket, buffer, capacity, out_actual);
}
}
zx_status_t zxs_sendto(const zxs_socket_t* socket, const struct sockaddr* addr,
size_t addr_length, const void* buffer, size_t capacity,
size_t* out_actual) {
struct iovec iov;
iov.iov_base = const_cast<void*>(buffer);
iov.iov_len = capacity;
struct msghdr msg;
msg.msg_name = const_cast<struct sockaddr*>(addr);
msg.msg_namelen = static_cast<socklen_t>(addr_length);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0; // this field is ignored
return zxs_sendmsg(socket, &msg, out_actual);
}
zx_status_t zxs_recvfrom(const zxs_socket_t* socket, struct sockaddr* addr,
size_t addr_capacity, size_t* out_addr_actual,
void* buffer, size_t capacity, size_t* out_actual) {
struct iovec iov;
iov.iov_base = buffer;
iov.iov_len = capacity;
struct msghdr msg;
msg.msg_name = addr;
msg.msg_namelen = static_cast<socklen_t>(addr_capacity);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
zx_status_t status = zxs_recvmsg(socket, &msg, out_actual);
*out_addr_actual = msg.msg_namelen;
return status;
}
zx_status_t zxs_sendmsg(const zxs_socket_t* socket, const struct msghdr* msg,
size_t* out_actual) {
if (socket->flags & ZXS_FLAG_DATAGRAM) {
return zxs_sendmsg_dgram(socket, msg, out_actual);
} else {
return zxs_sendmsg_stream(socket, msg, out_actual);
}
}
zx_status_t zxs_recvmsg(const zxs_socket_t* socket, struct msghdr* msg,
size_t* out_actual) {
if (socket->flags & ZXS_FLAG_DATAGRAM) {
return zxs_recvmsg_dgram(socket, msg, out_actual);
} else {
return zxs_recvmsg_stream(socket, msg, out_actual);
}
}