Merge pull request #19683 from lorentey/inline-customization-points
[stdlib] Force-inline some Sequence/Collection customization points
diff --git a/lib/AST/TypeJoinMeet.cpp b/lib/AST/TypeJoinMeet.cpp
index 3c5d9c8..4661b4d 100644
--- a/lib/AST/TypeJoinMeet.cpp
+++ b/lib/AST/TypeJoinMeet.cpp
@@ -51,6 +51,9 @@
}
static CanType getSuperclassJoin(CanType first, CanType second);
+ CanType computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
+ ArrayRef<Type> secondMembers);
+
CanType visitErrorType(CanType second);
CanType visitTupleType(CanType second);
@@ -105,10 +108,10 @@
// Likewise, rather than making every visitor deal with Any,
// always dispatch to the protocol composition side of the join.
- if (first->isAny())
+ if (first->is<ProtocolCompositionType>())
return TypeJoin(second).visit(first);
- if (second->isAny())
+ if (second->is<ProtocolCompositionType>())
return TypeJoin(first).visit(second);
// Otherwise the first type might be an optional (or not), so
@@ -184,16 +187,6 @@
return getSuperclassJoin(First, second);
}
-CanType TypeJoin::visitProtocolType(CanType second) {
- assert(First != second);
-
- // FIXME: We should compute a tighter bound and/or return nullptr if
- // we cannot. We do this now because existing tests rely on
- // producing Any for the join of protocols that have a common
- // supertype.
- return TheAnyType;
-}
-
CanType TypeJoin::visitBoundGenericClassType(CanType second) {
return getSuperclassJoin(First, second);
}
@@ -352,16 +345,111 @@
return Unimplemented;
}
+// Use the distributive law to compute the join of the protocol
+// compositions.
+//
+// (A ^ B) v (C ^ D)
+// = (A v C) ^ (A v D) ^ (B v C) ^ (B v D)
+//
+// In general this law only applies to distributive lattices.
+//
+// In our case, this should be safe because our meet operation only
+// produces an existing nominal type when it is one of the operands of
+// the operation. So we can never arbitrarily climb down the lattice
+// in ways that would break distributivity.
+//
+CanType TypeJoin::computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
+ ArrayRef<Type> secondMembers) {
+ SmallVector<Type, 8> result;
+ for (auto first : firstMembers) {
+ for (auto second : secondMembers) {
+ auto joined = Type::join(first, second);
+ if (!joined)
+ return Unimplemented;
+
+ if ((*joined)->isAny())
+ continue;
+
+ result.push_back(*joined);
+ }
+ }
+
+ if (result.empty())
+ return TheAnyType;
+
+ auto &ctx = result[0]->getASTContext();
+ return ProtocolCompositionType::get(ctx, result, false)->getCanonicalType();
+}
+
CanType TypeJoin::visitProtocolCompositionType(CanType second) {
+ // The join of Any and a no-escape function doesn't exist; it isn't
+ // Any. If it were Any, it would mean we would allow these functions
+ // to escape through Any.
if (second->isAny()) {
auto *fnTy = First->getAs<AnyFunctionType>();
if (fnTy && fnTy->getExtInfo().isNoEscape())
return Nonexistent;
- return second;
+ return TheAnyType;
}
- return Unimplemented;
+ assert(First != second);
+
+ // FIXME: Handle other types here.
+ if (!First->isExistentialType())
+ return TheAnyType;
+
+ SmallVector<Type, 1> protocolType;
+ ArrayRef<Type> firstMembers;
+ if (First->is<ProtocolType>()) {
+ protocolType.push_back(First);
+ firstMembers = protocolType;
+ } else {
+ firstMembers = cast<ProtocolCompositionType>(First)->getMembers();
+ }
+ auto secondMembers = cast<ProtocolCompositionType>(second)->getMembers();
+
+ return computeProtocolCompositionJoin(firstMembers, secondMembers);
+}
+
+CanType TypeJoin::visitProtocolType(CanType second) {
+ assert(First != second);
+
+ assert(!First->is<ProtocolCompositionType>() &&
+ !second->is<ProtocolCompositionType>());
+
+ // FIXME: Handle other types here.
+ if (First->getKind() != second->getKind())
+ return TheAnyType;
+
+ auto *firstDecl =
+ cast<ProtocolDecl>(First->getNominalOrBoundGenericNominal());
+
+ auto *secondDecl =
+ cast<ProtocolDecl>(second->getNominalOrBoundGenericNominal());
+
+ if (firstDecl->getInheritedProtocols().empty() &&
+ secondDecl->getInheritedProtocols().empty())
+ return TheAnyType;
+
+ if (firstDecl->inheritsFrom(secondDecl))
+ return second;
+
+ if (secondDecl->inheritsFrom(firstDecl))
+ return First;
+
+ // One isn't the supertype of the other, so instead, treat each as
+ // if it's a protocol composition of its inherited members, and join
+ // those.
+ SmallVector<Type, 4> firstMembers;
+ for (auto *decl : firstDecl->getInheritedProtocols())
+ firstMembers.push_back(decl->getDeclaredInterfaceType());
+
+ SmallVector<Type, 4> secondMembers;
+ for (auto *decl : secondDecl->getInheritedProtocols())
+ secondMembers.push_back(decl->getDeclaredInterfaceType());
+
+ return computeProtocolCompositionJoin(firstMembers, secondMembers);
}
CanType TypeJoin::visitLValueType(CanType second) { return Unimplemented; }
diff --git a/stdlib/public/core/Bitset.swift b/stdlib/public/core/Bitset.swift
index 4e5a1d4..3a0d971 100644
--- a/stdlib/public/core/Bitset.swift
+++ b/stdlib/public/core/Bitset.swift
@@ -38,7 +38,10 @@
@inline(__always)
internal static func word(for element: Int) -> Int {
_sanityCheck(element >= 0)
- return element / Word.capacity
+ // Note: We perform on UInts to get faster unsigned math (shifts).
+ let element = UInt(bitPattern: element)
+ let capacity = UInt(bitPattern: Word.capacity)
+ return Int(bitPattern: element / capacity)
}
@inlinable
@@ -61,7 +64,7 @@
@inline(__always)
internal static func join(word: Int, bit: Int) -> Int {
_sanityCheck(bit >= 0 && bit < Word.capacity)
- return word * Word.capacity + bit
+ return word &* Word.capacity + bit
}
}
@@ -69,14 +72,14 @@
@inlinable
@inline(__always)
internal static func wordCount(forCapacity capacity: Int) -> Int {
- return (capacity + Word.capacity - 1) / Word.capacity
+ return word(for: capacity + Word.capacity - 1)
}
@inlinable
internal var capacity: Int {
@inline(__always)
get {
- return wordCount * Word.capacity
+ return wordCount &* Word.capacity
}
}
diff --git a/stdlib/public/core/Dictionary.swift b/stdlib/public/core/Dictionary.swift
index 52715dd..0054651 100644
--- a/stdlib/public/core/Dictionary.swift
+++ b/stdlib/public/core/Dictionary.swift
@@ -1544,11 +1544,13 @@
internal var _base: Dictionary<Key, Value>.Iterator
@inlinable
+ @inline(__always)
internal init(_ base: Dictionary<Key, Value>.Iterator) {
self._base = base
}
@inlinable
+ @inline(__always)
public mutating func next() -> Key? {
#if _runtime(_ObjC)
if case .cocoa(let cocoa) = _base._variant {
@@ -1562,6 +1564,7 @@
}
@inlinable
+ @inline(__always)
public func makeIterator() -> Iterator {
return Iterator(_variant.makeIterator())
}
@@ -1574,11 +1577,13 @@
internal var _base: Dictionary<Key, Value>.Iterator
@inlinable
+ @inline(__always)
internal init(_ base: Dictionary<Key, Value>.Iterator) {
self._base = base
}
@inlinable
+ @inline(__always)
public mutating func next() -> Value? {
#if _runtime(_ObjC)
if case .cocoa(let cocoa) = _base._variant {
@@ -1592,6 +1597,7 @@
}
@inlinable
+ @inline(__always)
public func makeIterator() -> Iterator {
return Iterator(_variant.makeIterator())
}
diff --git a/stdlib/public/core/HashTable.swift b/stdlib/public/core/HashTable.swift
index f410b29..39ee521 100644
--- a/stdlib/public/core/HashTable.swift
+++ b/stdlib/public/core/HashTable.swift
@@ -233,6 +233,7 @@
var word: Word
@inlinable
+ @inline(__always)
init(_ hashTable: _HashTable) {
self.hashTable = hashTable
self.wordIndex = 0
@@ -260,6 +261,7 @@
}
@inlinable
+ @inline(__always)
internal func makeIterator() -> Iterator {
return Iterator(self)
}
diff --git a/stdlib/public/core/NativeDictionary.swift b/stdlib/public/core/NativeDictionary.swift
index 502ba1f..2a7605f 100644
--- a/stdlib/public/core/NativeDictionary.swift
+++ b/stdlib/public/core/NativeDictionary.swift
@@ -656,6 +656,7 @@
internal var iterator: _HashTable.Iterator
@inlinable
+ @inline(__always)
init(_ base: __owned _NativeDictionary) {
self.base = base
self.iterator = base.hashTable.makeIterator()
@@ -673,18 +674,21 @@
internal typealias Element = (key: Key, value: Value)
@inlinable
+ @inline(__always)
internal mutating func nextKey() -> Key? {
guard let index = iterator.next() else { return nil }
return base.uncheckedKey(at: index)
}
@inlinable
+ @inline(__always)
internal mutating func nextValue() -> Value? {
guard let index = iterator.next() else { return nil }
return base.uncheckedValue(at: index)
}
@inlinable
+ @inline(__always)
internal mutating func next() -> Element? {
guard let index = iterator.next() else { return nil }
let key = base.uncheckedKey(at: index)
diff --git a/stdlib/public/core/NativeSet.swift b/stdlib/public/core/NativeSet.swift
index 636dc24..5ebeba3 100644
--- a/stdlib/public/core/NativeSet.swift
+++ b/stdlib/public/core/NativeSet.swift
@@ -494,6 +494,7 @@
internal var iterator: _HashTable.Iterator
@inlinable
+ @inline(__always)
init(_ base: __owned _NativeSet) {
self.base = base
self.iterator = base.hashTable.makeIterator()
@@ -501,6 +502,7 @@
}
@inlinable
+ @inline(__always)
internal __consuming func makeIterator() -> Iterator {
return Iterator(self)
}
@@ -508,6 +510,7 @@
extension _NativeSet.Iterator: IteratorProtocol {
@inlinable
+ @inline(__always)
internal mutating func next() -> Element? {
guard let index = iterator.next() else { return nil }
return base.uncheckedElement(at: index)
diff --git a/test/Sema/type_join.swift b/test/Sema/type_join.swift
index e3c15d1..96d9724 100644
--- a/test/Sema/type_join.swift
+++ b/test/Sema/type_join.swift
@@ -5,6 +5,31 @@
class C {}
class D : C {}
+protocol L {}
+protocol M : L {}
+protocol N : L {}
+protocol P : M {}
+protocol Q : M {}
+protocol R : L {}
+protocol Y {}
+
+protocol FakeEquatable {}
+protocol FakeHashable : FakeEquatable {}
+protocol FakeExpressibleByIntegerLiteral {}
+protocol FakeNumeric : FakeEquatable, FakeExpressibleByIntegerLiteral {}
+protocol FakeSignedNumeric : FakeNumeric {}
+protocol FakeComparable : FakeEquatable {}
+protocol FakeStrideable : FakeComparable {}
+protocol FakeCustomStringConvertible {}
+protocol FakeBinaryInteger : FakeHashable, FakeNumeric, FakeCustomStringConvertible, FakeStrideable {}
+protocol FakeLosslessStringConvertible {}
+protocol FakeFixedWidthInteger : FakeBinaryInteger, FakeLosslessStringConvertible {}
+protocol FakeUnsignedInteger : FakeBinaryInteger {}
+protocol FakeSignedInteger : FakeBinaryInteger, FakeSignedNumeric {}
+protocol FakeFloatingPoint : FakeSignedNumeric, FakeStrideable, FakeHashable {}
+protocol FakeExpressibleByFloatLiteral {}
+protocol FakeBinaryFloatingPoint : FakeFloatingPoint, FakeExpressibleByFloatLiteral {}
+
func expectEqualType<T>(_: T.Type, _: T.Type) {}
func commonSupertype<T>(_: T, _: T) -> T {}
@@ -38,6 +63,27 @@
expectEqualType(Builtin.type_join(Builtin.Int32.self, Builtin.Int1.self), Any.self)
expectEqualType(Builtin.type_join(Builtin.Int1.self, Builtin.Int32.self), Any.self)
+expectEqualType(Builtin.type_join(L.self, L.self), L.self)
+expectEqualType(Builtin.type_join(L.self, M.self), L.self)
+expectEqualType(Builtin.type_join(L.self, P.self), L.self)
+expectEqualType(Builtin.type_join(L.self, Y.self), Any.self)
+expectEqualType(Builtin.type_join(N.self, P.self), L.self)
+expectEqualType(Builtin.type_join(Q.self, P.self), M.self)
+expectEqualType(Builtin.type_join((N & P).self, (Q & R).self), M.self)
+expectEqualType(Builtin.type_join((Q & P).self, (Y & R).self), L.self)
+expectEqualType(Builtin.type_join(FakeEquatable.self, FakeEquatable.self), FakeEquatable.self)
+expectEqualType(Builtin.type_join(FakeHashable.self, FakeEquatable.self), FakeEquatable.self)
+expectEqualType(Builtin.type_join(FakeEquatable.self, FakeHashable.self), FakeEquatable.self)
+expectEqualType(Builtin.type_join(FakeNumeric.self, FakeHashable.self), FakeEquatable.self)
+expectEqualType(Builtin.type_join((FakeHashable & FakeStrideable).self, (FakeHashable & FakeNumeric).self),
+ FakeHashable.self)
+expectEqualType(Builtin.type_join((FakeNumeric & FakeStrideable).self,
+ (FakeHashable & FakeNumeric).self), FakeNumeric.self)
+expectEqualType(Builtin.type_join(FakeBinaryInteger.self, FakeFloatingPoint.self),
+ (FakeHashable & FakeNumeric & FakeStrideable).self)
+expectEqualType(Builtin.type_join(FakeFloatingPoint.self, FakeBinaryInteger.self),
+ (FakeHashable & FakeNumeric & FakeStrideable).self)
+
func joinFunctions(
_ escaping: @escaping () -> (),
_ nonescaping: () -> ()