blob: 6a1a08cb932f65a9f93519cf27b3cb6b9dbf48af [file] [log] [blame]
// Copyright 2019 The Fuchsia Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "fidl/type_shape.h"
#include <algorithm>
#include <safemath/clamped_math.h>
#include "fidl/flat_ast.h"
#include "fidl/recursion_detector.h"
namespace {
// TODO(fxbug.dev/7680): We may want to fail instead of saturating.
using DataSize = safemath::ClampedNumeric<uint32_t>;
// Given |offset| in bytes, returns how many padding bytes need to be added to |offset| to be
// aligned to |alignment|.
DataSize Padding(const DataSize offset, const DataSize alignment) {
// See <https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding> for a context on
// computing the amount of padding required.
// The following expression is from <https://stackoverflow.com/a/32104582> and is equivalent to
// "(alignment - (offset % alignment)) % alignment".
return (~offset.RawValue() + 1) & (alignment.RawValue() - 1);
}
// Given |size| and |alignment| in bytes, returns |size| "rounded up" to the next |alignment|
// interval.
DataSize AlignTo(uint32_t size, uint64_t alignment) {
// From <https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding>.
return (size + (alignment - 1)) & -alignment;
}
// Given |size|, returns |size| "rounded up" to the next alignment interval required by an
// out-of-line FIDL object.
DataSize ObjectAlign(uint32_t size) { return AlignTo(size, 8); }
} // namespace
namespace std {
// Add a partial specialization for std::numeric_limits<DataSize>::max(), which would
// otherwise return 0 (see
// <https://stackoverflow.com/questions/35575276/why-does-stdnumeric-limitssecondsmax-return-0> if
// you're curious about why.)
template <>
struct numeric_limits<DataSize> {
static constexpr DataSize max() noexcept { return DataSize(numeric_limits<uint32_t>::max()); }
};
static_assert(numeric_limits<DataSize>::max() == numeric_limits<uint32_t>::max());
} // namespace std
namespace {
namespace flat = fidl::flat;
namespace types = fidl::types;
using WireFormat = fidl::WireFormat;
constexpr uint32_t kSizeOfTransactionHeader = 16;
constexpr uint32_t kAlignmentOfTransactionHeader = 8;
constexpr uint32_t kHandleSize = 4;
DataSize UnalignedSize(const flat::Object& object, const WireFormat wire_format);
DataSize UnalignedSize(const flat::Object* object, const WireFormat wire_format);
[[maybe_unused]] DataSize Alignment(const flat::Object& object, const WireFormat wire_format);
[[maybe_unused]] DataSize Alignment(const flat::Object* object, const WireFormat wire_format);
DataSize Depth(const flat::Object& object, const WireFormat wire_format);
[[maybe_unused]] DataSize Depth(const flat::Object* object, const WireFormat wire_format);
DataSize MaxHandles(const flat::Object& object);
[[maybe_unused]] DataSize MaxHandles(const flat::Object* object);
DataSize MaxOutOfLine(const flat::Object& object, const WireFormat wire_format);
[[maybe_unused]] DataSize MaxOutOfLine(const flat::Object* object, const WireFormat wire_format);
bool HasPadding(const flat::Object& object, const WireFormat wire_format);
[[maybe_unused]] bool HasPadding(const flat::Object* object, const WireFormat wire_format);
bool HasFlexibleEnvelope(const flat::Object& object, const WireFormat wire_format);
[[maybe_unused]] bool HasFlexibleEnvelope(const flat::Object* object, const WireFormat wire_format);
DataSize AlignedSize(const flat::Object& object, const WireFormat wire_format) {
return AlignTo(UnalignedSize(object, wire_format), Alignment(object, wire_format));
}
[[maybe_unused]] DataSize AlignedSize(const flat::Object* object, const WireFormat wire_format) {
return AlignedSize(*object, wire_format);
}
template <typename T>
class TypeShapeVisitor : public flat::Object::Visitor<T> {
public:
TypeShapeVisitor() = delete;
explicit TypeShapeVisitor(const WireFormat wire_format) : wire_format_(wire_format) {}
protected:
WireFormat wire_format() const { return wire_format_; }
private:
const WireFormat wire_format_;
};
class UnalignedSizeVisitor final : public TypeShapeVisitor<DataSize> {
public:
using TypeShapeVisitor<DataSize>::TypeShapeVisitor;
std::any Visit(const flat::ArrayType& object) override {
return UnalignedSize(object.element_type) * object.element_count->value;
}
std::any Visit(const flat::VectorType& object) override { return DataSize(16); }
std::any Visit(const flat::StringType& object) override { return DataSize(16); }
std::any Visit(const flat::HandleType& object) override { return DataSize(kHandleSize); }
std::any Visit(const flat::PrimitiveType& object) override {
switch (object.subtype) {
case types::PrimitiveSubtype::kBool:
case types::PrimitiveSubtype::kInt8:
case types::PrimitiveSubtype::kUint8:
return DataSize(1);
case types::PrimitiveSubtype::kInt16:
case types::PrimitiveSubtype::kUint16:
return DataSize(2);
case types::PrimitiveSubtype::kInt32:
case types::PrimitiveSubtype::kUint32:
case types::PrimitiveSubtype::kFloat32:
return DataSize(4);
case types::PrimitiveSubtype::kInt64:
case types::PrimitiveSubtype::kUint64:
case types::PrimitiveSubtype::kFloat64:
return DataSize(8);
}
}
std::any Visit(const flat::IdentifierType& object) override {
switch (object.nullability) {
case types::Nullability::kNullable:
switch (object.type_decl->kind) {
case flat::Decl::Kind::kProtocol:
case flat::Decl::Kind::kService:
return DataSize(kHandleSize);
case flat::Decl::Kind::kStruct:
return DataSize(8);
case flat::Decl::Kind::kUnion:
return DataSize(24);
case flat::Decl::Kind::kBits:
case flat::Decl::Kind::kConst:
case flat::Decl::Kind::kEnum:
case flat::Decl::Kind::kResource:
case flat::Decl::Kind::kTable:
case flat::Decl::Kind::kTypeAlias:
assert(false && "UnalignedSize(flat::IdentifierType&) called on invalid nullable kind");
return DataSize(0);
}
case types::Nullability::kNonnullable: {
return UnalignedSize(object.type_decl);
}
}
}
std::any Visit(const flat::RequestHandleType& object) override { return DataSize(kHandleSize); }
std::any Visit(const flat::Enum& object) override {
return UnalignedSize(object.subtype_ctor->type);
}
std::any Visit(const flat::Bits& object) override {
return UnalignedSize(object.subtype_ctor->type);
}
std::any Visit(const flat::Service& object) override { return DataSize(kHandleSize); }
std::any Visit(const flat::Struct& object) override {
if (object.members.empty()) {
// Object is an empty struct
if (object.is_request_or_response && wire_format() != WireFormat::kV1Header) {
return DataSize(kSizeOfTransactionHeader);
}
return DataSize(1);
}
DataSize size = 0;
if (object.is_request_or_response && wire_format() != WireFormat::kV1Header) {
size += kSizeOfTransactionHeader;
}
for (const auto& member : object.members) {
const DataSize member_size =
UnalignedSize(member) + member.fieldshape(wire_format()).Padding();
size += member_size;
}
return size;
}
std::any Visit(const flat::Struct::Member& object) override {
return UnalignedSize(object.type_ctor->type);
}
std::any Visit(const flat::Table& object) override { return DataSize(16); }
std::any Visit(const flat::Table::Member& object) override {
return object.maybe_used ? UnalignedSize(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::Table::Member::Used& object) override {
return UnalignedSize(object.type_ctor->type);
}
std::any Visit(const flat::Union& object) override { return DataSize(24); }
std::any Visit(const flat::Union::Member& object) override {
return object.maybe_used ? UnalignedSize(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::Union::Member::Used& object) override {
return UnalignedSize(object.type_ctor->type);
}
std::any Visit(const flat::Protocol& object) override { return DataSize(kHandleSize); }
private:
DataSize UnalignedSize(const flat::Object& object) { return object.Accept(this); }
DataSize UnalignedSize(const flat::Object* object) { return UnalignedSize(*object); }
};
class AlignmentVisitor final : public TypeShapeVisitor<DataSize> {
public:
using TypeShapeVisitor<DataSize>::TypeShapeVisitor;
std::any Visit(const flat::ArrayType& object) override { return Alignment(object.element_type); }
std::any Visit(const flat::VectorType& object) override { return DataSize(8); }
std::any Visit(const flat::StringType& object) override { return DataSize(8); }
std::any Visit(const flat::HandleType& object) override { return DataSize(kHandleSize); }
std::any Visit(const flat::PrimitiveType& object) override {
return UnalignedSize(object, wire_format());
}
std::any Visit(const flat::IdentifierType& object) override {
switch (object.nullability) {
case types::Nullability::kNullable:
switch (object.type_decl->kind) {
case flat::Decl::Kind::kProtocol:
case flat::Decl::Kind::kService:
return DataSize(kHandleSize);
case flat::Decl::Kind::kStruct:
case flat::Decl::Kind::kUnion:
return DataSize(8);
case flat::Decl::Kind::kBits:
case flat::Decl::Kind::kConst:
case flat::Decl::Kind::kEnum:
case flat::Decl::Kind::kResource:
case flat::Decl::Kind::kTable:
case flat::Decl::Kind::kTypeAlias:
assert(false && "Alignment(flat::IdentifierType&) called on invalid nullable kind");
return DataSize(0);
}
case types::Nullability::kNonnullable:
return Alignment(object.type_decl);
}
}
std::any Visit(const flat::RequestHandleType& object) override { return DataSize(kHandleSize); }
std::any Visit(const flat::Enum& object) override { return Alignment(object.subtype_ctor->type); }
std::any Visit(const flat::Bits& object) override { return Alignment(object.subtype_ctor->type); }
std::any Visit(const flat::Service& object) override { return DataSize(kHandleSize); }
std::any Visit(const flat::Struct& object) override {
if (object.recursive) {
// |object| is recursive, therefore there must be a pointer to this struct in the recursion
// chain, with pointer-sized alignment.
return DataSize(8);
}
if (object.is_request_or_response) {
// Request/response structs have an alignment of 8. (Note that this was a bug before FTP-029,
// which changed method ordinals from 32 to 64 bits. Before FTP-029, the assumed alignment was
// 4, but in practice, all FIDL bindings and typeshape calculation code were assuming a
// minimum alignment of 8.)
return DataSize(kAlignmentOfTransactionHeader);
}
if (object.members.empty()) {
// Empty struct.
return DataSize(1);
}
DataSize alignment = 0;
for (const auto& member : object.members) {
alignment = std::max(alignment, Alignment(member));
}
return alignment;
}
std::any Visit(const flat::Struct::Member& object) override {
return Alignment(object.type_ctor->type);
}
std::any Visit(const flat::Table& object) override { return DataSize(8); }
std::any Visit(const flat::Table::Member& object) override {
return object.maybe_used ? Alignment(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::Table::Member::Used& object) override {
return Alignment(object.type_ctor->type);
}
std::any Visit(const flat::Union& object) override { return DataSize(8); }
std::any Visit(const flat::Union::Member& object) override {
return object.maybe_used ? Alignment(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::UnionMember::Used& object) override {
return Alignment(object.type_ctor->type);
}
std::any Visit(const flat::Protocol& object) override { return DataSize(kHandleSize); }
private:
DataSize Alignment(const flat::Object& object) { return object.Accept(this); }
DataSize Alignment(const flat::Object* object) { return Alignment(*object); }
};
class DepthVisitor : public TypeShapeVisitor<DataSize> {
public:
using TypeShapeVisitor<DataSize>::TypeShapeVisitor;
std::any Visit(const flat::ArrayType& object) override { return Depth(object.element_type); }
std::any Visit(const flat::VectorType& object) override {
return DataSize(1) + Depth(object.element_type);
}
std::any Visit(const flat::StringType& object) override { return DataSize(1); }
std::any Visit(const flat::HandleType& object) override { return DataSize(0); }
std::any Visit(const flat::PrimitiveType& object) override { return DataSize(0); }
std::any Visit(const flat::IdentifierType& object) override {
thread_local RecursionDetector recursion_detector;
auto guard = recursion_detector.Enter(&object);
if (!guard) {
return DataSize(0);
}
switch (object.nullability) {
case types::Nullability::kNullable:
switch (object.type_decl->kind) {
case flat::Decl::Kind::kProtocol:
case flat::Decl::Kind::kService:
return DataSize(0);
case flat::Decl::Kind::kStruct:
return DataSize(1) + Depth(object.type_decl);
case flat::Decl::Kind::kUnion:
return Depth(object.type_decl);
case flat::Decl::Kind::kBits:
case flat::Decl::Kind::kConst:
case flat::Decl::Kind::kEnum:
case flat::Decl::Kind::kTable:
case flat::Decl::Kind::kResource:
case flat::Decl::Kind::kTypeAlias:
assert(false && "Depth(flat::IdentifierType&) called on invalid nullable kind");
return DataSize(0);
}
case types::Nullability::kNonnullable:
switch (object.type_decl->kind) {
case flat::Decl::Kind::kBits:
case flat::Decl::Kind::kConst:
case flat::Decl::Kind::kEnum:
case flat::Decl::Kind::kProtocol:
case flat::Decl::Kind::kResource:
case flat::Decl::Kind::kService:
return DataSize(0);
case flat::Decl::Kind::kUnion:
case flat::Decl::Kind::kTable:
case flat::Decl::Kind::kTypeAlias:
case flat::Decl::Kind::kStruct:
return Depth(object.type_decl);
}
}
}
std::any Visit(const flat::RequestHandleType& object) override { return DataSize(0); }
std::any Visit(const flat::Enum& object) override { return Depth(object.subtype_ctor->type); }
std::any Visit(const flat::Bits& object) override { return Depth(object.subtype_ctor->type); }
std::any Visit(const flat::Service& object) override { return DataSize(0); }
std::any Visit(const flat::Struct& object) override {
if (object.recursive) {
return std::numeric_limits<DataSize>::max();
}
DataSize max_depth = 0;
for (const auto& member : object.members) {
max_depth = std::max(max_depth, Depth(member));
}
return max_depth;
}
std::any Visit(const flat::Struct::Member& object) override {
return Depth(object.type_ctor->type);
}
std::any Visit(const flat::Table& object) override {
DataSize max_depth = 0;
for (const auto& member : object.members) {
max_depth = std::max(max_depth, Depth(member));
}
return DataSize(1) + max_depth;
}
std::any Visit(const flat::Table::Member& object) override {
return object.maybe_used ? Depth(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::Table::Member::Used& object) override {
return DataSize(1) + Depth(object.type_ctor->type);
}
std::any Visit(const flat::Union& object) override {
DataSize max_depth;
for (const auto& member : object.members) {
max_depth = std::max(max_depth, Depth(member));
}
return DataSize(1) + max_depth;
}
std::any Visit(const flat::Union::Member& object) override {
return object.maybe_used ? Depth(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::Union::Member::Used& object) override {
return Depth(object.type_ctor->type);
}
std::any Visit(const flat::Protocol& object) override { return DataSize(0); }
protected:
DataSize Depth(const flat::Object& object) { return object.Accept(this); }
DataSize Depth(const flat::Object* object) { return Depth(*object); }
};
// This visitor calculates depth according to the "old" wire format (i.e. with
// static unions). It leverages |DepthVisitor| for any cases that are wire format
// dependent, and overrides cases that are different in the old wire format (i.e.
// unions).
class OldWireFormatDepthVisitor final : public DepthVisitor {
public:
// A wire format is provided here because the default constructor is disabled. In actuality,
// the wire format does not matter as this class is hardcoded to return depth under the
// "old" wire format.
explicit OldWireFormatDepthVisitor(const WireFormat wire_format) : DepthVisitor(wire_format) {}
// A nullable static union introduces an extra level of depth, since it gets replaced with
// a presence pointer.
std::any Visit(const flat::IdentifierType& object) override {
if (object.nullability == types::Nullability::kNullable &&
object.type_decl->kind == flat::Decl::Kind::kUnion) {
return DataSize(1) + Depth(object.type_decl);
}
return DepthVisitor::Visit(object);
}
// Static unions do not introduce an extra level of depth because they hold data inline,
// without the use of an envelope
std::any Visit(const flat::Union& object) override {
DataSize max_depth;
for (const auto& member : object.members) {
max_depth = std::max(max_depth, Depth(member));
}
return max_depth;
}
};
class MaxHandlesVisitor final : public flat::Object::Visitor<DataSize> {
public:
std::any Visit(const flat::ArrayType& object) override {
return MaxHandles(object.element_type) * object.element_count->value;
}
std::any Visit(const flat::VectorType& object) override {
return MaxHandles(object.element_type) * object.element_count->value;
}
std::any Visit(const flat::StringType& object) override { return DataSize(0); }
std::any Visit(const flat::HandleType& object) override { return DataSize(1); }
std::any Visit(const flat::PrimitiveType& object) override { return DataSize(0); }
std::any Visit(const flat::IdentifierType& object) override {
thread_local RecursionDetector recursion_detector;
// TODO(fxbug.dev/36327): This code is technically incorrect; see the visit(Struct&) overload
// for more details.
auto guard = recursion_detector.Enter(&object);
if (!guard) {
return DataSize(0);
}
return MaxHandles(object.type_decl);
}
std::any Visit(const flat::RequestHandleType& object) override { return DataSize(1); }
std::any Visit(const flat::Enum& object) override {
return MaxHandles(object.subtype_ctor->type);
}
std::any Visit(const flat::Bits& object) override {
return MaxHandles(object.subtype_ctor->type);
}
std::any Visit(const flat::Service& object) override { return DataSize(1); }
std::any Visit(const flat::Struct& object) override {
// TODO(fxbug.dev/36327): This is technically incorrect: if a struct is recursive, it may not
// directly contain a handle, but could contain e.g. a struct that contains a handle. In that
// case, this code will return 0 instead of std::numeric_limits<DataSize>::max(). This does pass
// all current tests and Fuchsia compilation, so fixing it isn't super-urgent.
if (object.recursive) {
for (const auto& member : object.members) {
switch (member.type_ctor->type->kind) {
case flat::Type::Kind::kHandle:
case flat::Type::Kind::kRequestHandle:
return std::numeric_limits<DataSize>::max();
case flat::Type::Kind::kArray:
case flat::Type::Kind::kVector:
case flat::Type::Kind::kString:
case flat::Type::Kind::kPrimitive:
case flat::Type::Kind::kIdentifier:
continue;
}
}
return DataSize(0);
}
DataSize max_handles = 0;
for (const auto& member : object.members) {
max_handles += MaxHandles(member);
}
return max_handles;
}
std::any Visit(const flat::Struct::Member& object) override {
return MaxHandles(object.type_ctor->type);
}
std::any Visit(const flat::Table& object) override {
DataSize max_handles = 0;
for (const auto& member : object.members) {
max_handles += MaxHandles(member);
}
return max_handles;
}
std::any Visit(const flat::Table::Member& object) override {
return object.maybe_used ? MaxHandles(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::Table::Member::Used& object) override {
return MaxHandles(object.type_ctor->type);
}
std::any Visit(const flat::Union& object) override {
DataSize max_handles;
for (const auto& member : object.members) {
max_handles = std::max(max_handles, MaxHandles(member));
}
return max_handles;
}
std::any Visit(const flat::Union::Member& object) override {
return object.maybe_used ? MaxHandles(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::Union::Member::Used& object) override {
return MaxHandles(object.type_ctor->type);
}
std::any Visit(const flat::Protocol& object) override { return DataSize(1); }
};
class MaxOutOfLineVisitor final : public TypeShapeVisitor<DataSize> {
public:
using TypeShapeVisitor<DataSize>::TypeShapeVisitor;
std::any Visit(const flat::ArrayType& object) override {
return MaxOutOfLine(object.element_type) * DataSize(object.element_count->value);
}
std::any Visit(const flat::VectorType& object) override {
return ObjectAlign(UnalignedSize(object.element_type, wire_format()) *
object.element_count->value) +
ObjectAlign(MaxOutOfLine(object.element_type)) * object.element_count->value;
}
std::any Visit(const flat::StringType& object) override {
return object.max_size ? ObjectAlign(object.max_size->value)
: std::numeric_limits<DataSize>::max();
}
std::any Visit(const flat::HandleType& object) override { return DataSize(0); }
std::any Visit(const flat::PrimitiveType& object) override { return DataSize(0); }
std::any Visit(const flat::IdentifierType& object) override {
if (object.type_decl->recursive) {
return std::numeric_limits<DataSize>::max();
}
switch (object.nullability) {
case types::Nullability::kNullable: {
switch (object.type_decl->kind) {
case flat::Decl::Kind::kProtocol:
case flat::Decl::Kind::kService:
return DataSize(0);
case flat::Decl::Kind::kStruct:
return ObjectAlign(UnalignedSize(object.type_decl, wire_format())) +
MaxOutOfLine(object.type_decl);
case flat::Decl::Kind::kUnion:
return MaxOutOfLine(object.type_decl);
case flat::Decl::Kind::kBits:
case flat::Decl::Kind::kConst:
case flat::Decl::Kind::kEnum:
case flat::Decl::Kind::kResource:
case flat::Decl::Kind::kTable:
case flat::Decl::Kind::kTypeAlias:
assert(false && "MaxOutOfLine(flat::IdentifierType&) called on invalid nullable kind");
return 0;
}
}
case types::Nullability::kNonnullable:
return MaxOutOfLine(object.type_decl);
}
}
std::any Visit(const flat::RequestHandleType& object) override { return DataSize(0); }
std::any Visit(const flat::Enum& object) override {
return MaxOutOfLine(object.subtype_ctor->type);
}
std::any Visit(const flat::Bits& object) override {
return MaxOutOfLine(object.subtype_ctor->type);
}
std::any Visit(const flat::Service& object) override { return DataSize(0); }
std::any Visit(const flat::Struct& object) override {
DataSize max_out_of_line = 0;
for (const auto& member : object.members) {
max_out_of_line += MaxOutOfLine(member);
}
return max_out_of_line;
}
std::any Visit(const flat::Struct::Member& object) override {
return MaxOutOfLine(object.type_ctor->type);
}
std::any Visit(const flat::Table& object) override {
DataSize max_out_of_line = 0;
for (const auto& member : object.members) {
max_out_of_line += ObjectAlign(UnalignedSize(member, wire_format())) + MaxOutOfLine(member);
}
// The maximum number of envelopes is determined by the maximum _unreserved_ ordinal.
// Any trailing reserved ordinals MUST NOT be present in the array of envelopes.
// For example, a table that looks like
// "table T { 1: int32 i; 2: reserved; 3: uint32 u; 4: reserved; }"
// has an envelope array size of 3, not 4.
assert(object.members.size() <= INT32_MAX);
int max_unreserved_index = -1;
for (int i = static_cast<int>(object.members.size()) - 1; i >= 0; i--) {
if (object.members.at(i).maybe_used) {
max_unreserved_index = i;
break;
}
}
const size_t envelope_array_size = max_unreserved_index == -1 ? 0 : max_unreserved_index + 1;
constexpr DataSize kEnvelopeSize = 16;
return DataSize(envelope_array_size) * kEnvelopeSize + max_out_of_line;
}
std::any Visit(const flat::Table::Member& object) override {
return object.maybe_used ? MaxOutOfLine(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::Table::Member::Used& object) override {
return ObjectAlign(MaxOutOfLine(object.type_ctor->type));
}
std::any Visit(const flat::Union& object) override {
DataSize max_out_of_line;
for (const auto& member : object.members) {
max_out_of_line =
std::max(max_out_of_line,
ObjectAlign(UnalignedSize(member, wire_format())) + MaxOutOfLine(member));
}
return max_out_of_line;
}
std::any Visit(const flat::Union::Member& object) override {
return object.maybe_used ? MaxOutOfLine(*object.maybe_used) : DataSize(0);
}
std::any Visit(const flat::Union::Member::Used& object) override {
return MaxOutOfLine(object.type_ctor->type);
}
std::any Visit(const flat::Protocol& object) override { return DataSize(0); }
private:
DataSize MaxOutOfLine(const flat::Object& object) { return object.Accept(this); }
DataSize MaxOutOfLine(const flat::Object* object) { return MaxOutOfLine(*object); }
};
class HasPaddingVisitor final : public TypeShapeVisitor<bool> {
public:
using TypeShapeVisitor<bool>::TypeShapeVisitor;
std::any Visit(const flat::ArrayType& object) override { return HasPadding(object.element_type); }
std::any Visit(const flat::VectorType& object) override {
auto element_has_innate_padding = [&] { return HasPadding(object.element_type); };
auto element_has_trailing_padding = [&] {
// A vector will always have padding out-of-line for its contents unless its element_type's
// natural size is a multiple of 8.
if (Padding(UnalignedSize(object.element_type, wire_format()), 8) == 0) {
return false;
}
return true;
};
return element_has_trailing_padding() || element_has_innate_padding();
}
std::any Visit(const flat::StringType& object) override { return true; }
std::any Visit(const flat::HandleType& object) override { return false; }
std::any Visit(const flat::PrimitiveType& object) override { return false; }
std::any Visit(const flat::IdentifierType& object) override {
thread_local RecursionDetector recursion_detector;
auto guard = recursion_detector.Enter(&object);
if (!guard) {
return false;
}
switch (object.nullability) {
case types::Nullability::kNullable:
switch (object.type_decl->kind) {
case flat::Decl::Kind::kProtocol:
case flat::Decl::Kind::kService:
return false;
case flat::Decl::Kind::kStruct:
case flat::Decl::Kind::kUnion:
return Padding(UnalignedSize(object.type_decl, wire_format()), 8) > 0 ||
HasPadding(object.type_decl);
case flat::Decl::Kind::kBits:
case flat::Decl::Kind::kConst:
case flat::Decl::Kind::kEnum:
case flat::Decl::Kind::kResource:
case flat::Decl::Kind::kTable:
case flat::Decl::Kind::kTypeAlias:
assert(false && "HasPadding(flat::IdentifierType&) called on invalid nullable kind");
return false;
}
case types::Nullability::kNonnullable:
return HasPadding(object.type_decl);
}
}
std::any Visit(const flat::RequestHandleType& object) override { return false; }
std::any Visit(const flat::Enum& object) override {
return HasPadding(object.subtype_ctor->type);
}
std::any Visit(const flat::Bits& object) override {
return HasPadding(object.subtype_ctor->type);
}
std::any Visit(const flat::Service& object) override { return false; }
std::any Visit(const flat::Struct& object) override {
for (const auto& member : object.members) {
if (HasPadding(member)) {
return true;
}
}
return false;
}
std::any Visit(const flat::Struct::Member& object) override {
return object.fieldshape(wire_format()).Padding() > 0 || HasPadding(object.type_ctor->type);
}
std::any Visit(const flat::Table& object) override {
for (const auto& member : object.members) {
if (HasPadding(member)) {
return true;
}
}
return false;
}
std::any Visit(const flat::Table::Member& object) override {
return object.maybe_used ? HasPadding(*object.maybe_used) : false;
}
std::any Visit(const flat::Table::Member::Used& object) override {
return Padding(UnalignedSize(object.type_ctor->type, wire_format()), 8) > 0 ||
HasPadding(object.type_ctor->type) || object.fieldshape(wire_format()).Padding() > 0;
}
std::any Visit(const flat::Union& object) override {
// TODO(fxbug.dev/36332): Unions currently return true for has_padding in all cases, which
// should be fixed.
return true;
}
std::any Visit(const flat::Union::Member& object) override {
return object.maybe_used ? HasPadding(*object.maybe_used) : false;
}
std::any Visit(const flat::Union::Member::Used& object) override {
// TODO(fxbug.dev/36331): This code only accounts for inline padding for the union member. We
// also need to account for out-of-line padding.
return object.fieldshape(wire_format()).Padding() > 0;
}
std::any Visit(const flat::Protocol& object) override { return false; }
private:
bool HasPadding(const flat::Object& object) { return object.Accept(this); }
bool HasPadding(const flat::Object* object) { return HasPadding(*object); }
};
class HasFlexibleEnvelopeVisitor final : public TypeShapeVisitor<bool> {
public:
using TypeShapeVisitor<bool>::TypeShapeVisitor;
std::any Visit(const flat::ArrayType& object) override {
return HasFlexibleEnvelope(object.element_type, wire_format());
}
std::any Visit(const flat::VectorType& object) override {
return HasFlexibleEnvelope(object.element_type, wire_format());
}
std::any Visit(const flat::StringType& object) override { return false; }
std::any Visit(const flat::HandleType& object) override { return false; }
std::any Visit(const flat::PrimitiveType& object) override { return false; }
std::any Visit(const flat::IdentifierType& object) override {
thread_local RecursionDetector recursion_detector;
auto guard = recursion_detector.Enter(&object);
if (!guard) {
return false;
}
return HasFlexibleEnvelope(object.type_decl, wire_format());
}
std::any Visit(const flat::RequestHandleType& object) override { return false; }
std::any Visit(const flat::Enum& object) override {
return HasFlexibleEnvelope(object.subtype_ctor->type, wire_format());
}
std::any Visit(const flat::Bits& object) override {
return HasFlexibleEnvelope(object.subtype_ctor->type, wire_format());
}
std::any Visit(const flat::Service& object) override { return false; }
std::any Visit(const flat::Struct& object) override {
for (const auto& member : object.members) {
if (HasFlexibleEnvelope(member, wire_format())) {
return true;
}
}
return false;
}
std::any Visit(const flat::Struct::Member& object) override {
return HasFlexibleEnvelope(object.type_ctor->type, wire_format());
}
std::any Visit(const flat::Table& object) override {
if (object.strictness == types::Strictness::kFlexible) {
return true;
}
for (const auto& member : object.members) {
if (HasFlexibleEnvelope(member, wire_format())) {
return true;
}
}
return false;
}
std::any Visit(const flat::Table::Member& object) override {
return object.maybe_used ? HasFlexibleEnvelope(*object.maybe_used, wire_format()) : false;
}
std::any Visit(const flat::Table::Member::Used& object) override {
return HasFlexibleEnvelope(object.type_ctor->type, wire_format());
}
std::any Visit(const flat::Union& object) override {
if (object.strictness == types::Strictness::kFlexible) {
return true;
}
for (const auto& member : object.members) {
if (HasFlexibleEnvelope(member, wire_format())) {
return true;
}
}
return false;
}
std::any Visit(const flat::Union::Member& object) override {
return object.maybe_used ? HasFlexibleEnvelope(*object.maybe_used, wire_format()) : false;
}
std::any Visit(const flat::Union::Member::Used& object) override {
return HasFlexibleEnvelope(object.type_ctor->type, wire_format());
}
std::any Visit(const flat::Protocol& object) override { return false; }
};
DataSize UnalignedSize(const flat::Object& object, const WireFormat wire_format) {
UnalignedSizeVisitor v(wire_format);
return object.Accept(&v);
}
[[maybe_unused]] DataSize UnalignedSize(const flat::Object* object, const WireFormat wire_format) {
return UnalignedSize(*object, wire_format);
}
DataSize Alignment(const flat::Object& object, const WireFormat wire_format) {
AlignmentVisitor v(wire_format);
return object.Accept(&v);
}
[[maybe_unused]] DataSize Alignment(const flat::Object* object, const WireFormat wire_format) {
return Alignment(*object, wire_format);
}
DataSize Depth(const flat::Object& object, const WireFormat wire_format) {
DepthVisitor v(wire_format);
return object.Accept(&v);
}
[[maybe_unused]] DataSize Depth(const flat::Object* object, const WireFormat wire_format) {
return Depth(*object, wire_format);
}
DataSize MaxHandles(const flat::Object& object) {
MaxHandlesVisitor v;
return object.Accept(&v);
}
[[maybe_unused]] DataSize MaxHandles(const flat::Object* object) { return MaxHandles(*object); }
DataSize MaxOutOfLine(const flat::Object& object, const WireFormat wire_format) {
MaxOutOfLineVisitor v(wire_format);
return object.Accept(&v);
}
[[maybe_unused]] DataSize MaxOutOfLine(const flat::Object* object, const WireFormat wire_format) {
return MaxOutOfLine(*object, wire_format);
}
bool HasPadding(const flat::Object& object, const WireFormat wire_format) {
HasPaddingVisitor v(wire_format);
return object.Accept(&v);
}
[[maybe_unused]] bool HasPadding(const flat::Object* object, const WireFormat wire_format) {
return HasPadding(*object, wire_format);
}
bool HasFlexibleEnvelope(const flat::Object& object, const WireFormat wire_format) {
HasFlexibleEnvelopeVisitor v(wire_format);
return object.Accept(&v);
}
[[maybe_unused]] bool HasFlexibleEnvelope(const flat::Object* object,
const WireFormat wire_format) {
return HasFlexibleEnvelope(*object, wire_format);
}
} // namespace
namespace fidl {
uint32_t OldWireFormatDepth(const flat::Object& object) {
OldWireFormatDepthVisitor v(WireFormat::kV1NoEe);
return object.Accept(&v);
}
uint32_t OldWireFormatDepth(const flat::Object* object) { return OldWireFormatDepth(*object); }
TypeShape::TypeShape(const flat::Object& object, WireFormat wire_format)
: inline_size(::AlignedSize(object, wire_format)),
alignment(::Alignment(object, wire_format)),
depth(::Depth(object, wire_format)),
max_handles(::MaxHandles(object)),
max_out_of_line(::MaxOutOfLine(object, wire_format)),
has_padding(::HasPadding(object, wire_format)),
has_flexible_envelope(::HasFlexibleEnvelope(object, wire_format)) {}
TypeShape::TypeShape(const flat::Object* object, WireFormat wire_format)
: TypeShape(*object, wire_format) {}
FieldShape::FieldShape(const flat::StructMember& member, const WireFormat wire_format) {
assert(member.parent);
const flat::Struct& parent = *member.parent;
// Our parent struct must have at least one member if fieldshape() on a member is being
// called.
assert(parent.members.size());
const std::vector<flat::StructMember>& members = parent.members;
if (parent.is_request_or_response && wire_format != WireFormat::kV1Header) {
offset += kSizeOfTransactionHeader;
}
for (size_t i = 0; i < members.size(); i++) {
const flat::StructMember* it = &members.at(i);
DataSize alignment;
if (i + 1 < members.size()) {
const auto& next = members.at(i + 1);
alignment = Alignment(next, wire_format);
} else {
alignment = Alignment(parent, wire_format);
}
uint32_t size = UnalignedSize(*it, wire_format);
padding = ::Padding(offset + size, alignment);
if (it == &member)
break;
offset += size + padding;
}
}
FieldShape::FieldShape(const flat::TableMemberUsed& member, const WireFormat wire_format)
: padding(::Padding(UnalignedSize(member, wire_format), 8)) {}
FieldShape::FieldShape(const flat::UnionMemberUsed& member, const WireFormat wire_format)
: offset(0u),
padding(
::Padding(UnalignedSize(member, wire_format), Alignment(member.parent, wire_format))) {}
} // namespace fidl