Merge pull request #19683 from lorentey/inline-customization-points

[stdlib] Force-inline some Sequence/Collection customization points
diff --git a/lib/AST/TypeJoinMeet.cpp b/lib/AST/TypeJoinMeet.cpp
index 3c5d9c8..4661b4d 100644
--- a/lib/AST/TypeJoinMeet.cpp
+++ b/lib/AST/TypeJoinMeet.cpp
@@ -51,6 +51,9 @@
   }
 
   static CanType getSuperclassJoin(CanType first, CanType second);
+  CanType computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
+                                         ArrayRef<Type> secondMembers);
+
 
   CanType visitErrorType(CanType second);
   CanType visitTupleType(CanType second);
@@ -105,10 +108,10 @@
 
     // Likewise, rather than making every visitor deal with Any,
     // always dispatch to the protocol composition side of the join.
-    if (first->isAny())
+    if (first->is<ProtocolCompositionType>())
       return TypeJoin(second).visit(first);
 
-    if (second->isAny())
+    if (second->is<ProtocolCompositionType>())
       return TypeJoin(first).visit(second);
 
     // Otherwise the first type might be an optional (or not), so
@@ -184,16 +187,6 @@
   return getSuperclassJoin(First, second);
 }
 
-CanType TypeJoin::visitProtocolType(CanType second) {
-  assert(First != second);
-
-  // FIXME: We should compute a tighter bound and/or return nullptr if
-  // we cannot. We do this now because existing tests rely on
-  // producing Any for the join of protocols that have a common
-  // supertype.
-  return TheAnyType;
-}
-
 CanType TypeJoin::visitBoundGenericClassType(CanType second) {
   return getSuperclassJoin(First, second);
 }
@@ -352,16 +345,111 @@
   return Unimplemented;
 }
 
+// Use the distributive law to compute the join of the protocol
+// compositions.
+//
+//   (A ^ B) v (C ^ D)
+// = (A v C) ^ (A v D) ^ (B v C) ^ (B v D)
+//
+// In general this law only applies to distributive lattices.
+//
+// In our case, this should be safe because our meet operation only
+// produces an existing nominal type when it is one of the operands of
+// the operation. So we can never arbitrarily climb down the lattice
+// in ways that would break distributivity.
+//
+CanType TypeJoin::computeProtocolCompositionJoin(ArrayRef<Type> firstMembers,
+                                                 ArrayRef<Type> secondMembers) {
+  SmallVector<Type, 8> result;
+  for (auto first : firstMembers) {
+    for (auto second : secondMembers) {
+      auto joined = Type::join(first, second);
+      if (!joined)
+        return Unimplemented;
+
+      if ((*joined)->isAny())
+        continue;
+
+      result.push_back(*joined);
+    }
+  }
+
+  if (result.empty())
+    return TheAnyType;
+
+  auto &ctx = result[0]->getASTContext();
+  return ProtocolCompositionType::get(ctx, result, false)->getCanonicalType();
+}
+
 CanType TypeJoin::visitProtocolCompositionType(CanType second) {
+  // The join of Any and a no-escape function doesn't exist; it isn't
+  // Any. If it were Any, it would mean we would allow these functions
+  // to escape through Any.
   if (second->isAny()) {
     auto *fnTy = First->getAs<AnyFunctionType>();
     if (fnTy && fnTy->getExtInfo().isNoEscape())
       return Nonexistent;
 
-    return second;
+    return TheAnyType;
   }
 
-  return Unimplemented;
+  assert(First != second);
+
+  // FIXME: Handle other types here.
+  if (!First->isExistentialType())
+    return TheAnyType;
+
+  SmallVector<Type, 1> protocolType;
+  ArrayRef<Type> firstMembers;
+  if (First->is<ProtocolType>()) {
+    protocolType.push_back(First);
+    firstMembers = protocolType;
+  } else {
+    firstMembers = cast<ProtocolCompositionType>(First)->getMembers();
+  }
+  auto secondMembers = cast<ProtocolCompositionType>(second)->getMembers();
+
+  return computeProtocolCompositionJoin(firstMembers, secondMembers);
+}
+
+CanType TypeJoin::visitProtocolType(CanType second) {
+  assert(First != second);
+
+  assert(!First->is<ProtocolCompositionType>() &&
+         !second->is<ProtocolCompositionType>());
+
+  // FIXME: Handle other types here.
+  if (First->getKind() != second->getKind())
+    return TheAnyType;
+
+  auto *firstDecl =
+    cast<ProtocolDecl>(First->getNominalOrBoundGenericNominal());
+
+  auto *secondDecl =
+    cast<ProtocolDecl>(second->getNominalOrBoundGenericNominal());
+
+  if (firstDecl->getInheritedProtocols().empty() &&
+      secondDecl->getInheritedProtocols().empty())
+    return TheAnyType;
+
+  if (firstDecl->inheritsFrom(secondDecl))
+    return second;
+
+  if (secondDecl->inheritsFrom(firstDecl))
+    return First;
+
+  // One isn't the supertype of the other, so instead, treat each as
+  // if it's a protocol composition of its inherited members, and join
+  // those.
+  SmallVector<Type, 4> firstMembers;
+  for (auto *decl : firstDecl->getInheritedProtocols())
+    firstMembers.push_back(decl->getDeclaredInterfaceType());
+
+  SmallVector<Type, 4> secondMembers;
+  for (auto *decl : secondDecl->getInheritedProtocols())
+    secondMembers.push_back(decl->getDeclaredInterfaceType());
+
+  return computeProtocolCompositionJoin(firstMembers, secondMembers);
 }
 
 CanType TypeJoin::visitLValueType(CanType second) { return Unimplemented; }
