//===--- DerivedConformanceDifferentiable.cpp - Derived Differentiable ----===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements explicit derivation of the Differentiable protocol for
// struct and class types.
//
//===----------------------------------------------------------------------===//

#include "CodeSynthesis.h"
#include "TypeChecker.h"
#include "TypeCheckType.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Expr.h"
#include "swift/AST/Module.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/Pattern.h"
#include "swift/AST/PropertyWrappers.h"
#include "swift/AST/ProtocolConformance.h"
#include "swift/AST/Stmt.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"
#include "DerivedConformances.h"

using namespace swift;

/// Return true if `move(along:)` can be invoked on the given `Differentiable`-
/// conforming property.
///
/// If the given property is a `var`, return true because `move(along:)` can be
/// invoked regardless.  Otherwise, return true if and only if the property's
/// type's 'Differentiable.move(along:)' witness is non-mutating.
static bool canInvokeMoveAlongOnProperty(
    VarDecl *vd, ProtocolConformanceRef diffableConformance) {
  assert(diffableConformance && "Property must conform to 'Differentiable'");
  // `var` always supports `move(along:)` since it is mutable.
  if (vd->getIntroducer() == VarDecl::Introducer::Var)
    return true;
  // When the property is a `let`, the only case that would be supported is when
  // it has a `move(along:)` protocol requirement witness that is non-mutating.
  auto interfaceType = vd->getInterfaceType();
  auto &C = vd->getASTContext();
  auto witness = diffableConformance.getWitnessByName(
      interfaceType, DeclName(C, C.Id_move, {C.Id_along}));
  if (!witness)
    return false;
  auto *decl = cast<FuncDecl>(witness.getDecl());
  return decl->isNonMutating();
}

/// Get the stored properties of a nominal type that are relevant for
/// differentiation, except the ones tagged `@noDerivative`.
static void
getStoredPropertiesForDifferentiation(
    NominalTypeDecl *nominal, DeclContext *DC,
    SmallVectorImpl<VarDecl *> &result,
    bool includeLetPropertiesWithNonmutatingMoveAlong = false) {
  auto &C = nominal->getASTContext();
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
  for (auto *vd : nominal->getStoredProperties()) {
    // Peer through property wrappers: use original wrapped properties instead.
    if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
      // Skip immutable wrapped properties. `mutating func move(along:)` cannot
      // be synthesized to update these properties.
      if (!originalProperty->isSettable(DC))
        continue;
      // Use the original wrapped property.
      vd = originalProperty;
    }
    // Skip stored properties with `@noDerivative` attribute.
    if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
      continue;
    if (vd->getInterfaceType()->hasError())
      continue;
    auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
    auto conformance = TypeChecker::conformsToProtocol(
        varType, diffableProto, nominal);
    if (!conformance)
      continue;
    // Skip `let` stored properties with a mutating `move(along:)` if requested.
    // `mutating func move(along:)` cannot be synthesized to update `let`
    // properties.
    if (!includeLetPropertiesWithNonmutatingMoveAlong && 
        !canInvokeMoveAlongOnProperty(vd, conformance))
      continue;
    result.push_back(vd);
  }
}

/// Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a
/// `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`.
static StructDecl *convertToStructDecl(ValueDecl *v) {
  if (auto *structDecl = dyn_cast<StructDecl>(v))
    return structDecl;
  auto *typeDecl = dyn_cast<TypeDecl>(v);
  if (!typeDecl)
    return nullptr;
  return dyn_cast_or_null<StructDecl>(
      typeDecl->getDeclaredInterfaceType()->getAnyNominal());
}

/// Get the `Differentiable` protocol `TangentVector` associated type witness
/// for the given interface type and declaration context.
static Type getTangentVectorInterfaceType(Type contextualType,
                                          DeclContext *DC) {
  auto &C = contextualType->getASTContext();
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
  assert(diffableProto && "`Differentiable` protocol not found");
  auto conf =
      TypeChecker::conformsToProtocol(contextualType, diffableProto, DC);
  assert(conf && "Contextual type must conform to `Differentiable`");
  if (!conf)
    return nullptr;
  auto tanType = conf.getTypeWitnessByName(contextualType, C.Id_TangentVector);
  return tanType->hasArchetype() ? tanType->mapTypeOutOfContext() : tanType;
}

/// Returns true iff the given nominal type declaration can derive
/// `TangentVector` as `Self` in the given conformance context.
static bool canDeriveTangentVectorAsSelf(NominalTypeDecl *nominal,
                                         DeclContext *DC) {
  // `Self` must not be a class declaraiton.
  if (nominal->getSelfClassDecl())
    return false;

  auto nominalTypeInContext =
      DC->mapTypeIntoContext(nominal->getDeclaredInterfaceType());
  auto &C = nominal->getASTContext();
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
  auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
  // `Self` must conform to `AdditiveArithmetic`.
  if (!TypeChecker::conformsToProtocol(nominalTypeInContext, addArithProto, DC))
    return false;
  for (auto *field : nominal->getStoredProperties()) {
    // `Self` must not have any `@noDerivative` stored properties.
    if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
      return false;
    // `Self` must have all stored properties satisfy `Self == TangentVector`.
    auto fieldType = DC->mapTypeIntoContext(field->getValueInterfaceType());
    auto conf = TypeChecker::conformsToProtocol(fieldType, diffableProto, DC);
    if (!conf)
      return false;
    auto tangentType = conf.getTypeWitnessByName(fieldType, C.Id_TangentVector);
    if (!fieldType->isEqual(tangentType))
      return false;
  }
  return true;
}

