blob: 8801621721dbc1e1599ac7decae370c5431a2fc1 [file] [log] [blame]
// Copyright 2018 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "src/connectivity/bluetooth/core/bt-host/public/pw_bluetooth_sapphire/internal/host/l2cap/dynamic_channel_registry.h"
#include <gtest/gtest.h>
namespace bt::l2cap::internal {
namespace {
constexpr uint16_t kNumChannelsAllowed = 256;
constexpr uint16_t kPsm = 0x0001;
constexpr ChannelId kRemoteCId = 0x60a3;
constexpr ChannelParameters kChannelParams;
class FakeDynamicChannel final : public DynamicChannel {
public:
FakeDynamicChannel(DynamicChannelRegistry* registry,
Psm psm,
ChannelId local_cid,
ChannelId remote_cid)
: DynamicChannel(registry, psm, local_cid, remote_cid) {}
// DynamicChannel overrides
bool IsConnected() const override { return connected_; }
bool IsOpen() const override { return open_; }
ChannelInfo info() const override {
return ChannelInfo::MakeBasicMode(kDefaultMTU, kDefaultMTU);
}
void DoConnect(ChannelId remote_cid) {
ASSERT_TRUE(SetRemoteChannelId(remote_cid))
<< "Could not set non-unique remote_cid " << remote_cid;
connected_ = true;
}
void DoOpen(bool new_open = true) {
open_ = new_open;
if (new_open) {
set_opened();
}
open_result_cb_();
}
void DoRemoteClose() {
open_ = false;
connected_ = false;
OnDisconnected();
}
// After calling |set_defer_disconnect_callback|, this returns the callback
// passed to |Disconnect|, or an empty callback if |Disconnect| hasn't been
// called.
DisconnectDoneCallback& disconnect_done_callback() {
return disconnect_done_callback_;
}
void set_defer_disconnect_done_callback() {
defer_disconnect_done_callback_ = true;
}
private:
// DynamicChannel overrides
void Open(fit::closure open_result_cb) override {
open_result_cb_ = std::move(open_result_cb);
}
void Disconnect(DisconnectDoneCallback done_cb) override {
open_ = false;
connected_ = false;
bt_log(DEBUG,
"l2cap",
"Got Disconnect %#.4x callback: %d",
local_cid(),
psm());
ASSERT_FALSE(disconnect_done_callback_);
if (defer_disconnect_done_callback_) {
disconnect_done_callback_ = std::move(done_cb);
} else {
done_cb();
}
}
fit::closure open_result_cb_;
DisconnectDoneCallback disconnect_done_callback_;
// If true, the Disconnect call does not immediately signal its callback. The
// test will have to call it explicitly with |disconnect_callback()|.
bool defer_disconnect_done_callback_ = false;
bool connected_ = false;
bool open_ = false;
};
// Fake registry subclass for testing inherited logic. Stubs out |MakeOutbound|
// to vend FakeDynamicChannels.
class TestDynamicChannelRegistry final : public DynamicChannelRegistry {
public:
TestDynamicChannelRegistry(DynamicChannelCallback close_cb,
ServiceRequestCallback service_request_cb)
: DynamicChannelRegistry(kNumChannelsAllowed,
std::move(close_cb),
std::move(service_request_cb),
/*random_channel_ids=*/true) {}
// Returns previous channel created.
FakeDynamicChannel* last_channel() { return last_channel_; }
// Make public for testing.
using DynamicChannelRegistry::AliveChannelCount;
using DynamicChannelRegistry::FindAvailableChannelId;
using DynamicChannelRegistry::FindChannelByLocalId;
using DynamicChannelRegistry::FindChannelByRemoteId;
using DynamicChannelRegistry::ForEach;
using DynamicChannelRegistry::RequestService;
private:
// DynamicChannelRegistry overrides
DynamicChannelPtr MakeOutbound(Psm psm,
ChannelId local_cid,
ChannelParameters params) override {
return MakeChannelInternal(psm, local_cid, kInvalidChannelId);
}
DynamicChannelPtr MakeInbound(Psm psm,
ChannelId local_cid,
ChannelId remote_cid,
ChannelParameters params) override {
auto channel = MakeChannelInternal(psm, local_cid, remote_cid);
channel->DoConnect(remote_cid);
return channel;
}
std::unique_ptr<FakeDynamicChannel> MakeChannelInternal(
Psm psm, ChannelId local_cid, ChannelId remote_cid) {
auto channel =
std::make_unique<FakeDynamicChannel>(this, psm, local_cid, remote_cid);
last_channel_ = channel.get();
return channel;
}
FakeDynamicChannel* last_channel_ = nullptr;
};
// DynamicChannelCallback static handler
void DoNothing(const DynamicChannel* channel) {}
// ServiceRequestCallback static handler
std::optional<DynamicChannelRegistry::ServiceInfo> RejectAllServices(
Psm /*psm*/) {
return std::nullopt;
}
TEST(DynamicChannelRegistryTest, OpenAndRemoteCloseChannel) {
ChannelId local_cid = kInvalidChannelId;
ChannelId remote_cid = kInvalidChannelId;
bool close_cb_called = false;
auto close_cb = [&](const DynamicChannel* chan) {
EXPECT_FALSE(close_cb_called);
close_cb_called = true;
EXPECT_TRUE(chan);
EXPECT_FALSE(chan->IsConnected());
EXPECT_FALSE(chan->IsOpen());
EXPECT_EQ(local_cid, chan->local_cid());
EXPECT_EQ(remote_cid, chan->remote_cid());
};
TestDynamicChannelRegistry registry(std::move(close_cb), RejectAllServices);
EXPECT_NE(kInvalidChannelId, registry.FindAvailableChannelId());
bool open_result_cb_called = false;
auto open_result_cb = [&](const DynamicChannel* chan) {
EXPECT_FALSE(open_result_cb_called);
open_result_cb_called = true;
EXPECT_TRUE(chan);
EXPECT_EQ(kPsm, chan->psm());
local_cid = chan->local_cid();
remote_cid = chan->remote_cid();
};
registry.OpenOutbound(kPsm, kChannelParams, std::move(open_result_cb));
registry.last_channel()->DoConnect(kRemoteCId);
registry.last_channel()->DoOpen();
EXPECT_TRUE(open_result_cb_called);
EXPECT_FALSE(close_cb_called);
auto channel_by_local_id = registry.FindChannelByLocalId(local_cid);
auto channel_by_remote_id = registry.FindChannelByRemoteId(remote_cid);
EXPECT_TRUE(channel_by_local_id);
EXPECT_TRUE(channel_by_remote_id);
EXPECT_EQ(channel_by_local_id, channel_by_remote_id);
registry.last_channel()->DoRemoteClose();
EXPECT_TRUE(close_cb_called);
EXPECT_FALSE(registry.FindChannelByLocalId(local_cid));
EXPECT_FALSE(registry.FindChannelByRemoteId(remote_cid));
}
TEST(DynamicChannelRegistryTest, OpenAndLocalCloseChannel) {
bool registry_close_cb_called = false;
auto registry_close_cb = [&](const DynamicChannel*) {
registry_close_cb_called = true;
};
TestDynamicChannelRegistry registry(std::move(registry_close_cb),
RejectAllServices);
bool open_result_cb_called = false;
ChannelId local_cid = kInvalidChannelId;
auto open_result_cb = [&](const DynamicChannel* chan) {
EXPECT_FALSE(open_result_cb_called);
open_result_cb_called = true;
EXPECT_TRUE(chan);
local_cid = chan->local_cid();
};
registry.OpenOutbound(kPsm, kChannelParams, std::move(open_result_cb));
registry.last_channel()->DoConnect(kRemoteCId);
registry.last_channel()->DoOpen();
EXPECT_TRUE(open_result_cb_called);
EXPECT_TRUE(registry.FindChannelByLocalId(local_cid));
bool close_cb_called = false;
registry.CloseChannel(local_cid, [&] { close_cb_called = true; });
EXPECT_FALSE(registry_close_cb_called);
EXPECT_TRUE(close_cb_called);
EXPECT_FALSE(registry.FindChannelByLocalId(local_cid));
}
TEST(DynamicChannelRegistryTest, RejectServiceRequest) {
bool service_request_cb_called = false;
auto service_request_cb = [&service_request_cb_called](Psm psm) {
EXPECT_FALSE(service_request_cb_called);
EXPECT_EQ(kPsm, psm);
service_request_cb_called = true;
return std::nullopt;
};
TestDynamicChannelRegistry registry(DoNothing, std::move(service_request_cb));
registry.RequestService(kPsm, registry.FindAvailableChannelId(), kRemoteCId);
EXPECT_TRUE(service_request_cb_called);
EXPECT_FALSE(registry.last_channel());
}
TEST(DynamicChannelRegistryTest, AcceptServiceRequestThenOpenOk) {
bool open_result_cb_called = false;
ChannelId local_cid = kInvalidChannelId;
ChannelId remote_cid = kInvalidChannelId;
DynamicChannelRegistry::DynamicChannelCallback open_result_cb =
[&](const DynamicChannel* chan) {
EXPECT_FALSE(open_result_cb_called);
open_result_cb_called = true;
EXPECT_TRUE(chan);
EXPECT_EQ(kPsm, chan->psm());
local_cid = chan->local_cid();
remote_cid = chan->remote_cid();
};
bool service_request_cb_called = false;
auto service_request_cb = [&service_request_cb_called,
open_result_cb =
std::move(open_result_cb)](Psm psm) mutable {
EXPECT_FALSE(service_request_cb_called);
EXPECT_EQ(kPsm, psm);
service_request_cb_called = true;
return DynamicChannelRegistry::ServiceInfo{ChannelParameters(),
open_result_cb.share()};
};
TestDynamicChannelRegistry registry(DoNothing, std::move(service_request_cb));
registry.RequestService(kPsm, registry.FindAvailableChannelId(), kRemoteCId);
EXPECT_TRUE(service_request_cb_called);
ASSERT_TRUE(registry.last_channel());
registry.last_channel()->DoOpen();
EXPECT_TRUE(open_result_cb_called);
EXPECT_NE(kInvalidChannelId, local_cid);
EXPECT_NE(kInvalidChannelId, remote_cid);
EXPECT_TRUE(registry.FindChannelByLocalId(local_cid));
bool close_cb_called = false;
registry.CloseChannel(local_cid, [&] { close_cb_called = true; });
EXPECT_TRUE(close_cb_called);
EXPECT_FALSE(registry.FindChannelByLocalId(local_cid));
}
TEST(DynamicChannelRegistryTest, AcceptServiceRequestThenOpenFails) {
bool open_result_cb_called = false;
DynamicChannelRegistry::DynamicChannelCallback open_result_cb =
[&open_result_cb_called](const DynamicChannel* chan) {
open_result_cb_called = true;
};
bool service_request_cb_called = false;
auto service_request_cb = [&service_request_cb_called,
open_result_cb =
std::move(open_result_cb)](Psm psm) mutable {
EXPECT_FALSE(service_request_cb_called);
EXPECT_EQ(kPsm, psm);
service_request_cb_called = true;
return DynamicChannelRegistry::ServiceInfo{ChannelParameters(),
open_result_cb.share()};
};
TestDynamicChannelRegistry registry(DoNothing, std::move(service_request_cb));
ChannelId local_cid = registry.FindAvailableChannelId();
EXPECT_NE(kInvalidChannelId, local_cid);
registry.RequestService(kPsm, local_cid, kRemoteCId);
EXPECT_TRUE(service_request_cb_called);
ASSERT_TRUE(registry.last_channel());
registry.last_channel()->DoOpen(false);
// Don't get channels that failed to open.
EXPECT_FALSE(open_result_cb_called);
EXPECT_FALSE(registry.FindChannelByLocalId(local_cid));
// The channel should be released upon this failure.
EXPECT_EQ(0u, registry.AliveChannelCount());
}
TEST(DynamicChannelRegistryTest,
DestroyRegistryWithOpenChannelNoDisconnectionRequest) {
bool close_cb_called = false;
auto close_cb = [&close_cb_called](const DynamicChannel* chan) {
close_cb_called = true;
};
TestDynamicChannelRegistry registry(std::move(close_cb), RejectAllServices);
bool open_result_cb_called = false;
auto open_result_cb = [&open_result_cb_called](const DynamicChannel* chan) {
EXPECT_FALSE(open_result_cb_called);
open_result_cb_called = true;
EXPECT_TRUE(chan);
};
registry.OpenOutbound(kPsm, kChannelParams, std::move(open_result_cb));
registry.last_channel()->DoConnect(kRemoteCId);
registry.last_channel()->DoOpen();
EXPECT_TRUE(open_result_cb_called);
EXPECT_TRUE(registry.FindChannelByRemoteId(kRemoteCId));
EXPECT_FALSE(close_cb_called);
}
TEST(DynamicChannelRegistryTest, ErrorConnectingChannel) {
bool open_result_cb_called = false;
auto open_result_cb = [&open_result_cb_called](const DynamicChannel* chan) {
EXPECT_FALSE(open_result_cb_called);
open_result_cb_called = true;
EXPECT_FALSE(chan);
};
bool close_cb_called = false;
auto close_cb = [&close_cb_called](auto) { close_cb_called = true; };
TestDynamicChannelRegistry registry(std::move(close_cb), RejectAllServices);
registry.OpenOutbound(kPsm, kChannelParams, std::move(open_result_cb));
registry.last_channel()->DoOpen(false);
EXPECT_TRUE(open_result_cb_called);
EXPECT_FALSE(close_cb_called);
// Should be no alive channels anymore.
EXPECT_EQ(0u, registry.AliveChannelCount());
}
TEST(DynamicChannelRegistryTest, ExhaustedChannelIds) {
int open_result_cb_count = 0;
// This callback expects the channel to be creatable.
DynamicChannelRegistry::DynamicChannelCallback success_open_result_cb =
[&open_result_cb_count](const DynamicChannel* chan) {
ASSERT_NE(nullptr, chan);
EXPECT_NE(kInvalidChannelId, chan->local_cid());
open_result_cb_count++;
};
int close_cb_count = 0;
auto close_cb = [&close_cb_count](auto) { close_cb_count++; };
TestDynamicChannelRegistry registry(std::move(close_cb), RejectAllServices);
// Open a lot of channels.
for (int i = 0; i < kNumChannelsAllowed; i++) {
registry.OpenOutbound(
kPsm + i, kChannelParams, success_open_result_cb.share());
registry.last_channel()->DoConnect(kRemoteCId + i);
registry.last_channel()->DoOpen();
}
EXPECT_EQ(kNumChannelsAllowed, open_result_cb_count);
EXPECT_EQ(0, close_cb_count);
// Ensure that channel IDs are exhausted.
EXPECT_EQ(kInvalidChannelId, registry.FindAvailableChannelId());
// This callback expects the channel to fail creation.
auto fail_open_result_cb =
[&open_result_cb_count](const DynamicChannel* chan) {
EXPECT_FALSE(chan);
open_result_cb_count++;
};
// Try to open a new channel.
registry.OpenOutbound(kPsm, kChannelParams, std::move(fail_open_result_cb));
EXPECT_EQ(kNumChannelsAllowed + 1, open_result_cb_count);
EXPECT_EQ(0, close_cb_count);
// Close the most recently opened channel.
auto last_remote_cid = registry.last_channel()->remote_cid();
registry.last_channel()->DoRemoteClose();
EXPECT_EQ(1, close_cb_count);
EXPECT_NE(kInvalidChannelId, registry.FindAvailableChannelId());
// Try to open a channel again.
registry.OpenOutbound(kPsm, kChannelParams, success_open_result_cb.share());
registry.last_channel()->DoConnect(last_remote_cid);
registry.last_channel()->DoOpen();
EXPECT_EQ(kNumChannelsAllowed + 2, open_result_cb_count);
EXPECT_EQ(1, close_cb_count);
ChannelId last_local_cid = registry.last_channel()->local_cid();
int close_cb_called = 0;
for (int i = 0; i < kNumChannelsAllowed; i++) {
registry.CloseChannel(kFirstDynamicChannelId + i,
[&] { close_cb_called++; });
}
EXPECT_EQ(close_cb_called, kNumChannelsAllowed);
EXPECT_FALSE(registry.FindChannelByLocalId(last_local_cid));
}
TEST(DynamicChannelRegistryTest,
ChannelIdNotReusedUntilDisconnectionCompletes) {
TestDynamicChannelRegistry registry(DoNothing, RejectAllServices);
// This callback expects the channel to be creatable.
int open_result_cb_count = 0;
DynamicChannelRegistry::DynamicChannelCallback success_open_result_cb =
[&](const DynamicChannel* chan) {
ASSERT_NE(nullptr, chan);
EXPECT_NE(kInvalidChannelId, chan->local_cid());
open_result_cb_count++;
};
// Open all but one of the available channels.
for (int i = 0; i < kNumChannelsAllowed - 1; i++) {
registry.OpenOutbound(
kPsm + i, kChannelParams, success_open_result_cb.share());
registry.last_channel()->DoConnect(kRemoteCId + i);
registry.last_channel()->DoOpen();
}
EXPECT_EQ(kNumChannelsAllowed - 1, open_result_cb_count);
// This callback records the info on channel that was created
ChannelId last_local_cid = kInvalidChannelId;
auto record_open_result_cb = [&](const DynamicChannel* chan) {
ASSERT_TRUE(chan);
last_local_cid = chan->local_cid();
};
// Ensure that channel IDs are not exhausted.
EXPECT_NE(kInvalidChannelId, registry.FindAvailableChannelId());
// Open a the last available channel.
registry.OpenOutbound(kPsm, kChannelParams, std::move(record_open_result_cb));
registry.last_channel()->DoConnect(kRemoteCId + kNumChannelsAllowed - 1);
registry.last_channel()->DoOpen();
EXPECT_NE(kInvalidChannelId, last_local_cid);
ASSERT_TRUE(registry.FindChannelByLocalId(last_local_cid));
// The channels are exhausted now.
ASSERT_EQ(kInvalidChannelId, registry.FindAvailableChannelId());
// Close the channel but don't let the disconnection complete.
FakeDynamicChannel* const last_channel = registry.last_channel();
last_channel->set_defer_disconnect_done_callback();
int close_cb_called = 0;
registry.CloseChannel(last_local_cid, [&] { close_cb_called++; });
// New channels should not reuse the "mostly disconnected" channel's ID.
EXPECT_EQ(close_cb_called, 0);
// There should still be no channels left.
EXPECT_EQ(kInvalidChannelId, registry.FindAvailableChannelId());
ASSERT_TRUE(registry.FindChannelByLocalId(last_local_cid));
EXPECT_FALSE(registry.FindChannelByLocalId(last_local_cid)->IsConnected());
// Complete the disconnection for the first channel opened.
ASSERT_TRUE(last_channel->disconnect_done_callback());
last_channel->disconnect_done_callback()();
EXPECT_EQ(close_cb_called, 1);
EXPECT_EQ(last_local_cid, registry.FindAvailableChannelId());
// Open a new channel and make sure that last ID can be reused now.
bool open_result_cb_called = false;
auto open_result_cb = [&](const DynamicChannel* chan) {
EXPECT_FALSE(open_result_cb_called);
open_result_cb_called = true;
ASSERT_TRUE(chan);
EXPECT_EQ(last_local_cid, chan->local_cid());
};
registry.OpenOutbound(kPsm, kChannelParams, std::move(open_result_cb));
registry.last_channel()->DoConnect(kRemoteCId + kNumChannelsAllowed - 1);
registry.last_channel()->DoOpen();
EXPECT_TRUE(open_result_cb_called);
close_cb_called = 0;
for (int i = 0; i < kNumChannelsAllowed; i++) {
registry.CloseChannel(kFirstDynamicChannelId + i,
[&] { close_cb_called++; });
}
EXPECT_EQ(close_cb_called, kNumChannelsAllowed);
EXPECT_FALSE(registry.FindChannelByLocalId(last_local_cid));
}
// Removing a channel from the channel map while iterating the channels in
// ForEach should not cause a use-after-free of the invalidated pointer.
TEST(DynamicChannelRegistryTest, CloseChannelInForEachCallback) {
bool registry_close_cb_called = false;
auto registry_close_cb = [&](const DynamicChannel*) {
registry_close_cb_called = true;
};
TestDynamicChannelRegistry registry(std::move(registry_close_cb),
RejectAllServices);
bool open_result_cb_called = false;
ChannelId local_cid = kInvalidChannelId;
auto open_result_cb = [&](const DynamicChannel* chan) {
EXPECT_FALSE(open_result_cb_called);
open_result_cb_called = true;
EXPECT_TRUE(chan);
local_cid = chan->local_cid();
};
registry.OpenOutbound(kPsm, kChannelParams, std::move(open_result_cb));
registry.last_channel()->DoConnect(kRemoteCId);
registry.last_channel()->DoOpen();
EXPECT_TRUE(open_result_cb_called);
EXPECT_TRUE(registry.FindChannelByLocalId(local_cid));
// Even if the next iterator is "end", it would still be unsafe to advance the
// erased iterator.
registry.ForEach([&](DynamicChannel* chan) {
registry.CloseChannel(chan->local_cid(), [] {});
});
EXPECT_FALSE(registry.FindChannelByLocalId(local_cid));
}
} // namespace
} // namespace bt::l2cap::internal