blob: 5989b76e3e62b8a1136bddcdc6a95398950d4d43 [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "src/connectivity/overnet/overnetstack/bound_channel.h"
#include <lib/zx/socket.h>
#include <zircon/assert.h>
#include "src/connectivity/overnet/lib/protocol/coding.h"
#include "src/connectivity/overnet/overnetstack/fuchsia_port.h"
#include "src/connectivity/overnet/overnetstack/overnet_app.h"
namespace overnetstack {
BoundChannel::BoundChannel(OvernetApp* app,
overnet::RouterEndpoint::NewStream ns,
zx::channel channel)
: app_(app),
overnet_stream_(std::move(ns)),
zx_channel_(std::move(channel)) {
assert(zx_channel_.is_valid());
// Kick off the two read loops: one from the network and the other from the zx
// channel. Each proceeds much the same: as data is read, it's written and
// then the next read is begun.
StartNetRead();
StartChannelRead();
}
void BoundChannel::Close(const overnet::Status& status) {
OVERNET_TRACE(TRACE) << "CLOSE: " << status << " closed=" << closed_;
if (closed_) {
return;
}
closed_ = true;
zx_channel_.reset();
if (net_recv_) {
net_recv_->Close(status);
}
overnet_stream_.Close(status, [this] { Unref(); });
}
void BoundChannel::WriteToChannelAndStartNextRead(fidl::Message message) {
OVERNET_TRACE(TRACE) << "WriteToChannelAndStartNextRead hdl="
<< zx_channel_.get();
auto err = message.Write(zx_channel_.get(), 0);
switch (err) {
case ZX_OK:
pending_chan_bytes_.clear();
pending_chan_handles_.clear();
StartNetRead();
break;
case ZX_ERR_SHOULD_WAIT:
// Kernel push back: capture the slice, and ask to be informed when we
// can write.
std::vector<uint8_t>(message.bytes().begin(), message.bytes().end())
.swap(pending_chan_bytes_);
std::vector<zx::handle>(message.handles().begin(),
message.handles().end())
.swap(pending_chan_handles_);
message.ClearHandlesUnsafe();
Ref();
err = async_begin_wait(dispatcher_, &wait_send_.wait);
if (err != ZX_OK) {
Close(overnet::Status::FromZx(err).WithContext(
"Beginning wait for write"));
}
break;
default:
// If the write failed, close the stream.
Close(overnet::Status::FromZx(err).WithContext("Write"));
}
}
void BoundChannel::StartChannelRead() {
OVERNET_TRACE(TRACE) << "StartChannelRead hdl=" << zx_channel_.get();
struct Buffer {
uint8_t bytes[ZX_CHANNEL_MAX_MSG_BYTES];
zx_handle_t handles[ZX_CHANNEL_MAX_MSG_HANDLES];
};
auto buffer = std::make_unique<Buffer>();
fidl::Message message(
fidl::BytePart(buffer->bytes, ZX_CHANNEL_MAX_MSG_BYTES),
fidl::HandlePart(buffer->handles, ZX_CHANNEL_MAX_MSG_HANDLES));
auto err = message.Read(zx_channel_.get(), 0);
OVERNET_TRACE(DEBUG) << "StartChannelRead read result: "
<< overnet::Status::FromZx(err);
switch (err) {
case ZX_OK: {
auto encoded = EncodeMessage(std::move(message));
if (encoded.is_error()) {
Close(encoded.AsStatus());
return;
}
proxy_.Message(std::move(*encoded));
} break;
case ZX_ERR_SHOULD_WAIT:
// Kernel push back: ask to be informed when we can try again.
Ref();
err = async_begin_wait(dispatcher_, &wait_recv_.wait);
OVERNET_TRACE(DEBUG) << "async_begin_wait result: "
<< overnet::Status::FromZx(err);
if (err != ZX_OK) {
Close(overnet::Status::FromZx(err).WithContext(
"Beginning wait for read"));
Unref();
break;
}
break;
default:
// If the read failed, close the stream.
Close(overnet::Status::FromZx(err).WithContext("Read"));
break;
}
}
void BoundChannel::Proxy::Send_(fidl::Message message) {
assert(message.handles().size() == 0);
auto send_slice =
*overnet::Encode(overnet::Slice::FromContainer(message.bytes()));
channel_->Ref();
overnet::RouterEndpoint::Stream::SendOp(&channel_->overnet_stream_,
send_slice.length())
.Push(std::move(send_slice), [this] {
channel_->StartChannelRead();
channel_->Unref();
});
}
void BoundChannel::SendReady(async_dispatcher_t* dispatcher, async_wait_t* wait,
zx_status_t status,
const zx_packet_signal_t* signal) {
// Trampoline back into writing.
static_assert(offsetof(BoundWait, wait) == 0,
"The wait must be the first member of BoundWait for this "
"cast to be valid.");
reinterpret_cast<BoundWait*>(wait)->stream->OnSendReady(status, signal);
}
void BoundChannel::OnSendReady(zx_status_t status,
const zx_packet_signal_t* signal) {
OVERNET_TRACE(TRACE) << "OnSendReady: status="
<< overnet::Status::FromZx(status);
std::vector<zx_handle_t> handles;
for (auto& h : pending_chan_handles_) {
handles.push_back(h.release());
}
WriteToChannelAndStartNextRead(fidl::Message(
fidl::BytePart(pending_chan_bytes_.data(), pending_chan_bytes_.size(),
pending_chan_bytes_.size()),
fidl::HandlePart(handles.data(), handles.size(), handles.size())));
Unref();
}
void BoundChannel::RecvReady(async_dispatcher_t* dispatcher, async_wait_t* wait,
zx_status_t status,
const zx_packet_signal_t* signal) {
// Trampoline back into reading.
static_assert(offsetof(BoundWait, wait) == 0,
"The wait must be the first member of BoundWait for this "
"cast to be valid.");
reinterpret_cast<BoundWait*>(wait)->stream->OnRecvReady(status, signal);
}
void BoundChannel::OnRecvReady(zx_status_t status,
const zx_packet_signal_t* signal) {
OVERNET_TRACE(TRACE) << "OnRecvReady: status="
<< overnet::Status::FromZx(status)
<< " observed=" << signal->observed;
if (status != ZX_OK) {
Close(overnet::Status::FromZx(status).WithContext("OnRecvReady"));
Unref();
return;
}
if (signal->observed & ZX_CHANNEL_READABLE) {
StartChannelRead();
Unref();
return;
}
// Note: we flush all reads before propagating the close.
ZX_DEBUG_ASSERT(signal->observed & ZX_CHANNEL_PEER_CLOSED);
Close(overnet::Status::Ok());
Unref();
}
void BoundChannel::StartNetRead() {
OVERNET_TRACE(DEBUG) << "StartNetRead";
net_recv_.Reset(&overnet_stream_);
Ref();
net_recv_->PullAll(
[this](overnet::StatusOr<overnet::Optional<std::vector<overnet::Slice>>>
status) {
OVERNET_TRACE(DEBUG) << "StartNetRead got " << status;
if (status.is_error() || !status->has_value()) {
// If a read failed, close the stream.
Close(status.AsStatus());
Unref();
return;
}
if (closed_) {
Unref();
return;
}
// Write the message to the channel, then start reading again.
// Importantly: don't start reading until we've written to ensure
// that there's push back in the system.
auto decode_status = overnet::Decode(
overnet::Slice::Join((*status)->begin(), (*status)->end()));
if (decode_status.is_error()) {
// Failed to decode: close stream
Close(decode_status.AsStatus());
Unref();
return;
}
auto packet = overnet::Slice::Aligned(std::move(*decode_status));
if (auto process_status = stub_.Process_(fidl::Message(
fidl::BytePart(const_cast<uint8_t*>(packet.begin()),
packet.length(), packet.length()),
fidl::HandlePart()));
process_status != ZX_OK) {
Close(overnet::Status::FromZx(process_status)
.WithContext("Processing incoming channel message"));
Unref();
return;
}
Unref();
});
}
void BoundChannel::Stub::Message(
fuchsia::overnet::protocol::ZirconChannelMessage message) {
if (auto status = channel_->DecodeMessageThen(
&message,
[this](fidl::Message fidl_msg) {
channel_->WriteToChannelAndStartNextRead(std::move(fidl_msg));
return overnet::Status::Ok();
});
status.is_error()) {
channel_->Close(status);
}
}
overnet::StatusOr<fuchsia::overnet::protocol::ZirconChannelMessage>
BoundChannel::EncodeMessage(fidl::Message message) {
if (!message.has_header()) {
return overnet::Status(overnet::StatusCode::FAILED_PRECONDITION,
"FIDL message without a header");
}
assert(message.flags() == 0);
fuchsia::overnet::protocol::ZirconChannelMessage msg;
std::vector<uint8_t>(message.bytes().begin(), message.bytes().end())
.swap(msg.bytes);
// Keep track of failure.
// We cannot leave the loop below early or we risk leaking handles.
overnet::Status status = overnet::Status::Ok();
for (auto handle : message.handles()) {
zx::handle hdl(handle);
if (status.is_error()) {
continue;
}
zx_info_handle_basic_t info;
auto err = hdl.get_info(ZX_INFO_HANDLE_BASIC, &info, sizeof(info), nullptr,
nullptr);
if (err != ZX_OK) {
status = overnet::Status::FromZx(err).WithContext("Getting handle type");
continue;
}
fuchsia::overnet::protocol::ZirconHandle zh;
switch (info.type) {
case ZX_OBJ_TYPE_CHANNEL: {
auto fork_status = overnet_stream_.InitiateFork(
fuchsia::overnet::protocol::ReliabilityAndOrdering::
ReliableOrdered);
if (fork_status.is_error()) {
status = fork_status.AsStatus();
continue;
}
zh.set_channel(fuchsia::overnet::protocol::ChannelHandle{
fork_status->stream_id().as_fidl()});
auto channel = zx::channel(hdl.release());
assert(channel.is_valid());
assert(!hdl.is_valid());
app_->BindChannel(std::move(*fork_status), std::move(channel));
assert(!channel.is_valid());
} break;
case ZX_OBJ_TYPE_SOCKET: {
zx_info_socket_t socket_info;
err = hdl.get_info(ZX_INFO_SOCKET, &socket_info, sizeof(socket_info),
nullptr, nullptr);
if (err != ZX_OK) {
status =
overnet::Status::FromZx(err).WithContext("Getting socket info");
continue;
}
auto fork_status = overnet_stream_.InitiateFork(
// TODO(ctiller): unreliable for udp!
fuchsia::overnet::protocol::ReliabilityAndOrdering::
ReliableOrdered);
if (fork_status.is_error()) {
status = fork_status.AsStatus();
continue;
}
zh.set_socket(fuchsia::overnet::protocol::SocketHandle{
fork_status->stream_id().as_fidl(), socket_info.options});
auto socket = zx::socket(hdl.release());
assert(socket.is_valid());
assert(!hdl.is_valid());
app_->BindSocket(std::move(*fork_status), std::move(socket));
assert(!socket.is_valid());
} break;
default:
status = overnet::Status(overnet::StatusCode::INVALID_ARGUMENT,
"Bad handle type");
continue;
}
assert(status.is_ok());
msg.handles.emplace_back(std::move(zh));
}
message.ClearHandlesUnsafe();
if (status.is_error()) {
return status;
}
return msg;
}
overnet::Status BoundChannel::DecodeMessageThen(
fuchsia::overnet::protocol::ZirconChannelMessage* message,
fit::function<overnet::Status(fidl::Message)> then) {
std::vector<zx::handle> handle_objects;
for (const auto& h : message->handles) {
switch (h.Which()) {
case fuchsia::overnet::protocol::ZirconHandle::Tag::kChannel: {
zx::channel a, b;
if (auto err = zx::channel::create(0, &a, &b); err != ZX_OK) {
return overnet::Status::FromZx(err).WithContext("zx_channel_create");
}
auto fork_status = overnet_stream_.ReceiveFork(
h.channel().stream_id, fuchsia::overnet::protocol::
ReliabilityAndOrdering::ReliableOrdered);
if (fork_status.is_error()) {
return fork_status.AsStatus();
}
app_->BindChannel(std::move(*fork_status), std::move(a));
handle_objects.push_back(std::move(b));
} break;
case fuchsia::overnet::protocol::ZirconHandle::Tag::kSocket: {
zx::socket a, b;
if (auto err = zx::socket::create(h.socket().options, &a, &b);
err != ZX_OK) {
return overnet::Status::FromZx(err).WithContext("zx_socket_create");
}
auto fork_status = overnet_stream_.ReceiveFork(
h.channel().stream_id,
// TODO(ctiller): unreliable for udp
fuchsia::overnet::protocol::ReliabilityAndOrdering::
ReliableOrdered);
if (fork_status.is_error()) {
return fork_status.AsStatus();
}
app_->BindSocket(std::move(*fork_status), std::move(a));
handle_objects.push_back(std::move(b));
} break;
case fuchsia::overnet::protocol::ZirconHandle::Tag::Empty: {
return overnet::Status(overnet::StatusCode::INVALID_ARGUMENT,
"Bad handle type");
} break;
}
}
if (handle_objects.size() > ZX_CHANNEL_MAX_MSG_HANDLES) {
return overnet::Status(overnet::StatusCode::INVALID_ARGUMENT,
"Too many handles");
}
std::vector<zx_handle_t> handles;
for (auto& h : handle_objects) {
handles.push_back(h.release());
}
fidl::Message out(
fidl::BytePart(message->bytes.data(), message->bytes.size(),
message->bytes.size()),
fidl::HandlePart(handles.data(), handles.size(), handles.size()));
if (!out.has_header()) {
return overnet::Status(overnet::StatusCode::INVALID_ARGUMENT,
"Message has no header");
}
return then(std::move(out));
}
} // namespace overnetstack