// Synthesizable `Differentiable` protocol requirements.
enum class DifferentiableRequirement {
  // associatedtype TangentVector
  TangentVector,
  // mutating func move(along direction: TangentVector)
  MoveAlong,
  // var zeroTangentVectorInitializer: () -> TangentVector
  ZeroTangentVectorInitializer,
};

static DifferentiableRequirement
getDifferentiableRequirementKind(ValueDecl *requirement) {
  auto &C = requirement->getASTContext();
  if (requirement->getBaseName() == C.Id_TangentVector)
    return DifferentiableRequirement::TangentVector;
  if (requirement->getBaseName() == C.Id_move)
    return DifferentiableRequirement::MoveAlong;
  if (requirement->getBaseName() == C.Id_zeroTangentVectorInitializer)
    return DifferentiableRequirement::ZeroTangentVectorInitializer;
  llvm_unreachable("Invalid `Differentiable` protocol requirement");
}

bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
                                                 DeclContext *DC,
                                                 ValueDecl *requirement) {
  // Experimental differentiable programming must be enabled.
  if (auto *SF = DC->getParentSourceFile())
    if (!isDifferentiableProgrammingEnabled(*SF))
      return false;

  auto reqKind = getDifferentiableRequirementKind(requirement);

  auto &C = nominal->getASTContext();
  // If there are any `TangentVector` type witness candidates, check whether
  // there exists only a single valid candidate.
  bool canUseTangentVectorAsSelf = canDeriveTangentVectorAsSelf(nominal, DC);
  auto isValidTangentVectorCandidate = [&](ValueDecl *v) -> bool {
    // If the requirement is `var zeroTangentVectorInitializer` and
    // the candidate is a type declaration that conforms to
    // `AdditiveArithmetic`, return true.
    if (reqKind == DifferentiableRequirement::ZeroTangentVectorInitializer) {
      if (auto *tangentVectorTypeDecl = dyn_cast<TypeDecl>(v)) {
        auto tangentType = DC->mapTypeIntoContext(
            tangentVectorTypeDecl->getDeclaredInterfaceType());
        auto *addArithProto =
            C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
        if (TypeChecker::conformsToProtocol(tangentType, addArithProto, DC))
          return true;
      }
    }
    // Valid candidate must be a struct or a typealias to a struct.
    auto *structDecl = convertToStructDecl(v);
    if (!structDecl)
      return false;
    // Valid candidate must either:
    // 1. Be implicit (previously synthesized).
    if (structDecl->isImplicit())
      return true;
    // 2. Equal nominal, when the nominal can derive `TangentVector` as `Self`.
    // Nominal type must not customize `TangentVector` to anything other than
    // `Self`. Otherwise, synthesis is semantically unsupported.
    if (structDecl == nominal && canUseTangentVectorAsSelf)
      return true;
    // Otherwise, candidate is invalid.
    return false;
  };
  auto tangentDecls = nominal->lookupDirect(C.Id_TangentVector);
  // There can be at most one valid `TangentVector` type.
  if (tangentDecls.size() > 1)
    return false;
  // There cannot be any invalid `TangentVector` types.
  if (tangentDecls.size() == 1) {
    auto *tangentDecl = tangentDecls.front();
    if (!isValidTangentVectorCandidate(tangentDecl))
      return false;
  }
  bool hasValidTangentDecl = !tangentDecls.empty();

  // Check requirement-specific derivation conditions.
  if (reqKind == DifferentiableRequirement::ZeroTangentVectorInitializer) {
    // If there is a valid `TangentVector` type witness (conforming to
    // `AdditiveArithmetic`), return true.
    if (hasValidTangentDecl)
      return true;
    // Otherwise, fallback on `TangentVector` struct derivation conditions.
  }

  // Check `TangentVector` struct derivation conditions.
  // Nominal type must be a struct or class. (No stored properties is okay.)
  if (!isa<StructDecl>(nominal) && !isa<ClassDecl>(nominal))
    return false;
  // If there are no `TangentVector` candidates, derivation is possible if all
  // differentiation stored properties conform to `Differentiable`.
  SmallVector<VarDecl *, 16> diffProperties;
  getStoredPropertiesForDifferentiation(nominal, DC, diffProperties);
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
  return llvm::all_of(diffProperties, [&](VarDecl *v) {
    if (v->getInterfaceType()->hasError())
      return false;
    auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
    return (bool)TypeChecker::conformsToProtocol(varType, diffableProto, DC);
  });
}

