Support application of AnyKeyPath/PartialKeyPath.
rdar://problem/32237567
diff --git a/include/swift/AST/KnownDecls.def b/include/swift/AST/KnownDecls.def
index c1997fc..60ac981 100644
--- a/include/swift/AST/KnownDecls.def
+++ b/include/swift/AST/KnownDecls.def
@@ -70,6 +70,8 @@
FUNC_DECL(UnsafeBitCast, "unsafeBitCast")
+FUNC_DECL(ProjectKeyPathAny, "_projectKeyPathAny")
+FUNC_DECL(ProjectKeyPathPartial, "_projectKeyPathPartial")
FUNC_DECL(ProjectKeyPathReadOnly, "_projectKeyPathReadOnly")
FUNC_DECL(ProjectKeyPathWritable, "_projectKeyPathWritable")
FUNC_DECL(ProjectKeyPathReferenceWritable, "_projectKeyPathReferenceWritable")
diff --git a/include/swift/AST/KnownStdlibTypes.def b/include/swift/AST/KnownStdlibTypes.def
index 6298287..8120bbb 100644
--- a/include/swift/AST/KnownStdlibTypes.def
+++ b/include/swift/AST/KnownStdlibTypes.def
@@ -51,6 +51,7 @@
KNOWN_STDLIB_TYPE_DECL(AnyHashable, NominalTypeDecl, 0)
KNOWN_STDLIB_TYPE_DECL(MutableCollection, ProtocolDecl, 1)
+KNOWN_STDLIB_TYPE_DECL(AnyKeyPath, NominalTypeDecl, 0)
KNOWN_STDLIB_TYPE_DECL(PartialKeyPath, NominalTypeDecl, 1)
KNOWN_STDLIB_TYPE_DECL(KeyPath, NominalTypeDecl, 2)
KNOWN_STDLIB_TYPE_DECL(WritableKeyPath, NominalTypeDecl, 2)
diff --git a/lib/SILGen/SILGenExpr.cpp b/lib/SILGen/SILGenExpr.cpp
index 55d7225..ac3698b 100644
--- a/lib/SILGen/SILGenExpr.cpp
+++ b/lib/SILGen/SILGenExpr.cpp
@@ -2789,34 +2789,49 @@
auto root = SGF.emitMaterializedRValueAsOrig(E->getBase(),
AbstractionPattern::getOpaque());
auto keyPath = SGF.emitRValueAsSingleValue(E->getKeyPath());
-
- // Get the root and leaf type from the key path type.
- auto keyPathTy = E->getKeyPath()->getType()->castTo<BoundGenericType>();
- // Upcast the keypath to KeyPath<T, U> if it isn't already.
- if (keyPathTy->getDecl() != SGF.getASTContext().getKeyPathDecl()) {
- auto castToTy = BoundGenericType::get(SGF.getASTContext().getKeyPathDecl(),
+ auto keyPathDecl = E->getKeyPath()->getType()->getAnyNominal();
+ FuncDecl *projectFn;
+ SmallVector<Substitution, 4> subs;
+
+ if (keyPathDecl == SGF.getASTContext().getAnyKeyPathDecl()) {
+ // Invoke projectKeyPathAny with the type of the base value.
+ // The result is always `Any?`.
+ projectFn = SGF.getASTContext().getProjectKeyPathAny(nullptr);
+ subs.push_back(Substitution(E->getBase()->getType(), {}));
+ } else {
+ auto keyPathTy = E->getKeyPath()->getType()->castTo<BoundGenericType>();
+ if (keyPathDecl == SGF.getASTContext().getPartialKeyPathDecl()) {
+ // Invoke projectKeyPathPartial with the type of the base value.
+ // The result is always `Any`.
+ projectFn = SGF.getASTContext().getProjectKeyPathPartial(nullptr);
+ subs.push_back(Substitution(keyPathTy->getGenericArgs()[0], {}));
+ } else {
+ projectFn = SGF.getASTContext().getProjectKeyPathReadOnly(nullptr);
+ // Get the root and leaf type from the key path type.
+ subs.push_back(Substitution(keyPathTy->getGenericArgs()[0], {}));
+ subs.push_back(Substitution(keyPathTy->getGenericArgs()[1], {}));
+
+ // Upcast the keypath to KeyPath<T, U> if it isn't already.
+ if (keyPathTy->getDecl() != SGF.getASTContext().getKeyPathDecl()) {
+ auto castToTy = BoundGenericType::get(
+ SGF.getASTContext().getKeyPathDecl(),
nullptr,
keyPathTy->getGenericArgs())
- ->getCanonicalType();
- auto upcast = SGF.B.createUpcast(SILLocation(E),
+ ->getCanonicalType();
+ auto upcast = SGF.B.createUpcast(SILLocation(E),
keyPath.forward(SGF),
SILType::getPrimitiveObjectType(castToTy));
- keyPath = SGF.emitManagedRValueWithCleanup(upcast);
+ keyPath = SGF.emitManagedRValueWithCleanup(upcast);
+ }
+ }
}
- auto projectFn = SGF.getASTContext().getProjectKeyPathReadOnly(nullptr);
- Substitution genericArgs[] = {
- Substitution(keyPathTy->getGenericArgs()[0], {}),
- Substitution(keyPathTy->getGenericArgs()[1], {}),
- };
-
auto genericArgsMap =
- projectFn->getGenericSignature()->getSubstitutionMap(genericArgs);
+ projectFn->getGenericSignature()->getSubstitutionMap(subs);
return SGF.emitApplyOfLibraryIntrinsic(SILLocation(E),
- SGF.getASTContext().getProjectKeyPathReadOnly(nullptr),
- genericArgsMap, {root, keyPath}, C);
+ projectFn, genericArgsMap, {root, keyPath}, C);
}
RValue RValueEmitter::
diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp
index 9bb9f03..5ef9dd0 100644
--- a/lib/Sema/CSApply.cpp
+++ b/lib/Sema/CSApply.cpp
@@ -1306,24 +1306,53 @@
// Apply a key path if we have one.
if (choice.getKind() == OverloadChoiceKind::KeyPathApplication) {
// The index argument should be (keyPath: KeyPath<Root, Value>).
- auto keyPathTy = index->getType()->castTo<TupleType>()
- ->getElementType(0)->castTo<BoundGenericType>();
- auto valueTy = keyPathTy->getGenericArgs()[1];
+ auto keyPathTTy = index->getType()->castTo<TupleType>()
+ ->getElementType(0);
- // The result may be an lvalue based on the base and key path kind.
+ Type valueTy;
bool resultIsLValue;
- if (keyPathTy->getDecl() == cs.getASTContext().getKeyPathDecl()) {
- resultIsLValue = false;
- base = cs.coerceToRValue(base);
- } else if (keyPathTy->getDecl() ==
- cs.getASTContext().getWritableKeyPathDecl()) {
- resultIsLValue = base->getType()->isLValueType();
- } else if (keyPathTy->getDecl() ==
- cs.getASTContext().getReferenceWritableKeyPathDecl()) {
- resultIsLValue = true;
- base = cs.coerceToRValue(base);
+
+ if (auto nom = keyPathTTy->getAs<NominalType>()) {
+ // AnyKeyPath is <T> rvalue T -> rvalue Any?
+ if (nom->getDecl() == cs.getASTContext().getAnyKeyPathDecl()) {
+ valueTy = ProtocolCompositionType::get(cs.getASTContext(), {},
+ /*explicit anyobject*/ false);
+ valueTy = OptionalType::get(valueTy);
+ resultIsLValue = false;
+ base = cs.coerceToRValue(base);
+ } else {
+ llvm_unreachable("unknown key path class!");
+ }
} else {
- llvm_unreachable("unknown key path class!");
+ auto keyPathBGT = keyPathTTy->castTo<BoundGenericType>();
+
+ if (keyPathBGT->getDecl()
+ == cs.getASTContext().getPartialKeyPathDecl()) {
+ // PartialKeyPath<T> is rvalue T -> rvalue Any
+ valueTy = ProtocolCompositionType::get(cs.getASTContext(), {},
+ /*explicit anyobject*/ false);
+ resultIsLValue = false;
+ base = cs.coerceToRValue(base);
+ } else {
+ // *KeyPath<T, U> is T -> U, with rvalueness based on mutability
+ // of base and keypath
+ valueTy = keyPathBGT->getGenericArgs()[1];
+
+ // The result may be an lvalue based on the base and key path kind.
+ if (keyPathBGT->getDecl() == cs.getASTContext().getKeyPathDecl()) {
+ resultIsLValue = false;
+ base = cs.coerceToRValue(base);
+ } else if (keyPathBGT->getDecl() ==
+ cs.getASTContext().getWritableKeyPathDecl()) {
+ resultIsLValue = base->getType()->isLValueType();
+ } else if (keyPathBGT->getDecl() ==
+ cs.getASTContext().getReferenceWritableKeyPathDecl()) {
+ resultIsLValue = true;
+ base = cs.coerceToRValue(base);
+ } else {
+ llvm_unreachable("unknown key path class!");
+ }
+ }
}
if (resultIsLValue)
valueTy = LValueType::get(valueTy);
diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp
index bde89d8..33868bb 100644
--- a/lib/Sema/CSSimplify.cpp
+++ b/lib/Sema/CSSimplify.cpp
@@ -3936,13 +3936,21 @@
return SolutionKind::Unsolved;
};
+ if (auto clas = keyPathTy->getAs<NominalType>()) {
+ if (clas->getDecl() == getASTContext().getAnyKeyPathDecl()) {
+ // Read-only keypath, whose projected value is upcast to `Any?`.
+ // The root type can be anything.
+ Type resultTy = ProtocolCompositionType::get(getASTContext(), {},
+ /*explicit AnyObject*/ false);
+ resultTy = OptionalType::get(resultTy);
+ return matchTypes(resultTy, valueTy, ConstraintKind::Bind,
+ subflags, locator);
+ }
+ }
+
if (auto bgt = keyPathTy->getAs<BoundGenericType>()) {
- if (bgt->getGenericArgs().size() < 2)
- return SolutionKind::Error;
-
// We have the key path type. Match it to the other ends of the constraint.
auto kpRootTy = bgt->getGenericArgs()[0];
- auto kpValueTy = bgt->getGenericArgs()[1];
// Try to match the root type.
rootTy = getFixedTypeRecursive(rootTy, flags, /*wantRValue=*/false);
@@ -3957,7 +3965,19 @@
case SolutionKind::Unsolved:
llvm_unreachable("should have generated constraints");
}
-
+
+ if (bgt->getDecl() == getASTContext().getPartialKeyPathDecl()) {
+ // Read-only keypath, whose projected value is upcast to `Any`.
+ auto resultTy = ProtocolCompositionType::get(getASTContext(), {},
+ /*explicit AnyObject*/ false);
+ return matchTypes(resultTy, valueTy,
+ ConstraintKind::Bind, subflags, locator);
+ }
+
+ if (bgt->getGenericArgs().size() < 2)
+ return SolutionKind::Error;
+ auto kpValueTy = bgt->getGenericArgs()[1];
+
/// Solve for an rvalue base.
auto solveRValue = [&]() -> ConstraintSystem::SolutionKind {
return matchTypes(kpValueTy, valueTy,
@@ -3976,6 +3996,7 @@
return matchTypes(LValueType::get(kpValueTy), valueTy,
ConstraintKind::Bind, subflags, locator);
};
+
if (bgt->getDecl() == getASTContext().getKeyPathDecl()) {
// Read-only keypath.
diff --git a/stdlib/public/core/KeyPath.swift b/stdlib/public/core/KeyPath.swift
index 9f7463c..bec387d 100644
--- a/stdlib/public/core/KeyPath.swift
+++ b/stdlib/public/core/KeyPath.swift
@@ -1132,6 +1132,39 @@
// MARK: Library intrinsics for projecting key paths.
+@_inlineable
+public // COMPILER_INTRINSIC
+func _projectKeyPathPartial<Root>(
+ root: Root,
+ keyPath: PartialKeyPath<Root>
+) -> Any {
+ func open<Value>(_: Value.Type) -> Any {
+ return _projectKeyPathReadOnly(root: root,
+ keyPath: unsafeDowncast(keyPath, to: KeyPath<Root, Value>.self))
+ }
+ return _openExistential(type(of: keyPath).valueType, do: open)
+}
+
+@_inlineable
+public // COMPILER_INTRINSIC
+func _projectKeyPathAny<RootValue>(
+ root: RootValue,
+ keyPath: AnyKeyPath
+) -> Any? {
+ let (keyPathRoot, keyPathValue) = type(of: keyPath)._rootAndValueType
+ func openRoot<KeyPathRoot>(_: KeyPathRoot.Type) -> Any? {
+ guard let rootForKeyPath = root as? KeyPathRoot else {
+ return nil
+ }
+ func openValue<Value>(_: Value.Type) -> Any {
+ return _projectKeyPathReadOnly(root: rootForKeyPath,
+ keyPath: unsafeDowncast(keyPath, to: KeyPath<KeyPathRoot, Value>.self))
+ }
+ return _openExistential(keyPathValue, do: openValue)
+ }
+ return _openExistential(keyPathRoot, do: openRoot)
+}
+
public // COMPILER_INTRINSIC
func _projectKeyPathReadOnly<Root, Value>(
root: Root,
diff --git a/test/SILGen/keypath_application.swift b/test/SILGen/keypath_application.swift
index f470d34..f23b2c2 100644
--- a/test/SILGen/keypath_application.swift
+++ b/test/SILGen/keypath_application.swift
@@ -220,3 +220,24 @@
readonly[keyPath: rkp] = value
writable[keyPath: rkp] = value
}
+
+// CHECK-LABEL: sil hidden @{{.*}}partial
+func partial<A>(valueA: A,
+ valueB: Int,
+ pkpA: PartialKeyPath<A>,
+ pkpB: PartialKeyPath<Int>,
+ akp: AnyKeyPath) {
+ // CHECK: [[PROJECT:%.*]] = function_ref @{{.*}}projectKeyPathAny
+ // CHECK: apply [[PROJECT]]<A>
+ _ = valueA[keyPath: akp]
+ // CHECK: [[PROJECT:%.*]] = function_ref @{{.*}}projectKeyPathPartial
+ // CHECK: apply [[PROJECT]]<A>
+ _ = valueA[keyPath: pkpA]
+
+ // CHECK: [[PROJECT:%.*]] = function_ref @{{.*}}projectKeyPathAny
+ // CHECK: apply [[PROJECT]]<Int>
+ _ = valueB[keyPath: akp]
+ // CHECK: [[PROJECT:%.*]] = function_ref @{{.*}}projectKeyPathPartial
+ // CHECK: apply [[PROJECT]]<Int>
+ _ = valueB[keyPath: pkpB]
+}
diff --git a/test/expr/unary/keypath/keypath.swift b/test/expr/unary/keypath/keypath.swift
index 641cd64..6be4264 100644
--- a/test/expr/unary/keypath/keypath.swift
+++ b/test/expr/unary/keypath/keypath.swift
@@ -209,9 +209,6 @@
readonly[keyPath: rkp] = sink
writable[keyPath: rkp] = sink
- // TODO: PartialKeyPath and AnyKeyPath application
-
- /*
let pkp: PartialKeyPath = rkp
var anySink1 = readonly[keyPath: pkp]
@@ -219,8 +216,8 @@
var anySink2 = writable[keyPath: pkp]
expect(&anySink2, toHaveType: Exactly<Any>.self)
- readonly[keyPath: pkp] = anySink1 // e/xpected-error{{cannot assign to immutable}}
- writable[keyPath: pkp] = anySink2 // e/xpected-error{{cannot assign to immutable}}
+ readonly[keyPath: pkp] = anySink1 // expected-error{{cannot assign to immutable}}
+ writable[keyPath: pkp] = anySink2 // expected-error{{cannot assign to immutable}}
let akp: AnyKeyPath = pkp
@@ -229,9 +226,8 @@
var anyqSink2 = writable[keyPath: akp]
expect(&anyqSink2, toHaveType: Exactly<Any?>.self)
- readonly[keyPath: akp] = anyqSink1 // e/xpected-error{{cannot assign to immutable}}
- writable[keyPath: akp] = anyqSink2 // e/xpected-error{{cannot assign to immutable}}
- */
+ readonly[keyPath: akp] = anyqSink1 // expected-error{{cannot assign to immutable}}
+ writable[keyPath: akp] = anyqSink2 // expected-error{{cannot assign to immutable}}
}
func testKeyPathSubscriptMetatype(readonly: Z.Type, writable: inout Z.Type,
diff --git a/test/stdlib/KeyPath.swift b/test/stdlib/KeyPath.swift
index 2ec6387..f4b8a75 100644
--- a/test/stdlib/KeyPath.swift
+++ b/test/stdlib/KeyPath.swift
@@ -277,4 +277,41 @@
}
}
+class AB {
+}
+class ABC: AB {
+ var a = LifetimeTracked(1)
+ var b = LifetimeTracked(2)
+ var c = LifetimeTracked(3)
+}
+
+keyPath.test("dynamically-typed application") {
+ let cPaths = [\ABC.a, \ABC.b, \ABC.c]
+
+ let subject = ABC()
+
+ do {
+ let fields = cPaths.map { subject[keyPath: $0] }
+ expectTrue(fields[0] as! AnyObject === subject.a)
+ expectTrue(fields[1] as! AnyObject === subject.b)
+ expectTrue(fields[2] as! AnyObject === subject.c)
+ }
+
+ let erasedSubject: AB = subject
+ let erasedPaths: [AnyKeyPath] = cPaths
+ let wrongSubject = AB()
+
+ do {
+ let fields = erasedPaths.map { erasedSubject[keyPath: $0] }
+ expectTrue(fields[0]! as! AnyObject === subject.a)
+ expectTrue(fields[1]! as! AnyObject === subject.b)
+ expectTrue(fields[2]! as! AnyObject === subject.c)
+
+ let wrongFields = erasedPaths.map { wrongSubject[keyPath: $0] }
+ expectTrue(wrongFields[0] == nil)
+ expectTrue(wrongFields[1] == nil)
+ expectTrue(wrongFields[2] == nil)
+ }
+}
+
runAllTests()