blob: de9849fa64f6c559390b068ca0862455669b2c60 [file] [log] [blame]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "hci_wrapper.h"
#include <fuchsia/hardware/bt/hci/cpp/banjo.h>
#include <fuchsia/hardware/bt/vendor/cpp/banjo.h>
#include <lib/async/cpp/task.h>
#include <lib/async/cpp/wait.h>
#include <lib/fit/defer.h>
#include <lib/zx/channel.h>
#include <fbl/macros.h>
#include <fbl/ref_counted.h>
#include <fbl/ref_ptr.h>
namespace bt::hci {
namespace {
VendorFeaturesBits BanjoVendorFeaturesToVendorFeaturesBits(bt_vendor_features_t features) {
VendorFeaturesBits out{0};
if (features & BT_VENDOR_FEATURES_SET_ACL_PRIORITY_COMMAND) {
out |= VendorFeaturesBits::kSetAclPriorityCommand;
}
if (features & BT_VENDOR_FEATURES_ANDROID_VENDOR_EXTENSIONS) {
out |= VendorFeaturesBits::kAndroidVendorExtensions;
}
return out;
}
} // namespace
class HciWrapperImpl final : public HciWrapper {
public:
HciWrapperImpl(std::unique_ptr<DeviceWrapper> device, async_dispatcher_t* dispatcher);
~HciWrapperImpl() override;
bool Initialize(ErrorCallback error_callback) override;
zx_status_t SendCommand(std::unique_ptr<CommandPacket> packet) override;
void SetEventCallback(EventPacketFunction callback) override;
zx_status_t SendAclPacket(std::unique_ptr<ACLDataPacket> packet) override;
void SetAclCallback(AclPacketFunction callback) override;
zx_status_t SendScoPacket(std::unique_ptr<ScoDataPacket> packet) override;
void SetScoCallback(ScoPacketFunction callback) override;
bool IsScoSupported() override { return sco_channel_.is_valid(); }
void ConfigureSco(ScoCodingFormat coding_format, ScoEncoding encoding, ScoSampleRate sample_rate,
StatusCallback callback) override;
void ResetSco(StatusCallback callback) override;
VendorFeaturesBits GetVendorFeatures() override;
fitx::result<zx_status_t, DynamicByteBuffer> EncodeSetAclPriorityCommand(
hci_spec::ConnectionHandle connection, hci::AclPriority priority) override;
private:
// Used by Banjo callbacks to detect stack destruction & to dispatch callbacks onto the bt-host
// thread.
struct CallbackData : public fbl::RefCounted<CallbackData> {
// Lock to guard reads/writes to the |dispatcher| pointer variable below (not the underlying
// dispatcher). Calls to async::PostTask and async::WaitBase::Begin should be considered reads,
// and require the lock to be held.
std::mutex lock;
// Set to nullptr on HciWrapperImpl destruction to indicate to Banjo callbacks, which may run on
// an HCI driver thread, that they should do nothing. It is safe to access |dispatcher| on a
// different thread than |HciWrapperImpl::dispatcher_| because operations on the underying
// dispatcher, including waiting for signals and posting tasks, are thread-safe. The only
// concern is that the callbacks would use the dispatcher after it is destroyed and this pointer
// is invalid, but that is impossible because the dispatcher outlives HciWrapper, and HciWrapper
// sets |dispatcher| to null upon destruction.
async_dispatcher_t* dispatcher __TA_GUARDED(lock);
};
void OnError(zx_status_t status);
void CleanUp();
// Wraps a callback in a callback that posts the callback to the bt-host thread.
StatusCallback ThreadSafeCallbackWrapper(StatusCallback callback);
void InitializeWait(async::WaitBase& wait, zx::channel& channel);
zx_status_t OnChannelReadable(zx_status_t status, async::WaitBase* wait,
MutableBufferView buffer_view, size_t header_size,
zx::channel& channel, fit::function<uint16_t()> size_from_header);
void OnAclSignal(async_dispatcher_t* dispatcher, async::WaitBase* wait, zx_status_t status,
const zx_packet_signal_t* signal);
void OnCommandSignal(async_dispatcher_t* dispatcher, async::WaitBase* wait, zx_status_t status,
const zx_packet_signal_t* signal);
void OnScoSignal(async_dispatcher_t* dispatcher, async::WaitBase* wait, zx_status_t status,
const zx_packet_signal_t* signal);
std::unique_ptr<DeviceWrapper> device_;
zx::channel acl_channel_;
zx::channel command_channel_;
zx::channel sco_channel_;
EventPacketFunction event_cb_;
AclPacketFunction acl_cb_;
ScoPacketFunction sco_cb_;
ErrorCallback error_cb_;
async::WaitMethod<HciWrapperImpl, &HciWrapperImpl::OnAclSignal> acl_wait_{this};
async::WaitMethod<HciWrapperImpl, &HciWrapperImpl::OnCommandSignal> command_wait_{this};
async::WaitMethod<HciWrapperImpl, &HciWrapperImpl::OnScoSignal> sco_wait_{this};
async_dispatcher_t* dispatcher_;
fbl::RefPtr<CallbackData> callback_data_;
};
HciWrapperImpl::HciWrapperImpl(std::unique_ptr<DeviceWrapper> device,
async_dispatcher_t* dispatcher)
: device_(std::move(device)), dispatcher_(dispatcher) {
callback_data_ = fbl::AdoptRef(new CallbackData{.dispatcher = dispatcher_});
}
HciWrapperImpl::~HciWrapperImpl() { CleanUp(); }
bool HciWrapperImpl::Initialize(ErrorCallback error_callback) {
error_cb_ = std::move(error_callback);
command_channel_ = device_->GetCommandChannel();
if (!command_channel_.is_valid()) {
bt_log(ERROR, "hci", "Failed to open command channel");
return false;
}
acl_channel_ = device_->GetACLDataChannel();
if (!acl_channel_.is_valid()) {
bt_log(ERROR, "hci", "Failed to open ACL channel");
return false;
}
fitx::result<zx_status_t, zx::channel> sco_result = device_->GetScoChannel();
if (sco_result.is_ok()) {
sco_channel_ = std::move(sco_result.value());
} else {
// Failing to open a SCO channel is not fatal, it just indicates lack of SCO support.
bt_log(INFO, "hci", "Failed to open SCO channel: %s",
zx_status_get_string(sco_result.error_value()));
}
return true;
}
zx_status_t HciWrapperImpl::SendCommand(std::unique_ptr<CommandPacket> packet) {
return command_channel_.write(/*flags=*/0, packet->view().data().data(), packet->view().size(),
/*handles=*/nullptr, /*num_handles=*/0);
}
void HciWrapperImpl::SetEventCallback(EventPacketFunction callback) {
ZX_ASSERT(callback);
event_cb_ = std::move(callback);
InitializeWait(command_wait_, command_channel_);
}
zx_status_t HciWrapperImpl::SendAclPacket(std::unique_ptr<ACLDataPacket> packet) {
return acl_channel_.write(/*flags=*/0, packet->view().data().data(), packet->view().size(),
/*handles=*/nullptr, /*num_handles=*/0);
}
void HciWrapperImpl::SetAclCallback(AclPacketFunction callback) {
acl_cb_ = std::move(callback);
if (!acl_cb_) {
acl_wait_.Cancel();
return;
}
InitializeWait(acl_wait_, acl_channel_);
}
zx_status_t HciWrapperImpl::SendScoPacket(std::unique_ptr<ScoDataPacket> packet) {
return sco_channel_.write(/*flags=*/0, packet->view().data().data(), packet->view().size(),
/*handles=*/nullptr, /*num_handles=*/0);
}
void HciWrapperImpl::SetScoCallback(ScoPacketFunction callback) {
ZX_ASSERT(sco_channel_.is_valid());
ZX_ASSERT(callback);
sco_cb_ = std::move(callback);
InitializeWait(sco_wait_, sco_channel_);
}
void HciWrapperImpl::OnError(zx_status_t status) {
CleanUp();
if (error_cb_) {
error_cb_(status);
}
}
void HciWrapperImpl::CleanUp() {
{
std::lock_guard<std::mutex> guard(callback_data_->lock);
callback_data_->dispatcher = nullptr;
}
// Waits need to be canceled before the underlying channels are destroyed.
acl_wait_.Cancel();
command_wait_.Cancel();
sco_wait_.Cancel();
acl_channel_.reset();
sco_channel_.reset();
command_channel_.reset();
}
HciWrapper::StatusCallback HciWrapperImpl::ThreadSafeCallbackWrapper(StatusCallback callback) {
return [cb = std::move(callback), data = callback_data_](zx_status_t status) mutable {
std::lock_guard<std::mutex> guard(data->lock);
// Don't run the callback if HciWrapper has been destroyed.
if (data->dispatcher) {
// This callback may be run on a different thread, so post the result callback to the
// bt-host thread.
async::PostTask(data->dispatcher, [cb = std::move(cb), status]() mutable { cb(status); });
}
};
}
void HciWrapperImpl::InitializeWait(async::WaitBase& wait, zx::channel& channel) {
ZX_ASSERT(channel.is_valid());
wait.Cancel();
wait.set_object(channel.get());
wait.set_trigger(ZX_CHANNEL_READABLE | ZX_CHANNEL_PEER_CLOSED);
ZX_ASSERT(wait.Begin(dispatcher_) == ZX_OK);
}
void HciWrapperImpl::ConfigureSco(ScoCodingFormat coding_format, ScoEncoding encoding,
ScoSampleRate sample_rate, StatusCallback callback) {
device_->ConfigureSco(
static_cast<uint8_t>(coding_format), static_cast<uint8_t>(encoding),
static_cast<uint8_t>(sample_rate),
[](void* ctx, zx_status_t status) {
std::unique_ptr<StatusCallback> callback(static_cast<StatusCallback*>(ctx));
(*callback)(status);
},
new StatusCallback(ThreadSafeCallbackWrapper(std::move(callback))));
}
void HciWrapperImpl::ResetSco(StatusCallback callback) {
device_->ResetSco(
[](void* ctx, zx_status_t status) {
std::unique_ptr<StatusCallback> callback(static_cast<StatusCallback*>(ctx));
(*callback)(status);
},
new StatusCallback(ThreadSafeCallbackWrapper(std::move(callback))));
}
VendorFeaturesBits HciWrapperImpl::GetVendorFeatures() {
return BanjoVendorFeaturesToVendorFeaturesBits(device_->GetVendorFeatures());
}
fitx::result<zx_status_t, DynamicByteBuffer> HciWrapperImpl::EncodeSetAclPriorityCommand(
hci_spec::ConnectionHandle connection, hci::AclPriority priority) {
bt_vendor_set_acl_priority_params_t priority_params = {
.connection_handle = connection,
.priority = static_cast<bt_vendor_acl_priority_t>((priority == AclPriority::kNormal)
? BT_VENDOR_ACL_PRIORITY_NORMAL
: BT_VENDOR_ACL_PRIORITY_HIGH),
.direction = static_cast<bt_vendor_acl_direction_t>((priority == AclPriority::kSource)
? BT_VENDOR_ACL_DIRECTION_SOURCE
: BT_VENDOR_ACL_DIRECTION_SINK)};
bt_vendor_params_t cmd_params = {.set_acl_priority = priority_params};
fpromise::result<DynamicByteBuffer> encode_result =
device_->EncodeVendorCommand(BT_VENDOR_COMMAND_SET_ACL_PRIORITY, cmd_params);
if (encode_result.is_error()) {
bt_log(WARN, "hci", "Failed to encode vendor command");
return fitx::error(ZX_ERR_INTERNAL);
}
return fitx::ok(encode_result.take_value());
}
zx_status_t HciWrapperImpl::OnChannelReadable(zx_status_t status, async::WaitBase* wait,
MutableBufferView buffer_view, size_t header_size,
zx::channel& channel,
fit::function<uint16_t()> size_from_header) {
if (status != ZX_OK) {
bt_log(ERROR, "hci", "channel error: %s", zx_status_get_string(status));
return ZX_ERR_IO;
}
uint32_t read_size;
zx_status_t read_status =
channel.read(0u, buffer_view.mutable_data(), /*handles=*/nullptr, buffer_view.size(), 0,
&read_size, /*actual_handles=*/nullptr);
if (read_status != ZX_OK) {
bt_log(DEBUG, "hci", "failed to read RX bytes: %s", zx_status_get_string(read_status));
// Stop receiving packets.
return ZX_ERR_IO;
}
// The wait needs to be restarted after every signal.
auto defer_wait = fit::defer([wait, this] {
zx_status_t status = wait->Begin(dispatcher_);
if (status != ZX_OK) {
bt_log(ERROR, "hci", "wait error: %s", zx_status_get_string(status));
}
});
if (read_size < header_size) {
bt_log(ERROR, "hci", "malformed packet - expected at least %zu bytes, got %u", header_size,
read_size);
return ZX_ERR_IO_DATA_INTEGRITY;
}
const size_t payload_size = read_size - header_size;
const uint16_t expected_payload_size = size_from_header();
if (payload_size != expected_payload_size) {
bt_log(ERROR, "hci",
"malformed packet - payload size from header (%hu) does not match"
" received payload size: %zu",
expected_payload_size, payload_size);
return ZX_ERR_IO_DATA_INTEGRITY;
}
return ZX_OK;
}
void HciWrapperImpl::OnAclSignal(async_dispatcher_t* dispatcher, async::WaitBase* wait,
zx_status_t status, const zx_packet_signal_t* signal) {
TRACE_DURATION("bluetooth", "HciWrapperImpl::OnAclSignal");
if (signal->observed & ZX_CHANNEL_PEER_CLOSED) {
OnError(ZX_ERR_PEER_CLOSED);
return;
}
ZX_ASSERT(signal->observed & ZX_CHANNEL_READABLE);
// Allocate a buffer for the packet. Since we don't know the size beforehand
// we allocate the largest possible buffer.
auto packet = ACLDataPacket::New(slab_allocators::kLargeACLDataPayloadSize);
if (!packet) {
bt_log(ERROR, "hci", "failed to allocate ACL data packet");
return;
}
auto size_from_header = [&packet] { return le16toh(packet->view().header().data_total_length); };
const zx_status_t read_status =
OnChannelReadable(status, wait, packet->mutable_view()->mutable_data(),
sizeof(hci_spec::ACLDataHeader), acl_channel_, size_from_header);
if (read_status == ZX_ERR_IO_DATA_INTEGRITY) {
// TODO(fxbug.dev/97362): Handle these types of errors by calling error_cb_.
bt_log(ERROR, "hci", "Received invalid ACL packet; dropping");
return;
}
if (read_status != ZX_OK) {
OnError(read_status);
return;
}
packet->InitializeFromBuffer();
acl_cb_(std::move(packet));
}
void HciWrapperImpl::OnCommandSignal(async_dispatcher_t* dispatcher, async::WaitBase* wait,
zx_status_t status, const zx_packet_signal_t* signal) {
TRACE_DURATION("bluetooth", "HciWrapperImpl::OnCommandSignal");
if (signal->observed & ZX_CHANNEL_PEER_CLOSED) {
bt_log(ERROR, "hci", "command channel closed");
OnError(ZX_ERR_PEER_CLOSED);
return;
}
ZX_ASSERT(signal->observed & ZX_CHANNEL_READABLE);
// Allocate a buffer for the packet. Since we don't know the size beforehand
// we allocate the largest possible buffer.
std::unique_ptr<EventPacket> packet = EventPacket::New(slab_allocators::kLargeControlPayloadSize);
if (!packet) {
bt_log(ERROR, "hci", "failed to allocate event packet");
OnError(ZX_ERR_NO_MEMORY);
return;
}
auto size_from_header = [&packet] { return packet->view().header().parameter_total_size; };
const zx_status_t read_status =
OnChannelReadable(status, wait, packet->mutable_view()->mutable_data(),
sizeof(hci_spec::EventHeader), command_channel_, size_from_header);
if (read_status == ZX_ERR_IO_DATA_INTEGRITY) {
// TODO(fxbug.dev/97362): Handle these types of errors by calling error_cb_.
bt_log(ERROR, "hci", "Received invalid event packet; dropping");
return;
}
if (read_status != ZX_OK) {
bt_log(ERROR, "hci", "failed to read event packet");
OnError(read_status);
return;
}
packet->InitializeFromBuffer();
event_cb_(std::move(packet));
}
void HciWrapperImpl::OnScoSignal(async_dispatcher_t* dispatcher, async::WaitBase* wait,
zx_status_t status, const zx_packet_signal_t* signal) {
TRACE_DURATION("bluetooth", "HciWrapperImpl::OnScoSignal");
if (signal->observed & ZX_CHANNEL_PEER_CLOSED) {
OnError(ZX_ERR_PEER_CLOSED);
return;
}
ZX_ASSERT(signal->observed & ZX_CHANNEL_READABLE);
// Allocate a buffer for the packet. Since we don't know the size beforehand
// we allocate the largest possible buffer.
std::unique_ptr<ScoDataPacket> packet =
ScoDataPacket::New(hci_spec::kMaxSynchronousDataPacketPayloadSize);
if (!packet) {
bt_log(ERROR, "hci", "failed to allocate SCO packet");
OnError(ZX_ERR_NO_MEMORY);
return;
}
auto size_from_header = [&packet] { return packet->view().header().data_total_length; };
const zx_status_t read_status =
OnChannelReadable(status, wait, packet->mutable_view()->mutable_data(),
sizeof(hci_spec::SynchronousDataHeader), sco_channel_, size_from_header);
if (read_status == ZX_ERR_IO_DATA_INTEGRITY) {
// TODO(fxbug.dev/97362): Handle these types of errors by calling error_cb_.
bt_log(ERROR, "hci", "Received invalid SCO packet; dropping");
return;
}
if (read_status != ZX_OK) {
OnError(read_status);
return;
}
packet->InitializeFromBuffer();
sco_cb_(std::move(packet));
}
std::unique_ptr<HciWrapper> HciWrapper::Create(std::unique_ptr<DeviceWrapper> device,
async_dispatcher_t* dispatcher) {
return std::make_unique<HciWrapperImpl>(std::move(device), dispatcher);
}
} // namespace bt::hci