/// Synthesize body for `move(along:)`.
static std::pair<BraceStmt *, bool>
deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
  auto &C = funcDecl->getASTContext();
  auto *parentDC = funcDecl->getParent();
  auto *nominal = parentDC->getSelfNominalTypeDecl();

  // Get `Differentiable.move(along:)` protocol requirement.
  auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable);
  auto *requirement = getProtocolRequirement(diffProto, C.Id_move);

  // Get references to `self` and parameter declarations.
  auto *selfDecl = funcDecl->getImplicitSelfDecl();
  auto *selfDRE =
      new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
  auto *paramDecl = funcDecl->getParameters()->get(0);
  auto *paramDRE =
      new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);

  SmallVector<VarDecl *, 8> diffProperties;
  getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);

  // Create call expression applying a member `move(along:)` method to a
  // parameter member: `self.<member>.move(along: direction.<member>)`.
  auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
    auto *module = nominal->getModuleContext();
    auto memberType =
        parentDC->mapTypeIntoContext(member->getValueInterfaceType());
    auto confRef = module->lookupConformance(memberType, diffProto);
    assert(confRef && "Member does not conform to `Differentiable`");

    // Get member type's requirement witness: `<Member>.move(along:)`.
    ValueDecl *memberWitnessDecl = requirement;
    if (confRef.isConcrete())
      if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement))
        memberWitnessDecl = witness;
    assert(memberWitnessDecl && "Member witness declaration must exist");

    // Create reference to member method: `self.<member>.move(along:)`.
    Expr *memberExpr =
        new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
                              /*Implicit*/ true);
    auto *memberMethodExpr =
        new (C) MemberRefExpr(memberExpr, SourceLoc(), memberWitnessDecl,
                              DeclNameLoc(), /*Implicit*/ true);

    // Create reference to parameter member: `direction.<member>`.
    VarDecl *paramMember = nullptr;
    auto *paramNominal = paramDecl->getType()->getAnyNominal();
    assert(paramNominal && "Parameter should have a nominal type");
    // Find parameter member corresponding to returned nominal member.
    for (auto *candidate : paramNominal->getStoredProperties()) {
      if (candidate->getName() == member->getName()) {
        paramMember = candidate;
        break;
      }
    }
    assert(paramMember && "Could not find corresponding parameter member");
    auto *paramMemberExpr =
        new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(),
                              /*Implicit*/ true);
    // Create expression: `self.<member>.move(along: direction.<member>)`.
    return CallExpr::createImplicit(C, memberMethodExpr, {paramMemberExpr},
                                    {C.Id_along});
  };

  // Collect member `move(along:)` method call expressions.
  SmallVector<ASTNode, 2> memberMethodCallExprs;
  SmallVector<Identifier, 2> memberNames;
  for (auto *member : diffProperties) {
    memberMethodCallExprs.push_back(createMemberMethodCallExpr(member));
    memberNames.push_back(member->getName());
  }
  auto *braceStmt = BraceStmt::create(C, SourceLoc(), memberMethodCallExprs,
                                      SourceLoc(), true);
  return std::pair<BraceStmt *, bool>(braceStmt, false);
}

