blob: 24171a5dcbb21ba8d31efe7390d5e9591c480e05 [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "socket.h"
#include <assert.h>
#include <fuchsia/hardware/vsock/c/fidl.h>
#include <lib/async/cpp/task.h>
#include <limits.h>
#include <stddef.h>
#include <stdint.h>
#include <string.h>
#include <zircon/assert.h>
#include <zircon/status.h>
#include <zircon/types.h>
#include <algorithm>
#include <memory>
#include <ddk/debug.h>
#include <ddk/io-buffer.h>
#include <fbl/algorithm.h>
#include <fbl/alloc_checker.h>
#include <fbl/auto_call.h>
#include <fbl/auto_lock.h>
#include <pretty/hexdump.h>
#include <virtio/virtio.h>
namespace virtio {
static constexpr uint16_t kDataBacklog = 32;
static constexpr uint16_t kEventBacklog = 4;
static constexpr size_t kFrameSize = sizeof(virtio_vsock_hdr_t) + 468;
static constexpr uint16_t kRxId = 0u;
static constexpr uint16_t kTxId = 1u;
static constexpr uint16_t kEventId = 2u;
// Wrappers for linking generated C fidl interfaces to the actual message handlers
// in a SocketDevice
static zx_status_t fidl_Start(void* ctx, zx_handle_t cb, fidl_txn_t* txn) {
virtio::SocketDevice* sock = static_cast<virtio::SocketDevice*>(ctx);
sock->MessageStart(zx::channel(cb));
return fuchsia_hardware_vsock_DeviceStart_reply(txn, ZX_OK);
}
static zx_status_t fidl_SendRequest(void* ctx, const vsock_Addr* addr, zx_handle_t socket,
fidl_txn_t* txn) {
virtio::SocketDevice* sock = static_cast<virtio::SocketDevice*>(ctx);
zx_status_t status = sock->MessageSendRequest(*addr, zx::socket(socket));
return fuchsia_hardware_vsock_DeviceSendRequest_reply(txn, status);
}
static zx_status_t fidl_SendShutdown(void* ctx, const vsock_Addr* addr, fidl_txn_t* txn) {
virtio::SocketDevice* sock = static_cast<virtio::SocketDevice*>(ctx);
zx_status_t status = sock->MessageSendShutdown(*addr);
return fuchsia_hardware_vsock_DeviceSendShutdown_reply(txn, status);
}
static zx_status_t fidl_SendRst(void* ctx, const vsock_Addr* addr, fidl_txn_t* txn) {
virtio::SocketDevice* sock = static_cast<virtio::SocketDevice*>(ctx);
zx_status_t status = sock->MessageSendRst(*addr);
return fuchsia_hardware_vsock_DeviceSendRst_reply(txn, status);
}
static zx_status_t fidl_SendResponse(void* ctx, const vsock_Addr* addr, zx_handle_t socket,
fidl_txn_t* txn) {
virtio::SocketDevice* sock = static_cast<virtio::SocketDevice*>(ctx);
zx_status_t status = sock->MessageSendResponse(*addr, zx::socket(socket));
return fuchsia_hardware_vsock_DeviceSendResponse_reply(txn, status);
}
static zx_status_t fidl_GetCid(void* ctx, fidl_txn_t* txn) {
virtio::SocketDevice* sock = static_cast<virtio::SocketDevice*>(ctx);
return fuchsia_hardware_vsock_DeviceGetCid_reply(txn, sock->MessageGetCid());
}
static zx_status_t fidl_SendVmo(void* ctx, const vsock_Addr* addr, zx_handle_t vmo, uint64_t off,
uint64_t len, fidl_txn_t* txn) {
virtio::SocketDevice* sock = static_cast<virtio::SocketDevice*>(ctx);
zx_status_t status = sock->MessageSendVmo(*addr, zx::vmo(vmo), off, len);
return fuchsia_hardware_vsock_DeviceSendVmo_reply(txn, status);
}
static fuchsia_hardware_vsock_Device_ops_t fidl_ops = {
.Start = fidl_Start,
.SendRequest = fidl_SendRequest,
.SendShutdown = fidl_SendShutdown,
.SendRst = fidl_SendRst,
.SendResponse = fidl_SendResponse,
.GetCid = fidl_GetCid,
.SendVmo = fidl_SendVmo,
};
zx_status_t SocketDevice::DdkMessage(fidl_incoming_msg_t* msg, fidl_txn_t* txn) {
zx_status_t status = fuchsia_hardware_vsock_Device_dispatch(this, txn, msg, &fidl_ops);
return status;
}
static virtio_vsock_hdr_t make_hdr(const SocketDevice::ConnectionKey& key, uint16_t op,
uint32_t cid, const SocketDevice::CreditInfo& credit) {
return virtio_vsock_hdr_t{
.src_cid = cid,
.dst_cid = key.addr_.remote_cid,
.src_port = key.addr_.local_port,
.dst_port = key.addr_.remote_port,
.len = 0,
.type = 1,
.op = op,
.flags = op == VIRTIO_VSOCK_OP_SHUTDOWN ? 3u : 0u,
.buf_alloc = credit.buf_alloc,
.fwd_cnt = credit.fwd_count,
};
}
SocketDevice::SocketDevice(zx_device_t* bus_device, zx::bti bti, std::unique_ptr<Backend> backend)
: virtio::Device(bus_device, std::move(bti), std::move(backend)),
ddk::Device<SocketDevice, ddk::Unbindable, ddk::Messageable>(bus_device),
dispatch_loop_(&kAsyncLoopConfigNoAttachToCurrentThread),
rx_(this, kDataBacklog, kFrameSize),
tx_(this, kDataBacklog, kFrameSize),
event_(this, kEventBacklog, sizeof(virtio_vsock_event_t)),
have_timer_(false),
timer_wait_handler_(this),
callback_closed_handler_(this) {}
SocketDevice::~SocketDevice() {}
void SocketDevice::MessageStart(zx::channel callbacks) {
fbl::AutoLock lock(&lock_);
if (callbacks_.is_valid()) {
RemoveCallbacksLocked();
}
callbacks_ = std::move(callbacks);
callback_closed_handler_.set_object(callbacks_.get());
callback_closed_handler_.set_trigger(ZX_SOCKET_PEER_CLOSED);
callback_closed_handler_.Begin(dispatch_loop_.dispatcher());
// Go and process the rings to handle any pending rx descriptors and start
// queueing new ones.
UpdateRxRingLocked();
}
zx_status_t SocketDevice::MessageSendRst(const ConnectionKey& key) {
fbl::AutoLock lock(&lock_);
CleanupConAndRstLocked(key.addr_);
return ZX_OK;
}
zx_status_t SocketDevice::MessageSendShutdown(const ConnectionKey& key) {
fbl::AutoLock lock(&lock_);
if (!callbacks_.is_valid()) {
return ZX_ERR_BAD_STATE;
}
auto conn = connections_.find(key);
if (conn == connections_.end() || conn->IsShuttingDown()) {
return ZX_ERR_BAD_STATE;
}
if (conn->BeginShutdown()) {
SendOpLocked(conn.CopyPointer(), VIRTIO_VSOCK_OP_SHUTDOWN);
}
return ZX_OK;
}
zx_status_t SocketDevice::MessageSendRequest(const ConnectionKey& key, zx::socket data) {
fbl::AutoLock lock(&lock_);
if (!callbacks_.is_valid()) {
return ZX_ERR_BAD_STATE;
}
if (connections_.find(key) != connections_.end()) {
return ZX_ERR_ALREADY_BOUND;
}
fbl::AllocChecker ac;
auto conn = fbl::MakeRefCountedChecked<Connection>(
&ac, key, std::move(data), fbl::BindMember(this, &SocketDevice::ConnectionSocketSignalled),
cid_, lock_);
if (!ac.check()) {
return ZX_ERR_NO_MEMORY;
}
connections_.insert(conn);
SendOpLocked(conn, VIRTIO_VSOCK_OP_REQUEST);
return ZX_OK;
}
zx_status_t SocketDevice::MessageSendResponse(const ConnectionKey& key, zx::socket data) {
fbl::AutoLock lock(&lock_);
if (!callbacks_.is_valid()) {
return ZX_ERR_BAD_STATE;
}
if (connections_.find(key) != connections_.end()) {
return ZX_ERR_ALREADY_BOUND;
}
fbl::AllocChecker ac;
auto conn = fbl::MakeRefCountedChecked<Connection>(
&ac, key, std::move(data), fbl::BindMember(this, &SocketDevice::ConnectionSocketSignalled),
cid_, lock_);
if (!ac.check()) {
return ZX_ERR_NO_MEMORY;
}
conn->MakeActive(dispatch_loop_.dispatcher());
connections_.insert(conn);
SendOpLocked(conn, VIRTIO_VSOCK_OP_RESPONSE);
return ZX_OK;
}
zx_status_t SocketDevice::MessageSendVmo(const ConnectionKey& key, zx::vmo vmo, uint64_t off,
uint64_t len) {
fbl::AutoLock lock(&lock_);
if (!callbacks_.is_valid()) {
return ZX_ERR_BAD_STATE;
}
auto conn = connections_.find(key);
if (conn == connections_.end()) {
return ZX_ERR_NOT_FOUND;
}
// Forbid the zero length as the VMO transfer code will get confused.
if (len == 0) {
return ZX_ERR_INVALID_ARGS;
}
zx_status_t result = conn->SetVmo(bti_, std::move(vmo), off, len, bti_contiguity_);
if (result != ZX_OK) {
return result;
}
ContinueTxLocked(false, conn.CopyPointer());
return ZX_OK;
}
uint32_t SocketDevice::MessageGetCid() {
fbl::AutoLock lock(&lock_);
return cid_;
}
zx_status_t SocketDevice::Init() {
fbl::AutoLock lock(&lock_);
// It's a common part for all virtio devices: reset the device, notify
// about the driver and negotiate supported features.
DeviceReset();
DriverStatusAck();
if (!DeviceFeatureSupported(VIRTIO_F_VERSION_1)) {
zxlogf(ERROR, "%s: Legacy virtio interface is not supported by this driver", tag());
return ZX_ERR_NOT_SUPPORTED;
}
DriverFeatureAck(VIRTIO_F_VERSION_1);
// Plan to clean up unless everything goes right.
auto cleanup = fbl::MakeAutoCall([this]() TA_NO_THREAD_SAFETY_ANALYSIS { ReleaseLocked(); });
UpdateCidLocked();
zx_status_t rc;
rc = event_.Init(kEventId, bti());
if (rc != ZX_OK) {
zxlogf(ERROR, "%s: Failed to allocate event ring: %s", tag(), zx_status_get_string(rc));
return rc;
}
rc = rx_.Init(kRxId, bti());
if (rc != ZX_OK) {
zxlogf(ERROR, "%s: Failed to allocate rx ring: %s", tag(), zx_status_get_string(rc));
return rc;
}
rc = tx_.Init(kTxId, bti());
if (rc != ZX_OK) {
zxlogf(ERROR, "%s: Failed to allocate tx ring: %s", tag(), zx_status_get_string(rc));
return rc;
}
// Determine our bti contiguity.
zx_info_bti_t bti_info;
rc = bti_.get_info(ZX_INFO_BTI, &bti_info, sizeof(bti_info), nullptr, nullptr);
if (rc != ZX_OK) {
zxlogf(ERROR, "%s: Failed to determine BTI contiguity", tag());
return rc;
}
bti_contiguity_ = bti_info.minimum_contiguity;
// Start the interrupt thread and set the driver OK status
StartIrqThread();
// Start out dispatcher for connections
rc = dispatch_loop_.StartThread("virtio-vsock-connection");
if (rc != ZX_OK) {
zxlogf(ERROR, "%s: Failed to start dispatch thread: %s", tag(), zx_status_get_string(rc));
return rc;
}
// Setup our timer for retrying TX operations.
rc = zx::timer::create(ZX_TIMER_SLACK_CENTER, ZX_CLOCK_MONOTONIC, &tx_retry_timer_);
if (rc != ZX_OK) {
zxlogf(ERROR, "%s: Failed to create timer: %s", tag(), zx_status_get_string(rc));
return rc;
}
timer_wait_handler_.set_object(tx_retry_timer_.get());
timer_wait_handler_.set_trigger(ZX_TIMER_SIGNALED);
// Initialize the zx_device and publish us.
zx_status_t status = DdkAdd("virtio-vsock");
if (status != ZX_OK) {
zxlogf(ERROR, "%s: failed to add device: %s", tag(), zx_status_get_string(status));
return status;
}
device_ = zxdev();
event_.RefillRing();
cleanup.cancel();
DriverStatusOk();
return ZX_OK;
}
void SocketDevice::DdkRelease() {
fbl::AutoLock lock(&lock_);
ReleaseLocked();
}
void SocketDevice::IrqRingUpdate() {
fbl::AutoLock lock(&lock_);
tx_.ProcessDescriptors(
[this](const ConnectionKey& key, uint64_t payload) TA_NO_THREAD_SAFETY_ANALYSIS {
auto conn = connections_.find(key);
if (conn != connections_.end()) {
if (conn->NotifyVmoTxComplete(payload)) {
PerformCallbackLocked(fuchsia_hardware_vsock_CallbacksSendVmoComplete, key);
}
}
});
event_.ProcessDescriptors<virtio_vsock_event_t>(
[this](virtio_vsock_event_t* event, void* data, uint32_t data_len)
TA_NO_THREAD_SAFETY_ANALYSIS {
if (event->id == VIRTIO_VSOCK_EVENT_TRANSPORT_RESET) {
TransportResetLocked();
} else {
zxlogf(ERROR, "%s: Received unknown event: %d", tag(), event->id);
}
});
UpdateRxRingLocked();
// Send any queued ops in any freed tx descriptors first, in preference to
// any queued data transfers.
while (!has_pending_op_.is_empty()) {
auto conn = has_pending_op_.pop_front();
uint16_t op = conn->TakePendingOp();
if (!SendOp_RawLocked(conn->GetKey().addr_, op, conn->GetCreditInfo())) {
conn->QueueOp(op);
has_pending_op_.push_front(conn);
break;
}
}
RetryTxLocked(false);
}
void SocketDevice::IrqConfigChange() {
fbl::AutoLock lock(&lock_);
uint32_t old_cid = cid_;
UpdateCidLocked();
if (cid_ != old_cid) {
TransportResetLocked();
}
}
void SocketDevice::ProcessRxDescriptor(virtio_vsock_hdr_t* header, void* data, uint32_t data_len) {
if (header->dst_cid != cid_) {
zxlogf(ERROR, " %s: Received message for cid %d, but believe our cid is %d", tag(),
static_cast<uint32_t>(header->dst_cid), cid_);
return;
}
ConnectionKey key = ConnectionKey::FromHdr(header);
auto conn = connections_.find(key);
if (conn != connections_.end()) {
conn->UpdateCredit(header->buf_alloc, header->fwd_cnt);
}
if (header->op == VIRTIO_VSOCK_OP_RW) {
if (conn == connections_.end()) {
SendRstLocked(key);
} else {
if (!conn->Rx(data, data_len)) {
NotifyAndCleanupConLocked(conn.CopyPointer());
}
}
} else {
RxOpLocked(conn, key, header->op);
}
}
void SocketDevice::UpdateRxRingLocked() {
// Refuse to process rx buffers if we don't have callbacks. If the callbacks
// somehow vanish mid process then that's fine, we'll just dump a lot of
// requests on the floor, but there's little else we can do.
if (!callbacks_.is_valid()) {
return;
}
rx_.ProcessDescriptors<virtio_vsock_hdr_t>(
[this](virtio_vsock_hdr_t* header, void* data, uint32_t data_len)
TA_NO_THREAD_SAFETY_ANALYSIS { this->ProcessRxDescriptor(header, data, data_len); });
}
void SocketDevice::RxOpLocked(ConnectionIterator conn, const ConnectionKey& key, uint16_t op) {
switch (op) {
case VIRTIO_VSOCK_OP_INVALID:
zxlogf(ERROR, "%s: Received invalid op", tag());
break;
case VIRTIO_VSOCK_OP_REQUEST:
// Don't care if we have a connection or not, just send it to the
// service.
PerformCallbackLocked(fuchsia_hardware_vsock_CallbacksRequest, key);
break;
case VIRTIO_VSOCK_OP_RESPONSE: {
// Check for existing partial connection.
if (conn == connections_.end()) {
zxlogf(ERROR, "%s: Received response for unknown connection", tag());
// We weren't trying to make a connection, so reject this
SendRstLocked(key);
break;
}
// Upgrade the channel.
conn->MakeActive(dispatch_loop_.dispatcher());
PerformCallbackLocked(fuchsia_hardware_vsock_CallbacksResponse, key);
break;
}
case VIRTIO_VSOCK_OP_RST:
if (conn != connections_.end()) {
CleanupConLocked(conn.CopyPointer());
}
PerformCallbackLocked(fuchsia_hardware_vsock_CallbacksRst, key);
break;
case VIRTIO_VSOCK_OP_SHUTDOWN:
if (conn != connections_.end()) {
// Shutdown and move into the zombie state until the service
// confirms shutdown by sending the RST
conn->Close(dispatch_loop_.dispatcher());
DequeueTxLocked(conn.CopyPointer());
DequeueOpLocked(conn.CopyPointer());
}
PerformCallbackLocked(fuchsia_hardware_vsock_CallbacksShutdown, key);
break;
case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
if (conn == connections_.end()) {
SendRstLocked(key);
}
if (QueuedForTxLocked(conn.CopyPointer())) {
ContinueTxLocked(true, conn.CopyPointer());
}
break;
case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
if (conn == connections_.end()) {
SendRstLocked(key);
} else {
SendOpLocked(conn.CopyPointer(), VIRTIO_VSOCK_OP_CREDIT_UPDATE);
}
break;
case VIRTIO_VSOCK_OP_RW:
// This case should've gone to RxOp_Data
zxlogf(ERROR, "%s: OP_RW not handled here", tag());
break;
default:
zxlogf(ERROR, "%s: Unexpected op %d from host", tag(), op);
break;
}
}
bool SocketDevice::SendOp_RawLocked(const ConnectionKey& key, uint16_t op,
const CreditInfo& credit) {
virtio_vsock_hdr_t hdr = make_hdr(key, op, cid_, credit);
// Grab a free descriptor
uint16_t id;
if (!tx_.AllocInPlace(&id)) {
return false;
}
tx_.SetHeader(id, hdr);
tx_.SubmitChain(id, 0);
// Typically we call this in a path with a single TX, so minimal gains from
// trying to coalesce this
tx_.Kick();
return true;
}
void SocketDevice::SendOpLocked(fbl::RefPtr<Connection> conn, uint16_t op) {
// If there's a queue then keep queueing.
if (!has_pending_op_.is_empty()) {
// Store the op, updating if already queued.
conn->QueueOp(op);
QueueForOpLocked(conn);
return;
}
CreditInfo credit = conn->GetCreditInfo();
if (!SendOp_RawLocked(conn->GetKey(), op, credit)) {
conn->QueueOp(op);
QueueForOpLocked(conn);
}
}
void SocketDevice::RetryTxLocked(bool force_credit_request) {
for (auto it = has_pending_tx_.begin(); it != has_pending_tx_.end();) {
auto prev = it;
it++;
ContinueTxLocked(force_credit_request, prev.CopyPointer());
}
}
void SocketDevice::ContinueTxLocked(bool force_credit_request, fbl::RefPtr<Connection> conn) {
zx_status_t status = conn->ContinueTx(force_credit_request, tx_, dispatch_loop_.dispatcher());
if (status == ZX_OK || status == ZX_ERR_SHOULD_WAIT) {
if (conn->HasPendingOp() && !QueuedForOpLocked(conn)) {
SendOpLocked(conn, conn->TakePendingOp());
}
if (status == ZX_ERR_SHOULD_WAIT) {
QueueForTxLocked(conn);
} else if (status == ZX_OK) {
DequeueTxLocked(conn);
}
} else {
NotifyAndCleanupConLocked(conn);
}
}
void SocketDevice::SendRstLocked(const ConnectionKey& key) {
SendOp_RawLocked(key, VIRTIO_VSOCK_OP_RST, CreditInfo());
}
void SocketDevice::CleanupConLocked(fbl::RefPtr<Connection> conn) {
conn->Close(dispatch_loop_.dispatcher());
DequeueTxLocked(conn);
DequeueOpLocked(conn);
connections_.erase(*conn);
}
void SocketDevice::NotifyAndCleanupConLocked(fbl::RefPtr<Connection> conn) {
PerformCallbackLocked(fuchsia_hardware_vsock_CallbacksRst, conn->GetKey());
CleanupConLocked(conn);
}
void SocketDevice::CleanupConAndRstLocked(const ConnectionKey& key) {
auto it = connections_.find(key);
if (it != connections_.end()) {
SendOpLocked(it.CopyPointer(), VIRTIO_VSOCK_OP_RST);
CleanupConLocked(it.CopyPointer());
} else {
SendRstLocked(key);
}
}
void SocketDevice::RemoveCallbacksLocked() {
for (auto it = connections_.begin(); it != connections_.end(); it++) {
SendOpLocked(it.CopyPointer(), VIRTIO_VSOCK_OP_RST);
it->Close(dispatch_loop_.dispatcher());
}
connections_.clear();
callback_closed_handler_.Cancel();
callbacks_.reset();
has_pending_tx_.clear();
// We don't clear pending ops as we need our RST ops to finish sending.
}
void SocketDevice::PerformCallbackLocked(zx_status_t (*func)(zx_handle_t, const vsock_Addr*),
const ConnectionKey& key) {
if (callbacks_.is_valid()) {
func(callbacks_.get(), &key.addr_);
}
}
bool SocketDevice::QueuedForTxLocked(fbl::RefPtr<Connection> conn) {
return fbl::InContainer<PendingTxTag>(*conn);
}
void SocketDevice::QueueForTxLocked(fbl::RefPtr<Connection> conn) {
if (!QueuedForTxLocked(conn)) {
has_pending_tx_.push_back(conn);
EnableTxRetryTimerLocked();
}
}
void SocketDevice::DequeueTxLocked(fbl::RefPtr<Connection> conn) {
if (QueuedForTxLocked(conn)) {
has_pending_tx_.erase(*conn);
}
}
bool SocketDevice::QueuedForOpLocked(fbl::RefPtr<Connection> conn) {
return fbl::InContainer<PendingOpTag>(*conn);
}
void SocketDevice::QueueForOpLocked(fbl::RefPtr<Connection> conn) {
if (!QueuedForOpLocked(conn)) {
has_pending_op_.push_back(conn);
}
}
void SocketDevice::DequeueOpLocked(fbl::RefPtr<Connection> conn) {
if (QueuedForOpLocked(conn)) {
has_pending_op_.erase(*conn);
}
}
void SocketDevice::EnableTxRetryTimerLocked() {
if (!have_timer_) {
zx_status_t status = tx_retry_timer_.set(zx::deadline_after(zx::sec(1)), zx::sec(1));
if (status != ZX_OK) {
zxlogf(ERROR, "%s: Failed to set timer %s", tag(), zx_status_get_string(status));
return;
}
status = timer_wait_handler_.Begin(dispatch_loop_.dispatcher());
if (status != ZX_OK) {
zxlogf(ERROR, "%s: Failed to wait for timer %s", tag(), zx_status_get_string(status));
return;
}
have_timer_ = true;
}
}
void SocketDevice::TimerWaitHandler(async_dispatcher_t* dispatcher, async::WaitBase* wait,
zx_status_t status, const zx_packet_signal_t* signal) {
if (status != ZX_OK) {
// Dispatcher shut down.
return;
}
fbl::AutoLock lock(&lock_);
have_timer_ = false;
tx_retry_timer_.cancel();
RetryTxLocked(true);
if (!has_pending_tx_.is_empty()) {
EnableTxRetryTimerLocked();
}
}
void SocketDevice::CallbacksSignalled(async_dispatcher_t* dispatcher, async::WaitBase* wait,
zx_status_t status, const zx_packet_signal_t* signal) {
if (status != ZX_OK) {
// Dispatcher shut down.
return;
}
fbl::AutoLock lock(&lock_);
RemoveCallbacksLocked();
}
void SocketDevice::ConnectionSocketSignalled(zx_status_t status, const zx_packet_signal_t* signal,
fbl::RefPtr<Connection> conn) {
if (status != ZX_OK) {
// Dispatcher shut down
return;
}
fbl::AutoLock lock(&lock_);
if (conn->IsShuttingDown()) {
return;
}
if (signal->observed & ZX_SOCKET_PEER_CLOSED) {
NotifyAndCleanupConLocked(conn);
return;
}
ContinueTxLocked(false, conn);
}
void SocketDevice::UpdateCidLocked() {
virtio_vsock_config_t config;
CopyDeviceConfig(&config, sizeof(config));
cid_ = static_cast<uint32_t>(config.guest_cid);
}
void SocketDevice::ReleaseLocked() {
RemoveCallbacksLocked();
has_pending_op_.clear();
// Shutting down the dispatch loop will remove any existing wait handlers for
// things like timer_wait_handler_.
dispatch_loop_.Shutdown();
rx_.FreeBuffers();
tx_.FreeBuffers();
event_.FreeBuffers();
virtio::Device::Release();
}
void SocketDevice::TransportResetLocked() {
// shit this is complicated, just reload the cid for now
zxlogf(INFO, "%s: Received transport reset!", tag());
for (auto& it : connections_) {
it.Close(dispatch_loop_.dispatcher());
}
connections_.clear();
has_pending_tx_.clear();
has_pending_op_.clear();
UpdateCidLocked();
if (callbacks_.is_valid()) {
fuchsia_hardware_vsock_CallbacksTransportReset(callbacks_.get(), cid_);
}
}
SocketDevice::IoBufferRing::IoBufferRing(virtio::Device* device, uint16_t count, uint32_t buf_size,
bool host_write_only)
: ring_(device), host_write_only_(host_write_only), count_(count), buf_size_(buf_size) {}
SocketDevice::IoBufferRing::~IoBufferRing() { FreeBuffers(); }
zx_status_t SocketDevice::IoBufferRing::Init(uint16_t index, const zx::bti& bti) {
zx_status_t rc = ring_.Init(index, count_);
if (rc != ZX_OK) {
return rc;
}
rc = io_buffer_init(&io_buffer_, bti.get(), buf_size_ * count_,
IO_BUFFER_CONTIG | (host_write_only_ ? IO_BUFFER_RO : IO_BUFFER_RW));
if (rc != ZX_OK) {
return rc;
}
// Set the flags in all the descriptors if is host_write_only, this means the
// device (aka host) can write to the buffers, but we as the driver may only
// read from them.
if (host_write_only_) {
for (uint16_t id = 0; id < count_; id++) {
struct vring_desc* desc = ring_.DescFromIndex(id);
desc->addr = io_buffer_phys(&io_buffer_) + id * buf_size_;
desc->len = buf_size_;
desc->flags |= VRING_DESC_F_WRITE;
}
}
return ZX_OK;
}
void SocketDevice::IoBufferRing::FreeBuffers() {
if (io_buffer_is_valid(&io_buffer_)) {
io_buffer_release(&io_buffer_);
}
}
SocketDevice::RxIoBufferRing::RxIoBufferRing(virtio::Device* device, uint16_t count,
uint32_t buf_size)
: IoBufferRing(device, count, buf_size, true) {}
void SocketDevice::RxIoBufferRing::RefillRing() {
assert(io_buffer_is_valid(&io_buffer_));
bool needs_kick = false;
uint16_t id;
struct vring_desc* desc;
while ((desc = ring_.AllocDescChain(1, &id))) {
desc->len = buf_size_;
ring_.SubmitChain(id);
needs_kick = true;
}
if (needs_kick) {
Kick();
}
}
template <typename H, typename F>
void SocketDevice::RxIoBufferRing::ProcessDescriptors(F func) {
ring_.IrqRingUpdate([this, &func](vring_used_elem* used_elem) {
uint16_t last_id = static_cast<uint16_t>(used_elem->id);
struct vring_desc* desc = ring_.DescFromIndex(last_id);
if (desc->len < sizeof(H)) {
zxlogf(ERROR, "Descriptor is too short");
} else if ((desc->flags & VRING_DESC_F_NEXT) != 0) {
zxlogf(ERROR, "Chained descriptors are not supported");
} else {
func(reinterpret_cast<H*>(GetRawDesc(last_id, sizeof(H))), GetRawDesc(last_id, 0, sizeof(H)),
static_cast<uint32_t>(used_elem->len - sizeof(H)));
}
// Handle freeing arbitrarily long descriptor chains
while ((desc->flags & VRING_DESC_F_NEXT) != 0) {
uint16_t next_id = desc->next;
ring_.FreeDesc(last_id);
desc = ring_.DescFromIndex(last_id);
last_id = next_id;
}
ring_.FreeDesc(last_id);
});
RefillRing();
}
SocketDevice::TxIoBufferRing::TxIoBufferRing(virtio::Device* device, uint16_t count,
uint32_t buf_size)
: IoBufferRing(device, count, buf_size, false) {}
void* SocketDevice::TxIoBufferRing::AllocInPlace(uint16_t* id) {
struct vring_desc* desc = ring_.AllocDescChain(1, id);
if (desc) {
desc->addr = io_buffer_phys(&io_buffer_) + *id * buf_size_;
return GetRawDesc(*id, 0, sizeof(virtio_vsock_hdr_t));
}
return nullptr;
}
bool SocketDevice::TxIoBufferRing::AllocIndirect(const ConnectionKey& key, uint16_t* id) {
struct vring_desc* desc = ring_.AllocDescChain(2, id);
if (!desc) {
return false;
}
desc->addr = io_buffer_phys(&io_buffer_) + *id * buf_size_;
*reinterpret_cast<ConnectionKey*>(
GetRawDesc(*id, sizeof(ConnectionKey), sizeof(virtio_vsock_hdr_t))) = key;
return true;
}
void SocketDevice::TxIoBufferRing::SetIndirectPayload(uint16_t id, uintptr_t payload) {
struct vring_desc* desc = ring_.DescFromIndex(ring_.DescFromIndex(id)->next);
desc->addr = payload;
}
void SocketDevice::TxIoBufferRing::SubmitChain(uint16_t id, uint32_t data_len) {
struct vring_desc* desc = ring_.DescFromIndex(id);
desc->len = sizeof(virtio_vsock_hdr_t);
if ((desc->flags & VRING_DESC_F_NEXT) == 0) {
desc->len += data_len;
} else {
desc = ring_.DescFromIndex(desc->next);
desc->len = data_len;
}
ring_.SubmitChain(id);
}
void SocketDevice::TxIoBufferRing::FreeChain(uint16_t id) {
struct vring_desc* desc = ring_.DescFromIndex(id);
if ((desc->flags & VRING_DESC_F_NEXT) != 0) {
ring_.FreeDesc(desc->next);
}
ring_.FreeDesc(id);
}
template <typename F>
void SocketDevice::TxIoBufferRing::ProcessDescriptors(F func) {
ring_.IrqRingUpdate([this, &func](vring_used_elem* used_elem) {
uint16_t id = static_cast<uint16_t>(used_elem->id);
struct vring_desc* desc = ring_.DescFromIndex(id);
if ((desc->flags & VRING_DESC_F_NEXT) != 0) {
struct vring_desc* desc2 = ring_.DescFromIndex(desc->next);
ConnectionKey* key = reinterpret_cast<ConnectionKey*>(
GetRawDesc(id, sizeof(ConnectionKey), sizeof(virtio_vsock_hdr_t)));
func(*key, desc2->addr);
ring_.FreeDesc(desc->next);
}
ring_.FreeDesc(id);
});
}
SocketDevice::Connection::Connection(const ConnectionKey& key, zx::socket data,
SignalHandler wait_handler, uint32_t cid, fbl::Mutex& lock)
: lock_(lock),
key_(key),
state_(CON_WAIT_RESPONSE),
tx_count_(0),
rx_count_(0),
buf_alloc_(0),
fwd_cnt_(0),
data_(std::move(data)),
wait_handler_(data_.get(), ZX_SOCKET_READABLE | ZX_SOCKET_PEER_CLOSED, 0,
async::Wait::Handler([this, wait_handler = std::move(wait_handler)](
async_dispatcher_t* dispatcher, async::Wait* wait,
zx_status_t status, const zx_packet_signal_t* signal) {
fbl::RefPtr<Connection> ref;
wait_handler_ref_.swap(ref);
wait_handler(status, signal, ref);
})),
pending_vmo_(false),
has_pending_op_(false),
cid_(cid) {}
bool SocketDevice::Connection::PendingTx() {
if (pending_vmo_) {
return true;
}
return SocketTxPending();
}
bool SocketDevice::Connection::IsShuttingDown() {
return state_ == Connection::CON_ZOMBIE || state_ == Connection::CON_SHUTTING_DOWN ||
state_ == Connection::CON_WILL_SHUT_DOWN;
}
bool SocketDevice::Connection::BeginShutdown() {
assert(!IsShuttingDown());
if (PendingTx()) {
state_ = Connection::CON_WILL_SHUT_DOWN;
return false;
}
state_ = Connection::CON_SHUTTING_DOWN;
return true;
}
bool SocketDevice::Connection::NotifyVmoTxComplete(uintptr_t paddr) {
if (pending_vmo_ && vmo_.final_paddr_ == paddr) {
vmo_.Release();
pending_vmo_ = false;
return true;
}
return false;
}
void SocketDevice::Connection::UpdateCredit(uint32_t buf, uint32_t fwd) {
buf_alloc_ = buf;
fwd_cnt_ = fwd;
}
void SocketDevice::Connection::MakeActive(async_dispatcher_t* disp) {
if (state_ != Connection::CON_WAIT_RESPONSE) {
zxlogf(ERROR, "Received response for already established connection");
return;
}
BeginWait(disp);
state_ = Connection::CON_ACTIVE;
}
bool SocketDevice::Connection::Rx(void* data, size_t len) {
size_t written = 0;
zx_status_t status = data_.write(0, data, len, &written);
rx_count_ += static_cast<uint32_t>(written);
// The way flow control works in vsock we should never end up in a
// situation where the socket cannot hold the data. Therefore we consider
// any failure to be catastrophic and terminate the connection.
return status == ZX_OK && written == len;
}
SocketDevice::CreditInfo SocketDevice::Connection::GetCreditInfo() {
zx_info_socket_t info;
zx_status_t status = data_.get_info(ZX_INFO_SOCKET, &info, sizeof(info), nullptr, nullptr);
if (status == ZX_OK) {
return CreditInfo(static_cast<uint32_t>(info.tx_buf_max),
static_cast<uint32_t>(info.tx_buf_size));
} else {
return CreditInfo();
}
}
virtio_vsock_hdr_t SocketDevice::Connection::MakeHdr(uint16_t op) {
return make_hdr(key_.addr_, op, cid_, GetCreditInfo());
}
void SocketDevice::Connection::Close(async_dispatcher_t* dispatcher) {
state_ = CON_ZOMBIE;
zx_status_t __UNUSED status = async::PostTask(dispatcher, [this] {
zx_status_t status = wait_handler_.Cancel();
if (status == ZX_OK) {
wait_handler_ref_.reset();
};
});
ZX_DEBUG_ASSERT(status == ZX_OK);
}
zx_status_t SocketDevice::Connection::ContinueTx(bool force_credit_request, TxIoBufferRing& tx,
async_dispatcher_t* dispatcher) {
if (pending_vmo_) {
bool more = DoVmoTx(force_credit_request, tx);
if (more) {
return ZX_ERR_SHOULD_WAIT;
}
// If the vmo has fully transmitted then we are allowed to start transmitting
// data from the socket again, so we fall through to check the socket.
}
if (SocketTxPending()) {
return DoSocketTx(force_credit_request, tx, dispatcher);
} else {
BeginWait(dispatcher);
}
return ZX_OK;
}
zx_status_t SocketDevice::Connection::SetVmo(zx::bti& bti, zx::vmo vmo, uint64_t offset,
uint64_t len, uint64_t bti_contiguity) {
if (pending_vmo_) {
return ZX_ERR_BAD_STATE;
}
zx_status_t result = vmo_.Set(bti, std::move(vmo), offset, len, bti_contiguity);
if (result != ZX_OK) {
return result;
}
pending_vmo_ = true;
return result;
}
void SocketDevice::Connection::QueueOp(uint16_t new_op) {
// RW operations don't get queued here
assert(new_op != VIRTIO_VSOCK_OP_RW);
if (!has_pending_op_) {
pending_op_ = new_op;
has_pending_op_ = true;
return;
}
// We preference RST, then SHUTDOWN for ops since we never want to
// overwrite those. Then we preference CREDIT_REQUEST, since if we
// overwrite a CREDIT_UPDATE this is fine as the REQUEST will contain
// an update anyway. The only other op we send is REQUEST and RESPONSE
// and they will never queue over themselves or other ops, except for
// RST, which have already taken care of.
if (pending_op_ == VIRTIO_VSOCK_OP_RST || new_op == VIRTIO_VSOCK_OP_RST) {
pending_op_ = VIRTIO_VSOCK_OP_RST;
} else if (pending_op_ == VIRTIO_VSOCK_OP_SHUTDOWN || new_op == VIRTIO_VSOCK_OP_SHUTDOWN) {
pending_op_ = VIRTIO_VSOCK_OP_SHUTDOWN;
} else if (pending_op_ == VIRTIO_VSOCK_OP_CREDIT_REQUEST ||
new_op == VIRTIO_VSOCK_OP_CREDIT_REQUEST) {
pending_op_ = VIRTIO_VSOCK_OP_CREDIT_REQUEST;
} else {
pending_op_ = new_op;
}
}
bool SocketDevice::Connection::HasPendingOp() { return has_pending_op_; }
uint16_t SocketDevice::Connection::TakePendingOp() {
assert(HasPendingOp());
has_pending_op_ = false;
return pending_op_;
}
size_t SocketDevice::Connection::GetHash(const ConnectionKey& addr) {
return addr.addr_.local_port + addr.addr_.remote_port + addr.addr_.remote_cid;
}
const SocketDevice::ConnectionKey& SocketDevice::Connection::GetKey() const { return key_; }
void SocketDevice::Connection::CountTx(uint32_t len) {
// Previous peer_free amount.
uint32_t prev_peer_free = buf_alloc_ - (tx_count_ - fwd_cnt_);
// Determine our projected 'peer_free' amount after this.
uint32_t next_peer_free = buf_alloc_ - ((tx_count_ + len) - fwd_cnt_);
// Have we crossed the threshold of 40% or 80% used?
uint32_t prev_util = 100 - ((prev_peer_free * 100) / buf_alloc_);
uint32_t next_util = 100 - ((next_peer_free * 100) / buf_alloc_);
if ((prev_util < 40 && next_util >= 40) || (prev_util < 80 && next_util >= 80)) {
QueueOp(VIRTIO_VSOCK_OP_CREDIT_REQUEST);
}
tx_count_ += len;
}
bool SocketDevice::Connection::SocketTxPending() {
zx_info_socket_t info;
zx_status_t status = data_.get_info(ZX_INFO_SOCKET, &info, sizeof(info), nullptr, nullptr);
if (status != ZX_OK) {
return false;
}
return info.rx_buf_size != 0;
}
bool SocketDevice::Connection::DoVmoTx(bool force_credit_request, TxIoBufferRing& tx) {
bool needs_kick = false;
auto auto_kick = fbl::MakeAutoCall([&needs_kick, &tx]() {
if (needs_kick) {
tx.Kick();
}
});
while (vmo_.transfer_length_ > 0) {
uint32_t peer_free = GetPeerFree(force_credit_request);
if (peer_free == 0) {
return true;
}
uint16_t id;
if (!tx.AllocIndirect(key_, &id)) {
return true;
}
uint32_t len = static_cast<uint32_t>(vmo_.NextChunkLen(peer_free));
uintptr_t paddr = vmo_.Consume(len);
tx.SetIndirectPayload(id, paddr);
virtio_vsock_hdr_t hdr = MakeHdr(VIRTIO_VSOCK_OP_RW);
hdr.len = len;
tx.SetHeader(id, hdr);
tx.SubmitChain(id, len);
needs_kick = true;
CountTx(len);
}
return false;
}
zx_status_t SocketDevice::Connection::DoSocketTx(bool force_credit_request, TxIoBufferRing& tx,
async_dispatcher_t* dispatcher) {
bool needs_kick = false;
auto auto_kick = fbl::MakeAutoCall([&needs_kick, &tx]() {
if (needs_kick) {
tx.Kick();
}
});
zx_status_t status;
do {
uint32_t peer_free = GetPeerFree(force_credit_request);
if (peer_free == 0) {
return ZX_ERR_SHOULD_WAIT;
}
uint16_t id;
void* data = tx.AllocInPlace(&id);
if (!data) {
return ZX_ERR_SHOULD_WAIT;
}
size_t read_raw;
status = data_.read(
0, data,
std::min(static_cast<uint32_t>(kFrameSize - sizeof(virtio_vsock_hdr_t)), peer_free),
&read_raw);
uint32_t read = static_cast<uint32_t>(read_raw);
if (status == ZX_OK) {
virtio_vsock_hdr_t hdr = MakeHdr(VIRTIO_VSOCK_OP_RW);
hdr.len = read;
tx.SetHeader(id, hdr);
tx.SubmitChain(id, read);
needs_kick = true;
CountTx(read);
} else {
tx.FreeChain(id);
}
} while (status == ZX_OK);
if (status == ZX_ERR_SHOULD_WAIT) {
BeginWait(dispatcher);
// We have received all the data off the socket, so the correct thing to return
// to the caller is ZX_OK so that it doesn't think there is still TX pending.
return ZX_OK;
}
return status;
}
void SocketDevice::Connection::BeginWait(async_dispatcher_t* disp) {
fbl::RefPtr<Connection> wait_ref = fbl::RefPtr(this);
zx_status_t __UNUSED status = async::PostTask(disp, [wait_ref, disp] {
fbl::AutoLock lock(&wait_ref->lock_);
if (!wait_ref->wait_handler_.is_pending()) {
wait_ref->wait_handler_ref_ = wait_ref;
zx_status_t status = wait_ref->wait_handler_.Begin(disp);
if (status != ZX_OK) {
ZX_DEBUG_ASSERT(status == ZX_ERR_BAD_STATE);
wait_ref->wait_handler_ref_.reset();
}
}
});
ZX_DEBUG_ASSERT(status == ZX_OK);
}
uint32_t SocketDevice::Connection::GetPeerFree(bool request_credit) {
uint32_t peer_free = buf_alloc_ - (tx_count_ - fwd_cnt_);
if (peer_free == 0 && request_credit) {
QueueOp(VIRTIO_VSOCK_OP_CREDIT_REQUEST);
}
return peer_free;
}
zx_status_t SocketDevice::Connection::VmoWalker::Set(zx::bti& bti, zx::vmo vmo, uint64_t offset,
uint64_t len, uint64_t bti_contiguity) {
Release();
vmo_ = std::move(vmo);
contiguity_ = bti_contiguity;
transfer_offset_ = offset;
transfer_length_ = len;
// Construct a base pointer that is aligned to the contiguity
base_addr_ = fbl::round_down(offset, contiguity_);
// Determine an extended range to take into account the rounding amount
uint64_t full_range = fbl::round_up((offset - base_addr_) + len, contiguity_);
num_paddr_ = full_range / contiguity_;
fbl::AllocChecker ac;
paddrs_ = new (&ac) zx_paddr_t[num_paddr_];
if (!ac.check()) {
return ZX_ERR_NO_MEMORY;
}
zx_status_t status = bti.pin(ZX_BTI_PERM_READ | ZX_BTI_COMPRESS, vmo_, base_addr_, full_range,
paddrs_, num_paddr_, &pinned_pages_);
if (status != ZX_OK) {
Release();
}
return status;
}
void SocketDevice::Connection::VmoWalker::Release() {
pinned_pages_.reset();
vmo_.reset();
final_paddr_ = 0;
if (paddrs_) {
delete[] paddrs_;
paddrs_ = nullptr;
}
}
uint64_t SocketDevice::Connection::VmoWalker::NextChunkLen(uint64_t max) {
// First constrain max by the remaining transfer
uint64_t next_len = std::min(max, transfer_length_);
// Determine the end of the current contiguity region
uint64_t contiguity_area_end = fbl::round_up(transfer_offset_ + 1, contiguity_);
uint64_t max_in_contiguity = contiguity_area_end - transfer_offset_;
// Take the minimum of our transfer and the contiguity
return std::min(next_len, max_in_contiguity);
}
zx_paddr_t SocketDevice::Connection::VmoWalker::Consume(uint64_t len) {
assert(NextChunkLen(len) >= len);
// No need to subtract base_addr off transfer_offset since base_addr is
// already defined to be aligned to contiguity_ and so is factored out of the
// mod operation.
uint64_t contiguity_offset = transfer_offset_ % contiguity_;
zx_paddr_t ret = paddrs_[(transfer_offset_ - base_addr_) / contiguity_] + contiguity_offset;
transfer_offset_ += len;
transfer_length_ -= len;
if (transfer_length_ == 0) {
final_paddr_ = ret;
}
return ret;
}
} // namespace virtio