blob: 64b2566cc36c50ccfd821659d2dbeabd49fca691 [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <fuchsia/net/c/fidl.h>
#include <lib/zx/socket.h>
#include <lib/zxs/protocol.h>
#include <lib/zxs/zxs.h>
#include <stdlib.h>
#include <string.h>
#include <zircon/assert.h>
#include <zircon/device/ioctl.h>
#include <zircon/syscalls.h>
static bool is_rio_message_valid(zxsio_msg_t* msg) {
if ((msg->datalen > ZXSIO_PAYLOAD_SZ) ||
(msg->hcount > 0)) {
return false;
}
return true;
}
static bool is_rio_message_reply_valid(zxsio_msg_t* msg, size_t size) {
if ((size < ZXSIO_HDR_SZ) ||
(msg->datalen != (size - ZXSIO_HDR_SZ))) {
return false;
}
return is_rio_message_valid(msg);
}
static ssize_t zxsio_write_control(zx_handle_t socket, zxsio_msg_t* msg) {
for (;;) {
ssize_t r;
size_t len = ZXSIO_HDR_SZ + msg->datalen;
if ((r = zx_socket_write(socket, ZX_SOCKET_CONTROL, msg, len, &len)) == ZX_OK) {
return static_cast<ssize_t>(len);
}
// If the socket has no control plane then control messages are not
// supported.
if (r == ZX_ERR_BAD_STATE) {
return ZX_ERR_NOT_SUPPORTED;
}
if (r == ZX_ERR_SHOULD_WAIT) {
zx_signals_t pending;
r = zx_object_wait_one(socket,
ZX_SOCKET_CONTROL_WRITABLE | ZX_SOCKET_PEER_CLOSED,
ZX_TIME_INFINITE, &pending);
if (r < 0) {
return r;
}
if (pending & ZX_SOCKET_PEER_CLOSED) {
return ZX_ERR_PEER_CLOSED;
}
if (pending & ZX_SOCKET_CONTROL_WRITABLE) {
continue;
}
// impossible
return ZX_ERR_INTERNAL;
}
return r;
}
}
static ssize_t zxsio_read_control(zx_handle_t socket, void* data, size_t len) {
// TODO: let the generic read() to do this loop
for (;;) {
ssize_t r;
size_t bytes_read;
if ((r = zx_socket_read(socket, ZX_SOCKET_CONTROL, data, len, &bytes_read)) == ZX_OK) {
// zx_socket_read() sets *actual to the number of bytes in the buffer when data is NULL
// and len is 0. read() should return 0 in that case.
if (len == 0) {
return 0;
} else {
return static_cast<ssize_t>(bytes_read);
}
}
if (r == ZX_ERR_PEER_CLOSED || r == ZX_ERR_BAD_STATE) {
return 0;
} else if (r == ZX_ERR_SHOULD_WAIT) {
zx_signals_t pending;
r = zx_object_wait_one(socket,
ZX_SOCKET_CONTROL_READABLE | ZX_SOCKET_PEER_CLOSED,
ZX_TIME_INFINITE, &pending);
if (r < 0) {
return r;
}
if (pending & ZX_SOCKET_CONTROL_READABLE) {
continue;
}
if (pending & ZX_SOCKET_PEER_CLOSED) {
return 0;
}
// impossible
return ZX_ERR_INTERNAL;
}
return r;
}
}
static zx_status_t zxsio_txn(zx_handle_t socket, zxsio_msg_t* msg) {
if (!is_rio_message_valid(msg)) {
return ZX_ERR_INVALID_ARGS;
}
ssize_t r = zxsio_write_control(socket, msg);
if (r < 0)
return static_cast<zx_status_t>(r);
const uint32_t request_op = msg->op;
r = zxsio_read_control(socket, msg, sizeof(*msg));
if (r < 0)
return static_cast<zx_status_t>(r);
size_t dsize = (size_t)r;
// check for protocol errors
if (!is_rio_message_reply_valid(msg, dsize) || (msg->op != request_op)) {
return ZX_ERR_IO;
}
return msg->arg;
}
static zx_status_t zxsio_op(zx_handle_t socket, uint32_t op, int64_t off,
uint32_t maxreply, void* buffer, size_t length) {
if ((length > ZXSIO_PAYLOAD_SZ) || (maxreply > ZXSIO_PAYLOAD_SZ)) {
return ZX_ERR_INVALID_ARGS;
}
switch (op) {
case ZXSIO_GETSOCKNAME:
case ZXSIO_GETPEERNAME:
case ZXSIO_GETSOCKOPT:
case ZXSIO_SETSOCKOPT:
case ZXSIO_CONNECT:
case ZXSIO_BIND:
case ZXSIO_LISTEN:
break;
default:
return ZX_ERR_NOT_SUPPORTED;
}
zxsio_msg_t msg;
memset(&msg, 0, ZXSIO_HDR_SZ);
msg.op = op;
msg.arg = maxreply;
msg.arg2.off = off;
msg.datalen = static_cast<uint32_t>(length);
if (buffer && length > 0) {
memcpy(msg.data, buffer, length);
}
zx_status_t status = zxsio_txn(socket, &msg);
if (status < 0) {
return status;
}
if (msg.datalen > maxreply) {
return ZX_ERR_IO;
}
if (buffer && msg.datalen > 0) {
memcpy(buffer, msg.data, msg.datalen);
}
return status;
}
zx_status_t zxs_socket(zx_handle_t socket_provider,
fuchsia_net_SocketDomain domain,
fuchsia_net_SocketType type,
fuchsia_net_SocketProtocol protocol,
const zxs_option_t* options,
size_t options_count,
zxs_socket_t* out_socket) {
zxs_socket_t socket = {};
zx_status_t io_status, status;
io_status = fuchsia_net_LegacySocketProviderOpenSocket(
socket_provider, domain, type, protocol, &socket.socket, &status);
if (io_status != ZX_OK) {
return ZX_ERR_IO;
}
if (status != ZX_OK) {
return status;
}
if (type == fuchsia_net_SocketType_dgram) {
socket.flags |= ZXS_FLAG_DATAGRAM;
}
status = zxs_setsockopts(&socket, options, options_count);
if (status != ZX_OK) {
zxs_close(&socket);
return status;
}
*out_socket = socket;
return ZX_OK;
}
zx_status_t zxs_close(const zxs_socket_t* socket) {
zxsio_msg_t msg;
memset(&msg, 0, ZXSIO_HDR_SZ);
msg.op = ZXSIO_CLOSE;
zx_status_t status = zxsio_txn(socket->socket, &msg);
zx_handle_close(socket->socket);
return status;
}
zx_status_t zxs_connect(const zxs_socket_t* socket, const struct sockaddr* addr,
size_t addr_length) {
return zxsio_op(socket->socket, ZXSIO_CONNECT, 0, 0,
const_cast<struct sockaddr*>(addr), addr_length);
}
zx_status_t zxs_bind(const zxs_socket_t* socket, const struct sockaddr* addr,
size_t addr_length) {
return zxsio_op(socket->socket, ZXSIO_BIND, 0, 0,
const_cast<struct sockaddr*>(addr), addr_length);
}
zx_status_t zxs_listen(const zxs_socket_t* socket, uint32_t backlog) {
return zxsio_op(socket->socket, ZXSIO_LISTEN, 0, 0, &backlog,
sizeof(backlog));
}
zx_status_t zxs_accept(const zxs_socket_t* socket, struct sockaddr* addr,
size_t addr_capacity, size_t* out_addr_actual,
zxs_socket_t* out_socket) {
zx::socket accepted;
zx_status_t status = zx_socket_accept(socket->socket,
accepted.reset_and_get_address());
if (status != ZX_OK) {
return status;
}
zxs_socket_t accepted_socket = {
.socket = accepted.release(),
.flags = 0u,
};
status = zxs_getpeername(&accepted_socket, addr, addr_capacity, out_addr_actual);
if (status != ZX_OK) {
zx_handle_close(accepted_socket.socket);
accepted_socket.socket = ZX_HANDLE_INVALID;
return status;
}
*out_socket = accepted_socket;
return ZX_OK;
}
zx_status_t zxs_getsockname(const zxs_socket_t* socket, struct sockaddr* addr,
size_t capacity, size_t* out_actual) {
zxrio_sockaddr_reply_t reply = {};
zx_status_t status = zxsio_op(socket->socket, ZXSIO_GETSOCKNAME, 0,
sizeof(zxrio_sockaddr_reply_t), &reply,
sizeof(reply));
if (status != ZX_OK) {
return status;
}
*out_actual = reply.len;
memcpy(addr, &reply.addr, (capacity < reply.len) ? capacity : reply.len);
return status;
}
zx_status_t zxs_getpeername(const zxs_socket_t* socket, struct sockaddr* addr,
size_t capacity, size_t* out_actual) {
zxrio_sockaddr_reply_t reply = {};
zx_status_t status = zxsio_op(socket->socket, ZXSIO_GETPEERNAME, 0,
sizeof(zxrio_sockaddr_reply_t), &reply,
sizeof(reply));
if (status != ZX_OK) {
return status;
}
*out_actual = reply.len;
memcpy(addr, &reply.addr, (capacity < reply.len) ? capacity : reply.len);
return status;
}
zx_status_t zxs_getsockopt(const zxs_socket_t* socket, int32_t level,
int32_t name, void* buffer, size_t capacity,
size_t* out_actual) {
zxrio_sockopt_req_reply_t req_reply;
memset(&req_reply, 0, sizeof(req_reply));
req_reply.level = level;
req_reply.optname = name;
zx_status_t status = zxsio_op(socket->socket, ZXSIO_GETSOCKOPT, 0,
sizeof(req_reply), &req_reply,
sizeof(req_reply));
if (status < 0) {
return status;
}
size_t actual = (capacity < req_reply.optlen) ? capacity : req_reply.optlen;
memcpy(buffer, req_reply.optval, actual);
// Notice that |*out_actual| could be larger than |capacity| if the server
// misbehaves. It would be safer to set |*out_actual| to |actual|.
*out_actual = req_reply.optlen;
return ZX_OK;
}
zx_status_t zxs_setsockopts(const zxs_socket_t* socket,
const zxs_option_t* options,
size_t count) {
for (size_t i = 0u; i < count; ++i) {
zxrio_sockopt_req_reply_t request;
memset(&request, 0, sizeof(request));
request.level = options[i].level;
request.optname = options[i].name;
size_t length = options[i].length;
if (length > sizeof(request.optval)) {
return ZX_ERR_INVALID_ARGS;
}
memcpy(request.optval, options[i].value, length);
request.optlen = static_cast<socklen_t>(length);
zx_status_t status = zxsio_op(socket->socket, ZXSIO_SETSOCKOPT, 0, 0,
&request, sizeof(request));
if (status != ZX_OK) {
return status;
}
}
return ZX_OK;
}
static zx_status_t zxs_write(const zxs_socket_t* socket, const void* buffer,
size_t capacity, size_t* out_actual) {
return zx_socket_write(socket->socket, 0, buffer, capacity, out_actual);
}
static zx_status_t zxs_read(const zxs_socket_t* socket, void* buffer,
size_t capacity, size_t* out_actual) {
zx_status_t status = zx_socket_read(socket->socket, 0, buffer, capacity,
out_actual);
if (status == ZX_ERR_PEER_CLOSED || status == ZX_ERR_BAD_STATE) {
*out_actual = 0u;
return ZX_OK;
}
return status;
}
static zx_status_t zxs_sendmsg_stream(const zxs_socket_t* socket,
const struct msghdr* msg,
size_t* out_actual) {
size_t total = 0u;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
if (iov->iov_len <= 0) {
return ZX_ERR_INVALID_ARGS;
}
size_t actual = 0u;
zx_status_t status = zxs_write(socket, iov->iov_base, iov->iov_len,
&actual);
if (status != ZX_OK) {
if (total > 0) {
break;
}
return status;
}
total += actual;
if (actual != iov->iov_len) {
break;
}
}
*out_actual = total;
return ZX_OK;
}
static zx_status_t zxs_sendmsg_dgram(const zxs_socket_t* socket,
const struct msghdr* msg,
size_t* out_actual) {
size_t total = 0u;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
if (iov->iov_len <= 0) {
return ZX_ERR_INVALID_ARGS;
}
total += iov->iov_len;
}
size_t encoded_size = total + FDIO_SOCKET_MSG_HEADER_SIZE;
// TODO: avoid malloc m
fdio_socket_msg_t* m = static_cast<fdio_socket_msg_t*>(malloc(encoded_size));
if (msg->msg_name != nullptr) {
// TODO(abarth): Validate msg->msg_namelen against sizeof(m->addr).
memcpy(&m->addr, msg->msg_name, msg->msg_namelen);
}
m->addrlen = msg->msg_namelen;
m->flags = 0;
char* data = m->data;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
memcpy(data, iov->iov_base, iov->iov_len);
data += iov->iov_len;
}
size_t actual = 0u;
zx_status_t status = zxs_write(socket, m, encoded_size, &actual);
free(m);
if (status == ZX_OK) {
*out_actual = total;
}
return status;
}
static zx_status_t zxs_recvmsg_stream(const zxs_socket_t* socket,
struct msghdr* msg,
size_t* out_actual) {
size_t total = 0u;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
size_t actual = 0u;
zx_status_t status = zxs_read(socket, iov->iov_base, iov->iov_len,
&actual);
if (status != ZX_OK) {
if (total > 0) {
break;
}
return status;
}
total += actual;
if (actual != iov->iov_len) {
break;
}
}
*out_actual = total;
return ZX_OK;
}
static zx_status_t zxs_recvmsg_dgram(const zxs_socket_t* socket,
struct msghdr* msg,
size_t* out_actual) {
// Read 1 extra byte to detect if the buffer is too small to fit the whole
// packet, so we can set MSG_TRUNC flag if necessary.
size_t encoded_size = FDIO_SOCKET_MSG_HEADER_SIZE + 1;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
if (iov->iov_len <= 0) {
return ZX_ERR_INVALID_ARGS;
}
encoded_size += iov->iov_len;
}
// TODO: avoid malloc
fdio_socket_msg_t* m = static_cast<fdio_socket_msg_t*>(malloc(encoded_size));
size_t actual = 0u;
zx_status_t status = zxs_read(socket, m, encoded_size, &actual);
if (status != ZX_OK) {
free(m);
return status;
}
if (actual < FDIO_SOCKET_MSG_HEADER_SIZE) {
free(m);
return ZX_ERR_INTERNAL;
}
actual -= FDIO_SOCKET_MSG_HEADER_SIZE;
if (msg->msg_name != nullptr) {
int bytes_to_copy = (msg->msg_namelen < m->addrlen) ? msg->msg_namelen : m->addrlen;
memcpy(msg->msg_name, &m->addr, bytes_to_copy);
}
msg->msg_namelen = m->addrlen;
msg->msg_flags = m->flags;
char* data = m->data;
size_t remaining = actual;
for (int i = 0; i < msg->msg_iovlen; i++) {
struct iovec* iov = &msg->msg_iov[i];
if (remaining == 0) {
iov->iov_len = 0;
} else {
if (remaining < iov->iov_len)
iov->iov_len = remaining;
memcpy(iov->iov_base, data, iov->iov_len);
data += iov->iov_len;
remaining -= iov->iov_len;
}
}
if (remaining > 0) {
msg->msg_flags |= MSG_TRUNC;
actual -= remaining;
}
free(m);
*out_actual = actual;
return ZX_OK;
}
zx_status_t zxs_send(const zxs_socket_t* socket, const void* buffer,
size_t capacity, size_t* out_actual) {
if (socket->flags & ZXS_FLAG_DATAGRAM) {
struct iovec iov;
iov.iov_base = const_cast<void*>(buffer);
iov.iov_len = capacity;
struct msghdr msg;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
return zxs_sendmsg_dgram(socket, &msg, out_actual);
} else {
return zxs_write(socket, buffer, capacity, out_actual);
}
}
zx_status_t zxs_recv(const zxs_socket_t* socket, void* buffer,
size_t capacity, size_t* out_actual) {
if (socket->flags & ZXS_FLAG_DATAGRAM) {
struct iovec iov;
iov.iov_base = buffer;
iov.iov_len = capacity;
struct msghdr msg;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
return zxs_recvmsg_dgram(socket, &msg, out_actual);
} else {
return zxs_read(socket, buffer, capacity, out_actual);
}
}
zx_status_t zxs_sendto(const zxs_socket_t* socket, const struct sockaddr* addr,
size_t addr_length, const void* buffer, size_t capacity,
size_t* out_actual) {
struct iovec iov;
iov.iov_base = const_cast<void*>(buffer);
iov.iov_len = capacity;
struct msghdr msg;
msg.msg_name = const_cast<struct sockaddr*>(addr);
msg.msg_namelen = static_cast<socklen_t>(addr_length);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0; // this field is ignored
return zxs_sendmsg(socket, &msg, out_actual);
}
zx_status_t zxs_recvfrom(const zxs_socket_t* socket, struct sockaddr* addr,
size_t addr_capacity, size_t* out_addr_actual,
void* buffer, size_t capacity, size_t* out_actual) {
struct iovec iov;
iov.iov_base = buffer;
iov.iov_len = capacity;
struct msghdr msg;
msg.msg_name = addr;
msg.msg_namelen = static_cast<socklen_t>(addr_capacity);
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = nullptr;
msg.msg_controllen = 0;
msg.msg_flags = 0;
zx_status_t status = zxs_recvmsg(socket, &msg, out_actual);
*out_addr_actual = msg.msg_namelen;
return status;
}
zx_status_t zxs_sendmsg(const zxs_socket_t* socket, const struct msghdr* msg,
size_t* out_actual) {
if (socket->flags & ZXS_FLAG_DATAGRAM) {
return zxs_sendmsg_dgram(socket, msg, out_actual);
} else {
return zxs_sendmsg_stream(socket, msg, out_actual);
}
}
zx_status_t zxs_recvmsg(const zxs_socket_t* socket, struct msghdr* msg,
size_t* out_actual) {
if (socket->flags & ZXS_FLAG_DATAGRAM) {
return zxs_recvmsg_dgram(socket, msg, out_actual);
} else {
return zxs_recvmsg_stream(socket, msg, out_actual);
}
}
zx_status_t zxs_ioctl(const zxs_socket_t* socket, uint32_t op,
const void* in_buffer, size_t in_capacity,
void* out_buffer, size_t out_capacity,
size_t* out_actual) {
if (in_capacity > ZXSIO_PAYLOAD_SZ || out_capacity > ZXSIO_PAYLOAD_SZ) {
return ZX_ERR_INVALID_ARGS;
}
if (IOCTL_KIND(op) != IOCTL_KIND_DEFAULT) {
return ZX_ERR_NOT_SUPPORTED;
}
zxsio_msg_t msg;
memset(&msg, 0, ZXSIO_HDR_SZ);
msg.op = ZXSIO_IOCTL;
msg.datalen = static_cast<uint32_t>(in_capacity);
msg.arg = static_cast<uint32_t>(out_capacity);
msg.arg2.op = op;
memcpy(msg.data, in_buffer, in_capacity);
zx_status_t status = zxsio_txn(socket->socket, &msg);
if (status < 0) {
return status;
}
size_t copy_length = msg.datalen;
if (msg.datalen > out_capacity) {
copy_length = out_capacity;
}
memcpy(out_buffer, msg.data, copy_length);
*out_actual = copy_length;
return status;
}