/// Synthesize body for `var zeroTangentVectorInitializer` getter.
static std::pair<BraceStmt *, bool>
deriveBodyDifferentiable_zeroTangentVectorInitializer(
    AbstractFunctionDecl *funcDecl, void *) {
  auto &C = funcDecl->getASTContext();
  auto *parentDC = funcDecl->getParent();
  auto *nominal = parentDC->getSelfNominalTypeDecl();

  // Get method protocol requirement.
  auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable);
  auto *requirement =
      getProtocolRequirement(diffProto, C.Id_zeroTangentVectorInitializer);

  auto nominalType =
      parentDC->mapTypeIntoContext(nominal->getDeclaredInterfaceType());
  auto conf = TypeChecker::conformsToProtocol(nominalType, diffProto, parentDC);
  auto tangentType = conf.getTypeWitnessByName(nominalType, C.Id_TangentVector);
  auto *tangentTypeExpr = TypeExpr::createImplicit(tangentType, C);

  // Get differentiation properties.
  SmallVector<VarDecl *, 8> diffProperties;
  getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties,
                                        /*includeLetProperties*/ true);

  // Check whether memberwise derivation of `zeroTangentVectorInitializer` is
  // possible.
  bool canPerformMemberwiseDerivation = [&]() -> bool {
    // Memberwise derivation is possible only for struct `TangentVector` types.
    auto *tangentTypeDecl = tangentType->getAnyNominal();
    if (!tangentTypeDecl || !tangentTypeDecl->getSelfStructDecl())
      return false;
    // Get effective memberwise initializer.
    auto *memberwiseInitDecl =
        tangentTypeDecl->getEffectiveMemberwiseInitializer();
    // Return false if number of memberwise initializer parameters does not
    // equal number of differentiation properties.
    if (memberwiseInitDecl->getParameters()->size() != diffProperties.size())
      return false;
    // Iterate over all initializer parameters and differentiation properties.
    for (auto pair : llvm::zip(memberwiseInitDecl->getParameters()->getArray(),
                               diffProperties)) {
      auto *initParam = std::get<0>(pair);
      auto *diffProp = std::get<1>(pair);
      // Return false if parameter label does not equal property name.
      if (initParam->getParameterName() != diffProp->getName())
        return false;
      auto diffPropContextualType =
          parentDC->mapTypeIntoContext(diffProp->getValueInterfaceType());
      auto diffPropTangentType =
          getTangentVectorInterfaceType(diffPropContextualType, parentDC);
      // Return false if parameter type does not equal property tangent type.
      if (!initParam->getValueInterfaceType()->isEqual(diffPropTangentType))
        return false;
    }
    return true;
  }();

  // If memberwise derivation is not possible, synthesize
  // `{ TangentVector.zero }` as a fallback.
  if (!canPerformMemberwiseDerivation) {
    auto *module = nominal->getModuleContext();
    auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
    auto confRef = module->lookupConformance(tangentType, addArithProto);
    assert(confRef &&
           "`TangentVector` does not conform to `AdditiveArithmetic`");
    auto *zeroDecl = getProtocolRequirement(addArithProto, C.Id_zero);
    // If conformance reference is concrete, then use concrete witness
    // declaration for the operator.
    if (confRef.isConcrete())
      if (auto *witnessDecl = confRef.getConcrete()->getWitnessDecl(zeroDecl))
        zeroDecl = witnessDecl;
    assert(zeroDecl && "Member method declaration must exist");
    auto *zeroExpr =
        new (C) MemberRefExpr(tangentTypeExpr, SourceLoc(), zeroDecl,
                              DeclNameLoc(), /*Implicit*/ true);

    // Create closure expression.
    unsigned discriminator = 0;
    auto resultTy = funcDecl->getMethodInterfaceType()
                        ->castTo<AnyFunctionType>()
                        ->getResult();

    auto *closureParams = ParameterList::createEmpty(C);
    auto *closure = new (C) ClosureExpr(
        SourceRange(), /*capturedSelfDecl*/ nullptr, closureParams,
        SourceLoc(), SourceLoc(), SourceLoc(), SourceLoc(),
        TypeExpr::createImplicit(resultTy, C), discriminator, funcDecl);
    closure->setImplicit();
    auto *closureReturn = new (C) ReturnStmt(SourceLoc(), zeroExpr, true);
    auto *closureBody =
        BraceStmt::create(C, SourceLoc(), {closureReturn}, SourceLoc(), true);
    closure->setBody(closureBody, /*isSingleExpression=*/true);

    ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), closure, true);
    auto *braceStmt =
        BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true);
    return std::pair<BraceStmt *, bool>(braceStmt, false);
  }

  // Otherwise, perform memberwise derivation.
  // Get effective memberwise initializer: `Nominal.init(...)`.
  auto *tangentTypeDecl = tangentType->getAnyNominal();
  auto *memberwiseInitDecl =
      tangentTypeDecl->getEffectiveMemberwiseInitializer();
  assert(memberwiseInitDecl && "Memberwise initializer must exist");
  auto *initDRE =
      new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
  initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
  auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, tangentTypeExpr);

  // Get references to `self` and parameter declarations.
  auto *selfDecl = funcDecl->getImplicitSelfDecl();

  // Create `self.<member>.zeroTangentVectorInitializer` capture list entry.
  auto createMemberZeroTanInitCaptureListEntry =
      [&](VarDecl *member) -> CaptureListEntry {
    // Create `<member>_zeroTangentVectorInitializer` capture var declaration.
    auto memberCaptureName = C.getIdentifier(std::string(member->getNameStr()) +
                                             "_zeroTangentVectorInitializer");
    auto *memberZeroTanInitCaptureDecl = new (C) VarDecl(
        /*isStatic*/ false, VarDecl::Introducer::Let,
        SourceLoc(), memberCaptureName, funcDecl);
    memberZeroTanInitCaptureDecl->setImplicit();
    auto *memberZeroTanInitPattern =
        NamedPattern::createImplicit(C, memberZeroTanInitCaptureDecl);

    auto *module = nominal->getModuleContext();
    auto memberType =
        parentDC->mapTypeIntoContext(member->getValueInterfaceType());
    auto confRef = module->lookupConformance(memberType, diffProto);
    assert(confRef && "Member does not conform to `Differentiable`");

    // Get member type's `zeroTangentVectorInitializer` requirement witness.
    ValueDecl *memberWitnessDecl = requirement;
    if (confRef.isConcrete())
      if (auto *witness = confRef.getConcrete()->getWitnessDecl(requirement))
        memberWitnessDecl = witness;
    assert(memberWitnessDecl && "Member witness declaration must exist");

    // <member>.zeroTangentVectorInitializer
    auto *selfDRE =
        new (C) DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
    auto *memberExpr =
        new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
                              /*Implicit*/ true);
    auto *memberZeroTangentVectorInitExpr =
        new (C) MemberRefExpr(memberExpr, SourceLoc(), memberWitnessDecl,
                              DeclNameLoc(), /*Implicit*/ true);
    auto *memberZeroTanInitPBD = PatternBindingDecl::createImplicit(
        C, StaticSpellingKind::None, memberZeroTanInitPattern,
        memberZeroTangentVectorInitExpr, funcDecl);
    CaptureListEntry captureEntry(memberZeroTanInitCaptureDecl,
                                  memberZeroTanInitPBD);
    return captureEntry;
  };

  // Create `<member>_zeroTangentVectorInitializer()` call expression.
  auto createMemberZeroTanInitCallExpr =
      [&](CaptureListEntry memberZeroTanInitEntry) -> Expr * {
    // <member>_zeroTangentVectorInitializer
    auto *memberZeroTanInitDRE = new (C) DeclRefExpr(
        memberZeroTanInitEntry.Var, DeclNameLoc(), /*Implicit*/ true);
    // <member>_zeroTangentVectorInitializer()
    auto *memberZeroTangentVector =
        CallExpr::createImplicit(C, memberZeroTanInitDRE, {}, {});
    return memberZeroTangentVector;
  };

  // Collect member zero tangent vector expressions.
  SmallVector<Identifier, 4> memberNames;
  SmallVector<Expr *, 4> memberZeroTanExprs;
  SmallVector<CaptureListEntry, 2> memberZeroTanInitCaptures;
  for (auto *member : diffProperties) {
    memberNames.push_back(member->getName());
    auto memberZeroTanInitCapture =
        createMemberZeroTanInitCaptureListEntry(member);
    memberZeroTanInitCaptures.push_back(memberZeroTanInitCapture);
    memberZeroTanExprs.push_back(
        createMemberZeroTanInitCallExpr(memberZeroTanInitCapture));
  }

  // Create `zeroTangentVectorInitializer` closure body:
  // `TangentVector(x: x_zeroTangentVectorInitializer(), ...)`.
  auto *callExpr =
      CallExpr::createImplicit(C, initExpr, memberZeroTanExprs, memberNames);

  // Create closure expression:
  // `{ TangentVector(x: x_zeroTangentVectorInitializer(), ...) }`.
  unsigned discriminator = 0;
  auto resultTy = funcDecl->getMethodInterfaceType()
                      ->castTo<AnyFunctionType>()
                      ->getResult();
  auto *closureParams = ParameterList::createEmpty(C);
  auto *closure = new (C) ClosureExpr(
      SourceRange(), /*capturedSelfDecl*/ nullptr, closureParams, SourceLoc(),
      SourceLoc(), SourceLoc(), SourceLoc(),
      TypeExpr::createImplicit(resultTy, C), discriminator, funcDecl);
  closure->setImplicit();
  auto *closureReturn = new (C) ReturnStmt(SourceLoc(), callExpr, true);
  auto *closureBody =
      BraceStmt::create(C, SourceLoc(), {closureReturn}, SourceLoc(), true);
  closure->setBody(closureBody, /*isSingleExpression=*/true);

  // Create capture list expression:
  // ```
  // { [x_zeroTangentVectorInitializer = x.zeroTangentVectorInitializer, ...] in
  //     TangentVector(x: x_zeroTangentVectorInitializer(), ...)
  // }
  // ```
  auto *captureList =
      CaptureListExpr::create(C, memberZeroTanInitCaptures, closure);
  captureList->setImplicit();

  ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), captureList, true);
  auto *braceStmt =
      BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true);
  return std::pair<BraceStmt *, bool>(braceStmt, false);
}

