| // Copyright 2016 The Fuchsia Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "lib/netconnector/cpp/message_relay.h" |
| |
| #include "lib/fxl/logging.h" |
| #include "lib/fsl/tasks/message_loop.h" |
| |
| namespace netconnector { |
| |
| MessageRelayBase::MessageRelayBase() |
| : read_wait_(async_get_default()), |
| write_wait_(async_get_default()) {} |
| |
| MessageRelayBase::~MessageRelayBase() {} |
| |
| void MessageRelayBase::SetChannel(zx::channel channel) { |
| FXL_DCHECK(channel); |
| FXL_DCHECK(!channel_) |
| << "SetChannel called twice without intervening call to CloseChannel"; |
| |
| channel_.swap(channel); |
| |
| read_wait_.set_object(channel_.get()); |
| read_wait_.set_trigger(ZX_CHANNEL_READABLE | ZX_CHANNEL_PEER_CLOSED); |
| read_wait_.set_handler( |
| fbl::BindMember(this, &MessageRelayBase::ReadChannelMessages)); |
| |
| write_wait_.set_object(channel_.get()); |
| write_wait_.set_trigger(ZX_CHANNEL_WRITABLE | ZX_CHANNEL_PEER_CLOSED); |
| write_wait_.set_handler( |
| fbl::BindMember(this, &MessageRelayBase::WriteChannelMessages)); |
| |
| // We defer handling channel messages so that the caller doesn't get callbacks |
| // during SetChannel. |
| |
| read_wait_.Begin(); |
| |
| if (!messages_to_write_.empty()) { |
| write_wait_.Begin(); |
| } |
| } |
| |
| void MessageRelayBase::SendMessage(std::vector<uint8_t> message) { |
| messages_to_write_.push(std::move(message)); |
| |
| if (channel_ && !write_wait_.is_pending()) { |
| async_wait_result_t result = WriteChannelMessages(nullptr, ZX_OK, nullptr); |
| if (result == ASYNC_WAIT_AGAIN) |
| write_wait_.Begin(); |
| } |
| } |
| |
| void MessageRelayBase::CloseChannel() { |
| read_wait_.Cancel(); |
| write_wait_.Cancel(); |
| channel_.reset(); |
| OnChannelClosed(); |
| } |
| |
| async_wait_result_t MessageRelayBase::ReadChannelMessages( |
| async_t* async, zx_status_t status, const zx_packet_signal_t* signal) { |
| while (channel_) { |
| uint32_t actual_byte_count; |
| uint32_t actual_handle_count; |
| zx_status_t status = channel_.read(0, nullptr, 0, &actual_byte_count, |
| nullptr, 0, &actual_handle_count); |
| |
| if (status == ZX_ERR_SHOULD_WAIT) { |
| return ASYNC_WAIT_AGAIN; |
| } |
| |
| if (status == ZX_ERR_PEER_CLOSED) { |
| // Remote end of the channel closed. |
| CloseChannel(); |
| break; |
| } |
| |
| if (status != ZX_ERR_BUFFER_TOO_SMALL) { |
| FXL_LOG(ERROR) << "Failed to read (peek) from channel, status " << status; |
| CloseChannel(); |
| break; |
| } |
| |
| if (actual_handle_count != 0) { |
| FXL_LOG(ERROR) |
| << "Message received over channel has handles, closing connection"; |
| CloseChannel(); |
| break; |
| } |
| |
| std::vector<uint8_t> message(actual_byte_count); |
| status = |
| channel_.read(0, message.data(), message.size(), &actual_byte_count, |
| nullptr, 0, &actual_handle_count); |
| |
| if (status != ZX_OK) { |
| FXL_LOG(ERROR) << "Failed to read from channel, status " << status; |
| CloseChannel(); |
| break; |
| } |
| |
| FXL_DCHECK(actual_byte_count == message.size()); |
| |
| OnMessageReceived(std::move(message)); |
| } |
| |
| return ASYNC_WAIT_FINISHED; |
| } |
| |
| async_wait_result_t MessageRelayBase::WriteChannelMessages( |
| async_t* async, zx_status_t status, const zx_packet_signal_t* signal) { |
| if (!channel_) { |
| return ASYNC_WAIT_FINISHED; |
| } |
| |
| while (!messages_to_write_.empty()) { |
| const std::vector<uint8_t>& message = messages_to_write_.front(); |
| |
| zx_status_t status = |
| channel_.write(0, message.data(), message.size(), nullptr, 0); |
| |
| if (status == ZX_ERR_SHOULD_WAIT) { |
| return ASYNC_WAIT_AGAIN; |
| } |
| |
| if (status == ZX_ERR_PEER_CLOSED) { |
| // Remote end of the channel closed. |
| CloseChannel(); |
| break; |
| } |
| |
| if (status != ZX_OK) { |
| FXL_LOG(ERROR) << "zx::channel::write failed, status " << status; |
| CloseChannel(); |
| break; |
| } |
| |
| messages_to_write_.pop(); |
| } |
| |
| return ASYNC_WAIT_FINISHED; |
| } |
| |
| MessageRelay::MessageRelay() {} |
| |
| MessageRelay::~MessageRelay() {} |
| |
| void MessageRelay::SetMessageReceivedCallback( |
| std::function<void(std::vector<uint8_t>)> callback) { |
| message_received_callback_ = callback; |
| } |
| |
| void MessageRelay::SetChannelClosedCallback(std::function<void()> callback) { |
| channel_closed_callback_ = callback; |
| } |
| |
| void MessageRelay::OnMessageReceived(std::vector<uint8_t> message) { |
| if (message_received_callback_) { |
| message_received_callback_(std::move(message)); |
| } |
| } |
| |
| void MessageRelay::OnChannelClosed() { |
| if (channel_closed_callback_) { |
| channel_closed_callback_(); |
| } |
| } |
| |
| } // namespace example |