blob: 081fb5baa6045873b3bb47c883cd08131d9776e3 [file] [log] [blame] [edit]
// Copyright 2022 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "dgram_cache.h"
#include <fidl/fuchsia.posix.socket/cpp/wire.h>
#include <lib/fit/result.h>
#include <lib/zx/eventpair.h>
#include <lib/zx/handle.h>
#include <lib/zx/time.h>
#include <zircon/types.h>
#include <optional>
#include <utility>
#include <vector>
#include "hash.h"
#include "socket_address.h"
namespace fnet = fuchsia_net;
namespace fsocket = fuchsia_posix_socket;
using fuchsia_posix_socket::wire::CmsgRequests;
RequestedCmsgSet::RequestedCmsgSet(
const fsocket::wire::DatagramSocketRecvMsgPostflightResponse& response) {
if (response.has_requests()) {
requests_ = response.requests();
}
if (response.has_timestamp()) {
so_timestamp_filter_ = response.timestamp();
} else {
so_timestamp_filter_ = fsocket::wire::TimestampOption::kDisabled;
}
}
std::optional<fsocket::wire::TimestampOption> RequestedCmsgSet::so_timestamp() const {
return so_timestamp_filter_;
}
bool RequestedCmsgSet::ip_tos() const {
return static_cast<bool>(requests_ & CmsgRequests::kIpTos);
}
bool RequestedCmsgSet::ip_ttl() const {
return static_cast<bool>(requests_ & CmsgRequests::kIpTtl);
}
bool RequestedCmsgSet::ip_recvorigdstaddr() const {
return static_cast<bool>(requests_ & CmsgRequests::kIpRecvorigdstaddr);
}
bool RequestedCmsgSet::ipv6_tclass() const {
return static_cast<bool>(requests_ & CmsgRequests::kIpv6Tclass);
}
bool RequestedCmsgSet::ipv6_hoplimit() const {
return static_cast<bool>(requests_ & CmsgRequests::kIpv6Hoplimit);
}
bool RequestedCmsgSet::ipv6_pktinfo() const {
return static_cast<bool>(requests_ & CmsgRequests::kIpv6Pktinfo);
}
using RequestedCmsgResult = fit::result<ErrOrOutCode, std::optional<RequestedCmsgSet>>;
RequestedCmsgResult RequestedCmsgCache::Get(zx_wait_item_t err_wait_item,
bool get_requested_cmsg_set,
fidl::WireSyncClient<fsocket::DatagramSocket>& client) {
// TODO(https://fxbug.dev/42054723): Circumvent fast-path pessimization caused by lock
// contention between multiple fast paths.
std::lock_guard lock(lock_);
constexpr size_t MAX_WAIT_ITEMS = 2;
zx_wait_item_t wait_items[MAX_WAIT_ITEMS];
constexpr uint32_t ERR_WAIT_ITEM_IDX = 0;
wait_items[ERR_WAIT_ITEM_IDX] = err_wait_item;
std::optional<size_t> cmsg_idx;
while (true) {
uint32_t num_wait_items = ERR_WAIT_ITEM_IDX + 1;
if (get_requested_cmsg_set && cache_.has_value()) {
wait_items[num_wait_items] = {
.handle = cache_.value().validity.get(),
.waitfor = ZX_EVENTPAIR_PEER_CLOSED,
};
cmsg_idx = num_wait_items;
num_wait_items++;
}
zx_status_t status =
zx::handle::wait_many(wait_items, num_wait_items, zx::time::infinite_past());
switch (status) {
case ZX_OK: {
const zx_wait_item_t& err_wait_item_ref = wait_items[ERR_WAIT_ITEM_IDX];
if (err_wait_item_ref.pending & err_wait_item_ref.waitfor) {
std::optional err = GetErrorWithClient(client);
if (err.has_value()) {
return fit::error(err.value());
}
continue;
}
ZX_ASSERT_MSG(cmsg_idx.has_value(), "wait_many({{.pending = %d, .waitfor = %d}}) == ZX_OK",
err_wait_item_ref.pending, err_wait_item_ref.waitfor);
const zx_wait_item_t& cmsg_wait_item_ref = wait_items[cmsg_idx.value()];
ZX_ASSERT_MSG(cmsg_wait_item_ref.pending & cmsg_wait_item_ref.waitfor,
"wait_many({{.pending = %d, .waitfor = %d}, {.pending = %d, .waitfor = "
"%d}}) == ZX_OK",
err_wait_item_ref.pending, err_wait_item_ref.waitfor,
cmsg_wait_item_ref.pending, cmsg_wait_item_ref.waitfor);
} break;
case ZX_ERR_TIMED_OUT: {
if (!get_requested_cmsg_set) {
return fit::ok(std::nullopt);
}
if (cache_.has_value()) {
return fit::ok(cache_.value().requested_cmsg_set);
}
} break;
default:
ErrOrOutCode err = zx::error(status);
return fit::error(err);
}
const fidl::WireResult response = client->RecvMsgPostflight();
if (!response.ok()) {
ErrOrOutCode err = zx::error(response.status());
return fit::error(err);
}
const auto& result = response.value();
if (result.is_error()) {
return fit::error(zx::ok(static_cast<int16_t>(result.error_value())));
}
fsocket::wire::DatagramSocketRecvMsgPostflightResponse& response_inner = *result.value();
if (!response_inner.has_validity()) {
return fit::error(zx::ok(static_cast<int16_t>(EIO)));
}
cache_ = Value{
.validity = std::move(response_inner.validity()),
.requested_cmsg_set = RequestedCmsgSet(response_inner),
};
}
}
// TODO(https://fxbug.dev/42159831): remove this custom implementation when FIDL
// wire types support deep equality.
bool RouteCache::Key::operator==(const RouteCache::Key& o) const {
if (remote_addr != o.remote_addr) {
return false;
}
if (local_iface_and_addr.has_value() != o.local_iface_and_addr.has_value()) {
return false;
}
if (!local_iface_and_addr.has_value()) {
return true;
}
const auto& [iface, addr] = local_iface_and_addr.value();
const auto& [other_iface, other_addr] = o.local_iface_and_addr.value();
if (iface != other_iface) {
return false;
}
return addr.addr == other_addr.addr;
}
size_t RouteCache::KeyHasher::operator()(const Key& k) const {
size_t h = k.remote_addr.hash();
if (k.local_iface_and_addr.has_value()) {
const auto& [iface, addr] = k.local_iface_and_addr.value();
hash_combine(h, iface);
for (const auto& addr_bits : addr.addr) {
hash_combine(h, addr_bits);
}
}
return h;
}
void RouteCache::LruAddToFront(const Key& k, std::list<Key>::iterator& lru) {
lru_.push_front(k);
lru = lru_.begin();
}
void RouteCache::LruMoveToFront(const Key& k, std::list<Key>::iterator& lru) {
if (lru == lru_.begin()) {
return;
}
lru_.erase(lru);
LruAddToFront(k, lru);
}
using RouteCacheResult = fit::result<ErrOrOutCode, uint32_t>;
RouteCacheResult RouteCache::Get(
std::optional<SocketAddress>& remote_addr,
const std::optional<std::pair<uint64_t, fuchsia_net::wire::Ipv6Address>>& local_iface_and_addr,
const zx_wait_item_t& err_wait_item, fidl::WireSyncClient<fsocket::DatagramSocket>& client) {
// TODO(https://fxbug.dev/42054723): Circumvent fast-path pessimization caused by lock
// contention 1) between multiple fast paths and 2) between fast path and slow path.
std::lock_guard lock(lock_);
zx_wait_item_t wait_items[ZX_WAIT_MANY_MAX_ITEMS];
constexpr uint32_t ERR_WAIT_ITEM_IDX = 0;
wait_items[ERR_WAIT_ITEM_IDX] = err_wait_item;
while (true) {
std::optional<Key> cache_key;
std::optional<std::reference_wrapper<Value>> cache_value;
uint32_t num_wait_items = ERR_WAIT_ITEM_IDX + 1;
const std::optional<SocketAddress>& addr_to_lookup =
remote_addr.has_value() ? remote_addr : connected_;
// NOTE: `addr_to_lookup` might not have a value if we're looking up the
// connected addr for the first time. We still proceed with the syscall
// to check for errors in that case (since the socket might have been
// connected by another process).
if (addr_to_lookup.has_value()) {
const Key& key = cache_key.emplace(Key{
.remote_addr = addr_to_lookup.value(),
.local_iface_and_addr = local_iface_and_addr,
});
if (auto it = cache_.find(key); it != cache_.end()) {
const Value& value = cache_value.emplace(it->second);
ZX_ASSERT_MSG(value.eventpairs.size() + 1 <= ZX_WAIT_MANY_MAX_ITEMS,
"number of wait_items (%lu) exceeds maximum allowed (%zu)",
value.eventpairs.size() + 1, ZX_WAIT_MANY_MAX_ITEMS);
for (const zx::eventpair& eventpair : value.eventpairs) {
wait_items[num_wait_items] = {
.handle = eventpair.get(),
.waitfor = ZX_EVENTPAIR_PEER_CLOSED,
};
num_wait_items++;
}
}
}
zx_status_t status =
zx::handle::wait_many(wait_items, num_wait_items, zx::time::infinite_past());
switch (status) {
case ZX_OK: {
if (wait_items[ERR_WAIT_ITEM_IDX].pending & wait_items[ERR_WAIT_ITEM_IDX].waitfor) {
std::optional err = GetErrorWithClient(client);
if (err.has_value()) {
return fit::error(err.value());
}
continue;
}
} break;
case ZX_ERR_TIMED_OUT: {
if (cache_value.has_value()) {
ZX_ASSERT_MSG(cache_key.has_value(),
"cache_key was not set even though we retrieved an entry from the cache");
// Mark this entry in the cache as the most recently-used.
Value& value = cache_value.value();
LruMoveToFront(cache_key.value(), value.lru);
return fit::success(value.maximum_size);
}
} break;
default:
ErrOrOutCode err = zx::error(status);
return fit::error(err);
}
constexpr size_t kSendMsgPreflightRequestArenaSize =
fidl::MaxSizeInChannel<fsocket::wire::DatagramSocketSendMsgPreflightRequest,
fidl::MessageDirection::kSending>();
// Set a sensible upper limit for how much stack space we're going to allow
// using here to prevent deep stack usage in zxio/fdio. If this grows to
// untenable sizes we might have to change strategies here.
static_assert(kSendMsgPreflightRequestArenaSize <= 128);
fidl::Arena<kSendMsgPreflightRequestArenaSize> alloc;
const fidl::WireResult response = [&client, &alloc, &remote_addr, &local_iface_and_addr]() {
fidl::WireTableBuilder request_builder =
fsocket::wire::DatagramSocketSendMsgPreflightRequest::Builder(alloc);
if (remote_addr.has_value()) {
remote_addr.value().WithFIDL(
[&request_builder](fnet::wire::SocketAddress address) { request_builder.to(address); });
}
if (local_iface_and_addr.has_value()) {
const auto& [iface, addr] = local_iface_and_addr.value();
request_builder.ipv6_pktinfo(fsocket::wire::Ipv6PktInfoSendControlData{
.iface = iface,
.local_addr = addr,
});
}
return client->SendMsgPreflight(request_builder.Build());
}();
if (!response.ok()) {
ErrOrOutCode err = zx::error(response.status());
return fit::error(err);
}
const auto& result = response.value();
if (result.is_error()) {
return fit::error(zx::ok(static_cast<int16_t>(result.error_value())));
}
fsocket::wire::DatagramSocketSendMsgPreflightResponse& res = *result.value();
std::optional<SocketAddress> returned_addr;
if (!remote_addr.has_value()) {
if (res.has_to()) {
returned_addr = SocketAddress::FromFidl(res.to());
}
}
const std::optional<SocketAddress>& addr_to_store =
remote_addr.has_value() ? remote_addr : returned_addr;
if (!addr_to_store.has_value()) {
return fit::error(zx::ok(static_cast<int16_t>(EIO)));
}
if (!res.has_maximum_size() || !res.has_validity()) {
return fit::error(zx::ok(static_cast<int16_t>(EIO)));
}
std::vector<zx::eventpair> eventpairs;
eventpairs.reserve(res.validity().count());
std::move(res.validity().begin(), res.validity().end(), std::back_inserter(eventpairs));
// Remove least-recently-used element if cache is at capacity.
if (cache_.size() == kMaxEntries) {
const Key& k = lru_.back();
size_t removed = cache_.erase(k);
ZX_ASSERT_MSG(removed == 1,
"tried to remove least-recently-used item from route cache; removed %zu items",
removed);
lru_.pop_back();
}
const Key key = {
.remote_addr = addr_to_store.value(),
.local_iface_and_addr = local_iface_and_addr,
};
// Add the entry to the cache and set its eventpairs and maximum size to
// those returned by SendMsgPreflight. If the entry is newly added to the
// cache, add it to the front of the LRU list; if it already existed in the
// cache, move it to the front.
auto [it, inserted] = cache_.try_emplace(key);
Value& value = it->second;
value.eventpairs = std::move(eventpairs);
value.maximum_size = res.maximum_size();
if (inserted) {
LruAddToFront(key, value.lru);
} else {
LruMoveToFront(key, value.lru);
}
if (!remote_addr.has_value()) {
connected_ = addr_to_store.value();
}
}
}
std::optional<ErrOrOutCode> GetErrorWithClient(
fidl::WireSyncClient<fuchsia_posix_socket::DatagramSocket>& client) {
const fidl::WireResult response = client->GetError();
if (!response.ok()) {
return zx::error(response.status());
}
const auto& result = response.value();
if (result.is_error()) {
return zx::ok(static_cast<int16_t>(result.error_value()));
}
return std::nullopt;
}