/// Synthesize function declaration for a `Differentiable` method requirement.
static ValueDecl *deriveDifferentiable_method(
    DerivedConformance &derived, Identifier methodName, Identifier argumentName,
    Identifier parameterName, Type parameterType, Type returnType,
    AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
  auto *nominal = derived.Nominal;
  auto &C = derived.Context;
  auto *parentDC = derived.getConformanceContext();

  auto *param = new (C) ParamDecl(SourceLoc(), SourceLoc(), argumentName,
                                  SourceLoc(), parameterName, parentDC);
  param->setSpecifier(ParamDecl::Specifier::Default);
  param->setInterfaceType(parameterType);
  ParameterList *params = ParameterList::create(C, {param});

  DeclName declName(C, methodName, params);
  auto *const funcDecl = FuncDecl::createImplicit(
      C, StaticSpellingKind::None, declName, /*NameLoc=*/SourceLoc(),
      /*Async=*/false,
      /*Throws=*/false,
      /*GenericParams=*/nullptr, params, returnType, parentDC);
  if (!nominal->getSelfClassDecl())
    funcDecl->setSelfAccessKind(SelfAccessKind::Mutating);
  funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context);

  funcDecl->setGenericSignature(parentDC->getGenericSignatureOfContext());
  funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);

  derived.addMembersToConformanceContext({funcDecl});
  return funcDecl;
}

/// Synthesize the `move(along:)` function declaration.
static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
  auto &C = derived.Context;
  auto *parentDC = derived.getConformanceContext();
  auto tangentType =
      getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC);
  return deriveDifferentiable_method(
      derived, C.Id_move, C.Id_along, C.Id_direction, tangentType,
      C.TheEmptyTupleType, {deriveBodyDifferentiable_move, nullptr});
}

/// Synthesize the `zeroTangentVectorInitializer` computed property declaration.
static ValueDecl *
deriveDifferentiable_zeroTangentVectorInitializer(DerivedConformance &derived) {
  auto &C = derived.Context;
  auto *parentDC = derived.getConformanceContext();

  auto tangentType =
      getTangentVectorInterfaceType(parentDC->getSelfTypeInContext(), parentDC);
  auto returnType = FunctionType::get({}, tangentType);

  VarDecl *propDecl;
  PatternBindingDecl *pbDecl;
  std::tie(propDecl, pbDecl) = derived.declareDerivedProperty(
      C.Id_zeroTangentVectorInitializer, returnType, returnType,
      /*isStatic*/ false, /*isFinal*/ true);

  // Define the getter.
  auto *getterDecl =
      derived.addGetterToReadOnlyDerivedProperty(propDecl, returnType);
  // Add an implicit `@noDerivative` attribute.
  // `zeroTangentVectorInitializer` getter calls should never be differentiated.
  getterDecl->getAttrs().add(new (C) NoDerivativeAttr(/*Implicit*/ true));
  getterDecl->setBodySynthesizer(
      &deriveBodyDifferentiable_zeroTangentVectorInitializer);
  derived.addMembersToConformanceContext({propDecl, pbDecl});
  return propDecl;
}

