| //===--- TypeCheckAttr.cpp - Type Checking for Attributes -----------------===// |
| // |
| // This source file is part of the Swift.org open source project |
| // |
| // Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors |
| // Licensed under Apache License v2.0 with Runtime Library Exception |
| // |
| // See https://swift.org/LICENSE.txt for license information |
| // See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements semantic analysis for attributes. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "MiscDiagnostics.h" |
| #include "TypeCheckObjC.h" |
| #include "TypeCheckType.h" |
| #include "TypeChecker.h" |
| #include "swift/AST/ASTVisitor.h" |
| #include "swift/AST/ClangModuleLoader.h" |
| #include "swift/AST/DiagnosticsParse.h" |
| #include "swift/AST/GenericEnvironment.h" |
| #include "swift/AST/GenericSignatureBuilder.h" |
| #include "swift/AST/ModuleNameLookup.h" |
| #include "swift/AST/NameLookup.h" |
| #include "swift/AST/NameLookupRequests.h" |
| #include "swift/AST/ParameterList.h" |
| #include "swift/AST/PropertyWrappers.h" |
| #include "swift/AST/SourceFile.h" |
| #include "swift/AST/TypeCheckRequests.h" |
| #include "swift/AST/Types.h" |
| #include "swift/Parse/Lexer.h" |
| #include "swift/Sema/IDETypeChecking.h" |
| #include "clang/Basic/CharInfo.h" |
| #include "llvm/Support/Debug.h" |
| |
| using namespace swift; |
| |
| namespace { |
| /// This emits a diagnostic with a fixit to remove the attribute. |
| template<typename ...ArgTypes> |
| void diagnoseAndRemoveAttr(TypeChecker &TC, Decl *D, DeclAttribute *attr, |
| ArgTypes &&...Args) { |
| assert(!D->hasClangNode() && "Clang importer propagated a bogus attribute"); |
| if (!D->hasClangNode()) { |
| SourceLoc loc = attr->getLocation(); |
| assert(loc.isValid() && "Diagnosing attribute with invalid location"); |
| if (loc.isInvalid()) { |
| loc = D->getLoc(); |
| } |
| if (loc.isValid()) { |
| TC.diagnose(loc, std::forward<ArgTypes>(Args)...) |
| .fixItRemove(attr->getRangeWithAt()); |
| } |
| } |
| |
| attr->setInvalid(); |
| } |
| |
| /// This visits each attribute on a decl. The visitor should return true if |
| /// the attribute is invalid and should be marked as such. |
| class AttributeChecker : public AttributeVisitor<AttributeChecker> { |
| TypeChecker &TC; |
| Decl *D; |
| |
| public: |
| AttributeChecker(TypeChecker &TC, Decl *D) : TC(TC), D(D) {} |
| |
| /// This emits a diagnostic with a fixit to remove the attribute. |
| template<typename ...ArgTypes> |
| void diagnoseAndRemoveAttr(DeclAttribute *attr, ArgTypes &&...Args) { |
| ::diagnoseAndRemoveAttr(TC, D, attr, std::forward<ArgTypes>(Args)...); |
| } |
| |
| /// Deleting this ensures that all attributes are covered by the visitor |
| /// below. |
| bool visitDeclAttribute(DeclAttribute *A) = delete; |
| |
| #define IGNORED_ATTR(X) void visit##X##Attr(X##Attr *) {} |
| IGNORED_ATTR(AlwaysEmitIntoClient) |
| IGNORED_ATTR(HasInitialValue) |
| IGNORED_ATTR(ClangImporterSynthesizedType) |
| IGNORED_ATTR(Convenience) |
| IGNORED_ATTR(Effects) |
| IGNORED_ATTR(Exported) |
| IGNORED_ATTR(ForbidSerializingReference) |
| IGNORED_ATTR(HasStorage) |
| IGNORED_ATTR(Inline) |
| IGNORED_ATTR(ObjCBridged) |
| IGNORED_ATTR(ObjCNonLazyRealization) |
| IGNORED_ATTR(ObjCRuntimeName) |
| IGNORED_ATTR(RawDocComment) |
| IGNORED_ATTR(RequiresStoredPropertyInits) |
| IGNORED_ATTR(RestatedObjCConformance) |
| IGNORED_ATTR(Semantics) |
| IGNORED_ATTR(ShowInInterface) |
| IGNORED_ATTR(SILGenName) |
| IGNORED_ATTR(StaticInitializeObjCMetadata) |
| IGNORED_ATTR(SynthesizedProtocol) |
| IGNORED_ATTR(Testable) |
| IGNORED_ATTR(WeakLinked) |
| IGNORED_ATTR(PrivateImport) |
| IGNORED_ATTR(DisfavoredOverload) |
| IGNORED_ATTR(ProjectedValueProperty) |
| IGNORED_ATTR(ReferenceOwnership) |
| // SWIFT_ENABLE_TENSORFLOW |
| // TODO(TF-715): Allow @quoted on more decls. |
| IGNORED_ATTR(Quoted) |
| #undef IGNORED_ATTR |
| |
| void visitAlignmentAttr(AlignmentAttr *attr) { |
| // Alignment must be a power of two. |
| auto value = attr->getValue(); |
| if (value == 0 || (value & (value - 1)) != 0) |
| TC.diagnose(attr->getLocation(), diag::alignment_not_power_of_two); |
| } |
| |
| void visitBorrowedAttr(BorrowedAttr *attr) { |
| // These criteria are the same preconditions laid out by |
| // AbstractStorageDecl::requiresOpaqueModifyCoroutine(). |
| |
| assert(!D->hasClangNode() && "@_borrowed on imported declaration?"); |
| |
| if (D->getAttrs().hasAttribute<DynamicAttr>()) { |
| TC.diagnose(attr->getLocation(), diag::borrowed_with_objc_dynamic, |
| D->getDescriptiveKind()) |
| .fixItRemove(attr->getRange()); |
| D->getAttrs().removeAttribute(attr); |
| return; |
| } |
| |
| auto dc = D->getDeclContext(); |
| auto protoDecl = dyn_cast<ProtocolDecl>(dc); |
| if (protoDecl && protoDecl->isObjC()) { |
| TC.diagnose(attr->getLocation(), |
| diag::borrowed_on_objc_protocol_requirement, |
| D->getDescriptiveKind()) |
| .fixItRemove(attr->getRange()); |
| D->getAttrs().removeAttribute(attr); |
| return; |
| } |
| } |
| |
| void visitTransparentAttr(TransparentAttr *attr); |
| void visitMutationAttr(DeclAttribute *attr); |
| void visitMutatingAttr(MutatingAttr *attr) { visitMutationAttr(attr); } |
| void visitNonMutatingAttr(NonMutatingAttr *attr) { visitMutationAttr(attr); } |
| void visitConsumingAttr(ConsumingAttr *attr) { visitMutationAttr(attr); } |
| void visitDynamicAttr(DynamicAttr *attr); |
| |
| void visitIndirectAttr(IndirectAttr *attr) { |
| if (auto caseDecl = dyn_cast<EnumElementDecl>(D)) { |
| // An indirect case should have a payload. |
| if (!caseDecl->hasAssociatedValues()) |
| TC.diagnose(attr->getLocation(), |
| diag::indirect_case_without_payload, caseDecl->getName()); |
| // If the enum is already indirect, its cases don't need to be. |
| else if (caseDecl->getParentEnum()->getAttrs() |
| .hasAttribute<IndirectAttr>()) |
| TC.diagnose(attr->getLocation(), |
| diag::indirect_case_in_indirect_enum); |
| } |
| } |
| |
| void visitWarnUnqualifiedAccessAttr(WarnUnqualifiedAccessAttr *attr) { |
| if (!D->getDeclContext()->isTypeContext()) { |
| diagnoseAndRemoveAttr(attr, diag::attr_methods_only, attr); |
| } |
| } |
| |
| void visitFinalAttr(FinalAttr *attr); |
| void visitIBActionAttr(IBActionAttr *attr); |
| void visitIBSegueActionAttr(IBSegueActionAttr *attr); |
| void visitLazyAttr(LazyAttr *attr); |
| void visitIBDesignableAttr(IBDesignableAttr *attr); |
| void visitIBInspectableAttr(IBInspectableAttr *attr); |
| void visitGKInspectableAttr(GKInspectableAttr *attr); |
| void visitIBOutletAttr(IBOutletAttr *attr); |
| void visitLLDBDebuggerFunctionAttr(LLDBDebuggerFunctionAttr *attr); |
| void visitNSManagedAttr(NSManagedAttr *attr); |
| void visitOverrideAttr(OverrideAttr *attr); |
| void visitNonOverrideAttr(NonOverrideAttr *attr); |
| void visitAccessControlAttr(AccessControlAttr *attr); |
| void visitSetterAccessAttr(SetterAccessAttr *attr); |
| bool visitAbstractAccessControlAttr(AbstractAccessControlAttr *attr); |
| |
| void visitObjCAttr(ObjCAttr *attr); |
| void visitNonObjCAttr(NonObjCAttr *attr); |
| void visitObjCMembersAttr(ObjCMembersAttr *attr); |
| |
| void visitOptionalAttr(OptionalAttr *attr); |
| |
| void visitAvailableAttr(AvailableAttr *attr); |
| |
| void visitCDeclAttr(CDeclAttr *attr); |
| |
| void visitDynamicCallableAttr(DynamicCallableAttr *attr); |
| |
| void visitDynamicMemberLookupAttr(DynamicMemberLookupAttr *attr); |
| |
| void visitNSCopyingAttr(NSCopyingAttr *attr); |
| void visitRequiredAttr(RequiredAttr *attr); |
| void visitRethrowsAttr(RethrowsAttr *attr); |
| |
| void checkApplicationMainAttribute(DeclAttribute *attr, |
| Identifier Id_ApplicationDelegate, |
| Identifier Id_Kit, |
| Identifier Id_ApplicationMain); |
| |
| void visitNSApplicationMainAttr(NSApplicationMainAttr *attr); |
| void visitUIApplicationMainAttr(UIApplicationMainAttr *attr); |
| |
| void visitUnsafeNoObjCTaggedPointerAttr(UnsafeNoObjCTaggedPointerAttr *attr); |
| void visitSwiftNativeObjCRuntimeBaseAttr( |
| SwiftNativeObjCRuntimeBaseAttr *attr); |
| |
| void checkOperatorAttribute(DeclAttribute *attr); |
| |
| void visitInfixAttr(InfixAttr *attr) { checkOperatorAttribute(attr); } |
| void visitPostfixAttr(PostfixAttr *attr) { checkOperatorAttribute(attr); } |
| void visitPrefixAttr(PrefixAttr *attr) { checkOperatorAttribute(attr); } |
| |
| void visitSpecializeAttr(SpecializeAttr *attr); |
| |
| void visitFixedLayoutAttr(FixedLayoutAttr *attr); |
| void visitUsableFromInlineAttr(UsableFromInlineAttr *attr); |
| void visitInlinableAttr(InlinableAttr *attr); |
| void visitOptimizeAttr(OptimizeAttr *attr); |
| |
| void visitDiscardableResultAttr(DiscardableResultAttr *attr); |
| void visitDynamicReplacementAttr(DynamicReplacementAttr *attr); |
| void visitImplementsAttr(ImplementsAttr *attr); |
| |
| void visitFrozenAttr(FrozenAttr *attr); |
| |
| void visitCustomAttr(CustomAttr *attr); |
| void visitPropertyWrapperAttr(PropertyWrapperAttr *attr); |
| void visitFunctionBuilderAttr(FunctionBuilderAttr *attr); |
| |
| void visitImplementationOnlyAttr(ImplementationOnlyAttr *attr); |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| void visitDifferentiableAttr(DifferentiableAttr *attr); |
| void visitDifferentiatingAttr(DifferentiatingAttr *attr); |
| void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr); |
| void visitNoDerivativeAttr(NoDerivativeAttr *attr); |
| void visitTransposingAttr(TransposingAttr *attr); |
| }; |
| } // end anonymous namespace |
| |
| void AttributeChecker::visitTransparentAttr(TransparentAttr *attr) { |
| DeclContext *Ctx = D->getDeclContext(); |
| // Protocol declarations cannot be transparent. |
| if (isa<ProtocolDecl>(Ctx)) |
| diagnoseAndRemoveAttr(attr, diag::transparent_in_protocols_not_supported); |
| // Class declarations cannot be transparent. |
| if (isa<ClassDecl>(Ctx)) { |
| |
| // @transparent is always ok on implicitly generated accessors: they can |
| // be dispatched (even in classes) when the references are within the |
| // class themself. |
| if (!(isa<AccessorDecl>(D) && D->isImplicit())) |
| diagnoseAndRemoveAttr(attr, diag::transparent_in_classes_not_supported); |
| } |
| |
| if (auto *VD = dyn_cast<VarDecl>(D)) { |
| // Stored properties and variables can't be transparent. |
| if (VD->hasStorage()) |
| diagnoseAndRemoveAttr(attr, diag::attribute_invalid_on_stored_property, |
| attr); |
| } |
| } |
| |
| void AttributeChecker::visitMutationAttr(DeclAttribute *attr) { |
| FuncDecl *FD = cast<FuncDecl>(D); |
| |
| SelfAccessKind attrModifier; |
| switch (attr->getKind()) { |
| case DeclAttrKind::DAK_Consuming: |
| attrModifier = SelfAccessKind::Consuming; |
| break; |
| case DeclAttrKind::DAK_Mutating: |
| attrModifier = SelfAccessKind::Mutating; |
| break; |
| case DeclAttrKind::DAK_NonMutating: |
| attrModifier = SelfAccessKind::NonMutating; |
| break; |
| default: |
| llvm_unreachable("unhandled attribute kind"); |
| } |
| |
| // mutation attributes may only appear in type context. |
| if (auto contextTy = FD->getDeclContext()->getDeclaredInterfaceType()) { |
| // 'mutating' and 'nonmutating' are not valid on types |
| // with reference semantics. |
| if (contextTy->hasReferenceSemantics()) { |
| if (attrModifier != SelfAccessKind::Consuming) |
| diagnoseAndRemoveAttr(attr, diag::mutating_invalid_classes, |
| attrModifier); |
| } |
| } else { |
| diagnoseAndRemoveAttr(attr, diag::mutating_invalid_global_scope, |
| attrModifier); |
| } |
| |
| // Verify we don't have more than one of mutating, nonmutating, |
| // and __consuming. |
| if ((FD->getAttrs().hasAttribute<MutatingAttr>() + |
| FD->getAttrs().hasAttribute<NonMutatingAttr>() + |
| FD->getAttrs().hasAttribute<ConsumingAttr>()) > 1) { |
| if (auto *NMA = FD->getAttrs().getAttribute<NonMutatingAttr>()) { |
| if (attrModifier != SelfAccessKind::NonMutating) { |
| diagnoseAndRemoveAttr(NMA, diag::functions_mutating_and_not, |
| SelfAccessKind::NonMutating, attrModifier); |
| } |
| } |
| |
| if (auto *MUA = FD->getAttrs().getAttribute<MutatingAttr>()) { |
| if (attrModifier != SelfAccessKind::Mutating) { |
| diagnoseAndRemoveAttr(MUA, diag::functions_mutating_and_not, |
| SelfAccessKind::Mutating, attrModifier); |
| } |
| } |
| |
| if (auto *CSA = FD->getAttrs().getAttribute<ConsumingAttr>()) { |
| if (attrModifier != SelfAccessKind::Consuming) { |
| diagnoseAndRemoveAttr(CSA, diag::functions_mutating_and_not, |
| SelfAccessKind::Consuming, attrModifier); |
| } |
| } |
| } |
| |
| // Verify that we don't have a static function. |
| if (FD->isStatic()) |
| diagnoseAndRemoveAttr(attr, diag::static_functions_not_mutating); |
| } |
| |
| void AttributeChecker::visitDynamicAttr(DynamicAttr *attr) { |
| // Members cannot be both dynamic and @_transparent. |
| if (D->getAttrs().hasAttribute<TransparentAttr>()) |
| diagnoseAndRemoveAttr(attr, diag::dynamic_with_transparent); |
| if (!D->getAttrs().hasAttribute<ObjCAttr>() && |
| D->getModuleContext()->isResilient()) |
| diagnoseAndRemoveAttr(attr, |
| diag::dynamic_and_library_evolution_not_supported); |
| } |
| |
| static bool |
| validateIBActionSignature(TypeChecker &TC, DeclAttribute *attr, const FuncDecl *FD, |
| unsigned minParameters, unsigned maxParameters, |
| bool hasVoidResult = true) { |
| bool valid = true; |
| |
| auto arity = FD->getParameters()->size(); |
| auto resultType = FD->getResultInterfaceType(); |
| |
| if (arity < minParameters || arity > maxParameters) { |
| auto diagID = diag::invalid_ibaction_argument_count; |
| if (minParameters == maxParameters) |
| diagID = diag::invalid_ibaction_argument_count_exact; |
| else if (minParameters == 0) |
| diagID = diag::invalid_ibaction_argument_count_max; |
| TC.diagnose(FD, diagID, attr->getAttrName(), minParameters, maxParameters); |
| valid = false; |
| } |
| |
| if (resultType->isVoid() != hasVoidResult) { |
| TC.diagnose(FD, diag::invalid_ibaction_result, attr->getAttrName(), |
| hasVoidResult); |
| valid = false; |
| } |
| |
| // We don't need to check here that parameter or return types are |
| // ObjC-representable; IsObjCRequest will validate that. |
| |
| if (!valid) |
| attr->setInvalid(); |
| return valid; |
| } |
| |
| static bool isiOS(TypeChecker &TC) { |
| return TC.getLangOpts().Target.isiOS(); |
| } |
| |
| static bool iswatchOS(TypeChecker &TC) { |
| return TC.getLangOpts().Target.isWatchOS(); |
| } |
| |
| static bool isRelaxedIBAction(TypeChecker &TC) { |
| return isiOS(TC) || iswatchOS(TC); |
| } |
| |
| void AttributeChecker::visitIBActionAttr(IBActionAttr *attr) { |
| // Only instance methods can be IBActions. |
| const FuncDecl *FD = cast<FuncDecl>(D); |
| if (!FD->isPotentialIBActionTarget()) { |
| diagnoseAndRemoveAttr(attr, diag::invalid_ibaction_decl, |
| attr->getAttrName()); |
| return; |
| } |
| |
| if (isRelaxedIBAction(TC)) |
| // iOS, tvOS, and watchOS allow 0-2 parameters to an @IBAction method. |
| validateIBActionSignature(TC, attr, FD, /*minParams=*/0, /*maxParams=*/2); |
| else |
| // macOS allows 1 parameter to an @IBAction method. |
| validateIBActionSignature(TC, attr, FD, /*minParams=*/1, /*maxParams=*/1); |
| } |
| |
| void AttributeChecker::visitIBSegueActionAttr(IBSegueActionAttr *attr) { |
| // Only instance methods can be IBActions. |
| const FuncDecl *FD = cast<FuncDecl>(D); |
| if (!FD->isPotentialIBActionTarget()) |
| diagnoseAndRemoveAttr(attr, diag::invalid_ibaction_decl, |
| attr->getAttrName()); |
| |
| if (!validateIBActionSignature(TC, attr, FD, |
| /*minParams=*/1, /*maxParams=*/3, |
| /*hasVoidResult=*/false)) |
| return; |
| |
| // If the IBSegueAction method's selector belongs to one of the ObjC method |
| // families (like -newDocumentSegue: or -copyScreen), it would return the |
| // object at +1, but the caller would expect it to be +0 and would therefore |
| // leak it. |
| // |
| // To prevent that, diagnose if the selector belongs to one of the method |
| // families and suggest that the user change the Swift name or Obj-C selector. |
| auto currentSelector = FD->getObjCSelector(); |
| |
| SmallString<32> prefix("make"); |
| |
| switch (currentSelector.getSelectorFamily()) { |
| case ObjCSelectorFamily::None: |
| // No error--exit early. |
| return; |
| |
| case ObjCSelectorFamily::Alloc: |
| case ObjCSelectorFamily::Init: |
| case ObjCSelectorFamily::New: |
| // Fix-it will replace the "alloc"/"init"/"new" in the selector with "make". |
| break; |
| |
| case ObjCSelectorFamily::Copy: |
| // Fix-it will replace the "copy" in the selector with "makeCopy". |
| prefix += "Copy"; |
| break; |
| |
| case ObjCSelectorFamily::MutableCopy: |
| // Fix-it will replace the "mutable" in the selector with "makeMutable". |
| prefix += "Mutable"; |
| break; |
| } |
| |
| // Emit the actual error. |
| TC.diagnose(FD, diag::ibsegueaction_objc_method_family, |
| attr->getAttrName(), currentSelector); |
| |
| // The rest of this is just fix-it generation. |
| |
| /// Replaces the first word of \c oldName with the prefix, where "word" is a |
| /// sequence of lowercase characters. |
| auto replacingPrefix = [&](Identifier oldName) -> Identifier { |
| SmallString<32> scratch = prefix; |
| scratch += oldName.str().drop_while(clang::isLowercase); |
| return TC.Context.getIdentifier(scratch); |
| }; |
| |
| // Suggest changing the Swift name of the method, unless there is already an |
| // explicit selector. |
| if (!FD->getAttrs().hasAttribute<ObjCAttr>() || |
| !FD->getAttrs().getAttribute<ObjCAttr>()->hasName()) { |
| auto newSwiftBaseName = replacingPrefix(FD->getBaseName().getIdentifier()); |
| auto argumentNames = FD->getFullName().getArgumentNames(); |
| DeclName newSwiftName(TC.Context, newSwiftBaseName, argumentNames); |
| |
| auto diag = TC.diagnose(FD, diag::fixit_rename_in_swift, newSwiftName); |
| fixDeclarationName(diag, FD, newSwiftName); |
| } |
| |
| // Suggest changing just the selector to one with a different first piece. |
| auto oldPieces = currentSelector.getSelectorPieces(); |
| SmallVector<Identifier, 4> newPieces(oldPieces.begin(), oldPieces.end()); |
| newPieces[0] = replacingPrefix(newPieces[0]); |
| ObjCSelector newSelector(TC.Context, currentSelector.getNumArgs(), newPieces); |
| |
| auto diag = TC.diagnose(FD, diag::fixit_rename_in_objc, newSelector); |
| fixDeclarationObjCName(diag, FD, currentSelector, newSelector); |
| } |
| |
| void AttributeChecker::visitIBDesignableAttr(IBDesignableAttr *attr) { |
| if (auto *ED = dyn_cast<ExtensionDecl>(D)) { |
| if (auto nominalDecl = ED->getExtendedNominal()) { |
| if (!isa<ClassDecl>(nominalDecl)) |
| diagnoseAndRemoveAttr(attr, diag::invalid_ibdesignable_extension); |
| } |
| } |
| } |
| |
| void AttributeChecker::visitIBInspectableAttr(IBInspectableAttr *attr) { |
| // Only instance properties can be 'IBInspectable'. |
| auto *VD = cast<VarDecl>(D); |
| if (!VD->getDeclContext()->getSelfClassDecl() || VD->isStatic()) |
| diagnoseAndRemoveAttr(attr, diag::invalid_ibinspectable, |
| attr->getAttrName()); |
| } |
| |
| void AttributeChecker::visitGKInspectableAttr(GKInspectableAttr *attr) { |
| // Only instance properties can be 'GKInspectable'. |
| auto *VD = cast<VarDecl>(D); |
| if (!VD->getDeclContext()->getSelfClassDecl() || VD->isStatic()) |
| diagnoseAndRemoveAttr(attr, diag::invalid_ibinspectable, |
| attr->getAttrName()); |
| } |
| |
| static Optional<Diag<bool,Type>> |
| isAcceptableOutletType(Type type, bool &isArray, TypeChecker &TC) { |
| if (type->isObjCExistentialType() || type->isAny()) |
| return None; // @objc existential types are okay |
| |
| auto nominal = type->getAnyNominal(); |
| |
| if (auto classDecl = dyn_cast_or_null<ClassDecl>(nominal)) { |
| if (classDecl->isObjC()) |
| return None; // @objc class types are okay. |
| return diag::iboutlet_nonobjc_class; |
| } |
| |
| if (nominal == TC.Context.getStringDecl()) { |
| // String is okay because it is bridged to NSString. |
| // FIXME: BridgesTypes.def is almost sufficient for this. |
| return None; |
| } |
| |
| if (nominal == TC.Context.getArrayDecl()) { |
| // Arrays of arrays are not allowed. |
| if (isArray) |
| return diag::iboutlet_nonobject_type; |
| |
| isArray = true; |
| |
| // Handle Array<T>. T must be an Objective-C class or protocol. |
| auto boundTy = type->castTo<BoundGenericStructType>(); |
| auto boundArgs = boundTy->getGenericArgs(); |
| assert(boundArgs.size() == 1 && "invalid Array declaration"); |
| Type elementTy = boundArgs.front(); |
| return isAcceptableOutletType(elementTy, isArray, TC); |
| } |
| |
| if (type->isExistentialType()) |
| return diag::iboutlet_nonobjc_protocol; |
| |
| // No other types are permitted. |
| return diag::iboutlet_nonobject_type; |
| } |
| |
| |
| void AttributeChecker::visitIBOutletAttr(IBOutletAttr *attr) { |
| // Only instance properties can be 'IBOutlet'. |
| auto *VD = cast<VarDecl>(D); |
| if (!VD->getDeclContext()->getSelfClassDecl() || VD->isStatic()) |
| diagnoseAndRemoveAttr(attr, diag::invalid_iboutlet); |
| |
| if (!VD->isSettable(nullptr)) { |
| // Allow non-mutable IBOutlet properties in module interfaces, |
| // as they may have been private(set) |
| SourceFile *Parent = VD->getDeclContext()->getParentSourceFile(); |
| if (!Parent || Parent->Kind != SourceFileKind::Interface) |
| diagnoseAndRemoveAttr(attr, diag::iboutlet_only_mutable); |
| } |
| |
| // Verify that the field type is valid as an outlet. |
| auto type = VD->getType(); |
| |
| if (VD->isInvalid()) |
| return; |
| |
| // Look through ownership types, and optionals. |
| type = type->getReferenceStorageReferent(); |
| bool wasOptional = false; |
| if (Type underlying = type->getOptionalObjectType()) { |
| type = underlying; |
| wasOptional = true; |
| } |
| |
| bool isArray = false; |
| if (auto isError = isAcceptableOutletType(type, isArray, TC)) |
| diagnoseAndRemoveAttr(attr, isError.getValue(), |
| /*array=*/isArray, type); |
| |
| // If the type wasn't optional, an array, or unowned, complain. |
| if (!wasOptional && !isArray) { |
| TC.diagnose(attr->getLocation(), diag::iboutlet_non_optional, type); |
| auto typeRange = VD->getTypeSourceRangeForDiagnostics(); |
| { // Only one diagnostic can be active at a time. |
| auto diag = TC.diagnose(typeRange.Start, diag::note_make_optional, |
| OptionalType::get(type)); |
| if (type->hasSimpleTypeRepr()) { |
| diag.fixItInsertAfter(typeRange.End, "?"); |
| } else { |
| diag.fixItInsert(typeRange.Start, "(") |
| .fixItInsertAfter(typeRange.End, ")?"); |
| } |
| } |
| { // Only one diagnostic can be active at a time. |
| auto diag = TC.diagnose(typeRange.Start, |
| diag::note_make_implicitly_unwrapped_optional); |
| if (type->hasSimpleTypeRepr()) { |
| diag.fixItInsertAfter(typeRange.End, "!"); |
| } else { |
| diag.fixItInsert(typeRange.Start, "(") |
| .fixItInsertAfter(typeRange.End, ")!"); |
| } |
| } |
| } |
| } |
| |
| void AttributeChecker::visitNSManagedAttr(NSManagedAttr *attr) { |
| // @NSManaged only applies to instance methods and properties within a class. |
| if (cast<ValueDecl>(D)->isStatic() || |
| !D->getDeclContext()->getSelfClassDecl()) { |
| diagnoseAndRemoveAttr(attr, diag::attr_NSManaged_not_instance_member); |
| } |
| |
| if (auto *method = dyn_cast<FuncDecl>(D)) { |
| // Separate out the checks for methods. |
| if (method->hasBody()) |
| diagnoseAndRemoveAttr(attr, diag::attr_NSManaged_method_body); |
| |
| return; |
| } |
| |
| // Everything below deals with restrictions on @NSManaged properties. |
| auto *VD = cast<VarDecl>(D); |
| |
| // @NSManaged properties cannot be @NSCopying |
| if (auto *NSCopy = VD->getAttrs().getAttribute<NSCopyingAttr>()) |
| diagnoseAndRemoveAttr(NSCopy, diag::attr_NSManaged_NSCopying); |
| |
| } |
| |
| void AttributeChecker:: |
| visitLLDBDebuggerFunctionAttr(LLDBDebuggerFunctionAttr *attr) { |
| // This is only legal when debugger support is on. |
| if (!D->getASTContext().LangOpts.DebuggerSupport) |
| diagnoseAndRemoveAttr(attr, diag::attr_for_debugger_support_only); |
| } |
| |
| void AttributeChecker::visitOverrideAttr(OverrideAttr *attr) { |
| if (!isa<ClassDecl>(D->getDeclContext()) && |
| !isa<ProtocolDecl>(D->getDeclContext()) && |
| !isa<ExtensionDecl>(D->getDeclContext())) |
| diagnoseAndRemoveAttr(attr, diag::override_nonclass_decl); |
| } |
| |
| void AttributeChecker::visitNonOverrideAttr(NonOverrideAttr *attr) { |
| if (auto overrideAttr = D->getAttrs().getAttribute<OverrideAttr>()) |
| diagnoseAndRemoveAttr(overrideAttr, diag::nonoverride_and_override_attr); |
| |
| if (!isa<ClassDecl>(D->getDeclContext()) && |
| !isa<ProtocolDecl>(D->getDeclContext()) && |
| !isa<ExtensionDecl>(D->getDeclContext())) { |
| diagnoseAndRemoveAttr(attr, diag::nonoverride_wrong_decl_context); |
| } |
| } |
| |
| void AttributeChecker::visitLazyAttr(LazyAttr *attr) { |
| // lazy may only be used on properties. |
| auto *VD = cast<VarDecl>(D); |
| |
| auto attrs = VD->getAttrs(); |
| // 'lazy' is not allowed to have reference attributes |
| if (auto *refAttr = attrs.getAttribute<ReferenceOwnershipAttr>()) |
| diagnoseAndRemoveAttr(attr, diag::lazy_not_strong, refAttr->get()); |
| |
| auto varDC = VD->getDeclContext(); |
| |
| // 'lazy' is not allowed on a global variable or on a static property (which |
| // are already lazily initialized). |
| // TODO: we can't currently support lazy properties on non-type-contexts. |
| if (VD->isStatic() || |
| (varDC->isModuleScopeContext() && |
| !varDC->getParentSourceFile()->isScriptMode())) { |
| diagnoseAndRemoveAttr(attr, diag::lazy_on_already_lazy_global); |
| } else if (!VD->getDeclContext()->isTypeContext()) { |
| diagnoseAndRemoveAttr(attr, diag::lazy_must_be_property); |
| } |
| } |
| |
| bool AttributeChecker::visitAbstractAccessControlAttr( |
| AbstractAccessControlAttr *attr) { |
| // Access control attr may only be used on value decls and extensions. |
| if (!isa<ValueDecl>(D) && !isa<ExtensionDecl>(D)) { |
| diagnoseAndRemoveAttr(attr, diag::invalid_decl_modifier, attr); |
| return true; |
| } |
| |
| if (auto extension = dyn_cast<ExtensionDecl>(D)) { |
| if (!extension->getInherited().empty()) { |
| diagnoseAndRemoveAttr(attr, diag::extension_access_with_conformances, |
| attr); |
| return true; |
| } |
| } |
| |
| // And not on certain value decls. |
| if (isa<DestructorDecl>(D) || isa<EnumElementDecl>(D)) { |
| diagnoseAndRemoveAttr(attr, diag::invalid_decl_modifier, attr); |
| return true; |
| } |
| |
| // Or within protocols. |
| if (isa<ProtocolDecl>(D->getDeclContext())) { |
| diagnoseAndRemoveAttr(attr, diag::access_control_in_protocol, attr); |
| TC.diagnose(attr->getLocation(), diag::access_control_in_protocol_detail); |
| return true; |
| } |
| |
| return false; |
| } |
| |
| void AttributeChecker::visitAccessControlAttr(AccessControlAttr *attr) { |
| visitAbstractAccessControlAttr(attr); |
| |
| if (auto extension = dyn_cast<ExtensionDecl>(D)) { |
| if (attr->getAccess() == AccessLevel::Open) { |
| TC.diagnose(attr->getLocation(), diag::access_control_extension_open) |
| .fixItReplace(attr->getRange(), "public"); |
| attr->setInvalid(); |
| return; |
| } |
| |
| NominalTypeDecl *nominal = extension->getExtendedNominal(); |
| |
| // Extension is ill-formed; suppress the attribute. |
| if (!nominal) { |
| attr->setInvalid(); |
| return; |
| } |
| |
| AccessLevel typeAccess = nominal->getFormalAccess(); |
| if (attr->getAccess() > typeAccess) { |
| TC.diagnose(attr->getLocation(), diag::access_control_extension_more, |
| typeAccess, |
| nominal->getDescriptiveKind(), |
| attr->getAccess()) |
| .fixItRemove(attr->getRange()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| } else if (auto extension = dyn_cast<ExtensionDecl>(D->getDeclContext())) { |
| AccessLevel maxAccess = extension->getMaxAccessLevel(); |
| if (std::min(attr->getAccess(), AccessLevel::Public) > maxAccess) { |
| // FIXME: It would be nice to say what part of the requirements actually |
| // end up being problematic. |
| auto diag = |
| TC.diagnose(attr->getLocation(), |
| diag::access_control_ext_requirement_member_more, |
| attr->getAccess(), |
| D->getDescriptiveKind(), |
| maxAccess); |
| swift::fixItAccess(diag, cast<ValueDecl>(D), maxAccess); |
| return; |
| } |
| |
| if (auto extAttr = |
| extension->getAttrs().getAttribute<AccessControlAttr>()) { |
| AccessLevel defaultAccess = extension->getDefaultAccessLevel(); |
| if (attr->getAccess() > defaultAccess) { |
| auto diag = TC.diagnose(attr->getLocation(), |
| diag::access_control_ext_member_more, |
| attr->getAccess(), |
| extAttr->getAccess()); |
| // Don't try to fix this one; it's just a warning, and fixing it can |
| // lead to diagnostic fights between this and "declaration must be at |
| // least this accessible" checking for overrides and protocol |
| // requirements. |
| } else if (attr->getAccess() == defaultAccess) { |
| TC.diagnose(attr->getLocation(), |
| diag::access_control_ext_member_redundant, |
| attr->getAccess(), |
| D->getDescriptiveKind(), |
| extAttr->getAccess()) |
| .fixItRemove(attr->getRange()); |
| } |
| } |
| } |
| |
| if (attr->getAccess() == AccessLevel::Open) { |
| if (!isa<ClassDecl>(D) && !D->isPotentiallyOverridable() && |
| !attr->isInvalid()) { |
| TC.diagnose(attr->getLocation(), diag::access_control_open_bad_decl) |
| .fixItReplace(attr->getRange(), "public"); |
| attr->setInvalid(); |
| } |
| } |
| } |
| |
| void AttributeChecker::visitSetterAccessAttr( |
| SetterAccessAttr *attr) { |
| auto storage = dyn_cast<AbstractStorageDecl>(D); |
| if (!storage) |
| diagnoseAndRemoveAttr(attr, diag::access_control_setter, attr->getAccess()); |
| |
| if (visitAbstractAccessControlAttr(attr)) |
| return; |
| |
| if (!storage->isSettable(storage->getDeclContext())) { |
| // This must stay in sync with diag::access_control_setter_read_only. |
| enum { |
| SK_Constant = 0, |
| SK_Variable, |
| SK_Property, |
| SK_Subscript |
| } storageKind; |
| if (isa<SubscriptDecl>(storage)) |
| storageKind = SK_Subscript; |
| else if (storage->getDeclContext()->isTypeContext()) |
| storageKind = SK_Property; |
| else if (cast<VarDecl>(storage)->isLet()) |
| storageKind = SK_Constant; |
| else |
| storageKind = SK_Variable; |
| diagnoseAndRemoveAttr(attr, diag::access_control_setter_read_only, |
| attr->getAccess(), storageKind); |
| } |
| |
| auto getterAccess = cast<ValueDecl>(D)->getFormalAccess(); |
| if (attr->getAccess() > getterAccess) { |
| // This must stay in sync with diag::access_control_setter_more. |
| enum { |
| SK_Variable = 0, |
| SK_Property, |
| SK_Subscript |
| } storageKind; |
| if (isa<SubscriptDecl>(D)) |
| storageKind = SK_Subscript; |
| else if (D->getDeclContext()->isTypeContext()) |
| storageKind = SK_Property; |
| else |
| storageKind = SK_Variable; |
| TC.diagnose(attr->getLocation(), diag::access_control_setter_more, |
| getterAccess, storageKind, attr->getAccess()); |
| attr->setInvalid(); |
| return; |
| |
| } else if (attr->getAccess() == getterAccess) { |
| TC.diagnose(attr->getLocation(), |
| diag::access_control_setter_redundant, |
| attr->getAccess(), |
| D->getDescriptiveKind(), |
| getterAccess) |
| .fixItRemove(attr->getRange()); |
| return; |
| } |
| } |
| |
| static bool checkObjCDeclContext(Decl *D) { |
| DeclContext *DC = D->getDeclContext(); |
| if (DC->getSelfClassDecl()) |
| return true; |
| if (auto *PD = dyn_cast<ProtocolDecl>(DC)) |
| if (PD->isObjC()) |
| return true; |
| return false; |
| } |
| |
| void AttributeChecker::visitObjCAttr(ObjCAttr *attr) { |
| // Only certain decls can be ObjC. |
| Optional<Diag<>> error; |
| if (isa<ClassDecl>(D) || |
| isa<ProtocolDecl>(D)) { |
| /* ok */ |
| } else if (auto Ext = dyn_cast<ExtensionDecl>(D)) { |
| if (!Ext->getSelfClassDecl()) |
| error = diag::objc_extension_not_class; |
| } else if (auto ED = dyn_cast<EnumDecl>(D)) { |
| if (ED->isGenericContext()) |
| error = diag::objc_enum_generic; |
| } else if (auto EED = dyn_cast<EnumElementDecl>(D)) { |
| auto ED = EED->getParentEnum(); |
| if (!ED->getAttrs().hasAttribute<ObjCAttr>()) |
| error = diag::objc_enum_case_req_objc_enum; |
| else if (attr->hasName() && EED->getParentCase()->getElements().size() > 1) |
| error = diag::objc_enum_case_multi; |
| } else if (auto *func = dyn_cast<FuncDecl>(D)) { |
| if (!checkObjCDeclContext(D)) |
| error = diag::invalid_objc_decl_context; |
| else if (auto accessor = dyn_cast<AccessorDecl>(func)) |
| if (!accessor->isGetterOrSetter()) |
| error = diag::objc_observing_accessor; |
| } else if (isa<ConstructorDecl>(D) || |
| isa<DestructorDecl>(D) || |
| isa<SubscriptDecl>(D) || |
| isa<VarDecl>(D)) { |
| if (!checkObjCDeclContext(D)) |
| error = diag::invalid_objc_decl_context; |
| /* ok */ |
| } else { |
| error = diag::invalid_objc_decl; |
| } |
| |
| if (error) { |
| diagnoseAndRemoveAttr(attr, *error); |
| return; |
| } |
| |
| // If there is a name, check whether the kind of name is |
| // appropriate. |
| if (auto objcName = attr->getName()) { |
| if (isa<ClassDecl>(D) || isa<ProtocolDecl>(D) || isa<VarDecl>(D) |
| || isa<EnumDecl>(D) || isa<EnumElementDecl>(D) |
| || isa<ExtensionDecl>(D)) { |
| // Types and properties can only have nullary |
| // names. Complain and recover by chopping off everything |
| // after the first name. |
| if (objcName->getNumArgs() > 0) { |
| SourceLoc firstNameLoc = attr->getNameLocs().front(); |
| SourceLoc afterFirstNameLoc = |
| Lexer::getLocForEndOfToken(TC.Context.SourceMgr, firstNameLoc); |
| TC.diagnose(firstNameLoc, diag::objc_name_req_nullary, |
| D->getDescriptiveKind()) |
| .fixItRemoveChars(afterFirstNameLoc, attr->getRParenLoc()); |
| const_cast<ObjCAttr *>(attr)->setName( |
| ObjCSelector(TC.Context, 0, objcName->getSelectorPieces()[0]), |
| /*implicit=*/false); |
| } |
| } else if (isa<SubscriptDecl>(D) || isa<DestructorDecl>(D)) { |
| TC.diagnose(attr->getLParenLoc(), |
| isa<SubscriptDecl>(D) |
| ? diag::objc_name_subscript |
| : diag::objc_name_deinit); |
| const_cast<ObjCAttr *>(attr)->clearName(); |
| } else { |
| // We have a function. Make sure that the number of parameters |
| // matches the "number of colons" in the name. |
| auto func = cast<AbstractFunctionDecl>(D); |
| auto params = func->getParameters(); |
| unsigned numParameters = params->size(); |
| if (auto CD = dyn_cast<ConstructorDecl>(func)) |
| if (CD->isObjCZeroParameterWithLongSelector()) |
| numParameters = 0; // Something like "init(foo: ())" |
| |
| // A throwing method has an error parameter. |
| if (func->hasThrows()) |
| ++numParameters; |
| |
| unsigned numArgumentNames = objcName->getNumArgs(); |
| if (numArgumentNames != numParameters) { |
| TC.diagnose(attr->getNameLocs().front(), |
| diag::objc_name_func_mismatch, |
| isa<FuncDecl>(func), |
| numArgumentNames, |
| numArgumentNames != 1, |
| numParameters, |
| numParameters != 1, |
| func->hasThrows()); |
| D->getAttrs().add( |
| ObjCAttr::createUnnamed(TC.Context, |
| attr->AtLoc, |
| attr->Range.Start)); |
| D->getAttrs().removeAttribute(attr); |
| } |
| } |
| } else if (isa<EnumElementDecl>(D)) { |
| // Enum elements require names. |
| diagnoseAndRemoveAttr(attr, diag::objc_enum_case_req_name); |
| } |
| } |
| |
| void AttributeChecker::visitNonObjCAttr(NonObjCAttr *attr) { |
| // Only extensions of classes; methods, properties, subscripts |
| // and constructors can be NonObjC. |
| // The last three are handled automatically by generic attribute |
| // validation -- for the first one, we have to check FuncDecls |
| // ourselves. |
| auto func = dyn_cast<FuncDecl>(D); |
| if (func && |
| (isa<DestructorDecl>(func) || |
| !checkObjCDeclContext(func) || |
| (isa<AccessorDecl>(func) && |
| !cast<AccessorDecl>(func)->isGetterOrSetter()))) { |
| diagnoseAndRemoveAttr(attr, diag::invalid_nonobjc_decl); |
| } |
| |
| if (auto ext = dyn_cast<ExtensionDecl>(D)) { |
| if (!ext->getSelfClassDecl()) |
| diagnoseAndRemoveAttr(attr, diag::invalid_nonobjc_extension); |
| } |
| } |
| |
| void AttributeChecker::visitObjCMembersAttr(ObjCMembersAttr *attr) { |
| if (!isa<ClassDecl>(D)) |
| diagnoseAndRemoveAttr(attr, diag::objcmembers_attribute_nonclass); |
| } |
| |
| void AttributeChecker::visitOptionalAttr(OptionalAttr *attr) { |
| if (!isa<ProtocolDecl>(D->getDeclContext())) { |
| diagnoseAndRemoveAttr(attr, diag::optional_attribute_non_protocol); |
| } else if (!cast<ProtocolDecl>(D->getDeclContext())->isObjC()) { |
| diagnoseAndRemoveAttr(attr, diag::optional_attribute_non_objc_protocol); |
| } else if (isa<ConstructorDecl>(D)) { |
| diagnoseAndRemoveAttr(attr, diag::optional_attribute_initializer); |
| } else { |
| auto objcAttr = D->getAttrs().getAttribute<ObjCAttr>(); |
| if (!objcAttr || objcAttr->isImplicit()) { |
| auto diag = TC.diagnose(attr->getLocation(), |
| diag::optional_attribute_missing_explicit_objc); |
| if (auto VD = dyn_cast<ValueDecl>(D)) |
| diag.fixItInsert(VD->getAttributeInsertionLoc(false), "@objc "); |
| } |
| } |
| } |
| |
| void TypeChecker::checkDeclAttributes(Decl *D) { |
| AttributeChecker Checker(*this, D); |
| for (auto attr : D->getAttrs()) { |
| if (!attr->isValid()) continue; |
| |
| // If Attr.def says that the attribute cannot appear on this kind of |
| // declaration, diagnose it and disable it. |
| if (attr->canAppearOnDecl(D)) { |
| // Otherwise, check it. |
| Checker.visit(attr); |
| continue; |
| } |
| |
| // Otherwise, this attribute cannot be applied to this declaration. If the |
| // attribute is only valid on one kind of declaration (which is pretty |
| // common) give a specific helpful error. |
| auto PossibleDeclKinds = attr->getOptions() & DeclAttribute::OnAnyDecl; |
| StringRef OnlyKind; |
| switch (PossibleDeclKinds) { |
| case DeclAttribute::OnAccessor: OnlyKind = "accessor"; break; |
| case DeclAttribute::OnClass: OnlyKind = "class"; break; |
| case DeclAttribute::OnConstructor: OnlyKind = "init"; break; |
| case DeclAttribute::OnDestructor: OnlyKind = "deinit"; break; |
| case DeclAttribute::OnEnum: OnlyKind = "enum"; break; |
| case DeclAttribute::OnEnumCase: OnlyKind = "case"; break; |
| case DeclAttribute::OnFunc | DeclAttribute::OnAccessor: // FIXME |
| case DeclAttribute::OnFunc: OnlyKind = "func"; break; |
| case DeclAttribute::OnImport: OnlyKind = "import"; break; |
| case DeclAttribute::OnModule: OnlyKind = "module"; break; |
| case DeclAttribute::OnParam: OnlyKind = "parameter"; break; |
| case DeclAttribute::OnProtocol: OnlyKind = "protocol"; break; |
| case DeclAttribute::OnStruct: OnlyKind = "struct"; break; |
| case DeclAttribute::OnSubscript: OnlyKind = "subscript"; break; |
| case DeclAttribute::OnTypeAlias: OnlyKind = "typealias"; break; |
| case DeclAttribute::OnVar: OnlyKind = "var"; break; |
| default: break; |
| } |
| |
| if (!OnlyKind.empty()) |
| Checker.diagnoseAndRemoveAttr(attr, diag::attr_only_one_decl_kind, |
| attr, OnlyKind); |
| else if (attr->isDeclModifier()) |
| Checker.diagnoseAndRemoveAttr(attr, diag::invalid_decl_modifier, attr); |
| else |
| Checker.diagnoseAndRemoveAttr(attr, diag::invalid_decl_attribute, attr); |
| } |
| } |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| // TODO(TF-789): Figure out the proper way to typecheck these. |
| void TypeChecker::checkDeclDifferentiableAttributes(Decl *D) { |
| AttributeChecker Checker(*this, D); |
| for (auto attr : D->getAttrs()) { |
| if (!isa<DifferentiableAttr>(attr) || !attr->isValid() || |
| !attr->canAppearOnDecl(D)) |
| continue; |
| Checker.visit(attr); |
| } |
| } |
| |
| /// Returns true if the given method is an valid implementation of a |
| /// @dynamicCallable attribute requirement. The method is given to be defined |
| /// as one of the following: `dynamicallyCall(withArguments:)` or |
| /// `dynamicallyCall(withKeywordArguments:)`. |
| bool swift::isValidDynamicCallableMethod(FuncDecl *decl, DeclContext *DC, |
| TypeChecker &TC, |
| bool hasKeywordArguments) { |
| // There are two cases to check. |
| // 1. `dynamicallyCall(withArguments:)`. |
| // In this case, the method is valid if the argument has type `A` where |
| // `A` conforms to `ExpressibleByArrayLiteral`. |
| // `A.ArrayLiteralElement` and the return type can be arbitrary. |
| // 2. `dynamicallyCall(withKeywordArguments:)` |
| // In this case, the method is valid if the argument has type `D` where |
| // `D` conforms to `ExpressibleByDictionaryLiteral` and `D.Key` conforms to |
| // `ExpressibleByStringLiteral`. |
| // `D.Value` and the return type can be arbitrary. |
| |
| // FIXME(InterfaceTypeRequest): Remove this. |
| (void)decl->getInterfaceType(); |
| auto paramList = decl->getParameters(); |
| if (paramList->size() != 1 || paramList->get(0)->isVariadic()) return false; |
| auto argType = paramList->get(0)->getType(); |
| |
| // If non-keyword (positional) arguments, check that argument type conforms to |
| // `ExpressibleByArrayLiteral`. |
| if (!hasKeywordArguments) { |
| auto arrayLitProto = |
| TC.Context.getProtocol(KnownProtocolKind::ExpressibleByArrayLiteral); |
| return TypeChecker::conformsToProtocol(argType, arrayLitProto, DC, |
| ConformanceCheckOptions()).hasValue(); |
| } |
| // If keyword arguments, check that argument type conforms to |
| // `ExpressibleByDictionaryLiteral` and that the `Key` associated type |
| // conforms to `ExpressibleByStringLiteral`. |
| auto stringLitProtocol = |
| TC.Context.getProtocol(KnownProtocolKind::ExpressibleByStringLiteral); |
| auto dictLitProto = |
| TC.Context.getProtocol(KnownProtocolKind::ExpressibleByDictionaryLiteral); |
| auto dictConf = TypeChecker::conformsToProtocol(argType, dictLitProto, DC, |
| ConformanceCheckOptions()); |
| if (!dictConf) return false; |
| auto keyType = dictConf.getValue().getTypeWitnessByName( |
| argType, TC.Context.Id_Key); |
| return TypeChecker::conformsToProtocol(keyType, stringLitProtocol, DC, |
| ConformanceCheckOptions()).hasValue(); |
| } |
| |
| /// Returns true if the given nominal type has a valid implementation of a |
| /// @dynamicCallable attribute requirement with the given argument name. |
| static bool hasValidDynamicCallableMethod(TypeChecker &TC, |
| NominalTypeDecl *decl, |
| Identifier argumentName, |
| bool hasKeywordArgs) { |
| auto declType = decl->getDeclaredType(); |
| auto methodName = DeclName(TC.Context, |
| DeclBaseName(TC.Context.Id_dynamicallyCall), |
| { argumentName }); |
| auto candidates = TC.lookupMember(decl, declType, methodName); |
| if (candidates.empty()) return false; |
| |
| // Filter valid candidates. |
| candidates.filter([&](LookupResultEntry entry, bool isOuter) { |
| auto candidate = cast<FuncDecl>(entry.getValueDecl()); |
| return isValidDynamicCallableMethod(candidate, decl, TC, hasKeywordArgs); |
| }); |
| |
| // If there are no valid candidates, return false. |
| if (candidates.size() == 0) return false; |
| return true; |
| } |
| |
| void AttributeChecker:: |
| visitDynamicCallableAttr(DynamicCallableAttr *attr) { |
| // This attribute is only allowed on nominal types. |
| auto decl = cast<NominalTypeDecl>(D); |
| auto type = decl->getDeclaredType(); |
| |
| bool hasValidMethod = false; |
| hasValidMethod |= |
| hasValidDynamicCallableMethod(TC, decl, TC.Context.Id_withArguments, |
| /*hasKeywordArgs*/ false); |
| hasValidMethod |= |
| hasValidDynamicCallableMethod(TC, decl, TC.Context.Id_withKeywordArguments, |
| /*hasKeywordArgs*/ true); |
| if (!hasValidMethod) { |
| TC.diagnose(attr->getLocation(), diag::invalid_dynamic_callable_type, type); |
| attr->setInvalid(); |
| } |
| } |
| |
| static bool hasSingleNonVariadicParam(SubscriptDecl *decl, |
| Identifier expectedLabel, |
| bool ignoreLabel = false) { |
| auto *indices = decl->getIndices(); |
| if (decl->isInvalid() || indices->size() != 1) |
| return false; |
| |
| auto *index = indices->get(0); |
| if (index->isVariadic() || !index->hasInterfaceType()) |
| return false; |
| |
| if (ignoreLabel) { |
| return true; |
| } |
| |
| return index->getArgumentName() == expectedLabel; |
| } |
| |
| /// Returns true if the given subscript method is an valid implementation of |
| /// the `subscript(dynamicMember:)` requirement for @dynamicMemberLookup. |
| /// The method is given to be defined as `subscript(dynamicMember:)`. |
| bool swift::isValidDynamicMemberLookupSubscript(SubscriptDecl *decl, |
| DeclContext *DC, |
| TypeChecker &TC, |
| bool ignoreLabel) { |
| // It could be |
| // - `subscript(dynamicMember: {Writable}KeyPath<...>)`; or |
| // - `subscript(dynamicMember: String*)` |
| return isValidKeyPathDynamicMemberLookup(decl, TC, ignoreLabel) || |
| isValidStringDynamicMemberLookup(decl, DC, TC, ignoreLabel); |
| } |
| |
| bool swift::isValidStringDynamicMemberLookup(SubscriptDecl *decl, |
| DeclContext *DC, TypeChecker &TC, |
| bool ignoreLabel) { |
| // There are two requirements: |
| // - The subscript method has exactly one, non-variadic parameter. |
| // - The parameter type conforms to `ExpressibleByStringLiteral`. |
| if (!hasSingleNonVariadicParam(decl, TC.Context.Id_dynamicMember, |
| ignoreLabel)) |
| return false; |
| |
| const auto *param = decl->getIndices()->get(0); |
| auto paramType = param->getType(); |
| |
| auto stringLitProto = |
| TC.Context.getProtocol(KnownProtocolKind::ExpressibleByStringLiteral); |
| |
| // If this is `subscript(dynamicMember: String*)` |
| return bool(TypeChecker::conformsToProtocol(paramType, stringLitProto, DC, |
| ConformanceCheckOptions())); |
| } |
| |
| bool swift::isValidKeyPathDynamicMemberLookup(SubscriptDecl *decl, |
| TypeChecker &TC, |
| bool ignoreLabel) { |
| if (!hasSingleNonVariadicParam(decl, TC.Context.Id_dynamicMember, |
| ignoreLabel)) |
| return false; |
| |
| const auto *param = decl->getIndices()->get(0); |
| if (auto NTD = param->getType()->getAnyNominal()) { |
| return NTD == TC.Context.getKeyPathDecl() || |
| NTD == TC.Context.getWritableKeyPathDecl() || |
| NTD == TC.Context.getReferenceWritableKeyPathDecl(); |
| } |
| return false; |
| } |
| |
| /// The @dynamicMemberLookup attribute is only allowed on types that have at |
| /// least one subscript member declared like this: |
| /// |
| /// subscript<KeywordType: ExpressibleByStringLiteral, LookupValue> |
| /// (dynamicMember name: KeywordType) -> LookupValue { get } |
| /// |
| /// ... but doesn't care about the mutating'ness of the getter/setter. |
| /// We just manually check the requirements here. |
| void AttributeChecker:: |
| visitDynamicMemberLookupAttr(DynamicMemberLookupAttr *attr) { |
| // This attribute is only allowed on nominal types. |
| auto decl = cast<NominalTypeDecl>(D); |
| auto type = decl->getDeclaredType(); |
| auto &ctx = decl->getASTContext(); |
| |
| auto emitInvalidTypeDiagnostic = [&](const SourceLoc loc) { |
| TC.diagnose(loc, diag::invalid_dynamic_member_lookup_type, type); |
| attr->setInvalid(); |
| }; |
| |
| // Look up `subscript(dynamicMember:)` candidates. |
| auto subscriptName = |
| DeclName(ctx, DeclBaseName::createSubscript(), ctx.Id_dynamicMember); |
| auto candidates = TC.lookupMember(decl, type, subscriptName); |
| |
| if (!candidates.empty()) { |
| // If no candidates are valid, then reject one. |
| auto oneCandidate = candidates.front().getValueDecl(); |
| candidates.filter([&](LookupResultEntry entry, bool isOuter) -> bool { |
| auto cand = cast<SubscriptDecl>(entry.getValueDecl()); |
| // FIXME(InterfaceTypeRequest): Remove this. |
| (void)cand->getInterfaceType(); |
| return isValidDynamicMemberLookupSubscript(cand, decl, TC); |
| }); |
| |
| if (candidates.empty()) { |
| emitInvalidTypeDiagnostic(oneCandidate->getLoc()); |
| } |
| |
| return; |
| } |
| |
| // If we couldn't find any candidates, it's likely because: |
| // |
| // 1. We don't have a subscript with `dynamicMember` label. |
| // 2. We have a subscript with `dynamicMember` label, but no argument label. |
| // |
| // Let's do another lookup using just the base name. |
| auto newCandidates = |
| TC.lookupMember(decl, type, DeclBaseName::createSubscript()); |
| |
| // Validate the candidates while ignoring the label. |
| newCandidates.filter([&](const LookupResultEntry entry, bool isOuter) { |
| auto cand = cast<SubscriptDecl>(entry.getValueDecl()); |
| // FIXME(InterfaceTypeRequest): Remove this. |
| (void)cand->getInterfaceType(); |
| return isValidDynamicMemberLookupSubscript(cand, decl, TC, |
| /*ignoreLabel*/ true); |
| }); |
| |
| // If there were no potentially valid candidates, then throw an error. |
| if (newCandidates.empty()) { |
| emitInvalidTypeDiagnostic(attr->getLocation()); |
| return; |
| } |
| |
| // For each candidate, emit a diagnostic. If we don't have an explicit |
| // argument label, then emit a fix-it to suggest the user to add one. |
| for (auto cand : newCandidates) { |
| auto SD = cast<SubscriptDecl>(cand.getValueDecl()); |
| auto index = SD->getIndices()->get(0); |
| TC.diagnose(SD, diag::invalid_dynamic_member_lookup_type, type); |
| |
| // If we have something like `subscript(foo:)` then we want to insert |
| // `dynamicMember` before `foo`. |
| if (index->getParameterNameLoc().isValid() && |
| index->getArgumentNameLoc().isInvalid()) { |
| TC.diagnose(SD, diag::invalid_dynamic_member_subscript) |
| .highlight(index->getSourceRange()) |
| .fixItInsert(index->getParameterNameLoc(), "dynamicMember "); |
| } |
| } |
| |
| attr->setInvalid(); |
| return; |
| } |
| |
| /// Get the innermost enclosing declaration for a declaration. |
| static Decl *getEnclosingDeclForDecl(Decl *D) { |
| // If the declaration is an accessor, treat its storage declaration |
| // as the enclosing declaration. |
| if (auto *accessor = dyn_cast<AccessorDecl>(D)) { |
| return accessor->getStorage(); |
| } |
| |
| return D->getDeclContext()->getInnermostDeclarationDeclContext(); |
| } |
| |
| void AttributeChecker::visitAvailableAttr(AvailableAttr *attr) { |
| if (TC.getLangOpts().DisableAvailabilityChecking) |
| return; |
| |
| if (auto *PD = dyn_cast<ProtocolDecl>(D->getDeclContext())) { |
| if (auto *VD = dyn_cast<ValueDecl>(D)) { |
| if (VD->isProtocolRequirement()) { |
| if (attr->isActivePlatform(TC.Context) || |
| attr->isLanguageVersionSpecific() || |
| attr->isPackageDescriptionVersionSpecific()) { |
| auto versionAvailability = attr->getVersionAvailability(TC.Context); |
| if (attr->isUnconditionallyUnavailable() || |
| versionAvailability == AvailableVersionComparison::Obsoleted || |
| versionAvailability == AvailableVersionComparison::Unavailable) { |
| if (!PD->isObjC()) { |
| diagnoseAndRemoveAttr(attr, diag::unavailable_method_non_objc_protocol); |
| return; |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| if (!attr->hasPlatform() || !attr->isActivePlatform(TC.Context) || |
| !attr->Introduced.hasValue()) { |
| return; |
| } |
| |
| SourceLoc attrLoc = attr->getLocation(); |
| |
| Optional<Diag<>> MaybeNotAllowed = |
| TC.diagnosticIfDeclCannotBePotentiallyUnavailable(D); |
| if (MaybeNotAllowed.hasValue()) { |
| TC.diagnose(attrLoc, MaybeNotAllowed.getValue()); |
| } |
| |
| // Find the innermost enclosing declaration with an availability |
| // range annotation and ensure that this attribute's available version range |
| // is fully contained within that declaration's range. If there is no such |
| // enclosing declaration, then there is nothing to check. |
| Optional<AvailabilityContext> EnclosingAnnotatedRange; |
| Decl *EnclosingDecl = getEnclosingDeclForDecl(D); |
| |
| while (EnclosingDecl) { |
| EnclosingAnnotatedRange = |
| AvailabilityInference::annotatedAvailableRange(EnclosingDecl, |
| TC.Context); |
| |
| if (EnclosingAnnotatedRange.hasValue()) |
| break; |
| |
| EnclosingDecl = getEnclosingDeclForDecl(EnclosingDecl); |
| } |
| |
| if (!EnclosingDecl) |
| return; |
| |
| AvailabilityContext AttrRange{ |
| VersionRange::allGTE(attr->Introduced.getValue())}; |
| |
| if (!AttrRange.isContainedIn(EnclosingAnnotatedRange.getValue())) { |
| TC.diagnose(attr->getLocation(), |
| diag::availability_decl_more_than_enclosing); |
| TC.diagnose(EnclosingDecl->getLoc(), |
| diag::availability_decl_more_than_enclosing_enclosing_here); |
| } |
| } |
| |
| void AttributeChecker::visitCDeclAttr(CDeclAttr *attr) { |
| // Only top-level func decls are currently supported. |
| if (D->getDeclContext()->isTypeContext()) |
| TC.diagnose(attr->getLocation(), |
| diag::cdecl_not_at_top_level); |
| |
| // The name must not be empty. |
| if (attr->Name.empty()) |
| TC.diagnose(attr->getLocation(), |
| diag::cdecl_empty_name); |
| } |
| |
| void AttributeChecker::visitUnsafeNoObjCTaggedPointerAttr( |
| UnsafeNoObjCTaggedPointerAttr *attr) { |
| // Only class protocols can have the attribute. |
| auto proto = dyn_cast<ProtocolDecl>(D); |
| if (!proto) { |
| TC.diagnose(attr->getLocation(), |
| diag::no_objc_tagged_pointer_not_class_protocol); |
| attr->setInvalid(); |
| } |
| |
| if (!proto->requiresClass() |
| && !proto->getAttrs().hasAttribute<ObjCAttr>()) { |
| TC.diagnose(attr->getLocation(), |
| diag::no_objc_tagged_pointer_not_class_protocol); |
| attr->setInvalid(); |
| } |
| } |
| |
| void AttributeChecker::visitSwiftNativeObjCRuntimeBaseAttr( |
| SwiftNativeObjCRuntimeBaseAttr *attr) { |
| // Only root classes can have the attribute. |
| auto theClass = dyn_cast<ClassDecl>(D); |
| if (!theClass) { |
| TC.diagnose(attr->getLocation(), |
| diag::swift_native_objc_runtime_base_not_on_root_class); |
| attr->setInvalid(); |
| return; |
| } |
| |
| if (theClass->hasSuperclass()) { |
| TC.diagnose(attr->getLocation(), |
| diag::swift_native_objc_runtime_base_not_on_root_class); |
| attr->setInvalid(); |
| return; |
| } |
| } |
| |
| void AttributeChecker::visitFinalAttr(FinalAttr *attr) { |
| // Reject combining 'final' with 'open'. |
| if (auto accessAttr = D->getAttrs().getAttribute<AccessControlAttr>()) { |
| if (accessAttr->getAccess() == AccessLevel::Open) { |
| TC.diagnose(attr->getLocation(), diag::open_decl_cannot_be_final, |
| D->getDescriptiveKind()); |
| return; |
| } |
| } |
| |
| if (isa<ClassDecl>(D)) |
| return; |
| |
| // 'final' only makes sense in the context of a class declaration. |
| // Reject it on global functions, protocols, structs, enums, etc. |
| if (!D->getDeclContext()->getSelfClassDecl()) { |
| TC.diagnose(attr->getLocation(), diag::member_cannot_be_final) |
| .fixItRemove(attr->getRange()); |
| |
| // Remove the attribute so child declarations are not flagged as final |
| // and duplicate the error message. |
| D->getAttrs().removeAttribute(attr); |
| return; |
| } |
| |
| // We currently only support final on var/let, func and subscript |
| // declarations. |
| if (!isa<VarDecl>(D) && !isa<FuncDecl>(D) && !isa<SubscriptDecl>(D)) { |
| TC.diagnose(attr->getLocation(), diag::final_not_allowed_here) |
| .fixItRemove(attr->getRange()); |
| return; |
| } |
| |
| if (auto *accessor = dyn_cast<AccessorDecl>(D)) { |
| if (!attr->isImplicit()) { |
| unsigned Kind = 2; |
| if (auto *VD = dyn_cast<VarDecl>(accessor->getStorage())) |
| Kind = VD->isLet() ? 1 : 0; |
| TC.diagnose(attr->getLocation(), diag::final_not_on_accessors, Kind) |
| .fixItRemove(attr->getRange()); |
| return; |
| } |
| } |
| } |
| |
| /// Return true if this is a builtin operator that cannot be defined in user |
| /// code. |
| static bool isBuiltinOperator(StringRef name, DeclAttribute *attr) { |
| return ((isa<PrefixAttr>(attr) && name == "&") || // lvalue to inout |
| (isa<PostfixAttr>(attr) && name == "!") || // optional unwrapping |
| (isa<PostfixAttr>(attr) && name == "?") || // optional chaining |
| (isa<InfixAttr>(attr) && name == "?") || // ternary operator |
| (isa<PostfixAttr>(attr) && name == ">") || // generic argument list |
| (isa<PrefixAttr>(attr) && name == "<")); // generic argument list |
| } |
| |
| void AttributeChecker::checkOperatorAttribute(DeclAttribute *attr) { |
| // Check out the operator attributes. They may be attached to an operator |
| // declaration or a function. |
| if (auto *OD = dyn_cast<OperatorDecl>(D)) { |
| // Reject attempts to define builtin operators. |
| if (isBuiltinOperator(OD->getName().str(), attr)) { |
| TC.diagnose(D->getStartLoc(), diag::redefining_builtin_operator, |
| attr->getAttrName(), OD->getName().str()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Otherwise, the attribute is always ok on an operator. |
| return; |
| } |
| |
| // Operators implementations may only be defined as functions. |
| auto *FD = dyn_cast<FuncDecl>(D); |
| if (!FD) { |
| TC.diagnose(D->getLoc(), diag::operator_not_func); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Only functions with an operator identifier can be declared with as an |
| // operator. |
| if (!FD->isOperator()) { |
| TC.diagnose(D->getStartLoc(), diag::attribute_requires_operator_identifier, |
| attr->getAttrName()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Reject attempts to define builtin operators. |
| if (isBuiltinOperator(FD->getName().str(), attr)) { |
| TC.diagnose(D->getStartLoc(), diag::redefining_builtin_operator, |
| attr->getAttrName(), FD->getName().str()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Otherwise, must be unary. |
| if (!FD->isUnaryOperator()) { |
| TC.diagnose(attr->getLocation(), diag::attribute_requires_single_argument, |
| attr->getAttrName()); |
| attr->setInvalid(); |
| return; |
| } |
| } |
| |
| void AttributeChecker::visitNSCopyingAttr(NSCopyingAttr *attr) { |
| // The @NSCopying attribute is only allowed on stored properties. |
| auto *VD = cast<VarDecl>(D); |
| |
| // It may only be used on class members. |
| auto classDecl = D->getDeclContext()->getSelfClassDecl(); |
| if (!classDecl) { |
| TC.diagnose(attr->getLocation(), diag::nscopying_only_on_class_properties); |
| attr->setInvalid(); |
| return; |
| } |
| |
| if (!VD->isSettable(VD->getDeclContext())) { |
| TC.diagnose(attr->getLocation(), diag::nscopying_only_mutable); |
| attr->setInvalid(); |
| return; |
| } |
| |
| if (!VD->hasStorage()) { |
| TC.diagnose(attr->getLocation(), diag::nscopying_only_stored_property); |
| attr->setInvalid(); |
| return; |
| } |
| |
| if (VD->hasInterfaceType()) { |
| if (!TC.checkConformanceToNSCopying(VD)) { |
| attr->setInvalid(); |
| return; |
| } |
| } |
| |
| assert(VD->getOverriddenDecl() == nullptr && |
| "Can't have value with storage that is an override"); |
| |
| // Check the type. It must be an [unchecked]optional, weak, a normal |
| // class, AnyObject, or classbound protocol. |
| // It must conform to the NSCopying protocol. |
| |
| } |
| |
| void AttributeChecker::checkApplicationMainAttribute(DeclAttribute *attr, |
| Identifier Id_ApplicationDelegate, |
| Identifier Id_Kit, |
| Identifier Id_ApplicationMain) { |
| // %select indexes for ApplicationMain diagnostics. |
| enum : unsigned { |
| UIApplicationMainClass, |
| NSApplicationMainClass, |
| }; |
| |
| unsigned applicationMainKind; |
| if (isa<UIApplicationMainAttr>(attr)) |
| applicationMainKind = UIApplicationMainClass; |
| else if (isa<NSApplicationMainAttr>(attr)) |
| applicationMainKind = NSApplicationMainClass; |
| else |
| llvm_unreachable("not an ApplicationMain attr"); |
| |
| auto *CD = dyn_cast<ClassDecl>(D); |
| |
| // The applicant not being a class should have been diagnosed by the early |
| // checker. |
| if (!CD) return; |
| |
| // The class cannot be generic. |
| if (CD->isGenericContext()) { |
| TC.diagnose(attr->getLocation(), |
| diag::attr_generic_ApplicationMain_not_supported, |
| applicationMainKind); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // @XXApplicationMain classes must conform to the XXApplicationDelegate |
| // protocol. |
| auto *SF = cast<SourceFile>(CD->getModuleScopeContext()); |
| auto &C = SF->getASTContext(); |
| |
| auto KitModule = C.getLoadedModule(Id_Kit); |
| ProtocolDecl *ApplicationDelegateProto = nullptr; |
| if (KitModule) { |
| SmallVector<ValueDecl *, 1> decls; |
| namelookup::lookupInModule(KitModule, Id_ApplicationDelegate, |
| decls, NLKind::QualifiedLookup, |
| namelookup::ResolutionKind::TypesOnly, |
| SF); |
| if (decls.size() == 1) |
| ApplicationDelegateProto = dyn_cast<ProtocolDecl>(decls[0]); |
| } |
| |
| if (!ApplicationDelegateProto || |
| !TypeChecker::conformsToProtocol(CD->getDeclaredType(), |
| ApplicationDelegateProto, |
| CD, None)) { |
| TC.diagnose(attr->getLocation(), |
| diag::attr_ApplicationMain_not_ApplicationDelegate, |
| applicationMainKind); |
| attr->setInvalid(); |
| } |
| |
| if (attr->isInvalid()) |
| return; |
| |
| // Register the class as the main class in the module. If there are multiples |
| // they will be diagnosed. |
| if (SF->registerMainClass(CD, attr->getLocation())) |
| attr->setInvalid(); |
| } |
| |
| void AttributeChecker::visitNSApplicationMainAttr(NSApplicationMainAttr *attr) { |
| auto &C = D->getASTContext(); |
| checkApplicationMainAttribute(attr, |
| C.getIdentifier("NSApplicationDelegate"), |
| C.getIdentifier("AppKit"), |
| C.getIdentifier("NSApplicationMain")); |
| } |
| void AttributeChecker::visitUIApplicationMainAttr(UIApplicationMainAttr *attr) { |
| auto &C = D->getASTContext(); |
| checkApplicationMainAttribute(attr, |
| C.getIdentifier("UIApplicationDelegate"), |
| C.getIdentifier("UIKit"), |
| C.getIdentifier("UIApplicationMain")); |
| } |
| |
| /// Determine whether the given context is an extension to an Objective-C class |
| /// where the class is defined in the Objective-C module and the extension is |
| /// defined within its module. |
| static bool isObjCClassExtensionInOverlay(DeclContext *dc) { |
| // Check whether we have an extension. |
| auto ext = dyn_cast<ExtensionDecl>(dc); |
| if (!ext) |
| return false; |
| |
| // Find the extended class. |
| auto classDecl = ext->getSelfClassDecl(); |
| if (!classDecl) |
| return false; |
| |
| auto clangLoader = dc->getASTContext().getClangModuleLoader(); |
| if (!clangLoader) return false; |
| return clangLoader->isInOverlayModuleForImportedModule(ext, classDecl); |
| } |
| |
| void AttributeChecker::visitRequiredAttr(RequiredAttr *attr) { |
| // The required attribute only applies to constructors. |
| auto ctor = cast<ConstructorDecl>(D); |
| auto parentTy = ctor->getDeclContext()->getDeclaredInterfaceType(); |
| if (!parentTy) { |
| // Constructor outside of nominal type context; we've already complained |
| // elsewhere. |
| attr->setInvalid(); |
| return; |
| } |
| // Only classes can have required constructors. |
| if (parentTy->getClassOrBoundGenericClass()) { |
| // The constructor must be declared within the class itself. |
| // FIXME: Allow an SDK overlay to add a required initializer to a class |
| // defined in Objective-C |
| if (!isa<ClassDecl>(ctor->getDeclContext()) && |
| !isObjCClassExtensionInOverlay(ctor->getDeclContext())) { |
| TC.diagnose(ctor, diag::required_initializer_in_extension, parentTy) |
| .highlight(attr->getLocation()); |
| attr->setInvalid(); |
| return; |
| } |
| } else { |
| if (!parentTy->hasError()) { |
| TC.diagnose(ctor, diag::required_initializer_nonclass, parentTy) |
| .highlight(attr->getLocation()); |
| } |
| attr->setInvalid(); |
| return; |
| } |
| } |
| |
| static bool hasThrowingFunctionParameter(CanType type) { |
| // Only consider throwing function types. |
| if (auto fnType = dyn_cast<AnyFunctionType>(type)) { |
| return fnType->getExtInfo().throws(); |
| } |
| |
| // Look through tuples. |
| if (auto tuple = dyn_cast<TupleType>(type)) { |
| for (auto eltType : tuple.getElementTypes()) { |
| if (hasThrowingFunctionParameter(eltType)) |
| return true; |
| } |
| return false; |
| } |
| |
| // Suppress diagnostics in the presence of errors. |
| if (type->hasError()) { |
| return true; |
| } |
| |
| return false; |
| } |
| |
| void AttributeChecker::visitRethrowsAttr(RethrowsAttr *attr) { |
| // 'rethrows' only applies to functions that take throwing functions |
| // as parameters. |
| auto fn = cast<AbstractFunctionDecl>(D); |
| for (auto param : *fn->getParameters()) { |
| if (hasThrowingFunctionParameter(param->getType() |
| ->lookThroughAllOptionalTypes() |
| ->getCanonicalType())) |
| return; |
| } |
| |
| TC.diagnose(attr->getLocation(), diag::rethrows_without_throwing_parameter); |
| attr->setInvalid(); |
| } |
| |
| /// Collect all used generic parameter types from a given type. |
| static void collectUsedGenericParameters( |
| Type Ty, SmallPtrSetImpl<TypeBase *> &ConstrainedGenericParams) { |
| if (!Ty) |
| return; |
| |
| if (!Ty->hasTypeParameter()) |
| return; |
| |
| // Add used generic parameters/archetypes. |
| Ty.visit([&](Type Ty) { |
| if (auto GP = dyn_cast<GenericTypeParamType>(Ty->getCanonicalType())) { |
| ConstrainedGenericParams.insert(GP); |
| } |
| }); |
| } |
| |
| /// Perform some sanity checks for the requirements provided by |
| /// the @_specialize attribute. |
| static void checkSpecializeAttrRequirements( |
| SpecializeAttr *attr, |
| AbstractFunctionDecl *FD, |
| const SmallPtrSet<TypeBase *, 4> &constrainedGenericParams, |
| TypeChecker &TC) { |
| auto genericSig = FD->getGenericSignature(); |
| |
| if (!attr->isFullSpecialization()) |
| return; |
| |
| if (constrainedGenericParams.size() == genericSig->getGenericParams().size()) |
| return; |
| |
| TC.diagnose( |
| attr->getLocation(), diag::specialize_attr_type_parameter_count_mismatch, |
| genericSig->getGenericParams().size(), constrainedGenericParams.size(), |
| constrainedGenericParams.size() < genericSig->getGenericParams().size()); |
| |
| if (constrainedGenericParams.size() < genericSig->getGenericParams().size()) { |
| // Figure out which archetypes are not constrained. |
| for (auto gp : genericSig->getGenericParams()) { |
| if (constrainedGenericParams.count(gp->getCanonicalType().getPointer())) |
| continue; |
| auto gpDecl = gp->getDecl(); |
| if (gpDecl) { |
| TC.diagnose(attr->getLocation(), |
| diag::specialize_attr_missing_constraint, |
| gpDecl->getFullName()); |
| } |
| } |
| } |
| } |
| |
| /// Require that the given type either not involve type parameters or be |
| /// a type parameter. |
| static bool diagnoseIndirectGenericTypeParam(SourceLoc loc, Type type, |
| TypeRepr *typeRepr) { |
| if (type->hasTypeParameter() && !type->is<GenericTypeParamType>()) { |
| type->getASTContext().Diags.diagnose( |
| loc, |
| diag::specialize_attr_only_generic_param_req) |
| .highlight(typeRepr->getSourceRange()); |
| return true; |
| } |
| |
| return false; |
| } |
| |
| /// Type check that a set of requirements provided by @_specialize. |
| /// Store the set of requirements in the attribute. |
| void AttributeChecker::visitSpecializeAttr(SpecializeAttr *attr) { |
| DeclContext *DC = D->getDeclContext(); |
| auto *FD = cast<AbstractFunctionDecl>(D); |
| auto genericSig = FD->getGenericSignature(); |
| auto *trailingWhereClause = attr->getTrailingWhereClause(); |
| |
| if (!trailingWhereClause) { |
| // Report a missing "where" clause. |
| TC.diagnose(attr->getLocation(), diag::specialize_missing_where_clause); |
| return; |
| } |
| |
| if (trailingWhereClause->getRequirements().empty()) { |
| // Report an empty "where" clause. |
| TC.diagnose(attr->getLocation(), diag::specialize_empty_where_clause); |
| return; |
| } |
| |
| if (!genericSig) { |
| // Only generic functions are permitted to have trailing where clauses. |
| TC.diagnose(attr->getLocation(), |
| diag::specialize_attr_nongeneric_trailing_where, |
| FD->getFullName()) |
| .highlight(trailingWhereClause->getSourceRange()); |
| return; |
| } |
| |
| // Form a new generic signature based on the old one. |
| GenericSignatureBuilder Builder(D->getASTContext()); |
| |
| // First, add the old generic signature. |
| Builder.addGenericSignature(genericSig); |
| |
| // Set of generic parameters being constrained. It is used to |
| // determine if a full specialization misses requirements for |
| // some of the generic parameters. |
| SmallPtrSet<TypeBase *, 4> constrainedGenericParams; |
| |
| // Go over the set of requirements, adding them to the builder. |
| WhereClauseOwner(FD, attr).visitRequirements(TypeResolutionStage::Interface, |
| [&](const Requirement &req, RequirementRepr *reqRepr) { |
| // Collect all of the generic parameters used by these types. |
| switch (req.getKind()) { |
| case RequirementKind::Conformance: |
| case RequirementKind::SameType: |
| case RequirementKind::Superclass: |
| collectUsedGenericParameters(req.getSecondType(), |
| constrainedGenericParams); |
| LLVM_FALLTHROUGH; |
| |
| case RequirementKind::Layout: |
| collectUsedGenericParameters(req.getFirstType(), |
| constrainedGenericParams); |
| break; |
| } |
| |
| // Check additional constraints. |
| // FIXME: These likely aren't fundamental limitations. |
| switch (req.getKind()) { |
| case RequirementKind::SameType: { |
| bool firstHasTypeParameter = req.getFirstType()->hasTypeParameter(); |
| bool secondHasTypeParameter = req.getSecondType()->hasTypeParameter(); |
| |
| // Exactly one type can have a type parameter. |
| if (firstHasTypeParameter == secondHasTypeParameter) { |
| TC.diagnose( |
| attr->getLocation(), |
| firstHasTypeParameter |
| ? diag::specialize_attr_non_concrete_same_type_req |
| : diag::specialize_attr_only_one_concrete_same_type_req) |
| .highlight(reqRepr->getSourceRange()); |
| return false; |
| } |
| |
| // We either need a fully-concrete type or a generic type parameter. |
| if (diagnoseIndirectGenericTypeParam(attr->getLocation(), |
| req.getFirstType(), |
| reqRepr->getFirstTypeRepr()) || |
| diagnoseIndirectGenericTypeParam(attr->getLocation(), |
| req.getSecondType(), |
| reqRepr->getSecondTypeRepr())) { |
| return false; |
| } |
| break; |
| } |
| |
| case RequirementKind::Superclass: |
| TC.diagnose(attr->getLocation(), |
| diag::specialize_attr_non_protocol_type_constraint_req) |
| .highlight(reqRepr->getSourceRange()); |
| return false; |
| |
| case RequirementKind::Conformance: |
| if (diagnoseIndirectGenericTypeParam(attr->getLocation(), |
| req.getFirstType(), |
| reqRepr->getSubjectRepr())) { |
| return false; |
| } |
| |
| if (!req.getSecondType()->is<ProtocolType>()) { |
| TC.diagnose(attr->getLocation(), |
| diag::specialize_attr_non_protocol_type_constraint_req) |
| .highlight(reqRepr->getSourceRange()); |
| return false; |
| } |
| |
| TC.diagnose(attr->getLocation(), |
| diag::specialize_attr_unsupported_kind_of_req) |
| .highlight(reqRepr->getSourceRange()); |
| |
| return false; |
| |
| case RequirementKind::Layout: |
| if (diagnoseIndirectGenericTypeParam(attr->getLocation(), |
| req.getFirstType(), |
| reqRepr->getSubjectRepr())) { |
| return false; |
| } |
| break; |
| } |
| |
| // Add the requirement to the generic signature builder. |
| using FloatingRequirementSource = |
| GenericSignatureBuilder::FloatingRequirementSource; |
| Builder.addRequirement(req, reqRepr, |
| FloatingRequirementSource::forExplicit(reqRepr), |
| nullptr, DC->getParentModule()); |
| return false; |
| }); |
| |
| // Check the validity of provided requirements. |
| checkSpecializeAttrRequirements(attr, FD, constrainedGenericParams, TC); |
| |
| // Check the result. |
| auto specializedSig = std::move(Builder).computeGenericSignature( |
| attr->getLocation(), |
| /*allowConcreteGenericParams=*/true); |
| attr->setSpecializedSignature(specializedSig); |
| } |
| |
| void AttributeChecker::visitFixedLayoutAttr(FixedLayoutAttr *attr) { |
| if (isa<StructDecl>(D)) { |
| TC.diagnose(attr->getLocation(), diag::fixed_layout_struct) |
| .fixItReplace(attr->getRange(), "@frozen"); |
| } |
| |
| auto *VD = cast<ValueDecl>(D); |
| |
| if (VD->getFormalAccess() < AccessLevel::Public && |
| !VD->getAttrs().hasAttribute<UsableFromInlineAttr>()) { |
| diagnoseAndRemoveAttr(attr, diag::fixed_layout_attr_on_internal_type, |
| VD->getFullName(), VD->getFormalAccess()); |
| } |
| } |
| |
| void AttributeChecker::visitUsableFromInlineAttr(UsableFromInlineAttr *attr) { |
| auto *VD = cast<ValueDecl>(D); |
| |
| // FIXME: Once protocols can contain nominal types, do we want to allow |
| // these nominal types to have access control (and also @usableFromInline)? |
| if (isa<ProtocolDecl>(VD->getDeclContext())) { |
| diagnoseAndRemoveAttr(attr, diag::usable_from_inline_attr_in_protocol); |
| return; |
| } |
| |
| // @usableFromInline can only be applied to internal declarations. |
| if (VD->getFormalAccess() != AccessLevel::Internal) { |
| diagnoseAndRemoveAttr(attr, |
| diag::usable_from_inline_attr_with_explicit_access, |
| VD->getFullName(), |
| VD->getFormalAccess()); |
| return; |
| } |
| |
| // On internal declarations, @inlinable implies @usableFromInline. |
| if (VD->getAttrs().hasAttribute<InlinableAttr>()) { |
| if (TC.Context.isSwiftVersionAtLeast(4,2)) |
| diagnoseAndRemoveAttr(attr, diag::inlinable_implies_usable_from_inline); |
| return; |
| } |
| } |
| |
| void AttributeChecker::visitInlinableAttr(InlinableAttr *attr) { |
| // @inlinable cannot be applied to stored properties. |
| // |
| // If the type is fixed-layout, the accessors are inlinable anyway; |
| // if the type is resilient, the accessors cannot be inlinable |
| // because clients cannot directly access storage. |
| if (auto *VD = dyn_cast<VarDecl>(D)) { |
| if (VD->hasStorage() || VD->getAttrs().hasAttribute<LazyAttr>()) { |
| diagnoseAndRemoveAttr(attr, |
| diag::attribute_invalid_on_stored_property, |
| attr); |
| return; |
| } |
| } |
| |
| auto *VD = cast<ValueDecl>(D); |
| |
| // Calls to dynamically-dispatched declarations are never devirtualized, |
| // so marking them as @inlinable does not make sense. |
| if (VD->isDynamic()) { |
| diagnoseAndRemoveAttr(attr, diag::inlinable_dynamic_not_supported); |
| return; |
| } |
| |
| // @inlinable can only be applied to public or internal declarations. |
| auto access = VD->getFormalAccess(); |
| if (access < AccessLevel::Internal) { |
| diagnoseAndRemoveAttr(attr, diag::inlinable_decl_not_public, |
| VD->getBaseName(), |
| access); |
| return; |
| } |
| |
| // @inlinable cannot be applied to deinitializers in resilient classes. |
| if (auto *DD = dyn_cast<DestructorDecl>(D)) { |
| if (auto *CD = dyn_cast<ClassDecl>(DD->getDeclContext())) { |
| if (CD->isResilient()) { |
| diagnoseAndRemoveAttr(attr, diag::inlinable_resilient_deinit); |
| return; |
| } |
| } |
| } |
| } |
| |
| void AttributeChecker::visitOptimizeAttr(OptimizeAttr *attr) { |
| if (auto *VD = dyn_cast<VarDecl>(D)) { |
| if (VD->hasStorage()) { |
| diagnoseAndRemoveAttr(attr, |
| diag::attribute_invalid_on_stored_property, |
| attr); |
| return; |
| } |
| } |
| } |
| |
| void AttributeChecker::visitDiscardableResultAttr(DiscardableResultAttr *attr) { |
| if (auto *FD = dyn_cast<FuncDecl>(D)) { |
| if (auto result = FD->getResultInterfaceType()) { |
| auto resultIsVoid = result->isVoid(); |
| if (resultIsVoid || result->isUninhabited()) { |
| diagnoseAndRemoveAttr(attr, |
| diag::discardable_result_on_void_never_function, |
| resultIsVoid); |
| } |
| } |
| } |
| } |
| |
| /// Lookup the replaced decl in the replacments scope. |
| void lookupReplacedDecl(DeclName replacedDeclName, |
| const DynamicReplacementAttr *attr, |
| const ValueDecl *replacement, |
| SmallVectorImpl<ValueDecl *> &results) { |
| auto *declCtxt = replacement->getDeclContext(); |
| |
| // Look at the accessors' storage's context. |
| if (auto *accessor = dyn_cast<AccessorDecl>(replacement)) { |
| auto *storage = accessor->getStorage(); |
| declCtxt = storage->getDeclContext(); |
| } |
| |
| auto *moduleScopeCtxt = declCtxt->getModuleScopeContext(); |
| if (isa<FileUnit>(declCtxt)) { |
| UnqualifiedLookup lookup(replacedDeclName, moduleScopeCtxt, |
| attr->getLocation()); |
| if (lookup.isSuccess()) { |
| for (auto entry : lookup.Results) { |
| results.push_back(entry.getValueDecl()); |
| } |
| } |
| return; |
| } |
| |
| assert(declCtxt->isTypeContext()); |
| auto typeCtx = dyn_cast<NominalTypeDecl>(declCtxt->getAsDecl()); |
| if (!typeCtx) |
| typeCtx = cast<ExtensionDecl>(declCtxt->getAsDecl())->getExtendedNominal(); |
| |
| if (typeCtx) |
| moduleScopeCtxt->lookupQualified({typeCtx}, replacedDeclName, |
| NL_QualifiedDefault, results); |
| } |
| |
| /// Remove any argument labels from the interface type of the given value that |
| /// are extraneous from the type system's point of view, producing the |
| /// type to compare against for the purposes of dynamic replacement. |
| static Type getDynamicComparisonType(ValueDecl *value) { |
| unsigned numArgumentLabels = 0; |
| |
| if (isa<AbstractFunctionDecl>(value)) { |
| ++numArgumentLabels; |
| |
| if (value->getDeclContext()->isTypeContext()) |
| ++numArgumentLabels; |
| } else if (isa<SubscriptDecl>(value)) { |
| ++numArgumentLabels; |
| } |
| |
| auto interfaceType = value->getInterfaceType(); |
| if (!interfaceType) |
| return ErrorType::get(value->getASTContext()); |
| |
| return interfaceType->removeArgumentLabels(numArgumentLabels); |
| } |
| |
| static FuncDecl *findReplacedAccessor(DeclName replacedVarName, |
| AccessorDecl *replacement, |
| DynamicReplacementAttr *attr, |
| TypeChecker &TC) { |
| |
| // Retrieve the replaced abstract storage decl. |
| SmallVector<ValueDecl *, 4> results; |
| lookupReplacedDecl(replacedVarName, attr, replacement, results); |
| |
| // Filter out any accessors that won't work. |
| if (!results.empty()) { |
| auto replacementStorage = replacement->getStorage(); |
| Type replacementStorageType = getDynamicComparisonType(replacementStorage); |
| results.erase(std::remove_if(results.begin(), results.end(), |
| [&](ValueDecl *result) { |
| // Protocol requirements are not replaceable. |
| if (isa<ProtocolDecl>(result->getDeclContext())) |
| return true; |
| // Check for static/instance mismatch. |
| if (result->isStatic() != replacementStorage->isStatic()) |
| return true; |
| |
| // Check for type mismatch. |
| auto resultType = getDynamicComparisonType(result); |
| if (!resultType->isEqual(replacementStorageType) && |
| !resultType->matches( |
| replacementStorageType, |
| TypeMatchFlags::AllowCompatibleOpaqueTypeArchetypes)) { |
| return true; |
| } |
| |
| return false; |
| }), |
| results.end()); |
| } |
| |
| if (results.empty()) { |
| TC.diagnose(attr->getLocation(), |
| diag::dynamic_replacement_accessor_not_found, replacedVarName); |
| attr->setInvalid(); |
| return nullptr; |
| } |
| |
| if (results.size() > 1) { |
| TC.diagnose(attr->getLocation(), |
| diag::dynamic_replacement_accessor_ambiguous, replacedVarName); |
| for (auto result : results) { |
| TC.diagnose(result, |
| diag::dynamic_replacement_accessor_ambiguous_candidate, |
| result->getModuleContext()->getFullName()); |
| } |
| attr->setInvalid(); |
| return nullptr; |
| } |
| |
| assert(!isa<FuncDecl>(results[0])); |
| |
| // FIXME(InterfaceTypeRequest): Remove this. |
| (void)results[0]->getInterfaceType(); |
| auto *origStorage = cast<AbstractStorageDecl>(results[0]); |
| if (!origStorage->isDynamic()) { |
| TC.diagnose(attr->getLocation(), |
| diag::dynamic_replacement_accessor_not_dynamic, |
| replacedVarName); |
| attr->setInvalid(); |
| return nullptr; |
| } |
| |
| // Find the accessor in the replaced storage decl. |
| auto *origAccessor = origStorage->getOpaqueAccessor( |
| replacement->getAccessorKind()); |
| if (!origAccessor) |
| return nullptr; |
| |
| // FIXME(InterfaceTypeRequest): Remove this. |
| (void)origAccessor->getInterfaceType(); |
| if (origAccessor->isImplicit() && |
| !(origStorage->getReadImpl() == ReadImplKind::Stored && |
| origStorage->getWriteImpl() == WriteImplKind::Stored)) { |
| TC.diagnose(attr->getLocation(), |
| diag::dynamic_replacement_accessor_not_explicit, |
| (unsigned)origAccessor->getAccessorKind(), replacedVarName); |
| attr->setInvalid(); |
| return nullptr; |
| } |
| |
| return origAccessor; |
| } |
| |
| static AbstractFunctionDecl * |
| findReplacedFunction(DeclName replacedFunctionName, |
| const AbstractFunctionDecl *replacement, |
| DynamicReplacementAttr *attr, TypeChecker *TC) { |
| |
| // Note: we might pass a constant attribute when typechecker is nullptr. |
| // Any modification to attr must be guarded by a null check on TC. |
| // |
| SmallVector<ValueDecl *, 4> results; |
| lookupReplacedDecl(replacedFunctionName, attr, replacement, results); |
| |
| for (auto *result : results) { |
| // Protocol requirements are not replaceable. |
| if (isa<ProtocolDecl>(result->getDeclContext())) |
| continue; |
| // Check for static/instance mismatch. |
| if (result->isStatic() != replacement->isStatic()) |
| continue; |
| |
| TypeMatchOptions matchMode = TypeMatchFlags::AllowABICompatible; |
| matchMode |= TypeMatchFlags::AllowCompatibleOpaqueTypeArchetypes; |
| if (result->getInterfaceType()->getCanonicalType()->matches( |
| replacement->getInterfaceType()->getCanonicalType(), matchMode)) { |
| if (!result->isDynamic()) { |
| if (TC) { |
| TC->diagnose(attr->getLocation(), |
| diag::dynamic_replacement_function_not_dynamic, |
| replacedFunctionName); |
| attr->setInvalid(); |
| } |
| return nullptr; |
| } |
| return cast<AbstractFunctionDecl>(result); |
| } |
| } |
| |
| if (!TC) |
| return nullptr; |
| |
| if (results.empty()) { |
| TC->diagnose(attr->getLocation(), |
| diag::dynamic_replacement_function_not_found, |
| attr->getReplacedFunctionName()); |
| } else { |
| TC->diagnose(attr->getLocation(), |
| diag::dynamic_replacement_function_of_type_not_found, |
| attr->getReplacedFunctionName(), |
| replacement->getInterfaceType()->getCanonicalType()); |
| |
| for (auto *result : results) { |
| TC->diagnose(SourceLoc(), |
| diag::dynamic_replacement_found_function_of_type, |
| attr->getReplacedFunctionName(), |
| result->getInterfaceType()->getCanonicalType()); |
| } |
| } |
| attr->setInvalid(); |
| return nullptr; |
| } |
| |
| static AbstractStorageDecl * |
| findReplacedStorageDecl(DeclName replacedFunctionName, |
| const AbstractStorageDecl *replacement, |
| const DynamicReplacementAttr *attr) { |
| |
| SmallVector<ValueDecl *, 4> results; |
| lookupReplacedDecl(replacedFunctionName, attr, replacement, results); |
| |
| for (auto *result : results) { |
| // Check for static/instance mismatch. |
| if (result->isStatic() != replacement->isStatic()) |
| continue; |
| if (result->getInterfaceType()->getCanonicalType()->matches( |
| replacement->getInterfaceType()->getCanonicalType(), |
| TypeMatchFlags::AllowABICompatible)) { |
| if (!result->isDynamic()) { |
| return nullptr; |
| } |
| return cast<AbstractStorageDecl>(result); |
| } |
| } |
| return nullptr; |
| } |
| |
| ValueDecl *TypeChecker::findReplacedDynamicFunction(const ValueDecl *vd) { |
| assert(isa<AbstractFunctionDecl>(vd) || isa<AbstractStorageDecl>(vd)); |
| if (isa<AccessorDecl>(vd)) |
| return nullptr; |
| |
| auto *attr = vd->getAttrs().getAttribute<DynamicReplacementAttr>(); |
| if (!attr) |
| return nullptr; |
| |
| auto *afd = dyn_cast<AbstractFunctionDecl>(vd); |
| if (afd) { |
| // When we pass nullptr as the type checker argument attr is truely const. |
| return findReplacedFunction(attr->getReplacedFunctionName(), afd, |
| const_cast<DynamicReplacementAttr *>(attr), |
| nullptr); |
| } |
| auto *storageDecl = dyn_cast<AbstractStorageDecl>(vd); |
| if (!storageDecl) |
| return nullptr; |
| return findReplacedStorageDecl(attr->getReplacedFunctionName(), storageDecl, attr); |
| } |
| |
| void AttributeChecker::visitDynamicReplacementAttr(DynamicReplacementAttr *attr) { |
| assert(isa<AbstractFunctionDecl>(D) || isa<AbstractStorageDecl>(D)); |
| auto *VD = cast<ValueDecl>(D); |
| |
| if (!isa<ExtensionDecl>(VD->getDeclContext()) && |
| !VD->getDeclContext()->isModuleScopeContext()) { |
| TC.diagnose(attr->getLocation(), diag::dynamic_replacement_not_in_extension, |
| VD->getBaseName()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| if (VD->isNativeDynamic()) { |
| TC.diagnose(attr->getLocation(), diag::dynamic_replacement_must_not_be_dynamic, |
| VD->getBaseName()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Don't process a declaration twice. This will happen to accessor decls after |
| // we have processed their var decls. |
| if (attr->getReplacedFunction()) |
| return; |
| |
| SmallVector<AbstractFunctionDecl *, 4> replacements; |
| SmallVector<AbstractFunctionDecl *, 4> origs; |
| |
| // Collect the accessor replacement mapping if this is an abstract storage. |
| if (auto *var = dyn_cast<AbstractStorageDecl>(VD)) { |
| var->visitParsedAccessors([&](AccessorDecl *accessor) { |
| if (attr->isInvalid()) |
| return; |
| |
| // FIXME(InterfaceTypeRequest): Remove this. |
| (void)accessor->getInterfaceType(); |
| auto *orig = findReplacedAccessor(attr->getReplacedFunctionName(), |
| accessor, attr, TC); |
| if (!orig) |
| return; |
| |
| origs.push_back(orig); |
| replacements.push_back(accessor); |
| }); |
| } else { |
| // Otherwise, find the matching function. |
| auto *fun = cast<AbstractFunctionDecl>(VD); |
| if (auto *orig = findReplacedFunction(attr->getReplacedFunctionName(), fun, |
| attr, &TC)) { |
| origs.push_back(orig); |
| replacements.push_back(fun); |
| } else |
| return; |
| } |
| |
| // Annotate the replacement with the original func decl. |
| for (auto index : indices(replacements)) { |
| if (auto *attr = replacements[index] |
| ->getAttrs() |
| .getAttribute<DynamicReplacementAttr>()) { |
| auto *replacedFun = origs[index]; |
| auto *replacement = replacements[index]; |
| if (replacedFun->isObjC() && !replacement->isObjC()) { |
| TC.diagnose(attr->getLocation(), |
| diag::dynamic_replacement_replacement_not_objc_dynamic, |
| attr->getReplacedFunctionName()); |
| attr->setInvalid(); |
| return; |
| } |
| if (!replacedFun->isObjC() && replacement->isObjC()) { |
| TC.diagnose(attr->getLocation(), |
| diag::dynamic_replacement_replaced_not_objc_dynamic, |
| attr->getReplacedFunctionName()); |
| attr->setInvalid(); |
| return; |
| } |
| attr->setReplacedFunction(replacedFun); |
| continue; |
| } |
| auto *newAttr = DynamicReplacementAttr::create( |
| VD->getASTContext(), attr->getReplacedFunctionName(), origs[index]); |
| DeclAttributes &attrs = replacements[index]->getAttrs(); |
| attrs.add(newAttr); |
| } |
| if (auto *CD = dyn_cast<ConstructorDecl>(VD)) { |
| auto *attr = CD->getAttrs().getAttribute<DynamicReplacementAttr>(); |
| auto replacedIsConvenienceInit = |
| cast<ConstructorDecl>(attr->getReplacedFunction())->isConvenienceInit(); |
| if (replacedIsConvenienceInit &&!CD->isConvenienceInit()) { |
| TC.diagnose(attr->getLocation(), |
| diag::dynamic_replacement_replaced_constructor_is_convenience, |
| attr->getReplacedFunctionName()); |
| } else if (!replacedIsConvenienceInit && CD->isConvenienceInit()) { |
| TC.diagnose( |
| attr->getLocation(), |
| diag::dynamic_replacement_replaced_constructor_is_not_convenience, |
| attr->getReplacedFunctionName()); |
| } |
| } |
| |
| |
| // Remove the attribute on the abstract storage (we have moved it to the |
| // accessor decl). |
| if (!isa<AbstractStorageDecl>(VD)) |
| return; |
| D->getAttrs().removeAttribute(attr); |
| } |
| |
| void AttributeChecker::visitImplementsAttr(ImplementsAttr *attr) { |
| TypeLoc &ProtoTypeLoc = attr->getProtocolType(); |
| |
| DeclContext *DC = D->getDeclContext(); |
| |
| Type T = ProtoTypeLoc.getType(); |
| if (!T && ProtoTypeLoc.getTypeRepr()) { |
| TypeResolutionOptions options = None; |
| options |= TypeResolutionFlags::AllowUnboundGenerics; |
| |
| auto resolution = TypeResolution::forContextual(DC); |
| T = resolution.resolveType(ProtoTypeLoc.getTypeRepr(), options); |
| ProtoTypeLoc.setType(T); |
| } |
| |
| // Definite error-types were already diagnosed in resolveType. |
| if (T->hasError()) |
| return; |
| |
| // Check that we got a ProtocolType. |
| if (auto PT = T->getAs<ProtocolType>()) { |
| ProtocolDecl *PD = PT->getDecl(); |
| |
| // Check that the ProtocolType has the specified member. |
| LookupResult R = TC.lookupMember(PD->getDeclContext(), |
| PT, attr->getMemberName()); |
| if (!R) { |
| TC.diagnose(attr->getLocation(), |
| diag::implements_attr_protocol_lacks_member, |
| PD->getBaseName(), attr->getMemberName()) |
| .highlight(attr->getMemberNameLoc().getSourceRange()); |
| } |
| |
| // Check that the decl we're decorating is a member of a type that actually |
| // conforms to the specified protocol. |
| NominalTypeDecl *NTD = DC->getSelfNominalTypeDecl(); |
| SmallVector<ProtocolConformance *, 2> conformances; |
| if (!NTD->lookupConformance(DC->getParentModule(), PD, conformances)) { |
| TC.diagnose(attr->getLocation(), |
| diag::implements_attr_protocol_not_conformed_to, |
| NTD->getFullName(), PD->getFullName()) |
| .highlight(ProtoTypeLoc.getTypeRepr()->getSourceRange()); |
| } |
| |
| } else { |
| TC.diagnose(attr->getLocation(), |
| diag::implements_attr_non_protocol_type) |
| .highlight(ProtoTypeLoc.getTypeRepr()->getSourceRange()); |
| } |
| } |
| |
| void AttributeChecker::visitFrozenAttr(FrozenAttr *attr) { |
| if (auto *ED = dyn_cast<EnumDecl>(D)) { |
| if (!ED->getModuleContext()->isResilient()) { |
| diagnoseAndRemoveAttr(attr, diag::enum_frozen_nonresilient, attr); |
| return; |
| } |
| |
| if (ED->getFormalAccess() < AccessLevel::Public && |
| !ED->getAttrs().hasAttribute<UsableFromInlineAttr>()) { |
| diagnoseAndRemoveAttr(attr, diag::enum_frozen_nonpublic, attr); |
| return; |
| } |
| } |
| |
| auto *VD = cast<ValueDecl>(D); |
| |
| if (VD->getFormalAccess() < AccessLevel::Public && |
| !VD->getAttrs().hasAttribute<UsableFromInlineAttr>()) { |
| diagnoseAndRemoveAttr(attr, diag::frozen_attr_on_internal_type, |
| VD->getFullName(), VD->getFormalAccess()); |
| } |
| } |
| |
| void AttributeChecker::visitCustomAttr(CustomAttr *attr) { |
| auto dc = D->getInnermostDeclContext(); |
| |
| // Figure out which nominal declaration this custom attribute refers to. |
| auto nominal = evaluateOrDefault( |
| TC.Context.evaluator, CustomAttrNominalRequest{attr, dc}, nullptr); |
| |
| // If there is no nominal type with this name, complain about this being |
| // an unknown attribute. |
| if (!nominal) { |
| std::string typeName; |
| if (auto typeRepr = attr->getTypeLoc().getTypeRepr()) { |
| llvm::raw_string_ostream out(typeName); |
| typeRepr->print(out); |
| } else { |
| typeName = attr->getTypeLoc().getType().getString(); |
| } |
| |
| TC.diagnose(attr->getLocation(), diag::unknown_attribute, |
| typeName); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // If the nominal type is a property wrapper type, we can be delegating |
| // through a property. |
| if (nominal->getPropertyWrapperTypeInfo()) { |
| // property wrappers can only be applied to variables |
| if (!isa<VarDecl>(D) || isa<ParamDecl>(D)) { |
| TC.diagnose(attr->getLocation(), |
| diag::property_wrapper_attribute_not_on_property, |
| nominal->getFullName()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| return; |
| } |
| |
| // If the nominal type is a function builder type, verify that D is a |
| // function, storage with an explicit getter, or parameter of function type. |
| if (nominal->getAttrs().hasAttribute<FunctionBuilderAttr>()) { |
| ValueDecl *decl; |
| if (auto param = dyn_cast<ParamDecl>(D)) { |
| decl = param; |
| } else if (auto func = dyn_cast<FuncDecl>(D)) { |
| decl = func; |
| } else if (auto storage = dyn_cast<AbstractStorageDecl>(D)) { |
| decl = storage; |
| auto getter = storage->getParsedAccessor(AccessorKind::Get); |
| if (!getter || !getter->hasBody()) { |
| TC.diagnose(attr->getLocation(), |
| diag::function_builder_attribute_on_storage_without_getter, |
| nominal->getFullName(), |
| isa<SubscriptDecl>(storage) ? 0 |
| : storage->getDeclContext()->isTypeContext() ? 1 |
| : cast<VarDecl>(storage)->isLet() ? 2 : 3); |
| attr->setInvalid(); |
| return; |
| } |
| } else { |
| TC.diagnose(attr->getLocation(), |
| diag::function_builder_attribute_not_allowed_here, |
| nominal->getFullName()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Diagnose and ignore arguments. |
| if (attr->getArg()) { |
| TC.diagnose(attr->getLocation(), diag::function_builder_arguments) |
| .highlight(attr->getArg()->getSourceRange()); |
| } |
| |
| // Complain if this isn't the primary function-builder attribute. |
| auto attached = decl->getAttachedFunctionBuilder(); |
| if (attached != attr) { |
| TC.diagnose(attr->getLocation(), diag::function_builder_multiple, |
| isa<ParamDecl>(decl)); |
| TC.diagnose(attached->getLocation(), |
| diag::previous_function_builder_here); |
| attr->setInvalid(); |
| return; |
| } else { |
| // Force any diagnostics associated with computing the function-builder |
| // type. |
| (void) decl->getFunctionBuilderType(); |
| } |
| |
| return; |
| } |
| |
| TC.diagnose(attr->getLocation(), diag::nominal_type_not_attribute, |
| nominal->getDescriptiveKind(), nominal->getFullName()); |
| nominal->diagnose(diag::decl_declared_here, nominal->getFullName()); |
| attr->setInvalid(); |
| } |
| |
| |
| void AttributeChecker::visitPropertyWrapperAttr(PropertyWrapperAttr *attr) { |
| auto nominal = dyn_cast<NominalTypeDecl>(D); |
| if (!nominal) |
| return; |
| |
| // Force checking of the property wrapper type. |
| (void)nominal->getPropertyWrapperTypeInfo(); |
| } |
| |
| void AttributeChecker::visitFunctionBuilderAttr(FunctionBuilderAttr *attr) { |
| // TODO: check that the type at least provides a `sequence` factory? |
| // Any other validation? |
| } |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| /// Returns true if the given type conforms to `Differentiable` in the given |
| /// module. |
| static bool conformsToDifferentiable(Type type, DeclContext *DC) { |
| auto &ctx = type->getASTContext(); |
| auto *differentiableProto = |
| ctx.getProtocol(KnownProtocolKind::Differentiable); |
| auto conf = TypeChecker::conformsToProtocol( |
| type, differentiableProto, DC, ConformanceCheckFlags::InExpression); |
| if (!conf) |
| return false; |
| // Try to get the `TangentVector` type witness, in case the conformance has |
| // not been fully checked and the type witness cannot be resolved. |
| Type tanType = conf->getTypeWitnessByName(type, ctx.Id_TangentVector); |
| return !tanType.isNull() && !tanType->hasError(); |
| }; |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| /// Returns true if the given type's `TangentVector` is equal to itself in the |
| /// given module. |
| static bool tangentVectorEqualSelf(Type type, DeclContext *DC) { |
| assert(conformsToDifferentiable(type, DC)); |
| auto &ctx = type->getASTContext(); |
| auto *differentiableProto = |
| ctx.getProtocol(KnownProtocolKind::Differentiable); |
| auto conf = TypeChecker::conformsToProtocol( |
| type, differentiableProto, DC, |
| ConformanceCheckFlags::InExpression); |
| auto tanType = conf->getTypeWitnessByName(type, ctx.Id_TangentVector); |
| return type->getCanonicalType() == tanType->getCanonicalType(); |
| }; |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| /// Creates a `IndexSubset` for the given function type, representing |
| /// all inferred differentiation parameters. |
| /// The differentiation parameters are inferred to be: |
| /// - All parameters of the function type that conform to `Differentiable`. |
| /// - If the function type's result is a function type (i.e. it is a curried |
| /// method type), then also all parameters of the function result type that |
| /// conform to `Differentiable`. |
| IndexSubset * |
| TypeChecker::inferDifferentiableParameters( |
| AbstractFunctionDecl *AFD, GenericEnvironment *derivativeGenEnv) { |
| auto &ctx = AFD->getASTContext(); |
| auto *functionType = AFD->getInterfaceType()->castTo<AnyFunctionType>(); |
| auto numUncurriedParams = functionType->getNumParams(); |
| if (auto *resultFnType = |
| functionType->getResult()->getAs<AnyFunctionType>()) { |
| numUncurriedParams += resultFnType->getNumParams(); |
| } |
| llvm::SmallBitVector parameterBits(numUncurriedParams); |
| SmallVector<Type, 4> allParamTypes; |
| |
| // Returns true if the i-th parameter type is differentiable. |
| auto isDifferentiableParam = [&](unsigned i) -> bool { |
| if (i >= allParamTypes.size()) |
| return false; |
| auto paramType = allParamTypes[i]; |
| if (derivativeGenEnv) |
| paramType = derivativeGenEnv->mapTypeIntoContext(paramType); |
| else |
| paramType = AFD->mapTypeIntoContext(paramType); |
| // Return false for existential types. |
| if (paramType->isExistentialType()) |
| return false; |
| // Return false for function types. |
| if (paramType->is<AnyFunctionType>()) |
| return false; |
| // Return true if the type conforms to `Differentiable`. |
| return conformsToDifferentiable(paramType, AFD); |
| }; |
| |
| // Get all parameter types. |
| // NOTE: To be robust, result function type parameters should be added only if |
| // `functionType` comes from a static/instance method, and not a free function |
| // returning a function type. In practice, this code path should not be |
| // reachable for free functions returning a function type. |
| if (auto resultFnType = functionType->getResult()->getAs<AnyFunctionType>()) |
| for (auto ¶m : resultFnType->getParams()) |
| allParamTypes.push_back(param.getPlainType()); |
| for (auto ¶m : functionType->getParams()) |
| allParamTypes.push_back(param.getPlainType()); |
| |
| // Set differentiation parameters. |
| for (unsigned i : range(parameterBits.size())) |
| if (isDifferentiableParam(i)) |
| parameterBits.set(i); |
| |
| return IndexSubset::get(ctx, parameterBits); |
| } |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| static FuncDecl *resolveAutoDiffDerivativeFunction( |
| TypeChecker &TC, DeclNameWithLoc specifier, AbstractFunctionDecl *original, |
| Type expectedTy, std::function<bool(FuncDecl *)> isValid) { |
| auto nameLoc = specifier.Loc.getBaseNameLoc(); |
| auto overloadDiagnostic = [&]() { |
| TC.diagnose(nameLoc, diag::differentiable_attr_overload_not_found, |
| specifier.Name, expectedTy); |
| }; |
| auto ambiguousDiagnostic = [&]() { |
| TC.diagnose(nameLoc, |
| diag::differentiable_attr_ambiguous_function_identifier, |
| specifier.Name); |
| }; |
| auto notFunctionDiagnostic = [&]() { |
| TC.diagnose(nameLoc, diag::differentiable_attr_specified_not_function, |
| specifier.Name); |
| }; |
| std::function<void()> invalidTypeContextDiagnostic = [&]() { |
| TC.diagnose(nameLoc, |
| diag::differentiable_attr_function_not_same_type_context, |
| specifier.Name); |
| }; |
| |
| // Returns true if the original function and derivative function candidate are |
| // defined in compatible type contexts. If the original function and the |
| // derivative function have different parents, or if they both have no type |
| // context and are in different modules, return false. |
| std::function<bool(FuncDecl *)> hasValidTypeContext = [&](FuncDecl *func) { |
| // Check if both functions are top-level. |
| if (!original->getInnermostTypeContext() && |
| !func->getInnermostTypeContext() && |
| original->getParentModule() == func->getParentModule()) |
| return true; |
| // Check if both functions are defined in the same type context. |
| if (auto typeCtx1 = original->getInnermostTypeContext()) |
| if (auto typeCtx2 = func->getInnermostTypeContext()) |
| return typeCtx1->getSelfNominalTypeDecl() == |
| typeCtx2->getSelfNominalTypeDecl(); |
| return original->getParent() == func->getParent(); |
| }; |
| |
| auto isABIPublic = [&](AbstractFunctionDecl *func) { |
| return func->getFormalAccess() >= AccessLevel::Public || |
| func->getAttrs().hasAttribute<InlinableAttr>() || |
| func->getAttrs().hasAttribute<UsableFromInlineAttr>(); |
| }; |
| |
| // If the original function is exported (i.e. it is public or |
| // @usableFromInline), then the derivative functions must also be exported. |
| // Returns true on error. |
| auto checkAccessControl = [&](FuncDecl *func) { |
| if (!isABIPublic(original)) |
| return false; |
| if (isABIPublic(func)) |
| return false; |
| TC.diagnose(nameLoc, diag::differentiable_attr_invalid_access, |
| specifier.Name, original->getFullName()); |
| return true; |
| }; |
| |
| auto originalTypeCtx = original->getInnermostTypeContext(); |
| if (!originalTypeCtx) originalTypeCtx = original->getParent(); |
| assert(originalTypeCtx); |
| |
| // Set lookup options. |
| auto lookupOptions = defaultMemberLookupOptions |
| | NameLookupFlags::IgnoreAccessControl; |
| |
| auto candidate = TC.lookupFuncDecl( |
| specifier.Name, nameLoc, /*baseType*/ Type(), originalTypeCtx, isValid, |
| overloadDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic, |
| lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic); |
| |
| if (!candidate) |
| return nullptr; |
| |
| if (checkAccessControl(candidate)) |
| return nullptr; |
| |
| // Derivatives of class members must be final. |
| if (original->getDeclContext()->getSelfClassDecl() && |
| !candidate->isFinal()) { |
| TC.diagnose(nameLoc, |
| diag::differentiable_attr_class_derivative_not_final); |
| return nullptr; |
| } |
| |
| return candidate; |
| } |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| // Checks that the `candidate` function type equals the `required` function |
| // type, disregarding parameter labels and tuple result labels. |
| // `checkGenericSignature` is used to check generic signatures, if specified. |
| // Otherwise, generic signatures are checked for equality. |
| static bool checkFunctionSignature( |
| CanAnyFunctionType required, CanType candidate, |
| Optional<std::function<bool(GenericSignature, GenericSignature)>> |
| checkGenericSignature = None) { |
| // Check that candidate is actually a function. |
| auto candidateFnTy = dyn_cast<AnyFunctionType>(candidate); |
| if (!candidateFnTy) |
| return false; |
| |
| // Erase dynamic self types. |
| required = dyn_cast<AnyFunctionType>(required->getCanonicalType()); |
| candidateFnTy = dyn_cast<AnyFunctionType>(candidateFnTy->getCanonicalType()); |
| |
| // Check that generic signatures match. |
| auto requiredGenSig = required.getOptGenericSignature(); |
| auto candidateGenSig = candidateFnTy.getOptGenericSignature(); |
| // Call generic signature check function, if specified. |
| // Otherwise, check that generic signatures are equal. |
| if (!checkGenericSignature) { |
| if (candidateGenSig != requiredGenSig) |
| return false; |
| } else if (!(*checkGenericSignature)(requiredGenSig, candidateGenSig)) { |
| return false; |
| } |
| |
| // Check that parameter types match, disregarding labels. |
| if (required->getNumParams() != candidateFnTy->getNumParams()) |
| return false; |
| if (!std::equal(required->getParams().begin(), required->getParams().end(), |
| candidateFnTy->getParams().begin(), |
| [](AnyFunctionType::Param x, AnyFunctionType::Param y) { |
| return x.getPlainType()->isEqual(y.getPlainType()); |
| })) |
| return false; |
| |
| // If required result type is not a function type, check that result types |
| // match exactly. |
| auto requiredResultFnTy = dyn_cast<AnyFunctionType>(required.getResult()); |
| if (!requiredResultFnTy) { |
| auto requiredResultTupleTy = dyn_cast<TupleType>(required.getResult()); |
| auto candidateResultTupleTy = |
| dyn_cast<TupleType>(candidateFnTy.getResult()); |
| if (!requiredResultTupleTy || !candidateResultTupleTy) |
| return required.getResult()->isEqual(candidateFnTy.getResult()); |
| // If result types are tuple types, check that element types match, |
| // ignoring labels. |
| if (requiredResultTupleTy->getNumElements() != |
| candidateResultTupleTy->getNumElements()) |
| return false; |
| return std::equal(requiredResultTupleTy.getElementTypes().begin(), |
| requiredResultTupleTy.getElementTypes().end(), |
| candidateResultTupleTy.getElementTypes().begin(), |
| [](CanType x, CanType y) { return x->isEqual(y); }); |
| } |
| |
| // Required result type is a function. Recurse. |
| return checkFunctionSignature(requiredResultFnTy, candidateFnTy.getResult()); |
| }; |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| // Computes `IndexSubset` from the given parsed differentiation parameters |
| // (possibly empty) for the given function and derivative generic environment, |
| // then verifies that the parameter indices are valid. |
| // - If parsed parameters are empty, infer parameter indices. |
| // - Otherwise, build parameter indices from parsed parameters. |
| // The attribute name/location are used in diagnostics. |
| static IndexSubset *computeDifferentiationParameters( |
| TypeChecker &TC, ArrayRef<ParsedAutoDiffParameter> parsedWrtParams, |
| AbstractFunctionDecl *function, GenericEnvironment *derivativeGenEnv, |
| StringRef attrName, SourceLoc attrLoc |
| ) { |
| // Get function type and parameters. |
| TC.resolveDeclSignature(function); |
| auto *functionType = function->getInterfaceType()->castTo<AnyFunctionType>(); |
| auto ¶ms = *function->getParameters(); |
| auto numParams = function->getParameters()->size(); |
| auto isInstanceMethod = function->isInstanceMember(); |
| |
| // Diagnose if function has no parameters. |
| if (params.size() == 0) { |
| // If function is not an instance method, diagnose immediately. |
| if (!isInstanceMethod) { |
| TC.diagnose(attrLoc, diag::diff_function_no_parameters, |
| function->getFullName()) |
| .highlight(function->getSignatureSourceRange()); |
| return nullptr; |
| } |
| // If function is an instance method, diagnose only if `self` does not |
| // conform to `Differentiable`. |
| else { |
| auto selfType = function->getImplicitSelfDecl()->getInterfaceType(); |
| if (derivativeGenEnv) |
| selfType = derivativeGenEnv->mapTypeIntoContext(selfType); |
| else |
| selfType = function->mapTypeIntoContext(selfType); |
| // FIXME(TF-568): `Differentiable`-conforming protocols cannot define |
| // `@differentiable` computed properties because the check below returns |
| // false. |
| if (!conformsToDifferentiable(selfType, function)) { |
| TC.diagnose(attrLoc, diag::diff_function_no_parameters, |
| function->getFullName()) |
| .highlight(function->getSignatureSourceRange()); |
| return nullptr; |
| } |
| } |
| } |
| |
| // If parsed differentiation parameters are empty, infer parameter indices |
| // from the function type. |
| if (parsedWrtParams.empty()) |
| return TypeChecker::inferDifferentiableParameters( |
| function, derivativeGenEnv); |
| |
| // Otherwise, build parameter indices from parsed differentiation parameters. |
| auto numUncurriedParams = functionType->getNumParams(); |
| if (auto *resultFnType = |
| functionType->getResult()->getAs<AnyFunctionType>()) { |
| numUncurriedParams += resultFnType->getNumParams(); |
| } |
| llvm::SmallBitVector parameterBits(numUncurriedParams); |
| int lastIndex = -1; |
| for (unsigned i : indices(parsedWrtParams)) { |
| auto paramLoc = parsedWrtParams[i].getLoc(); |
| switch (parsedWrtParams[i].getKind()) { |
| case ParsedAutoDiffParameter::Kind::Named: { |
| auto nameIter = |
| llvm::find_if(params.getArray(), [&](ParamDecl *param) { |
| return param->getName() == parsedWrtParams[i].getName(); |
| }); |
| // Parameter name must exist. |
| if (nameIter == params.end()) { |
| TC.diagnose(paramLoc, diag::diff_params_clause_param_name_unknown, |
| parsedWrtParams[i].getName()); |
| return nullptr; |
| } |
| // Parameter names must be specified in the original order. |
| unsigned index = std::distance(params.begin(), nameIter); |
| if ((int)index <= lastIndex) { |
| TC.diagnose(paramLoc, |
| diag::diff_params_clause_params_not_original_order); |
| return nullptr; |
| } |
| parameterBits.set(index); |
| lastIndex = index; |
| break; |
| } |
| case ParsedAutoDiffParameter::Kind::Self: { |
| // 'self' is only applicable to instance methods. |
| if (!isInstanceMethod) { |
| TC.diagnose(paramLoc, |
| diag::diff_params_clause_self_instance_method_only); |
| return nullptr; |
| } |
| // 'self' can only be the first in the list. |
| if (i > 0) { |
| TC.diagnose(paramLoc, diag::diff_params_clause_self_must_be_first); |
| return nullptr; |
| } |
| parameterBits.set(parameterBits.size() - 1); |
| break; |
| } |
| case ParsedAutoDiffParameter::Kind::Ordered: { |
| auto index = parsedWrtParams[i].getIndex(); |
| if (index >= numParams) { |
| TC.diagnose(paramLoc, diag::diff_params_clause_param_index_out_of_range); |
| return nullptr; |
| } |
| // Parameter names must be specified in the original order. |
| if ((int)index <= lastIndex) { |
| TC.diagnose(paramLoc, |
| diag::diff_params_clause_params_not_original_order); |
| return nullptr; |
| } |
| parameterBits.set(index); |
| lastIndex = index; |
| break; |
| } |
| } |
| } |
| return IndexSubset::get(TC.Context, parameterBits); |
| } |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| // Computes `IndexSubset` from the given parsed transposing parameters |
| // (possibly empty) for the given function, then verifies that the parameter |
| // indices are valid. |
| // - If parsed parameters are empty, infer parameter indices. |
| // - Otherwise, build parameter indices from parsed parameters. |
| // The attribute name/location are used in diagnostics. |
| static IndexSubset *computeTransposingParameters( |
| TypeChecker &TC, ArrayRef<ParsedAutoDiffParameter> parsedWrtParams, |
| AbstractFunctionDecl *transposeFunc, bool isCurried, |
| GenericEnvironment *derivativeGenEnv, SourceLoc attrLoc |
| ) { |
| // Get function type and parameters. |
| TC.resolveDeclSignature(transposeFunc); |
| auto *functionType = transposeFunc->getInterfaceType() |
| ->castTo<AnyFunctionType>(); |
| |
| ArrayRef<TupleTypeElt> transposeResultTypes; |
| // Return type of '@transposing' function can have single type or tuple |
| // of types. |
| auto temp = functionType->getResult(); |
| if (isCurried) |
| temp = temp->getAs<AnyFunctionType>()->getResult(); |
| |
| if (auto t = temp->getAs<TupleType>()) { |
| transposeResultTypes = t->getElements(); |
| } else { |
| transposeResultTypes = ArrayRef<TupleTypeElt>(temp); |
| } |
| |
| auto ¶ms = *transposeFunc->getParameters(); |
| auto isInstanceMethod = transposeFunc->isInstanceMember(); |
| |
| bool wrtSelf = false; |
| if (isCurried && !parsedWrtParams.empty() && |
| parsedWrtParams.front().getKind() == ParsedAutoDiffParameter::Kind::Self) |
| wrtSelf = true; |
| |
| // Make sure the self type is differentiable. |
| if (isCurried && wrtSelf) { |
| auto selfType = transposeFunc->getImplicitSelfDecl()->getInterfaceType(); |
| if (derivativeGenEnv) |
| selfType = derivativeGenEnv->mapTypeIntoContext(selfType); |
| if (!conformsToDifferentiable(selfType, transposeFunc)) { |
| TC.diagnose(attrLoc, diag::diff_function_no_parameters, |
| transposeFunc->getFullName()) |
| .highlight(transposeFunc->getSignatureSourceRange()); |
| return nullptr; |
| } |
| } |
| |
| // If parsed differentiation parameters are empty, infer parameter indices |
| // from the function type. |
| // TODO(bartchr): still need to do this! |
| // if (parsedWrtParams.empty()) |
| // return TypeChecker::inferTransposingParameters( |
| // function, derivativeGenEnv); |
| |
| // Otherwise, build parameter indices from parsed differentiation parameters. |
| unsigned numParams = params.size() + transposeResultTypes.size(); |
| auto paramIndices = SmallBitVector(numParams); |
| int lastIndex = -1; |
| for (unsigned i : indices(parsedWrtParams)) { |
| auto paramLoc = parsedWrtParams[i].getLoc(); |
| switch (parsedWrtParams[i].getKind()) { |
| case ParsedAutoDiffParameter::Kind::Named: { |
| TC.diagnose(paramLoc, diag::transposing_attr_cant_use_named_wrt_params, |
| parsedWrtParams[i].getName()); |
| return nullptr; |
| } |
| case ParsedAutoDiffParameter::Kind::Self: { |
| // 'self' is only applicable to instance methods. |
| if (!isInstanceMethod) { |
| TC.diagnose(paramLoc, |
| diag::diff_params_clause_self_instance_method_only); |
| return nullptr; |
| } |
| // 'self' can only be the first in the list. |
| if (i > 0) { |
| TC.diagnose(paramLoc, diag::diff_params_clause_self_must_be_first); |
| return nullptr; |
| } |
| paramIndices.set(numParams - 1); |
| break; |
| } |
| case ParsedAutoDiffParameter::Kind::Ordered: { |
| auto index = parsedWrtParams[i].getIndex(); |
| if (index >= numParams) { |
| TC.diagnose(paramLoc, diag::diff_params_clause_param_index_out_of_range); |
| return nullptr; |
| } |
| // Parameter names must be specified in the original order. |
| if ((int)index <= lastIndex) { |
| TC.diagnose(paramLoc, |
| diag::diff_params_clause_params_not_original_order); |
| return nullptr; |
| } |
| paramIndices.set(index); |
| lastIndex = index; |
| break; |
| } |
| } |
| } |
| return IndexSubset::get(TC.Context, paramIndices); |
| } |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| // Checks if the given `IndexSubset` instance is valid for the given function |
| // type in the given derivative generic environment and module context. Returns |
| // true on error. |
| // The parsed differentiation parameters and attribute location are used in |
| // diagnostics. |
| static bool checkDifferentiationParameters( |
| TypeChecker &TC, AbstractFunctionDecl *AFD, IndexSubset *indices, |
| AnyFunctionType *functionType, GenericEnvironment *derivativeGenEnv, |
| ModuleDecl *module, ArrayRef<ParsedAutoDiffParameter> parsedWrtParams, |
| SourceLoc attrLoc) { |
| // Diagnose empty parameter indices. This occurs when no `wrt` clause is |
| // declared and no differentiation parameters can be inferred. |
| if (indices->isEmpty()) { |
| TC.diagnose(attrLoc, diag::diff_params_clause_no_inferred_parameters); |
| return true; |
| } |
| |
| // Check that differentiation parameters have allowed types. |
| SmallVector<Type, 4> wrtParamTypes; |
| autodiff::getSubsetParameterTypes(indices, functionType, wrtParamTypes); |
| for (unsigned i : range(wrtParamTypes.size())) { |
| SourceLoc loc = parsedWrtParams.empty() |
| ? attrLoc |
| : parsedWrtParams[i].getLoc(); |
| auto wrtParamType = wrtParamTypes[i]; |
| if (wrtParamType->is<InOutType>()) { |
| TC.diagnose( |
| loc, |
| diag::diff_params_clause_inout_argument, |
| wrtParamType); |
| return true; |
| } |
| if (!wrtParamType->hasTypeParameter()) |
| wrtParamType = wrtParamType->mapTypeOutOfContext(); |
| if (derivativeGenEnv) |
| wrtParamType = |
| derivativeGenEnv->mapTypeIntoContext(wrtParamType); |
| else |
| wrtParamType = AFD->mapTypeIntoContext(wrtParamType); |
| // Parameter cannot have an existential type. |
| if (wrtParamType->isExistentialType()) { |
| TC.diagnose( |
| loc, diag::diff_params_clause_cannot_diff_wrt_existentials, |
| wrtParamType); |
| return true; |
| } |
| // Parameter cannot have a function type. |
| if (wrtParamType->is<AnyFunctionType>()) { |
| TC.diagnose(loc, diag::diff_params_clause_cannot_diff_wrt_functions, |
| wrtParamType); |
| return true; |
| } |
| // Parameter must conform to `Differentiable`. |
| if (!conformsToDifferentiable(wrtParamType, AFD)) { |
| TC.diagnose(loc, diag::diff_params_clause_param_not_differentiable, |
| wrtParamType); |
| return true; |
| } |
| } |
| return false; |
| } |
| // SWIFT_ENABLE_TENSORFLOW |
| // Checks if the given `IndexSubset` instance is valid for the |
| // given function type in the given derivative generic environment and module |
| // context. Returns true on error. |
| // The parsed differentiation parameters and attribute location are used in |
| // diagnostics. |
| static bool checkTransposingParameters( |
| TypeChecker &TC, AbstractFunctionDecl *AFD, |
| SmallVector<Type, 4> wrtParamTypes, GenericEnvironment *derivativeGenEnv, |
| ModuleDecl *module, ArrayRef<ParsedAutoDiffParameter> parsedWrtParams, |
| SourceLoc attrLoc) { |
| // Check that differentiation parameters have allowed types. |
| for (unsigned i : range(wrtParamTypes.size())) { |
| auto wrtParamType = wrtParamTypes[i]; |
| if (!wrtParamType->hasTypeParameter()) |
| wrtParamType = wrtParamType->mapTypeOutOfContext(); |
| if (derivativeGenEnv) |
| wrtParamType = derivativeGenEnv->mapTypeIntoContext(wrtParamType); |
| else |
| wrtParamType = AFD->mapTypeIntoContext(wrtParamType); |
| SourceLoc loc = parsedWrtParams.empty() |
| ? attrLoc |
| : parsedWrtParams[i].getLoc(); |
| // Parameter cannot have an existential type. |
| if (wrtParamType->isExistentialType()) { |
| TC.diagnose(loc, diag::diff_params_clause_cannot_diff_wrt_existentials, |
| wrtParamType); |
| return true; |
| } |
| // Parameter cannot have a function type. |
| if (wrtParamType->is<AnyFunctionType>()) { |
| TC.diagnose(loc, diag::diff_params_clause_cannot_diff_wrt_functions, |
| wrtParamType); |
| return true; |
| } |
| // Parameter must conform to `Differentiable` |
| // and `Type.TangentVector == Type`. |
| if (!conformsToDifferentiable(wrtParamType, AFD) || |
| !tangentVectorEqualSelf(wrtParamType, AFD)) { |
| TC.diagnose(loc, diag::transpose_params_clause_param_not_differentiable, |
| wrtParamType.getString()); |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { |
| // Skip checking implicit `@differentiable` attributes. We currently assume |
| // that all implicit `@differentiable` attributes are valid. |
| // Motivation: some implicit attributes do not contain a where clause, and |
| // this function assumes that the where clauses are available. Propagating |
| // where clauses and requirements consistently is a larger problem, to be |
| // revisited. |
| if (attr->isImplicit()) |
| return; |
| |
| auto &ctx = TC.Context; |
| auto lookupConformance = |
| LookUpConformanceInModule(D->getDeclContext()->getParentModule()); |
| |
| // If functions is marked as linear, you cannot have a custom VJP and/or |
| // a JVP. |
| if (attr->isLinear() && (attr->getVJP() || attr->getJVP())) { |
| diagnoseAndRemoveAttr(attr, |
| diag::attr_differentiable_no_vjp_or_jvp_when_linear); |
| return; |
| } |
| |
| AbstractFunctionDecl *original = dyn_cast<AbstractFunctionDecl>(D); |
| if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) { |
| if (asd->getImplInfo().isSimpleStored() && |
| (attr->getJVP() || attr->getVJP())) { |
| diagnoseAndRemoveAttr(attr, |
| diag::differentiable_attr_stored_property_variable_unsupported); |
| return; |
| } |
| // When used directly on a storage decl (stored/computed property or |
| // subscript), the getter is currently inferred to be `@differentiable`. |
| // TODO(TF-129): Infer setter to also be `@differentiable` after |
| // differentiation supports inout parameters. This requires refactoring to |
| // handle multiple `original` functions (both getter and setter). |
| if (!asd->getDeclContext()->isModuleScopeContext()) { |
| original = asd->getSynthesizedAccessor(AccessorKind::Get); |
| } else { |
| original = nullptr; |
| } |
| } |
| // Setters are not yet supported. |
| // TODO(TF-129): Remove this when differentiation supports inout parameters. |
| if (auto *accessor = dyn_cast_or_null<AccessorDecl>(original)) |
| if (accessor->isSetter()) |
| original = nullptr; |
| |
| // Global immutable vars, for example, have no getter, and therefore trigger |
| // this. |
| if (!original) { |
| diagnoseAndRemoveAttr(attr, diag::invalid_decl_attribute, attr); |
| return; |
| } |
| |
| TC.resolveDeclSignature(original); |
| auto *originalFnTy = original->getInterfaceType()->castTo<AnyFunctionType>(); |
| bool isMethod = original->hasImplicitSelfDecl(); |
| |
| // If the original function returns the empty tuple type, there's no output to |
| // differentiate from. |
| auto originalResultTy = originalFnTy->getResult(); |
| if (isMethod) |
| originalResultTy = originalResultTy->castTo<AnyFunctionType>()->getResult(); |
| if (originalResultTy->isEqual(ctx.TheEmptyTupleType)) { |
| TC.diagnose(attr->getLocation(), diag::differentiable_attr_void_result, |
| original->getFullName()) |
| .highlight(original->getSourceRange()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| bool isOriginalProtocolRequirement = |
| isa<ProtocolDecl>(original->getDeclContext()) && |
| original->isProtocolRequirement(); |
| |
| bool isOriginalClassMember = |
| original->getDeclContext() && |
| original->getDeclContext()->getSelfClassDecl(); |
| |
| // Diagnose invalid class conditions. |
| if (isOriginalClassMember) { |
| // Class methods returning dynamic `Self` are not supported. |
| // (For class methods, dynamic `Self` is supported only as the single |
| // result - JVPs/VJPs would not type-check. |
| if (auto *originalFn = dyn_cast<FuncDecl>(original)) { |
| if (originalFn->hasDynamicSelfResult()) { |
| TC.diagnose(attr->getLocation(), |
| diag::differentiable_attr_class_member_no_dynamic_self); |
| attr->setInvalid(); |
| return; |
| } |
| } |
| |
| // TODO(TF-654): Class initializers are not yet supported. |
| // Extra JVP/VJP type calculation logic is necessary because classes have |
| // both allocators and initializers. |
| if (auto *initDecl = dyn_cast<ConstructorDecl>(original)) { |
| TC.diagnose(attr->getLocation(), |
| diag::differentiable_attr_class_init_not_yet_supported); |
| attr->setInvalid(); |
| return; |
| } |
| } |
| |
| // Start type-checking the arguments of the @differentiable attribute. This |
| // covers 'wrt:', 'jvp:', 'vjp:', and 'where', all of which are optional. |
| |
| // Handle 'where' clause, if it exists. |
| // - Resolve attribute where clause requirements and store in the attribute |
| // for serialization. |
| // - Compute generic signature for autodiff derivative functions based on |
| // the original function's generate signature and the attribute's where |
| // clause requirements. |
| GenericSignature whereClauseGenSig = GenericSignature(); |
| GenericEnvironment *whereClauseGenEnv = nullptr; |
| if (auto *whereClause = attr->getWhereClause()) { |
| // `@differentiable` attributes on protocol requirements do not support |
| // 'where' clauses. |
| if (isOriginalProtocolRequirement) { |
| TC.diagnose(attr->getLocation(), |
| diag::differentiable_attr_protocol_req_where_clause); |
| attr->setInvalid(); |
| return; |
| } |
| if (whereClause->getRequirements().empty()) { |
| // Where clause must not be empty. |
| TC.diagnose(attr->getLocation(), |
| diag::differentiable_attr_empty_where_clause); |
| attr->setInvalid(); |
| return; |
| } |
| |
| auto originalGenSig = original->getGenericSignature(); |
| if (!originalGenSig) { |
| // Attributes with where clauses can only be declared on |
| // generic functions. |
| TC.diagnose(attr->getLocation(), |
| diag::differentiable_attr_nongeneric_trailing_where, |
| original->getFullName()) |
| .highlight(whereClause->getSourceRange()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Build a new generic signature for autodiff derivative functions. |
| GenericSignatureBuilder builder(ctx); |
| // Add the original function's generic signature. |
| builder.addGenericSignature(originalGenSig); |
| |
| using FloatingRequirementSource = |
| GenericSignatureBuilder::FloatingRequirementSource; |
| |
| bool errorOccurred = false; |
| WhereClauseOwner(original, attr).visitRequirements( |
| TypeResolutionStage::Structural, |
| [&](const Requirement &req, RequirementRepr *reqRepr) { |
| switch (req.getKind()) { |
| case RequirementKind::SameType: |
| case RequirementKind::Superclass: |
| case RequirementKind::Conformance: |
| break; |
| |
| // Layout requirements are not supported. |
| case RequirementKind::Layout: |
| TC.diagnose(attr->getLocation(), |
| diag::differentiable_attr_layout_req_unsupported) |
| .highlight(reqRepr->getSourceRange()); |
| errorOccurred = true; |
| return false; |
| } |
| |
| // Add requirement to generic signature builder. |
| builder.addRequirement(req, reqRepr, |
| FloatingRequirementSource::forExplicit(reqRepr), |
| nullptr, original->getModuleContext()); |
| return false; |
| }); |
| |
| if (errorOccurred) { |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Compute generic signature and environment for autodiff associated |
| // functions. |
| whereClauseGenSig = std::move(builder).computeGenericSignature( |
| attr->getLocation(), /*allowConcreteGenericParams=*/true); |
| whereClauseGenEnv = whereClauseGenSig->getGenericEnvironment(); |
| // Store the resolved derivative generic signature in the attribute. |
| attr->setDerivativeGenericSignature(ctx, whereClauseGenSig); |
| } |
| |
| // Validate the 'wrt:' parameters. |
| |
| // Get the parsed wrt param indices, which have not yet been checked. |
| // This is defined for parsed attributes. |
| auto parsedWrtParams = attr->getParsedParameters(); |
| // Get checked wrt param indices. |
| // This is defined only for compiler-synthesized attributes. |
| auto *checkedWrtParamIndices = attr->getParameterIndices(); |
| |
| // Compute the derivative function type. |
| auto derivativeFnTy = originalFnTy; |
| if (whereClauseGenEnv) |
| derivativeFnTy = whereClauseGenEnv->mapTypeIntoContext(derivativeFnTy) |
| ->castTo<AnyFunctionType>(); |
| |
| // If checked wrt param indices are not specified, compute them. |
| if (!checkedWrtParamIndices) |
| checkedWrtParamIndices = |
| computeDifferentiationParameters(TC, parsedWrtParams, original, |
| whereClauseGenEnv, attr->getAttrName(), |
| attr->getLocation()); |
| if (!checkedWrtParamIndices) { |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Check if differentiation parameter indices are valid. |
| if (checkDifferentiationParameters( |
| TC, original, checkedWrtParamIndices, derivativeFnTy, whereClauseGenEnv, |
| original->getModuleContext(), parsedWrtParams, attr->getLocation())) { |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Set the checked differentiation parameter indices in the attribute. |
| attr->setParameterIndices(checkedWrtParamIndices); |
| |
| if (whereClauseGenEnv) |
| originalResultTy = |
| whereClauseGenEnv->mapTypeIntoContext(originalResultTy); |
| else |
| originalResultTy = original->mapTypeIntoContext(originalResultTy); |
| // Check that original function's result type conforms to `Differentiable`. |
| if (!conformsToDifferentiable(originalResultTy, original)) { |
| TC.diagnose(attr->getLocation(), |
| diag::differentiable_attr_result_not_differentiable, |
| originalResultTy); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // `@differentiable` attributes on protocol requirements do not support |
| // JVP/VJP. |
| if (isOriginalProtocolRequirement && (attr->getJVP() || attr->getVJP())) { |
| TC.diagnose(attr->getLocation(), |
| diag::differentiable_attr_protocol_req_assoc_func); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Resolve the JVP declaration, if it exists. |
| if (attr->getJVP()) { |
| AnyFunctionType *expectedJVPFnTy = |
| originalFnTy->getAutoDiffDerivativeFunctionType( |
| checkedWrtParamIndices, /*resultIndex*/ 0, |
| AutoDiffDerivativeFunctionKind::JVP, lookupConformance, |
| whereClauseGenSig, /*makeSelfParamFirst*/ true); |
| |
| auto isValidJVP = [&](FuncDecl *jvpCandidate) { |
| TC.validateDecl(jvpCandidate); |
| return checkFunctionSignature( |
| cast<AnyFunctionType>(expectedJVPFnTy->getCanonicalType()), |
| jvpCandidate->getInterfaceType()->getCanonicalType()); |
| }; |
| |
| FuncDecl *jvp = resolveAutoDiffDerivativeFunction( |
| TC, attr->getJVP().getValue(), original, expectedJVPFnTy, isValidJVP); |
| |
| if (!jvp) { |
| attr->setInvalid(); |
| return; |
| } |
| // Memorize the jvp reference in the attribute. |
| attr->setJVPFunction(jvp); |
| } |
| |
| // Resolve the VJP declaration, if it exists. |
| if (attr->getVJP()) { |
| AnyFunctionType *expectedVJPFnTy = |
| originalFnTy->getAutoDiffDerivativeFunctionType( |
| checkedWrtParamIndices, /*resultIndex*/ 0, |
| AutoDiffDerivativeFunctionKind::VJP, lookupConformance, |
| whereClauseGenSig, /*makeSelfParamFirst*/ true); |
| |
| auto isValidVJP = [&](FuncDecl *vjpCandidate) { |
| TC.validateDecl(vjpCandidate); |
| return checkFunctionSignature( |
| cast<AnyFunctionType>(expectedVJPFnTy->getCanonicalType()), |
| vjpCandidate->getInterfaceType()->getCanonicalType()); |
| }; |
| |
| FuncDecl *vjp = resolveAutoDiffDerivativeFunction( |
| TC, attr->getVJP().getValue(), original, expectedVJPFnTy, isValidVJP); |
| |
| if (!vjp) { |
| attr->setInvalid(); |
| return; |
| } |
| // Memorize the vjp reference in the attribute. |
| attr->setVJPFunction(vjp); |
| } |
| |
| if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) { |
| // Remove `@differentiable` attribute from storage declaration to prevent |
| // duplicate attribute registration during SILGen. |
| D->getAttrs().removeAttribute(attr); |
| // Transfer `@differentiable` attribute from storage declaration to |
| // getter accessor. |
| auto *newAttr = DifferentiableAttr::create( |
| ctx, /*implicit*/ true, attr->AtLoc, attr->getRange(), attr->isLinear(), |
| attr->getParameterIndices(), attr->getJVP(), attr->getVJP(), |
| attr->getDerivativeGenericSignature()); |
| newAttr->setJVPFunction(attr->getJVPFunction()); |
| newAttr->setVJPFunction(attr->getVJPFunction()); |
| auto insertion = ctx.DifferentiableAttrs.try_emplace( |
| {asd->getAccessor(AccessorKind::Get), newAttr->getParameterIndices()}, |
| newAttr); |
| // Valid `@differentiable` attributes are uniqued by their parameter |
| // indices. Reject duplicate attributes for the same decl and parameter |
| // indices pair. |
| if (!insertion.second) { |
| diagnoseAndRemoveAttr(attr, diag::differentiable_attr_duplicate); |
| TC.diagnose(insertion.first->getSecond()->getLocation(), |
| diag::differentiable_attr_duplicate_note); |
| return; |
| } |
| asd->getAccessor(AccessorKind::Get)->getAttrs().add(newAttr); |
| return; |
| } |
| auto insertion = ctx.DifferentiableAttrs.try_emplace( |
| {D, attr->getParameterIndices()}, attr); |
| // `@differentiable` attributes are uniqued by their parameter indices. |
| // Reject duplicate attributes for the same decl and parameter indices pair. |
| if (!insertion.second && insertion.first->getSecond() != attr) { |
| diagnoseAndRemoveAttr(attr, diag::differentiable_attr_duplicate); |
| TC.diagnose(insertion.first->getSecond()->getLocation(), |
| diag::differentiable_attr_duplicate_note); |
| return; |
| } |
| } |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { |
| auto &ctx = TC.Context; |
| FuncDecl *derivative = dyn_cast<FuncDecl>(D); |
| auto lookupConformance = |
| LookUpConformanceInModule(D->getDeclContext()->getParentModule()); |
| auto original = attr->getOriginal(); |
| |
| auto *derivativeInterfaceType = derivative->getInterfaceType() |
| ->castTo<AnyFunctionType>(); |
| |
| // Perform preliminary derivative checks. |
| |
| // The result type should be a two-element tuple. |
| // Either a value and pullback: |
| // (value: R, pullback: (R.TangentVector) -> (T.TangentVector...) |
| // Or a value and differential: |
| // (value: R, differential: (T.TangentVector...) -> (R.TangentVector) |
| auto derivativeResultType = derivative->getResultInterfaceType(); |
| auto derivativeResultTupleType = derivativeResultType->getAs<TupleType>(); |
| if (!derivativeResultTupleType || |
| derivativeResultTupleType->getNumElements() != 2) { |
| TC.diagnose(attr->getLocation(), |
| diag::differentiating_attr_expected_result_tuple); |
| attr->setInvalid(); |
| return; |
| } |
| auto valueResultElt = derivativeResultTupleType->getElement(0); |
| auto funcResultElt = derivativeResultTupleType->getElement(1); |
| // Get derivative kind and derivative function identifier. |
| AutoDiffDerivativeFunctionKind kind; |
| if (valueResultElt.getName().str() != "value") { |
| TC.diagnose(attr->getLocation(), |
| diag::differentiating_attr_invalid_result_tuple_value_label); |
| attr->setInvalid(); |
| return; |
| } |
| if (funcResultElt.getName().str() == "differential") { |
| kind = AutoDiffDerivativeFunctionKind::JVP; |
| } else if (funcResultElt.getName().str() == "pullback") { |
| kind = AutoDiffDerivativeFunctionKind::VJP; |
| } else { |
| TC.diagnose(attr->getLocation(), |
| diag::differentiating_attr_invalid_result_tuple_func_label); |
| attr->setInvalid(); |
| return; |
| } |
| // `value: R` result tuple element must conform to `Differentiable`. |
| auto diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable); |
| auto valueResultType = valueResultElt.getType(); |
| if (valueResultType->hasTypeParameter()) |
| valueResultType = derivative->mapTypeIntoContext(valueResultType); |
| auto valueResultConf = TC.conformsToProtocol(valueResultType, diffableProto, |
| derivative->getDeclContext(), |
| None); |
| if (!valueResultConf) { |
| TC.diagnose(attr->getLocation(), |
| diag::differentiating_attr_result_value_not_differentiable, |
| valueResultElt.getType()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Compute expected original function type and look up original function. |
| auto *originalFnType = |
| derivativeInterfaceType->getAutoDiffOriginalFunctionType(); |
| |
| std::function<bool(GenericSignature, GenericSignature)> |
| checkGenericSignatureSatisfied = |
| [&](GenericSignature source, GenericSignature target) { |
| // If target is null, then its requirements are satisfied. |
| if (!target) |
| return true; |
| // If source is null but target is not null, then target's |
| // requirements are not satisfied. |
| if (!source) |
| return false; |
| // Check if target's requirements are satisfied by source. |
| return TC.checkGenericArguments( |
| derivative, original.Loc.getBaseNameLoc(), |
| original.Loc.getBaseNameLoc(), Type(), |
| source->getGenericParams(), target->getRequirements(), |
| [](SubstitutableType *dependentType) { |
| return Type(dependentType); |
| }, lookupConformance, None) == RequirementCheckResult::Success; |
| }; |
| |
| auto isValidOriginal = [&](FuncDecl *originalCandidate) { |
| TC.validateDecl(originalCandidate); |
| return checkFunctionSignature( |
| cast<AnyFunctionType>(originalFnType->getCanonicalType()), |
| originalCandidate->getInterfaceType()->getCanonicalType(), |
| checkGenericSignatureSatisfied); |
| }; |
| |
| // TODO: Do not reuse incompatible `@differentiable` attribute diagnostics. |
| // Rename compatible diagnostics so that they're not attribute-specific. |
| auto overloadDiagnostic = [&]() { |
| TC.diagnose(original.Loc, diag::differentiating_attr_overload_not_found, |
| original.Name, originalFnType); |
| }; |
| auto ambiguousDiagnostic = [&]() { |
| TC.diagnose(original.Loc, |
| diag::differentiable_attr_ambiguous_function_identifier, |
| original.Name); |
| }; |
| auto notFunctionDiagnostic = [&]() { |
| TC.diagnose(original.Loc, diag::differentiable_attr_specified_not_function, |
| original.Name); |
| }; |
| std::function<void()> invalidTypeContextDiagnostic = [&]() { |
| TC.diagnose(original.Loc, |
| diag::differentiable_attr_function_not_same_type_context, |
| original.Name); |
| }; |
| |
| // Returns true if the derivative function and original function candidate are |
| // defined in compatible type contexts. If the derivative function and the |
| // original function candidate have different parents, return false. |
| std::function<bool(FuncDecl *)> hasValidTypeContext = [&](FuncDecl *func) { |
| // Check if both functions are top-level. |
| if (!derivative->getInnermostTypeContext() && |
| !func->getInnermostTypeContext()) |
| return true; |
| // Check if both functions are defined in the same type context. |
| if (auto typeCtx1 = derivative->getInnermostTypeContext()) |
| if (auto typeCtx2 = func->getInnermostTypeContext()) { |
| return typeCtx1->getSelfNominalTypeDecl() == |
| typeCtx2->getSelfNominalTypeDecl(); |
| } |
| return derivative->getParent() == func->getParent(); |
| }; |
| |
| auto lookupOptions = defaultMemberLookupOptions |
| | NameLookupFlags::IgnoreAccessControl; |
| auto derivativeTypeCtx = derivative->getInnermostTypeContext(); |
| if (!derivativeTypeCtx) derivativeTypeCtx = derivative->getParent(); |
| assert(derivativeTypeCtx); |
| |
| // Look up original function. |
| auto *originalFn = TC.lookupFuncDecl( |
| original.Name, original.Loc.getBaseNameLoc(), /*baseType*/ Type(), |
| derivativeTypeCtx, isValidOriginal, overloadDiagnostic, |
| ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions, |
| hasValidTypeContext, invalidTypeContextDiagnostic); |
| if (!originalFn) { |
| attr->setInvalid(); |
| return; |
| } |
| attr->setOriginalFunction(originalFn); |
| |
| // Get checked wrt param indices. |
| auto *checkedWrtParamIndices = attr->getParameterIndices(); |
| |
| // Get the parsed wrt param indices, which have not yet been checked. |
| // This is defined for parsed attributes. |
| auto parsedWrtParams = attr->getParsedParameters(); |
| |
| // If checked wrt param indices are not specified, compute them. |
| if (!checkedWrtParamIndices) |
| checkedWrtParamIndices = |
| computeDifferentiationParameters(TC, parsedWrtParams, derivative, |
| derivative->getGenericEnvironment(), |
| attr->getAttrName(), |
| attr->getLocation()); |
| if (!checkedWrtParamIndices) { |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Check if differentiation parameter indices are valid. |
| if (checkDifferentiationParameters( |
| TC, originalFn, checkedWrtParamIndices, originalFnType, |
| derivative->getGenericEnvironment(), derivative->getModuleContext(), |
| parsedWrtParams, attr->getLocation())) { |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Set the checked differentiation parameter indices in the attribute. |
| attr->setParameterIndices(checkedWrtParamIndices); |
| |
| // Gather differentiation parameters. |
| SmallVector<Type, 4> wrtParamTypes; |
| autodiff::getSubsetParameterTypes(checkedWrtParamIndices, originalFnType, |
| wrtParamTypes); |
| |
| auto diffParamElts = |
| map<SmallVector<TupleTypeElt, 4>>(wrtParamTypes, [&](Type paramType) { |
| if (paramType->hasTypeParameter()) |
| paramType = derivative->mapTypeIntoContext(paramType); |
| auto conf = TC.conformsToProtocol(paramType, diffableProto, derivative, |
| None); |
| assert(conf && |
| "Expected checked parameter to conform to `Differentiable`"); |
| auto paramAssocType = conf->getTypeWitnessByName( |
| paramType, ctx.Id_TangentVector); |
| return TupleTypeElt(paramAssocType); |
| }); |
| |
| // Check differential/pullback type. |
| // Get vector type: the associated type of the value result type. |
| auto vectorTy = valueResultConf->getTypeWitnessByName( |
| valueResultType, ctx.Id_TangentVector); |
| |
| // Compute expected differential/pullback type. |
| auto funcEltType = funcResultElt.getType(); |
| Type expectedFuncEltType; |
| if (kind == AutoDiffDerivativeFunctionKind::JVP) { |
| auto diffParams = map<SmallVector<AnyFunctionType::Param, 4>>( |
| diffParamElts, [&](TupleTypeElt elt) { |
| return AnyFunctionType::Param(elt.getType()); |
| }); |
| expectedFuncEltType = FunctionType::get(diffParams, vectorTy); |
| } else { |
| expectedFuncEltType = FunctionType::get({AnyFunctionType::Param(vectorTy)}, |
| TupleType::get(diffParamElts, ctx)); |
| } |
| expectedFuncEltType = expectedFuncEltType->mapTypeOutOfContext(); |
| |
| // Check if differential/pullback type matches expected type. |
| if (!funcEltType->isEqual(expectedFuncEltType)) { |
| // Emit differential/pullback type mismatch error on attribute. |
| TC.diagnose(attr->getLocation(), |
| diag::differentiating_attr_result_func_type_mismatch, |
| funcResultElt.getName(), originalFn->getFullName()); |
| // Emit note with expected differential/pullback type on actual type |
| // location. |
| auto *tupleReturnTypeRepr = |
| cast<TupleTypeRepr>(derivative->getBodyResultTypeLoc().getTypeRepr()); |
| auto *funcEltTypeRepr = tupleReturnTypeRepr->getElementType(1); |
| TC.diagnose(funcEltTypeRepr->getStartLoc(), |
| diag::differentiating_attr_result_func_type_mismatch_note, |
| funcResultElt.getName(), expectedFuncEltType) |
| .highlight(funcEltTypeRepr->getSourceRange()); |
| // Emit note showing original function location, if possible. |
| if (originalFn->getLoc().isValid()) |
| TC.diagnose(originalFn->getLoc(), |
| diag::differentiating_attr_result_func_original_note, |
| originalFn->getFullName()); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Reject different-file retroactive derivatives. |
| // TODO(TF-136): Full support for cross-file/cross-module retroactive |
| // differentiability will require SIL differentiability witnesses and lots of |
| // plumbing. |
| if (originalFn->getParentSourceFile() != derivative->getParentSourceFile()) { |
| diagnoseAndRemoveAttr( |
| attr, diag::differentiating_attr_not_in_same_file_as_original); |
| return; |
| } |
| |
| // Try to find a `@differentiable` attribute on the original function with the |
| // same differentiation parameters. |
| DifferentiableAttr *da = nullptr; |
| for (auto *cda : originalFn->getAttrs().getAttributes<DifferentiableAttr>()) |
| if (checkedWrtParamIndices == cda->getParameterIndices()) |
| da = const_cast<DifferentiableAttr *>(cda); |
| // If the original function does not have a `@differentiable` attribute with |
| // the same differentiation parameters, create one. |
| if (!da) { |
| da = DifferentiableAttr::create(ctx, /*implicit*/ true, attr->AtLoc, |
| attr->getRange(), attr->isLinear(), |
| checkedWrtParamIndices, /*jvp*/ None, |
| /*vjp*/ None, |
| derivative->getGenericSignature()); |
| switch (kind) { |
| case AutoDiffDerivativeFunctionKind::JVP: |
| da->setJVPFunction(derivative); |
| break; |
| case AutoDiffDerivativeFunctionKind::VJP: |
| da->setVJPFunction(derivative); |
| break; |
| } |
| auto insertion = ctx.DifferentiableAttrs.try_emplace( |
| {originalFn, checkedWrtParamIndices}, da); |
| // Valid `@differentiable` attributes are uniqued by their parameter |
| // indices. Reject duplicate attributes for the same decl and parameter |
| // indices pair. |
| if (!insertion.second && insertion.first->getSecond() != da) { |
| diagnoseAndRemoveAttr(da, diag::differentiable_attr_duplicate); |
| TC.diagnose(insertion.first->getSecond()->getLocation(), |
| diag::differentiable_attr_duplicate_note); |
| return; |
| } |
| originalFn->getAttrs().add(da); |
| return; |
| } |
| // If the original function has a `@differentiable` attribute with the same |
| // differentiation parameters, check if the `@differentiable` attribute |
| // already has a different registered derivative. If so, emit an error on the |
| // `@differentiating` attribute. Otherwise, register the derivative in the |
| // `@differentiable` attribute. |
| switch (kind) { |
| case AutoDiffDerivativeFunctionKind::JVP: |
| // If there's a different registered derivative, emit an error. |
| if ((da->getJVP() && |
| da->getJVP()->Name.getBaseName() != derivative->getBaseName()) || |
| (da->getJVPFunction() && da->getJVPFunction() != derivative)) { |
| diagnoseAndRemoveAttr( |
| attr, diag::differentiating_attr_original_already_has_derivative, |
| originalFn->getFullName()); |
| return; |
| } |
| da->setJVPFunction(derivative); |
| break; |
| case AutoDiffDerivativeFunctionKind::VJP: |
| // If there's a different registered derivative, emit an error. |
| if ((da->getVJP() && |
| da->getVJP()->Name.getBaseName() != derivative->getBaseName()) || |
| (da->getVJPFunction() && da->getVJPFunction() != derivative)) { |
| diagnoseAndRemoveAttr( |
| attr, diag::differentiating_attr_original_already_has_derivative, |
| originalFn->getFullName()); |
| return; |
| } |
| da->setVJPFunction(derivative); |
| break; |
| } |
| } |
| |
| /// Pushes the subset's parameter's types to `paramTypes`, in the order in |
| /// which they appear in the function type. For example, |
| /// |
| /// functionType = (A, B, C) -> R |
| /// if "A" and "C" are in the set, |
| /// ==> pushes {A, C} to `paramTypes`. |
| /// |
| void getIndexSubsetParameterTypes( |
| IndexSubset *indexSubset, AnyFunctionType *functionType, |
| SmallVectorImpl<Type> ¶mTypes, bool isCurried) { |
| auto *fnTy = functionType; |
| if (isCurried) { |
| fnTy = fnTy->getResult()->getAs<AnyFunctionType>(); |
| } |
| |
| for (unsigned paramIndex : range(fnTy->getNumParams())) { |
| if ((paramIndex < indexSubset->getCapacity()) && |
| indexSubset->contains(paramIndex)) { |
| paramTypes.push_back(fnTy->getParams()[paramIndex].getPlainType()); |
| } |
| } |
| } |
| |
| void AttributeChecker::visitTransposingAttr(TransposingAttr *attr) { |
| auto &ctx = TC.Context; |
| auto *transpose = dyn_cast<FuncDecl>(D); |
| auto lookupConformance = |
| LookUpConformanceInModule(D->getDeclContext()->getParentModule()); |
| auto original = attr->getOriginal(); |
| TC.resolveDeclSignature(transpose); |
| auto *transposeInterfaceType = transpose->getInterfaceType() |
| ->castTo<AnyFunctionType>(); |
| |
| // Get checked wrt param indices. |
| auto *wrtParamIndices = attr->getParameterIndexSubset(); |
| |
| // Get the parsed wrt param indices, which have not yet been checked. |
| // This is defined for parsed attributes. |
| auto parsedWrtParams = attr->getParsedParameters(); |
| |
| bool wrtSelf = false; |
| if (!parsedWrtParams.empty()) |
| wrtSelf = parsedWrtParams.front().getKind() == |
| ParsedAutoDiffParameter::Kind::Self; |
| |
| // If checked wrt param indices are not specified, compute them. |
| bool isCurried = transposeInterfaceType->getResult()->is<AnyFunctionType>(); |
| if (!wrtParamIndices) |
| wrtParamIndices = computeTransposingParameters( |
| TC, parsedWrtParams, transpose, isCurried, |
| transpose->getGenericEnvironment(), |
| attr->getLocation()); |
| if (!wrtParamIndices) { |
| D->getAttrs().removeAttribute(attr); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Diagnose empty parameter indices. This occurs when no `wrt` clause is |
| // declared and no differentiation parameters can be inferred. |
| if (wrtParamIndices->isEmpty()) { |
| TC.diagnose(attr->getLocation(), |
| diag::diff_params_clause_no_inferred_parameters); |
| D->getAttrs().removeAttribute(attr); |
| attr->setInvalid(); |
| return; |
| } |
| |
| auto *expectedOriginalFnType = |
| transposeInterfaceType->getTransposeOriginalFunctionType( |
| attr, wrtParamIndices, wrtSelf); |
| |
| // `R` result type must conform to `Differentiable`. |
| auto expectedOriginalResultType = expectedOriginalFnType->getResult(); |
| if (isCurried) { |
| expectedOriginalResultType = transpose->mapTypeIntoContext( |
| expectedOriginalResultType->getAs<AnyFunctionType>()->getResult()); |
| } |
| if (expectedOriginalResultType->hasTypeParameter()) |
| expectedOriginalResultType = transpose->mapTypeIntoContext( |
| expectedOriginalResultType); |
| auto diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable); |
| auto valueResultConf = TC.conformsToProtocol(expectedOriginalResultType, |
| diffableProto, transpose->getDeclContext(), None); |
| |
| if (!valueResultConf) { |
| TC.diagnose(attr->getLocation(), |
| diag::transposing_attr_result_value_not_differentiable, |
| expectedOriginalFnType); |
| D->getAttrs().removeAttribute(attr); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Compute expected original function type. |
| std::function<bool(GenericSignature, GenericSignature)> |
| checkGenericSignatureSatisfied = [&](GenericSignature source, |
| GenericSignature target) { |
| // If target is null, then its requirements are satisfied. |
| if (!target) |
| return true; |
| // If source is null but target is not null, then target's |
| // requirements are not satisfied. |
| if (!source) |
| return false; |
| // Check if target's requirements are satisfied by source. |
| return TC.checkGenericArguments( |
| transpose, original.Loc.getBaseNameLoc(), |
| original.Loc.getBaseNameLoc(), Type(), |
| source->getGenericParams(), target->getRequirements(), |
| [](SubstitutableType *dependentType) { |
| return Type(dependentType); |
| }, |
| lookupConformance, None) == RequirementCheckResult::Success; |
| }; |
| |
| auto isValidOriginal = [&](FuncDecl *originalCandidate) { |
| TC.validateDecl(originalCandidate); |
| return checkFunctionSignature( |
| cast<AnyFunctionType>(expectedOriginalFnType->getCanonicalType()), |
| originalCandidate->getInterfaceType()->getCanonicalType(), |
| checkGenericSignatureSatisfied); |
| }; |
| |
| // TODO: Do not reuse incompatible `@differentiable` attribute diagnostics. |
| // Rename compatible diagnostics so that they're not attribute-specific. |
| auto overloadDiagnostic = [&]() { |
| TC.diagnose(original.Loc, diag::differentiating_attr_overload_not_found, |
| original.Name, expectedOriginalFnType); |
| }; |
| auto ambiguousDiagnostic = [&]() { |
| TC.diagnose(original.Loc, |
| diag::differentiable_attr_ambiguous_function_identifier, |
| original.Name); |
| }; |
| auto notFunctionDiagnostic = [&]() { |
| TC.diagnose(original.Loc, diag::differentiable_attr_specified_not_function, |
| original.Name); |
| }; |
| std::function<void()> invalidTypeContextDiagnostic = [&]() { |
| TC.diagnose(original.Loc, |
| diag::differentiable_attr_function_not_same_type_context, |
| original.Name); |
| }; |
| |
| // Returns true if the derivative function and original function candidate are |
| // defined in compatible type contexts. If the derivative function and the |
| // original function candidate have different parents, return false. |
| std::function<bool(FuncDecl *)> hasValidTypeContext = [&](FuncDecl *func) { |
| return true; |
| }; |
| |
| auto typeRes = TypeResolution::forContextual(transpose->getDeclContext()); |
| auto baseType = Type(); |
| if (attr->getBaseType()) |
| baseType = typeRes.resolveType(attr->getBaseType(), None); |
| auto lookupOptions = (attr->getBaseType() ? defaultMemberLookupOptions |
| : defaultUnqualifiedLookupOptions) | |
| NameLookupFlags::IgnoreAccessControl; |
| auto transposeTypeCtx = transpose->getInnermostTypeContext(); |
| if (!transposeTypeCtx) transposeTypeCtx = transpose->getParent(); |
| assert(transposeTypeCtx); |
| |
| // Look up original function. |
| auto funcLoc = original.Loc.getBaseNameLoc(); |
| if (attr->getBaseType()) |
| funcLoc = attr->getBaseType()->getLoc(); |
| auto *originalFn = TC.lookupFuncDecl( |
| original.Name, funcLoc, baseType, transposeTypeCtx, isValidOriginal, |
| overloadDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic, |
| lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic); |
| |
| if (!originalFn) { |
| D->getAttrs().removeAttribute(attr); |
| attr->setInvalid(); |
| return; |
| } |
| |
| attr->setOriginalFunction(originalFn); |
| |
| // Gather differentiation parameters. |
| // Differentiation parameters are with respect to the original function. |
| SmallVector<Type, 4> wrtParamTypes; |
| getIndexSubsetParameterTypes(wrtParamIndices, expectedOriginalFnType, |
| wrtParamTypes, isCurried); |
| |
| // Check if differentiation parameter indices are valid. |
| if (checkTransposingParameters(TC, originalFn, wrtParamTypes, |
| transpose->getGenericEnvironment(), |
| transpose->getModuleContext(), parsedWrtParams, |
| attr->getLocation())) { |
| D->getAttrs().removeAttribute(attr); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Set the checked differentiation parameter indices in the attribute. |
| attr->setParameterIndices(wrtParamIndices); |
| |
| // Check if original function type matches expected original function type |
| // we computed. |
| std::function<bool(GenericSignature, GenericSignature)> |
| genericComparison = |
| [&](GenericSignature a, GenericSignature b) { return a.getPointer() == b.getPointer(); }; |
| if (!checkFunctionSignature( |
| cast<AnyFunctionType>(expectedOriginalFnType->getCanonicalType()), |
| originalFn->getInterfaceType()->getCanonicalType(), |
| genericComparison)) { |
| // Emit differential/pullback type mismatch error on attribute. |
| TC.diagnose(attr->getLocation(), |
| diag::differentiating_attr_result_func_type_mismatch, |
| transpose->getName(), originalFn->getName()); |
| // Emit note with expected differential/pullback type on actual type |
| // location. |
| auto *tupleReturnTypeRepr = |
| cast<TupleTypeRepr>(transpose->getBodyResultTypeLoc().getTypeRepr()); |
| auto *funcEltTypeRepr = tupleReturnTypeRepr->getElementType(1); |
| TC.diagnose(funcEltTypeRepr->getStartLoc(), |
| diag::differentiating_attr_result_func_type_mismatch_note, |
| transpose->getName(), expectedOriginalFnType) |
| .highlight(funcEltTypeRepr->getSourceRange()); |
| // Emit note showing original function location, if possible. |
| if (originalFn->getLoc().isValid()) |
| TC.diagnose(originalFn->getLoc(), |
| diag::differentiating_attr_result_func_original_note, |
| originalFn->getFullName()); |
| D->getAttrs().removeAttribute(attr); |
| attr->setInvalid(); |
| return; |
| } |
| } |
| |
| static bool |
| compilerEvaluableAllowedInExtensionDecl(ExtensionDecl *extensionDecl) { |
| auto extendedTypeKind = extensionDecl->getExtendedType()->getKind(); |
| return extendedTypeKind == TypeKind::Enum || |
| extendedTypeKind == TypeKind::Protocol || |
| extendedTypeKind == TypeKind::Struct || |
| extendedTypeKind == TypeKind::BoundGenericEnum || |
| extendedTypeKind == TypeKind::BoundGenericStruct; |
| } |
| |
| void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) { |
| // Check that the function is defined in an allowed context. |
| // TODO(marcrasi): In many cases, we can probably generate a more informative |
| // error message than just saying that it's "not allowed here". (Like "not |
| // allowed in a class [point at the class decl], put it at the top level or in |
| // a struct instead"). |
| auto declContext = D->getDeclContext(); |
| switch (declContext->getContextKind()) { |
| case DeclContextKind::AbstractFunctionDecl: |
| // Nested functions are okay. |
| break; |
| case DeclContextKind::ExtensionDecl: |
| // Enum, Protocol, and Struct extensions are okay. For Enums and Structs |
| // extensions, the extended type must be compiler-representable. |
| // TODO(marcrasi): Check that the extended type is compiler-representable. |
| if (!compilerEvaluableAllowedInExtensionDecl( |
| cast<ExtensionDecl>(declContext))) { |
| TC.diagnose(D, diag::compiler_evaluable_bad_context); |
| attr->setInvalid(); |
| return; |
| } |
| break; |
| case DeclContextKind::FileUnit: |
| // Top level functions are okay. |
| break; |
| case DeclContextKind::GenericTypeDecl: |
| switch (cast<GenericTypeDecl>(declContext)->getKind()) { |
| case DeclKind::Enum: |
| // Enums are okay, if they are compiler-representable. |
| // TODO(marcrasi): Check that it's compiler-representable. |
| break; |
| case DeclKind::Struct: |
| // Structs are okay, if they are compiler-representable. |
| // TODO(marcrasi): Check that it's compiler-representable. |
| break; |
| default: |
| TC.diagnose(D, diag::compiler_evaluable_bad_context); |
| attr->setInvalid(); |
| return; |
| } |
| break; |
| default: |
| TC.diagnose(D, diag::compiler_evaluable_bad_context); |
| attr->setInvalid(); |
| return; |
| } |
| |
| // Check that the signature only has allowed types. |
| // TODO(marcrasi): Do this. |
| |
| // For @compilerEvaluable to be truly valid, the function body must also |
| // follow certain rules. We can only check these rules after the body is type |
| // checked, and it's not type checked yet, so we check these rules later in |
| // TypeChecker::checkFunctionBodyCompilerEvaluable(). |
| } |
| |
| // SWIFT_ENABLE_TENSORFLOW |
| void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) { |
| auto *vd = dyn_cast<VarDecl>(D); |
| if (attr->isImplicit()) |
| return; |
| if (!vd || vd->isStatic()) { |
| diagnoseAndRemoveAttr(attr, |
| diag::noderivative_only_on_differentiable_struct_or_class_fields); |
| return; |
| } |
| auto *nominal = vd->getDeclContext()->getSelfNominalTypeDecl(); |
| if (!nominal || (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal))) { |
| diagnoseAndRemoveAttr(attr, |
| diag::noderivative_only_on_differentiable_struct_or_class_fields); |
| return; |
| } |
| if (!conformsToDifferentiable( |
| nominal->getDeclaredInterfaceType(), |
| nominal->getDeclContext())) { |
| diagnoseAndRemoveAttr(attr, |
| diag::noderivative_only_on_differentiable_struct_or_class_fields); |
| return; |
| } |
| } |
| |
| void |
| AttributeChecker::visitImplementationOnlyAttr(ImplementationOnlyAttr *attr) { |
| if (isa<ImportDecl>(D)) { |
| // These are handled elsewhere. |
| return; |
| } |
| |
| auto *VD = cast<ValueDecl>(D); |
| auto *overridden = VD->getOverriddenDecl(); |
| if (!overridden) { |
| diagnoseAndRemoveAttr(attr, diag::implementation_only_decl_non_override); |
| return; |
| } |
| |
| // Check if VD has the exact same type as what it overrides. |
| // Note: This is specifically not using `swift::getMemberTypeForComparison` |
| // because that erases more information than we want, like `throws`-ness. |
| auto baseInterfaceTy = overridden->getInterfaceType(); |
| auto derivedInterfaceTy = VD->getInterfaceType(); |
| |
| auto selfInterfaceTy = VD->getDeclContext()->getDeclaredInterfaceType(); |
| |
| auto overrideInterfaceTy = |
| selfInterfaceTy->adjustSuperclassMemberDeclType(overridden, VD, |
| baseInterfaceTy); |
| |
| if (isa<AbstractFunctionDecl>(VD)) { |
| // Drop the 'Self' parameter. |
| // FIXME: The real effect here, though, is dropping the generic signature. |
| // This should be okay because it should already be checked as part of |
| // making an override, but that isn't actually the case as of this writing, |
| // and it's kind of suspect anyway. |
| derivedInterfaceTy = |
| derivedInterfaceTy->castTo<AnyFunctionType>()->getResult(); |
| overrideInterfaceTy = |
| overrideInterfaceTy->castTo<AnyFunctionType>()->getResult(); |
| } else if (isa<SubscriptDecl>(VD)) { |
| // For subscripts, we don't have a 'Self' type, but turn it |
| // into a monomorphic function type. |
| // FIXME: does this actually make sense, though? |
| auto derivedInterfaceFuncTy = derivedInterfaceTy->castTo<AnyFunctionType>(); |
| derivedInterfaceTy = |
| FunctionType::get(derivedInterfaceFuncTy->getParams(), |
| derivedInterfaceFuncTy->getResult()); |
| auto overrideInterfaceFuncTy = |
| overrideInterfaceTy->castTo<AnyFunctionType>(); |
| overrideInterfaceTy = |
| FunctionType::get(overrideInterfaceFuncTy->getParams(), |
| overrideInterfaceFuncTy->getResult()); |
| } |
| |
| if (!derivedInterfaceTy->isEqual(overrideInterfaceTy)) { |
| TC.diagnose(VD, diag::implementation_only_override_changed_type, |
| overrideInterfaceTy); |
| TC.diagnose(overridden, diag::overridden_here); |
| return; |
| } |
| |
| // FIXME: When compiling without library evolution enabled, this should also |
| // check whether VD or any of its accessors need a new vtable entry, even if |
| // it won't necessarily be able to say why. |
| } |
| |
| void TypeChecker::checkParameterAttributes(ParameterList *params) { |
| for (auto param: *params) { |
| checkDeclAttributes(param); |
| } |
| } |
| |
| Type TypeChecker::checkReferenceOwnershipAttr(VarDecl *var, Type type, |
| ReferenceOwnershipAttr *attr) { |
| auto *dc = var->getDeclContext(); |
| |
| // Don't check ownership attribute if the type is invalid. |
| if (attr->isInvalid() || type->is<ErrorType>()) |
| return type; |
| |
| auto ownershipKind = attr->get(); |
| |
| // A weak variable must have type R? or R! for some ownership-capable type R. |
| auto underlyingType = type->getOptionalObjectType(); |
| auto isOptional = bool(underlyingType); |
| |
| switch (optionalityOf(ownershipKind)) { |
| case ReferenceOwnershipOptionality::Disallowed: |
| if (isOptional) { |
| diagnose(var->getStartLoc(), diag::invalid_ownership_with_optional, |
| ownershipKind) |
| .fixItReplace(attr->getRange(), "weak"); |
| attr->setInvalid(); |
| } |
| break; |
| case ReferenceOwnershipOptionality::Allowed: |
| break; |
| case ReferenceOwnershipOptionality::Required: |
| if (var->isLet()) { |
| diagnose(var->getStartLoc(), diag::invalid_ownership_is_let, |
| ownershipKind); |
| attr->setInvalid(); |
| } |
| |
| if (!isOptional) { |
| attr->setInvalid(); |
| |
| // @IBOutlet has its own diagnostic when the property type is |
| // non-optional. |
| if (var->getAttrs().hasAttribute<IBOutletAttr>()) |
| break; |
| |
| auto diag = diagnose(var->getStartLoc(), |
| diag::invalid_ownership_not_optional, |
| ownershipKind, |
| OptionalType::get(type)); |
| auto typeRange = var->getTypeSourceRangeForDiagnostics(); |
| if (type->hasSimpleTypeRepr()) { |
| diag.fixItInsertAfter(typeRange.End, "?"); |
| } else { |
| diag.fixItInsert(typeRange.Start, "(") |
| .fixItInsertAfter(typeRange.End, ")?"); |
| } |
| } |
| break; |
| } |
| |
| if (!underlyingType) |
| underlyingType = type; |
| |
| auto sig = var->getDeclContext()->getGenericSignatureOfContext(); |
| if (!underlyingType->allowsOwnership(sig.getPointer())) { |
| auto D = diag::invalid_ownership_type; |
| |
| if (underlyingType->isExistentialType() || |
| underlyingType->isTypeParameter()) { |
| // Suggest the possibility of adding a class bound. |
| D = diag::invalid_ownership_protocol_type; |
| } |
| |
| diagnose(var->getStartLoc(), D, ownershipKind, underlyingType); |
| attr->setInvalid(); |
| } |
| |
| ClassDecl *underlyingClass = underlyingType->getClassOrBoundGenericClass(); |
| if (underlyingClass && underlyingClass->isIncompatibleWithWeakReferences()) { |
| diagnose(attr->getLocation(), |
| diag::invalid_ownership_incompatible_class, |
| underlyingType, ownershipKind) |
| .fixItRemove(attr->getRange()); |
| attr->setInvalid(); |
| } |
| |
| auto PDC = dyn_cast<ProtocolDecl>(dc); |
| if (PDC && !PDC->isObjC()) { |
| // Ownership does not make sense in protocols, except for "weak" on |
| // properties of Objective-C protocols. |
| auto D = Context.isSwiftVersionAtLeast(5) |
| ? diag::ownership_invalid_in_protocols |
| : diag::ownership_invalid_in_protocols_compat_warning; |
| diagnose(attr->getLocation(), D, ownershipKind) |
| .fixItRemove(attr->getRange()); |
| attr->setInvalid(); |
| } |
| |
| if (attr->isInvalid()) |
| return type; |
| |
| // Change the type to the appropriate reference storage type. |
| return ReferenceStorageType::get(type, ownershipKind, Context); |
| } |
| |
| Optional<Diag<>> |
| TypeChecker::diagnosticIfDeclCannotBePotentiallyUnavailable(const Decl *D) { |
| DeclContext *DC = D->getDeclContext(); |
| // Do not permit potential availability of script-mode global variables; |
| // their initializer expression is not lazily evaluated, so this would |
| // not be safe. |
| if (isa<VarDecl>(D) && DC->isModuleScopeContext() && |
| DC->getParentSourceFile()->isScriptMode()) { |
| return diag::availability_global_script_no_potential; |
| } |
| |
| // For now, we don't allow stored properties to be potentially unavailable. |
| // We will want to support these eventually, but we haven't figured out how |
| // this will interact with Definite Initialization, deinitializers and |
| // resilience yet. |
| if (auto *VD = dyn_cast<VarDecl>(D)) { |
| // Globals and statics are lazily initialized, so they are safe |
| // for potential unavailability. Note that if D is a global in script |
| // mode (which are not lazy) then we will already have returned |
| // a diagnosis above. |
| bool lazilyInitializedStored = VD->isStatic() || |
| VD->getAttrs().hasAttribute<LazyAttr>() || |
| DC->isModuleScopeContext(); |
| |
| if (VD->hasStorage() && !lazilyInitializedStored) { |
| return diag::availability_stored_property_no_potential; |
| } |
| } |
| |
| return None; |
| } |
| |
| static bool shouldBlockImplicitDynamic(Decl *D) { |
| if (D->getAttrs().hasAttribute<NonObjCAttr>() || |
| D->getAttrs().hasAttribute<SILGenNameAttr>() || |
| D->getAttrs().hasAttribute<TransparentAttr>() || |
| D->getAttrs().hasAttribute<InlinableAttr>()) |
| return true; |
| return false; |
| } |
| void TypeChecker::addImplicitDynamicAttribute(Decl *D) { |
| if (!D->getModuleContext()->isImplicitDynamicEnabled()) |
| return; |
| |
| // Add the attribute if the decl kind allows it and it is not an accessor |
| // decl. Accessor decls should always infer the var/subscript's attribute. |
| if (!DeclAttribute::canAttributeAppearOnDecl(DAK_Dynamic, D) || |
| isa<AccessorDecl>(D)) |
| return; |
| |
| // Don't add dynamic if decl is inlinable or tranparent. |
| if (shouldBlockImplicitDynamic(D)) |
| return; |
| |
| if (auto *FD = dyn_cast<FuncDecl>(D)) { |
| // Don't add dynamic to defer bodies. |
| if (FD->isDeferBody()) |
| return; |
| // Don't add dynamic to functions with a cdecl. |
| if (FD->getAttrs().hasAttribute<CDeclAttr>()) |
| return; |
| // Don't add dynamic to local function definitions. |
| if (!FD->getDeclContext()->isTypeContext() && |
| FD->getDeclContext()->isLocalContext()) |
| return; |
| } |
| |
| // Don't add dynamic if accessor is inlinable or transparent. |
| if (auto *asd = dyn_cast<AbstractStorageDecl>(D)) { |
| bool blocked = false; |
| asd->visitParsedAccessors([&](AccessorDecl *accessor) { |
| blocked |= shouldBlockImplicitDynamic(accessor); |
| }); |
| if (blocked) |
| return; |
| } |
| |
| if (auto *VD = dyn_cast<VarDecl>(D)) { |
| // Don't turn stored into computed properties. This could conflict with |
| // exclusivity checking. |
| // If there is a didSet or willSet function we allow dynamic replacement. |
| if (VD->hasStorage() && |
| !VD->getParsedAccessor(AccessorKind::DidSet) && |
| !VD->getParsedAccessor(AccessorKind::WillSet)) |
| return; |
| // Don't add dynamic to local variables. |
| if (VD->getDeclContext()->isLocalContext()) |
| return; |
| // Don't add to implicit variables. |
| if (VD->isImplicit()) |
| return; |
| } |
| |
| if (!D->getAttrs().hasAttribute<DynamicAttr>() && |
| !D->getAttrs().hasAttribute<DynamicReplacementAttr>()) { |
| auto attr = new (D->getASTContext()) DynamicAttr(/*implicit=*/true); |
| D->getAttrs().add(attr); |
| } |
| } |