diff --git a/stdlib/public/core/Bitset.swift b/stdlib/public/core/Bitset.swift
index 4e5a1d4..3a0d971 100644
--- a/stdlib/public/core/Bitset.swift
+++ b/stdlib/public/core/Bitset.swift
@@ -38,7 +38,10 @@
   @inline(__always)
   internal static func word(for element: Int) -> Int {
     _sanityCheck(element >= 0)
-    return element / Word.capacity
+    // Note: We perform on UInts to get faster unsigned math (shifts).
+    let element = UInt(bitPattern: element)
+    let capacity = UInt(bitPattern: Word.capacity)
+    return Int(bitPattern: element / capacity)
   }
 
   @inlinable
@@ -61,7 +64,7 @@
   @inline(__always)
   internal static func join(word: Int, bit: Int) -> Int {
     _sanityCheck(bit >= 0 && bit < Word.capacity)
-    return word * Word.capacity + bit
+    return word &* Word.capacity + bit
   }
 }
 
@@ -69,14 +72,14 @@
   @inlinable
   @inline(__always)
   internal static func wordCount(forCapacity capacity: Int) -> Int {
-    return (capacity + Word.capacity - 1) / Word.capacity
+    return word(for: capacity + Word.capacity - 1)
   }
 
   @inlinable
   internal var capacity: Int {
     @inline(__always)
     get {
-      return wordCount * Word.capacity
+      return wordCount &* Word.capacity
     }
   }
 
diff --git a/stdlib/public/core/Dictionary.swift b/stdlib/public/core/Dictionary.swift
index 52715dd..0054651 100644
--- a/stdlib/public/core/Dictionary.swift
+++ b/stdlib/public/core/Dictionary.swift
@@ -1544,11 +1544,13 @@
     internal var _base: Dictionary<Key, Value>.Iterator
 
     @inlinable
+    @inline(__always)
     internal init(_ base: Dictionary<Key, Value>.Iterator) {
       self._base = base
     }
 
     @inlinable