/// Pushes all the protocols inherited, directly or transitively, by `decl` to `protos`.
///
/// Precondition: `decl` is a nominal type decl or an extension decl.
void getInheritedProtocols(Decl *decl, SmallPtrSetImpl<ProtocolDecl *> &protos) {
  ArrayRef<TypeLoc> inheritedTypeLocs;
  if (auto *nominalDecl = dyn_cast<NominalTypeDecl>(decl))
    inheritedTypeLocs = nominalDecl->getInherited();
  else if (auto *extDecl = dyn_cast<ExtensionDecl>(decl))
    inheritedTypeLocs = extDecl->getInherited();
  else
    llvm_unreachable("conformance is not a nominal or an extension");

  std::function<void(Type)> handleInheritedType;

  auto handleProto = [&](ProtocolType *proto) -> void {
    proto->getDecl()->walkInheritedProtocols([&](ProtocolDecl *p) -> TypeWalker::Action {
      protos.insert(p);
      return TypeWalker::Action::Continue;
    });
  };

  auto handleProtoComp = [&](ProtocolCompositionType *comp) -> void {
    for (auto ty : comp->getMembers())
      handleInheritedType(ty);
  };

  handleInheritedType = [&](Type ty) -> void {
    if (auto *proto = ty->getAs<ProtocolType>())
      handleProto(proto);
    else if (auto *comp = ty->getAs<ProtocolCompositionType>())
      handleProtoComp(comp);
  };

  for (auto loc : inheritedTypeLocs) {
    if (loc.getTypeRepr())
      handleInheritedType(TypeResolution::forStructural(
          cast<DeclContext>(decl), None, /*unboundTyOpener*/ nullptr)
        .resolveType(loc.getTypeRepr()));
    else
      handleInheritedType(loc.getType());
  }
}

/// Return associated `TangentVector` struct for a nominal type, if it exists.
/// If not, synthesize the struct.
static StructDecl *
getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
  auto *parentDC = derived.getConformanceContext();
  auto *nominal = derived.Nominal;
  auto &C = nominal->getASTContext();

  // If the associated struct already exists, return it.
  auto lookup = nominal->lookupDirect(C.Id_TangentVector);
  assert(lookup.size() < 2 &&
         "Expected at most one associated type named `TangentVector`");
  if (lookup.size() == 1) {
    auto *structDecl = convertToStructDecl(lookup.front());
    assert(structDecl && "Expected lookup result to be a struct");
    return structDecl;
  }

  // Otherwise, synthesize a new struct.

  // Compute `tvDesiredProtos`, the set of protocols that the new `TangentVector` struct must
  // inherit, by collecting all the `TangentVector` conformance requirements imposed by the
  // protocols that `derived.ConformanceDecl` inherits.
  //
  // Note that, for example, this will always find `AdditiveArithmetic` and `Differentiable` because
  // the `Differentiable` protocol itself requires that its `TangentVector` conforms to
  // `AdditiveArithmetic` and `Differentiable`.
  llvm::SmallPtrSet<ProtocolType *, 4> tvDesiredProtos;
  llvm::SmallPtrSet<ProtocolDecl *, 4> conformanceInheritedProtos;
  getInheritedProtocols(derived.ConformanceDecl, conformanceInheritedProtos);
  auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
  auto *tvAssocType = diffableProto->getAssociatedType(C.Id_TangentVector);
  for (auto proto : conformanceInheritedProtos) {
    for (auto req : proto->getRequirementSignature()) {
      if (req.getKind() != RequirementKind::Conformance)
        continue;
      auto *firstType = req.getFirstType()->getAs<DependentMemberType>();
      if (!firstType || firstType->getAssocType() != tvAssocType)
        continue;
      auto tvRequiredProto = req.getSecondType()->getAs<ProtocolType>();
      if (!tvRequiredProto)
        continue;
      tvDesiredProtos.insert(tvRequiredProto);
    }
  }
  SmallVector<TypeLoc, 4> tvDesiredProtoTypeLocs;
  for (auto *p : tvDesiredProtos)
    tvDesiredProtoTypeLocs.push_back(TypeLoc::withoutLoc(p));

  // Cache original members and their associated types for later use.
  SmallVector<VarDecl *, 8> diffProperties;
  getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties);

  auto synthesizedLoc = derived.ConformanceDecl->getEndLoc();
  auto *structDecl =
      new (C) StructDecl(synthesizedLoc, C.Id_TangentVector, synthesizedLoc,
                         /*Inherited*/ C.AllocateCopy(tvDesiredProtoTypeLocs),
                         /*GenericParams*/ {}, parentDC);
  structDecl->setBraces({synthesizedLoc, synthesizedLoc});
  structDecl->setImplicit();
  structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);

  // Add stored properties to the `TangentVector` struct.
  for (auto *member : diffProperties) {
    // Add a tangent stored property to the `TangentVector` struct, with the
    // name and `TangentVector` type of the original property.
    auto *tangentProperty = new (C) VarDecl(
        member->isStatic(), member->getIntroducer(),
        /*NameLoc*/ SourceLoc(), member->getName(), structDecl);
    // Note: `tangentProperty` is not marked as implicit here, because that
    // incorrectly affects memberwise initializer synthesis.
    auto memberContextualType =
        parentDC->mapTypeIntoContext(member->getValueInterfaceType());
    auto memberTanType =
        getTangentVectorInterfaceType(memberContextualType, parentDC);
    tangentProperty->setInterfaceType(memberTanType);
    Pattern *memberPattern = NamedPattern::createImplicit(C, tangentProperty);
    memberPattern->setType(memberTanType);
    memberPattern =
        TypedPattern::createImplicit(C, memberPattern, memberTanType);
    memberPattern->setType(memberTanType);
    auto *memberBinding = PatternBindingDecl::createImplicit(
        C, StaticSpellingKind::None, memberPattern, /*initExpr*/ nullptr,
        structDecl);
    structDecl->addMember(tangentProperty);
    structDecl->addMember(memberBinding);
    tangentProperty->copyFormalAccessFrom(member,
                                          /*sourceIsParentContext*/ true);
    tangentProperty->setSetterAccess(member->getFormalAccess());

    // Cache the tangent property.
    C.evaluator.cacheOutput(TangentStoredPropertyRequest{member, CanType()},
                            TangentPropertyInfo(tangentProperty));

    // Now that the original property has a corresponding tangent property, it
    // should be marked `@differentiable` so that the differentiation transform
    // will synthesize derivative functions for its accessors. We only add this
    // to public stored properties, because their access outside the module will
    // go through accessor declarations.
    if (member->getEffectiveAccess() > AccessLevel::Internal &&
        !member->getAttrs().hasAttribute<DifferentiableAttr>()) {
      auto *getter = member->getSynthesizedAccessor(AccessorKind::Get);
      (void)getter->getInterfaceType();
      // If member or its getter already has a `@differentiable` attribute,
      // continue.
      if (member->getAttrs().hasAttribute<DifferentiableAttr>() ||
          getter->getAttrs().hasAttribute<DifferentiableAttr>())
        continue;
      GenericSignature derivativeGenericSignature =
          getter->getGenericSignature();
      // If the parent declaration context is an extension, the nominal type may
      // conditionally conform to `Differentiable`. Use the extension generic
      // requirements in getter `@differentiable` attributes.
      if (auto *extDecl = dyn_cast<ExtensionDecl>(parentDC->getAsDecl()))
        if (auto extGenSig = extDecl->getGenericSignature())
          derivativeGenericSignature = extGenSig;
      auto *diffableAttr = DifferentiableAttr::create(
          getter, /*implicit*/ true, SourceLoc(), SourceLoc(),
          /*linear*/ false, /*parameterIndices*/ IndexSubset::get(C, 1, {0}),
          derivativeGenericSignature);
      member->getAttrs().add(diffableAttr);
    }
  }

  // If nominal type is `@_fixed_layout`, also mark `TangentVector` struct as
  // `@_fixed_layout`.
  if (nominal->getAttrs().hasAttribute<FixedLayoutAttr>())
    addFixedLayoutAttr(structDecl);

  // If nominal type is `@frozen`, also mark `TangentVector` struct as
  // `@frozen`.
  if (nominal->getAttrs().hasAttribute<FrozenAttr>())
    structDecl->getAttrs().add(new (C) FrozenAttr(/*implicit*/ true));

  // If nominal type is `@usableFromInline`, also mark `TangentVector` struct as
  // `@usableFromInline`.
  if (nominal->getAttrs().hasAttribute<UsableFromInlineAttr>())
    structDecl->getAttrs().add(new (C) UsableFromInlineAttr(/*implicit*/ true));

  // The implicit memberwise constructor must be explicitly created so that it
  // can called in `AdditiveArithmetic` and `Differentiable` methods. Normally,
  // the memberwise constructor is synthesized during SILGen, which is too late.
  TypeChecker::addImplicitConstructors(structDecl);

  // After memberwise initializer is synthesized, mark members as implicit.
  for (auto *member : structDecl->getStoredProperties())
    member->setImplicit();

  derived.addMembersToConformanceContext({structDecl});
  return structDecl;
}