+    @inline(__always)
     public mutating func next() -> Key? {
 #if _runtime(_ObjC)
       if case .cocoa(let cocoa) = _base._variant {
@@ -1562,6 +1564,7 @@
   }
 
   @inlinable
+  @inline(__always)
   public func makeIterator() -> Iterator {
     return Iterator(_variant.makeIterator())
   }
@@ -1574,11 +1577,13 @@
     internal var _base: Dictionary<Key, Value>.Iterator
 
     @inlinable
+    @inline(__always)
     internal init(_ base: Dictionary<Key, Value>.Iterator) {
       self._base = base
     }
 
     @inlinable
+    @inline(__always)
     public mutating func next() -> Value? {
 #if _runtime(_ObjC)
       if case .cocoa(let cocoa) = _base._variant {
@@ -1592,6 +1597,7 @@
   }
 
   @inlinable
+  @inline(__always)
   public func makeIterator() -> Iterator {
     return Iterator(_variant.makeIterator())
   }
diff --git a/stdlib/public/core/HashTable.swift b/stdlib/public/core/HashTable.swift
index f410b29..39ee521 100644
--- a/stdlib/public/core/HashTable.swift
+++ b/stdlib/public/core/HashTable.swift
@@ -233,6 +233,7 @@
     var word: Word
 
     @inlinable
+    @inline(__always)
     init(_ hashTable: _HashTable) {
       self.hashTable = hashTable
       self.wordIndex = 0
@@ -260,6 +261,7 @@
   }
 
   @inlinable
+  @inline(__always)
   internal func makeIterator() -> Iterator {
     return Iterator(self)
   }
diff --git a/stdlib/public/core/NativeDictionary.swift b/stdlib/public/core/NativeDictionary.swift
index 502ba1f..2a7605f 100644
--- a/stdlib/public/core/NativeDictionary.swift
+++ b/stdlib/public/core/NativeDictionary.swift
@@ -656,6 +656,7 @@
     internal var iterator: _HashTable.Iterator
 
     @inlinable
+    @inline(__always)
     init(_ base: __owned _NativeDictionary) {
       self.base = base
       self.iterator = base.hashTable.makeIterator()
@@ -673,18 +674,21 @@
   internal typealias Element = (key: Key, value: Value)
 
   @inlinable
+  @inline(__always)
   internal mutating func nextKey() -> Key? {
     guard let index = iterator.next() else { return nil }
     return base.uncheckedKey(at: index)
   }
 
   @inlinable
+  @inline(__always)
   internal mutating func nextValue() -> Value? {
     guard let index = iterator.next() else { return nil }
     return base.uncheckedValue(at: index)
   }
 
   @inlinable
+  @inline(__always)
   internal mutating func next() -> Element? {
     guard let index = iterator.next() else { return nil }
     let key = base.uncheckedKey(at: index)
diff --git a/stdlib/public/core/NativeSet.swift b/stdlib/public/core/NativeSet.swift
index 636dc24..5ebeba3 100644
--- a/stdlib/public/core/NativeSet.swift
+++ b/stdlib/public/core/NativeSet.swift
@@ -494,6 +494,7 @@
     internal var iterator: _HashTable.Iterator
 
     @inlinable
+    @inline(__always)
     init(_ base: __owned _NativeSet) {
       self.base = base
       self.iterator = base.hashTable.makeIterator()
@@ -501,6 +502,7 @@
   }
 
   @inlinable
+  @inline(__always)
   internal __consuming func makeIterator() -> Iterator {
     return Iterator(self)
   }
@@ -508,6 +510,7 @@
 
 extension _NativeSet.Iterator: IteratorProtocol {
   @inlinable
+  @inline(__always)
   internal mutating func next() -> Element? {
     guard let index = iterator.next() else { return nil }
     return base.uncheckedElement(at: index)
diff --git a/test/Sema/type_join.swift b/test/Sema/type_join.swift
index e3c15d1..96d9724 100644
--- a/test/Sema/type_join.swift
+++ b/test/Sema/type_join.swift
@@ -5,6 +5,31 @@
 class C {}
 class D : C {}
 
+protocol L {}
+protocol M : L {}
+protocol N : L {}
+protocol P : M {}
+protocol Q : M {}
+protocol R : L {}
+protocol Y {}
+
+protocol FakeEquatable {}
+protocol FakeHashable : FakeEquatable {}
+protocol FakeExpressibleByIntegerLiteral {}
+protocol FakeNumeric : FakeEquatable, FakeExpressibleByIntegerLiteral {}
+protocol FakeSignedNumeric : FakeNumeric {}
+protocol FakeComparable : FakeEquatable {}
+protocol FakeStrideable : FakeComparable {}
+protocol FakeCustomStringConvertible {}
+protocol FakeBinaryInteger : FakeHashable, FakeNumeric, FakeCustomStringConvertible, FakeStrideable {}
+protocol FakeLosslessStringConvertible {}
+protocol FakeFixedWidthInteger : FakeBinaryInteger, FakeLosslessStringConvertible {}
+protocol FakeUnsignedInteger : FakeBinaryInteger {}
+protocol FakeSignedInteger : FakeBinaryInteger, FakeSignedNumeric {}
+protocol FakeFloatingPoint : FakeSignedNumeric, FakeStrideable, FakeHashable {}
+protocol FakeExpressibleByFloatLiteral {}
+protocol FakeBinaryFloatingPoint : FakeFloatingPoint, FakeExpressibleByFloatLiteral {}
+
 func expectEqualType<T>(_: T.Type, _: T.Type) {}
 func commonSupertype<T>(_: T, _: T) -> T {}
 
@@ -38,6 +63,27 @@
 expectEqualType(Builtin.type_join(Builtin.Int32.self, Builtin.Int1.self), Any.self)
 expectEqualType(Builtin.type_join(Builtin.Int1.self, Builtin.Int32.self), Any.self)
 
+expectEqualType(Builtin.type_join(L.self, L.self), L.self)
+expectEqualType(Builtin.type_join(L.self, M.self), L.self)
+expectEqualType(Builtin.type_join(L.self, P.self), L.self)
+expectEqualType(Builtin.type_join(L.self, Y.self), Any.self)
+expectEqualType(Builtin.type_join(N.self, P.self), L.self)
+expectEqualType(Builtin.type_join(Q.self, P.self), M.self)
+expectEqualType(Builtin.type_join((N & P).self, (Q & R).self), M.self)
+expectEqualType(Builtin.type_join((Q & P).self, (Y & R).self), L.self)
+expectEqualType(Builtin.type_join(FakeEquatable.self, FakeEquatable.self), FakeEquatable.self)
+expectEqualType(Builtin.type_join(FakeHashable.self, FakeEquatable.self), FakeEquatable.self)
+expectEqualType(Builtin.type_join(FakeEquatable.self, FakeHashable.self), FakeEquatable.self)
+expectEqualType(Builtin.type_join(FakeNumeric.self, FakeHashable.self), FakeEquatable.self)
+expectEqualType(Builtin.type_join((FakeHashable & FakeStrideable).self, (FakeHashable & FakeNumeric).self),
+                                  FakeHashable.self)
+expectEqualType(Builtin.type_join((FakeNumeric & FakeStrideable).self,
+                                  (FakeHashable & FakeNumeric).self), FakeNumeric.self)
+expectEqualType(Builtin.type_join(FakeBinaryInteger.self, FakeFloatingPoint.self),
+                                  (FakeHashable & FakeNumeric & FakeStrideable).self)
+expectEqualType(Builtin.type_join(FakeFloatingPoint.self, FakeBinaryInteger.self),
+                                  (FakeHashable & FakeNumeric & FakeStrideable).self)
+
 func joinFunctions(
   _ escaping: @escaping () -> (),
   _ nonescaping: () -> ()