/// Diagnose stored properties in the nominal that do not have an explicit
/// `@noDerivative` attribute, but either:
/// - Do not conform to `Differentiable`.
/// - Are a `let` stored property.
/// Emit a warning and a fixit so that users will make the attribute explicit.
static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
                                                 NominalTypeDecl *nominal,
                                                 DeclContext *DC) {
  // If nominal type can conform to `AdditiveArithmetic`, suggest adding a
  // conformance to `AdditiveArithmetic` in fix-its.
  // `Differentiable` protocol requirements all have default implementations
  // when `Self` conforms to `AdditiveArithmetic`, so `Differentiable`
  // derived conformances will no longer be necessary.
  bool nominalCanDeriveAdditiveArithmetic =
      DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC);
  auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
  // Check all stored properties.
  for (auto *vd : nominal->getStoredProperties()) {
    // Peer through property wrappers: use original wrapped properties.
    if (auto *originalProperty = vd->getOriginalWrappedProperty()) {
      // Skip wrapped properties with `@noDerivative` attribute.
      if (originalProperty->getAttrs().hasAttribute<NoDerivativeAttr>())
        continue;
      // Diagnose wrapped properties whose property wrappers do not define
      // `wrappedValue.set`. `mutating func move(along:)` cannot be synthesized
      // to update these properties.
      if (!originalProperty->isSettable(DC)) {
        auto *wrapperDecl =
            vd->getInterfaceType()->getNominalOrBoundGenericNominal();
        auto loc =
            originalProperty->getAttributeInsertionLoc(/*forModifier*/ false);
        Context.Diags
            .diagnose(
                loc,
                diag::
                    differentiable_immutable_wrapper_implicit_noderivative_fixit,
                wrapperDecl->getName(), nominal->getName(),
                nominalCanDeriveAdditiveArithmetic)
            .fixItInsert(loc, "@noDerivative ");
        // Add an implicit `@noDerivative` attribute.
        originalProperty->getAttrs().add(
            new (Context) NoDerivativeAttr(/*Implicit*/ true));
        continue;
      }
      // Use the original wrapped property.
      vd = originalProperty;
    }
    if (vd->getInterfaceType()->hasError())
      continue;
    // Skip stored properties with `@noDerivative` attribute.
    if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
      continue;
    // Check whether to diagnose stored property.
    auto varType = DC->mapTypeIntoContext(vd->getValueInterfaceType());
    auto diffableConformance =
        TypeChecker::conformsToProtocol(varType, diffableProto, nominal);
    // If stored property should not be diagnosed, continue.
    if (diffableConformance && 
        canInvokeMoveAlongOnProperty(vd, diffableConformance))
      continue;
    // Otherwise, add an implicit `@noDerivative` attribute.
    vd->getAttrs().add(new (Context) NoDerivativeAttr(/*Implicit*/ true));
    auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false);
    assert(loc.isValid() && "Expected valid source location");
    // Diagnose properties that do not conform to `Differentiable`.
    if (!diffableConformance) {
      Context.Diags
          .diagnose(
              loc,
              diag::differentiable_nondiff_type_implicit_noderivative_fixit,
              vd->getName(), vd->getType(), nominal->getName(),
              nominalCanDeriveAdditiveArithmetic)
          .fixItInsert(loc, "@noDerivative ");
      continue;
    }
    // Otherwise, diagnose `let` property.
    Context.Diags
        .diagnose(loc,
                  diag::differentiable_let_property_implicit_noderivative_fixit,
                  nominal->getName(), nominalCanDeriveAdditiveArithmetic)
        .fixItInsert(loc, "@noDerivative ");
  }
}

/// Get or synthesize `TangentVector` struct type.
static std::pair<Type, TypeDecl *>
getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
  auto *parentDC = derived.getConformanceContext();
  auto *nominal = derived.Nominal;
  auto &C = nominal->getASTContext();

  // Get or synthesize `TangentVector` struct.
  auto *tangentStruct =
      getOrSynthesizeTangentVectorStruct(derived, C.Id_TangentVector);
  if (!tangentStruct)
    return std::make_pair(nullptr, nullptr);

  // Check and emit warnings for implicit `@noDerivative` members.
  checkAndDiagnoseImplicitNoDerivative(C, nominal, parentDC);

  // Return the `TangentVector` struct type.
  return std::make_pair(
    parentDC->mapTypeIntoContext(
      tangentStruct->getDeclaredInterfaceType()),
    tangentStruct);
}

/// Synthesize the `TangentVector` struct type.
static std::pair<Type, TypeDecl *>
deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) {
  auto *parentDC = derived.getConformanceContext();
  auto *nominal = derived.Nominal;

  // If nominal type can derive `TangentVector` as the contextual `Self` type,
  // return it.
  if (canDeriveTangentVectorAsSelf(nominal, parentDC))
    return std::make_pair(parentDC->getSelfTypeInContext(), nullptr);

  // Otherwise, get or synthesize `TangentVector` struct type.
  return getOrSynthesizeTangentVectorStructType(derived);
}

ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) {
  // Diagnose unknown requirements.
  if (requirement->getBaseName() != Context.Id_move &&
      requirement->getBaseName() != Context.Id_zeroTangentVectorInitializer) {
    Context.Diags.diagnose(requirement->getLoc(),
                           diag::broken_differentiable_requirement);
    return nullptr;
  }
  // Diagnose conformances in disallowed contexts.
  if (checkAndDiagnoseDisallowedContext(requirement))
    return nullptr;

  // Start an error diagnostic before attempting derivation.
  // If derivation succeeds, cancel the diagnostic.
  DiagnosticTransaction diagnosticTransaction(Context.Diags);
  ConformanceDecl->diagnose(diag::type_does_not_conform,
                            Nominal->getDeclaredType(), getProtocolType());
  requirement->diagnose(diag::no_witnesses,
                        getProtocolRequirementKind(requirement),
                        requirement->getName(), getProtocolType(),
                        /*AddFixIt=*/false);

  // If derivation is possible, cancel the diagnostic and perform derivation.
  if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) {
    diagnosticTransaction.abort();
    if (requirement->getBaseName() == Context.Id_move)
      return deriveDifferentiable_move(*this);
    if (requirement->getBaseName() == Context.Id_zeroTangentVectorInitializer)
      return deriveDifferentiable_zeroTangentVectorInitializer(*this);
  }

  // Otheriwse, return nullptr.
  return nullptr;
}

std::pair<Type, TypeDecl *>
DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
  // Diagnose unknown requirements.
  if (requirement->getBaseName() != Context.Id_TangentVector) {
    Context.Diags.diagnose(requirement->getLoc(),
                           diag::broken_differentiable_requirement);
    return std::make_pair(nullptr, nullptr);
  }

  // Start an error diagnostic before attempting derivation.
  // If derivation succeeds, cancel the diagnostic.
  DiagnosticTransaction diagnosticTransaction(Context.Diags);
  ConformanceDecl->diagnose(diag::type_does_not_conform,
                            Nominal->getDeclaredType(), getProtocolType());
  requirement->diagnose(diag::no_witnesses_type, requirement->getName());

  // If derivation is possible, cancel the diagnostic and perform derivation.
  if (canDeriveDifferentiable(Nominal, getConformanceContext(), requirement)) {
    diagnosticTransaction.abort();
    return deriveDifferentiable_TangentVectorStruct(*this);
  }

  // Otherwise, return nullptr.
  return std::make_pair(nullptr, nullptr);
}
