//===--- Differentiation.cpp - SIL Automatic Differentiation --*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// SWIFT_ENABLE_TENSORFLOW
//
// This file implements automatic differentiation.
//
// NOTE: Although the AD feature is developed as part of the Swift for
// TensorFlow project, it is completely independent from TensorFlow support.
//
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "differentiation"

#include "Differentiation.h"
#include "swift/AST/ASTMangler.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/AnyFunctionRef.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Builtins.h"
#include "swift/AST/DeclContext.h"
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/SourceFile.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/SIL/FormalLinkage.h"
#include "swift/SIL/LoopInfo.h"
#include "swift/SIL/Projection.h"
#include "swift/SIL/SILBuilder.h"
#include "swift/SIL/TypeSubstCloner.h"
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
#include "swift/SILOptimizer/Analysis/LoopAnalysis.h"
#include "swift/SILOptimizer/PassManager/Passes.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
#include "swift/SILOptimizer/Utils/LoopUtils.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/BreadthFirstIterator.h"
#include "llvm/ADT/DenseSet.h"

using namespace swift;
using llvm::DenseMap;
using llvm::SmallDenseMap;
using llvm::SmallDenseSet;
using llvm::SmallMapVector;
using llvm::SmallSet;

/// This flag is used to disable `differentiable_function_extract` instruction
/// folding for SIL testing purposes.
static llvm::cl::opt<bool> SkipFoldingDifferentiableFunctionExtraction(
    "differentiation-skip-folding-differentiable-function-extraction",
    llvm::cl::init(true));

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//

/// Prints an "[AD] " prefix to `llvm::dbgs()` and returns the debug stream.
/// This is being used to print short debug messages within the AD pass.
static raw_ostream &getADDebugStream() { return llvm::dbgs() << "[AD] "; }

/// Given a dumpable value, dumps it to `llvm::dbgs()`.
template <typename T> static inline void debugDump(T &v) {
  LLVM_DEBUG(llvm::dbgs() << "\n==== BEGIN DEBUG DUMP ====\n"
                          << v << "\n==== END DEBUG DUMP ====\n");
}

static bool isWithoutDerivative(SILValue v) {
  if (auto *fnRef = dyn_cast<FunctionRefInst>(v))
    return fnRef->getReferencedFunctionOrNull()->hasSemanticsAttr(
        "autodiff.nonvarying");
  return false;
}

static bool isArrayLiteralIntrinsic(ApplyInst *ai) {
  return ai->hasSemantics("array.uninitialized_intrinsic");
}

static ApplyInst *getAllocateUninitializedArrayIntrinsic(SILValue v) {
  if (auto *applyInst = dyn_cast<ApplyInst>(v))
    if (isArrayLiteralIntrinsic(applyInst))
      return applyInst;
  return nullptr;
}

/// Given a value, find its single `destructure_tuple` user if the value is
/// tuple-typed and such a user exists.
static DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
  bool foundDestructureTupleUser = false;
  if (!value->getType().is<TupleType>())
    return nullptr;
  DestructureTupleInst *result = nullptr;
  for (auto *use : value->getUses()) {
    if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser())) {
      assert(!foundDestructureTupleUser &&
             "There should only be one `destructure_tuple` user of a tuple");
      foundDestructureTupleUser = true;
      result = dti;
    }
  }
  return result;
}

/// Given an `apply` instruction, apply the given callback to each of its
/// direct results. If the `apply` instruction has a single `destructure_tuple`
/// user, apply the callback to the results of the `destructure_tuple` user.
static void forEachApplyDirectResult(
    ApplyInst *ai, llvm::function_ref<void(SILValue)> resultCallback) {
  if (!ai->getType().is<TupleType>()) {
    resultCallback(ai);
    return;
  }
  if (auto *dti = getSingleDestructureTupleUser(ai))
    for (auto result : dti->getResults())
      resultCallback(result);
}

/// Given a function, gather all of its formal results (both direct and
/// indirect) in an order defined by its result type. Note that "formal results"
/// refer to result values in the body of the function, not at call sites.
static void
collectAllFormalResultsInTypeOrder(SILFunction &function,
                                   SmallVectorImpl<SILValue> &results) {
  SILFunctionConventions convs(function.getLoweredFunctionType(),
                               function.getModule());
  auto indResults = function.getIndirectResults();
  auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
  auto retVal = retInst->getOperand();
  SmallVector<SILValue, 8> dirResults;
  if (auto *tupleInst =
          dyn_cast_or_null<TupleInst>(retVal->getDefiningInstruction()))
    dirResults.append(tupleInst->getElements().begin(),
                      tupleInst->getElements().end());
  else
    dirResults.push_back(retVal);
  unsigned indResIdx = 0, dirResIdx = 0;
  for (auto &resInfo : convs.getResults())
    results.push_back(resInfo.isFormalDirect() ? dirResults[dirResIdx++]
                                               : indResults[indResIdx++]);
}

/// Given a function, gather all of its direct results in an order defined by
/// its result type. Note that "formal results" refer to result values in the
/// body of the function, not at call sites.
static void
collectAllDirectResultsInTypeOrder(SILFunction &function,
                                   SmallVectorImpl<SILValue> &results) {
  SILFunctionConventions convs(function.getLoweredFunctionType(),
                               function.getModule());
  auto *retInst = cast<ReturnInst>(function.findReturnBB()->getTerminator());
  auto retVal = retInst->getOperand();
  if (auto *tupleInst = dyn_cast<TupleInst>(retVal))
    results.append(tupleInst->getElements().begin(),
                   tupleInst->getElements().end());
  else
    results.push_back(retVal);
}

/// Given a function call site, gather all of its actual results (both direct
/// and indirect) in an order defined by its result type.
static void collectAllActualResultsInTypeOrder(
    ApplyInst *ai, ArrayRef<SILValue> extractedDirectResults,
    SmallVectorImpl<SILValue> &results) {
  auto calleeConvs = ai->getSubstCalleeConv();
  unsigned indResIdx = 0, dirResIdx = 0;
  for (auto &resInfo : calleeConvs.getResults()) {
    results.push_back(resInfo.isFormalDirect()
                          ? extractedDirectResults[dirResIdx++]
                          : ai->getIndirectSILResults()[indResIdx++]);
  }
}

/// Given a range of types, joins these into a single type. If there's exactly
/// one element type, returns that element type. Otherwise, creates a tuple type
/// of all element types.
template <typename TypeRange>
static CanType joinElementTypes(TypeRange &&range, const ASTContext &ctx) {
  if (range.size() == 1)
    return range.front();
  auto typeElts =
      map<SmallVector<TupleTypeElt, 8>>(range, [&](Type type) { return type; });
  return TupleType::get(typeElts, ctx);
}

/// Given a range of SIL values, retrieves the canonical types of these values,
/// and joins these types into a single type.
template <typename SILValueRange>
static CanType joinElementTypesFromValues(SILValueRange &&range,
                                          const ASTContext &ctx) {
  if (range.size() == 1)
    return range.front()->getType().getASTType();
  SmallVector<TupleTypeElt, 8> elts;
  transform(range, elts.begin(),
            [&](SILValue val) { return val->getType().getASTType(); });
  return TupleType::get(elts, ctx)->getCanonicalType();
}

/// Given an operator name, such as '+', and a protocol, returns the '+'
/// operator. If the operator does not exist in the protocol, returns null.
static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName,
                                            ProtocolDecl *protocol) {
  assert(operatorName.isOperator());
  // Find the operator requirement in the given protocol declaration.
  auto opLookup = protocol->lookupDirect(operatorName);
  for (auto *decl : opLookup) {
    if (!decl->isProtocolRequirement())
      continue;
    auto *fd = dyn_cast<FuncDecl>(decl);
    if (!fd || !fd->isStatic() || !fd->isOperator())
      continue;
    return fd;
  }
  // Not found.
  return nullptr;
}

/// Returns the "constrained" derivative generic signature given:
/// - An original SIL function type.
/// - A wrt parameter index subset.
/// - A possibly uncanonical derivative generic signature (optional).
/// - Additional derivative requirements (optional).
/// The constrained derivative generic signature constrains all wrt parameters
/// to conform to `Differentiable`.
static GenericSignature getConstrainedDerivativeGenericSignature(
    CanSILFunctionType originalFnTy, IndexSubset *paramIndexSet,
    GenericSignature derivativeGenSig) {
  if (!derivativeGenSig)
    derivativeGenSig = originalFnTy->getGenericSignature();
  if (!derivativeGenSig)
    return nullptr;
  // Constrain all wrt parameters to `Differentiable`.
  auto &ctx = derivativeGenSig->getASTContext();
  auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
  SmallVector<Requirement, 4> requirements;
  for (unsigned paramIdx : paramIndexSet->getIndices()) {
    auto paramType = originalFnTy->getParameters()[paramIdx].getType();
    Requirement req(RequirementKind::Conformance, paramType,
                    diffableProto->getDeclaredType());
    requirements.push_back(req);
  }
  return evaluateOrDefault(
      ctx.evaluator,
      AbstractGenericSignatureRequest{
          derivativeGenSig.getPointer(),
          /*addedGenericParams*/ {},
          std::move(requirements)},
      nullptr);
}

/// Returns the canonical derivative generic signature for the given
/// `[differentiable]` attribute and original function.
/// - Return the `[differentiable]` attribute derivative generic signature if
///   it exists.
/// - Otherwise, return the original function's generic signature.
static CanGenericSignature getDerivativeGenericSignature(
    SILDifferentiableAttr *attr, SILFunction *original) {
  if (auto attrDerivativeGenSig = attr->getDerivativeGenericSignature())
    return attrDerivativeGenSig->getCanonicalSignature();
  return original->getLoweredFunctionType()->getGenericSignature();
}

// Clone the generic parameters of the given generic signature and return a new
// `GenericParamList`.
static GenericParamList *cloneGenericParameters(ASTContext &ctx,
                                                DeclContext *dc,
                                                CanGenericSignature sig) {
  SmallVector<GenericTypeParamDecl *, 2> clonedParams;
  for (auto paramType : sig->getGenericParams()) {
    auto clonedParam = new (ctx) GenericTypeParamDecl(
        dc, paramType->getName(), SourceLoc(), paramType->getDepth(),
        paramType->getIndex());
    clonedParam->setDeclContext(dc);
    clonedParam->setImplicit(true);
    clonedParams.push_back(clonedParam);
  }
  return GenericParamList::create(ctx, SourceLoc(), clonedParams, SourceLoc());
}

/// Given an `differentiable_function` instruction, find the corresponding
/// differential operator used in the AST. If no differential operator is found,
/// return nullptr.
static DifferentiableFunctionExpr *
findDifferentialOperator(DifferentiableFunctionInst *inst) {
  return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>();
}

/// Returns the underlying instruction for the given SILValue, if it exists,
/// peering through function conversion instructions.
template<class Inst>
static Inst *peerThroughFunctionConversions(SILValue value) {
  if (auto *inst = dyn_cast<Inst>(value))
    return inst;
  if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(value))
    return peerThroughFunctionConversions<Inst>(thinToThick->getOperand());
  if (auto *convertFn = dyn_cast<ConvertFunctionInst>(value))
    return peerThroughFunctionConversions<Inst>(convertFn->getOperand());
  if (auto *partialApply = dyn_cast<PartialApplyInst>(value))
    return peerThroughFunctionConversions<Inst>(partialApply->getCallee());
  return nullptr;
}

//===----------------------------------------------------------------------===//
// Auxiliary data structures
//===----------------------------------------------------------------------===//

namespace {
class ADContext;

/// The invoker of a differentiation task. It can be some user syntax, e.g.
/// an `differentiable_function` instruction lowered from an
/// `DifferentiableFunctionExpr` expression, the differentiation pass, or
/// nothing at all. This will be used to emit informative diagnostics.
struct DifferentiationInvoker {
public:
  /// The kind of the invoker of a differentiation task.
  enum class Kind {
    // Invoked by an `differentiable_function` instruction, which may or may not
    // be linked to a Swift AST node (e.g. an `DifferentiableFunctionExpr`
    // expression).
    DifferentiableFunctionInst,

    // Invoked by the indirect application of differentiation. This case has an
    // associated original `apply` instruction and `[differentiable]` attribute.
    IndirectDifferentiation,

    // Invoker by a `[differentiable]` attribute in SIL **without** being linked
    // to a Swift AST attribute. This case has an associated `[differentiable]`
    // attribute.
    SILDifferentiableAttribute
  };

private:
  Kind kind;
  union Value {
    /// The instruction associated with the `DifferentiableFunctionInst` case.
    DifferentiableFunctionInst *diffFuncInst;
    Value(DifferentiableFunctionInst *inst) : diffFuncInst(inst) {}

    /// The parent `apply` instruction and `[differentiable]` attribute
    /// associated with the `IndirectDifferentiation` case.
    std::pair<ApplyInst *, SILDifferentiableAttr *>
        indirectDifferentiation;
    Value(ApplyInst *applyInst, SILDifferentiableAttr *attr)
        : indirectDifferentiation({applyInst, attr}) {}

    /// The `[differentiable]` attribute associated with the
    /// `SILDifferentiableAttribute` case.
    SILDifferentiableAttr *silDifferentiableAttribute;
    Value(SILDifferentiableAttr *attr) : silDifferentiableAttribute(attr) {}
  } value;

  /*implicit*/
  DifferentiationInvoker(Kind kind, Value value) : kind(kind), value(value) {}

public:
  DifferentiationInvoker(DifferentiableFunctionInst *inst)
      : kind(Kind::DifferentiableFunctionInst), value(inst) {}
  DifferentiationInvoker(ApplyInst *applyInst, SILDifferentiableAttr *attr)
      : kind(Kind::IndirectDifferentiation),
        value({applyInst, attr}) {}
  DifferentiationInvoker(SILDifferentiableAttr *attr)
      : kind(Kind::SILDifferentiableAttribute), value(attr) {}

  Kind getKind() const { return kind; }

  DifferentiableFunctionInst *getDifferentiableFunctionInst() const {
    assert(kind == Kind::DifferentiableFunctionInst);
    return value.diffFuncInst;
  }

  std::pair<ApplyInst *, SILDifferentiableAttr *>
  getIndirectDifferentiation() const {
    assert(kind == Kind::IndirectDifferentiation);
    return value.indirectDifferentiation;
  }


  SILDifferentiableAttr *getSILDifferentiableAttribute() const {
    assert(kind == Kind::SILDifferentiableAttribute);
    return value.silDifferentiableAttribute;
  }

  SourceLoc getLocation() const {
    switch (kind) {
    case Kind::DifferentiableFunctionInst:
      return getDifferentiableFunctionInst()->getLoc().getSourceLoc();
    case Kind::IndirectDifferentiation:
      return getIndirectDifferentiation().first->getLoc().getSourceLoc();
    case Kind::SILDifferentiableAttribute:
      return getSILDifferentiableAttribute()->getOriginal()
          ->getLocation().getSourceLoc();
    }
  }

  void print(llvm::raw_ostream &os) const;
};

class DifferentiableActivityInfo;

/// Information about the JVP/VJP function produced during JVP/VJP generation,
/// e.g. mappings from original values to corresponding values in the
/// pullback/differential struct.
///
/// A linear map struct is an aggregate value containing linear maps checkpointed
/// during the JVP/VJP computation. Linear map structs are generated for every
/// original function during JVP/VJP generation. Linear map struct values are
/// constructed by JVP/VJP functions and consumed by pullback/differential
/// functions.
class LinearMapInfo {
private:
  /// The linear map kind.
  AutoDiffLinearMapKind kind;

  /// The original function.
  SILFunction *const original;

  /// The derivative function.
  SILFunction *const derivative;

  /// Activity info of the original function.
  const DifferentiableActivityInfo &activityInfo;

  /// Differentiation indices of the function.
  const SILAutoDiffIndices &indices;

  /// Mapping from original basic blocks to linear map structs.
  DenseMap<SILBasicBlock *, StructDecl *> linearMapStructs;

  /// Mapping from original basic blocks to branching trace enums.
  /// For pullbacks: these are predecessor enums.
  /// For differentials: these are successor enums.
  DenseMap<SILBasicBlock *, EnumDecl *> branchingTraceDecls;

  /// Mapping from `apply` and `struct_extract` instructions in the original
  /// function to the corresponding linear map declaration in the linear map
  /// struct.
  DenseMap<SILInstruction *, VarDecl *> linearMapValueMap;

  /// Mapping from predecessor+succcessor basic block pairs in original function
  /// to the corresponding branching trace enum case.
  DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, EnumElementDecl *>
      branchingTraceEnumCases;

  /// Mapping from linear map structs to their branching trace enum fields.
  DenseMap<StructDecl *, VarDecl *> linearMapStructEnumFields;

  /// A type converter, used to compute struct/enum SIL types.
  Lowering::TypeConverter &typeConverter;

private:
  /// Remaps the given type into the derivative function's context.
  SILType remapTypeInDerivative(SILType ty) {
    if (ty.hasArchetype())
      return derivative->mapTypeIntoContext(ty.mapTypeOutOfContext());
    return derivative->mapTypeIntoContext(ty);
  }

  /// Adds a `VarDecl` member with the given name and type to the given nominal
  /// declaration.
  VarDecl *addVarDecl(NominalTypeDecl *nominal, StringRef name, Type type) {
    auto &astCtx = nominal->getASTContext();
    auto id = astCtx.getIdentifier(name);
    auto *varDecl = new (astCtx) VarDecl(
        /*IsStatic*/ false, VarDecl::Introducer::Var, /*IsCaptureList*/ false,
        SourceLoc(), id, nominal);
    varDecl->setAccess(nominal->getEffectiveAccess());
    if (type->hasArchetype())
      varDecl->setInterfaceType(type->mapTypeOutOfContext());
    else
      varDecl->setInterfaceType(type);
    nominal->addMember(varDecl);
    return varDecl;
  }

  /// Retrieves the file unit that contains implicit declarations in the
  /// current Swift module. If it does not exist, create one.
  ///
  // FIXME: Currently it defaults to the file containing `original`, if it can
  // be determined. Otherwise, it defaults to any file unit in the module. To
  // handle this more properly, we could revive the DerivedFileUnit class to
  // contain all synthesized implicit type declarations.
  SourceFile &getDeclarationFileUnit() {
    if (original->hasLocation())
      if (auto *declContext = original->getLocation().getAsDeclContext())
        if (auto *parentSourceFile = declContext->getParentSourceFile())
          return *parentSourceFile;
    for (auto *file : original->getModule().getSwiftModule()->getFiles())
      if (auto *src = dyn_cast<SourceFile>(file))
        return *src;
    llvm_unreachable("No files?");
  }

  /// Compute and set the access level for the given nominal type, given the
  /// original function linkage.
  void computeAccessLevel(
      NominalTypeDecl *nominal, SILLinkage originalLinkage) {
    auto &astCtx = nominal->getASTContext();
    switch (originalLinkage) {
    case swift::SILLinkage::Public:
    case swift::SILLinkage::PublicNonABI:
      nominal->setAccess(AccessLevel::Internal);
      nominal->getAttrs().add(
          new (astCtx) UsableFromInlineAttr(/*Implicit*/ true));
      break;
    case swift::SILLinkage::Hidden:
    case swift::SILLinkage::Shared:
      nominal->setAccess(AccessLevel::Internal);
      break;
    case swift::SILLinkage::Private:
      nominal->setAccess(AccessLevel::FilePrivate);
      break;
    default:
      // When the original function has external linkage, we create an internal
      // struct for use by our own module. This is necessary for cross-cell
      // differentiation in Jupyter.
      // TODO: Add a test in the compiler that exercises a similar situation as
      // cross-cell differentiation in Jupyter.
      nominal->setAccess(AccessLevel::Internal);
    }
  }

  /// Creates an enum declaration with the given JVP/VJP generic signature,
  /// whose cases represent the predecessors/successors of the given original
  /// block.
  EnumDecl *createBranchingTraceDecl(SILBasicBlock *originalBB,
                                     SILAutoDiffIndices indices,
                                     CanGenericSignature genericSig,
                                     SILLoopInfo *loopInfo) {
    assert(originalBB->getParent() == original);
    auto &astCtx = original->getASTContext();
    auto *moduleDecl = original->getModule().getSwiftModule();
    auto &file = getDeclarationFileUnit();
    // Create a branching trace enum.
    std::string enumName;
    switch (kind) {
    case AutoDiffLinearMapKind::Differential:
      enumName =
          "_AD__" + original->getName().str() +
          "_bb" + std::to_string(originalBB->getDebugID()) +
          "__Succ__" + indices.mangle();
      break;
    case AutoDiffLinearMapKind::Pullback:
      enumName =
          "_AD__" + original->getName().str() +
          "_bb" + std::to_string(originalBB->getDebugID()) +
          "__Pred__" + indices.mangle();
      break;
    }
    auto enumId = astCtx.getIdentifier(enumName);
    auto loc = original->getLocation().getSourceLoc();
    GenericParamList *genericParams = nullptr;
    if (genericSig)
      genericParams = cloneGenericParameters(astCtx, &file, genericSig);
    auto *branchingTraceDecl = new (astCtx) EnumDecl(
        /*EnumLoc*/ SourceLoc(), /*Name*/ enumId, /*NameLoc*/ loc,
        /*Inherited*/ {}, /*GenericParams*/ genericParams, /*DC*/ &file);
    // Note: must mark enum as implicit to satisfy assertion in
    // `Parser::parseDeclListDelayed`.
    branchingTraceDecl->setImplicit();
    if (genericSig)
      branchingTraceDecl->setGenericSignature(genericSig);
    computeAccessLevel(branchingTraceDecl,
                       original->getEffectiveSymbolLinkage());
    branchingTraceDecl->computeType();
    assert(branchingTraceDecl->hasInterfaceType());
    file.addVisibleDecl(branchingTraceDecl);
    // Add basic block enum cases.
    for (auto *predBB : originalBB->getPredecessorBlocks()) {
      auto bbId = "bb" + std::to_string(predBB->getDebugID());
      auto *linearMapStruct = getLinearMapStruct(predBB);
      assert(linearMapStruct);
      auto linearMapStructTy =
          linearMapStruct->getDeclaredInterfaceType()->getCanonicalType();
      // Create dummy declaration representing enum case parameter.
      auto *decl = new (astCtx)
          ParamDecl(ParamDecl::Specifier::Default, loc, loc, Identifier(), loc,
                    Identifier(), moduleDecl);
      if (linearMapStructTy->hasArchetype())
        decl->setInterfaceType(linearMapStructTy->mapTypeOutOfContext());
      else
        decl->setInterfaceType(linearMapStructTy);
      // Create enum element and enum case declarations.
      auto *paramList = ParameterList::create(astCtx, {decl});
      auto *enumEltDecl = new (astCtx) EnumElementDecl(
          /*IdentifierLoc*/ loc, DeclName(astCtx.getIdentifier(bbId)),
          paramList, loc, /*RawValueExpr*/ nullptr, branchingTraceDecl);
      enumEltDecl->setImplicit();
      enumEltDecl->computeType();
      auto *enumCaseDecl = EnumCaseDecl::create(
          /*CaseLoc*/ loc, {enumEltDecl}, branchingTraceDecl);
      enumCaseDecl->setImplicit();
      branchingTraceDecl->addMember(enumEltDecl);
      branchingTraceDecl->addMember(enumCaseDecl);
      // Record enum element declaration.
      branchingTraceEnumCases.insert({{predBB, originalBB}, enumEltDecl});
    }
    // If original block is in a loop, mark branching trace enum as indirect.
    if (loopInfo->getLoopFor(originalBB))
      branchingTraceDecl->getAttrs().add(
          new (astCtx) IndirectAttr(/*Implicit*/ true));
    return branchingTraceDecl;
  }

  /// Creates a struct declaration with the given JVP/VJP generic signature, for
  /// storing the linear map values and predecessor/successor basic block of the
  /// given original block.
  StructDecl *
  createLinearMapStruct(SILBasicBlock *originalBB, SILAutoDiffIndices indices,
                        CanGenericSignature genericSig) {
    assert(originalBB->getParent() == original);
    auto *original = originalBB->getParent();
    auto &astCtx = original->getASTContext();
    auto &file = getDeclarationFileUnit();
    std::string structName;
    switch (kind) {
    case swift::AutoDiffLinearMapKind::Differential:
      structName =
          "_AD__" + original->getName().str() +
          "_bb" + std::to_string(originalBB->getDebugID()) +
          "__DF__" + indices.mangle();
      break;
    case swift::AutoDiffLinearMapKind::Pullback:
      structName =
          "_AD__" + original->getName().str() +
          "_bb" + std::to_string(originalBB->getDebugID()) +
          "__PB__" + indices.mangle();
      break;
    }
    auto structId = astCtx.getIdentifier(structName);
    GenericParamList *genericParams = nullptr;
    if (genericSig)
      genericParams = cloneGenericParameters(astCtx, &file, genericSig);
    auto *linearMapStruct = new (astCtx) StructDecl(
        /*StructLoc*/ SourceLoc(), /*Name*/ structId, /*NameLoc*/ SourceLoc(),
        /*Inherited*/ {}, /*GenericParams*/ genericParams, /*DC*/ &file);
    // Note: must mark struct as implicit to satisfy assertion in
    // `Parser::parseDeclListDelayed`.
    linearMapStruct->setImplicit();
    if (genericSig)
      linearMapStruct->setGenericSignature(genericSig);
    computeAccessLevel(
        linearMapStruct, original->getEffectiveSymbolLinkage());
    linearMapStruct->computeType();
    assert(linearMapStruct->hasInterfaceType());
    file.addVisibleDecl(linearMapStruct);
    return linearMapStruct;
  }

  /// Add a linear map to the linear map struct.
  VarDecl *addLinearMapDecl(SILInstruction *inst, SILType linearMapType) {
    // IRGen requires decls to have AST types (not `SILFunctionType`), so we
    // convert the `SILFunctionType` of the linear map to a `FunctionType` with
    // the same parameters and results.
    auto silFnTy = linearMapType.castTo<SILFunctionType>();
    SmallVector<AnyFunctionType::Param, 8> params;
    for (auto &param : silFnTy->getParameters())
      params.push_back(AnyFunctionType::Param(param.getType()));
    AnyFunctionType *astFnTy;
    if (auto genSig = silFnTy->getGenericSignature())
      astFnTy = GenericFunctionType::get(
          genSig, params, silFnTy->getAllResultsType().getASTType());
    else
      astFnTy = FunctionType::get(
          params, silFnTy->getAllResultsType().getASTType());

    auto *origBB = inst->getParent();
    auto *linMapStruct = getLinearMapStruct(origBB);
    std::string linearMapName;
    switch (kind) {
    case AutoDiffLinearMapKind::Differential:
      linearMapName = "differential_" + llvm::itostr(linearMapValueMap.size());
      break;
    case AutoDiffLinearMapKind::Pullback:
      linearMapName = "pullback_" + llvm::itostr(linearMapValueMap.size());
      break;
    }
    auto *linearMapDecl = addVarDecl(linMapStruct, linearMapName, astFnTy);
    linearMapValueMap.insert({inst, linearMapDecl});
    return linearMapDecl;
  }

  /// Given an `apply` instruction, conditionally adds its linear map function
  /// to the linear map struct if it is active.
  void addLinearMapToStruct(ADContext &context, ApplyInst *ai,
                            const SILAutoDiffIndices &indices);

  /// Generate linear map struct and branching enum declarations for the given
  /// function. Linear map structs are populated with linear map fields and a
  /// branching enum field.
  void generateDifferentiationDataStructures(
      ADContext &context, const SILAutoDiffIndices &indices,
      SILFunction *derivative);

public:
  bool shouldDifferentiateApplyInst(ApplyInst *ai);
  bool shouldDifferentiateInstruction(SILInstruction *inst);

  LinearMapInfo(const LinearMapInfo &) = delete;
  LinearMapInfo &operator=(const LinearMapInfo &) = delete;

  explicit LinearMapInfo(ADContext &context,
                         AutoDiffLinearMapKind kind,
                         SILFunction *original, SILFunction *derivative,
                         const SILAutoDiffIndices &indices,
                         const DifferentiableActivityInfo &activityInfo);

  /// Returns the linear map struct associated with the given original block.
  StructDecl *getLinearMapStruct(SILBasicBlock *origBB) const {
    return linearMapStructs.lookup(origBB);
  }

  /// Returns the lowered SIL type of the linear map struct associated with the
  /// given original block.
  SILType getLinearMapStructLoweredType(SILBasicBlock *origBB) const {
    auto *linMapStruct = getLinearMapStruct(origBB);
    auto linMapStructType =
        linMapStruct->getDeclaredInterfaceType()->getCanonicalType();
    return typeConverter.getLoweredType(linMapStructType,
                                        ResilienceExpansion::Minimal);
  }

  /// Returns the branching trace enum associated with the given original block.
  EnumDecl *getBranchingTraceDecl(SILBasicBlock *origBB) const {
    return branchingTraceDecls.lookup(origBB);
  }

  /// Returns the lowered SIL type of the branching trace enum associated with
  /// the given original block.
  SILType getBranchingTraceEnumLoweredType(SILBasicBlock *origBB) const {
    auto *traceDecl = getBranchingTraceDecl(origBB);
    auto traceDeclType =
        traceDecl->getDeclaredInterfaceType()->getCanonicalType();
    return typeConverter.getLoweredType(traceDeclType,
                                        ResilienceExpansion::Minimal);
  }

  /// Returns the enum element in the given successor block's branching trace
  /// enum corresponding to the given predecessor block.
  EnumElementDecl *
  lookUpBranchingTraceEnumElement(SILBasicBlock *origPredBB,
                                  SILBasicBlock *origSuccBB) const {
    assert(origPredBB->getParent() == original);
    return branchingTraceEnumCases.lookup({origPredBB, origSuccBB});
  }

  /// Returns the mapping from linear map structs to their branching trace enum
  /// fields.
  DenseMap<StructDecl *, VarDecl *> &getLinearMapStructEnumFields() {
    return linearMapStructEnumFields;
  }

  /// Returns the branching trace enum field for the linear map struct of the
  /// given original block.
  VarDecl *lookUpLinearMapStructEnumField(SILBasicBlock *origBB) {
    auto *linearMapStruct = getLinearMapStruct(origBB);
    return linearMapStructEnumFields.lookup(linearMapStruct);
  }

  /// Finds the linear map declaration in the pullback struct for an `apply` or
  /// `struct_extract` in the original function.
  VarDecl *lookUpLinearMapDecl(SILInstruction *inst) {
    auto lookup = linearMapValueMap.find(inst);
    assert(lookup != linearMapValueMap.end() &&
           "No linear map declaration corresponding to the given instruction");
    return lookup->getSecond();
  }
};

/// Stores `apply` instruction information calculated by VJP generation.
struct NestedApplyInfo {
  /// The differentiation indices that are used to differentiate this `apply`
  /// instruction.
  SILAutoDiffIndices indices;
  /// The original pullback type before reabstraction. `None` if the pullback
  /// type is not reabstracted.
  Optional<CanSILFunctionType> originalPullbackType;
};

static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
                                            DifferentiationInvoker invoker) {
  invoker.print(os);
  return os;
}

void DifferentiationInvoker::print(llvm::raw_ostream &os) const {
  os << "(differentiation_invoker ";
  switch (kind) {
  case Kind::DifferentiableFunctionInst:
    os << "differentiable_function_inst=(" << *getDifferentiableFunctionInst()
       << ")";
    break;
  case Kind::IndirectDifferentiation: {
    auto indDiff = getIndirectDifferentiation();
    os << "indirect_differentiation=(" << *std::get<0>(indDiff) << ')';
    // TODO: Enable printing parent invokers.
    // May require storing a `DifferentiableInvoker *` in the
    // `IndirectDifferentiation` case.
    /*
    SILInstruction *inst;
    SILDifferentiableAttr *attr;
    std::tie(inst, attr) = getIndirectDifferentiation();
    auto invokerLookup = invokers.find(attr); // No access to ADContext?
    assert(invokerLookup != invokers.end() && "Expected parent invoker");
    */
    break;
  }
  case Kind::SILDifferentiableAttribute: {
    auto diffAttr = getSILDifferentiableAttribute();
    os << "sil_differentiable_attribute=(attr=(";
    diffAttr->print(os);
    os << ") function=" << diffAttr->getOriginal()->getName();
    break;
  }
  }
  os << ')';
}

//===----------------------------------------------------------------------===//
// ADContext - Per-module contextual information for the Differentiation pass.
//===----------------------------------------------------------------------===//

class ADContext {
private:
  /// Reference to the main transform.
  SILModuleTransform &transform;

  /// The module where Differentiation is performed on.
  SILModule &module;

  /// AST context.
  ASTContext &astCtx = module.getASTContext();

  /// Shared pass manager.
  SILPassManager &passManager;

  /// The worklist (stack) of `differentiable_function` instructions to be
  /// processed.
  SmallVector<DifferentiableFunctionInst *, 32> differentiableFunctionInsts;

  /// The set of `differentiable_function` instructions that have been
  /// processed. Used to avoid reprocessing invalidated instructions.
  SmallPtrSet<DifferentiableFunctionInst *, 32>
      processedDifferentiableFunctionInsts;

  /// Mapping from `[differentiable]` attributes to invokers.
  /// `SmallMapVector` is used for deterministic insertion order iteration.
  SmallMapVector<SILDifferentiableAttr *, DifferentiationInvoker, 32>
      invokers;

  /// Mapping from `differentiable_function` instructions to result indices.
  DenseMap<DifferentiableFunctionInst *, unsigned> resultIndices;

  /// Mapping from original `apply` instructions to their corresponding
  /// `NestedApplyInfo`s.
  DenseMap<ApplyInst *, NestedApplyInfo> nestedApplyInfo;

  /// List of generated functions (JVPs, VJPs, pullbacks, and thunks).
  /// Saved for deletion during cleanup.
  SmallVector<SILFunction *, 32> generatedFunctions;

  /// List of references to generated functions.
  /// Saved for deletion during cleanup.
  SmallVector<SILValue, 32> generatedFunctionReferences;

  /// The AdditiveArithmetic protocol in the standard library.
  ProtocolDecl *additiveArithmeticProtocol =
      astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);

  /// `AdditiveArithmetic.+` declaration.
  mutable FuncDecl *cachedPlusFn = nullptr;
  /// `AdditiveArithmetic.+=` declaration.
  mutable FuncDecl *cachedPlusEqualFn = nullptr;

public:
  /// Construct an ADContext for the given module.
  explicit ADContext(SILModuleTransform &transform);

  //--------------------------------------------------------------------------//
  // General utilities
  //--------------------------------------------------------------------------//

  SILModuleTransform &getTransform() const { return transform; }
  SILModule &getModule() const { return module; }
  ASTContext &getASTContext() const { return module.getASTContext(); }
  SILPassManager &getPassManager() const { return passManager; }
  Lowering::TypeConverter &getTypeConverter() { return module.Types; }

  SmallVectorImpl<DifferentiableFunctionInst *> &
  getDifferentiableFunctionInsts() {
    return differentiableFunctionInsts;
  }

  SmallPtrSetImpl<DifferentiableFunctionInst *> &
  getProcessedDifferentiableFunctionInsts() {
    return processedDifferentiableFunctionInsts;
  }

  llvm::SmallMapVector<SILDifferentiableAttr *, DifferentiationInvoker, 32> &
  getInvokers() {
    return invokers;
  }

  DenseMap<DifferentiableFunctionInst *, unsigned> &getResultIndices() {
    return resultIndices;
  }

  DenseMap<ApplyInst *, NestedApplyInfo> &getNestedApplyInfo() {
    return nestedApplyInfo;
  }

  SmallVector<SILFunction *, 32> &getGeneratedFunctions() {
    return generatedFunctions;
  }

  SmallVector<SILValue, 32> &getGeneratedFunctionReferences() {
    return generatedFunctionReferences;
  }

  ProtocolDecl *getAdditiveArithmeticProtocol() const {
    return additiveArithmeticProtocol;
  }

  FuncDecl *getPlusDecl() const {
    if (!cachedPlusFn) {
      cachedPlusFn = findOperatorDeclInProtocol(
          astCtx.getIdentifier("+"), additiveArithmeticProtocol);
      assert(cachedPlusFn && "AdditiveArithmetic.+ not found");
    }
    return cachedPlusFn;
  }

  FuncDecl *getPlusEqualDecl() const {
    if (!cachedPlusEqualFn) {
      cachedPlusEqualFn = findOperatorDeclInProtocol(
          astCtx.getIdentifier("+="), additiveArithmeticProtocol);
      assert(cachedPlusEqualFn && "AdditiveArithmetic.+= not found");
    }
    return cachedPlusEqualFn;
  }

  void cleanUp() {
    for (auto invokerPair : invokers) {
      auto *attr = std::get<0>(invokerPair);
      auto *original = attr->getOriginal();
      LLVM_DEBUG(getADDebugStream()
                 << "Removing [differentiable] attribute for "
                 << original->getName() << '\n');
      original->removeDifferentiableAttr(attr);
    }
    // Delete all references to generated functions.
    for (auto fnRef : generatedFunctionReferences) {
      if (auto *fnRefInst =
              peerThroughFunctionConversions<FunctionRefInst>(fnRef)) {
        fnRefInst->replaceAllUsesWithUndef();
        fnRefInst->eraseFromParent();
      }
    }
    // Delete all generated functions.
    for (auto *generatedFunction : generatedFunctions) {
      LLVM_DEBUG(getADDebugStream()
                 << "Deleting generated function "
                 << generatedFunction->getName() << '\n');
      generatedFunction->dropAllReferences();
      transform.notifyWillDeleteFunction(generatedFunction);
      module.eraseFunction(generatedFunction);
    }
  }

  //--------------------------------------------------------------------------//
  // `[differentiable]` attribute lookup and registration
  //--------------------------------------------------------------------------//

  /// Finds the `[differentiable]` attribute on the specified original function
  /// with the exact specified parameter indices. Returns nullptr if no such
  /// attribute exists.
  SILDifferentiableAttr *lookUpDifferentiableAttr(
      SILFunction *original, const SILAutoDiffIndices &indices) const {
    for (auto *attr : original->getDifferentiableAttrs())
      if (attr->getIndices() == indices)
        return attr;
    return nullptr;
  }

  /// Finds the `[differentiable]` attribute on the specified original function
  /// whose parameter indices are a minimal superset of the specified parameter
  /// indices. Returns nullptr if no such attribute exists.
  SILDifferentiableAttr *lookUpMinimalDifferentiableAttr(
      SILFunction *original, const SILAutoDiffIndices &indices) const {
    auto *minimalIndexSet = IndexSubset::getDefault(
        getASTContext(),
        original->getLoweredFunctionType()->getNumParameters(), false);
    auto *indexSet = indices.parameters;
    if (auto *exactAttr = lookUpDifferentiableAttr(original, indices))
      return exactAttr;
    SILDifferentiableAttr *minimalAttr = nullptr;
    for (auto *da : original->getDifferentiableAttrs()) {
      if (da->getIndices().source != indices.source)
        continue;
      auto *daIndexSet = da->getIndices().parameters;
      // If all indices in `indexSet` are in `daIndexSet`, and it has fewer
      // indices than our current candidate and a primitive VJP, then `da` is
      // our new candidate.
      //
      // NOTE(TF-642): `da` may come from a un-partial-applied function and
      // have larger capacity than the desired indices. We expect this logic to
      // go away when `partial_apply` supports `@differentiable` callees.
      if (daIndexSet->isSupersetOf(indexSet->extendingCapacity(
              getASTContext(), daIndexSet->getCapacity())) &&
          // fewer parameters than before
          (minimalIndexSet->isEmpty() ||
           daIndexSet->getNumIndices() < minimalIndexSet->getNumIndices())) {
        minimalAttr = da;
        minimalIndexSet = daIndexSet;
      }
    }
    return minimalAttr;
  }

  /// Finds the `@differentiable` attribute (and its parameter indices) on the
  /// specified original function whose parameter indices are a minimal
  /// superset of the specified parameter indices. Returns nullptr if no such
  /// attribute exists.
  std::pair<const DifferentiableAttr *, IndexSubset *>
  lookUpMinimalASTDifferentiableAttrAndIndexSubset(
      SILDeclRef originalDeclRef, CanSILFunctionType originalFnType,
      const SILAutoDiffIndices &indices) {
    auto *original = originalDeclRef.getDecl();
    const DifferentiableAttr *minimalAttr = nullptr;
    auto *minimalIndexSet = IndexSubset::getDefault(
        getASTContext(), originalFnType->getNumParameters(), false);
    auto *indexSet = indices.parameters;
    for (auto *da : original->getAttrs().getAttributes<DifferentiableAttr>()) {
      auto *daParamIndices = da->getParameterIndices();
      auto *daIndexSet = autodiff::getLoweredParameterIndices(
          daParamIndices, original->getInterfaceType()->castTo<AnyFunctionType>());
      // If all indices in `indexSet` are in `daIndexSet`, and it has fewer
      // indices than our current candidate and a primitive VJP, then `da` is
      // our new candidate.
      //
      // NOTE(TF-642): `da` may come from a un-partial-applied function and
      // have larger capacity than the desired indices. We expect this logic to
      // go away when `partial_apply` supports `@differentiable` callees.
      if (daIndexSet->isSupersetOf(indexSet->extendingCapacity(getASTContext(),
              daIndexSet->getCapacity())) &&
          // fewer parameters than before
          (minimalIndexSet->isEmpty() ||
           daIndexSet->getNumIndices() < minimalIndexSet->getNumIndices())) {
        minimalAttr = da;
        minimalIndexSet = daIndexSet;
      }
    }
    return std::make_pair(minimalAttr, minimalIndexSet);
  }

  /// Creates a `[differentiable]` attribute on the specified original function
  /// with the specified parameter indices.
  SILDifferentiableAttr *createDifferentiableAttr(
      SILFunction *original, const SILAutoDiffIndices &indices,
      GenericSignature derivativeGenericSignature) const {
    assert(!lookUpDifferentiableAttr(original, indices));
    auto derivativeConstrainedGenSig = getConstrainedDerivativeGenericSignature(
        original->getLoweredFunctionType(), indices.parameters,
        derivativeGenericSignature);
    auto *attr = SILDifferentiableAttr::create(getModule(), indices,
                                               /*jvpName*/ StringRef(),
                                               /*vjpName*/ StringRef(),
                                               derivativeConstrainedGenSig);
    original->addDifferentiableAttr(attr);
    return attr;
  }

  /// Finds or creates a `[differentiable]` attribute on the specified
  /// original function corresponding to the specified parameter indices.
  SILDifferentiableAttr *getOrCreateDifferentiableAttr(
      SILFunction *original, const SILAutoDiffIndices &indices,
      GenericSignature derivativeGenericSignature) {
    if (auto *attr = lookUpDifferentiableAttr(original, indices))
      return attr;
    assert(original->isDefinition());
    return createDifferentiableAttr(original, indices,
                                    derivativeGenericSignature);
  }

  /// Creates an `differentiable_function` instruction using the given builder
  /// and arguments. Erase the newly created instruction from the processed set,
  /// if it exists - it may exist in the processed set if it has the same
  /// pointer value as a previously processed and deleted instruction.
  DifferentiableFunctionInst *createDifferentiableFunction(
      SILBuilder &builder, SILLocation loc,
      IndexSubset *parameterIndices, SILValue original,
      Optional<std::pair<SILValue, SILValue>> derivativeFunctions = None) {
    auto *dfi = builder.createDifferentiableFunction(
        loc, parameterIndices, original, derivativeFunctions);
    processedDifferentiableFunctionInsts.erase(dfi);
    return dfi;
  }

private:
  /// Promotes the given `differentiable_function` instruction to a valid
  /// `@differentiable` function-typed value.
  SILValue promoteToDifferentiableFunction(
      DifferentiableFunctionInst *inst, SILBuilder &builder, SILLocation loc,
      DifferentiationInvoker invoker);

public:
  /// Process the given `[differentiable]` attribute, filling in JVP/VJPs if
  /// missing.
  bool processDifferentiableAttribute(
      SILFunction *original, SILDifferentiableAttr *attr,
      DifferentiationInvoker invoker);

  /// Process the given `differentiable_function` instruction, filling in
  /// missing derivative functions if necessary.
  bool processDifferentiableFunctionInst(DifferentiableFunctionInst *dfi);

  /// Fold `differentiable_function_extract` users of the given
  /// `differentiable_function` instruction, directly replacing them with
  /// `differentiable_function` instruction operands. If the
  /// `differentiable_function` instruction has no remaining uses, delete the
  /// instruction itself after folding.
  ///
  /// Folding can be disabled by the
  /// `SkipFoldingDifferentiableFunctionExtraction` flag for SIL testing
  /// purposes.
  void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source);

  /// Get or create a derivative function parameter index subset thunk from
  /// `actualIndices` to `desiredIndices` for the given associated function
  /// value and original function operand. Returns a pair of the parameter
  /// index subset thunk and its interface substitution map (used to partially
  /// apply the thunk).
  /// Calls `getOrCreateSubsetParametersThunkForLinearMap` to thunk the linear
  /// map returned by the derivative function.
  std::pair<SILFunction *, SubstitutionMap>
  getOrCreateSubsetParametersThunkForDerivativeFunction(
      SILValue origFnOperand, SILValue derivativeFn,
      AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
      SILAutoDiffIndices actualIndices);

  /// Get or create a derivative function parameter index subset thunk from
  /// `actualIndices` to `desiredIndices` for the given associated function
  /// value and original function operand. Returns a pair of the parameter
  /// index subset thunk and its interface substitution map (used to partially
  /// apply the thunk).
  std::pair<SILFunction *, SubstitutionMap>
  getOrCreateSubsetParametersThunkForLinearMap(
      SILFunction *assocFn, CanSILFunctionType linearMapType,
      CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
      SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices);

public:
  /// Declare an external reference to a derivative function of `original`,
  /// given a `[differentiable]` attribute of `original` and the associated
  /// function kind.
  SILFunction *
  declareExternalDerivativeFunction(SILFunction *original,
                                    SILDifferentiableAttr *attr, StringRef name,
                                    AutoDiffDerivativeFunctionKind kind);

  template <typename ...T, typename ...U>
  InFlightDiagnostic diagnose(SourceLoc loc, Diag<T...> diag,
                              U &&...args) const {
    return getASTContext().Diags.diagnose(loc, diag, std::forward<U>(args)...);
  }

  /// Given an instruction and a differentiation task associated with the
  /// parent function, emits a "not differentiable" error based on the task. If
  /// the task is indirect, emits notes all the way up to the outermost task,
  /// and emits an error at the outer task. Otherwise, emits an error directly.
  template<typename ...T, typename ...U>
  InFlightDiagnostic emitNondifferentiabilityError(
      SILInstruction *inst, DifferentiationInvoker invoker,
      Diag<T...> diag, U &&...args);

  /// Given a value and a differentiation task associated with the parent
  /// function, emits a "not differentiable" error based on the task. If the
  /// task is indirect, emits notes all the way up to the outermost task, and
  /// emits an error at the outer task. Otherwise, emits an error directly.
  template<typename ...T, typename ...U>
  InFlightDiagnostic emitNondifferentiabilityError(
      SILValue value, DifferentiationInvoker invoker,
      Diag<T...> diag, U &&...args);

  /// Emit a "not differentiable" error based on the given differentiation task
  /// and diagnostic.
  template<typename ...T, typename ...U>
  InFlightDiagnostic emitNondifferentiabilityError(
      SourceLoc loc, DifferentiationInvoker invoker,
      Diag<T...> diag, U &&...args);
};
} // end anonymous namespace

ADContext::ADContext(SILModuleTransform &transform)
    : transform(transform), module(*transform.getModule()),
      passManager(*transform.getPassManager()) {}

template<typename ...T, typename ...U>
InFlightDiagnostic
ADContext::emitNondifferentiabilityError(SILValue value,
                                         DifferentiationInvoker invoker,
                                         Diag<T...> diag, U &&...args) {
  LLVM_DEBUG({
    getADDebugStream() << "Diagnosing non-differentiability.\n";
    getADDebugStream() << "For value:\n" << value;
    getADDebugStream() << "With invoker:\n" << invoker << '\n';
  });
  auto valueLoc = value.getLoc().getSourceLoc();
  // If instruction does not have a valid location, use the function location
  // as a fallback. Improves diagnostics in some cases.
  if (valueLoc.isInvalid())
    valueLoc = value->getFunction()->getLocation().getSourceLoc();
  return emitNondifferentiabilityError(valueLoc, invoker, diag,
                                       std::forward<U>(args)...);
}

template<typename ...T, typename ...U>
InFlightDiagnostic
ADContext::emitNondifferentiabilityError(SILInstruction *inst,
                                         DifferentiationInvoker invoker,
                                         Diag<T...> diag, U &&...args) {
  LLVM_DEBUG({
    getADDebugStream() << "Diagnosing non-differentiability.\n";
    getADDebugStream() << "For instruction:\n" << *inst;
    getADDebugStream() << "With invoker:\n" << invoker << '\n';
  });
  auto instLoc = inst->getLoc().getSourceLoc();
  // If instruction does not have a valid location, use the function location
  // as a fallback. Improves diagnostics for `ref_element_addr` generated in
  // synthesized stored property getters.
  if (instLoc.isInvalid())
    instLoc = inst->getFunction()->getLocation().getSourceLoc();
  return emitNondifferentiabilityError(instLoc, invoker, diag,
                                       std::forward<U>(args)...);
}

template<typename ...T, typename ...U>
InFlightDiagnostic
ADContext::emitNondifferentiabilityError(SourceLoc loc,
                                         DifferentiationInvoker invoker,
                                         Diag<T...> diag, U &&...args) {
  switch (invoker.getKind()) {
  // For `differentiable_function` instructions: if the `differentiable_function`
  // instruction comes from a differential operator, emit an error on the
  // expression and a note on the non-differentiable operation. Otherwise, emit
  // both an error and note on the non-differentiation operation.
  case DifferentiationInvoker::Kind::DifferentiableFunctionInst: {
    auto *inst = invoker.getDifferentiableFunctionInst();
    if (auto *expr = findDifferentialOperator(inst)) {
      diagnose(expr->getLoc(), diag::autodiff_function_not_differentiable_error)
          .highlight(expr->getSubExpr()->getSourceRange());
      return diagnose(loc, diag, std::forward<U>(args)...);
    }
    diagnose(loc, diag::autodiff_expression_not_differentiable_error);
    return diagnose(loc, diag, std::forward<U>(args)...);
  }

  // For `[differentiable]` attributes, try to find an AST function declaration
  // and `@differentiable` attribute. If they are found, emit an error on the
  // `@differentiable` attribute; otherwise, emit an error on the SIL function.
  // Emit a note at the non-differentiable operation.
  case DifferentiationInvoker::Kind::SILDifferentiableAttribute: {
    auto *attr = invoker.getSILDifferentiableAttribute();
    auto *original = attr->getOriginal();
    bool foundAttr = false;
    if (auto *declContext = original->getDeclContext()) {
      if (auto *fnDecl = declContext->getInnermostDeclarationDeclContext()) {
        if (auto *diffAttr =
                fnDecl->getAttrs().getAttribute<DifferentiableAttr>()) {
          diagnose(diffAttr->getLocation(),
                   diag::autodiff_function_not_differentiable_error)
              .highlight(diffAttr->getRangeWithAt());
          diagnose(original->getLocation().getSourceLoc(),
                   diag::autodiff_when_differentiating_function_definition);
          foundAttr = true;
        }
      }
    }
    // Fallback if we cannot find the expected attribute.
    if (!foundAttr)
      diagnose(original->getLocation().getSourceLoc(),
               diag::autodiff_function_not_differentiable_error);
    return diagnose(loc, diag, std::forward<U>(args)...);
  }

  // For indirect differentiation, emit a "not differentiable" note on the
  // expression first. Then emit an error at the source invoker of
  // differentiation, and a "when differentiating this" note at each indirect
  // invoker.
  case DifferentiationInvoker::Kind::IndirectDifferentiation: {
    SILInstruction *inst;
    SILDifferentiableAttr *attr;
    std::tie(inst, attr) = invoker.getIndirectDifferentiation();
    auto invokerLookup = invokers.find(attr);
    assert(invokerLookup != invokers.end() && "Expected parent invoker");
    emitNondifferentiabilityError(inst, invokerLookup->second,
        diag::autodiff_expression_not_differentiable_note);
    return diagnose(loc, diag::autodiff_when_differentiating_function_call);
  }
  }
}

//===----------------------------------------------------------------------===//
// Activity Analysis
//===----------------------------------------------------------------------===//

namespace {
class DifferentiableActivityCollection;

/// In many real situations, the end-users of AD need only the derivatives of
/// some selected outputs of `P` with respect to some selected inputs of `P`.
/// Whatever the differentiation mode (tangent, reverse,...), these restrictions
/// allow the AD tool to produce a much more efficient differentiated program.
/// Essentially, fixing some inputs and neglecting some outputs allows AD to
/// just forget about several intermediate differentiated variables.
///
/// Activity analysis is the specific analysis that detects these situations,
/// therefore allowing for a better differentiated code. Activity analysis is
/// present in all transformation-based AD tools.
///
/// To begin with, the end-user specifies that only some output variables (the
/// “dependent”) must be differentiated with respect to only some input
/// variables (the “independent”). We say that variable `y` depends on `x` when
/// the derivative of `y` with respect to `x` is not trivially null. We say that
/// a variable is “varied” if it depends on at least one independent. Conversely
/// we say that a variable is “useful” if at least one dependent depends on it.
/// Finally, we say that a variable is “active” if it is at the same time varied
/// and useful. In the special case of the tangent mode, it is easy to check
/// that when variable `v` is not varied at some place in the program, then its
/// derivative `v̇` at this place is certainly null. Conversely when variable `v`
/// is not useful, then whatever the value of `v̇`, this value does not matter
/// for the final result. Symmetric reasoning applies for the reverse mode of
/// AD: observing that differentiated variables go upstream, we see that a
/// useless variable has a null derivative, in other words the partial
/// derivative of the output with respect to this variable is null. Conversely
/// when variable `v` is not varied, then whatever the value of `v`, this value
/// does not matter for the final result.
///
/// Reference:
/// Laurent Hascoët. Automatic Differentiation by Program Transformation. 2007.
class DifferentiableActivityAnalysis
    : public FunctionAnalysisBase<DifferentiableActivityCollection> {
private:
  DominanceAnalysis *dominanceAnalysis = nullptr;
  PostDominanceAnalysis *postDominanceAnalysis = nullptr;

public:
  explicit DifferentiableActivityAnalysis()
      : FunctionAnalysisBase(SILAnalysisKind::DifferentiableActivity) {}

  static bool classof(const SILAnalysis *s) {
    return s->getKind() == SILAnalysisKind::DifferentiableActivity;
  }

  virtual bool shouldInvalidate(SILAnalysis::InvalidationKind k) override {
    return k & InvalidationKind::Everything;
  }

  virtual std::unique_ptr<DifferentiableActivityCollection>
  newFunctionAnalysis(SILFunction *f) override;

  virtual void initialize(SILPassManager *pm) override;
};
} // end anonymous namespace

namespace {
/// Represents the differentiation activity associated with a SIL value.
enum class ActivityFlags : unsigned {
  /// The value depends on a function parameter.
  Varied = 1 << 1,
  /// The value contributes to a result.
  Useful = 1 << 2,
  /// The value is both varied and useful.
  Active = Varied | Useful,
};

using Activity = OptionSet<ActivityFlags>;

/// Result of activity analysis on a function. Accepts queries for whether a
/// value is "varied", "useful" or "active" against certain differentiation
/// indices.
class DifferentiableActivityInfo {
private:
  DifferentiableActivityCollection &parent;

  /// The derivative generic signature.
  GenericSignature derivativeGenericSignature;

  /// Input values, i.e. parameters (both direct and indirect).
  SmallVector<SILValue, 4> inputValues;
  /// Output values, i.e. individual values (not the final tuple) being returned
  /// by the `return` instruction.
  SmallVector<SILValue, 4> outputValues;

  /// The set of useful variables, indexed by the corresponding dependent value
  /// (output) index.
  SmallVector<SmallDenseSet<SILValue>, 4> usefulValueSets;
  /// The set of useful variables, indexed by the corresponding independent
  /// value (input) index.
  SmallVector<SmallDenseSet<SILValue>, 4> variedValueSets;

  /// The original function.
  SILFunction &getFunction();

  /// The conformance lookup function.
  LookupConformanceFn getLookupConformanceFunction() {
    // Look up in derivative generic signature, if defined.
    if (derivativeGenericSignature)
      return LookUpConformanceInSignature(
          derivativeGenericSignature.getPointer());
    // Otherwise, look up in the module.
    return LookUpConformanceInModule(
        getFunction().getModule().getSwiftModule());
  }

  /// Perform analysis and populate sets.
  void analyze(DominanceInfo *di, PostDominanceInfo *pdi);

  void setVaried(SILValue value, unsigned independentVariableIndex);
  void setVariedAcrossArrayInitialization(SILValue value,
                                          unsigned independentVariableIndex);
  void setUseful(SILValue value, unsigned dependentVariableIndex);
  void setUsefulAcrossArrayInitialization(SILValue value,
                                          unsigned dependentVariableIndex);
  /// Marks the given value as "varied" and recursively propagates "varied"
  /// inwards (to operands) through projections. Skips any `@noDerivative`
  /// struct field projections.
  void propagateVariedInwardsThroughProjections(
      SILValue value, unsigned independentVariableIndex);
  void propagateUsefulThroughBuffer(SILValue value,
                                    unsigned dependentVariableIndex);

public:
  explicit DifferentiableActivityInfo(
      DifferentiableActivityCollection &parent,
      GenericSignature derivativeGenericSignature);

  bool isVaried(SILValue value, unsigned independentVariableIndex) const;
  bool isUseful(SILValue value, unsigned dependentVariableIndex) const;
  bool isVaried(SILValue value, IndexSubset *parameterIndices) const;
  bool isActive(SILValue value, const SILAutoDiffIndices &indices) const;

  Activity getActivity(SILValue value,
                       const SILAutoDiffIndices &indices) const;
  Activity getActivity(SILInstruction *inst,
                       const SILAutoDiffIndices &indices) const;
};

/// Given a parameter argument (not indirect result) and some differentiation
/// indices, figure out whether the parent function is being differentiated with
/// respect to this parameter, according to the indices.
static bool isDifferentiationParameter(SILArgument *argument,
                                       IndexSubset *indices) {
  if (!argument) return false;
  auto *function = argument->getFunction();
  auto paramArgs = function->getArgumentsWithoutIndirectResults();
  for (unsigned i : indices->getIndices())
    if (paramArgs[i] == argument)
      return true;
  return false;
}

/// For an `apply` instruction with active results, compute:
/// - The results of the `apply` instruction, in type order.
/// - The set of minimal parameter and result indices for differentiating the
///   `apply` instruction.
static void collectMinimalIndicesForFunctionCall(
    ApplyInst *ai, const SILAutoDiffIndices &parentIndices,
    const DifferentiableActivityInfo &activityInfo,
    SmallVectorImpl<SILValue> &results, SmallVectorImpl<unsigned> &paramIndices,
    SmallVectorImpl<unsigned> &resultIndices) {
  auto calleeFnTy = ai->getSubstCalleeType();
  auto calleeConvs = ai->getSubstCalleeConv();
  // Parameter indices are indices (in the callee type signature) of parameter
  // arguments that are varied or are arguments.
  // Record all parameter indices in type order.
  unsigned currentParamIdx = 0;
  for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) {
    if (activityInfo.isVaried(applyArg, parentIndices.parameters) ||
        isDifferentiationParameter(dyn_cast<SILArgument>(applyArg),
                                   parentIndices.parameters))
      paramIndices.push_back(currentParamIdx);
    ++currentParamIdx;
  }
  // Result indices are indices (in the callee type signature) of results that
  // are useful.
  SmallVector<SILValue, 8> directResults;
  forEachApplyDirectResult(ai, [&](SILValue directResult) {
    directResults.push_back(directResult);
  });
  auto indirectResults = ai->getIndirectSILResults();
  // Record all results and result indices in type order.
  results.reserve(calleeFnTy->getNumResults());
  unsigned dirResIdx = 0;
  unsigned indResIdx = calleeConvs.getSILArgIndexOfFirstIndirectResult();
  for (auto &resAndIdx : enumerate(calleeConvs.getResults())) {
    auto &res = resAndIdx.value();
    unsigned idx = resAndIdx.index();
    if (res.isFormalDirect()) {
      results.push_back(directResults[dirResIdx]);
      if (auto dirRes = directResults[dirResIdx])
        if (dirRes && activityInfo.isUseful(dirRes, parentIndices.source))
          resultIndices.push_back(idx);
      ++dirResIdx;
    } else {
      results.push_back(indirectResults[indResIdx]);
      if (activityInfo.isUseful(indirectResults[indResIdx],
                                parentIndices.source))
        resultIndices.push_back(idx);
      ++indResIdx;
    }
  }
  // Make sure the function call has active results.
  assert(results.size() == calleeFnTy->getNumResults());
  assert(llvm::any_of(results, [&](SILValue result) {
    return activityInfo.isActive(result, parentIndices);
  }));
}

LinearMapInfo::LinearMapInfo(ADContext &context,
                             AutoDiffLinearMapKind kind,
                             SILFunction *original, SILFunction *derivative,
                             const SILAutoDiffIndices &indices,
                             const DifferentiableActivityInfo &activityInfo)
    : kind(kind), original(original), derivative(derivative),
      activityInfo(activityInfo), indices(indices),
      typeConverter(context.getTypeConverter()) {
  generateDifferentiationDataStructures(context, indices, derivative);
}

/// Returns a flag that indicates whether the `apply` instruction should be
/// differentiated, given the differentiation indices of the instruction's
/// parent function. Whether the `apply` should be differentiated is determined
/// sequentially from the following conditions:
/// 1. The instruction has an active `inout` argument.
/// 2. The instruction is a call to the array literal initialization intrinsic
///    ("array.uninitialized_intrinsic"), where the result is active and where
///    there is a `store` of an active value into the array's buffer.
/// 3. The instruction has both an active result (direct or indirect) and an
///    active argument.
bool LinearMapInfo::shouldDifferentiateApplyInst(ApplyInst *ai) {
  // Function applications with an inout argument should be differentiated.
  auto paramInfos = ai->getSubstCalleeConv().getParameters();
  auto arguments = ai->getArgumentsWithoutIndirectResults();
  for (auto i : swift::indices(paramInfos))
    if (paramInfos[i].isIndirectInOut() &&
        activityInfo.isActive(arguments[i], indices))
      return true;

  bool hasActiveDirectResults = false;
  forEachApplyDirectResult(ai, [&](SILValue directResult) {
    hasActiveDirectResults |= activityInfo.isActive(directResult, indices);
  });
  bool hasActiveIndirectResults = llvm::any_of(ai->getIndirectSILResults(),
      [&](SILValue result) { return activityInfo.isActive(result, indices); });
  bool hasActiveResults = hasActiveDirectResults || hasActiveIndirectResults;

  // TODO: Pattern match to make sure there is at least one `store` to the
  // array's active buffer.
  if (isArrayLiteralIntrinsic(ai) && hasActiveResults)
    return true;

  bool hasActiveArguments = llvm::any_of(arguments,
      [&](SILValue arg) { return activityInfo.isActive(arg, indices); });
  return hasActiveResults && hasActiveArguments;
}

/// Returns a flag indicating whether the instruction should be differentiated,
/// given the differentiation indices of the instruction's parent function.
/// Whether the instruction should be differentiated is determined sequentially
/// from any of the following conditions:
/// 1. The instruction is an `apply` and `shouldDifferentiateApplyInst` returns
///    true.
/// 2. The instruction has a source operand and a destination operand, both
///    being active.
/// 3. The instruction is an allocation instruction and has an active result.
/// 4. The instruction performs reference counting, lifetime ending, access
///    ending, or destroying on an active operand.
/// 5. The instruction creates an SSA copy of an active operand.
bool LinearMapInfo::shouldDifferentiateInstruction(SILInstruction *inst) {
  // An `apply` with an active argument and an active result (direct or
  // indirect) should be differentiated.
  if (auto *ai = dyn_cast<ApplyInst>(inst))
    return shouldDifferentiateApplyInst(ai);
  // Anything with an active result and an active operand should be
  // differentiated.
  auto hasActiveOperands = llvm::any_of(inst->getAllOperands(),
      [&](Operand &op) { return activityInfo.isActive(op.get(), indices); });
  auto hasActiveResults = llvm::any_of(inst->getResults(),
      [&](SILValue val) { return activityInfo.isActive(val, indices); });
  if (hasActiveOperands && hasActiveResults)
    return true;
  // A `store`-like instruction does not have an SSA result, but has two
  // operands that represent the source and the destination. We treat them as
  // the input and the output, respectively.
#define CHECK_INST_TYPE_ACTIVE_DEST(INST) \
  if (auto *castInst = dyn_cast<INST##Inst>(inst)) \
    return activityInfo.isActive(castInst->getDest(), indices);
  CHECK_INST_TYPE_ACTIVE_DEST(Store)
  CHECK_INST_TYPE_ACTIVE_DEST(StoreBorrow)
  CHECK_INST_TYPE_ACTIVE_DEST(CopyAddr)
  CHECK_INST_TYPE_ACTIVE_DEST(UnconditionalCheckedCastAddr)
#undef CHECK_INST_TYPE_ACTIVE_DEST
  // Should differentiate any allocation instruction that has an active result.
  if ((isa<AllocationInst>(inst) && hasActiveResults))
    return true;
  if (hasActiveOperands) {
    // Should differentiate any instruction that performs reference counting,
    // lifetime ending, access ending, or destroying on an active operand.
    if (isa<RefCountingInst>(inst) || isa<EndAccessInst>(inst) ||
        isa<EndBorrowInst>(inst) || isa<DeallocationInst>(inst) ||
        isa<DestroyValueInst>(inst) || isa<DestroyAddrInst>(inst))
      return true;
    // Should differentiate any instruction that creates an SSA copy of an
    // active operand.
    if (isa<CopyValueInst>(inst))
      return true;
  }
  return false;
}

/// Takes an `apply` instruction and adds its linear map function to the
/// linear map struct if it is active.
void LinearMapInfo::addLinearMapToStruct(ADContext &context, ApplyInst *ai,
                                         const SILAutoDiffIndices &indices) {
  SmallVector<SILValue, 4> allResults;
  SmallVector<unsigned, 8> activeParamIndices;
  SmallVector<unsigned, 8> activeResultIndices;
  collectMinimalIndicesForFunctionCall(
      ai, indices, activityInfo, allResults, activeParamIndices,
      activeResultIndices);

  // Check if there are any active results or arguments. If not, skip
  // this instruction.
  auto hasActiveResults = llvm::any_of(allResults, [&](SILValue res) {
    return activityInfo.isActive(res, indices);
  });
  auto hasActiveArguments = llvm::any_of(
      ai->getArgumentsWithoutIndirectResults(), [&](SILValue arg) {
    return activityInfo.isActive(arg, indices);
  });
  if (!hasActiveResults || !hasActiveArguments)
    return;

  // Compute differentiation result index.
  auto source = activeResultIndices.front();
  // Compute differentiation parameters.
  // - If the callee has `@differentiable` function type, use differentiation
  //   parameters from the function type.
  // - Otherwise, use the active parameters.
  IndexSubset *parameters;
  auto origFnSubstTy = ai->getSubstCalleeType();
  auto remappedOrigFnSubstTy =
      remapTypeInDerivative(SILType::getPrimitiveObjectType(origFnSubstTy))
          .castTo<SILFunctionType>();
  if (remappedOrigFnSubstTy->isDifferentiable()) {
    parameters = remappedOrigFnSubstTy->getDifferentiationParameterIndices();
  } else {
    parameters = IndexSubset::get(
        original->getASTContext(),
        ai->getArgumentsWithoutIndirectResults().size(),
        activeParamIndices);
  }
  // Create autodiff indices for the `apply` instruction.
  SILAutoDiffIndices applyIndices(source, parameters);

  // Check for non-differentiable original function type.
  auto checkNondifferentiableOriginalFunctionType =
      [&](CanSILFunctionType origFnTy) {
        // Check non-differentiable arguments.
        for (unsigned paramIndex : range(origFnTy->getNumParameters())) {
          auto remappedParamType =
              origFnTy->getParameters()[paramIndex].getSILStorageType();
          if (applyIndices.isWrtParameter(paramIndex) &&
              !remappedParamType.isDifferentiable(derivative->getModule()))
            return true;
        }
        // Check non-differentiable results.
        auto remappedResultType =
            origFnTy->getResults()[applyIndices.source].getSILStorageType();
        if (!remappedResultType.isDifferentiable(derivative->getModule()))
          return true;
        return false;
      };
  if (checkNondifferentiableOriginalFunctionType(remappedOrigFnSubstTy))
    return;

  AutoDiffDerivativeFunctionKind derivativeFnKind(kind);
  auto derivativeFnType = remappedOrigFnSubstTy->getAutoDiffDerivativeFunctionType(
      parameters, source, derivativeFnKind, context.getTypeConverter(),
      LookUpConformanceInModule(derivative->getModule().getSwiftModule()));

  auto derivativeFnResultTypes =
      derivativeFnType->getAllResultsType().castTo<TupleType>();
  derivativeFnResultTypes->getElement(derivativeFnResultTypes->getElements().size() - 1);
  auto linearMapSILType = SILType::getPrimitiveObjectType(
      derivativeFnResultTypes
          ->getElement(derivativeFnResultTypes->getElements().size() - 1)
          .getType()
          ->getCanonicalType());
  addLinearMapDecl(ai, linearMapSILType);
}

void LinearMapInfo::generateDifferentiationDataStructures(
    ADContext &context, const SILAutoDiffIndices &indices,
    SILFunction *derivativeFn) {
  auto &astCtx = original->getASTContext();
  auto *loopAnalysis = context.getPassManager().getAnalysis<SILLoopAnalysis>();
  auto *loopInfo = loopAnalysis->get(original);

  // Get the derivative function generic signature.
  CanGenericSignature derivativeFnGenSig = nullptr;
  if (auto *derivativeFnGenEnv = derivativeFn->getGenericEnvironment())
    derivativeFnGenSig =
        derivativeFnGenEnv->getGenericSignature()->getCanonicalSignature();

  // Create linear map struct for each original block.
  for (auto &origBB : *original) {
    auto *linearMapStruct =
        createLinearMapStruct(&origBB, indices, derivativeFnGenSig);
    linearMapStructs.insert({&origBB, linearMapStruct});
  }

  // Create branching trace enum for each original block and add it as a field
  // in the corresponding struct.
  StringRef traceEnumFieldName;
  switch (kind) {
  case AutoDiffLinearMapKind::Differential:
    traceEnumFieldName = "successor";
    break;
  case AutoDiffLinearMapKind::Pullback:
    traceEnumFieldName = "predecessor";
    break;
  }
  for (auto &origBB : *original) {
    auto *traceEnum =
        createBranchingTraceDecl(&origBB, indices, derivativeFnGenSig, loopInfo);
    branchingTraceDecls.insert({&origBB, traceEnum});
    if (origBB.isEntry())
      continue;
    // Add branching trace enum field to corresponding linear map struct.
    auto *linearMapStruct = getLinearMapStruct(&origBB);
    auto *traceEnumField =
        addVarDecl(linearMapStruct,
                   astCtx.getIdentifier(traceEnumFieldName).str(),
                   traceEnum->getDeclaredInterfaceType());
    linearMapStructEnumFields.insert({linearMapStruct, traceEnumField});
  }

  // Add linear map fields to the linear map structs.
  for (auto &origBB : *original) {
    for (auto &inst : origBB) {
      if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
        // Check for active 'inout' arguments.
        bool isInout = false;
        auto paramInfos = ai->getSubstCalleeConv().getParameters();
        for (unsigned i : swift::indices(paramInfos)) {
          if (paramInfos[i].isIndirectInOut() &&
              activityInfo.isActive(ai->getArgumentsWithoutIndirectResults()[i],
                                    indices)) {
            // Reject functions with active inout arguments. It's not yet
            // supported.
            isInout = true;
            break;
          }
        }
        if (isInout)
          continue;

        // Add linear map field to struct for active `apply` instructions.
        // Skip array literal intrinsic applications since array literal
        // initialization is linear and handled separately.
        if (!shouldDifferentiateApplyInst(ai) || isArrayLiteralIntrinsic(ai))
          continue;

        LLVM_DEBUG(getADDebugStream() << "Adding linear map struct field for "
                                      << *ai);
        addLinearMapToStruct(context, ai, indices);
      }
    }
  }

  // Print generated linear map structs and branching trace enums.
  // These declarations do not show up with `-emit-sil` because they are
  // implicit. Instead, use `-Xllvm -debug-only=differentiation` to test
  // declarations with FileCheck.
  LLVM_DEBUG({
    auto &s = getADDebugStream();
    PrintOptions printOptions;
    printOptions.TypeDefinitions = true;
    printOptions.ExplodePatternBindingDecls = true;
    printOptions.SkipImplicit = false;
    s << "Generated linear map structs and branching trace enums for @"
      << original->getName() << ":\n";
    for (auto &origBB : *original) {
      auto *linearMapStruct = getLinearMapStruct(&origBB);
      linearMapStruct->print(s, printOptions); s << '\n';
    }
    for (auto &origBB : *original) {
      auto *traceEnum = getBranchingTraceDecl(&origBB);
      traceEnum->print(s, printOptions); s << '\n';
    }
  });
}

class DifferentiableActivityCollection {
public:
  SmallDenseMap<GenericSignature, DifferentiableActivityInfo> activityInfoMap;
  SILFunction &function;
  DominanceInfo *domInfo;
  PostDominanceInfo *postDomInfo;

  DifferentiableActivityInfo &getActivityInfo(
      GenericSignature assocGenSig, AutoDiffDerivativeFunctionKind kind) {
    auto activityInfoLookup = activityInfoMap.find(assocGenSig);
    if (activityInfoLookup != activityInfoMap.end())
      return activityInfoLookup->getSecond();
    auto insertion = activityInfoMap.insert(
        {assocGenSig, DifferentiableActivityInfo(*this, assocGenSig)});
    return insertion.first->getSecond();
  }

  explicit DifferentiableActivityCollection(SILFunction &f,
                                            DominanceInfo *di,
                                            PostDominanceInfo *pdi);
};

} // end anonymous namespace

std::unique_ptr<DifferentiableActivityCollection>
DifferentiableActivityAnalysis::newFunctionAnalysis(SILFunction *f) {
  assert(dominanceAnalysis && "Expect a valid dominance anaysis");
  assert(postDominanceAnalysis && "Expect a valid post-dominance anaysis");
  return llvm::make_unique<DifferentiableActivityCollection>(
      *f, dominanceAnalysis->get(f), postDominanceAnalysis->get(f));
}

void DifferentiableActivityAnalysis::initialize(SILPassManager *pm) {
  dominanceAnalysis = pm->getAnalysis<DominanceAnalysis>();
  postDominanceAnalysis = pm->getAnalysis<PostDominanceAnalysis>();
}

SILAnalysis *swift::createDifferentiableActivityAnalysis(SILModule *m) {
  return new DifferentiableActivityAnalysis();
}

DifferentiableActivityCollection::DifferentiableActivityCollection(
    SILFunction &f, DominanceInfo *di, PostDominanceInfo *pdi)
    : function(f), domInfo(di), postDomInfo(pdi) {}

DifferentiableActivityInfo::DifferentiableActivityInfo(
    DifferentiableActivityCollection &parent, GenericSignature derivGenSig)
    : parent(parent), derivativeGenericSignature(derivGenSig) {
  analyze(parent.domInfo, parent.postDomInfo);
}

SILFunction &DifferentiableActivityInfo::getFunction() {
  return parent.function;
}

void DifferentiableActivityInfo::analyze(DominanceInfo *di,
                                         PostDominanceInfo *pdi) {
  auto &function = getFunction();
  LLVM_DEBUG(getADDebugStream()
             << "Running activity analysis on @" << function.getName() << '\n');
  // Inputs are just function's arguments, count `n`.
  auto paramArgs = function.getArgumentsWithoutIndirectResults();
  for (auto value : paramArgs)
    inputValues.push_back(value);
  LLVM_DEBUG({
    auto &s = getADDebugStream();
    s << "Inputs in @" << function.getName() << ":\n";
    for (auto val : inputValues)
      s << val << '\n';
  });
  // Outputs are indirect result buffers and return values, count `m`.
  collectAllFormalResultsInTypeOrder(function, outputValues);
  LLVM_DEBUG({
    auto &s = getADDebugStream();
    s << "Outputs in @" << function.getName() << ":\n";
    for (auto val : outputValues)
      s << val << '\n';
  });

  // Mark inputs as varied.
  assert(variedValueSets.empty());
  for (auto input : inputValues)
    variedValueSets.push_back({input});
  // Propagate varied-ness through the function in dominance order.
  DominanceOrder domOrder(function.getEntryBlock(), di);
  while (auto *bb = domOrder.getNext()) {
    for (auto &inst : *bb) {
      for (auto i : indices(inputValues)) {
        // Handle `apply`.
        if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
          // If callee is non-varying, skip.
          if (isWithoutDerivative(ai->getCallee()))
            continue;
          // If any argument is varied, set all direct and indirect results as
          // varied.
          for (auto arg : ai->getArgumentsWithoutIndirectResults()) {
            if (isVaried(arg, i)) {
              for (auto indRes : ai->getIndirectSILResults())
                setVaried(indRes, i);
              forEachApplyDirectResult(ai, [&](SILValue directResult) {
                setVaried(directResult, i);
              });
            }
          }
        }
        // Handle store-like instructions:
        //   `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
#define PROPAGATE_VARIED_THROUGH_STORE(INST) \
        else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \
          if (isVaried(si->getSrc(), i)) \
            propagateVariedInwardsThroughProjections(si->getDest(), i); \
        }
        PROPAGATE_VARIED_THROUGH_STORE(Store)
        PROPAGATE_VARIED_THROUGH_STORE(StoreBorrow)
        PROPAGATE_VARIED_THROUGH_STORE(CopyAddr)
        PROPAGATE_VARIED_THROUGH_STORE(UnconditionalCheckedCastAddr)
#undef PROPAGATE_VARIED_THROUGH_STORE
        // Handle `tuple_element_addr`.
        else if (auto *teai = dyn_cast<TupleElementAddrInst>(&inst)) {
          if (isVaried(teai->getOperand(), i)) {
            auto projType = teai->getType().getASTType();
            if (derivativeGenericSignature && projType->hasArchetype())
              projType = derivativeGenericSignature->getCanonicalTypeInContext(
                  projType->mapTypeOutOfContext());
            if (projType->getAutoDiffAssociatedTangentSpace(
                    getLookupConformanceFunction()))
              setVaried(teai, i);
          }
        }
        // Handle `struct_extract` and `struct_element_addr` instructions.
        // - If the field is marked `@noDerivative`, do not set the result as
        // varied because it is not in the set of differentiable variables.
        // - Otherwise, propagate variedness from operand to result as usual.
#define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(INST) \
        else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
          if (isVaried(sei->getOperand(), i) && \
              !sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
            setVaried(sei, i); \
        }
        PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructExtract)
        PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructElementAddr)
#undef PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION
        // Handle `br`.
        else if (auto *bi = dyn_cast<BranchInst>(&inst)) {
          for (auto &op : bi->getAllOperands())
            if (isVaried(op.get(), i))
              setVaried(bi->getArgForOperand(&op), i);
        }
        // Handle `cond_br`.
        else if (auto *cbi = dyn_cast<CondBranchInst>(&inst)) {
          for (unsigned opIdx : indices(cbi->getTrueOperands())) {
            auto &op = cbi->getTrueOperands()[opIdx];
            if (isVaried(op.get(), i))
              setVaried(cbi->getTrueBB()->getArgument(opIdx), i);
          }
          for (unsigned opIdx : indices(cbi->getFalseOperands())) {
            auto &op = cbi->getFalseOperands()[opIdx];
            if (isVaried(op.get(), i))
              setVaried(cbi->getFalseBB()->getArgument(opIdx), i);
          }
        }
        // Handle `switch_enum`.
        else if (auto *sei = dyn_cast<SwitchEnumInst>(&inst)) {
          if (isVaried(sei->getOperand(), i))
            for (auto *succBB : sei->getSuccessorBlocks())
              for (auto *arg : succBB->getArguments())
                setVaried(arg, i);
        }
        // Handle everything else.
        else {
          for (auto &op : inst.getAllOperands())
            if (isVaried(op.get(), i))
              for (auto result : inst.getResults())
                setVaried(result, i);
        }
      }
    }
    domOrder.pushChildren(bb);
  }

  // Mark differentiable outputs as useful.
  assert(usefulValueSets.empty());
  for (auto output : outputValues) {
    usefulValueSets.push_back({});
    // If the output has an address or class type, propagate usefulness
    // recursively.
    if (output->getType().isAddress() ||
        output->getType().isClassOrClassMetatype())
      propagateUsefulThroughBuffer(output, usefulValueSets.size() - 1);
    // Otherwise, just mark the output as useful.
    else
      setUseful(output, usefulValueSets.size() - 1);
  }
  // Propagate usefulness through the function in post-dominance order.
  PostDominanceOrder postDomOrder(&*function.findReturnBB(), pdi);
  while (auto *bb = postDomOrder.getNext()) {
    for (auto &inst : reversed(*bb)) {
      for (auto i : indices(outputValues)) {
        // Handle indirect results in `apply`.
        if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
          if (isWithoutDerivative(ai->getCallee()))
            continue;
          auto checkAndSetUseful = [&](SILValue res) {
            if (isUseful(res, i))
              for (auto arg : ai->getArgumentsWithoutIndirectResults())
                setUseful(arg, i);
          };
          for (auto dirRes : ai->getResults())
            checkAndSetUseful(dirRes);
          for (auto indRes : ai->getIndirectSILResults())
            checkAndSetUseful(indRes);
          auto paramInfos = ai->getSubstCalleeConv().getParameters();
          for (auto i : indices(paramInfos))
            if (paramInfos[i].isIndirectInOut())
              checkAndSetUseful(ai->getArgumentsWithoutIndirectResults()[i]);
        }
        // Handle store-like instructions:
        //   `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
#define PROPAGATE_USEFUL_THROUGH_STORE(INST, PROPAGATE) \
        else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \
          if (isUseful(si->getDest(), i)) \
            PROPAGATE(si->getSrc(), i); \
        }
        PROPAGATE_USEFUL_THROUGH_STORE(Store, setUseful)
        PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow, setUseful)
        PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr, propagateUsefulThroughBuffer)
        PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr,
                                       propagateUsefulThroughBuffer)
#undef PROPAGATE_USEFUL_THROUGH_STORE
        // Handle struct element extraction, skipping `@noDerivative` fields:
        //   `struct_extract`, `struct_element_addr`.
#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST, PROPAGATE) \
        else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
          if (isUseful(sei, i)) { \
            auto hasNoDeriv = sei->getField()->getAttrs() \
                .hasAttribute<NoDerivativeAttr>(); \
            if (!hasNoDeriv) \
              PROPAGATE(sei->getOperand(), i); \
          } \
        }
        PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract, setUseful)
        PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr,
                                                   propagateUsefulThroughBuffer)
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
        // Handle everything else.
        else if (llvm::any_of(inst.getResults(),
          [&](SILValue res) { return isUseful(res, i); })) {
          for (auto &op : inst.getAllOperands()) {
            auto value = op.get();
            if (value->getType().isAddress())
              propagateUsefulThroughBuffer(value, i);
            setUseful(value, i);
          }
        }
      }
    }
    // Propagate usefulness from basic block arguments to incoming phi values.
    for (auto i : indices(outputValues)) {
      for (auto *arg : bb->getArguments()) {
        if (isUseful(arg, i)) {
          SmallVector<SILValue, 4> incomingValues;
          arg->getSingleTerminatorOperands(incomingValues);
          for (auto incomingValue : incomingValues)
            setUseful(incomingValue, i);
        }
      }
    }
    postDomOrder.pushChildren(bb);
  }
}

void DifferentiableActivityInfo::setVariedAcrossArrayInitialization(
    SILValue value, unsigned independentVariableIndex) {
  auto uai = getAllocateUninitializedArrayIntrinsic(value);
  if (!uai) return;
  for (auto use : value->getUses())
    if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser()))
      // The first tuple field of the intrinsic's return value is the array.
      setVaried(dti->getResult(0), independentVariableIndex);
}

void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization(
    SILValue value, unsigned dependentVariableIndex) {
  // Array initializer syntax is lowered to an intrinsic and one or more
  // stores to a `RawPointer` returned by the intrinsic.
  auto uai = getAllocateUninitializedArrayIntrinsic(value);
  if (!uai) return;
  for (auto use : value->getUses()) {
    auto dti = dyn_cast<DestructureTupleInst>(use->getUser());
    if (!dti) continue;
    // The second tuple field of the return value is the `RawPointer`.
    for (auto use : dti->getResult(1)->getUses()) {
      // The `RawPointer` passes through a `pointer_to_address`. That
      // instruction's first use is a `store` whose src is useful; its
      // subsequent uses are `index_addr`s whose only use is a useful `store`.
      for (auto use : use->getUser()->getResult(0)->getUses()) {
        auto inst = use->getUser();
        if (auto si = dyn_cast<StoreInst>(inst)) {
          setUseful(si->getSrc(), dependentVariableIndex);
        } else if (auto iai = dyn_cast<IndexAddrInst>(inst)) {
          for (auto use : iai->getUses())
            if (auto si = dyn_cast<StoreInst>(use->getUser()))
              setUseful(si->getSrc(), dependentVariableIndex);
        }
      }
    }
  }
}

void DifferentiableActivityInfo::setVaried(SILValue value,
                                           unsigned independentVariableIndex) {
  variedValueSets[independentVariableIndex].insert(value);
  setVariedAcrossArrayInitialization(value, independentVariableIndex);
}

void DifferentiableActivityInfo::setUseful(SILValue value,
                                           unsigned dependentVariableIndex) {
  usefulValueSets[dependentVariableIndex].insert(value);
  setUsefulAcrossArrayInitialization(value, dependentVariableIndex);
}

void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections(
    SILValue value, unsigned independentVariableIndex) {
#define SKIP_NODERIVATIVE(INST) \
  if (auto *sei = dyn_cast<INST##Inst>(value)) \
    if (sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
      return;
  SKIP_NODERIVATIVE(StructExtract)
  SKIP_NODERIVATIVE(StructElementAddr)
#undef SKIP_NODERIVATIVE
  setVaried(value, independentVariableIndex);
  auto *inst = value->getDefiningInstruction();
  if (!inst || isa<ApplyInst>(inst))
    return;
  // Standard propagation.
  for (auto &op : inst->getAllOperands())
    propagateVariedInwardsThroughProjections(
        op.get(), independentVariableIndex);
}

void DifferentiableActivityInfo::propagateUsefulThroughBuffer(
    SILValue value, unsigned dependentVariableIndex) {
  assert(value->getType().isAddress() ||
         value->getType().isClassOrClassMetatype());
  // Check whether value is already useful to prevent infinite recursion.
  if (isUseful(value, dependentVariableIndex))
    return;
  setUseful(value, dependentVariableIndex);
  if (auto *inst = value->getDefiningInstruction())
    for (auto &operand : inst->getAllOperands())
      if (operand.get()->getType().isAddress())
        propagateUsefulThroughBuffer(operand.get(), dependentVariableIndex);
  // Recursively propagate usefulness through users that are projections or
  // `begin_access` instructions.
  for (auto use : value->getUses()) {
    for (auto res : use->getUser()->getResults()) {
#define SKIP_NODERIVATIVE(INST) \
      if (auto *sei = dyn_cast<INST##Inst>(res)) \
        if (sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
          continue;
      SKIP_NODERIVATIVE(StructExtract)
      SKIP_NODERIVATIVE(StructElementAddr)
#undef SKIP_NODERIVATIVE
      if (Projection::isAddressProjection(res) || isa<BeginAccessInst>(res))
        propagateUsefulThroughBuffer(res, dependentVariableIndex);
    }
  }
}

bool DifferentiableActivityInfo::isVaried(
    SILValue value, unsigned independentVariableIndex) const {
  assert(independentVariableIndex < variedValueSets.size() &&
         "Independent variable index out of range");
  auto &set = variedValueSets[independentVariableIndex];
  return set.count(value);
}

bool DifferentiableActivityInfo::isVaried(
    SILValue value, IndexSubset *parameterIndices) const {
  for (auto paramIdx : parameterIndices->getIndices())
    if (isVaried(value, paramIdx))
      return true;
  return false;
}

bool DifferentiableActivityInfo::isUseful(
    SILValue value, unsigned dependentVariableIndex) const {
  assert(dependentVariableIndex < usefulValueSets.size() &&
         "Dependent variable index out of range");
  auto &set = usefulValueSets[dependentVariableIndex];
  return set.count(value);
}

bool DifferentiableActivityInfo::isActive(
    SILValue value, const SILAutoDiffIndices &indices) const {
  return isVaried(value, indices.parameters) && isUseful(value, indices.source);
}

Activity DifferentiableActivityInfo::getActivity(
    SILValue value, const SILAutoDiffIndices &indices) const {
  Activity activity;
  if (isVaried(value, indices.parameters))
    activity |= ActivityFlags::Varied;
  if (isUseful(value, indices.source))
    activity |= ActivityFlags::Useful;
  return activity;
}

Activity DifferentiableActivityInfo::getActivity(
    SILInstruction *inst, const SILAutoDiffIndices &indices) const {
  Activity activity;
  for (auto result : inst->getResults())
    activity |= getActivity(result, indices);
  return activity;
}

static void dumpActivityInfo(SILValue value,
                             const SILAutoDiffIndices &indices,
                             const DifferentiableActivityInfo &activityInfo,
                             llvm::raw_ostream &s = llvm::dbgs()) {
  s << '[';
  auto activity = activityInfo.getActivity(value, indices);
  switch (activity.toRaw()) {
  case 0: s << "NONE"; break;
  case (unsigned)ActivityFlags::Varied: s << "VARIED"; break;
  case (unsigned)ActivityFlags::Useful: s << "USEFUL"; break;
  case (unsigned)ActivityFlags::Active: s << "ACTIVE"; break;
  }
  s << "] " << value;
}

static void dumpActivityInfo(SILFunction &fn,
                             const SILAutoDiffIndices &indices,
                             const DifferentiableActivityInfo &activityInfo,
                             llvm::raw_ostream &s = llvm::dbgs()) {
  s << "Activity info for " << fn.getName() << " at " << indices << '\n';
  for (auto &bb : fn) {
    s << "bb" << bb.getDebugID() << ":\n";
    for (auto *arg : bb.getArguments())
      dumpActivityInfo(arg, indices, activityInfo, s);
    for (auto &inst : bb)
      for (auto res : inst.getResults())
        dumpActivityInfo(res, indices, activityInfo, s);
    s << '\n';
  }
}

/// If the original function doesn't have a return, it cannot be differentiated.
/// Returns true if error is emitted.
static bool diagnoseNoReturn(ADContext &context, SILFunction *original,
                             DifferentiationInvoker invoker) {
  if (original->findReturnBB() != original->end())
    return false;
  context.emitNondifferentiabilityError(
      original->getLocation().getEndSourceLoc(), invoker,
      diag::autodiff_missing_return);
  return true;
}

/// If the original function contains unsupported control flow, emit a "control
/// flow unsupported" error at appropriate source locations. Returns true if
/// error is emitted.
///
/// Update as control flow support is added. Currently, branching terminators
/// other than `br`, `cond_br`, `switch_enum` are not supported.
static bool diagnoseUnsupportedControlFlow(ADContext &context,
                                           SILFunction *original,
                                           DifferentiationInvoker invoker) {
  if (original->getBlocks().size() <= 1)
    return false;
  // Diagnose unsupported branching terminators.
  for (auto &bb : *original) {
    auto *term = bb.getTerminator();
    // Supported terminators are: `br`, `cond_br`, `switch_enum`.
    if (isa<BranchInst>(term) || isa<CondBranchInst>(term) ||
        isa<SwitchEnumInst>(term))
      continue;
    // If terminator is an unsupported branching terminator, emit an error.
    if (term->isBranch()) {
      context.emitNondifferentiabilityError(
          term, invoker, diag::autodiff_control_flow_not_supported);
      return true;
    }
  }
  return false;
}

/// Check whether the given requirements are satisfied, with the given
/// derivative generic signature (containing requirements), original function,
/// and substitution map. Returns true if error is emitted.
static bool diagnoseUnsatisfiedRequirements(ADContext &context,
                                            GenericSignature derivativeGenSig,
                                            SILFunction *original,
                                            SubstitutionMap substMap,
                                            DifferentiationInvoker invoker,
                                            SourceLoc loc) {
  // If there are no derivative requirements, return false.
  if (!derivativeGenSig)
    return false;
  auto requirements = derivativeGenSig->getRequirements();
  if (requirements.empty())
    return false;
  // Iterate through all requirements and check whether they are satisfied.
  auto *swiftModule = context.getModule().getSwiftModule();
  SmallVector<Requirement, 2> unsatisfiedRequirements;
  for (auto req : requirements) {
    auto firstType = req.getFirstType();
    Type secondType;
    // Substitute first and second types using the given substitution map,
    // looking up conformances in the current module, if possible.
    if (auto substFirstType =
            firstType.subst(QuerySubstitutionMap{substMap},
                            LookUpConformanceInModule(swiftModule))) {
      firstType = substFirstType;
    }
    if (req.getKind() != RequirementKind::Layout) {
      secondType = req.getSecondType();
      if (auto substSecondType =
              secondType.subst(QuerySubstitutionMap{substMap},
                               LookUpConformanceInModule(swiftModule))) {
        secondType = substSecondType;
      }
    }
    switch (req.getKind()) {
    // Check layout requirements.
    case RequirementKind::Layout: {
      auto layout = req.getLayoutConstraint();
      switch (layout->getKind()) {
      case LayoutConstraintKind::Class:
        if (!firstType->satisfiesClassConstraint())
          unsatisfiedRequirements.push_back(req);
        continue;
      default:
        // TODO: Check other layout requirements. Note that `@differentiable`
        // attribute type-checking does not yet support layout requirements in
        // where clauses; layout requirements in derivative generic signatures
        // can be formed only from `differentiable_function` instructions whose
        // original function operand is generic with layout requirements.
        break;
      }
      continue;
    }
    // Check same type requirements.
    case RequirementKind::SameType:
      // If the first type does not equal the second type, then record the
      // unsatisfied requirement.
      if (!firstType->isEqual(secondType))
        unsatisfiedRequirements.push_back(req);
      continue;
    // Check superclass requirements.
    case RequirementKind::Superclass: {
      // If the second type is not an exact superclass of second type, then
      // record the unsatisfied requirement.
      if (!secondType->isExactSuperclassOf(firstType))
        unsatisfiedRequirements.push_back(req);
      continue;
    }
    // Check conformance requirements.
    case RequirementKind::Conformance: {
      auto protocolType = req.getSecondType()->castTo<ProtocolType>();
      auto protocol = protocolType->getDecl();
      assert(protocol && "Expected protocol in generic signature requirement");
      // If the first type does not conform to the second type in the current
      // module, then record the unsatisfied requirement.
      if (!swiftModule->lookupConformance(firstType, protocol))
        unsatisfiedRequirements.push_back(req);
      continue;
    }
    }
  }
  if (unsatisfiedRequirements.empty())
    return false;
  // Diagnose unsatisfied requirements.
  std::string reqText;
  llvm::raw_string_ostream stream(reqText);
  interleave(unsatisfiedRequirements,
             [&](Requirement req) { req.print(stream, PrintOptions()); },
             [&] { stream << ", "; });
  context.emitNondifferentiabilityError(
      loc, invoker, diag::autodiff_function_assoc_func_unmet_requirements,
      stream.str());
  return true;
}

//===----------------------------------------------------------------------===//
// Code emission utilities
//===----------------------------------------------------------------------===//

/// Given a value, extracts all elements to `results` from this value if it has
/// a tuple type. Otherwise, add this value directly to `results`.
static void extractAllElements(SILValue value, SILBuilder &builder,
                               SmallVectorImpl<SILValue> &results) {
  auto tupleType = value->getType().getAs<TupleType>();
  if (!tupleType) {
    results.push_back(value);
    return;
  }
  if (builder.hasOwnership()) {
    auto *dti = builder.createDestructureTuple(value.getLoc(), value);
    results.append(dti->getResults().begin(), dti->getResults().end());
    return;
  }
  for (auto i : range(tupleType->getNumElements()))
    results.push_back(builder.createTupleExtract(value.getLoc(), value, i));
}

/// Given a range of elements, joins these into a single value. If there's
/// exactly one element, returns that element. Otherwise, creates a tuple using
/// a `tuple` instruction.
static SILValue joinElements(ArrayRef<SILValue> elements, SILBuilder &builder,
                             SILLocation loc) {
  if (elements.size() == 1)
    return elements.front();
  return builder.createTuple(loc, elements);
}

/// Given an apply site, emit copies of all parameters and place them in
/// `copiedArgs`. Any buffers that need to be destroyed will be added to
/// `newArgsToDestroy`. Any new buffers that need to be deallocated will be
/// added to `newBuffersToDealloc`. This helper is used for duplicating an
/// apply site.
static void copyParameterArgumentsForApply(
    ApplySite applySite, SmallVectorImpl<SILValue> &copiedArgs,
    SmallVectorImpl<SILValue> &newArgsToDestroy,
    SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
  LLVM_DEBUG({
    auto &s = getADDebugStream() << "Copying arguments from apply site: ";
    applySite.getInstruction()->print(s);
  });
  auto loc = applySite.getLoc();
  copiedArgs.reserve(applySite.getNumArguments());
  SILBuilder copyBuilder(applySite.getInstruction());
  for (auto &argOperand : applySite.getArgumentOperands()) {
    auto arg = argOperand.get();
    auto argConv = applySite.getArgumentConvention(argOperand);
    auto collectNewArg = [&](SILValue newArg) {
      copiedArgs.push_back(newArg);
      if (argConv.isGuaranteedConvention() &&
          argConv != SILArgumentConvention::Indirect_InoutAliasable)
        newArgsToDestroy.push_back(newArg);
    };
    // Copy the argument if it's to be owned by the newly created closure.
    // Objects are to be retained.
    if (arg->getType().isObject()) {
      auto newArg = copyBuilder.emitCopyValueOperation(loc, arg);
      collectNewArg(newArg);
      continue;
    }
    // Addresses depend on argument conventions.
    // If the argument is an aliasable inout reference, do not copy the
    // argument since it's a `@noescape` capture.
    if (argConv == SILArgumentConvention::Indirect_InoutAliasable) {
      collectNewArg(arg);
      continue;
    }
    // Otherwise, it must be address-only. Create a new buffer and perform
    // `copy_addr`.
    auto *argCopy = copyBuilder.createAllocStack(loc, arg->getType());
    newBuffersToDealloc.push_back(argCopy);
    copyBuilder.createCopyAddr(loc, arg, argCopy, IsNotTake,
                               IsInitialization);
    collectNewArg(argCopy);
  }
}

/// When a function value is used in an instruction (usually `apply`), there's
/// some conversion instruction in between, e.g. `thin_to_thick_function`. Given
/// a new function value and an old function value, this helper function
/// recursively converts the new function just like how the old function is
/// converted. If the new function's generic signature is specified, it is used
/// to create substitution maps for reapplied `partial_apply` instructions.
static SILValue
reapplyFunctionConversion(
    SILValue newFunc, SILValue oldFunc, SILValue oldConvertedFunc,
    SILBuilder &builder, SILLocation loc,
    SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc,
    GenericSignature newFuncGenSig = GenericSignature()) {
  // If the old func is the new func, then there's no conversion.
  if (oldFunc == oldConvertedFunc)
    return newFunc;
  // Handle a few instruction cases.
  // thin_to_thick_function
  if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(oldConvertedFunc)) {
    auto innerNewFunc = reapplyFunctionConversion(
        newFunc, oldFunc, tttfi->getOperand(), builder, loc,
        newBuffersToDealloc, newFuncGenSig);
    auto operandFnTy = innerNewFunc->getType().castTo<SILFunctionType>();
    auto thickTy = operandFnTy->getWithRepresentation(
        SILFunctionTypeRepresentation::Thick);
    auto silTy = SILType::getPrimitiveObjectType(thickTy);
    return builder.createThinToThickFunction(loc, innerNewFunc, silTy);
  }
  // partial_apply
  if (auto *pai = dyn_cast<PartialApplyInst>(oldConvertedFunc)) {
    SmallVector<SILValue, 8> newArgs;
    newArgs.reserve(pai->getNumArguments());
    SmallVector<SILValue, 1> newArgsToDestroy;
    copyParameterArgumentsForApply(pai, newArgs, newArgsToDestroy,
                                   newBuffersToDealloc);
    auto innerNewFunc = reapplyFunctionConversion(
        newFunc, oldFunc, pai->getCallee(), builder, loc, newBuffersToDealloc,
        newFuncGenSig);
    // If new function's generic signature is specified, use it to create
    // substitution map for reapplied `partial_apply` instruction.
    auto substMap = !newFuncGenSig
        ? pai->getSubstitutionMap()
        : SubstitutionMap::get(
              newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()},
              LookUpConformanceInModule(builder.getModule().getSwiftModule()));
    return builder.createPartialApply(loc, innerNewFunc, substMap, newArgs,
                                      ParameterConvention::Direct_Guaranteed);
  }
  llvm_unreachable("Unhandled function conversion instruction");
}

/// Emits a reference to a derivative function of `original`, differentiated
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
/// the derivative function and the actual indices that the derivative function
/// is with respect to.
///
/// Returns `None` on failure, signifying that a diagnostic has been emitted.
///
/// Creates new differentiation tasks, if necessary, using `invoker` as the
/// invoker. Calls `taskCallback` for all newly-created tasks (but may also call
/// `taskCallback` for already-existing tasks), so that the caller can make sure
/// that the task actually gets executed.
///
/// FIXME: This is too complicated and needs to be rewritten.
static Optional<std::pair<SILValue, SILAutoDiffIndices>>
emitDerivativeFunctionReference(
    ADContext &context, SILBuilder &builder, SILAutoDiffIndices desiredIndices,
    AutoDiffDerivativeFunctionKind kind, SILValue original,
    DifferentiationInvoker invoker,
    SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {

  SILValue functionSource = original;

  // If `original` is itself an `DifferentiableFunctionExtractInst` whose kind matches
  // the given kind and desired differentiation parameter indices, simply
  // extract the derivative function of its function operand, retain the
  // derivative function, and return it.
  if (auto *inst = original->getDefiningInstruction())
    if (auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(inst))
      if (dfei->getExtractee() ==
              NormalDifferentiableFunctionTypeComponent::Original)
        functionSource = dfei->getFunctionOperand();

  // If `functionSource` is a `@differentiable` function, just extract the
  // derivative function.
  if (auto diffableFnType =
          functionSource->getType().castTo<SILFunctionType>()) {
    if (diffableFnType->isDifferentiable()) {
      auto paramIndices = diffableFnType->getDifferentiationParameterIndices();
      for (auto i : desiredIndices.parameters->getIndices()) {
        if (!paramIndices->contains(i)) {
          context.emitNondifferentiabilityError(functionSource, invoker,
              diag::autodiff_function_nondiff_parameter_not_differentiable);
          return None;
        }
      }
      auto borrowedDiffFunc = builder.emitBeginBorrowOperation(
          functionSource.getLoc(), functionSource);
      SILValue derivativeFn = builder.createDifferentiableFunctionExtract(
          borrowedDiffFunc.getLoc(), kind, borrowedDiffFunc);
      derivativeFn =
          builder.emitCopyValueOperation(functionSource.getLoc(), derivativeFn);
      builder.emitEndBorrowOperation(functionSource.getLoc(), borrowedDiffFunc);
      SILAutoDiffIndices indices(0, desiredIndices.parameters);
      return std::make_pair(derivativeFn, indices);
    }
  }

  // Find local function reference.
  if (auto *originalFRI =
          peerThroughFunctionConversions<FunctionRefInst>(original)) {
    auto loc = originalFRI->getLoc();
    auto *originalFn = originalFRI->getReferencedFunctionOrNull();
    // Attempt to look up a `[differentiable]` attribute that minimally
    // satisfies the specified indices.
    // TODO(TF-482): Change `lookUpMinimalDifferentiableAttr` to additionally
    // check whether `[differentiable]` attribute generic requirements are
    // satisfied.
    auto *minimalAttr =
        context.lookUpMinimalDifferentiableAttr(originalFn, desiredIndices);
    if (!minimalAttr) {
      // If the function is intentionally marked as being opaque to
      // differentiation, then we should not create a task for it.
      if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
        context.emitNondifferentiabilityError(original, invoker,
            diag::autodiff_opaque_function_not_differentiable);
        return None;
      }
      // Check and diagnose non-differentiable arguments.
      auto originalFnTy = originalFn->getLoweredFunctionType();
      for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
        if (desiredIndices.isWrtParameter(paramIndex) &&
            !originalFnTy->getParameters()[paramIndex]
                 .getSILStorageType()
                 .isDifferentiable(context.getModule())) {
          auto diag = context.emitNondifferentiabilityError(
              original, invoker, diag::autodiff_nondifferentiable_argument);
          return None;
        }
      }
      // Check and diagnose non-differentiable results.
      if (!originalFnTy->getResults()[desiredIndices.source]
               .getSILStorageType()
               .isDifferentiable(context.getModule())) {
        context.emitNondifferentiabilityError(
            original, invoker, diag::autodiff_nondifferentiable_result);
        return None;
      }
      // Check and diagnose external declarations.
      if (originalFn->isExternalDeclaration()) {
        context.emitNondifferentiabilityError(
            original, invoker,
            diag::autodiff_external_nondifferentiable_function);
        return None;
      }
      // Sanity check passed. Create a new `[differentiable]` attribute and
      // process it it.
      GenericSignature contextualDerivativeGenSig = GenericSignature();
      if (invoker.getKind() ==
          DifferentiationInvoker::Kind::IndirectDifferentiation)
        contextualDerivativeGenSig = invoker.getIndirectDifferentiation().second
            ->getDerivativeGenericSignature();
      auto *newAttr = context.getOrCreateDifferentiableAttr(
          originalFn, desiredIndices, contextualDerivativeGenSig);
      if (context.processDifferentiableAttribute(originalFn, newAttr, invoker))
        return None;
      minimalAttr = newAttr;
    }
    assert(minimalAttr);
    // TODO(TF-482): Move generic requirement checking logic to
    // `lookUpMinimalDifferentiableAttr`.
    // Get the substitution map for checking unmet generic requirements.
    // By default, use the forwarding substitution map of the original function.
    // If the original callee is a `partial_apply` or `apply` instruction, use
    // its substitution map instead.
    auto substMap = original->getFunction()->getForwardingSubstitutionMap();
    if (auto *pai = dyn_cast<PartialApplyInst>(original)) {
      substMap = pai->getSubstitutionMap();
    } else if (auto *ai = dyn_cast<ApplyInst>(original)) {
      substMap = ai->getSubstitutionMap();
    }
    if (diagnoseUnsatisfiedRequirements(
            context, minimalAttr->getDerivativeGenericSignature(), originalFn,
            substMap, invoker, original.getLoc().getSourceLoc()))
      return None;
    if (context.processDifferentiableAttribute(
            originalFn, minimalAttr, invoker))
      return None;
    SILFunction *derivativeFn = nullptr;
    switch (kind) {
    case AutoDiffDerivativeFunctionKind::JVP:
      assert(!minimalAttr->getJVPName().empty() && "Expected JVP name");
      derivativeFn = context.getModule().lookUpFunction(minimalAttr->getJVPName());
      break;
    case AutoDiffDerivativeFunctionKind::VJP:
      assert(!minimalAttr->getVJPName().empty() && "Expected VJP name");
      derivativeFn = context.getModule().lookUpFunction(minimalAttr->getVJPName());
      break;
    }
    auto *derivativeFnRef = builder.createFunctionRef(loc, derivativeFn);
    // FIXME(TF-201): Handle direct differentiation of reabstraction thunks.
    // Tentative solution: clone a new reabstraction thunk where function
    // argument has a `@differentiable` function type.
    if (originalFn->isThunk() == IsReabstractionThunk) {
      // Handle here.
    }
    auto convertedRef = reapplyFunctionConversion(
        derivativeFnRef, originalFRI, original, builder, loc,
        newBuffersToDealloc,
        derivativeFn->getLoweredFunctionType()->getGenericSignature());
    return std::make_pair(convertedRef, minimalAttr->getIndices());
  }

  // Find witness method retrieval.
  if (auto *witnessMethod =
          peerThroughFunctionConversions<WitnessMethodInst>(original)) {
    auto loc = witnessMethod->getLoc();
    auto requirementDeclRef = witnessMethod->getMember();
    auto *requirementDecl = requirementDeclRef.getDecl();
    auto witnessMethodType = witnessMethod->getType().castTo<SILFunctionType>();
    // If requirement declaration does not have any `@differentiable`
    // attributes, produce an error.
    if (!requirementDecl->getAttrs().hasAttribute<DifferentiableAttr>()) {
      context.emitNondifferentiabilityError(
          original, invoker, diag::autodiff_protocol_member_not_differentiable);
      return None;
    }
    // Get the minimal `@differentiable` attribute and parameter index subset.
    const DifferentiableAttr *minimalAttr;
    IndexSubset *minimalParamIndexSet;
    std::tie(minimalAttr, minimalParamIndexSet) =
        context.lookUpMinimalASTDifferentiableAttrAndIndexSubset(
            requirementDeclRef, witnessMethodType, desiredIndices);
    SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet);
    // If minimal `@differentiable` attribute does not exist, then no attribute
    // exists with a superset of the desired indices. Produce an error.
    if (!minimalAttr) {
      context.emitNondifferentiabilityError(
          original, invoker,
          diag::autodiff_member_subset_indices_not_differentiable);
      return None;
    }
    // Emit a `witness_method` instruction for the derivative function.
    auto originalType = witnessMethod->getType().castTo<SILFunctionType>();
    auto assocType = originalType->getAutoDiffDerivativeFunctionType(
        minimalIndices.parameters, minimalIndices.source,
        kind, context.getTypeConverter(),
        LookUpConformanceInModule(builder.getModule().getSwiftModule()));
    auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
        kind, minimalAttr->getParameterIndices(), context.getASTContext());
    auto *ref = builder.createWitnessMethod(
        loc, witnessMethod->getLookupType(), witnessMethod->getConformance(),
        requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
        SILType::getPrimitiveObjectType(assocType));
    auto convertedRef =
        reapplyFunctionConversion(ref, witnessMethod, original, builder, loc,
                                  newBuffersToDealloc);
    return std::make_pair(convertedRef, minimalIndices);
  }

  // Find class method.
  if (auto *classMethodInst =
          peerThroughFunctionConversions<ClassMethodInst>(original)) {
    auto loc = classMethodInst->getLoc();
    auto methodDeclRef = classMethodInst->getMember();
    auto *methodDecl = methodDeclRef.getDecl();
    auto classMethodType = classMethodInst->getType().castTo<SILFunctionType>();
    // If method declaration does not have any `@differentiable` attributes,
    // produce an error.
    if (!methodDecl->getAttrs().hasAttribute<DifferentiableAttr>()) {
      context.emitNondifferentiabilityError(
          original, invoker, diag::autodiff_class_member_not_differentiable);
      return None;
    }
    // Get the minimal `@differentiable` attribute and parameter index subset.
    const DifferentiableAttr *minimalAttr;
    IndexSubset *minimalParamIndexSet;
    std::tie(minimalAttr, minimalParamIndexSet) =
        context.lookUpMinimalASTDifferentiableAttrAndIndexSubset(
            methodDeclRef, classMethodType, desiredIndices);
    SILAutoDiffIndices minimalIndices(/*source*/ 0, minimalParamIndexSet);
    // If minimal `@differentiable` attribute does not exist, then no attribute
    // exists with a superset of the desired indices. Produce an error.
    if (!minimalAttr) {
      context.emitNondifferentiabilityError(
          original, invoker,
          diag::autodiff_member_subset_indices_not_differentiable);
      return None;
    }
    // Emit a `class_method` instruction for the derivative function.
    auto originalType = classMethodInst->getType().castTo<SILFunctionType>();
    auto assocType = originalType->getAutoDiffDerivativeFunctionType(
        minimalIndices.parameters, minimalIndices.source,
        kind, context.getTypeConverter(),
        LookUpConformanceInModule(builder.getModule().getSwiftModule()));
    auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
        kind, minimalAttr->getParameterIndices(),
        context.getASTContext());
    auto *ref = builder.createClassMethod(
        loc, classMethodInst->getOperand(),
        methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
        SILType::getPrimitiveObjectType(assocType));
    auto convertedRef =
        reapplyFunctionConversion(ref, classMethodInst, original, builder, loc,
                                  newBuffersToDealloc);
    return std::make_pair(convertedRef, minimalIndices);
  }

  // Emit the general opaque function error.
  context.emitNondifferentiabilityError(original, invoker,
      diag::autodiff_opaque_function_not_differentiable);
  return None;
}

/// Emit a zero value into the given buffer access by calling
/// `AdditiveArithmetic.zero`. The given type must conform to
/// `AdditiveArithmetic`.
static void emitZeroIntoBuffer(
    SILBuilder &builder, CanType type, SILValue bufferAccess,
    SILLocation loc) {
  auto &astCtx = builder.getASTContext();
  auto *swiftMod = builder.getModule().getSwiftModule();
  auto &typeConverter = builder.getModule().Types;
  // Look up conformance to `AdditiveArithmetic`.
  auto *additiveArithmeticProto =
      astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
  auto confRef = swiftMod->lookupConformance(type, additiveArithmeticProto);
  assert(confRef.hasValue() && "Missing conformance to `AdditiveArithmetic`");
  // Look up `AdditiveArithmetic.zero.getter`.
  auto zeroDeclLookup = additiveArithmeticProto->lookupDirect(astCtx.Id_zero);
  auto *zeroDecl = cast<VarDecl>(zeroDeclLookup.front());
  assert(zeroDecl->isProtocolRequirement());
  auto *accessorDecl = zeroDecl->getAccessor(AccessorKind::Get);
  SILDeclRef accessorDeclRef(accessorDecl, SILDeclRef::Kind::Func);
  auto silFnType = typeConverter.getConstantType(accessorDeclRef);
  // %wm = witness_method ...
  auto *getter = builder.createWitnessMethod(
      loc, type, *confRef, accessorDeclRef, silFnType);
  // %metatype = metatype $T
  auto metatypeType = CanMetatypeType::get(
      type, MetatypeRepresentation::Thick);
  auto metatype = builder.createMetatype(
      loc, SILType::getPrimitiveObjectType(metatypeType));
  auto subMap = SubstitutionMap::getProtocolSubstitutions(
      additiveArithmeticProto, type, *confRef);
  builder.createApply(loc, getter, subMap, {bufferAccess, metatype},
                      /*isNonThrowing*/ false);
  builder.emitDestroyValueOperation(loc, getter);
}

//===----------------------------------------------------------------------===//
// Thunk helpers
//===----------------------------------------------------------------------===//
// These helpers are copied/adapted from SILGen. They should be refactored and
// moved to a shared location.
//===----------------------------------------------------------------------===//

static CanGenericSignature
buildThunkSignature(SILFunction *fn,
                    bool inheritGenericSig,
                    OpenedArchetypeType *openedExistential,
                    GenericEnvironment *&genericEnv,
                    SubstitutionMap &contextSubs,
                    SubstitutionMap &interfaceSubs,
                    ArchetypeType *&newArchetype) {
  // If there's no opened existential, we just inherit the generic environment
  // from the parent function.
  if (openedExistential == nullptr) {
    auto genericSig = fn->getLoweredFunctionType()->getGenericSignature();
    genericEnv = fn->getGenericEnvironment();
    interfaceSubs = fn->getForwardingSubstitutionMap();
    contextSubs = interfaceSubs;
    return genericSig;
  }

  auto &ctx = fn->getASTContext();
  GenericSignatureBuilder builder(ctx);

  // Add the existing generic signature.
  int depth = 0;
  if (inheritGenericSig) {
    if (auto genericSig =
            fn->getLoweredFunctionType()->getGenericSignature()) {
      builder.addGenericSignature(genericSig);
      depth = genericSig->getGenericParams().back()->getDepth() + 1;
    }
  }

  // Add a new generic parameter to replace the opened existential.
  auto *newGenericParam = GenericTypeParamType::get(depth, 0, ctx);

  builder.addGenericParameter(newGenericParam);
  Requirement newRequirement(RequirementKind::Conformance, newGenericParam,
                             openedExistential->getOpenedExistentialType());
  auto source =
      GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
  builder.addRequirement(newRequirement, source, nullptr);

  auto genericSig = std::move(builder).computeGenericSignature(
      SourceLoc(), /*allowConcreteGenericParams=*/true);
  genericEnv = genericSig->getGenericEnvironment();

  newArchetype = genericEnv->mapTypeIntoContext(newGenericParam)
      ->castTo<ArchetypeType>();

  // Calculate substitutions to map the caller's archetypes to the thunk's
  // archetypes.
  if (auto calleeGenericSig =
          fn->getLoweredFunctionType()->getGenericSignature()) {
    contextSubs = SubstitutionMap::get(
        calleeGenericSig,
        [&](SubstitutableType *type) -> Type {
          return genericEnv->mapTypeIntoContext(type);
        },
        MakeAbstractConformanceForGenericType());
  }

  // Calculate substitutions to map interface types to the caller's archetypes.
  interfaceSubs = SubstitutionMap::get(
      genericSig,
      [&](SubstitutableType *type) -> Type {
        if (type->isEqual(newGenericParam))
          return openedExistential;
        return fn->mapTypeIntoContext(type);
      },
      MakeAbstractConformanceForGenericType());

  return genericSig->getCanonicalSignature();

}

/// The thunk kinds used in the differentiation transform.
enum class DifferentiationThunkKind {
  /// A reabstraction thunk.
  ///
  /// Reabstraction thunks transform a function-typed value to another one with
  /// different parameter/result abstraction patterns. This is identical to the
  /// thunks generated by SILGen.
  Reabstraction,

  /// An index subset thunk.
  ///
  /// An index subset thunk is used transform JVP/VJPs into a version that is
  /// "wrt" fewer differentiation parameters.
  /// - Differentials of thunked JVPs use zero for non-requested differentiation
  //    parameters.
  /// - Pullbacks of thunked VJPs discard results for non-requested
  ///   differentiation parameters.
  IndexSubset
};

/// Build the type of a function transformation thunk.
static CanSILFunctionType buildThunkType(SILFunction *fn,
                                         CanSILFunctionType &sourceType,
                                         CanSILFunctionType &expectedType,
                                         GenericEnvironment *&genericEnv,
                                         SubstitutionMap &interfaceSubs,
                                         bool withoutActuallyEscaping,
                                         DifferentiationThunkKind thunkKind) {
  assert(!expectedType->isPolymorphic());
  assert(!sourceType->isPolymorphic());

  auto &module = fn->getModule();
  auto origType = sourceType;

  // Cannot build a reabstraction thunk without context. Ownership semantics
  // on the result type are required.
  if (thunkKind == DifferentiationThunkKind::Reabstraction)
    assert(expectedType->getExtInfo().hasContext());

  // This may inherit @noescape from the expected type. The `@noescape`
  // attribute is only stripped when using this type to materialize a new decl.
  // Use `@convention(thin)` if:
  // - Building a reabstraction thunk type.
  // - Building an index subset thunk type, where the expected type has context
  //   (i.e. is `@convention(thick)`).
  auto extInfo = expectedType->getExtInfo();
  if (thunkKind == DifferentiationThunkKind::Reabstraction ||
      extInfo.hasContext()) {
    extInfo = extInfo.withRepresentation(
        SILFunctionType::Representation::Thin);
  }
  if (withoutActuallyEscaping)
    extInfo = extInfo.withNoEscape(false);

  // Does the thunk type involve archetypes other than opened existentials?
  bool hasArchetypes = false;
  // Does the thunk type involve an open existential type?
  CanOpenedArchetypeType openedExistential;
  auto archetypeVisitor = [&](CanType t) {
    if (auto archetypeTy = dyn_cast<OpenedArchetypeType>(t)) {
      if (auto opened = dyn_cast<OpenedArchetypeType>(archetypeTy)) {
        assert((openedExistential == CanArchetypeType() ||
                openedExistential == opened) &&
               "one too many open existentials");
        openedExistential = opened;
      } else {
        hasArchetypes = true;
      }
    }
  };

  // Use the generic signature from the context if the thunk involves
  // generic parameters.
  CanGenericSignature genericSig;
  SubstitutionMap contextSubs;
  ArchetypeType *newArchetype = nullptr;

  if (expectedType->hasArchetype() || sourceType->hasArchetype()) {
    expectedType.visit(archetypeVisitor);
    sourceType.visit(archetypeVisitor);
    genericSig = buildThunkSignature(
        fn, hasArchetypes, openedExistential, genericEnv, contextSubs,
        interfaceSubs, newArchetype);
  }

  // Utility function to apply contextSubs, and also replace the
  // opened existential with the new archetype.
  auto substIntoThunkContext = [&](CanType t) -> CanType {
    return t.subst(
        [&](SubstitutableType *type) -> Type {
          if (CanType(type) == openedExistential)
            return newArchetype;
          return Type(type).subst(contextSubs);
        },
        LookUpConformanceInSubstitutionMap(contextSubs),
        SubstFlags::AllowLoweredTypes)->getCanonicalType();
  };

  sourceType = cast<SILFunctionType>(substIntoThunkContext(sourceType));
  expectedType = cast<SILFunctionType>(substIntoThunkContext(expectedType));

  // If our parent function was pseudogeneric, this thunk must also be
  // pseudogeneric, since we have no way to pass generic parameters.
  if (genericSig)
    if (origType->isPseudogeneric())
      extInfo = extInfo.withIsPseudogeneric();

  // Add the function type as the parameter.
  auto contextConvention =
      SILType::getPrimitiveObjectType(sourceType).isTrivial(*fn)
          ? ParameterConvention::Direct_Unowned
          : ParameterConvention::Direct_Guaranteed;
  SmallVector<SILParameterInfo, 4> params;
  params.append(expectedType->getParameters().begin(),
                expectedType->getParameters().end());
  // Add reabstraction function parameter only if building a reabstraction thunk
  // type.
  if (thunkKind == DifferentiationThunkKind::Reabstraction)
    params.push_back({sourceType, sourceType->getExtInfo().hasContext()
                                      ? contextConvention
                                      : ParameterConvention::Direct_Unowned});

  // Map the parameter and expected types out of context to get the interface
  // type of the thunk.
  SmallVector<SILParameterInfo, 4> interfaceParams;
  interfaceParams.reserve(params.size());
  for (auto &param : params) {
    auto paramIfaceTy = param.getType()->mapTypeOutOfContext();
    interfaceParams.push_back(SILParameterInfo(
        paramIfaceTy->getCanonicalType(genericSig), param.getConvention()));
  }

  SmallVector<SILYieldInfo, 4> interfaceYields;
  for (auto &yield : expectedType->getYields()) {
    auto yieldIfaceTy = yield.getType()->mapTypeOutOfContext();
    auto interfaceYield =
        yield.getWithType(yieldIfaceTy->getCanonicalType(genericSig));
    interfaceYields.push_back(interfaceYield);
  }

  SmallVector<SILResultInfo, 4> interfaceResults;
  for (auto &result : expectedType->getResults()) {
    auto resultIfaceTy = result.getType()->mapTypeOutOfContext();
    auto interfaceResult =
        result.getWithType(resultIfaceTy->getCanonicalType(genericSig));
    interfaceResults.push_back(interfaceResult);
  }

  Optional<SILResultInfo> interfaceErrorResult;
  if (expectedType->hasErrorResult()) {
    auto errorResult = expectedType->getErrorResult();
    auto errorIfaceTy = errorResult.getType()->mapTypeOutOfContext();
    interfaceErrorResult =
        SILResultInfo(errorIfaceTy->getCanonicalType(genericSig),
                      expectedType->getErrorResult().getConvention());
  }

  // The type of the thunk function.
  return SILFunctionType::get(
      genericSig, extInfo, expectedType->getCoroutineKind(),
      ParameterConvention::Direct_Unowned, interfaceParams, interfaceYields,
      interfaceResults, interfaceErrorResult, module.getASTContext());
}

/// Get or create a reabstraction thunk from `fromType` to `toType`, to be
/// called in `caller`.
static SILFunction *getOrCreateReabstractionThunk(SILOptFunctionBuilder &fb,
                                                  SILModule &module,
                                                  SILLocation loc,
                                                  SILFunction *caller,
                                                  CanSILFunctionType fromType,
                                                  CanSILFunctionType toType) {
  SubstitutionMap interfaceSubs;
  GenericEnvironment *genericEnv = nullptr;
  auto thunkType = buildThunkType(
      caller, fromType, toType, genericEnv, interfaceSubs,
      /*withoutActuallyEscaping*/ false,
      DifferentiationThunkKind::Reabstraction);
  auto thunkDeclType =
      thunkType->getWithExtInfo(thunkType->getExtInfo().withNoEscape(false));

  auto fromInterfaceType = fromType->mapTypeOutOfContext()->getCanonicalType();
  auto toInterfaceType = toType->mapTypeOutOfContext()->getCanonicalType();

  Mangle::ASTMangler mangler;
  std::string name = mangler.mangleReabstractionThunkHelper(
      thunkType, fromInterfaceType, toInterfaceType,
      Type(), module.getSwiftModule());

  auto *thunk = fb.getOrCreateSharedFunction(
      loc, name, thunkDeclType, IsBare, IsTransparent, IsSerialized,
      ProfileCounter(), IsReabstractionThunk, IsNotDynamic);
  if (!thunk->empty())
    return thunk;

  thunk->setGenericEnvironment(genericEnv);
  thunk->setOwnershipEliminated();
  auto *entry = thunk->createBasicBlock();
  SILBuilder builder(entry);
  createEntryArguments(thunk);

  SILFunctionConventions fromConv(fromType, module);
  SILFunctionConventions toConv(toType, module);
  assert(toConv.useLoweredAddresses());

  auto *fnArg = thunk->getArgumentsWithoutIndirectResults().back();

  SmallVector<SILValue, 4> arguments;
  auto toArgIter = thunk->getArguments().begin();
  auto useNextArgument = [&]() {
    arguments.push_back(*toArgIter++);
  };

  SmallVector<AllocStackInst *, 4> localAllocations;
  auto createAllocStack = [&](SILType type) {
    auto *alloc = builder.createAllocStack(loc, type);
    localAllocations.push_back(alloc);
    return alloc;
  };

  // Handle indirect results.
  assert(fromType->getNumResults() == toType->getNumResults());
  for (unsigned resIdx : range(toType->getNumResults())) {
    auto fromRes = fromConv.getResults()[resIdx];
    auto toRes = toConv.getResults()[resIdx];
    // No abstraction mismatch.
    if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
      // If result types are indirect, directly pass as next argument.
      if (toRes.isFormalIndirect())
        useNextArgument();
      continue;
    }
    // Convert indirect result to direct result.
    if (fromRes.isFormalIndirect()) {
      SILType resultTy = fromConv.getSILType(fromRes);
      assert(resultTy.isAddress());
      auto *indRes = createAllocStack(resultTy);
      arguments.push_back(indRes);
      continue;
    }
    // Convert direct result to indirect result.
    // Increment thunk argument iterator; reabstraction handled later.
    toArgIter++;
  }

  // Reabstract parameters.
  assert(toType->getNumParameters() == fromType->getNumParameters());
  for (unsigned paramIdx : range(toType->getNumParameters())) {
    auto fromParam = fromConv.getParameters()[paramIdx];
    auto toParam = toConv.getParameters()[paramIdx];
    // No abstraction mismatch. Directly use next argument.
    if (fromParam.isFormalIndirect() == toParam.isFormalIndirect()) {
      useNextArgument();
      continue;
    }
    // Convert indirect parameter to direct parameter.
    if (fromParam.isFormalIndirect()) {
      auto paramTy = fromConv.getSILType(fromType->getParameters()[paramIdx]);
      if (!paramTy.hasArchetype())
        paramTy = thunk->mapTypeIntoContext(paramTy);
      assert(paramTy.isAddress());
      auto *toArg = *toArgIter++;
      auto *buf = createAllocStack(toArg->getType());
      builder.createStore(loc, toArg, buf,
                          StoreOwnershipQualifier::Unqualified);
      arguments.push_back(buf);
      continue;
    }
    // Convert direct parameter to indirect parameter.
    assert(toParam.isFormalIndirect());
    auto *toArg = *toArgIter++;
    auto *load = builder.createLoad(loc, toArg,
                                    LoadOwnershipQualifier::Unqualified);
    arguments.push_back(load);
  }

  auto *apply = builder.createApply(
      loc, fnArg, SubstitutionMap(), arguments, /*isNonThrowing*/ false);

  // Get return elements.
  SmallVector<SILValue, 4> results;
  // Extract all direct results.
  SmallVector<SILValue, 4> directResults;
  extractAllElements(apply, builder, directResults);

  auto fromDirResultsIter = directResults.begin();
  auto fromIndResultsIter = apply->getIndirectSILResults().begin();
  auto toIndResultsIter = thunk->getIndirectResults().begin();
  // Reabstract results.
  for (unsigned resIdx : range(toType->getNumResults())) {
    auto fromRes = fromConv.getResults()[resIdx];
    auto toRes = toConv.getResults()[resIdx];
    // No abstraction mismatch.
    if (fromRes.isFormalIndirect() == toRes.isFormalIndirect()) {
      // If result types are direct, add call result as direct thunk result.
      if (toRes.isFormalDirect())
        results.push_back(*fromDirResultsIter++);
      // If result types are indirect, increment indirect result iterators.
      else {
        ++fromIndResultsIter;
        ++toIndResultsIter;
      }
      continue;
    }
    // Load direct results from indirect results.
    if (fromRes.isFormalIndirect()) {
      auto indRes = *fromIndResultsIter++;
      auto *load = builder.createLoad(loc, indRes,
                                      LoadOwnershipQualifier::Unqualified);
      results.push_back(load);
      continue;
    }
    // Store direct results to indirect results.
    assert(toRes.isFormalIndirect());
    SILType resultTy = toConv.getSILType(toRes);
    assert(resultTy.isAddress());
    auto indRes = *toIndResultsIter++;
    builder.createStore(loc, *fromDirResultsIter++, indRes,
                        StoreOwnershipQualifier::Unqualified);
  }
  auto retVal = joinElements(results, builder, loc);

  // Deallocate local allocations.
  for (auto *alloc : reversed(localAllocations))
    builder.createDeallocStack(loc, alloc);

  // Create return.
  builder.createReturn(loc, retVal);

  LLVM_DEBUG(auto &s = getADDebugStream() << "Created reabstraction thunk.\n";
             s << "  From type: " << fromType << '\n';
             s << "  To type: " << toType << '\n';
             s << '\n' << *thunk);

  return thunk;
}

namespace {
class VJPEmitter final
    : public TypeSubstCloner<VJPEmitter, SILOptFunctionBuilder> {
  friend class PullbackEmitter;

private:
  /// The global context.
  ADContext &context;

  /// The original function.
  SILFunction *const original;

  /// The `[differentiable]` attribute.
  SILDifferentiableAttr *const attr;

  /// The VJP function.
  SILFunction *const vjp;

  /// The pullback function.
  SILFunction *pullback;

  /// The differentiation invoker.
  DifferentiationInvoker invoker;

  /// Info from activity analysis on the original function.
  const DifferentiableActivityInfo &activityInfo;

  /// The linear map info.
  LinearMapInfo pullbackInfo;

  /// Caches basic blocks whose phi arguments have been remapped (adding a
  /// predecessor enum argument).
  SmallPtrSet<SILBasicBlock *, 4> remappedBasicBlocks;

  /// A pair of a trampoline block phi argument and its corresponding
  /// destination block phi argument.
  struct TrampolinedArgumentPair {
    SILPhiArgument *trampolineArgument;
    SILPhiArgument *destinationArgument;
  };
  /// An array that keeps track of all `@guaranteed` phi arguments in any
  /// trampoline blocks we've added. Each of these arguments needs to have a
  /// lifetime-ending use past its destination argument's lifetime-ending use,
  /// so we keep track of these pairs of arguments and emit `end_borrow`s when
  /// function cloning is finished.
  SmallVector<TrampolinedArgumentPair, 8> trampolinedGuaranteedPhiArguments;

  bool errorOccurred = false;

  /// Mapping from original blocks to pullback values. Used to build pullback
  /// struct instances.
  DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> pullbackValues;

  ASTContext &getASTContext() const { return vjp->getASTContext(); }
  SILModule &getModule() const { return vjp->getModule(); }
  const SILAutoDiffIndices &getIndices() const { return attr->getIndices(); }

  static SubstitutionMap getSubstitutionMap(SILFunction *original,
                                            SILFunction *vjp) {
    auto substMap = original->getForwardingSubstitutionMap();
    if (auto *vjpGenEnv = vjp->getGenericEnvironment()) {
      auto vjpSubstMap = vjpGenEnv->getForwardingSubstitutionMap();
      substMap = SubstitutionMap::get(
          vjpGenEnv->getGenericSignature(), QuerySubstitutionMap{vjpSubstMap},
          LookUpConformanceInSubstitutionMap(vjpSubstMap));
    }
    return substMap;
  }

  static const DifferentiableActivityInfo &getActivityInfo(
      ADContext &context, SILFunction *original,
      const SILAutoDiffIndices &indices, SILFunction *vjp) {
    // Get activity info of the original function.
    auto &passManager = context.getPassManager();
    auto *activityAnalysis =
        passManager.getAnalysis<DifferentiableActivityAnalysis>();
    auto &activityCollection = *activityAnalysis->get(original);
    auto &activityInfo = activityCollection.getActivityInfo(
        vjp->getLoweredFunctionType()->getGenericSignature(),
        AutoDiffDerivativeFunctionKind::VJP);
    LLVM_DEBUG(
        dumpActivityInfo(*original, indices, activityInfo, getADDebugStream()));
    return activityInfo;
  }

public:
  explicit VJPEmitter(ADContext &context, SILFunction *original,
                      SILDifferentiableAttr *attr, SILFunction *vjp,
                      DifferentiationInvoker invoker)
      : TypeSubstCloner(*vjp, *original, getSubstitutionMap(original, vjp)),
        context(context), original(original), attr(attr), vjp(vjp),
        invoker(invoker), activityInfo(getActivityInfo(
                              context, original, attr->getIndices(), vjp)),
        pullbackInfo(context, AutoDiffLinearMapKind::Pullback, original, vjp,
                     attr->getIndices(), activityInfo) {
    // Create empty pullback function.
    pullback = createEmptyPullback();
    context.getGeneratedFunctions().push_back(pullback);
  }

  SILFunction *createEmptyPullback() {
    auto &module = context.getModule();
    auto origTy = original->getLoweredFunctionType();
    auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());

    // RAII that pushes the original function's generic signature to
    // `module.Types` so that the calls to `module.Types.getTypeLowering()`
    // below will know the original function's generic parameter types.
    Lowering::GenericContextScope genericContextScope(
        module.Types, origTy->getGenericSignature());

    // Given a type, returns its formal SIL parameter info.
    auto getTangentParameterInfoForOriginalResult = [&](
        CanType tanType, ResultConvention origResConv) -> SILParameterInfo {
      auto &tl = context.getTypeConverter().getTypeLowering(
          tanType, ResilienceExpansion::Minimal);
      ParameterConvention conv;
      switch (origResConv) {
      case ResultConvention::Owned:
      case ResultConvention::Autoreleased:
        conv = tl.isTrivial()
            ? ParameterConvention::Direct_Unowned
            : ParameterConvention::Direct_Guaranteed;
        break;
      case ResultConvention::Unowned:
      case ResultConvention::UnownedInnerPointer:
        conv = ParameterConvention::Direct_Unowned;
        break;
      case ResultConvention::Indirect:
        conv = ParameterConvention::Indirect_In_Guaranteed;
        break;
      }
      return {tanType, conv};
    };

    // Given a type, returns its formal SIL result info.
    auto getTangentResultInfoForOriginalParameter = [&](
        CanType tanType, ParameterConvention origParamConv) -> SILResultInfo {
      auto &tl = context.getTypeConverter().getTypeLowering(
          tanType, ResilienceExpansion::Minimal);
      ResultConvention conv;
      switch (origParamConv) {
      case ParameterConvention::Direct_Owned:
      case ParameterConvention::Direct_Guaranteed:
      case ParameterConvention::Direct_Unowned:
        conv = tl.isTrivial()
            ? ResultConvention::Unowned
            : ResultConvention::Owned;
        break;
      case ParameterConvention::Indirect_In:
      case ParameterConvention::Indirect_Inout:
      case ParameterConvention::Indirect_In_Constant:
      case ParameterConvention::Indirect_In_Guaranteed:
      case ParameterConvention::Indirect_InoutAliasable:
        conv = ResultConvention::Indirect;
        break;
      }
      return {tanType, conv};
    };

    // Parameters of the pullback are:
    // - the tangent vectors of the original results, and
    // - a pullback struct.
    // Results of the pullback are in the tangent space of the original
    // parameters.
    SmallVector<SILParameterInfo, 8> pbParams;
    SmallVector<SILResultInfo, 8> adjResults;
    auto origParams = origTy->getParameters();
    auto indices = attr->getIndices();

    // Add pullback parameter for the seed.
    auto origResInfo = origTy->getResults()[indices.source];
    pbParams.push_back(getTangentParameterInfoForOriginalResult(
        origResInfo.getType()
            ->getAutoDiffAssociatedTangentSpace(lookupConformance)
            ->getCanonicalType(), origResInfo.getConvention()));

    // Accept a pullback struct in the pullback parameter list. This is the
    // returned pullback's closure context.
    auto *origExit = &*original->findReturnBB();
    auto *pbStruct = pullbackInfo.getLinearMapStruct(origExit);
    auto pbStructType = pbStruct->getDeclaredInterfaceType()
        ->getCanonicalType();
    pbParams.push_back({pbStructType, ParameterConvention::Direct_Owned});

    // Add pullback results for the requested wrt parameters.
    for (auto i : indices.parameters->getIndices()) {
      auto origParam = origParams[i];
      adjResults.push_back(getTangentResultInfoForOriginalParameter(
          origParam.getType()
              ->getAutoDiffAssociatedTangentSpace(lookupConformance)
              ->getCanonicalType(), origParam.getConvention()));
    }

    Mangle::ASTMangler mangler;
    auto pbName = original->getASTContext().getIdentifier(
        mangler.mangleAutoDiffLinearMapHelper(
            original->getName(), AutoDiffLinearMapKind::Pullback,
            indices)).str();
    auto pbGenericSig = getDerivativeGenericSignature(attr, original);
    auto *pbGenericEnv =
        pbGenericSig ? pbGenericSig->getGenericEnvironment() : nullptr;
    auto pbType = SILFunctionType::get(
        pbGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(),
        origTy->getCalleeConvention(), pbParams, {}, adjResults, None,
        original->getASTContext());

    SILOptFunctionBuilder fb(context.getTransform());
    // The generated pullback linkage is set to Hidden because generated
    // pullbacks are never called cross-module.
    auto linkage = SILLinkage::Hidden;
    auto *pullback = fb.createFunction(
        linkage, pbName, pbType, pbGenericEnv, original->getLocation(),
        original->isBare(), IsNotTransparent, original->isSerialized(),
        original->isDynamicallyReplaceable());
    pullback->setDebugScope(new (module)
                                SILDebugScope(original->getLocation(),
                                              pullback));
    return pullback;
  }

  /// Run VJP generation. Returns true on error.
  bool run();

  void postProcess(SILInstruction *orig, SILInstruction *cloned) {
    if (errorOccurred)
      return;
    SILClonerWithScopes::postProcess(orig, cloned);
  }

  /// Remap original basic blocks, adding predecessor enum arguments.
  SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) {
    auto *vjpBB = BBMap[bb];
    // If error has occurred, or if block has already been remapped, return
    // remapped, return remapped block.
    if (errorOccurred || remappedBasicBlocks.count(bb))
      return vjpBB;
    // Add predecessor enum argument to the remapped block.
    auto *predEnum = pullbackInfo.getBranchingTraceDecl(bb);
    auto enumTy = getOpASTType(predEnum->getDeclaredInterfaceType()
                                 ->getCanonicalType());
    auto enumLoweredTy = context.getTypeConverter().getLoweredType(
        enumTy, ResilienceExpansion::Minimal);
    vjpBB->createPhiArgument(enumLoweredTy, ValueOwnershipKind::Owned);
    remappedBasicBlocks.insert(bb);
    return vjpBB;
  }

  /// General visitor for all instructions. If any error is emitted by previous
  /// visits, bail out.
  void visit(SILInstruction *inst) {
    if (errorOccurred)
      return;
    TypeSubstCloner::visit(inst);
  }

  void visitSILInstruction(SILInstruction *inst) {
    context.emitNondifferentiabilityError(inst, invoker,
        diag::autodiff_expression_not_differentiable_note);
    errorOccurred = true;
  }

private:
  /// Get the lowered SIL type of the given nominal type declaration.
  SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) {
    auto nomType = getOpASTType(
        nominal->getDeclaredInterfaceType()->getCanonicalType());
    auto nomSILType = context.getTypeConverter().getLoweredType(
        nomType, ResilienceExpansion::Minimal);
    return nomSILType;
  }

  /// Build a pullback struct value for the original block corresponding to the
  /// given terminator.
  StructInst *buildPullbackValueStructValue(TermInst *termInst) {
    assert(termInst->getFunction() == original);
    auto loc = termInst->getFunction()->getLocation();
    auto *origBB = termInst->getParent();
    auto *vjpBB = BBMap[origBB];
    auto *pbStruct = pullbackInfo.getLinearMapStruct(origBB);
    auto structLoweredTy = getNominalDeclLoweredType(pbStruct);
    auto bbPullbackValues = pullbackValues[origBB];
    if (!origBB->isEntry()) {
      auto *predEnumArg = vjpBB->getArguments().back();
      bbPullbackValues.insert(bbPullbackValues.begin(), predEnumArg);
    }
    return getBuilder().createStruct(loc, structLoweredTy, bbPullbackValues);
  }

  /// Build a predecessor enum instance using the given builder for the given
  /// original predecessor/successor blocks and pullback struct value.
  EnumInst *buildPredecessorEnumValue(SILBuilder &builder,
                                      SILBasicBlock *predBB,
                                      SILBasicBlock *succBB,
                                      SILValue pbStructVal) {
    auto loc = pbStructVal.getLoc();
    auto *succEnum = pullbackInfo.getBranchingTraceDecl(succBB);
    auto enumLoweredTy = getNominalDeclLoweredType(succEnum);
    auto *enumEltDecl =
        pullbackInfo.lookUpBranchingTraceEnumElement(predBB, succBB);
    auto enumEltType = getOpType(
        enumLoweredTy.getEnumElementType(enumEltDecl, getModule()));
    // If the enum element type does not have a box type (i.e. the enum case is
    // not indirect), then directly create an enum.
    auto boxType = dyn_cast<SILBoxType>(enumEltType.getASTType());
    if (!boxType)
      return builder.createEnum(loc, pbStructVal, enumEltDecl, enumLoweredTy);
    // Otherwise, box the pullback struct value and create an enum.
    auto *newBox = builder.createAllocBox(loc, boxType);
    builder.emitScopedBorrowOperation(
        loc, newBox, [&](SILValue borrowedBox) {
      auto *projectBox = builder.createProjectBox(loc, newBox, /*index*/ 0);
      builder.emitStoreValueOperation(loc, pbStructVal, projectBox,
                                      StoreOwnershipQualifier::Init);
    });
    return builder.createEnum(loc, newBox, enumEltDecl, enumLoweredTy);
  }

public:
  void visitReturnInst(ReturnInst *ri) {
    auto loc = ri->getOperand().getLoc();
    auto *origExit = ri->getParent();
    auto &builder = getBuilder();
    auto *pbStructVal = buildPullbackValueStructValue(ri);

    // Get the value in the VJP corresponding to the original result.
    auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
    auto origResult = getOpValue(origRetInst->getOperand());
    SmallVector<SILValue, 8> origResults;
    extractAllElements(origResult, builder, origResults);

    // Get and partially apply the pullback.
    auto vjpGenericEnv = vjp->getGenericEnvironment();
    auto vjpSubstMap = vjpGenericEnv
        ? vjpGenericEnv->getForwardingSubstitutionMap()
        : vjp->getForwardingSubstitutionMap();
    auto *pullbackRef = builder.createFunctionRef(loc, pullback);
    auto *pullbackPartialApply = builder.createPartialApply(
        loc, pullbackRef, vjpSubstMap, {pbStructVal},
        ParameterConvention::Direct_Guaranteed);

    // Return a tuple of the original result and pullback.
    SmallVector<SILValue, 8> directResults;
    directResults.append(origResults.begin(), origResults.end());
    directResults.push_back(pullbackPartialApply);
    builder.createReturn(
        ri->getLoc(), joinElements(directResults, builder, loc));
  }

  void visitBranchInst(BranchInst *bi) {
    // Build pullback struct value for original block.
    // Build predecessor enum value for destination block.
    auto *origBB = bi->getParent();
    auto *pbStructVal = buildPullbackValueStructValue(bi);
    auto *enumVal = buildPredecessorEnumValue(
        getBuilder(), origBB, bi->getDestBB(), pbStructVal);

    // Remap arguments, appending the new enum values.
    SmallVector<SILValue, 8> args;
    for (auto origArg : bi->getArgs())
      args.push_back(getOpValue(origArg));
    args.push_back(enumVal);

    // Create a new `br` instruction.
    getBuilder().createBranch(
        bi->getLoc(), getOpBasicBlock(bi->getDestBB()), args);
  }

  void visitCondBranchInst(CondBranchInst *cbi) {
    // Build pullback struct value for original block.
    // Build predecessor enum values for true/false blocks.
    auto *origBB = cbi->getParent();
    auto *pbStructVal = buildPullbackValueStructValue(cbi);

    // Creates a trampoline block for given original successor block. The
    // trampoline block has the same arguments as the VJP successor block but
    // drops the last predecessor enum argument. The generated `switch_enum`
    // instruction branches to the trampoline block, and the trampoline block
    // constructs a predecessor enum value and branches to the VJP successor
    // block.
    auto createTrampolineBasicBlock =
        [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
      auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
      // Create the trampoline block.
      auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
      for (auto *arg : vjpSuccBB->getArguments().drop_back())
        trampolineBB->createPhiArgument(arg->getType(),
                                        arg->getOwnershipKind());
      // Build predecessor enum value for successor block and branch to it.
      SILBuilder trampolineBuilder(trampolineBB);
      auto *succEnumVal = buildPredecessorEnumValue(
          trampolineBuilder, origBB, origSuccBB, pbStructVal);
      SmallVector<SILValue, 4> forwardedArguments(
          trampolineBB->getArguments().begin(),
          trampolineBB->getArguments().end());
      forwardedArguments.push_back(succEnumVal);
      trampolineBuilder.createBranch(cbi->getLoc(), vjpSuccBB,
                                     forwardedArguments);
      return trampolineBB;
    };

    // Create a new `cond_br` instruction.
    getBuilder().createCondBranch(
        cbi->getLoc(), getOpValue(cbi->getCondition()),
        createTrampolineBasicBlock(cbi->getTrueBB()),
        createTrampolineBasicBlock(cbi->getFalseBB()));
  }

  void visitSwitchEnumInst(SwitchEnumInst *sei) {
    // Build pullback struct value for original block.
    auto *origBB = sei->getParent();
    auto *pbStructVal = buildPullbackValueStructValue(sei);

    // Creates a trampoline block for given original successor block. The
    // trampoline block has the same arguments as the VJP successor block but
    // drops the last predecessor enum argument. The generated `switch_enum`
    // instruction branches to the trampoline block, and the trampoline block
    // constructs a predecessor enum value and branches to the VJP successor
    // block.
    auto createTrampolineBasicBlock =
        [&](SILBasicBlock *origSuccBB) -> SILBasicBlock * {
      auto *vjpSuccBB = getOpBasicBlock(origSuccBB);
      // Create the trampoline block.
      auto *trampolineBB = vjp->createBasicBlockBefore(vjpSuccBB);
      for (auto *destArg : vjpSuccBB->getArguments().drop_back()) {
        auto *trampolineArg = trampolineBB->createPhiArgument(
            destArg->getType(), destArg->getOwnershipKind());
        // Each `@guaranteed` trampoline argument needs to have a
        // lifetime-ending use past its destination argument's lifetime-ending
        // uses, so we keep track of these pairs of arguments in
        // `trampolinedGuaranteedPhiArguments` and emit `end_borrow`s when
        // function cloning is finished.
        if (trampolineArg->getOwnershipKind() == ValueOwnershipKind::Guaranteed)
          trampolinedGuaranteedPhiArguments.push_back(
              {trampolineArg, cast<SILPhiArgument>(destArg)});
      }
      // Build predecessor enum value for successor block and branch to it.
      SILBuilder trampolineBuilder(trampolineBB);
      auto *succEnumVal = buildPredecessorEnumValue(
          trampolineBuilder, origBB, origSuccBB, pbStructVal);
      SmallVector<SILValue, 4> forwardedArguments(
          trampolineBB->getArguments().begin(),
          trampolineBB->getArguments().end());
      forwardedArguments.push_back(succEnumVal);
      trampolineBuilder.createBranch(sei->getLoc(), vjpSuccBB,
                                     forwardedArguments);
      return trampolineBB;
    };

    // Create trampoline successor basic blocks.
    SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4> caseBBs;
    for (unsigned i : range(sei->getNumCases())) {
      auto caseBB = sei->getCase(i);
      auto *trampolineBB = createTrampolineBasicBlock(caseBB.second);
      caseBBs.push_back({caseBB.first, trampolineBB});
    }
    // Create trampoline default basic block.
    SILBasicBlock *newDefaultBB = nullptr;
    if (auto *defaultBB = sei->getDefaultBBOrNull().getPtrOrNull())
      newDefaultBB = createTrampolineBasicBlock(defaultBB);

    // Create a new `switch_enum` instruction.
    getBuilder().createSwitchEnum(
        sei->getLoc(), getOpValue(sei->getOperand()), newDefaultBB, caseBBs);
  }

  // If an `apply` has active results or active inout parameters, replace it
  // with an `apply` of its VJP.
  void visitApplyInst(ApplyInst *ai) {
    // If the function should not be differentiated or its the array literal
    // initialization intrinsic, just do standard cloning.
    if (!pullbackInfo.shouldDifferentiateApplyInst(ai) ||
        isArrayLiteralIntrinsic(ai)) {
      LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
      TypeSubstCloner::visitApplyInst(ai);
      return;
    }

    // Check and reject functions with active inout arguments. It's not yet
    // supported.
    auto paramInfos = ai->getSubstCalleeConv().getParameters();
    auto paramArgs = ai->getArgumentsWithoutIndirectResults();
    for (unsigned i : swift::indices(paramInfos)) {
      if (paramInfos[i].isIndirectInOut() &&
          activityInfo.isActive(paramArgs[i], getIndices())) {
        context.emitNondifferentiabilityError(ai, invoker,
            diag::autodiff_cannot_differentiate_through_inout_arguments);
        errorOccurred = true;
        return;
      }
    }

    LLVM_DEBUG(getADDebugStream() << "VJP-transforming:\n" << *ai << '\n');

    // Get the minimal parameter and result indices required for differentiating
    // this `apply`.
    SmallVector<SILValue, 4> allResults;
    SmallVector<unsigned, 8> activeParamIndices;
    SmallVector<unsigned, 8> activeResultIndices;
    collectMinimalIndicesForFunctionCall(ai, getIndices(), activityInfo,
                                         allResults, activeParamIndices,
                                         activeResultIndices);
    assert(!activeParamIndices.empty() && "Parameter indices cannot be empty");
    assert(!activeResultIndices.empty() && "Result indices cannot be empty");
    LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={";
               interleave(activeParamIndices.begin(), activeParamIndices.end(),
                          [&s](unsigned i) { s << i; }, [&s] { s << ", "; });
               s << "}, results={"; interleave(
                   activeResultIndices.begin(), activeResultIndices.end(),
                   [&s](unsigned i) { s << i; }, [&s] { s << ", "; });
               s << "}\n";);
    // FIXME: We don't support multiple active results yet.
    if (activeResultIndices.size() > 1) {
      context.emitNondifferentiabilityError(
          ai, invoker, diag::autodiff_expression_not_differentiable_note);
      errorOccurred = true;
      return;
    }

    // Form expected indices, assuming there's only one result.
    SILAutoDiffIndices indices(
        activeResultIndices.front(),
        IndexSubset::get(
            getASTContext(), ai->getArgumentsWithoutIndirectResults().size(),
            activeParamIndices));

    // Emit the VJP.
    auto loc = ai->getLoc();
    auto &builder = getBuilder();
    auto original = getOpValue(ai->getCallee());
    SILValue vjpValue;
    // If functionSource is a `@differentiable` function, just extract it.
    auto originalFnTy = original->getType().castTo<SILFunctionType>();
    if (originalFnTy->isDifferentiable()) {
      auto paramIndices = originalFnTy->getDifferentiationParameterIndices();
      for (auto i : indices.parameters->getIndices()) {
        if (!paramIndices->contains(i)) {
          context.emitNondifferentiabilityError(original, invoker,
              diag::autodiff_function_nondiff_parameter_not_differentiable);
          errorOccurred = true;
          return;
        }
      }
      auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original);
      vjpValue = builder.createDifferentiableFunctionExtract(
          loc, NormalDifferentiableFunctionTypeComponent::VJP,
          borrowedDiffFunc);
      vjpValue = builder.emitCopyValueOperation(loc, vjpValue);
    }

    // Check and diagnose non-differentiable original function type.
    auto diagnoseNondifferentiableOriginalFunctionType =
        [&](CanSILFunctionType origFnTy) {
          // Check and diagnose non-differentiable arguments.
          for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
            if (indices.isWrtParameter(paramIndex) &&
                    !originalFnTy->getParameters()[paramIndex]
                    .getSILStorageType()
                    .isDifferentiable(getModule())) {
              context.emitNondifferentiabilityError(
                  ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker,
                  diag::autodiff_nondifferentiable_argument);
              errorOccurred = true;
              return true;
            }
          }
          // Check and diagnose non-differentiable results.
          if (!originalFnTy->getResults()[indices.source]
                  .getSILStorageType()
                  .isDifferentiable(getModule())) {
            context.emitNondifferentiabilityError(
                original, invoker, diag::autodiff_nondifferentiable_result);
            errorOccurred = true;
            return true;
          }
          return false;
        };
    if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
      return;

    // If VJP has not yet been found, emit an `differentiable_function`
    // instruction on the remapped original function operand and
    // an `differentiable_function_extract` instruction to get the VJP.
    // The `differentiable_function` instruction will be canonicalized during
    // the transform main loop.
    if (!vjpValue) {
      // FIXME: Handle indirect differentiation invokers. This may require some
      // redesign: currently, each original function + attribute pair is mapped
      // only to one invoker.
      /*
      DifferentiationInvoker indirect(ai, attr);
      auto insertion =
          context.getInvokers().try_emplace({this->original, attr}, indirect);
      auto &invoker = insertion.first->getSecond();
      invoker = indirect;
      */

      // If the original `apply` instruction has a substitution map, then the
      // applied function is specialized.
      // In the VJP, specialization is also necessary for parity. The original
      // function operand is specialized with a remapped version of same
      // substitution map using an argument-less `partial_apply`.
      if (ai->getSubstitutionMap().empty()) {
        original = builder.emitCopyValueOperation(loc, original);
      } else {
        auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap());
        auto vjpPartialApply = getBuilder().createPartialApply(
            ai->getLoc(), original, substMap, {},
            ParameterConvention::Direct_Guaranteed);
        original = vjpPartialApply;
        originalFnTy = original->getType().castTo<SILFunctionType>();
        // Diagnose if new original function type is non-differentiable.
        if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
          return;
      }

      auto *diffFuncInst = context.createDifferentiableFunction(
          getBuilder(), loc, indices.parameters, original);

      // Record the `differentiable_function` instruction.
      context.getDifferentiableFunctionInsts().push_back(diffFuncInst);
      // TODO(TF-689): Make `differentiable_function` store result indices and
      // remove `ADContext::resultIndices`.
      context.getResultIndices()[diffFuncInst] = activeResultIndices.front();

      auto borrowedADFunc =
          builder.emitBeginBorrowOperation(loc, diffFuncInst);
      auto extractedVJP = getBuilder().createDifferentiableFunctionExtract(
          loc, NormalDifferentiableFunctionTypeComponent::VJP,
          borrowedADFunc);
      vjpValue = builder.emitCopyValueOperation(loc, extractedVJP);
      builder.emitEndBorrowOperation(loc, borrowedADFunc);
      builder.emitDestroyValueOperation(loc, diffFuncInst);
    }

    // Record desired/actual VJP indices.
    // Temporarily set original pullback type to `None`.
    NestedApplyInfo info{indices, /*originalPullbackType*/ None};
    auto insertion = context.getNestedApplyInfo().try_emplace(ai, info);
    auto &nestedApplyInfo = insertion.first->getSecond();
    nestedApplyInfo = info;

    // Call the VJP using the original parameters.
    SmallVector<SILValue, 8> vjpArgs;
    auto vjpFnTy = getOpType(vjpValue->getType()).castTo<SILFunctionType>();
    auto numVJPArgs =
        vjpFnTy->getNumParameters() + vjpFnTy->getNumIndirectFormalResults();
    vjpArgs.reserve(numVJPArgs);
    // Collect substituted arguments.
    for (auto origArg : ai->getArguments())
      vjpArgs.push_back(getOpValue(origArg));
    assert(vjpArgs.size() == numVJPArgs);
    // Apply the VJP.
    // The VJP should be specialized, so no substitution map is necessary.
    auto *vjpCall = getBuilder().createApply(loc, vjpValue, SubstitutionMap(),
                                             vjpArgs, ai->isNonThrowing());
    LLVM_DEBUG(getADDebugStream() << "Applied vjp function\n" << *vjpCall);
    builder.emitDestroyValueOperation(loc, vjpValue);

    // Get the VJP results (original results and pullback).
    SmallVector<SILValue, 8> vjpDirectResults;
    extractAllElements(vjpCall, getBuilder(), vjpDirectResults);
    ArrayRef<SILValue> originalDirectResults =
        ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
    SILValue originalDirectResult = joinElements(originalDirectResults,
                                                 getBuilder(),
                                                 vjpCall->getLoc());
    SILValue pullback = vjpDirectResults.back();

    // Store the original result to the value map.
    mapValue(ai, originalDirectResult);

    // Checkpoint the pullback.
    auto *pullbackDecl = pullbackInfo.lookUpLinearMapDecl(ai);

    // If actual pullback type does not match lowered pullback type, reabstract
    // the pullback using a thunk.
    auto actualPullbackType =
        getOpType(pullback->getType()).getAs<SILFunctionType>();
    auto vjpGenSig = SubsMap.getGenericSignature()
        ? SubsMap.getGenericSignature()->getCanonicalSignature()
        : nullptr;
    Lowering::GenericContextScope genericContextScope(
        context.getTypeConverter(), vjpGenSig);
    auto loweredPullbackType =
        getOpType(context.getTypeConverter().getLoweredType(
                      pullbackDecl->getInterfaceType()->getCanonicalType(),
                      ResilienceExpansion::Minimal))
            .castTo<SILFunctionType>();
    if (!loweredPullbackType->isEqual(actualPullbackType)) {
      // Set non-reabstracted original pullback type in nested apply info.
      nestedApplyInfo.originalPullbackType = actualPullbackType;
      SILOptFunctionBuilder fb(context.getTransform());
      auto *thunk = getOrCreateReabstractionThunk(
          fb, getModule(), loc, /*caller*/ vjp, actualPullbackType,
          loweredPullbackType);
      auto *thunkRef = getBuilder().createFunctionRef(loc, thunk);
      pullback = getBuilder().createPartialApply(
          ai->getLoc(), thunkRef,
          getOpSubstitutionMap(thunk->getForwardingSubstitutionMap()),
          {pullback}, actualPullbackType->getCalleeConvention());
    }
    pullbackValues[ai->getParent()].push_back(pullback);

    // Some instructions that produce the callee may have been cloned.
    // If the original callee did not have any users beyond this `apply`,
    // recursively kill the cloned callee.
    if (auto *origCallee = cast_or_null<SingleValueInstruction>(
            ai->getCallee()->getDefiningInstruction()))
      if (origCallee->hasOneUse())
        recursivelyDeleteTriviallyDeadInstructions(
            getOpValue(origCallee)->getDefiningInstruction());
  }

  void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) {
    // Clone `differentiable_function` from original to VJP, then add the cloned
    // instruction to the `differentiable_function` worklist.
    TypeSubstCloner::visitDifferentiableFunctionInst(dfi);
    auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi));
    context.getDifferentiableFunctionInsts().push_back(newDFI);
  }
};
} // end anonymous namespace

//===----------------------------------------------------------------------===//
// AdjointValue - a symbolic representation for adjoint values that allows
// for efficient differentiation of aggregates.
//===----------------------------------------------------------------------===//

namespace {
class PullbackEmitter;
class AdjointValue;

enum AdjointValueKind {
  /// An empty adjoint, i.e. zero. This case exists due to its special
  /// mathematical properties: `0 + x = x`. This is a guaranteed optimization
  /// when we combine a zero adjoint with another (e.g. differentiating a
  /// fanout).
  Zero,

  /// An aggregate of adjoint values.
  Aggregate,

  /// A concrete SIL value.
  Concrete,
};

class AdjointValueBase {
  friend class AdjointValue;

  /// The kind of this adjoint value.
  AdjointValueKind kind;

  /// The type of this value as if it were materialized as a SIL value.
  SILType type;

  /// The underlying value.
  union Value {
    ArrayRef<AdjointValue> aggregate;
    SILValue concrete;
    Value(ArrayRef<AdjointValue> v) : aggregate(v) {}
    Value(SILValue v) : concrete(v) {}
    Value() {}
  } value;

  explicit AdjointValueBase(SILType type,
                            ArrayRef<AdjointValue> aggregate)
      : kind(AdjointValueKind::Aggregate), type(type), value(aggregate) {}

  explicit AdjointValueBase(SILValue v)
      : kind(AdjointValueKind::Concrete), type(v->getType()), value(v) {}

  explicit AdjointValueBase(SILType type)
      : kind(AdjointValueKind::Zero), type(type) {}
};

/// A symbolic adjoint value that is capable of representing zero value 0 and
/// 1, in addition to a materialized SILValue. This is expected to be passed
/// around by value in most cases, as it's two words long.
class AdjointValue final {
  friend class PullbackEmitter;

private:
  /// The kind of this adjoint value.
  AdjointValueBase *base;
  /*implicit*/ AdjointValue(AdjointValueBase *base = nullptr) : base(base) {}

public:
  AdjointValueBase *operator->() const { return base; }
  AdjointValueBase &operator*() const { return *base; }

  static AdjointValue createConcrete(llvm::BumpPtrAllocator &allocator,
                                     SILValue value) {
    return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(value);
  }

  template<typename EltRange>
  static AdjointValue createAggregate(llvm::BumpPtrAllocator &allocator,
                                      SILType type, EltRange elements) {
    AdjointValue *buf = reinterpret_cast<AdjointValue *>(allocator.Allocate(
        elements.size() * sizeof(AdjointValue), alignof(AdjointValue)));
    MutableArrayRef<AdjointValue> elementsCopy(buf, elements.size());
    std::uninitialized_copy(elements.begin(), elements.end(),
                            elementsCopy.begin());
    return new (allocator.Allocate<AdjointValueBase>())
        AdjointValueBase(type, elementsCopy);
  }

  static AdjointValue createZero(llvm::BumpPtrAllocator &allocator,
                                 SILType type) {
    return new (allocator.Allocate<AdjointValueBase>()) AdjointValueBase(type);
  }

  AdjointValueKind getKind() const { return base->kind; }
  SILType getType() const { return base->type; }
  CanType getSwiftType() const { return getType().getASTType(); }

  NominalTypeDecl *getAnyNominal() const {
    return getSwiftType()->getAnyNominal();
  }

  bool isZero() const { return getKind() == AdjointValueKind::Zero; }
  bool isAggregate() const { return getKind() == AdjointValueKind::Aggregate; }
  bool isConcrete() const { return getKind() == AdjointValueKind::Concrete; }

  unsigned getNumAggregateElements() const {
    assert(isAggregate());
    return base->value.aggregate.size();
  }

  AdjointValue getAggregateElement(unsigned i) const {
    assert(isAggregate());
    return base->value.aggregate[i];
  }

  ArrayRef<AdjointValue> getAggregateElements() const {
    return base->value.aggregate;
  }

  SILValue getConcreteValue() const {
    assert(isConcrete());
    return base->value.concrete;
  }

  void print(llvm::raw_ostream &s) const {
    switch (getKind()) {
    case AdjointValueKind::Zero:
      s << "Zero";
      break;
    case AdjointValueKind::Aggregate:
      s << "Aggregate<";
      if (auto *decl =
            getType().getASTType()->getStructOrBoundGenericStruct()) {
        s << "Struct>(";
        interleave(llvm::zip(decl->getStoredProperties(),
                             base->value.aggregate),
                             [&s](std::tuple<VarDecl *,
                                             const AdjointValue &> elt) {
                               s << std::get<0>(elt)->getName() << ": ";
                               std::get<1>(elt).print(s);
                             }, [&s] { s << ", "; });
      } else if (auto tupleType = getType().getAs<TupleType>()) {
        s << "Tuple>(";
        interleave(base->value.aggregate,
                   [&s](const AdjointValue &elt) { elt.print(s); },
                   [&s] { s << ", "; });
      } else {
        llvm_unreachable("Invalid aggregate");
      }
      s << ')';
      break;
    case AdjointValueKind::Concrete:
      s << "Concrete(" << base->value.concrete << ')';
      break;
    }
  }
};

inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
                                     const AdjointValue &adjVal) {
  adjVal.print(os);
  return os;
}

} // end anonymous namespace

namespace {

class JVPEmitter final
    : public TypeSubstCloner<JVPEmitter, SILOptFunctionBuilder> {
private:
  /// The global context.
  ADContext &context;

  /// The original function.
  SILFunction *const original;

  /// The `[differentiable]` attribute.
  SILDifferentiableAttr *const attr;

  /// The JVP function.
  SILFunction *const jvp;

  llvm::BumpPtrAllocator allocator;

  /// The differentiation invoker.
  DifferentiationInvoker invoker;

  /// Info from activity analysis on the original function.
  const DifferentiableActivityInfo &activityInfo;

  /// The differential info.
  LinearMapInfo differentialInfo;

  bool errorOccurred = false;

  //--------------------------------------------------------------------------//
  // Differential generation related fields
  //--------------------------------------------------------------------------//

  /// The builder for the differential function.
  SILBuilder differentialBuilder;

  /// Mapping from original basic blocks to corresponding differential basic
  /// blocks.
  DenseMap<SILBasicBlock *, SILBasicBlock *> diffBBMap;

  /// Mapping from original basic blocks and original values to corresponding
  /// tangent values.
  DenseMap<SILValue, AdjointValue> tangentValueMap;

  /// Mapping from original basic blocks and original buffers to corresponding
  /// tangent buffers.
  DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap;

  /// Mapping from differential basic blocks to differential struct arguments.
  DenseMap<SILBasicBlock *, SILArgument *> differentialStructArguments;

  /// Mapping from differential struct field declarations to differential struct
  /// elements destructured from the linear map basic block argument. In the
  /// beginning of each differential basic block, the block's differential
  /// struct is destructured into the individual elements stored here.
  DenseMap<VarDecl *, SILValue> differentialStructElements;

  /// An auxiliary differential local allocation builder.
  SILBuilder diffLocalAllocBuilder;

  /// Stack buffers allocated for storing local tangent values.
  SmallVector<SILValue, 8> differentialLocalAllocations;

  /// Mapping from original blocks to differential values. Used to build
  /// differential struct instances.
  DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> differentialValues;

  //--------------------------------------------------------------------------//
  // Getters
  //--------------------------------------------------------------------------//

  ASTContext &getASTContext() const { return jvp->getASTContext(); }
  SILModule &getModule() const { return jvp->getModule(); }
  const SILAutoDiffIndices &getIndices() const { return attr->getIndices(); }
  SILBuilder &getDifferentialBuilder() { return differentialBuilder; }
  SILFunction &getDifferential() {
    return differentialBuilder.getFunction();
  }
  SILArgument *getDifferentialStructArgument(SILBasicBlock *origBB) {
#ifndef NDEBUG
    auto *diffStruct = differentialStructArguments[origBB]->getType()
        .getStructOrBoundGenericStruct();
    assert(diffStruct == differentialInfo.getLinearMapStruct(origBB));
#endif
    return differentialStructArguments[origBB];
  }

  //--------------------------------------------------------------------------//
  // Initialization helpers
  //--------------------------------------------------------------------------//

  static SubstitutionMap getSubstitutionMap(SILFunction *original,
                                            SILFunction *jvp) {
    auto substMap = original->getForwardingSubstitutionMap();
    if (auto *jvpGenEnv = jvp->getGenericEnvironment()) {
      auto jvpSubstMap = jvpGenEnv->getForwardingSubstitutionMap();
      substMap = SubstitutionMap::get(
          jvpGenEnv->getGenericSignature(), QuerySubstitutionMap{jvpSubstMap},
          LookUpConformanceInSubstitutionMap(jvpSubstMap));
    }
    return substMap;
  }

  /// Returns the activity info about the SILValues in the original function.
  static const DifferentiableActivityInfo &getActivityInfo(
      ADContext &context, SILFunction *original,
      const SILAutoDiffIndices &indices, SILFunction *jvp) {
    // Get activity info of the original function.
    auto &passManager = context.getPassManager();
    auto *activityAnalysis =
        passManager.getAnalysis<DifferentiableActivityAnalysis>();
    auto &activityCollection = *activityAnalysis->get(original);
    auto &activityInfo = activityCollection.getActivityInfo(
        jvp->getLoweredFunctionType()->getGenericSignature(),
        AutoDiffDerivativeFunctionKind::JVP);
    LLVM_DEBUG(
        dumpActivityInfo(*original, indices, activityInfo, getADDebugStream()));
    return activityInfo;
  }

  //--------------------------------------------------------------------------//
  // Differential struct mapping
  //--------------------------------------------------------------------------//

  void initializeDifferentialStructElements(SILBasicBlock *origBB,
                                            SILInstructionResultArray values) {
    auto *diffStructDecl = differentialInfo.getLinearMapStruct(origBB);
    assert(diffStructDecl->getStoredProperties().size() == values.size() &&
           "The number of differential struct fields must equal the number of "
           "differential struct element values");
    for (auto pair : llvm::zip(diffStructDecl->getStoredProperties(), values)) {
      assert(
          std::get<1>(pair).getOwnershipKind() != ValueOwnershipKind::Guaranteed
              && "Differential struct elements must be @owned");
      auto insertion = differentialStructElements.insert({std::get<0>(pair),
                                                          std::get<1>(pair)});
      (void)insertion;
      assert(insertion.second &&
             "A differential struct element mapping already exists!");
    }
  }

  SILValue getDifferentialStructElement(SILBasicBlock *origBB, VarDecl *field) {
    assert(differentialInfo.getLinearMapStruct(origBB) ==
               cast<StructDecl>(field->getDeclContext()));
    assert(differentialStructElements.count(field) &&
           "Differential struct element for this field does not exist!");
    return differentialStructElements.lookup(field);
  }

  //--------------------------------------------------------------------------//
  // General utilities
  //--------------------------------------------------------------------------//

  SILBasicBlock::iterator getNextDifferentialLocalAllocationInsertionPoint() {
    // If there are no local allocations, insert at the beginning of the tangent
    // entry.
    if (differentialLocalAllocations.empty())
      return getDifferential().getEntryBlock()->begin();
    // Otherwise, insert before the last local allocation. Inserting before
    // rather than after ensures that allocation and zero initialization
    // instructions are grouped together.
    auto lastLocalAlloc = differentialLocalAllocations.back();
    auto it = lastLocalAlloc->getDefiningInstruction()->getIterator();
    return it;
  }

  /// Get the lowered SIL type of the given nominal type declaration.
  SILType getNominalDeclLoweredType(NominalTypeDecl *nominal) {
    auto nomType =
        getOpASTType(nominal->getDeclaredInterfaceType()->getCanonicalType());
    auto nomSILType = context.getTypeConverter().getLoweredType(
        nomType, ResilienceExpansion::Minimal);
    return nomSILType;
  }

  /// Build a differential struct value for the original block corresponding to
  /// the given terminator.
  StructInst *buildDifferentialValueStructValue(TermInst *termInst) {
    assert(termInst->getFunction() == original);
    auto loc = termInst->getFunction()->getLocation();
    auto *origBB = termInst->getParent();
    auto *jvpBB = BBMap[origBB];
    assert(jvpBB && "Basic block mapping should exist");
    auto *diffStruct = differentialInfo.getLinearMapStruct(origBB);
    assert(diffStruct && "The differential struct should have been declared");
    auto structLoweredTy = getNominalDeclLoweredType(diffStruct);
    auto bbDifferentialValues = differentialValues[origBB];
    if (!origBB->isEntry()) {
      auto *enumArg = jvpBB->getArguments().back();
      bbDifferentialValues.insert(bbDifferentialValues.begin(), enumArg);
    }
    return getBuilder().createStruct(loc, structLoweredTy,
                                     bbDifferentialValues);
  }

  //--------------------------------------------------------------------------//
  // Tangent value factory methods
  //--------------------------------------------------------------------------//

  AdjointValue makeZeroTangentValue(SILType type) {
    return AdjointValue::createZero(
        allocator, remapSILTypeInDifferential(type));
  }

  AdjointValue makeConcreteTangentValue(SILValue value) {
    return AdjointValue::createConcrete(allocator, value);
  }

  //--------------------------------------------------------------------------//
  // Tangent materialization
  //--------------------------------------------------------------------------//

  void emitZeroIndirect(CanType type, SILValue bufferAccess,
                        SILLocation loc) {
    auto builder = getDifferentialBuilder();
    auto tangentSpace = getTangentSpace(type);
    assert(tangentSpace && "No tangent space for this type");
    switch (tangentSpace->getKind()) {
    case VectorSpace::Kind::Vector:
      emitZeroIntoBuffer(builder, type, bufferAccess, loc);
      return;
    case VectorSpace::Kind::Tuple: {
      auto tupleType = tangentSpace->getTuple();
      SmallVector<SILValue, 8> zeroElements;
      for (unsigned i : range(tupleType->getNumElements())) {
        auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i);
        emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(),
                         eltAddr, loc);
      }
      return;
    }
    case VectorSpace::Kind::Function: {
      llvm_unreachable(
          "Unimplemented: Emit thunks for abstracting zero initialization");
    }
    }
  }

  SILValue emitZeroDirect(CanType type, SILLocation loc) {
    auto diffBuilder = getDifferentialBuilder();
    auto silType = getModule().Types.getLoweredLoadableType(
        type, ResilienceExpansion::Minimal, getModule());
    auto *buffer = diffBuilder.createAllocStack(loc, silType);
    emitZeroIndirect(type, buffer, loc);
    auto loaded = diffBuilder.emitLoadValueOperation(
        loc, buffer, LoadOwnershipQualifier::Take);
    diffBuilder.createDeallocStack(loc, buffer);
    return loaded;
  }

  SILValue materializeTangentDirect(AdjointValue val, SILLocation loc) {
    assert(val.getType().isObject());
    LLVM_DEBUG(getADDebugStream()
               << "Materializing tangents for " << val << '\n');
    switch (val.getKind()) {
    case AdjointValueKind::Zero: {
      auto zeroVal = emitZeroDirect(val.getSwiftType(), loc);
      return zeroVal;
    }
    case AdjointValueKind::Aggregate:
      llvm_unreachable(
          "Tuples and structs are not supported in forward mode yet.");
    case AdjointValueKind::Concrete:
      return val.getConcreteValue();
  }
  }

  SILValue materializeTangent(AdjointValue val, SILLocation loc) {
    if (val.isConcrete()) {
      LLVM_DEBUG(getADDebugStream()
                 << "Materializing tangent: Value is concrete.\n");
      return val.getConcreteValue();
    }
    LLVM_DEBUG(getADDebugStream() << "Materializing tangent: Value is "
                                     "non-concrete. Materializing directly.\n");
    return materializeTangentDirect(val, loc);
  }

  //--------------------------------------------------------------------------//
  // Tangent buffer mapping
  //--------------------------------------------------------------------------//

  void setTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
                        SILValue tangentBuffer) {
    assert(originalBuffer->getType().isAddress());
    auto insertion =
        bufferMap.try_emplace({origBB, originalBuffer}, tangentBuffer);
    assert(insertion.second && "tangent buffer already exists.");
    (void)insertion;
  }

  SILValue &getTangentBuffer(SILBasicBlock *origBB, SILValue originalBuffer) {
    assert(originalBuffer->getType().isAddress());
    assert(originalBuffer->getFunction() == original);
    auto insertion = bufferMap.try_emplace({origBB, originalBuffer},
                                           SILValue());
    assert(!insertion.second && "tangent buffer should already exist");
    return insertion.first->getSecond();
  }

  //--------------------------------------------------------------------------//
  // Differential type calculations
  //--------------------------------------------------------------------------//

  /// Substitutes all replacement types of the given substitution map using the
  /// tangent function's substitution map.
  SubstitutionMap remapSubstitutionMapInDifferential(SubstitutionMap substMap) {
    return substMap.subst(getDifferential().getForwardingSubstitutionMap());
  }

  /// Remap any archetypes into the differential function's context.
  Type remapTypeInDifferential(Type ty) {
    if (ty->hasArchetype())
      return getDifferential().mapTypeIntoContext(ty->mapTypeOutOfContext());
    return getDifferential().mapTypeIntoContext(ty);
  }

  /// Remap any archetypes into the differential function's context.
  SILType remapSILTypeInDifferential(SILType ty) {
    if (ty.hasArchetype())
      return getDifferential().mapTypeIntoContext(ty.mapTypeOutOfContext());
    return getDifferential().mapTypeIntoContext(ty);
  }

  /// Find the tangent space of a given canonical type.
  Optional<VectorSpace> getTangentSpace(CanType type) {
    return type->getAutoDiffAssociatedTangentSpace(
        LookUpConformanceInModule(getModule().getSwiftModule()));
  }

  /// Assuming the given type conforms to `Differentiable` after remapping,
  /// returns the associated tangent space SIL type.
  SILType getRemappedTangentType(SILType type) {
    return SILType::getPrimitiveType(
        getTangentSpace(remapSILTypeInDifferential(type).getASTType())
            ->getCanonicalType(),
        type.getCategory());
  }

  //--------------------------------------------------------------------------//
  // Tangent value mapping
  //--------------------------------------------------------------------------//

  /// Get the tangent for an original value. The given value must be in the
  /// original function.
  ///
  /// This method first tries to find an entry in `tangentValueMap`. If an entry
  /// doesn't exist, create a zero tangent.
  AdjointValue getTangentValue(SILValue originalValue) {
    assert(originalValue->getType().isObject());
    assert(originalValue->getFunction() == original);
    auto insertion = tangentValueMap.try_emplace(
        originalValue, makeZeroTangentValue(
        getRemappedTangentType(originalValue->getType())));
    return insertion.first->getSecond();
  }

  /// Map the tangent value to the given original value.
  void setTangentValue(SILBasicBlock *origBB, SILValue originalValue,
                       AdjointValue newTangentValue) {
    if (auto *defInst = originalValue->getDefiningInstruction()) {
      bool isTupleTypedApplyResult =
          isa<ApplyInst>(defInst) && originalValue->getType().is<TupleType>();
      assert(!isTupleTypedApplyResult &&
             "Should not set tangent value for tuple-typed result from `apply` "
             "instruction; use `destructure_tuple` on `apply` result and set "
             "tangent value for `destructure_tuple` results instead.");
    }
    assert(originalValue->getType().isObject());
    assert(newTangentValue.getType().isObject());
    assert(originalValue->getFunction() == original);
    LLVM_DEBUG(getADDebugStream() << "Adding tangent for " << originalValue);
    // The tangent value must be in the tangent space.
    assert(newTangentValue.getType() ==
           getRemappedTangentType(originalValue->getType()));
    auto insertion =
        tangentValueMap.try_emplace(originalValue, newTangentValue);
    auto inserted = insertion.second;
    assert(inserted && "The tangent value should not already exist.");
  }

  //--------------------------------------------------------------------------//
  // Tangent emission helpers
  //--------------------------------------------------------------------------//
public:
#define CLONE_AND_EMIT_TANGENT(INST, ID) \
  void visit##INST##Inst(INST##Inst *inst) { \
    TypeSubstCloner::visit##INST##Inst(inst); \
    if (differentialInfo.shouldDifferentiateInstruction(inst)) \
      emitTangentFor##INST##Inst(inst); \
  } \
  void emitTangentFor##INST##Inst(INST##Inst *(ID))

  CLONE_AND_EMIT_TANGENT(BeginBorrow, bbi) {
    auto &diffBuilder = getDifferentialBuilder();
    auto loc = bbi->getLoc();
    auto tanVal = materializeTangent(getTangentValue(bbi->getOperand()), loc);
    auto tanValBorrow = diffBuilder.emitBeginBorrowOperation(loc, tanVal);
    setTangentValue(bbi->getParent(), bbi,
                    makeConcreteTangentValue(tanValBorrow));
  }

  CLONE_AND_EMIT_TANGENT(EndBorrow, ebi) {
    auto &diffBuilder = getDifferentialBuilder();
    auto loc = ebi->getLoc();
    auto tanVal = materializeTangent(getTangentValue(ebi->getOperand()), loc);
    diffBuilder.emitEndBorrowOperation(loc, tanVal);
  }

  CLONE_AND_EMIT_TANGENT(DestroyValue, dvi) {
    auto &diffBuilder = getDifferentialBuilder();
    auto loc = dvi->getLoc();
    auto tanVal = materializeTangent(getTangentValue(dvi->getOperand()), loc);
    diffBuilder.emitDestroyValue(loc, tanVal);
  }

  CLONE_AND_EMIT_TANGENT(CopyValue, cvi) {
    auto &diffBuilder = getDifferentialBuilder();
    auto tan = getTangentValue(cvi->getOperand());
    auto tanVal = materializeTangent(tan, cvi->getLoc());
    auto tanValCopy = diffBuilder.emitCopyValueOperation(cvi->getLoc(), tanVal);
    setTangentValue(cvi->getParent(), cvi,
                    makeConcreteTangentValue(tanValCopy));
  }

  /// Handle `load` instruction.
  ///   Original: y = load x
  ///    Tangent: tan[y] = load tan[x]
  CLONE_AND_EMIT_TANGENT(Load, li) {
    auto &diffBuilder = getDifferentialBuilder();
    auto *bb = li->getParent();
    auto loc = li->getLoc();
    auto tanBuf = getTangentBuffer(bb, li->getOperand());
    auto tanVal = diffBuilder.emitLoadValueOperation(
        loc, tanBuf, li->getOwnershipQualifier());
    setTangentValue(bb, li, makeConcreteTangentValue(tanVal));
  }

  /// Handle `load_borrow` instruction.
  ///   Original: y = load_borrow x
  ///    Tangent: tan[y] = load_borrow tan[x]
  CLONE_AND_EMIT_TANGENT(LoadBorrow, lbi) {
    auto &diffBuilder = getDifferentialBuilder();
    auto *bb = lbi->getParent();
    auto loc = lbi->getLoc();
    auto tanBuf = getTangentBuffer(bb, lbi->getOperand());
    auto tanVal = diffBuilder.emitLoadBorrowOperation(
        loc, tanBuf);
    setTangentValue(bb, lbi, makeConcreteTangentValue(tanVal));
  }

  /// Handle `store` instruction in the differential.
  ///   Original: store x to y
  ///     Tangent: store tan[x] to tan[y]
  CLONE_AND_EMIT_TANGENT(Store, si) {
    auto &diffBuilder = getDifferentialBuilder();
    auto loc = si->getLoc();
    auto tanValSrc = materializeTangent(getTangentValue(si->getSrc()), loc);
    auto &tanValDest = getTangentBuffer(si->getParent(), si->getDest());
    diffBuilder.emitStoreValueOperation(
        loc, tanValSrc, tanValDest, si->getOwnershipQualifier());
  }

  /// Handle `store_borrow` instruction in the differential.
  ///   Original: store_borrow x to y
  ///    Tangent: store_borrow tan[x] to tan[y]
  CLONE_AND_EMIT_TANGENT(StoreBorrow, sbi) {
     auto &diffBuilder = getDifferentialBuilder();
     auto loc = sbi->getLoc();
     auto tanValSrc = materializeTangent(getTangentValue(sbi->getSrc()), loc);
     auto &tanValDest = getTangentBuffer(sbi->getParent(), sbi->getDest());
    diffBuilder.createStoreBorrow(loc, tanValSrc, tanValDest);
  }

  /// Handle `copy_addr` instruction.
  ///   Original: copy_addr x to y
  ///    Tangent: copy_addr tan[x] to tan[y]
  CLONE_AND_EMIT_TANGENT(CopyAddr, cai) {
    auto *diffGenEnv = getDifferential().getGenericEnvironment();
    auto diffGenSig = diffGenEnv
        ? diffGenEnv->getGenericSignature()->getCanonicalSignature()
        : nullptr;
    Lowering::GenericContextScope genericContextScope(
        context.getTypeConverter(), diffGenSig);

    auto diffBuilder = getDifferentialBuilder();
    auto loc = cai->getLoc();
    auto *bb = cai->getParent();
    auto &tanSrc = getTangentBuffer(bb, cai->getSrc());
    auto tanDest = getTangentBuffer(bb, cai->getDest());

    diffBuilder.createCopyAddr(loc, tanSrc, tanDest, cai->isTakeOfSrc(),
                               cai->isInitializationOfDest());
  }

  /// Handle `unconditional_checked_cast_addr` instruction.
  ///   Original: unconditional_checked_cast_addr $X in x to $Y in y
  ///    Tangent: unconditional_checked_cast_addr $X.Tan in tan[x]
  ///                                          to $Y.Tan in tan[y]
  CLONE_AND_EMIT_TANGENT(UnconditionalCheckedCastAddr, uccai) {
    auto diffBuilder = getDifferentialBuilder();
    auto loc = uccai->getLoc();
    auto *bb = uccai->getParent();
    auto &tanSrc = getTangentBuffer(bb, uccai->getSrc());
    auto tanDest = getTangentBuffer(bb, uccai->getDest());

    diffBuilder.createUnconditionalCheckedCastAddr(
        loc, tanSrc, tanSrc->getType().getASTType(), tanDest,
        tanDest->getType().getASTType());
  }

  /// Handle `begin_access` instruction (and do differentiability checks).
  ///   Original: y = begin_access x
  ///    Tangent: tan[y] = begin_access tan[x]
  CLONE_AND_EMIT_TANGENT(BeginAccess, bai) {
    // Check for non-differentiable writes.
    if (bai->getAccessKind() == SILAccessKind::Modify) {
      if (auto *gai = dyn_cast<GlobalAddrInst>(bai->getSource())) {
        context.emitNondifferentiabilityError(bai, invoker,
            diag::autodiff_cannot_differentiate_writes_to_global_variables);
        errorOccurred = true;
        return;
      }
      if (auto *pbi = dyn_cast<ProjectBoxInst>(bai->getSource())) {
        context.emitNondifferentiabilityError(bai, invoker,
            diag::autodiff_cannot_differentiate_writes_to_mutable_captures);
        errorOccurred = true;
        return;
      }
    }

    auto &diffBuilder = getDifferentialBuilder();
    auto *bb = bai->getParent();

    auto tanSrc = getTangentBuffer(bb, bai->getSource());
    auto *tanDest = diffBuilder.createBeginAccess(
        bai->getLoc(), tanSrc, bai->getAccessKind(), bai->getEnforcement(),
        bai->hasNoNestedConflict(), bai->isFromBuiltin());
    setTangentBuffer(bb, bai, tanDest);
  }

  /// Handle `end_access` instruction.
  ///   Original: begin_access x
  ///    Tangent: end_access tan[x]
  CLONE_AND_EMIT_TANGENT(EndAccess, eai) {
    auto &diffBuilder = getDifferentialBuilder();
    auto *bb = eai->getParent();
    auto loc = eai->getLoc();
    auto tanSrc = getTangentBuffer(bb, eai->getOperand());
    diffBuilder.createEndAccess(loc, tanSrc, eai->isAborting());
  }

  /// Handle `alloc_stack` instruction.
  ///   Original: y = alloc_stack $T
  ///    Tangent: tan[y] = alloc_stack $T.Tangent
  CLONE_AND_EMIT_TANGENT(AllocStack, asi) {
    auto &diffBuilder = getDifferentialBuilder();
    auto *mappedAllocStackInst = diffBuilder.createAllocStack(
        asi->getLoc(), getRemappedTangentType(asi->getElementType()),
        asi->getVarInfo());
    bufferMap.try_emplace({asi->getParent(), asi},
                          mappedAllocStackInst);
  }

  /// Handle `dealloc_stack` instruction.
  ///   Original: dealloc_stack x
  ///    Tangent: dealloc_stack tan[x]
  CLONE_AND_EMIT_TANGENT(DeallocStack, dsi) {
    auto &diffBuilder = getDifferentialBuilder();
    auto tanBuf = getTangentBuffer(dsi->getParent(), dsi->getOperand());
    diffBuilder.createDeallocStack(dsi->getLoc(), tanBuf);
  }

  /// Handle `destroy_addr` instruction.
  ///   Original: destroy_addr x
  ///    Tangent: destroy_addr tan[x]
  CLONE_AND_EMIT_TANGENT(DestroyAddr, dai) {
    auto &diffBuilder = getDifferentialBuilder();
    auto tanBuf = getTangentBuffer(dai->getParent(), dai->getOperand());
    diffBuilder.createDestroyAddr(dai->getLoc(), tanBuf);
  }

  /// Handle `struct` instruction.
  ///   Original: y = struct $T (x0, x1, x2, ...)
  ///    Tangent: tan[y] = struct $T.Tangent (tan[x0], tan[x1], tan[x2], ...)
  CLONE_AND_EMIT_TANGENT(Struct, si) {
    auto &diffBuilder = getDifferentialBuilder();
    SmallVector<SILValue, 4> tangentElements;
    for (auto elem : si->getElements())
      tangentElements.push_back(getTangentValue(elem).getConcreteValue());
    auto tanExtract = diffBuilder.createStruct(
        si->getLoc(), getRemappedTangentType(si->getType()), tangentElements);
    setTangentValue(si->getParent(), si, makeConcreteTangentValue(tanExtract));
  }

  /// Handle `struct_extract` instruction.
  ///   Original: y = struct_extract x, #field
  ///    Tangent: tan[y] = struct_extract tan[x], #field'
  ///                                             ^~~~~~~
  ///                          field in tangent space corresponding to #field
  CLONE_AND_EMIT_TANGENT(StructExtract, sei) {
    assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
           "`struct_extract` with `@noDerivative` field should not be "
           "differentiated; activity analysis should not marked as varied.");

    auto diffBuilder = getDifferentialBuilder();;
    auto tangentVectorTy =
        getRemappedTangentType(sei->getOperand()->getType());
    auto *tangentVectorDecl =
        tangentVectorTy.getStructOrBoundGenericStruct();

    // Find the corresponding field in the tangent space.
    VarDecl *tanField = nullptr;
    // If the tangent space is the original struct, then field is the same.
    if (tangentVectorDecl == sei->getStructDecl())
      tanField = sei->getField();
    // Otherwise, look up the field by name.
    else {
      auto tanFieldLookup =
          tangentVectorDecl->lookupDirect(sei->getField()->getName());
      if (tanFieldLookup.empty()) {
        context.emitNondifferentiabilityError(
            sei, invoker,
            diag::autodiff_stored_property_no_corresponding_tangent,
            sei->getStructDecl()->getNameStr(),
            sei->getField()->getNameStr());
        errorOccurred = true;
        return;
      }
      tanField = cast<VarDecl>(tanFieldLookup.front());
    }
    // Emit tangent `struct_extract`.
    auto tanStruct =
        materializeTangent(getTangentValue(sei->getOperand()), sei->getLoc());
    auto tangentInst =
        diffBuilder.createStructExtract(sei->getLoc(), tanStruct, tanField);
    // Update tangent value mapping for `struct_extract` result.
    auto tangentResult =  makeConcreteTangentValue(tangentInst);
    setTangentValue(sei->getParent(), sei, tangentResult);
  }

  /// Handle `struct_element_addr` instruction.
  ///   Original: y = struct_element_addr x, #field
  ///    Tangent: tan[y] = struct_element_addr tan[x], #field'
  ///                                                  ^~~~~~~
  ///                          field in tangent space corresponding to #field
  CLONE_AND_EMIT_TANGENT(StructElementAddr, seai) {
    assert(!seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
           "`struct_element_addr` with `@noDerivative` field should not be "
           "differentiated; activity analysis should not marked as varied.");

    auto diffBuilder = getDifferentialBuilder();
    auto *bb = seai->getParent();
    auto tangentVectorTy =
        getRemappedTangentType(seai->getOperand()->getType());
    auto *tangentVectorDecl =
        tangentVectorTy.getStructOrBoundGenericStruct();

    // Find the corresponding field in the tangent space.
    VarDecl *tanField = nullptr;
    // If the tangent space is the original struct, then field is the same.
    if (tangentVectorDecl == seai->getStructDecl())
      tanField = seai->getField();
    // Otherwise, look up the field by name.
    else {
      auto tanFieldLookup =
          tangentVectorDecl->lookupDirect(seai->getField()->getName());
      if (tanFieldLookup.empty()) {
        context.emitNondifferentiabilityError(
            seai, invoker,
            diag::autodiff_stored_property_no_corresponding_tangent,
            seai->getStructDecl()->getNameStr(),
            seai->getField()->getNameStr());
        errorOccurred = true;
        return;
      }
      tanField = cast<VarDecl>(tanFieldLookup.front());
    }

    // Emit tangent `struct_element_addr`.
    auto tanOperand = getTangentBuffer(bb, seai->getOperand());
    auto tangentInst = diffBuilder.createStructElementAddr(
        seai->getLoc(), tanOperand, tanField);
    // Update tangent buffer map for `struct_element_addr`.
    setTangentBuffer(bb, seai, tangentInst);
  }

  /// Handle `tuple` instruction.
  ///   Original: y = tuple (x0, x1, x2, ...)
  ///    Tangent: tan[y] = tuple (tan[x0], tan[x1], tan[x2], ...)
  CLONE_AND_EMIT_TANGENT(Tuple, ti) {
    auto diffBuilder = getDifferentialBuilder();

    // Get the tangents of all the tuple elements.
    SmallVector<SILValue, 8> tangentTupleElements;
    for (auto elem : ti->getElements()) {
      tangentTupleElements.push_back(
          materializeTangent(getTangentValue(elem), ti->getLoc()));
    }

    // Emit the instruction and add the tangent mapping.
    auto tanTuple = diffBuilder.createTuple(ti->getLoc(), tangentTupleElements);
    setTangentValue(ti->getParent(), ti, makeConcreteTangentValue(tanTuple));
  }

  /// Handle `tuple_extract` instruction.
  ///   Original: y = tuple_extract x, <n>
  ///    Tangent: tan[y] = tuple_extract tan[x], <n'>
  ///                                            ^~~~
  ///                         tuple tangent space index corresponding to n
  CLONE_AND_EMIT_TANGENT(TupleExtract, tei) {
    auto &diffBuilder = getDifferentialBuilder();
    auto loc = tei->getLoc();
    auto origTupleTy = tei->getOperand()->getType().castTo<TupleType>();
    unsigned tanIndex = 0;
    for (unsigned i : range(tei->getFieldNo())) {
      if (getTangentSpace(
              origTupleTy->getElement(i).getType()->getCanonicalType()))
        ++tanIndex;
    }
    auto tanType = getRemappedTangentType(tei->getType());
    auto tanSource = materializeTangent(
        getTangentValue(tei->getOperand()), loc);
    SILValue tanBuf;
    // If the tangent buffer of the source does not have a tuple type, then
    // it must represent a "single element tuple type". Use it directly.
    if (!tanSource->getType().is<TupleType>()) {
      setTangentValue(tei->getParent(), tei,
                      makeConcreteTangentValue(tanSource));
    } else {
      tanBuf = diffBuilder.createTupleExtract(loc, tanSource, tanIndex, tanType);
      bufferMap.try_emplace({tei->getParent(), tei}, tanBuf);
    }
  }

  /// Handle `tuple_element_addr` instruction.
  ///   Original: y = tuple_element_addr x, <n>
  ///    Tangent: tan[y] = tuple_element_addr tan[x], <n'>
  ///                                                ^~~~
  ///                            tuple tangent space index corresponding to n
  CLONE_AND_EMIT_TANGENT(TupleElementAddr, teai) {
    auto &diffBuilder = getDifferentialBuilder();
    auto origTupleTy = teai->getOperand()->getType().castTo<TupleType>();
    unsigned tanIndex = 0;
    for (unsigned i : range(teai->getFieldNo())) {
      if (getTangentSpace(
              origTupleTy->getElement(i).getType()->getCanonicalType()))
        ++tanIndex;
    }
    auto tanType = getRemappedTangentType(teai->getType());
    auto tanSource = getTangentBuffer(teai->getParent(), teai->getOperand());
    SILValue tanBuf;
    // If the tangent buffer of the source does not have a tuple type, then
    // it must represent a "single element tuple type". Use it directly.
    if (!tanSource->getType().is<TupleType>()) {
      tanBuf = tanSource;
    } else {
      tanBuf = diffBuilder.createTupleElementAddr(
          teai->getLoc(), tanSource, tanIndex, tanType);
    }
    bufferMap.try_emplace({teai->getParent(), teai}, tanBuf);
  }

  /// Handle `destructure_tuple` instruction.
  ///   Original: (y0, y1, ...)  = destructure_tuple x, <n>
  ///    Tangent: (tan[y0], tan[y1], ...) = destructure_tuple tan[x], <n'>
  ///                                                                 ^~~~
  ///                              tuple tangent space index corresponding to n
  CLONE_AND_EMIT_TANGENT(DestructureTuple, dti) {
    auto &diffBuilder = getDifferentialBuilder();
    auto *bb = dti->getParent();
    auto loc = dti->getLoc();

    SmallVector<SILValue, 2> activeOrigResults;
    bool hasActiveResult = false;
    for (auto result : dti->getResults()) {
      if (activityInfo.isActive(result, getIndices())) {
        activeOrigResults.push_back(result);
        hasActiveResult = true;
        break;
      }
    }
    assert(!activeOrigResults.empty() &&
           "original 'destructure_tuple' should have at least one active "
           "result");

    auto tanTuple =
        materializeTangent(getTangentValue(dti->getOperand()), loc);
    auto *tupleElements = diffBuilder.createDestructureTuple(loc, tanTuple);
    for (auto i : range(tupleElements->getNumResults())) {
      auto origElem = dti->getResult(i);
      auto tanElem = tupleElements->getResult(i);
      setTangentValue(bb, origElem, makeConcreteTangentValue(tanElem));
    }
  }

#undef CLONE_AND_EMIT_TANGENT

  /// Handle `apply` instruction.
  ///   Original: y = apply f(x0, x1, ...)
  ///    Tangent: tan[y] = apply diff_f(tan[x0], tan[x1], ...)
  void emitTangentForApplyInst(ApplyInst *ai,
                               const SILAutoDiffIndices &actualIndices,
                               CanSILFunctionType originalDifferentialType) {
    assert(differentialInfo.shouldDifferentiateApplyInst(ai));
    auto *bb = ai->getParent();
    auto loc = ai->getLoc();
    auto &diffBuilder = getDifferentialBuilder();

    // Get the differential value.
    auto *field = differentialInfo.lookUpLinearMapDecl(ai);
    assert(field);
    SILValue differential = getDifferentialStructElement(bb, field);
    auto differentialType = remapSILTypeInDifferential(differential->getType())
        .castTo<SILFunctionType>();

    // Get the differential arguments.
    SmallVector<SILValue, 8> diffArgs;

    for (auto indRes : ai->getIndirectSILResults())
      diffArgs.push_back(getTangentBuffer(bb, indRes));

    auto paramArgs = ai->getArgumentsWithoutIndirectResults();
    // Get the tangent value of the original arguments.
    for (auto i : indices(paramArgs)) {
      auto origArg = paramArgs[i];
      // If the argument is not active:
      // - Skip the element, if it is not differentiable.
      // - Otherwise, add a zero value to that location.
      if (!activityInfo.isActive(origArg, getIndices())) {
        auto origCalleeType = ai->getSubstCalleeType();
        if (!origCalleeType->isDifferentiable())
          continue;
        auto actualOrigCalleeIndices =
            origCalleeType->getDifferentiationParameterIndices();
        if (actualOrigCalleeIndices->contains(i)) {
          SILValue tanParam;
          if (origArg->getType().isObject()) {
            tanParam = emitZeroDirect(
                getRemappedTangentType(origArg->getType()).getASTType(), loc);
            diffArgs.push_back(tanParam);
          } else {
            tanParam = diffBuilder.createAllocStack(
                loc, getRemappedTangentType(origArg->getType()));
            emitZeroIndirect(
                getRemappedTangentType(origArg->getType()).getASTType(), tanParam,
                loc);
          }
        }
      }
      // Otherwise, if the argument is active, handle the argument normally by
      // getting its tangent value.
      else {
        SILValue tanParam;
        if (origArg->getType().isObject()) {
          tanParam = materializeTangent(getTangentValue(origArg), loc);
        } else {
          tanParam = getTangentBuffer(ai->getParent(), origArg);
        }
        diffArgs.push_back(tanParam);
        if (errorOccurred)
          return;
      }
    }

    // If callee differential was reabstracted in JVP, reabstract the callee
    // differential.
    if (!differentialType->isEqual(originalDifferentialType)) {
      SILOptFunctionBuilder fb(context.getTransform());
      auto *thunk = getOrCreateReabstractionThunk(
          fb, context.getModule(), loc, &getDifferential(),
          differentialType, originalDifferentialType);
      auto *thunkRef = diffBuilder.createFunctionRef(loc, thunk);
      differential = diffBuilder.createPartialApply(
         loc, thunkRef,
         remapSubstitutionMapInDifferential(
             thunk->getForwardingSubstitutionMap()),
         {differential}, differentialType->getCalleeConvention());
    }

    // Call the differential.
    auto *differentialCall = diffBuilder.createApply(
        loc, differential, SubstitutionMap(), diffArgs,
        /*isNonThrowing*/ false);
    diffBuilder.emitDestroyValueOperation(loc, differential);
    assert(differentialCall->getNumResults() == 1 &&
           "Expected differential to return one result");

    // Get the original results of the `apply` instructions.
    SmallVector<SILValue, 8> origDirectResults;
    forEachApplyDirectResult(ai, [&](SILValue directResult) {
      origDirectResults.push_back(directResult);
    });
    SmallVector<SILValue, 8> origAllResults;
    collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults);
    auto origResult = origAllResults[actualIndices.source];

    // Get the differential results of the `apply` instructions.
    SmallVector<SILValue, 8> differentialDirectResults;
    forEachApplyDirectResult(differentialCall, [&](SILValue directResult) {
      differentialDirectResults.push_back(directResult);
    });
    SmallVector<SILValue, 8> differentialAllResults;
    collectAllActualResultsInTypeOrder(differentialCall,
                                       differentialDirectResults,
                                       differentialAllResults);
    auto differentialResult = differentialAllResults.front();

    // Add tangent for original result.
    if (origResult->getType().isObject()) {
      if (!origResult->getType().is<TupleType>()) {
        setTangentValue(bb, origResult,
            makeConcreteTangentValue(differentialResult));
      } else if (auto *dti = getSingleDestructureTupleUser(ai)) {
        bool notSetValue = true;
        for (auto result : dti->getResults()) {
          if (activityInfo.isActive(result, getIndices())) {
            assert(notSetValue &&
                   "This was incorrectly set, should only have one active "
                   "result from the tuple.");
            notSetValue = false;
            setTangentValue(bb, result,
                            makeConcreteTangentValue(differentialResult));
          }
        }
      }
    }
  }

  /// Generate a `return` instruction in the current differential basic block.
  void emitReturnInstForDifferential() {
    auto &differential = getDifferential();
    auto diffLoc = differential.getLocation();
    auto &diffBuilder = getDifferentialBuilder();

    SmallVector<SILValue, 2> activeResults;

    // This vector will contain all the materialized return elements.
    SmallVector<SILValue, 8> retElts;
    SmallVector<SILValue, 2> originalResults;
    collectAllDirectResultsInTypeOrder(*original, originalResults);

    // Materializes the return element corresponding to the result
    // `resultIndex` into the `retElts` vector.
    auto addActiveResult = [&](unsigned resultIndex) -> void {
      auto origResult = originalResults[resultIndex];
      assert(origResult->getType().isObject() &&
             "Should only be handling direct results for 'return' "
             "instruction.");
      if (activityInfo.isActive(origResult, getIndices())) {
        activeResults.push_back(origResult);
      }
    };
    // Create an array of the direct tangent values of the original results.
    for (auto i : range(originalResults.size()))
      addActiveResult(i);
    assert(activeResults.size() <= 1);

    if (activeResults.empty() && !originalResults.empty()) {
      // Create zero tangent value for direct result.
      auto origResult = originalResults[getIndices().source];
      assert(origResult->getType().isObject() &&
             "Should only be handling direct results for 'return' "
             "instruction.");
      auto zeroType = origResult->getType().getASTType();
      auto zero =
          emitZeroDirect(getTangentSpace(zeroType)->getCanonicalType(),
                         diffLoc);
      retElts.push_back(zero);
    } else if (!activeResults.empty()) {
      auto diffVal = getTangentValue(activeResults.front());
      auto val = materializeTangent(diffVal, diffLoc);
      retElts.push_back(val);
    }

    diffBuilder.createReturn(
        diffLoc, joinElements(retElts, diffBuilder, diffLoc));
  }

private:

  /// Set up the differential function. This includes:
  /// - Creating all differential blocks.
  /// - Creating differential entry block arguments based on the function type.
  /// - Creating tangent value mapping for original/differential parameters.
  /// - Checking for unvaried result and emitting related warnings.
  void prepareForDifferentialGeneration() {
    // Create differential blocks and arguments.
    auto *diffGenEnv = getDifferential().getGenericEnvironment();
    auto diffGenSig = diffGenEnv
        ? diffGenEnv->getGenericSignature()->getCanonicalSignature()
        : nullptr;
    auto &differential = getDifferential();
    auto *origEntry = original->getEntryBlock();
    for (auto &origBB : *original) {
      auto *diffBB = differential.createBasicBlock();
      diffBBMap.insert({&origBB, diffBB});
      {
        Lowering::GenericContextScope genericContextScope(
            context.getTypeConverter(), diffGenSig);
        auto diffStructLoweredType = remapSILTypeInDifferential(
            differentialInfo.getLinearMapStructLoweredType(&origBB));

        // If the BB is the original entry, then the differential block that we
        // just created must be the differential function's entry. Create
        // differential entry arguments and continue.
        if (&origBB == origEntry) {
          assert(diffBB->isEntry());
          createEntryArguments(&differential);
          auto *lastArg = diffBB->getArguments().back();
          assert(lastArg->getType() == diffStructLoweredType);
          differentialStructArguments[&origBB] = lastArg;
        }
      }

      LLVM_DEBUG({
        auto &s = getADDebugStream()
                  << "Original bb" + std::to_string(origBB.getDebugID())
                  << ": To differentiate or not to differentiate?\n";
        for (auto &inst : origBB) {
          s << (differentialInfo.shouldDifferentiateInstruction(&inst)
                    ? "[∂] " : "[ ] ")
            << inst;
        }
      });
    }

    assert(diffBBMap.size() == 1 &&
           "Can only currently handle single basic block functions");

    // The differential function has type:
    // (arg0', ..., argn', entry_df_struct) -> result'.
    auto diffParamArgs =
        differential.getArgumentsWithoutIndirectResults().drop_back();
    assert(diffParamArgs.size() ==
           attr->getIndices().parameters->getNumIndices());
    auto origParamArgs = original->getArgumentsWithoutIndirectResults();

    // TODO(TF-788): Re-enable non-varied result warning.
    /*
    // Check if result is not varied.
    SmallVector<SILValue, 8> origFormalResults;
    collectAllFormalResultsInTypeOrder(*original, origFormalResults);
    auto origResult = origFormalResults[getIndices().source];
    // Emit warning if original result is not varied, because it will always
    // have a zero derivative.
    if (!activityInfo.isVaried(origResult, getIndices().parameters)) {
      // Emit fixit if original result has a valid source location.
      auto startLoc = origResult.getLoc().getStartSourceLoc();
      auto endLoc = origResult.getLoc().getEndSourceLoc();
      if (startLoc.isValid() && endLoc.isValid()) {
        context.diagnose(startLoc, diag::autodiff_nonvaried_result_fixit)
            .fixItInsert(startLoc, "withoutDerivative(at:")
            .fixItInsertAfter(endLoc, ")");
      }
    }
    */

    // Initialize tangent mapping for parameters.
    auto diffParamsIt = getIndices().parameters->begin();
    for (auto index : range(diffParamArgs.size())) {
      auto *diffArg = diffParamArgs[index];
      auto *origArg = origParamArgs[*diffParamsIt];
      diffParamsIt++;
      if (diffArg->getType().isAddress()) {
        setTangentBuffer(origEntry, origArg, diffArg);
      } else {
        setTangentValue(
            origEntry, origArg, makeConcreteTangentValue(diffArg));
      }
      LLVM_DEBUG(getADDebugStream()
                 << "Assigned parameter " << *diffArg
                 << " as the tangent of original result " << *origArg);
    }

    // Initialize tangent mapping for indirect results.
    auto origIndResults = original->getIndirectResults();
    auto diffIndResults = differential.getIndirectResults();
    assert(origIndResults.size() == diffIndResults.size());

    for (auto &origBB : *original)
      for (auto i : indices(diffIndResults))
        setTangentBuffer(&origBB, origIndResults[i], diffIndResults[i]);
  }

public:
  explicit JVPEmitter(ADContext &context, SILFunction *original,
                      SILDifferentiableAttr *attr, SILFunction *jvp,
                      DifferentiationInvoker invoker)
      : TypeSubstCloner(*jvp, *original, getSubstitutionMap(original, jvp)),
        context(context), original(original), attr(attr), jvp(jvp),
        invoker(invoker), activityInfo(getActivityInfo(
                              context, original, attr->getIndices(), jvp)),
        differentialInfo(context, AutoDiffLinearMapKind::Differential, original,
                         jvp, attr->getIndices(), activityInfo),
        differentialBuilder(SILBuilder(*createEmptyDifferential(
            context, original, attr, &differentialInfo))),
        diffLocalAllocBuilder(getDifferential()) {
    // Create empty differential function.
    context.getGeneratedFunctions().push_back(&getDifferential());
  }

  static SILFunction *createEmptyDifferential(ADContext &context,
                                              SILFunction *original,
                                              SILDifferentiableAttr *attr,
                                              LinearMapInfo *linearMapInfo) {
    auto &module = context.getModule();
    auto origTy = original->getLoweredFunctionType();
    auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());

    // RAII that pushes the original function's generic signature to
    // `module.Types` so that calls to `module.Types.getTypeLowering()` below
    // will know the original function's generic parameter types.
    Lowering::GenericContextScope genericContextScope(
        module.Types, origTy->getGenericSignature());

    // Parameters of the differential are:
    // - the tangent values of the wrt parameters.
    // - the differential struct for the original entry.
    // Result of the differential is in the tangent space of the original
    // result.
    SmallVector<SILParameterInfo, 8> dfParams;
    SmallVector<SILResultInfo, 8> dfResults;
    auto origParams = origTy->getParameters();
    auto indices = attr->getIndices();

    // Add differential results.
    auto origResInfo = origTy->getResults()[indices.source];
    dfResults.push_back(
        SILResultInfo(origResInfo.getType()
                          ->getAutoDiffAssociatedTangentSpace(lookupConformance)
                          ->getCanonicalType(),
                      origResInfo.getConvention()));

    // Add differential parameters for the requested wrt parameters.
    for (auto i : indices.parameters->getIndices()) {
      auto origParam = origParams[i];
      dfParams.push_back(SILParameterInfo(
          origParam.getType()
              ->getAutoDiffAssociatedTangentSpace(lookupConformance)
              ->getCanonicalType(),
          origParam.getConvention()));
    }

    // Accept a differential struct in the differential parameter list. This is
    // the returned differential's closure context.
    auto *origEntry = original->getEntryBlock();
    auto *dfStruct = linearMapInfo->getLinearMapStruct(origEntry);
    auto dfStructType =
        dfStruct->getDeclaredInterfaceType()->getCanonicalType();
    dfParams.push_back({dfStructType, ParameterConvention::Direct_Owned});

    Mangle::ASTMangler mangler;
    auto diffName = original->getASTContext().getIdentifier(
        mangler.mangleAutoDiffLinearMapHelper(
            original->getName(), AutoDiffLinearMapKind::Differential,
            indices)).str();
    auto diffGenericSig = getDerivativeGenericSignature(attr, original);
    auto *diffGenericEnv =
        diffGenericSig ? diffGenericSig->getGenericEnvironment() : nullptr;
    auto diffType = SILFunctionType::get(
        diffGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(),
        origTy->getCalleeConvention(), dfParams, {}, dfResults, None,
        original->getASTContext());

    SILOptFunctionBuilder fb(context.getTransform());
    // The generated tangent linkage is set to Hidden because generated tangent
    // are never called cross-module.
    auto linkage = SILLinkage::Hidden;
    auto *differential = fb.createFunction(
        linkage, diffName, diffType, diffGenericEnv, original->getLocation(),
        original->isBare(), IsNotTransparent, original->isSerialized(),
        original->isDynamicallyReplaceable());
    differential->setDebugScope(
        new (module) SILDebugScope(original->getLocation(), differential));

    return differential;
  }

  /// Run JVP generation. Returns true on error.
  bool run() {
    LLVM_DEBUG(getADDebugStream()
               << "Cloning original @" << original->getName()
               << " to jvp @" << jvp->getName() << '\n');
    // Create JVP and differential entry and arguments.
    auto *entry = jvp->createBasicBlock();
    createEntryArguments(jvp);
    prepareForDifferentialGeneration();
    // Clone.
    SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(),
                                       entry->getArguments().end());
    cloneFunctionBody(original, entry, entryArgs);
    emitReturnInstForDifferential();
    // If errors occurred, back out.
    if (errorOccurred)
      return true;
    LLVM_DEBUG(getADDebugStream() << "Generated JVP for "
               << original->getName() << ":\n" << *jvp);
    LLVM_DEBUG(getADDebugStream() << "Generated differential for "
               << original->getName() << ":\n" << getDifferential());
    return errorOccurred;
  }

  void postProcess(SILInstruction *orig, SILInstruction *cloned) {
    if (errorOccurred)
      return;
    SILClonerWithScopes::postProcess(orig, cloned);
  }

  /// Remap original basic blocks.
  SILBasicBlock *remapBasicBlock(SILBasicBlock *bb) {
    auto *jvpBB = BBMap[bb];
    return jvpBB;
  }

  /// General visitor for all instructions. If any error is emitted by previous
  /// visits, bail out.
  void visit(SILInstruction *inst) {
    auto diffBuilder = getDifferentialBuilder();
    if (errorOccurred)
      return;
    if (differentialInfo.shouldDifferentiateInstruction(inst)) {
      LLVM_DEBUG(getADDebugStream() << "JVPEmitter visited:\n[ORIG]" << *inst);
#ifndef NDEBUG
      auto beforeInsertion = std::prev(diffBuilder.getInsertionPoint());
#endif
      TypeSubstCloner::visit(inst);
      LLVM_DEBUG({
        auto &s = llvm::dbgs() << "[TAN] Emitted in differential:\n";
        auto afterInsertion = diffBuilder.getInsertionPoint();
        for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
          s << *it;
      });
    } else {
      TypeSubstCloner::visit(inst);
    }
  }

  void visitSILInstruction(SILInstruction *inst) {
    context.emitNondifferentiabilityError(inst, invoker,
        diag::autodiff_expression_not_differentiable_note);
    errorOccurred = true;
  }

  void visitInstructionsInBlock(SILBasicBlock *bb) {
    // Destructure the differential struct to get the elements.
    auto &diffBuilder = getDifferentialBuilder();
    auto diffLoc = getDifferential().getLocation();
    auto *diffBB = diffBBMap.lookup(bb);
    auto *mainDifferentialStruct = diffBB->getArguments().back();
    diffBuilder.setInsertionPoint(diffBB);
    auto *dsi = diffBuilder.createDestructureStruct(
        diffLoc, mainDifferentialStruct);
    initializeDifferentialStructElements(bb, dsi->getResults());
    TypeSubstCloner::visitInstructionsInBlock(bb);
  }

  // If an `apply` has active results or active inout parameters, replace it
  // with an `apply` of its JVP.
  void visitApplyInst(ApplyInst *ai) {
    // If the function should not be differentiated or its the array literal
    // initialization intrinsic, just do standard cloning.
    if (!differentialInfo.shouldDifferentiateApplyInst(ai) ||
        isArrayLiteralIntrinsic(ai)) {
      LLVM_DEBUG(getADDebugStream() << "No active results:\n" << *ai << '\n');
      TypeSubstCloner::visitApplyInst(ai);
      return;
    }

    // Check and reject functions with active inout arguments. It's not yet
    // supported.
    auto paramInfos = ai->getSubstCalleeConv().getParameters();
    auto paramArgs = ai->getArgumentsWithoutIndirectResults();
    for (unsigned i : swift::indices(paramInfos)) {
      if (paramInfos[i].isIndirectInOut() &&
          activityInfo.isActive(paramArgs[i], getIndices())) {
        context.emitNondifferentiabilityError(ai, invoker,
            diag::autodiff_cannot_differentiate_through_inout_arguments);
        errorOccurred = true;
        return;
      }
    }

    LLVM_DEBUG(getADDebugStream() << "JVP-transforming:\n" << *ai << '\n');

    // Get the minimal parameter and result indices required for differentiating
    // this `apply`.
    SmallVector<SILValue, 4> allResults;
    SmallVector<unsigned, 8> activeParamIndices;
    SmallVector<unsigned, 8> activeResultIndices;
    collectMinimalIndicesForFunctionCall(ai, getIndices(), activityInfo,
                                         allResults, activeParamIndices,
                                         activeResultIndices);
    assert(!activeParamIndices.empty() && "Parameter indices cannot be empty");
    assert(!activeResultIndices.empty() && "Result indices cannot be empty");
    LLVM_DEBUG(auto &s = getADDebugStream() << "Active indices: params={";
               interleave(activeParamIndices.begin(), activeParamIndices.end(),
                          [&s](unsigned i) { s << i; }, [&s] { s << ", "; });
               s << "}, results={"; interleave(
                   activeResultIndices.begin(), activeResultIndices.end(),
                   [&s](unsigned i) { s << i; }, [&s] { s << ", "; });
               s << "}\n";);
    // FIXME: We don't support multiple active results yet.
    if (activeResultIndices.size() > 1) {
      context.emitNondifferentiabilityError(
          ai, invoker, diag::autodiff_expression_not_differentiable_note);
      errorOccurred = true;
      return;
    }
    // Form expected indices, assuming there's only one result.
    SILAutoDiffIndices indices(
        activeResultIndices.front(),
        IndexSubset::get(
            getASTContext(), ai->getArgumentsWithoutIndirectResults().size(),
            activeParamIndices));

    // Emit the JVP.
    auto loc = ai->getLoc();
    auto &builder = getBuilder();
    auto original = getOpValue(ai->getCallee());
    SILValue jvpValue;
    // If functionSource is a `@differentiable` function, just extract it.
    auto originalFnTy = original->getType().castTo<SILFunctionType>();
    if (originalFnTy->isDifferentiable()) {
      auto paramIndices = originalFnTy->getDifferentiationParameterIndices();
      for (auto i : indices.parameters->getIndices()) {
        if (!paramIndices->contains(i)) {
          context.emitNondifferentiabilityError(original, invoker,
              diag::autodiff_function_nondiff_parameter_not_differentiable);
          errorOccurred = true;
          return;
        }
      }
      auto borrowedDiffFunc = builder.emitBeginBorrowOperation(loc, original);
      jvpValue = builder.createDifferentiableFunctionExtract(
          loc, NormalDifferentiableFunctionTypeComponent::JVP,
          borrowedDiffFunc);
      jvpValue = builder.emitCopyValueOperation(loc, jvpValue);
    }

    // If JVP has not yet been found, emit an `differentiable_function`
    // instruction on the remapped original function operand and
    // an `differentiable_function_extract` instruction to get the JVP.
    // The `differentiable_function` instruction will be canonicalized during
    // the transform main loop.
    if (!jvpValue) {
      // FIXME: Handle indirect differentiation invokers. This may require some
      // redesign: currently, each original function + attribute pair is mapped
      // only to one invoker.
      /*
       DifferentiationInvoker indirect(ai, attr);
       auto insertion =
           context.getInvokers().try_emplace({this->original, attr}, indirect);
       auto &invoker = insertion.first->getSecond();
       invoker = indirect;
       */

      // If the original `apply` instruction has a substitution map, then the
      // applied function is specialized.
      // In the JVP, specialization is also necessary for parity. The original
      // function operand is specialized with a remapped version of same
      // substitution map using an argument-less `partial_apply`.
      if (ai->getSubstitutionMap().empty()) {
        original = builder.emitCopyValueOperation(loc, original);
      } else {
        auto substMap = getOpSubstitutionMap(ai->getSubstitutionMap());
        auto jvpPartialApply = getBuilder().createPartialApply(
            ai->getLoc(), original, substMap, {},
            ParameterConvention::Direct_Guaranteed);
        original = jvpPartialApply;
      }

      // Check and diagnose non-differentiable original function type.
      auto diagnoseNondifferentiableOriginalFunctionType =
          [&](CanSILFunctionType origFnTy) {
            // Check and diagnose non-differentiable arguments.
            for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
              if (indices.isWrtParameter(paramIndex) &&
                      !originalFnTy->getParameters()[paramIndex]
                      .getSILStorageType()
                      .isDifferentiable(getModule())) {
                context.emitNondifferentiabilityError(
                    ai->getArgumentsWithoutIndirectResults()[paramIndex], invoker,
                    diag::autodiff_nondifferentiable_argument);
                errorOccurred = true;
                return true;
              }
            }
            // Check and diagnose non-differentiable results.
            if (!originalFnTy->getResults()[indices.source]
                    .getSILStorageType()
                    .isDifferentiable(getModule())) {
              context.emitNondifferentiabilityError(
                  original, invoker, diag::autodiff_nondifferentiable_result);
              errorOccurred = true;
              return true;
            }
            return false;
          };
      if (diagnoseNondifferentiableOriginalFunctionType(originalFnTy))
        return;

      auto *diffFuncInst = context.createDifferentiableFunction(
          builder, loc, indices.parameters, original);

      // Record the `differentiable_function` instruction.
      context.getDifferentiableFunctionInsts().push_back(diffFuncInst);
      // TODO(TF-689): Make `differentiable_function` store result indices and
      // remove `ADContext::resultIndices`.
      context.getResultIndices()[diffFuncInst] = activeResultIndices.front();

      auto borrowedADFunc =
          builder.emitBeginBorrowOperation(loc, diffFuncInst);
      auto extractedJVP = builder.createDifferentiableFunctionExtract(
          loc, NormalDifferentiableFunctionTypeComponent::JVP,
          borrowedADFunc);
      jvpValue = builder.emitCopyValueOperation(loc, extractedJVP);
      builder.emitEndBorrowOperation(loc, borrowedADFunc);
      builder.emitDestroyValueOperation(loc, diffFuncInst);
    }

    // Call the JVP using the original parameters.
    SmallVector<SILValue, 8> jvpArgs;
    auto jvpFnTy = getOpType(jvpValue->getType()).castTo<SILFunctionType>();
    auto numJVPArgs =
        jvpFnTy->getNumParameters() + jvpFnTy->getNumIndirectFormalResults();
    jvpArgs.reserve(numJVPArgs);
    // Collect substituted arguments.
    for (auto origArg : ai->getArguments())
      jvpArgs.push_back(getOpValue(origArg));
    assert(jvpArgs.size() == numJVPArgs);
    // Apply the JVP.
    // The JVP should be specialized, so no substitution map is necessary.
    auto *jvpCall = getBuilder().createApply(loc, jvpValue, SubstitutionMap(),
                                             jvpArgs, ai->isNonThrowing());
    LLVM_DEBUG(getADDebugStream() << "Applied jvp function\n" << *jvpCall);

    // Release the differentiable function.
    builder.emitDestroyValueOperation(loc, jvpValue);

    // Get the JVP results (original results and differential).
    SmallVector<SILValue, 8> jvpDirectResults;
    extractAllElements(jvpCall, builder, jvpDirectResults);
    auto originalDirectResults =
        ArrayRef<SILValue>(jvpDirectResults).drop_back(1);
    auto originalDirectResult =
        joinElements(originalDirectResults, getBuilder(), jvpCall->getLoc());

    mapValue(ai, originalDirectResult);

    // Some instructions that produce the callee may have been cloned.
    // If the original callee did not have any users beyond this `apply`,
    // recursively kill the cloned callee.
    if (auto *origCallee = cast_or_null<SingleValueInstruction>(
            ai->getCallee()->getDefiningInstruction()))
      if (origCallee->hasOneUse())
        recursivelyDeleteTriviallyDeadInstructions(
            getOpValue(origCallee)->getDefiningInstruction());

    // Add the differential function for when we create the struct we partially
    // apply to the differential we are generating.
    auto differential = jvpDirectResults.back();
    auto *differentialDecl = differentialInfo.lookUpLinearMapDecl(ai);
    auto originalDifferentialType =
        getOpType(differential->getType()).getAs<SILFunctionType>();
    auto differentialType =
        remapType(differential->getType())
            .castTo<SILFunctionType>();
    auto jvpGenSig = SubsMap.getGenericSignature()
        ? SubsMap.getGenericSignature()->getCanonicalSignature()
        : nullptr;
    Lowering::GenericContextScope genericContextScope(
        context.getTypeConverter(), jvpGenSig);
    auto loweredDifferentialType =
        getOpType(context.getTypeConverter().getLoweredType(
            differentialDecl->getInterfaceType()->getCanonicalType(),
            ResilienceExpansion::Minimal))
            .castTo<SILFunctionType>();
    // If actual differential type does not match lowered differential type,
    // reabstract the differential using a thunk.
    if (!loweredDifferentialType->isEqual(originalDifferentialType)) {
      SILOptFunctionBuilder fb(context.getTransform());
      auto *thunk = getOrCreateReabstractionThunk(
          fb, context.getModule(), loc, &getDifferential(),
          differentialType, loweredDifferentialType);
      auto *thunkRef = builder.createFunctionRef(loc, thunk);
      differential = builder.createPartialApply(
          loc, thunkRef,
          getOpSubstitutionMap(thunk->getForwardingSubstitutionMap()),
          {differential}, differentialType->getCalleeConvention());
    }
    differentialValues[ai->getParent()].push_back(differential);

    // Differential emission.
    emitTangentForApplyInst(ai, indices, originalDifferentialType);
  }

  void visitReturnInst(ReturnInst *ri) {
    auto loc = ri->getOperand().getLoc();
    auto *origExit = ri->getParent();
    auto &builder = getBuilder();
    auto *diffStructVal = buildDifferentialValueStructValue(ri);

    // Get the JVP value corresponding to the original functions's return value.
    auto *origRetInst = cast<ReturnInst>(origExit->getTerminator());
    auto origResult = getOpValue(origRetInst->getOperand());
    SmallVector<SILValue, 8> origResults;
    extractAllElements(origResult, builder, origResults);

    // Get and partially apply the differential.
    auto jvpGenericEnv = jvp->getGenericEnvironment();
    auto jvpSubstMap = jvpGenericEnv
        ? jvpGenericEnv->getForwardingSubstitutionMap()
        : jvp->getForwardingSubstitutionMap();
    auto *differentialRef =
        builder.createFunctionRef(loc, &getDifferential());
    auto *differentialPartialApply = builder.createPartialApply(
        loc, differentialRef, jvpSubstMap, {diffStructVal},
        ParameterConvention::Direct_Guaranteed);

    // Return a tuple of the original result and pullback.
    SmallVector<SILValue, 8> directResults;
    directResults.append(origResults.begin(), origResults.end());
    directResults.push_back(differentialPartialApply);
    builder.createReturn(
        ri->getLoc(), joinElements(directResults, builder, loc));
  }

  void visitBranchInst(BranchInst *bi) {
    llvm_unreachable("Unsupported SIL instruction.");
  }

  void visitCondBranchInst(CondBranchInst *cbi) {
    llvm_unreachable("Unsupported SIL instruction.");
  }

  void visitSwitchEnumInst(SwitchEnumInst *sei) {
    llvm_unreachable("Unsupported SIL instruction.");
  }

  void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) {
    // Clone `differentiable_function` from original to JVP, then add the cloned
    // instruction to the `differentiable_function` worklist.
    TypeSubstCloner::visitDifferentiableFunctionInst(dfi);
    auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi));
    context.getDifferentiableFunctionInsts().push_back(newDFI);
  }
};
} // end anonymous namespace

//===----------------------------------------------------------------------===//
// PullbackEmitter - visitors on the original function for pullback code
// generation
//===----------------------------------------------------------------------===//

namespace {
class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
private:
  /// The parent VJP emitter.
  VJPEmitter &vjpEmitter;

  /// Dominance info for the original function.
  DominanceInfo *domInfo = nullptr;

  /// Post-dominance info for the original function.
  PostDominanceInfo *postDomInfo = nullptr;

  /// Post-order info for the original function.
  PostOrderFunctionInfo *postOrderInfo = nullptr;

  /// Mapping from original basic blocks to corresponding pullback basic blocks.
  /// Pullback basic blocks always have the predecessor as the single argument.
  DenseMap<SILBasicBlock *, SILBasicBlock *> pullbackBBMap;

  /// Mapping from original basic blocks and original values to corresponding
  /// adjoint values.
  DenseMap<std::pair<SILBasicBlock *, SILValue>, AdjointValue> valueMap;

  /// Mapping from original basic blocks and original buffers to corresponding
  /// adjoint buffers.
  DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap;

  /// Mapping from pullback basic blocks to pullback struct arguments.
  DenseMap<SILBasicBlock *, SILArgument *> pullbackStructArguments;

  /// Mapping from pullback struct field declarations to pullback struct
  /// elements destructured from the linear map basic block argument. In the
  /// beginning of each pullback basic block, the block's pullback struct is
  /// destructured into individual elements stored here.
  DenseMap<VarDecl *, SILValue> pullbackStructElements;

  /// Mapping from original basic blocks and successor basic blocks to
  /// corresponding pullback trampoline basic blocks. Trampoline basic blocks
  /// take additional arguments in addition to the predecessor enum argument.
  DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, SILBasicBlock *>
      pullbackTrampolineBBMap;

  /// Mapping from original basic blocks to dominated active values.
  DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues;

  /// Mapping from original basic blocks and original active values to
  /// corresponding pullback block arguments.
  DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *>
      activeValuePullbackBBArgumentMap;

  /// Mapping from original basic blocks to local temporary values to be cleaned
  /// up. This is populated when pullback emission is run on one basic block and
  /// cleaned before processing another basic block.
  DenseMap<SILBasicBlock *, SmallVector<SILValue, 64>>
      blockTemporaries;

  llvm::DenseSet<SILValue> blockTemporarySet;

  /// The main builder.
  SILBuilder builder;

  /// An auxiliary local allocation builder.
  SILBuilder localAllocBuilder;

  /// Stack buffers allocated for storing local adjoint values.
  SmallVector<SILValue, 64> functionLocalAllocations;

  /// A set used to remember local allocations that were destroyed.
  llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;

  /// The seed argument in the pullback function.
  SILArgument *seed = nullptr;

  llvm::BumpPtrAllocator allocator;

  bool errorOccurred = false;

  ADContext &getContext() const { return vjpEmitter.context; }
  SILModule &getModule() const { return getContext().getModule(); }
  ASTContext &getASTContext() const { return getPullback().getASTContext(); }
  SILFunction &getOriginal() const { return *vjpEmitter.original; }
  SILFunction &getPullback() const { return *vjpEmitter.pullback; }
  SILDifferentiableAttr *getAttr() const { return vjpEmitter.attr; }
  DifferentiationInvoker getInvoker() const { return vjpEmitter.invoker; }
  LinearMapInfo &getPullbackInfo() { return vjpEmitter.pullbackInfo; }
  const SILAutoDiffIndices &getIndices() const {
    return vjpEmitter.getIndices();
  }
  const DifferentiableActivityInfo &getActivityInfo() const {
    return vjpEmitter.activityInfo;
  }

public:
  explicit PullbackEmitter(VJPEmitter &vjpEmitter)
      : vjpEmitter(vjpEmitter), builder(getPullback()),
        localAllocBuilder(getPullback()) {
    // Get dominance and post-order info for the original function.
    auto &passManager = getContext().getPassManager();
    auto *domAnalysis = passManager.getAnalysis<DominanceAnalysis>();
    auto *postDomAnalysis = passManager.getAnalysis<PostDominanceAnalysis>();
    auto *postOrderAnalysis = passManager.getAnalysis<PostOrderAnalysis>();
    domInfo = domAnalysis->get(vjpEmitter.original);
    postDomInfo = postDomAnalysis->get(vjpEmitter.original);
    postOrderInfo = postOrderAnalysis->get(vjpEmitter.original);
  }

private:
  //--------------------------------------------------------------------------//
  // Pullback struct mapping
  //--------------------------------------------------------------------------//

  void initializePullbackStructElements(SILBasicBlock *origBB,
                                        SILInstructionResultArray values) {
    auto *pbStructDecl = getPullbackInfo().getLinearMapStruct(origBB);
    assert(pbStructDecl->getStoredProperties().size() == values.size() &&
           "The number of pullback struct fields must equal the number of "
           "pullback struct element values");
    for (auto pair : llvm::zip(pbStructDecl->getStoredProperties(), values)) {
      assert(
          std::get<1>(pair).getOwnershipKind() != ValueOwnershipKind::Guaranteed
              && "Pullback struct elements must be @owned");
      auto insertion =
          pullbackStructElements.insert({std::get<0>(pair), std::get<1>(pair)});
      (void)insertion;
      assert(insertion.second && "A pullback struct element already exists!");
    }
  }

  SILValue getPullbackStructElement(SILBasicBlock *origBB, VarDecl *field) {
    assert(getPullbackInfo().getLinearMapStruct(origBB) ==
               cast<StructDecl>(field->getDeclContext()));
    assert(pullbackStructElements.count(field) &&
           "Pullback struct element for this field does not exist!");
    return pullbackStructElements.lookup(field);
  }

  //--------------------------------------------------------------------------//
  // Adjoint value factory methods
  //--------------------------------------------------------------------------//

  AdjointValue makeZeroAdjointValue(SILType type);

  AdjointValue makeConcreteAdjointValue(SILValue value);

  template<typename EltRange>
  AdjointValue makeAggregateAdjointValue(SILType type, EltRange elements);

  //--------------------------------------------------------------------------//
  // Temporary value management
  //--------------------------------------------------------------------------//

  /// Record a temporary value for cleanup before its block's terminator.
  SILValue recordTemporary(SILValue value) {
    assert(value->getType().isObject());
    blockTemporaries[value->getParentBlock()].push_back(value);
    LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value);
    auto insertion = blockTemporarySet.insert(value); (void)insertion;
    assert(insertion.second && "Temporary already recorded?");
    return value;
  }

  /// Clean up all temporary values for the given block.
  void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) {
    LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for bb"
               << bb->getDebugID() << '\n');
    for (auto temp : blockTemporaries[bb]) {
      builder.emitDestroyValueOperation(loc, temp);
      blockTemporarySet.erase(temp);
    }
  }

  //--------------------------------------------------------------------------//
  // Symbolic value materializers
  //--------------------------------------------------------------------------//

  /// Materialize an adjoint value. The type of the given adjoint value must be
  /// loadable.
  SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc);

  /// Materialize an adjoint value indirectly to a SIL buffer.
  void materializeAdjointIndirect(AdjointValue val, SILValue destBuffer,
                                  SILLocation loc);

  //--------------------------------------------------------------------------//
  // Helpers for symbolic value materializers
  //--------------------------------------------------------------------------//

  /// Emit a zero value by calling `AdditiveArithmetic.zero`. The given type
  /// must conform to `AdditiveArithmetic`.
  void emitZeroIndirect(CanType type, SILValue bufferAccess, SILLocation loc);

  /// Emit a zero value by calling `AdditiveArithmetic.zero`. The given type
  /// must conform to `AdditiveArithmetic` and be loadable in SIL.
  SILValue emitZeroDirect(CanType type, SILLocation loc);

  //--------------------------------------------------------------------------//
  // Accumulator
  //--------------------------------------------------------------------------//

  /// Materialize an adjoint value in the most efficient way.
  SILValue materializeAdjoint(AdjointValue val, SILLocation loc);

  /// Given two adjoint values, accumulate them.
  AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs,
                                        SILLocation loc);

  /// Given two materialized adjoint values, accumulate them. These two
  /// adjoints must be objects of loadable type.
  SILValue accumulateDirect(SILValue lhs, SILValue rhs, SILLocation loc);

  /// Given two materialized adjoint values, accumulate them using
  /// `AdditiveArithmetic.+`, depending on the differentiation mode.
  void accumulateIndirect(SILValue resultBufAccess,
                          SILValue lhsBufAccess, SILValue rhsBufAccess,
                          SILLocation loc);

  /// Given two buffers of an `AdditiveArithmetic` type, accumulate the right
  /// hand side into the left hand side using `+=`.
  void accumulateIndirect(SILValue lhsDestAccess, SILValue rhsAccess,
                          SILLocation loc);

  //--------------------------------------------------------------------------//
  // Type transformer
  //--------------------------------------------------------------------------//

  /// Remap any archetypes into the current function's context.
  SILType remapType(SILType ty) {
    if (ty.hasArchetype())
      return getPullback().mapTypeIntoContext(ty.mapTypeOutOfContext());
    return getPullback().mapTypeIntoContext(ty);
  }

  Optional<VectorSpace> getTangentSpace(CanType type) {
    return type->getAutoDiffAssociatedTangentSpace(
        LookUpConformanceInModule(getModule().getSwiftModule()));
  }

  /// Assuming the given type conforms to `Differentiable` after remapping,
  /// returns the associated tangent space type.
  SILType getRemappedTangentType(SILType type) {
    return SILType::getPrimitiveType(
        getTangentSpace(remapType(type).getASTType())->getCanonicalType(),
        type.getCategory());
  }

  /// Substitutes all replacement types of the given substitution map using the
  /// pullback function's substitution map.
  SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) {
    return substMap.subst(getPullback().getForwardingSubstitutionMap());
  }

  //--------------------------------------------------------------------------//
  // Managed value mapping
  //--------------------------------------------------------------------------//

  /// Returns true if the original value has a corresponding adjoint value.
  bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const {
    assert(origBB->getParent() == &getOriginal());
    assert(originalValue->getType().isObject());
    return valueMap.count({origBB, originalValue});
  }

  /// Initializes an original value's corresponding adjoint value. It must not
  /// have an adjoint value before this function is called.
  void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
                       AdjointValue adjointValue) {
    LLVM_DEBUG(getADDebugStream() << "Setting adjoint value for "
                                  << originalValue);
    assert(origBB->getParent() == &getOriginal());
    assert(originalValue->getType().isObject());
    assert(adjointValue.getType().isObject());
    assert(originalValue->getFunction() == &getOriginal());
    // The adjoint value must be in the tangent space.
    assert(adjointValue.getType() ==
               getRemappedTangentType(originalValue->getType()));
    auto insertion = valueMap.try_emplace({origBB, originalValue},
                                          adjointValue);
    LLVM_DEBUG(getADDebugStream()
                   << "The existing adjoint value will be replaced: "
                   << insertion.first->getSecond());
    if (!insertion.second)
      insertion.first->getSecond() = adjointValue;
  }

  /// Get the adjoint for an original value. The given value must be in the
  /// original function.
  ///
  /// This method first tries to find an entry in `adjointMap`. If an adjoint
  /// doesn't exist, create a zero adjoint.
  AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) {
    assert(origBB->getParent() == &getOriginal());
    assert(originalValue->getType().isObject());
    assert(originalValue->getFunction() == &getOriginal());
    auto insertion = valueMap.try_emplace(
        {origBB, originalValue}, makeZeroAdjointValue(
            getRemappedTangentType(originalValue->getType())));
    auto it = insertion.first;
    return it->getSecond();
  }

  /// Add an adjoint value for the given original value.
  void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
                       AdjointValue newAdjointValue, SILLocation loc) {
    assert(origBB->getParent() == &getOriginal());
    assert(originalValue->getType().isObject());
    assert(newAdjointValue.getType().isObject());
    assert(originalValue->getFunction() == &getOriginal());
    LLVM_DEBUG(getADDebugStream() << "Adding adjoint for " << originalValue);
    // The adjoint value must be in the tangent space.
    assert(newAdjointValue.getType() ==
               getRemappedTangentType(originalValue->getType()));
    auto insertion =
        valueMap.try_emplace({origBB, originalValue}, newAdjointValue);
    auto inserted = insertion.second;
    if (inserted)
      return;
    // If adjoint already exists, accumulate the adjoint onto the existing
    // adjoint.
    auto it = insertion.first;
    auto existingValue = it->getSecond();
    valueMap.erase(it);
    auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc);
    setAdjointValue(origBB, originalValue, adjVal);
  }

  /// Get the pullback block argument corresponding to the given original block
  /// and active value.
  SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB,
                                                   SILValue activeValue) {
    assert(origBB->getParent() == &getOriginal());
    auto pullbackBBArg =
        activeValuePullbackBBArgumentMap[{origBB, activeValue}];
    assert(pullbackBBArg);
    assert(pullbackBBArg->getParent() == getPullbackBlock(origBB));
    return pullbackBBArg;
  }

  //--------------------------------------------------------------------------//
  // Buffer mapping
  //--------------------------------------------------------------------------//

  void setAdjointBuffer(SILBasicBlock *origBB,
                        SILValue originalBuffer,
                        SILValue adjointBuffer) {
    assert(originalBuffer->getType().isAddress());
    auto insertion =
        bufferMap.try_emplace({origBB, originalBuffer}, adjointBuffer);
    assert(insertion.second); (void)insertion;
  }

  SILValue getAdjointProjection(SILBasicBlock *origBB,
                                SILValue originalProjection) {
    // Handle `struct_element_addr`.
    if (auto *seai = dyn_cast<StructElementAddrInst>(originalProjection)) {
      auto adjSource = getAdjointBuffer(origBB, seai->getOperand());
      auto *tangentVectorDecl =
          adjSource->getType().getStructOrBoundGenericStruct();
      auto tanFieldLookup =
          tangentVectorDecl->lookupDirect(seai->getField()->getName());
      assert(tanFieldLookup.size() == 1);
      auto *tanField = cast<VarDecl>(tanFieldLookup.front());
      return builder.createStructElementAddr(
          seai->getLoc(), adjSource, tanField);
    }
    // Handle `tuple_element_addr`.
    if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) {
      auto source = teai->getOperand();
      auto adjSource = getAdjointBuffer(origBB, source);
      if (!adjSource->getType().is<TupleType>())
        return adjSource;
      auto origTupleTy = source->getType().castTo<TupleType>();
      unsigned adjIndex = 0;
      for (unsigned i : range(teai->getFieldNo())) {
        if (getTangentSpace(
                origTupleTy->getElement(i).getType()->getCanonicalType()))
          ++adjIndex;
      }
      return builder.createTupleElementAddr(
          teai->getLoc(), adjSource, adjIndex);
    }
    // Handle `begin_access`.
    if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) {
      auto adjBase = getAdjointBuffer(origBB, bai->getOperand());
      if (errorOccurred)
        return (bufferMap[{origBB, originalProjection}] = SILValue());
      // Return the base buffer's adjoint buffer.
      return adjBase;
    }
    return SILValue();
  }

  SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() {
    // If there are no local allocations, insert at the pullback entry start.
    if (functionLocalAllocations.empty())
      return getPullback().getEntryBlock()->begin();
    // Otherwise, insert before the last local allocation. Inserting before
    // rather than after ensures that allocation and zero initialization
    // instructions are grouped together.
    auto lastLocalAlloc = functionLocalAllocations.back();
    return lastLocalAlloc->getDefiningInstruction()->getIterator();
  }

  SILValue &getAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer) {
    assert(originalBuffer->getType().isAddress());
    assert(originalBuffer->getFunction() == &getOriginal());
    auto insertion = bufferMap.try_emplace({origBB, originalBuffer},
                                           SILValue());
    if (!insertion.second) // not inserted
      return insertion.first->getSecond();

    // If the original buffer is a projection, return a corresponding projection
    // into the adjoint buffer.
    if (auto adjProj = getAdjointProjection(origBB, originalBuffer))
      return (bufferMap[{origBB, originalBuffer}] = adjProj);

    // Set insertion point for local allocation builder: before the last local
    // allocation, or at the start of the pullback function's entry if no local
    // allocations exist yet.
    localAllocBuilder.setInsertionPoint(
        getPullback().getEntryBlock(),
        getNextFunctionLocalAllocationInsertionPoint());
    // Allocate local buffer and initialize to zero.
    auto bufObjectType = getRemappedTangentType(originalBuffer->getType());
    auto *newBuf = localAllocBuilder.createAllocStack(
        RegularLocation::getAutoGeneratedLocation(), bufObjectType);
    // Temporarily change global builder insertion point and emit zero into the
    // local buffer.
    auto insertionPoint = builder.getInsertionBB();
    builder.setInsertionPoint(
        localAllocBuilder.getInsertionBB(),
        localAllocBuilder.getInsertionPoint());
    emitZeroIndirect(bufObjectType.getASTType(), newBuf, newBuf->getLoc());
    builder.setInsertionPoint(insertionPoint);
    // Register the local buffer.
    functionLocalAllocations.push_back(newBuf);
    return (insertion.first->getSecond() = newBuf);
  }

  // Accumulates `rhsBufferAccess` into the adjoint buffer corresponding to
  // `originalBuffer`.
  void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalBuffer,
                          SILValue rhsBufferAccess, SILLocation loc) {
    assert(originalBuffer->getType().isAddress() &&
           rhsBufferAccess->getType().isAddress());
    assert(originalBuffer->getFunction() == &getOriginal());
    assert(rhsBufferAccess->getFunction() == &getPullback());
    auto adjointBuffer = getAdjointBuffer(origBB, originalBuffer);
    accumulateIndirect(adjointBuffer, rhsBufferAccess, loc);
  }

  //--------------------------------------------------------------------------//
  // CFG mapping
  //--------------------------------------------------------------------------//

  SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) {
    return pullbackBBMap.lookup(originalBlock);
  }

  SILBasicBlock *getPullbackTrampolineBlock(
      SILBasicBlock *originalBlock, SILBasicBlock *successorBlock) {
    return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock});
  }

public:
  //--------------------------------------------------------------------------//
  // Entry point
  //--------------------------------------------------------------------------//

  /// Performs pullback generation on the empty pullback function. Returns true
  /// if any error occurs.
  bool run() {
    auto &original = getOriginal();
    auto &pullback = getPullback();
    auto pbLoc = getPullback().getLocation();
    LLVM_DEBUG(getADDebugStream() << "Running PullbackEmitter on\n"
                                  << original);

    auto *pbGenEnv = getPullback().getGenericEnvironment();
    auto pbGenSig = pbGenEnv
        ? pbGenEnv->getGenericSignature()->getCanonicalSignature()
        : nullptr;
    Lowering::GenericContextScope genericContextScope(
        getContext().getTypeConverter(), pbGenSig);
    auto origExitIt = original.findReturnBB();
    assert(origExitIt != original.end() &&
           "Functions without returns must have been diagnosed");
    auto *origExit = &*origExitIt;

    SmallVector<SILValue, 8> origFormalResults;
    collectAllFormalResultsInTypeOrder(original, origFormalResults);
    auto origResult = origFormalResults[getIndices().source];

    // If original result is non-varied, it will always have a zero derivative.
    // Skip full pullback generation and simply emit zero derivatives for wrt
    // parameters.
    //
    // NOTE(TF-876): This shortcut is currently necessary for functions
    // returning non-varied result with >1 basic block where some basic blocks
    // have no dominated active values; control flow differentiation does not
    // handle this case. See TF-876 for context.
    if (!getActivityInfo().isVaried(origResult, getIndices().parameters)) {
      emitZeroDerivativesForNonvariedResult(origResult);
      return false;
    }

    // Get dominated active values in original blocks.
    // Adjoint values of dominated active values are passed as pullback block
    // arguments.
    DominanceOrder domOrder(original.getEntryBlock(), domInfo);
    while (auto *bb = domOrder.getNext()) {
      auto &bbActiveValues = activeValues[bb];
      // If the current block has an immediate dominator, append the immediate
      // dominator block's active values to the current block's active values.
      if (auto *domNode = domInfo->getNode(bb)->getIDom()) {
        auto &domBBActiveValues = activeValues[domNode->getBlock()];
        bbActiveValues.append(domBBActiveValues.begin(),
                              domBBActiveValues.end());
      }
      SmallPtrSet<SILValue, 8> visited(bbActiveValues.begin(),
                                       bbActiveValues.end());
      // Register a value as active if it has not yet been visited.
      auto addActiveValue = [&](SILValue v) {
        if (visited.count(v))
          return;
        // Diagnose active enum values. Differentiation of enum values is not
        // yet supported; requires special adjoint value handling.
        if (v->getType().getEnumOrBoundGenericEnum()) {
          getContext().emitNondifferentiabilityError(
              v, getInvoker(), diag::autodiff_enums_unsupported);
          errorOccurred = true;
        }
        // Skip address projections.
        // Address projections do not need their own adjoint buffers; they
        // become projections into their adjoint base buffer.
        if (Projection::isAddressProjection(v))
          return;
        visited.insert(v);
        bbActiveValues.push_back(v);
      };
      // Register bb arguments and all instruction operands/results.
      for (auto *arg : bb->getArguments())
        if (getActivityInfo().isActive(arg, getIndices()))
          addActiveValue(arg);
      for (auto &inst : *bb) {
        for (auto op : inst.getOperandValues())
          if (getActivityInfo().isActive(op, getIndices()))
            addActiveValue(op);
        for (auto result : inst.getResults())
          if (getActivityInfo().isActive(result, getIndices()))
            addActiveValue(result);
      }
      domOrder.pushChildren(bb);
      if (errorOccurred)
        return true;
    }

    // Create pullback blocks and arguments, visiting original blocks in
    // post-order post-dominance order.
    SmallVector<SILBasicBlock *, 8> postOrderPostDomOrder;
    // Start from the root node, which may have a marker `nullptr` block if
    // there are multiple roots.
    PostOrderPostDominanceOrder postDomOrder(postDomInfo->getRootNode(),
                                             postOrderInfo, original.size());
    while (auto *origNode = postDomOrder.getNext()) {
      auto *origBB = origNode->getBlock();
      postDomOrder.pushChildren(origNode);
      // If node is the `nullptr` marker basic block, do not push it.
      if (!origBB)
        continue;
      postOrderPostDomOrder.push_back(origBB);
    }
    for (auto *origBB : postOrderPostDomOrder) {
      auto *pullbackBB = pullback.createBasicBlock();
      pullbackBBMap.insert({origBB, pullbackBB});
      auto pbStructLoweredType =
          remapType(getPullbackInfo().getLinearMapStructLoweredType(origBB));
      // If the BB is the original exit, then the pullback block that we just
      // created must be the pullback function's entry. For the pullback entry,
      // create entry arguments and continue to the next block.
      if (origBB == origExit) {
        assert(pullbackBB->isEntry());
        createEntryArguments(&pullback);
        auto *mainPullbackStruct = pullbackBB->getArguments().back();
        assert(mainPullbackStruct->getType() == pbStructLoweredType);
        pullbackStructArguments[origBB] = mainPullbackStruct;
        // Destructure the pullback struct to get the elements.
        builder.setInsertionPoint(pullbackBB);
        auto *dsi = builder.createDestructureStruct(pbLoc, mainPullbackStruct);
        initializePullbackStructElements(origBB, dsi->getResults());
        continue;
      }
      // Get all active values in the original block.
      // If the original block has no active values, continue.
      auto &bbActiveValues = activeValues[origBB];
      if (bbActiveValues.empty())
        continue;
      // Otherwise, if the original block has active values:
      // - For each active buffer in the original block, allocate a new local
      //   buffer in the pullback entry. (All adjoint buffers are allocated in
      //   the pullback entry and deallocated in the pullback exit.)
      // - For each active value in the original block, add adjoint value
      //   arguments to the pullback block.
      for (auto activeValue : bbActiveValues) {
        if (activeValue->getType().isAddress()) {
          // Allocate and zero initialize a new local buffer using
          // `getAdjointBuffer`.
          builder.setInsertionPoint(pullback.getEntryBlock());
          getAdjointBuffer(origBB, activeValue);
        } else {
          // Create and register pullback block argument for the active value.
          auto *pullbackArg = pullbackBB->createPhiArgument(
              getRemappedTangentType(activeValue->getType()),
              ValueOwnershipKind::Owned);
          activeValuePullbackBBArgumentMap[{origBB, activeValue}] = pullbackArg;
          recordTemporary(pullbackArg);
        }
      }
      // Add a pullback struct argument.
      auto *pbStructArg = pullbackBB->createPhiArgument(
          pbStructLoweredType, ValueOwnershipKind::Owned);
      pullbackStructArguments[origBB] = pbStructArg;
      // Destructure the pullback struct to get the elements.
      builder.setInsertionPoint(pullbackBB);
      auto *dsi = builder.createDestructureStruct(pbLoc, pbStructArg);
      initializePullbackStructElements(origBB, dsi->getResults());

      // - Create pullback trampoline blocks for each successor block of the
      //   original block. Pullback trampoline blocks only have a pullback
      //   struct argument. They branch from a pullback successor block to the
      //   pullback original block, passing adjoint values of active values.
      for (auto *succBB : origBB->getSuccessorBlocks()) {
        auto *pullbackTrampolineBB =
            pullback.createBasicBlockBefore(pullbackBB);
        pullbackTrampolineBBMap.insert({{origBB, succBB},
                                       pullbackTrampolineBB});
        // Get the enum element type (i.e. the pullback struct type). The enum
        // element type may be boxed if the enum is indirect.
        auto enumLoweredTy =
            getPullbackInfo().getBranchingTraceEnumLoweredType(succBB);
        auto *enumEltDecl =
            getPullbackInfo().lookUpBranchingTraceEnumElement(origBB, succBB);
        auto enumEltType = remapType(
            enumLoweredTy.getEnumElementType(enumEltDecl, getModule()));
        pullbackTrampolineBB->createPhiArgument(enumEltType,
                                                ValueOwnershipKind::Owned);
      }
    }

    auto *pullbackEntry = pullback.getEntryBlock();
    // The pullback function has type (seed, exit_pbs) -> ([arg0], ..., [argn]).
    auto pbParamArgs = pullback.getArgumentsWithoutIndirectResults();
    assert(pbParamArgs.size() == 2);
    seed = pbParamArgs[0];

    // Assign adjoint for original result.
    builder.setInsertionPoint(
        pullbackEntry, getNextFunctionLocalAllocationInsertionPoint());
    if (seed->getType().isAddress()) {
      auto *seedBufCopy = builder.createAllocStack(pbLoc, seed->getType());
      builder.createCopyAddr(pbLoc, seed, seedBufCopy, IsNotTake,
                             IsInitialization);
      setAdjointBuffer(origExit, origResult, seedBufCopy);
      functionLocalAllocations.push_back(seedBufCopy);
      LLVM_DEBUG(getADDebugStream()
                 << "Assigned seed buffer " << seedBufCopy
                 << " as the adjoint of original indirect result "
                 << origResult);
    } else {
      setAdjointValue(origExit, origResult, makeConcreteAdjointValue(seed));
      LLVM_DEBUG(getADDebugStream()
                 << "Assigned seed " << *seed
                 << " as the adjoint of original result " << origResult);
    }

    // Visit original blocks blocks in post-order and perform differentiation
    // in corresponding pullback blocks. If errors occurred, back out.
    for (auto *bb : postOrderPostDomOrder) {
      visitSILBasicBlock(bb);
      if (errorOccurred)
        return true;
    }

    // Prepare and emit a `return` in the pullback exit block.
    auto *origEntry = getOriginal().getEntryBlock();
    auto *pbExit = getPullbackBlock(origEntry);
    builder.setInsertionPoint(pbExit);

    // This vector will contain all the materialized return elements.
    SmallVector<SILValue, 8> retElts;
    // This vector will contain all indirect parameter adjoint buffers.
    SmallVector<SILValue, 4> indParamAdjoints;

    auto origParams = getOriginal().getArgumentsWithoutIndirectResults();

    // Materializes the return element corresponding to the parameter
    // `parameterIndex` into the `retElts` vector.
    auto addRetElt = [&](unsigned parameterIndex) -> void {
      auto origParam = origParams[parameterIndex];
      if (origParam->getType().isObject()) {
        auto pbVal = getAdjointValue(origEntry, origParam);
        auto val = materializeAdjointDirect(pbVal, pbLoc);
        auto newVal = builder.emitCopyValueOperation(pbLoc, val);
        retElts.push_back(newVal);
      } else {
        auto adjBuf = getAdjointBuffer(origEntry, origParam);
        indParamAdjoints.push_back(adjBuf);
      }
    };
    // Collect differentiation parameter adjoints.
    for (auto i : getIndices().parameters->getIndices())
      addRetElt(i);

    // Copy them to adjoint indirect results.
    assert(indParamAdjoints.size() ==
               getPullback().getIndirectResults().size() &&
           "Indirect parameter adjoint count mismatch");
    for (auto pair : zip(indParamAdjoints,
                             getPullback().getIndirectResults())) {
      auto source = std::get<0>(pair);
      auto *dest = std::get<1>(pair);
      builder.createCopyAddr(pbLoc, source, dest, IsTake, IsInitialization);
      // Prevent source buffer from being deallocated, since the underlying
      // value is moved.
      destroyedLocalAllocations.insert(source);
    }

    // Emit cleanups for all local values.
    cleanUpTemporariesForBlock(pbExit, pbLoc);
    // Deallocate local allocations.
    for (auto alloc : functionLocalAllocations) {
      // Assert that local allocations have at least one use.
      // Buffers should not be allocated needlessly.
      assert(!alloc->use_empty());
      if (!destroyedLocalAllocations.count(alloc)) {
        builder.emitDestroyAddrAndFold(pbLoc, alloc);
        destroyedLocalAllocations.insert(alloc);
      }
      builder.createDeallocStack(pbLoc, alloc);
    }
    builder.createReturn(pbLoc, joinElements(retElts, builder, pbLoc));

#ifndef NDEBUG
    bool leakFound = false;
    // Ensure all temporaries have been cleaned up.
    for (auto &bb : pullback) {
      for (auto temp : blockTemporaries[&bb]) {
        if (blockTemporarySet.count(temp)) {
          leakFound = true;
          getADDebugStream() << "Found leaked temporary:\n" << temp;
        }
      }
    }
    // Ensure all local allocations have been cleaned up.
    for (auto localAlloc : functionLocalAllocations) {
      if (!destroyedLocalAllocations.count(localAlloc)) {
        leakFound = true;
        getADDebugStream() << "Found leaked local buffer:\n" << localAlloc;
      }
    }
    assert(!leakFound && "Leaks found!");
#endif

    LLVM_DEBUG(getADDebugStream() << "Generated pullback for "
                                  << original.getName() << ":\n" << pullback);
    return errorOccurred;
  }

  /// If original result is non-varied, it will always have a zero derivative.
  /// Skip full pullback generation and simply emit zero derivatives for wrt
  /// parameters.
  void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult) {
    auto &pullback = getPullback();
    auto pbLoc = getPullback().getLocation();
    /*
    // TODO(TF-788): Re-enable non-varied result warning.
    // Emit fixit if original non-varied result has a valid source location.
    auto startLoc = origNonvariedResult.getLoc().getStartSourceLoc();
    auto endLoc = origNonvariedResult.getLoc().getEndSourceLoc();
    if (startLoc.isValid() && endLoc.isValid()) {
      getContext().diagnose(startLoc, diag::autodiff_nonvaried_result_fixit)
          .fixItInsert(startLoc, "withoutDerivative(at:")
          .fixItInsertAfter(endLoc, ")");
    }
    */
    LLVM_DEBUG(getADDebugStream() << getOriginal().getName()
                                  << " has non-varied result, returning zero"
                                     " for all pullback results\n");
    auto *pullbackEntry = pullback.createBasicBlock();
    createEntryArguments(&pullback);
    builder.setInsertionPoint(pullbackEntry);
    // Destroy all owned arguments.
    for (auto *arg : pullbackEntry->getArguments())
      if (arg->getOwnershipKind() == ValueOwnershipKind::Owned)
        builder.emitDestroyOperation(pbLoc, arg);
    // Return zero for each result.
    SmallVector<SILValue, 4> directResults;
    auto indirectResultIt = pullback.getIndirectResults().begin();
    for (auto resultInfo : pullback.getLoweredFunctionType()->getResults()) {
      auto resultType =
          pullback.mapTypeIntoContext(resultInfo.getType())->getCanonicalType();
      if (resultInfo.isFormalDirect())
        directResults.push_back(emitZeroDirect(resultType, pbLoc));
      else
        emitZeroIndirect(resultType, *indirectResultIt++, pbLoc);
    }
    builder.createReturn(pbLoc, joinElements(directResults, builder, pbLoc));
    LLVM_DEBUG(getADDebugStream() << "Generated pullback for "
                                  << getOriginal().getName() << ":\n"
                                  << pullback);
  }

  using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>;

  /// Determine the pullback successor block for a given original block and one
  /// of its predecessors. When a trampoline block is necessary, emit code into
  /// the trampoline block to trampoline the original block's active value's
  /// adjoint values. A dense map `trampolineArgs` will be populated to keep
  /// track of which pullback successor blocks each active value's adjoint value
  /// is used, so that we can release those values in pullback successor blocks
  /// that are not using them.
  SILBasicBlock *buildPullbackSuccessor(
      SILBasicBlock *origBB, SILBasicBlock *origPredBB,
      SmallDenseMap<SILValue, TrampolineBlockSet> &pullbackTrampolineBlockMap) {
    // Get the pullback block and optional pullback trampoline block of the
    // predecessor block.
    auto *pullbackBB = getPullbackBlock(origPredBB);
    auto *pullbackTrampolineBB = getPullbackTrampolineBlock(origPredBB, origBB);
    // If the predecessor block does not have a corresponding pullback
    // trampoline block, then the pullback successor is the pullback block.
    if (!pullbackTrampolineBB)
      return pullbackBB;

    // Otherwise, the pullback successor is the pullback trampoline block,
    // which branches to the pullback block and propagates adjoint values of
    // active values.
    assert(pullbackTrampolineBB->getNumArguments() == 1);
    auto loc = origBB->getParent()->getLocation();
    SmallVector<SILValue, 8> trampolineArguments;
    // Propagate adjoint values/buffers of active values/buffers to
    // predecessor blocks.
    auto &predBBActiveValues = activeValues[origPredBB];
    for (auto activeValue : predBBActiveValues) {
      LLVM_DEBUG(getADDebugStream()
                 << "Propagating active adjoint " << activeValue
                 << " to predecessors' pullback blocks\n");
      if (activeValue->getType().isObject()) {
        auto activeValueAdj = getAdjointValue(origBB, activeValue);
        auto concreteActiveValueAdj =
            materializeAdjointDirect(activeValueAdj, loc);

        if (!pullbackTrampolineBlockMap.count(concreteActiveValueAdj)) {
          concreteActiveValueAdj =
              builder.emitCopyValueOperation(loc, concreteActiveValueAdj);
          setAdjointValue(origBB, activeValue,
                          makeConcreteAdjointValue(concreteActiveValueAdj));
        }
        auto insertion = pullbackTrampolineBlockMap.try_emplace(
            concreteActiveValueAdj, TrampolineBlockSet());
        auto &blockSet = insertion.first->getSecond();
        blockSet.insert(pullbackTrampolineBB);
        trampolineArguments.push_back(concreteActiveValueAdj);

        // If the pullback block does not yet have a registered adjoint
        // value for the active value, set the adjoint value to the
        // forwarded adjoint value argument.
        // TODO: Hoist this logic out of loop over predecessor blocks to
        // remove the `hasAdjointValue` check.
        if (!hasAdjointValue(origPredBB, activeValue)) {
          auto *pullbackBBArg =
              getActiveValuePullbackBlockArgument(origPredBB, activeValue);
          auto forwardedArgAdj = makeConcreteAdjointValue(pullbackBBArg);
          setAdjointValue(origPredBB, activeValue, forwardedArgAdj);
        }
      } else {
        // Propagate adjoint buffers using `copy_addr`.
        auto adjBuf = getAdjointBuffer(origBB, activeValue);
        auto predAdjBuf = getAdjointBuffer(origPredBB, activeValue);
        builder.createCopyAddr(
            loc, adjBuf, predAdjBuf, IsNotTake, IsNotInitialization);
      }
    }
    // Propagate pullback struct argument.
    SILBuilder pullbackTrampolineBBBuilder(pullbackTrampolineBB);
    auto *predPBStructVal = pullbackTrampolineBB->getArguments().front();
    auto boxType =
        dyn_cast<SILBoxType>(predPBStructVal->getType().getASTType());
    if (!boxType) {
      trampolineArguments.push_back(predPBStructVal);
    } else {
      auto *projectBox = pullbackTrampolineBBBuilder.createProjectBox(
          loc, predPBStructVal, /*index*/ 0);
      auto loaded = pullbackTrampolineBBBuilder.emitLoadValueOperation(
          loc, projectBox, LoadOwnershipQualifier::Copy);
      pullbackTrampolineBBBuilder.emitDestroyValueOperation(loc,
                                                            predPBStructVal);
      trampolineArguments.push_back(loaded);
    }
    // Branch from pullback trampoline block to pullback block.
    pullbackTrampolineBBBuilder.createBranch(loc, pullbackBB,
                                             trampolineArguments);
    return pullbackTrampolineBB;
  }

  /// Emit pullback code in the corresponding pullback block.
  void visitSILBasicBlock(SILBasicBlock *bb) {
    auto pbLoc = getPullback().getLocation();
    // Get the corresponding pullback basic block.
    auto *pbBB = getPullbackBlock(bb);
    builder.setInsertionPoint(pbBB);

    LLVM_DEBUG({
      auto &s = getADDebugStream()
          << "Original bb" + std::to_string(bb->getDebugID())
          << ": To differentiate or not to differentiate?\n";
      for (auto &inst : reversed(*bb)) {
        s << (getPullbackInfo().shouldDifferentiateInstruction(&inst)
                  ? "[∂] " : "[ ] ")
          << inst;
      }
    });

    // Visit each instruction in reverse order.
    for (auto &inst : reversed(*bb)) {
      if (!getPullbackInfo().shouldDifferentiateInstruction(&inst))
        continue;
      // Differentiate instruction.
      visit(&inst);
      if (errorOccurred)
        return;
    }

    // Emit a branching terminator for the block.
    // If the original block is the original entry, then the pullback block is
    // the pullback exit. This is handled specially in `PullbackEmitter::run()`,
    // so we leave the block non-terminated.
    if (bb->isEntry())
      return;

    // Otherwise, add a `switch_enum` terminator for non-exit
    // pullback blocks.
    // 1. Get the pullback struct pullback block argument.
    // 2. Extract the predecessor enum value from the pullback struct value.
    auto *predEnum = getPullbackInfo().getBranchingTraceDecl(bb);
    auto *predEnumField =
        getPullbackInfo().lookUpLinearMapStructEnumField(bb);
    auto predEnumVal = getPullbackStructElement(bb, predEnumField);

    // Propagate adjoint values from active basic block arguments to
    // predecessor terminator operands.
    for (auto *bbArg : bb->getArguments()) {
      if (!getActivityInfo().isActive(bbArg, getIndices()))
        continue;
      // Get predecessor terminator operands.
      SmallVector<std::pair<SILBasicBlock *, SILValue>, 4> incomingValues;
      bbArg->getSingleTerminatorOperands(incomingValues);
      // Initialize adjoint value of predecessor terminator operands as
      // adjoint value of current block arguments.
      auto bbArgAdj = getAdjointValue(bb, bbArg);
      for (auto pair : incomingValues) {
        auto *predBB = std::get<0>(pair);
        auto incomingValue = std::get<1>(pair);
        setAdjointValue(predBB, incomingValue, bbArgAdj);
      }
    }

    // 3. Build the pullback successor cases for the `switch_enum`
    //    instruction. The pullback successors correspond to the predecessors
    //    of the current block.
    SmallVector<std::pair<EnumElementDecl *, SILBasicBlock *>, 4>
        pullbackSuccessorCases;
    // A map from active values' adjoint values to the trampoline blocks that
    // are using them.
    SmallDenseMap<SILValue, TrampolineBlockSet> pullbackTrampolineBlockMap;
    SmallVector<SILBasicBlock *, 8> pullbackSuccBBs;
    for (auto *predBB : bb->getPredecessorBlocks()) {
      auto *pullbackSuccBB = buildPullbackSuccessor(bb, predBB,
                                                    pullbackTrampolineBlockMap);
      pullbackSuccBBs.push_back(pullbackSuccBB);
      auto *enumEltDecl =
          getPullbackInfo().lookUpBranchingTraceEnumElement(predBB, bb);
      pullbackSuccessorCases.push_back({enumEltDecl, pullbackSuccBB});
    }
    // Values are trampolined by only a subset of pullback successor blocks.
    // Other successors blocks should destroy the value to balance the reference
    // count.
    for (auto pair : pullbackTrampolineBlockMap) {
      auto value = pair.getFirst();
      // The set of trampoline BBs that are users of `value`.
      auto &userTrampolineBBSet = pair.getSecond();
      // For each pullback successor block that does not trampoline the value,
      // release the value.
      for (auto *pullbackSuccBB : pullbackSuccBBs) {
        if (userTrampolineBBSet.count(pullbackSuccBB))
          continue;
        SILBuilder builder(pullbackSuccBB->begin());
        builder.emitDestroyValueOperation(pbLoc, value);
      }
    }
    // Emit cleanups for all block-local temporaries.
    cleanUpTemporariesForBlock(pbBB, pbLoc);
    // - If the original block has exactly one predecessor, then the pullback
    //   block has exactly one successor. Extract the pullback struct value
    //   from the predecessor enum value using `unchecked_take_enum_data_addr`
    //   and `load [take]`, and branch to the pullback successor block.
    assert(pullbackSuccessorCases.size() == predEnum->getNumElements());
    builder.createSwitchEnum(
        pbLoc, predEnumVal, /*DefaultBB*/ nullptr, pullbackSuccessorCases);
  }

  void visit(SILInstruction *inst) {
    if (errorOccurred)
      return;

    LLVM_DEBUG(getADDebugStream()
               << "PullbackEmitter visited:\n[ORIG]" << *inst);
#ifndef NDEBUG
    auto beforeInsertion = std::prev(builder.getInsertionPoint());
#endif
    SILInstructionVisitor::visit(inst);
    LLVM_DEBUG({
      auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback:\n";
      auto afterInsertion = builder.getInsertionPoint();
      for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
        s << *it;
    });
  }

  void visitSILInstruction(SILInstruction *inst) {
    LLVM_DEBUG(getADDebugStream()
               << "Unhandled instruction in PullbackEmitter: " << *inst);
    getContext().emitNondifferentiabilityError(inst, getInvoker(),
        diag::autodiff_expression_not_differentiable_note);
    errorOccurred = true;
  }

  AllocStackInst *
  emitArrayTangentSubscript(ApplyInst *ai, SILType eltType,
                            SILValue adjointArray, SILValue fnRef,
                            CanGenericSignature genericSig, int index) {
    auto &ctx = builder.getASTContext();
    auto astType = eltType.getASTType();
    auto literal = builder.createIntegerLiteral(
        ai->getLoc(), SILType::getBuiltinIntegerType(64, ctx), index);
    auto intType = SILType::getPrimitiveObjectType(
        ctx.getIntDecl()->getDeclaredType()->getCanonicalType());
    auto intStruct = builder.createStruct(ai->getLoc(), intType, {literal});
    AllocStackInst *subscriptBuffer =
        builder.createAllocStack(ai->getLoc(), eltType);
    auto swiftModule = getModule().getSwiftModule();
    auto diffProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
    auto diffConf = swiftModule->lookupConformance(astType, diffProto);
    assert(diffConf.hasValue() && "Missing conformance to `Differentiable`");
    auto addArithProto = ctx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
    auto addArithConf = swiftModule->lookupConformance(astType, addArithProto);
    assert(addArithConf.hasValue() &&
           "Missing conformance to `AdditiveArithmetic`");
    auto subMap =
        SubstitutionMap::get(genericSig, {astType}, {*addArithConf, *diffConf});
    builder.createApply(ai->getLoc(), fnRef, subMap,
                        {subscriptBuffer, intStruct, adjointArray});
    return subscriptBuffer;
  }

  void accumulateArrayTangentSubscriptDirect(ApplyInst *ai, SILType eltType,
                                             StoreInst *si,
                                             AllocStackInst *subscriptBuffer) {
    auto newAdjValue = builder.emitLoadValueOperation(
        ai->getLoc(), subscriptBuffer, LoadOwnershipQualifier::Take);
    recordTemporary(newAdjValue);
    SILValue src = si->getSrc();
    // When the store's source is a `copy_value`, the `copy_value` is part of
    // array literal initialization. In this case, add the adjoint to the source
    // of the copy directly.
    if (auto *cvi = dyn_cast<CopyValueInst>(src))
      src = cvi->getOperand();
    addAdjointValue(si->getParent(), src,
                    makeConcreteAdjointValue(newAdjValue), si->getLoc());
    blockTemporaries[ai->getParent()].push_back(newAdjValue);
    builder.createDeallocStack(ai->getLoc(), subscriptBuffer);
  }

  void accumulateArrayTangentSubscriptIndirect(
      ApplyInst *ai, CopyAddrInst *cai, AllocStackInst *subscriptBuffer) {
    addToAdjointBuffer(cai->getParent(), cai->getSrc(), subscriptBuffer,
                       cai->getLoc());
    builder.emitDestroyAddrAndFold(cai->getLoc(), subscriptBuffer);
    builder.createDeallocStack(ai->getLoc(), subscriptBuffer);
  }

  void visitArrayInitialization(ApplyInst *ai) {
    LLVM_DEBUG(getADDebugStream() << "Visiting array initialization:\n" << *ai);
    SILValue adjointArray;
    SILValue fnRef;
    CanGenericSignature genericSig;
    for (auto use : ai->getUses()) {
      auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
      if (!dti) continue;
      // The first tuple field of the return value is the `Array`.
      adjointArray = getAdjointValue(ai->getParent(), dti->getResult(0))
          .getConcreteValue();
      assert(adjointArray && "Array does not have adjoint value");
      auto astType = adjointArray->getType().getASTType();
      auto typeDecl = astType->getStructOrBoundGenericStruct();
      auto subscriptDecl = cast<SubscriptDecl>(typeDecl->lookupDirect(
          DeclBaseName::createSubscript()).front());
      auto subscriptGet = subscriptDecl->getAccessor(AccessorKind::Get);
      SILDeclRef subscriptRef(subscriptGet, SILDeclRef::Kind::Func);
      auto fnBuilder = SILOptFunctionBuilder(getContext().getTransform());
      auto fn = fnBuilder.getOrCreateFunction(
          ai->getLoc(), subscriptRef, NotForDefinition);
      genericSig = fn->getLoweredFunctionType()->getGenericSignature();
      fnRef = builder.createFunctionRef(ai->getLoc(), fn);
    }
    assert(adjointArray && "Array does not have adjoint value");
    assert(genericSig && "No generic signature");
    assert(fnRef && "Could not create `function_ref`");
    // Two loops because the `tuple_extract` instructions can be reached in
    // either order.
    for (auto use : ai->getUses()) {
      auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
      if (!dti) continue;
      // The second tuple field is the `RawPointer`.
      for (auto use : dti->getResult(1)->getUses()) {
        // The `RawPointer` passes through a `pointer_to_address`. That
        // instruction's first use is a `store` whose src is useful; its
        // subsequent uses are `index_addr`s whose only use is a useful
        // `store`. In the indirect case, each `store` is instead a
        // `copy_addr`.
        for (auto use : use->getUser()->getResult(0)->getUses()) {
          auto inst = use->getUser();
          if (auto si = dyn_cast<StoreInst>(inst)) {
            auto tanType = getRemappedTangentType(si->getSrc()->getType());
            auto subscriptBuffer = emitArrayTangentSubscript(
                ai, tanType, adjointArray, fnRef, genericSig, 0);
            accumulateArrayTangentSubscriptDirect(
                ai, tanType, si, subscriptBuffer);
          } else if (auto cai = dyn_cast<CopyAddrInst>(inst)) {
            auto tanType = getRemappedTangentType(cai->getSrc()->getType());
            auto subscriptBuffer = emitArrayTangentSubscript(
                ai, tanType, adjointArray, fnRef, genericSig, 0);
            accumulateArrayTangentSubscriptIndirect(
                ai, cai, subscriptBuffer);
          } else if (auto iai = dyn_cast<IndexAddrInst>(inst)) {
            for (auto use : iai->getUses()) {
              if (auto si = dyn_cast<StoreInst>(use->getUser())) {
                auto literal = dyn_cast<IntegerLiteralInst>(iai->getIndex());
                auto tanType = getRemappedTangentType(
                    si->getSrc()->getType());
                auto subscriptBuffer = emitArrayTangentSubscript(
                    ai, tanType, adjointArray, fnRef,
                    genericSig, literal->getValue().getLimitedValue());
                accumulateArrayTangentSubscriptDirect(
                    ai, tanType, si, subscriptBuffer);
              } else if (auto cai = dyn_cast<CopyAddrInst>(use->getUser())) {
                auto literal = dyn_cast<IntegerLiteralInst>(iai->getIndex());
                auto tanType = getRemappedTangentType(
                    cai->getSrc()->getType());
                auto subscriptBuffer = emitArrayTangentSubscript(
                    ai, tanType, adjointArray, fnRef,
                    genericSig, literal->getValue().getLimitedValue());
                accumulateArrayTangentSubscriptIndirect(
                    ai, cai, subscriptBuffer);
              }
            }
          }
        }
      }
    }
  }

  void visitApplyInst(ApplyInst *ai) {
    assert(getPullbackInfo().shouldDifferentiateApplyInst(ai));
    // Handle array uninitialized allocation intrinsic specially.
    if (isArrayLiteralIntrinsic(ai))
      return visitArrayInitialization(ai);
    // Replace a call to a function with a call to its pullback.
    auto &nestedApplyInfo = getContext().getNestedApplyInfo();
    auto applyInfoLookup = nestedApplyInfo.find(ai);
    // If no `NestedApplyInfo` was found, then this task doesn't need to be
    // differentiated.
    if (applyInfoLookup == nestedApplyInfo.end()) {
      // Must not be active.
      assert(!getActivityInfo().isActive(ai, getIndices()));
      return;
    }
    auto applyInfo = applyInfoLookup->getSecond();

    // Get the pullback.
    auto *field = getPullbackInfo().lookUpLinearMapDecl(ai);
    assert(field);
    auto loc = ai->getLoc();
    auto pullback = getPullbackStructElement(ai->getParent(), field);

    // Get the original result of the `apply` instruction.
    SmallVector<SILValue, 8> args;
    SmallVector<SILValue, 8> origDirectResults;
    forEachApplyDirectResult(ai, [&](SILValue directResult) {
      origDirectResults.push_back(directResult);
    });
    SmallVector<SILValue, 8> origAllResults;
    collectAllActualResultsInTypeOrder(ai, origDirectResults, origAllResults);
    assert(applyInfo.indices.source < origAllResults.size());
    auto origResult = origAllResults[applyInfo.indices.source];
    assert(origResult);
    auto origNumIndRes = ai->getNumIndirectResults();

    auto pullbackType =
        remapType(pullback->getType()).castTo<SILFunctionType>();

    // Get the seed (i.e. adjoint value of the original result).
    SILValue seed;
    auto *bb = ai->getParent();
    if (origResult->getType().isObject()) {
      // Otherwise, materialize adjoint value of `ai`.
      seed = materializeAdjoint(getAdjointValue(bb, origResult), loc);
    } else {
      seed = getAdjointBuffer(bb, origResult);
    }

    // Create allocations for pullback indirect results.
    SmallVector<AllocStackInst *, 4> pullbackIndirectResults;
    auto actualPullbackType = applyInfo.originalPullbackType
        ? *applyInfo.originalPullbackType
        : pullbackType;
    for (auto indRes : actualPullbackType->getIndirectFormalResults()) {
      auto *alloc =
          builder.createAllocStack(loc, remapType(indRes.getSILStorageType()));
      pullbackIndirectResults.push_back(alloc);
      args.push_back(alloc);
    }

    // If callee pullback was reabstracted in VJP, reabstract callee pullback.
    if (applyInfo.originalPullbackType) {
      SILOptFunctionBuilder fb(getContext().getTransform());
      auto *thunk = getOrCreateReabstractionThunk(
          fb, getContext().getModule(), loc, &getPullback(),
          pullbackType, *applyInfo.originalPullbackType);
      auto *thunkRef = builder.createFunctionRef(loc, thunk);
      pullback = builder.createPartialApply(
          loc, thunkRef,
          remapSubstitutionMap(thunk->getForwardingSubstitutionMap()),
          {pullback}, pullbackType->getCalleeConvention());
    }
    args.push_back(seed);

    // Call the callee pullback.
    auto *pullbackCall = builder.createApply(
        loc, pullback, SubstitutionMap(), args, /*isNonThrowing*/ false);
    builder.emitDestroyValueOperation(loc, pullback);

    // Extract all results from `pullbackCall`.
    SmallVector<SILValue, 8> dirResults;
    extractAllElements(pullbackCall, builder, dirResults);
    // Get all results in type-defined order.
    SmallVector<SILValue, 8> allResults;
    collectAllActualResultsInTypeOrder(pullbackCall, dirResults, allResults);
    LLVM_DEBUG({
      auto &s = getADDebugStream();
      s << "All results of the nested pullback call:\n";
      llvm::for_each(allResults, [&](SILValue v) { s << v; });
    });

    // Accumulate adjoints for original differentiation parameters.
    auto allResultsIt = allResults.begin();
    for (unsigned i : applyInfo.indices.parameters->getIndices()) {
      auto origArg = ai->getArgument(origNumIndRes + i);
      auto tan = *allResultsIt++;
      if (tan->getType().isAddress()) {
        addToAdjointBuffer(bb, origArg, tan, loc);
      } else {
        if (origArg->getType().isAddress()) {
          auto *tmpBuf = builder.createAllocStack(loc, tan->getType());
          builder.emitStoreValueOperation(loc, tan, tmpBuf,
                                          StoreOwnershipQualifier::Init);
          addToAdjointBuffer(bb, origArg, tmpBuf, loc);
          builder.emitDestroyAddrAndFold(loc, tmpBuf);
          builder.createDeallocStack(loc, tmpBuf);
        }
        else {
          recordTemporary(tan);
          addAdjointValue(bb, origArg, makeConcreteAdjointValue(tan), loc);
        }
      }
    }
    // Destroy and deallocate pullback indirect results.
    for (auto *alloc : reversed(pullbackIndirectResults)) {
      builder.emitDestroyAddrAndFold(loc, alloc);
      builder.createDeallocStack(loc, alloc);
    }
  }

  /// Handle `struct` instruction.
  ///   Original: y = struct (x0, x1, x2, ...)
  ///    Adjoint: adj[x0] += struct_extract adj[y], #x0
  ///             adj[x1] += struct_extract adj[y], #x1
  ///             adj[x2] += struct_extract adj[y], #x2
  ///             ...
  void visitStructInst(StructInst *si) {
    auto *bb = si->getParent();
    auto loc = si->getLoc();
    auto *structDecl = si->getStructDecl();
    auto av = getAdjointValue(bb, si);
    switch (av.getKind()) {
    case AdjointValueKind::Zero:
      for (auto *field : structDecl->getStoredProperties()) {
        auto fv = si->getFieldValue(field);
        addAdjointValue(bb, fv,
            makeZeroAdjointValue(getRemappedTangentType(fv->getType())), loc);
      }
      break;
    case AdjointValueKind::Concrete: {
      auto adjStruct = materializeAdjointDirect(std::move(av), loc);
      // Find the struct `TangentVector` type.
      auto structTy = remapType(si->getType()).getASTType();
      auto tangentVectorTy =
          getTangentSpace(structTy)->getType()->getCanonicalType();
      assert(!getModule().Types.getTypeLowering(
                 tangentVectorTy, ResilienceExpansion::Minimal)
                     .isAddressOnly());
      auto *tangentVectorDecl =
          tangentVectorTy->getStructOrBoundGenericStruct();
      assert(tangentVectorDecl);

      auto *dti = builder.createDestructureStruct(si->getLoc(), adjStruct);
      // Accumulate adjoints for the fields of the `struct` operand.
      unsigned fieldIndex = 0;
      for (auto it = structDecl->getStoredProperties().begin();
           it != structDecl->getStoredProperties().end(); ++it, ++fieldIndex) {
        VarDecl *field = *it;
        if (field->getAttrs().hasAttribute<NoDerivativeAttr>())
          continue;
        // Find the corresponding field in the tangent space.
        VarDecl *tanField = nullptr;
        if (tangentVectorDecl == structDecl)
          tanField = field;
        // Otherwise, look up the field by name.
        else {
          auto tanFieldLookup =
          tangentVectorDecl->lookupDirect(field->getName());
          if (tanFieldLookup.empty()) {
            getContext().emitNondifferentiabilityError(
                si, getInvoker(),
                diag::autodiff_stored_property_no_corresponding_tangent,
                tangentVectorDecl->getNameStr(), field->getNameStr());
            errorOccurred = true;
            return;
          }
          tanField = cast<VarDecl>(tanFieldLookup.front());
        }
        assert(tanField);
        auto tanElt = dti->getResult(fieldIndex);
        addAdjointValue(
            bb, si->getFieldValue(field),
            makeConcreteAdjointValue(tanElt), si->getLoc());
      }
      break;
    }
    case AdjointValueKind::Aggregate: {
      // Note: All user-called initializations go through the calls to the
      // initializer, and synthesized initializers only have one level of struct
      // formation which will not result into any aggregate adjoint valeus.
      llvm_unreachable("Aggregate adjoint values should not occur for `struct` "
                       "instructions");
    }
    }
  }

  /// Handle `struct_extract` instruction.
  ///   Original: y = struct_extract x, #field
  ///    Adjoint: adj[x] += struct (0, ..., #field': adj[y], ..., 0)
  ///                                       ^~~~~~~
  ///                     field in tangent space corresponding to #field
  void visitStructExtractInst(StructExtractInst *sei) {
    assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
           "`struct_extract` with `@noDerivative` field should not be "
           "differentiated; activity analysis should not marked as varied");
    auto *bb = sei->getParent();
    auto structTy = remapType(sei->getOperand()->getType()).getASTType();
    auto tangentVectorTy =
        getTangentSpace(structTy)->getType()->getCanonicalType();
    assert(!getModule().Types.getTypeLowering(
               tangentVectorTy, ResilienceExpansion::Minimal)
                   .isAddressOnly());
    auto tangentVectorSILTy =
        SILType::getPrimitiveObjectType(tangentVectorTy);
    auto *tangentVectorDecl =
        tangentVectorTy->getStructOrBoundGenericStruct();
    assert(tangentVectorDecl);
    // Find the corresponding field in the tangent space.
    VarDecl *tanField = nullptr;
    // If the tangent space is the original struct, then field is the same.
    if (tangentVectorDecl == sei->getStructDecl())
      tanField = sei->getField();
    // Otherwise, look up the field by name.
    else {
      auto tanFieldLookup =
          tangentVectorDecl->lookupDirect(sei->getField()->getName());
      if (tanFieldLookup.empty()) {
        getContext().emitNondifferentiabilityError(
            sei, getInvoker(),
            diag::autodiff_stored_property_no_corresponding_tangent,
            sei->getStructDecl()->getNameStr(),
            sei->getField()->getNameStr());
        errorOccurred = true;
        return;
      }
      tanField = cast<VarDecl>(tanFieldLookup.front());
    }
    // Accumulate adjoint for the `struct_extract` operand.
    auto av = getAdjointValue(bb, sei);
    switch (av.getKind()) {
    case AdjointValueKind::Zero:
      addAdjointValue(bb, sei->getOperand(),
                      makeZeroAdjointValue(tangentVectorSILTy), sei->getLoc());
      break;
    case AdjointValueKind::Concrete:
    case AdjointValueKind::Aggregate: {
      SmallVector<AdjointValue, 8> eltVals;
      for (auto *field : tangentVectorDecl->getStoredProperties()) {
        if (field == tanField) {
          eltVals.push_back(av);
        } else {
          auto substMap = tangentVectorTy->getMemberSubstitutionMap(
              field->getModuleContext(), field);
          auto fieldTy = field->getType().subst(substMap);
          auto fieldSILTy =
              getContext().getTypeConverter().getLoweredType(
                  fieldTy, ResilienceExpansion::Minimal);
          assert(fieldSILTy.isObject());
          eltVals.push_back(makeZeroAdjointValue(fieldSILTy));
        }
      }
      addAdjointValue(bb, sei->getOperand(),
                      makeAggregateAdjointValue(tangentVectorSILTy, eltVals),
                      sei->getLoc());
    }
    }
  }

  /// Handle `tuple` instruction.
  ///   Original: y = tuple (x0, x1, x2, ...)
  ///    Adjoint: adj[x0] += tuple_extract adj[y], 0
  ///             ...
  void visitTupleInst(TupleInst *ti) {
    auto *bb = ti->getParent();
    auto av = getAdjointValue(bb, ti);
    switch (av.getKind()) {
    case AdjointValueKind::Zero:
      for (auto eltVal : ti->getElements()) {
        if (!getTangentSpace(eltVal->getType().getASTType()))
          continue;
        addAdjointValue(bb, eltVal,
            makeZeroAdjointValue(getRemappedTangentType(eltVal->getType())),
            ti->getLoc());
      }
      break;
    case AdjointValueKind::Concrete: {
      auto val = av.getConcreteValue();
      unsigned adjIdx = 0;
      auto elts = builder.createDestructureTuple(ti->getLoc(), val);
      for (auto i : range(ti->getNumOperands())) {
        if (!getTangentSpace(ti->getOperand(i)->getType().getASTType()))
          continue;
        auto adjElt = val;
        if (val->getType().is<TupleType>())
          adjElt = elts->getResult(adjIdx++);
        addAdjointValue(bb, ti->getOperand(i),
                        makeConcreteAdjointValue(adjElt), ti->getLoc());
      }
      break;
    }
    case AdjointValueKind::Aggregate:
      unsigned adjIdx = 0;
      for (auto i : range(ti->getElements().size())) {
        if (!getTangentSpace(ti->getElement(i)->getType().getASTType()))
          continue;
        addAdjointValue(bb, ti->getElement(i), av.getAggregateElement(adjIdx++),
                        ti->getLoc());
      }
      break;
    }
  }

  /// Handle `tuple_extract` instruction.
  ///   Original: y = tuple_extract x, <n>
  ///    Adjoint: adj[x] += tuple (0, 0, ..., adj[y], ..., 0, 0)
  ///                                         ^~~~~~
  ///                            n'-th element, where n' is tuple tangent space
  ///                            index corresponding to n
  void visitTupleExtractInst(TupleExtractInst *tei) {
    auto *bb = tei->getParent();
    auto tupleTanTy = getRemappedTangentType(tei->getOperand()->getType());
    auto av = getAdjointValue(bb, tei);
    switch (av.getKind()) {
    case AdjointValueKind::Zero:
      addAdjointValue(bb, tei->getOperand(), makeZeroAdjointValue(tupleTanTy),
                      tei->getLoc());
      break;
    case AdjointValueKind::Aggregate:
    case AdjointValueKind::Concrete: {
      auto tupleTy = tei->getTupleType();
      auto tupleTanTupleTy = tupleTanTy.getAs<TupleType>();
      if (!tupleTanTupleTy) {
        addAdjointValue(bb, tei->getOperand(), av, tei->getLoc());
        break;
      }
      SmallVector<AdjointValue, 8> elements;
      unsigned adjIdx = 0;
      for (unsigned i : range(tupleTy->getNumElements())) {
        if (!getTangentSpace(
                tupleTy->getElement(i).getType()->getCanonicalType()))
          continue;
        if (tei->getFieldNo() == i)
          elements.push_back(av);
        else
          elements.push_back(makeZeroAdjointValue(
              getRemappedTangentType(SILType::getPrimitiveObjectType(
                  tupleTanTupleTy->getElementType(adjIdx++)
                      ->getCanonicalType()))));
      }
      if (elements.size() == 1) {
        addAdjointValue(bb, tei->getOperand(), elements.front(), tei->getLoc());
        break;
      }
      addAdjointValue(bb, tei->getOperand(),
          makeAggregateAdjointValue(tupleTanTy, elements), tei->getLoc());
      break;
    }
    }
  }

  /// Handle `destructure_tuple` instruction.
  ///   Original: (y0, ..., yn) = destructure_tuple x
  ///    Adjoint: adj[x].0 += adj[y0]
  ///             ...
  ///             adj[x].n += adj[yn]
  void visitDestructureTupleInst(DestructureTupleInst *dti) {
    auto *bb = dti->getParent();
    auto tupleTanTy = getRemappedTangentType(dti->getOperand()->getType());
    SmallVector<AdjointValue, 8> adjValues;
    for (auto origElt : dti->getResults()) {
      if (!getTangentSpace(origElt->getType().getASTType()))
        continue;
      adjValues.push_back(getAdjointValue(bb, origElt));
    }
    addAdjointValue(bb, dti->getOperand(),
                    makeAggregateAdjointValue(tupleTanTy, adjValues),
                    dti->getLoc());
  }

  /// Handle `load` or `load_borrow` instruction
  ///   Original: y = load/load_borrow x
  ///    Adjoint: adj[x] += adj[y]
  void visitLoadOperation(SingleValueInstruction *inst) {
    assert(isa<LoadInst>(inst) || isa<LoadBorrowInst>(inst));
    auto *bb = inst->getParent();
    auto adjVal =
    materializeAdjointDirect(getAdjointValue(bb, inst), inst->getLoc());
    // Allocate a local buffer and store the adjoint value. This buffer will be
    // used for accumulation into the adjoint buffer.
    auto *localBuf = builder.createAllocStack(inst->getLoc(), adjVal->getType());
    auto copy = builder.emitCopyValueOperation(inst->getLoc(), adjVal);
    builder.emitStoreValueOperation(inst->getLoc(), copy, localBuf,
                                    StoreOwnershipQualifier::Init);
    // Accumulate the adjoint value in the local buffer into the adjoint buffer.
    addToAdjointBuffer(bb, inst->getOperand(0), localBuf, inst->getLoc());
    builder.emitDestroyAddr(inst->getLoc(), localBuf);
    builder.createDeallocStack(inst->getLoc(), localBuf);
  }
  void visitLoadInst(LoadInst *li) { visitLoadOperation(li); }
  void visitLoadBorrowInst(LoadBorrowInst *lbi) { visitLoadOperation(lbi); }

  /// Handle `store` or `store_borrow` instruction.
  ///   Original: store/store_borrow x to y
  ///    Adjoint: adj[x] += load adj[y]; adj[y] = 0
  void visitStoreOperation(SILBasicBlock *bb, SILLocation loc,
                           SILValue origSrc, SILValue origDest) {
    auto &adjBuf = getAdjointBuffer(bb, origDest);
    auto bufType = remapType(adjBuf->getType());
    auto adjVal = builder.emitLoadValueOperation(
        loc, adjBuf, LoadOwnershipQualifier::Take);
    recordTemporary(adjVal);
    addAdjointValue(bb, origSrc, makeConcreteAdjointValue(adjVal), loc);
    emitZeroIndirect(bufType.getASTType(), adjBuf, loc);
  }
  void visitStoreInst(StoreInst *si) {
    visitStoreOperation(
        si->getParent(), si->getLoc(), si->getSrc(), si->getDest());
  }
  void visitStoreBorrowInst(StoreBorrowInst *sbi) {
    visitStoreOperation(
        sbi->getParent(), sbi->getLoc(), sbi->getSrc(), sbi->getDest());
  }

  /// Handle `copy_addr` instruction.
  ///   Original: copy_addr x to y
  ///    Adjoint: adj[x] += adj[y]; adj[y] = 0
  void visitCopyAddrInst(CopyAddrInst *cai) {
    auto *bb = cai->getParent();
    auto &adjDest = getAdjointBuffer(bb, cai->getDest());
    auto destType = remapType(adjDest->getType());
    addToAdjointBuffer(bb, cai->getSrc(), adjDest, cai->getLoc());
    builder.emitDestroyAddrAndFold(cai->getLoc(), adjDest);
    emitZeroIndirect(destType.getASTType(), adjDest, cai->getLoc());
  }

  /// Handle `copy_value` instruction.
  ///   Original: y = copy_value x
  ///    Adjoint: adj[x] += adj[y]
  void visitCopyValueInst(CopyValueInst *cvi) {
    auto *bb = cvi->getParent();
    auto adj = getAdjointValue(bb, cvi);
    addAdjointValue(bb, cvi->getOperand(), adj, cvi->getLoc());
  }

  /// Handle `begin_borrow` instruction.
  ///   Original: y = begin_borrow x
  ///    Adjoint: adj[x] += adj[y]
  void visitBeginBorrowInst(BeginBorrowInst *bbi) {
    auto *bb = bbi->getParent();
    auto adj = getAdjointValue(bb, bbi);
    addAdjointValue(bb, bbi->getOperand(), adj, bbi->getLoc());
  }

  /// Handle `begin_access` instruction.
  ///   Original: y = begin_access x
  ///    Adjoint: nothing
  void visitBeginAccessInst(BeginAccessInst *bai) {
    // Check for non-differentiable writes.
    if (bai->getAccessKind() == SILAccessKind::Modify) {
      if (auto *gai = dyn_cast<GlobalAddrInst>(bai->getSource())) {
        getContext().emitNondifferentiabilityError(bai, getInvoker(),
            diag::autodiff_cannot_differentiate_writes_to_global_variables);
        errorOccurred = true;
        return;
      }
      if (auto *pbi = dyn_cast<ProjectBoxInst>(bai->getSource())) {
        getContext().emitNondifferentiabilityError(bai, getInvoker(),
            diag::autodiff_cannot_differentiate_writes_to_mutable_captures);
        errorOccurred = true;
        return;
      }
    }
  }

  /// Handle `unconditional_checked_cast_addr` instruction.
  ///   Original: y = unconditional_checked_cast_addr x
  ///    Adjoint: adj[x] += unconditional_checked_cast_addr adj[y]
  void visitUnconditionalCheckedCastAddrInst(
      UnconditionalCheckedCastAddrInst *uccai) {
    auto *bb = uccai->getParent();
    auto &adjDest = getAdjointBuffer(bb, uccai->getDest());
    auto &adjSrc = getAdjointBuffer(bb, uccai->getSrc());
    auto destType = remapType(adjDest->getType());
    auto castBuf = builder.createAllocStack(uccai->getLoc(), adjSrc->getType());
    builder.createUnconditionalCheckedCastAddr(
        uccai->getLoc(), adjDest, adjDest->getType().getASTType(), castBuf,
        adjSrc->getType().getASTType());
    addToAdjointBuffer(bb, uccai->getSrc(), castBuf, uccai->getLoc());
    builder.emitDestroyAddrAndFold(uccai->getLoc(), castBuf);
    builder.createDeallocStack(uccai->getLoc(), castBuf);
    emitZeroIndirect(destType.getASTType(), adjDest, uccai->getLoc());
  }

#define NOT_DIFFERENTIABLE(INST, DIAG) \
  void visit##INST##Inst(INST##Inst *inst) { \
    getContext().emitNondifferentiabilityError( \
        inst, getInvoker(), diag::DIAG); \
    errorOccurred = true; \
    return; \
  }
  NOT_DIFFERENTIABLE(RefElementAddr, autodiff_class_property_not_supported)
#undef NOT_DIFFERENTIABLE

#define NO_ADJOINT(INST) \
  void visit##INST##Inst(INST##Inst *inst) {}
  // Terminators.
  NO_ADJOINT(Return)
  NO_ADJOINT(Branch)
  NO_ADJOINT(CondBranch)

  // Buffer projection.
  NO_ADJOINT(StructElementAddr)
  NO_ADJOINT(TupleElementAddr)

  // Memory allocation/access.
  NO_ADJOINT(AllocStack)
  NO_ADJOINT(DeallocStack)
  NO_ADJOINT(EndAccess)

  // Debugging/reference counting instructions.
  NO_ADJOINT(DebugValue)
  NO_ADJOINT(DebugValueAddr)
  NO_ADJOINT(RetainValue)
  NO_ADJOINT(RetainValueAddr)
  NO_ADJOINT(ReleaseValue)
  NO_ADJOINT(ReleaseValueAddr)
  NO_ADJOINT(StrongRetain)
  NO_ADJOINT(StrongRelease)
  NO_ADJOINT(UnownedRetain)
  NO_ADJOINT(UnownedRelease)
  NO_ADJOINT(StrongRetainUnowned)
  NO_ADJOINT(DestroyValue)
  NO_ADJOINT(DestroyAddr)

  // Value ownership.
  NO_ADJOINT(EndBorrow)
#undef NO_DERIVATIVE
};
} // end anonymous namespace

AdjointValue PullbackEmitter::makeZeroAdjointValue(SILType type) {
  return AdjointValue::createZero(allocator, remapType(type));
}

AdjointValue
PullbackEmitter::makeConcreteAdjointValue(SILValue value) {
  return AdjointValue::createConcrete(allocator, value);
}

template<typename EltRange>
AdjointValue PullbackEmitter::makeAggregateAdjointValue(
    SILType type, EltRange elements) {
  return AdjointValue::createAggregate(allocator, remapType(type), elements);
}

SILValue PullbackEmitter::materializeAdjointDirect(
    AdjointValue val, SILLocation loc) {
  assert(val.getType().isObject());
  LLVM_DEBUG(getADDebugStream() <<
             "Materializing adjoints for " << val << '\n');
  switch (val.getKind()) {
  case AdjointValueKind::Zero:
    return recordTemporary(emitZeroDirect(val.getType().getASTType(), loc));
  case AdjointValueKind::Aggregate: {
    SmallVector<SILValue, 8> elements;
    for (auto i : range(val.getNumAggregateElements())) {
      auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
      elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
    }
    if (val.getType().is<TupleType>())
      return recordTemporary(
          builder.createTuple(loc, val.getType(), elements));
    else
      return recordTemporary(
          builder.createStruct(loc, val.getType(), elements));
  }
  case AdjointValueKind::Concrete:
    return val.getConcreteValue();
  }
}

SILValue PullbackEmitter::materializeAdjoint(AdjointValue val,
                                             SILLocation loc) {
  if (val.isConcrete()) {
    LLVM_DEBUG(getADDebugStream()
        << "Materializing adjoint: Value is concrete.\n");
    return val.getConcreteValue();
  }
  LLVM_DEBUG(getADDebugStream() << "Materializing adjoint: Value is "
                                   "non-concrete. Materializing directly.\n");
  return materializeAdjointDirect(val, loc);
}

void PullbackEmitter::materializeAdjointIndirect(
    AdjointValue val, SILValue destBufferAccess, SILLocation loc) {
  switch (val.getKind()) {
  /// Given a `%buf : *T, emit instructions that produce a zero or an aggregate
  /// of zeros of the expected type. When `T` conforms to
  /// `AdditiveArithmetic`, we emit a call to `AdditiveArithmetic.zero`. When
  /// `T` is a builtin float, we emit a `float_literal` instruction.
  /// Otherwise, we assert that `T` must be an aggregate where each element
  /// conforms to `AdditiveArithmetic` or is a builtin float. We expect to emit
  /// a zero for each element and use the appropriate aggregate constructor
  /// instruction (in this case, `tuple`) to produce a tuple. But currently,
  /// since we need indirect passing for aggregate instruction, we just use
  /// `tuple_element_addr` to get element buffers and write elements to them.
  case AdjointValueKind::Zero:
    emitZeroIndirect(val.getSwiftType(), destBufferAccess, loc);
    break;
  /// Given a `%buf : *(T0, T1, T2, ...)` or `%buf : *Struct` recursively emit
  /// instructions to materialize the symbolic tuple or struct, filling the
  /// buffer.
  case AdjointValueKind::Aggregate: {
    if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
      for (auto idx : range(val.getNumAggregateElements())) {
        auto eltTy = SILType::getPrimitiveAddressType(
            tupTy->getElementType(idx)->getCanonicalType());
        auto *eltBuf =
            builder.createTupleElementAddr(loc, destBufferAccess, idx, eltTy);
        materializeAdjointIndirect(
            val.getAggregateElement(idx), eltBuf, loc);
      }
    } else if (auto *structDecl =
                   val.getSwiftType()->getStructOrBoundGenericStruct()) {
      auto fieldIt = structDecl->getStoredProperties().begin();
      for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
           ++fieldIt, ++i) {
        auto eltBuf =
            builder.createStructElementAddr(loc, destBufferAccess, *fieldIt);
        materializeAdjointIndirect(
            val.getAggregateElement(i), eltBuf, loc);
      }
    } else {
      llvm_unreachable("Not an aggregate type");
    }
    break;
  }
  /// Value is already materialized!
  case AdjointValueKind::Concrete:
    auto concreteVal = val.getConcreteValue();
    builder.emitStoreValueOperation(loc, concreteVal, destBufferAccess,
                                    StoreOwnershipQualifier::Init);
    break;
  }
}

void PullbackEmitter::emitZeroIndirect(CanType type, SILValue bufferAccess,
                                       SILLocation loc) {
  auto tangentSpace = getTangentSpace(type);
  assert(tangentSpace && "No tangent space for this type");
  switch (tangentSpace->getKind()) {
  case VectorSpace::Kind::Vector:
    emitZeroIntoBuffer(builder, type, bufferAccess, loc);
    return;
  case VectorSpace::Kind::Tuple: {
    auto tupleType = tangentSpace->getTuple();
    SmallVector<SILValue, 8> zeroElements;
    for (unsigned i : range(tupleType->getNumElements())) {
      auto eltAddr = builder.createTupleElementAddr(loc, bufferAccess, i);
      emitZeroIndirect(tupleType->getElementType(i)->getCanonicalType(),
                       eltAddr, loc);
    }
    return;
  }
  case VectorSpace::Kind::Function: {
    llvm_unreachable(
      "Unimplemented: Emit thunks for abstracting zero initialization");
  }
  }
}

SILValue PullbackEmitter::emitZeroDirect(CanType type, SILLocation loc) {
  auto silType = getModule().Types.getLoweredLoadableType(
      type, ResilienceExpansion::Minimal, getModule());
  auto *buffer = builder.createAllocStack(loc, silType);
  emitZeroIndirect(type, buffer, loc);
  auto loaded = builder.emitLoadValueOperation(
      loc, buffer, LoadOwnershipQualifier::Take);
  builder.createDeallocStack(loc, buffer);
  return loaded;
}

AdjointValue
PullbackEmitter::accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs,
                                          SILLocation loc) {
  LLVM_DEBUG(getADDebugStream()
             << "Materializing adjoint directly.\nLHS: " << lhs
             << "\nRHS: " << rhs << '\n');

  switch (lhs.getKind()) {
  // x
  case AdjointValueKind::Concrete: {
    auto lhsVal = lhs.getConcreteValue();
    switch (rhs.getKind()) {
    // x + y
    case AdjointValueKind::Concrete: {
      auto rhsVal = rhs.getConcreteValue();
      auto sum = recordTemporary(accumulateDirect(lhsVal, rhsVal, loc));
      return makeConcreteAdjointValue(sum);
    }
    // x + 0 => x
    case AdjointValueKind::Zero:
      return lhs;
    // x + (y, z) => (x.0 + y, x.1 + z)
    case AdjointValueKind::Aggregate:
      SmallVector<AdjointValue, 8> newElements;
      auto lhsTy = lhsVal->getType().getASTType();
      auto lhsValCopy = builder.emitCopyValueOperation(loc, lhsVal);
      if (auto *tupTy = lhsTy->getAs<TupleType>()) {
        auto elts = builder.createDestructureTuple(loc, lhsValCopy);
        llvm::for_each(elts->getResults(),
                       [this](SILValue result) { recordTemporary(result); });
        for (auto i : indices(elts->getResults())) {
          auto rhsElt = rhs.getAggregateElement(i);
          newElements.push_back(accumulateAdjointsDirect(
              makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc));
        }
      } else if (auto *structDecl = lhsTy->getStructOrBoundGenericStruct()) {
        auto elts =
            builder.createDestructureStruct(lhsVal.getLoc(), lhsValCopy);
        llvm::for_each(elts->getResults(),
                       [this](SILValue result) { recordTemporary(result); });
        for (unsigned i : indices(elts->getResults())) {
          auto rhsElt = rhs.getAggregateElement(i);
          newElements.push_back(
              accumulateAdjointsDirect(
                  makeConcreteAdjointValue(elts->getResult(i)), rhsElt, loc));
        }
      } else {
        llvm_unreachable("Not an aggregate type");
      }
      return makeAggregateAdjointValue(lhsVal->getType(), newElements);
    }
  }
  // 0
  case AdjointValueKind::Zero:
    // 0 + x => x
    return rhs;
  // (x, y)
  case AdjointValueKind::Aggregate:
    switch (rhs.getKind()) {
    // (x, y) + z => (x + z.0, y + z.1)
    case AdjointValueKind::Concrete:
    // x + 0 => x
    case AdjointValueKind::Zero:
      return lhs;
    // (x, y) + (z, w) => (x + z, y + w)
    case AdjointValueKind::Aggregate: {
      SmallVector<AdjointValue, 8> newElements;
      for (auto i : range(lhs.getNumAggregateElements()))
        newElements.push_back(
            accumulateAdjointsDirect(lhs.getAggregateElement(i),
                                     rhs.getAggregateElement(i),
                                     loc));
      return makeAggregateAdjointValue(lhs.getType(), newElements);
    }
    }
  }
}

SILValue PullbackEmitter::accumulateDirect(SILValue lhs, SILValue rhs,
                                           SILLocation loc) {
  // TODO: Optimize for the case when lhs == rhs.
  LLVM_DEBUG(getADDebugStream() <<
             "Emitting adjoint accumulation for lhs: " << lhs <<
             " and rhs: " << rhs << "\n");
  assert(lhs->getType() == rhs->getType() && "Adjoints must have equal types!");
  assert(lhs->getType().isObject() && rhs->getType().isObject() &&
         "Adjoint types must be both object types!");
  auto adjointTy = lhs->getType();
  auto adjointASTTy = adjointTy.getASTType();
  auto tangentSpace = getTangentSpace(adjointASTTy);
  auto lhsCopy = builder.emitCopyValueOperation(loc, lhs);
  auto rhsCopy = builder.emitCopyValueOperation(loc, rhs);
  assert(tangentSpace && "No tangent space for this type");
  switch (tangentSpace->getKind()) {
  case VectorSpace::Kind::Vector: {
    // Allocate buffers for inputs and output.
    auto *resultBuf = builder.createAllocStack(loc, adjointTy);
    auto *lhsBuf = builder.createAllocStack(loc, adjointTy);
    auto *rhsBuf = builder.createAllocStack(loc, adjointTy);
    // Initialize input buffers.
    builder.emitStoreValueOperation(loc, lhsCopy, lhsBuf,
                                    StoreOwnershipQualifier::Init);
    builder.emitStoreValueOperation(loc, rhsCopy, rhsBuf,
                                    StoreOwnershipQualifier::Init);
    accumulateIndirect(resultBuf, lhsBuf, rhsBuf, loc);
    builder.emitDestroyAddr(loc, lhsBuf);
    builder.emitDestroyAddr(loc, rhsBuf);
    // Deallocate input buffers.
    builder.createDeallocStack(loc, rhsBuf);
    builder.createDeallocStack(loc, lhsBuf);
    auto val = builder.emitLoadValueOperation(
        loc, resultBuf, LoadOwnershipQualifier::Take);
    // Deallocate result buffer.
    builder.createDeallocStack(loc, resultBuf);
    return val;
  }
  case VectorSpace::Kind::Tuple: {
    SmallVector<SILValue, 8> adjElements;
    auto lhsElts = builder.createDestructureTuple(loc, lhsCopy)->getResults();
    auto rhsElts = builder.createDestructureTuple(loc, rhsCopy)->getResults();
    for (auto zipped : llvm::zip(lhsElts, rhsElts))
      adjElements.push_back(
          accumulateDirect(std::get<0>(zipped), std::get<1>(zipped), loc));
    return builder.createTuple(loc, adjointTy, adjElements);
  }
  case VectorSpace::Kind::Function: {
    llvm_unreachable(
        "Unimplemented: Emit thunks for abstracting adjoint accumulation");
  }
  }
}

void PullbackEmitter::accumulateIndirect(
    SILValue resultBufAccess, SILValue lhsBufAccess, SILValue rhsBufAccess,
    SILLocation loc) {
  // TODO: Optimize for the case when lhs == rhs.
  assert(lhsBufAccess->getType() == rhsBufAccess->getType() &&
         "Adjoint values must have same type!");
  assert(lhsBufAccess->getType().isAddress() &&
         rhsBufAccess->getType().isAddress() &&
         "Adjoint values must both have address types!");
  auto adjointTy = lhsBufAccess->getType();
  auto adjointASTTy = adjointTy.getASTType();
  auto *swiftMod = getModule().getSwiftModule();
  auto tangentSpace = adjointASTTy->getAutoDiffAssociatedTangentSpace(
      LookUpConformanceInModule(swiftMod));
  assert(tangentSpace && "No tangent space for this type");
  switch (tangentSpace->getKind()) {
  case VectorSpace::Kind::Vector: {
    auto *proto = getContext().getAdditiveArithmeticProtocol();
    auto *combinerFuncDecl = getContext().getPlusDecl();
    // Call the combiner function and return.
    auto adjointParentModule = tangentSpace->getNominal()
        ? tangentSpace->getNominal()->getModuleContext()
        : getModule().getSwiftModule();
    auto confRef = adjointParentModule->lookupConformance(adjointASTTy,
                                                           proto);
    assert(confRef.hasValue() && "Missing conformance to `AdditiveArithmetic`");
    SILDeclRef declRef(combinerFuncDecl, SILDeclRef::Kind::Func);
    auto silFnTy = getContext().getTypeConverter().getConstantType(declRef);
    // %0 = witness_method @+
    auto witnessMethod = builder.createWitnessMethod(loc, adjointASTTy,
                                                     *confRef, declRef,
                                                     silFnTy);
    auto subMap = SubstitutionMap::getProtocolSubstitutions(
        proto, adjointASTTy, *confRef);
    // %1 = metatype $T.Type
    auto metatypeType =
        CanMetatypeType::get(adjointASTTy, MetatypeRepresentation::Thick);
    auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
    auto metatype = builder.createMetatype(loc, metatypeSILType);
    // %2 = apply $0(%result, %new, %old, %1)
    builder.createApply(loc, witnessMethod, subMap,
                        {resultBufAccess, rhsBufAccess, lhsBufAccess, metatype},
                        /*isNonThrowing*/ false);
    builder.emitDestroyValueOperation(loc, witnessMethod);
    return;
  }
  case VectorSpace::Kind::Tuple: {
    auto tupleType = tangentSpace->getTuple();
    for (unsigned i : range(tupleType->getNumElements())) {
      auto *destAddr = builder.createTupleElementAddr(loc, resultBufAccess, i);
      auto *eltAddrLHS = builder.createTupleElementAddr(loc, lhsBufAccess, i);
      auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsBufAccess, i);
      accumulateIndirect(destAddr, eltAddrLHS, eltAddrRHS, loc);
    }
    return;
  }
  case VectorSpace::Kind::Function: {
    llvm_unreachable(
        "Unimplemented: Emit thunks for abstracting adjoint value "
        "accumulation");
  }
  }
}

void PullbackEmitter::accumulateIndirect(SILValue lhsDestAccess,
                                         SILValue rhsAccess, SILLocation loc) {
  assert(lhsDestAccess->getType().isAddress() &&
         rhsAccess->getType().isAddress());
  assert(lhsDestAccess->getFunction() == &getPullback());
  assert(rhsAccess->getFunction() == &getPullback());
  auto type = lhsDestAccess->getType();
  auto astType = type.getASTType();
  auto *swiftMod = getModule().getSwiftModule();
  auto tangentSpace = astType->getAutoDiffAssociatedTangentSpace(
      LookUpConformanceInModule(swiftMod));
  assert(tangentSpace && "No tangent space for this type");
  switch (tangentSpace->getKind()) {
  case VectorSpace::Kind::Vector: {
    auto *proto = getContext().getAdditiveArithmeticProtocol();
    auto *accumulatorFuncDecl = getContext().getPlusEqualDecl();
    // Call the combiner function and return.
    auto confRef = swiftMod->lookupConformance(astType, proto);
    assert(confRef.hasValue() && "Missing conformance to `AdditiveArithmetic`");
    SILDeclRef declRef(accumulatorFuncDecl, SILDeclRef::Kind::Func);
    auto silFnTy = getContext().getTypeConverter().getConstantType(declRef);
    // %0 = witness_method @+=
    auto witnessMethod =
        builder.createWitnessMethod(loc, astType, *confRef, declRef, silFnTy);
    auto subMap =
        SubstitutionMap::getProtocolSubstitutions(proto, astType, *confRef);
    // %1 = metatype $T.Type
    auto metatypeType =
        CanMetatypeType::get(astType, MetatypeRepresentation::Thick);
    auto metatypeSILType = SILType::getPrimitiveObjectType(metatypeType);
    auto metatype = builder.createMetatype(loc, metatypeSILType);
    // %2 = apply $0(%lhs, %rhs, %1)
    builder.createApply(loc, witnessMethod, subMap,
                        {lhsDestAccess, rhsAccess, metatype},
                        /*isNonThrowing*/ false);
    builder.emitDestroyValueOperation(loc, witnessMethod);
    return;
  }
  case VectorSpace::Kind::Tuple: {
    auto tupleType = tangentSpace->getTuple();
    for (unsigned i : range(tupleType->getNumElements())) {
      auto *destAddr = builder.createTupleElementAddr(loc, lhsDestAccess, i);
      auto *eltAddrRHS = builder.createTupleElementAddr(loc, rhsAccess, i);
      accumulateIndirect(destAddr, eltAddrRHS, loc);
    }
    return;
  }
  case VectorSpace::Kind::Function: {
    llvm_unreachable(
        "Unimplemented: Emit thunks for abstracting adjoint value "
        "accumulation");
  }
  }
}

bool VJPEmitter::run() {
  LLVM_DEBUG(getADDebugStream()
             << "Cloning original @" << original->getName()
             << " to vjp @" << vjp->getName() << '\n');
  // Create entry BB and arguments.
  auto *entry = vjp->createBasicBlock();
  createEntryArguments(vjp);

  // Clone.
  SmallVector<SILValue, 4> entryArgs(entry->getArguments().begin(),
                                     entry->getArguments().end());
  cloneFunctionBody(original, entry, entryArgs);
  // If errors occurred, back out.
  if (errorOccurred)
    return true;

  // Each `@guaranteed` trampoline argument needs to have a lifetime-ending use
  // past its destination argument's lifetime-ending uses (aka. `end_borrow`).
  // `trampolinedGuaranteedPhiArguments` tracks all `@guaranteed` trampoline
  // arguments. We emit an `end_borrow` immediately past each destination
  // argument's lifetime-ending uses.
  for (auto &trampolinedArgPair : trampolinedGuaranteedPhiArguments) {
    for (auto *destArgUse : trampolinedArgPair.destinationArgument->getUses()) {
      if (auto *lifetimeEnd = dyn_cast<EndBorrowInst>(destArgUse->getUser())) {
        getBuilder().setInsertionPoint(lifetimeEnd->getParentBlock(),
                                       std::next(lifetimeEnd->getIterator()));
        getBuilder().emitEndBorrowOperation(
            lifetimeEnd->getLoc(), trampolinedArgPair.trampolineArgument);
      }
    }
  }

  // Generate pullback code.
  PullbackEmitter PullbackEmitter(*this);
  if (PullbackEmitter.run()) {
    errorOccurred = true;
    return true;
  }
  LLVM_DEBUG(getADDebugStream() << "Generated VJP for "
                                << original->getName() << ":\n" << *vjp);
  return errorOccurred;
}

//===----------------------------------------------------------------------===//
// `[differentiable]` attribute processing
//===----------------------------------------------------------------------===//

SILFunction *
ADContext::declareExternalDerivativeFunction(
    SILFunction *original, SILDifferentiableAttr *attr, StringRef name,
    AutoDiffDerivativeFunctionKind kind) {
  auto &module = getModule();
  auto &indices = attr->getIndices();
  auto originalTy = original->getLoweredFunctionType();
  auto originalLoc = original->getLocation();
  auto assocGenSig = getDerivativeGenericSignature(attr, original);
  auto derivativeFnTy = originalTy->getAutoDiffDerivativeFunctionType(
      indices.parameters, indices.source, kind, module.Types,
      LookUpConformanceInModule(module.getSwiftModule()), assocGenSig);
  SILOptFunctionBuilder fb(getTransform());
  // Create external function declaration.
  auto *derivativeFn = fb.createFunction(
      SILLinkage::PublicExternal, name, derivativeFnTy,
      /*genericEnv*/ nullptr, originalLoc, original->isBare(), IsNotTransparent,
      original->isSerialized(), original->isDynamicallyReplaceable());
  // Note: Setting debug scope prevents crashes during later transforms.
  derivativeFn->setDebugScope(new (module) SILDebugScope(originalLoc, derivativeFn));
  return derivativeFn;
}

static SILFunction *createEmptyVJP(
    ADContext &context, SILFunction *original, SILDifferentiableAttr *attr,
    bool isExported) {
  LLVM_DEBUG({
    auto &s = getADDebugStream();
    s << "Creating VJP:\n\t";
    s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
  });

  auto &module = context.getModule();
  auto originalTy = original->getLoweredFunctionType();
  auto indices = attr->getIndices();

  // === Create an empty VJP. ===
  Mangle::ASTMangler mangler;
  auto vjpName = original->getASTContext().getIdentifier(
      mangler.mangleAutoDiffDerivativeFunctionHelper(
          original->getName(), AutoDiffDerivativeFunctionKind::VJP, indices))
              .str();
  auto vjpGenericSig = getDerivativeGenericSignature(attr, original);

  // RAII that pushes the original function's generic signature to
  // `module.Types` so that calls to `module.Types.getTypeLowering()` below
  // will know the VJP's generic parameter types.
  Lowering::GenericContextScope genericContextScope(
      module.Types, vjpGenericSig);

  auto *vjpGenericEnv = vjpGenericSig
      ? vjpGenericSig->getGenericEnvironment()
      : nullptr;
  auto vjpType = originalTy->getAutoDiffDerivativeFunctionType(
      indices.parameters, indices.source, AutoDiffDerivativeFunctionKind::VJP,
      module.Types, LookUpConformanceInModule(module.getSwiftModule()),
      vjpGenericSig);

  SILOptFunctionBuilder fb(context.getTransform());
  auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage(
      original->getLinkage(), isExported);
  auto *vjp = fb.createFunction(linkage, vjpName, vjpType, vjpGenericEnv,
                                original->getLocation(), original->isBare(),
                                IsNotTransparent, original->isSerialized(),
                                original->isDynamicallyReplaceable());
  vjp->setDebugScope(new (module) SILDebugScope(original->getLocation(), vjp));
  attr->setVJPName(vjpName);

  LLVM_DEBUG(llvm::dbgs() << "VJP type: " << vjp->getLoweredFunctionType()
                          << "\n");
  return vjp;
}

static SILFunction *createEmptyJVP(
    ADContext &context, SILFunction *original, SILDifferentiableAttr *attr,
    bool isExported) {
  LLVM_DEBUG({
    auto &s = getADDebugStream();
    s << "Creating JVP:\n\t";
    s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
  });

  auto &module = context.getModule();
  auto originalTy = original->getLoweredFunctionType();
  auto indices = attr->getIndices();

  // === Create an empty JVP. ===
  Mangle::ASTMangler mangler;
  auto jvpName = original->getASTContext().getIdentifier(
      mangler.mangleAutoDiffDerivativeFunctionHelper(
          original->getName(), AutoDiffDerivativeFunctionKind::JVP, indices))
              .str();
  auto jvpGenericSig = getDerivativeGenericSignature(attr, original);

  // RAII that pushes the original function's generic signature to
  // `module.Types` so that calls to `module.Types.getTypeLowering()` below
  // will know the VJP's generic parameter types.
  Lowering::GenericContextScope genericContextScope(
      module.Types, jvpGenericSig);

  auto *jvpGenericEnv = jvpGenericSig
      ? jvpGenericSig->getGenericEnvironment()
      : nullptr;
  auto jvpType = originalTy->getAutoDiffDerivativeFunctionType(
      indices.parameters, indices.source,
      AutoDiffDerivativeFunctionKind::JVP, module.Types,
      LookUpConformanceInModule(module.getSwiftModule()), jvpGenericSig);

  SILOptFunctionBuilder fb(context.getTransform());
  auto linkage = autodiff::getAutoDiffDerivativeFunctionLinkage(
      original->getLinkage(), isExported);
  auto *jvp = fb.createFunction(linkage, jvpName, jvpType, jvpGenericEnv,
                                original->getLocation(), original->isBare(),
                                IsNotTransparent, original->isSerialized(),
                                original->isDynamicallyReplaceable());
  jvp->setDebugScope(new (module) SILDebugScope(original->getLocation(), jvp));
  attr->setJVPName(jvpName);

  LLVM_DEBUG(llvm::dbgs() << "JVP type: " << jvp->getLoweredFunctionType()
             << "\n");
  return jvp;
}

/// Returns true on error.
bool ADContext::processDifferentiableAttribute(
    SILFunction *original, SILDifferentiableAttr *attr,
    DifferentiationInvoker invoker) {
  auto &module = getModule();
  // Try to look up JVP only if attribute specifies JVP name or if original
  // function is an external declaration. If JVP function cannot be found,
  // create an external JVP reference.
  StringRef jvpName;
  SILFunction *jvp = nullptr;
  if (attr->hasJVP()) {
    jvpName = attr->getJVPName();
  } else if (original->isExternalDeclaration()) {
    Mangle::ASTMangler mangler;
    jvpName = original->getASTContext().getIdentifier(
        mangler.mangleAutoDiffDerivativeFunctionHelper(
            original->getName(), AutoDiffDerivativeFunctionKind::JVP,
            attr->getIndices())).str();
  }
  if (!jvpName.empty()) {
    jvp = module.lookUpFunction(jvpName);
    if (!jvp)
      jvp = declareExternalDerivativeFunction(
          original, attr, jvpName, AutoDiffDerivativeFunctionKind::JVP);
    attr->setJVPName(jvpName);
  }

  // If differentiation is triggered by `[differentiable]`, derivative function
  // should share linkage of original function.
  auto isDerivativeFnExported =
      invoker.getKind() ==
          DifferentiationInvoker::Kind::SILDifferentiableAttribute;

  // Try to look up VJP only if attribute specifies VJP name or if original
  // function is an external declaration. If VJP function cannot be found,
  // create an external VJP reference.
  StringRef vjpName;
  SILFunction *vjp = nullptr;
  if (attr->hasVJP()) {
    vjpName = attr->getVJPName();
  } else if (original->isExternalDeclaration()) {
    Mangle::ASTMangler mangler;
    vjpName = original->getASTContext().getIdentifier(
        mangler.mangleAutoDiffDerivativeFunctionHelper(
            original->getName(), AutoDiffDerivativeFunctionKind::VJP,
            attr->getIndices())).str();
  }
  if (!vjpName.empty()) {
    vjp = module.lookUpFunction(vjpName);
    if (!vjp)
      vjp = declareExternalDerivativeFunction(
          original, attr, vjpName, AutoDiffDerivativeFunctionKind::VJP);
    attr->setVJPName(vjpName);
  }

  // If the JVP doesn't exist, need to synthesize it.
  if (!jvp) {
    // Diagnose:
    // - Functions with no return.
    // - Functions with unsupported control flow.
    if (getASTContext().LangOpts.EnableExperimentalForwardModeDifferentiation &&
        (diagnoseNoReturn(*this, original, invoker) ||
         diagnoseUnsupportedControlFlow(*this, original, invoker)))
      return true;

    jvp = createEmptyJVP(*this, original, attr, isDerivativeFnExported);
    getGeneratedFunctions().push_back(jvp);

    // For now, only do JVP generation if the flag is enabled and if custom VJP
    // does not exist. If custom VJP exists but custom JVP does not, skip JVP
    // generation because generated JVP may not match semantics of custom VJP.
    // Instead, create an empty JVP.
    if (getASTContext().LangOpts.EnableExperimentalForwardModeDifferentiation &&
        !vjp) {
      // JVP and differential generation do not currently support functions with
      // multiple basic blocks.
      if (original->getBlocks().size() > 1) {
        emitNondifferentiabilityError(
            original->getLocation().getSourceLoc(), invoker,
            diag::autodiff_jvp_control_flow_not_supported);
        return true;
      }

      JVPEmitter emitter(*this, original, attr, jvp, invoker);
      if (emitter.run())
        return true;
    } else {
      LLVM_DEBUG(getADDebugStream()
                 << "Generating empty JVP for original @"
                 << original->getName() << '\n');
      // Create empty JVP body since custom VJP exists.
      auto *entry = jvp->createBasicBlock();
      createEntryArguments(jvp);
      SILBuilder builder(entry);
      auto loc = jvp->getLocation();

      // Destroy all owned arguments.
      for (auto *arg : entry->getArguments())
        if (arg->getOwnershipKind() == ValueOwnershipKind::Owned)
          builder.emitDestroyOperation(loc, arg);

      // Fatal error in case this JVP is called by the user.
      auto neverResultInfo = SILResultInfo(
          module.getASTContext().getNeverType(), ResultConvention::Unowned);
      auto fatalErrorJVPType = SILFunctionType::get(
          /*genericSig*/ nullptr,
          SILFunctionType::ExtInfo().withRepresentation(
              SILFunctionTypeRepresentation::Thin),
          SILCoroutineKind::None, ParameterConvention::Direct_Unowned, {},
          /*interfaceYields*/ {}, neverResultInfo,
          /*interfaceErrorResults*/ None, getASTContext());
      auto fnBuilder = SILOptFunctionBuilder(getTransform());
      auto *fatalErrrorJvpFunc = fnBuilder.getOrCreateFunction(
          loc, "_printJVPErrorAndExit", SILLinkage::PublicExternal,
          fatalErrorJVPType, IsNotBare, IsNotTransparent, IsNotSerialized,
          IsNotDynamic, ProfileCounter(), IsNotThunk);
      auto *jvpErrorFuncRef =
          builder.createFunctionRef(loc, fatalErrrorJvpFunc);
      builder.createApply(loc, jvpErrorFuncRef, SubstitutionMap(), {});
      builder.createUnreachable(loc);
      LLVM_DEBUG(getADDebugStream() << "Generated empty JVP for "
                 << original->getName() << ":\n" << *jvp);
    }
  }

  // If the VJP doesn't exist, need to synthesize it.
  if (!vjp) {
    // Diagnose:
    // - Functions with no return.
    // - Functions with unsupported control flow.
    if (diagnoseNoReturn(*this, original, invoker) ||
        diagnoseUnsupportedControlFlow(*this, original, invoker))
      return true;

    vjp = createEmptyVJP(*this, original, attr, isDerivativeFnExported);
    getGeneratedFunctions().push_back(vjp);
    VJPEmitter emitter(*this, original, attr, vjp, invoker);
    return emitter.run();
  }

  return false;
}

//===----------------------------------------------------------------------===//
// Differentiation pass implementation
//===----------------------------------------------------------------------===//

/// The automatic differentiation pass.
namespace {
class Differentiation : public SILModuleTransform {
public:
  Differentiation() : SILModuleTransform() {}
  void run() override;
};
} // end anonymous namespace

std::pair<SILFunction *, SubstitutionMap>
ADContext::getOrCreateSubsetParametersThunkForLinearMap(
    SILFunction *parentThunk, CanSILFunctionType linearMapType,
    CanSILFunctionType targetType, AutoDiffDerivativeFunctionKind kind,
    SILAutoDiffIndices desiredIndices, SILAutoDiffIndices actualIndices) {
  LLVM_DEBUG(getADDebugStream()
             << "Getting a subset parameters thunk for " << linearMapType
             << " from " << actualIndices << " to " << desiredIndices << '\n');

  SubstitutionMap interfaceSubs;
  GenericEnvironment *genericEnv = nullptr;
  auto thunkType = buildThunkType(
      parentThunk, linearMapType, targetType, genericEnv, interfaceSubs,
      /*withoutActuallyEscaping*/ true,
      DifferentiationThunkKind::Reabstraction);

  // TODO(TF-685): Use more principled mangling for thunks.
  std::string thunkName;
  switch (kind) {
    case AutoDiffDerivativeFunctionKind::JVP:
      thunkName = "differential";
      break;
    case AutoDiffDerivativeFunctionKind::VJP:
      thunkName = "pullback";
  }
  Mangle::ASTMangler mangler;
  auto fromInterfaceType =
      linearMapType->mapTypeOutOfContext()->getCanonicalType();
  auto toInterfaceType = targetType->mapTypeOutOfContext()->getCanonicalType();
  CanType dynamicSelfType;
  thunkName = "AD__" + mangler.mangleReabstractionThunkHelper(
      thunkType, fromInterfaceType, toInterfaceType, dynamicSelfType,
      module.getSwiftModule()) + "_" + desiredIndices.mangle() + "_" +
      thunkName;
  thunkName += "_index_subset_thunk";

  auto loc = parentThunk->getLocation();
  SILOptFunctionBuilder fb(getTransform());
  auto *thunk = fb.getOrCreateSharedFunction(
      loc, thunkName, thunkType, IsBare, IsTransparent, IsSerialized,
      ProfileCounter(), IsThunk, IsNotDynamic);

  if (!thunk->empty())
    return {thunk, interfaceSubs};

  thunk->setGenericEnvironment(genericEnv);
  thunk->setOwnershipEliminated();
  auto *entry = thunk->createBasicBlock();
  SILBuilder builder(entry);
  createEntryArguments(thunk);

  // Get arguments.
  SmallVector<SILValue, 4> arguments;
  SmallVector<AllocStackInst *, 4> localAllocations;

  // Build a `.zero` argument for the given `Differentiable`-conforming type.
  auto buildZeroArgument = [&](SILType zeroSILType) {
    auto zeroSILObjType = zeroSILType.getObjectType();
    auto zeroType = zeroSILType.getASTType();
    auto *swiftMod = getModule().getSwiftModule();
    auto tangentSpace = zeroType->getAutoDiffAssociatedTangentSpace(
      LookUpConformanceInModule(swiftMod));
    assert(tangentSpace && "No tangent space for this type");
    switch (tangentSpace->getKind()) {
    case VectorSpace::Kind::Vector: {
      auto *buf = builder.createAllocStack(loc, zeroSILObjType);
      localAllocations.push_back(buf);
      emitZeroIntoBuffer(builder, zeroType, buf, loc);
      if (zeroSILType.isAddress())
        arguments.push_back(buf);
      else {
        auto *arg = builder.createLoad(loc, buf,
                                       LoadOwnershipQualifier::Unqualified);
        arguments.push_back(arg);
      }
      break;
    }
    case VectorSpace::Kind::Tuple: {
      llvm_unreachable(
          "Unimplemented: Handle zero initialization for tuples");
    }
    case VectorSpace::Kind::Function:
      llvm_unreachable(
          "Unimplemented: Emit thunks for abstracting zero initialization");
    }
  };

  // `actualIndices` and `desiredIndices` are with respect to the original
  // function. However, the differential parameters and pullback results may
  // already be w.r.t. a subset. We create a map between the original function's
  // actual parameter indices and the linear map's actual indices.
  // Example:
  //   Original: (T0, T1, T2) -> R
  //   Actual indices: 0, 2
  //   Original differential: (T0, T2) -> R
  //   Original pullback: R -> (T0, T2)
  //   Desired indices w.r.t. original: 2
  //   Desired indices w.r.t. linear map: 1
  SmallVector<unsigned, 4> actualParamIndicesMap(
      actualIndices.parameters->getCapacity(), UINT_MAX);
  {
    unsigned indexInBitVec = 0;
    for (auto index : actualIndices.parameters->getIndices()) {
      actualParamIndicesMap[index] = indexInBitVec;
      indexInBitVec++;
    }
  }
  auto mapOriginalParameterIndex = [&](unsigned index) -> unsigned {
    auto mappedIndex = actualParamIndicesMap[index];
    assert(mappedIndex < actualIndices.parameters->getCapacity());
    return mappedIndex;
  };

  switch (kind) {
  // Differential arguments are:
  // - All indirect results, followed by:
  // - An interleaving of:
  //   - Thunk arguments (when parameter index is in both desired and actual
  //     indices).
  //   - Zeros (when parameter is not in desired indices).
  case AutoDiffDerivativeFunctionKind::JVP: {
    // Forward all indirect results.
    arguments.append(thunk->getIndirectResults().begin(),
                     thunk->getIndirectResults().end());
    auto toArgIter = thunk->getArgumentsWithoutIndirectResults().begin();
    auto useNextArgument = [&]() {
      arguments.push_back(*toArgIter++);
    };
    // Iterate over actual indices.
    for (unsigned i : actualIndices.parameters->getIndices()) {
      // If index is desired, use next argument.
      if (desiredIndices.isWrtParameter(i)) {
        useNextArgument();
      }
      // Otherwise, construct and use a zero argument.
      else {
        auto zeroSILType =
            linearMapType->getParameters()[mapOriginalParameterIndex(i)]
                .getSILStorageType();
        buildZeroArgument(zeroSILType);
      }
    }
    break;
  }
  // Pullback arguments are:
  // - An interleaving of:
  //   - Thunk indirect results (when parameter index is in both desired and
  //     actual indices).
  //   - Zeros (when parameter is not in desired indices).
  // - All actual arguments.
  case AutoDiffDerivativeFunctionKind::VJP: {
    auto toIndirectResultsIter = thunk->getIndirectResults().begin();
    auto useNextResult = [&]() {
      arguments.push_back(*toIndirectResultsIter++);
    };
    // Iterate over actual indices.
    for (unsigned i : actualIndices.parameters->getIndices()) {
      auto resultInfo =
          linearMapType->getResults()[mapOriginalParameterIndex(i)];
      // Skip direct results. Only indirect results are relevant as arguments.
      if (resultInfo.isFormalDirect())
        continue;
      // If index is desired, use next indirect result.
      if (desiredIndices.isWrtParameter(i)) {
        useNextResult();
        continue;
      }
      // Otherwise, construct and use an uninitialized indirect result.
      auto *indirectResult =
          builder.createAllocStack(loc, resultInfo.getSILStorageType());
      localAllocations.push_back(indirectResult);
      arguments.push_back(indirectResult);
    }
    // Foward all actual non-indirect-result arguments.
    arguments.append(thunk->getArgumentsWithoutIndirectResults().begin(),
                     thunk->getArgumentsWithoutIndirectResults().end() - 1);
    break;
  }
  }

  // Get the linear map thunk argument and apply it.
  auto *linearMap = thunk->getArguments().back();
  auto *ai = builder.createApply(
      loc, linearMap, SubstitutionMap(), arguments, /*isNonThrowing*/ false);

  // If differential thunk, deallocate local allocations and directly return
  // `apply` result.
  if (kind == AutoDiffDerivativeFunctionKind::JVP) {
    for (auto *alloc : reversed(localAllocations))
      builder.createDeallocStack(loc, alloc);
    builder.createReturn(loc, ai);
    return {thunk, interfaceSubs};
  }

  // If pullback thunk, return only the desired results and clean up the
  // undesired results.
  SmallVector<SILValue, 8> pullbackDirectResults;
  extractAllElements(ai, builder, pullbackDirectResults);
  SmallVector<SILValue, 8> allResults;
  collectAllActualResultsInTypeOrder(ai, pullbackDirectResults, allResults);

  SmallVector<SILValue, 8> results;
  for (unsigned i : actualIndices.parameters->getIndices()) {
    // If result is desired:
    // - Do nothing if result is indirect.
    //   (It was already forwarded to the `apply` instruction).
    // - Push it to `results` if result is direct.
    auto result = allResults[mapOriginalParameterIndex(i)];
    if (desiredIndices.isWrtParameter(i)) {
      if (result->getType().isObject())
        results.push_back(result);
    }
    // Otherwise, cleanup the unused results.
    else {
      if (result->getType().isAddress())
        builder.emitDestroyAddrAndFold(loc, result);
      else
        builder.emitDestroyValueOperation(loc, result);
    }
  }
  // Deallocate local allocations and return final direct result.
  for (auto *alloc : reversed(localAllocations))
    builder.createDeallocStack(loc, alloc);
  auto result = joinElements(results, builder, loc);
  builder.createReturn(loc, result);

  getGeneratedFunctions().push_back(thunk);
  return {thunk, interfaceSubs};
}

std::pair<SILFunction *, SubstitutionMap>
ADContext::getOrCreateSubsetParametersThunkForDerivativeFunction(
    SILValue origFnOperand, SILValue derivativeFn,
    AutoDiffDerivativeFunctionKind kind, SILAutoDiffIndices desiredIndices,
    SILAutoDiffIndices actualIndices) {
  LLVM_DEBUG(getADDebugStream()
             << "Getting a subset parameters thunk for derivative function "
             << derivativeFn << " of the original function " << origFnOperand
             << " from " << actualIndices << " to " << desiredIndices << '\n');

  auto origFnType = origFnOperand->getType().castTo<SILFunctionType>();
  auto &module = getModule();
  auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());

  // Compute target type for thunking.
  auto derivativeFnType = derivativeFn->getType().castTo<SILFunctionType>();
  auto targetType = origFnType->getAutoDiffDerivativeFunctionType(
      desiredIndices.parameters, desiredIndices.source, kind, module.Types,
      lookupConformance);
  auto *caller = derivativeFn->getFunction();
  if (targetType->hasArchetype()) {
    auto substTargetType = caller->mapTypeIntoContext(
        targetType->mapTypeOutOfContext())->getCanonicalType();
    targetType = SILType::getPrimitiveObjectType(substTargetType)
        .castTo<SILFunctionType>();
  }
  assert(derivativeFnType->getNumParameters() == targetType->getNumParameters());
  assert(derivativeFnType->getNumResults() == targetType->getNumResults());

  // Build thunk type.
  SubstitutionMap interfaceSubs;
  GenericEnvironment *genericEnv = nullptr;
  auto thunkType = buildThunkType(
      derivativeFn->getFunction(), derivativeFnType, targetType, genericEnv,
      interfaceSubs, /*withoutActuallyEscaping*/ false,
      DifferentiationThunkKind::IndexSubset);

  // FIXME: The logic for resolving `assocRef` does not reapply function
  // conversions, which is problematic if `derivativeFn` is a `partial_apply`
  // instruction.
  StringRef origName;
  if (auto *origFnRef =
          peerThroughFunctionConversions<FunctionRefInst>(origFnOperand)) {
    origName = origFnRef->getInitiallyReferencedFunction()->getName();
  } else if (auto *origMethodInst =
                 peerThroughFunctionConversions<MethodInst>(origFnOperand)) {
    origName = origMethodInst->getMember().getAnyFunctionRef()
        ->getAbstractFunctionDecl()->getNameStr();
  }
  assert(!origName.empty() && "Original function name could not be resolved");
  // TODO(TF-685): Use more principled mangling for thunks.
  std::string thunkName;
  switch (kind) {
    case AutoDiffDerivativeFunctionKind::JVP:
      thunkName = "jvp";
      break;
    case AutoDiffDerivativeFunctionKind::VJP:
      thunkName = "vjp";
  }
  Mangle::ASTMangler mangler;
  auto fromInterfaceType =
      derivativeFnType->mapTypeOutOfContext()->getCanonicalType();
  auto toInterfaceType = targetType->mapTypeOutOfContext()->getCanonicalType();
  CanType dynamicSelfType;
  thunkName = "AD__orig_" + origName.str() + "_" +
      mangler.mangleReabstractionThunkHelper(
          thunkType, fromInterfaceType, toInterfaceType, dynamicSelfType,
          module.getSwiftModule()) + "_" + desiredIndices.mangle() + "_" +
          thunkName;
  thunkName += "_subset_parameters_thunk";

  auto loc = origFnOperand.getLoc();
  SILOptFunctionBuilder fb(getTransform());
  auto *thunk = fb.getOrCreateSharedFunction(
      loc, thunkName, thunkType, IsBare, IsTransparent, caller->isSerialized(),
      ProfileCounter(), IsThunk, IsNotDynamic);

  if (!thunk->empty())
    return {thunk, interfaceSubs};

  thunk->setOwnershipEliminated();
  thunk->setGenericEnvironment(genericEnv);
  auto *entry = thunk->createBasicBlock();
  SILBuilder builder(entry);
  createEntryArguments(thunk);

  SubstitutionMap assocSubstMap;
  if (auto *partialApply = dyn_cast<PartialApplyInst>(derivativeFn))
    assocSubstMap = partialApply->getSubstitutionMap();

  // FIXME: The logic for resolving `assocRef` does not reapply function
  // conversions, which is problematic if `derivativeFn` is a `partial_apply`
  // instruction.
  SILValue assocRef;
  if (auto *derivativeFnRef =
          peerThroughFunctionConversions<FunctionRefInst>(derivativeFn)) {
    auto *assoc = derivativeFnRef->getReferencedFunctionOrNull();
    assocRef = builder.createFunctionRef(loc, assoc);
  } else if (auto *assocMethodInst =
                 peerThroughFunctionConversions<WitnessMethodInst>(derivativeFn)) {
    assocRef = builder.createWitnessMethod(
        loc, assocMethodInst->getLookupType(),
        assocMethodInst->getConformance(), assocMethodInst->getMember(),
        thunk->mapTypeIntoContext(assocMethodInst->getType()));
  } else if (auto *assocMethodInst =
                 peerThroughFunctionConversions<ClassMethodInst>(derivativeFn)) {
    auto classOperand = thunk->getArgumentsWithoutIndirectResults().back();
    auto classOperandType = assocMethodInst->getOperand()->getType();
    assert(classOperand->getType() == classOperandType);
    assocRef = builder.createClassMethod(
        loc, classOperand, assocMethodInst->getMember(),
        thunk->mapTypeIntoContext(assocMethodInst->getType()));
  }
  assert(assocRef && "Expected derivative function to be resolved");

  assocSubstMap = assocSubstMap.subst(thunk->getForwardingSubstitutionMap());
  derivativeFnType = assocRef->getType().castTo<SILFunctionType>();

  SmallVector<SILValue, 4> arguments;
  arguments.append(thunk->getArguments().begin(), thunk->getArguments().end());
  assert(arguments.size() == derivativeFnType->getNumParameters() +
                                 derivativeFnType->getNumIndirectFormalResults());
  auto *apply = builder.createApply(
      loc, assocRef, assocSubstMap, arguments, /*isNonThrowing*/ false);

  // Extract all direct results.
  SmallVector<SILValue, 8> directResults;
  extractAllElements(apply, builder, directResults);
  auto originalDirectResults = ArrayRef<SILValue>(directResults).drop_back(1);
  auto originalDirectResult =
      joinElements(originalDirectResults, builder, apply->getLoc());
  auto linearMap = directResults.back();

  auto linearMapType = linearMap->getType().castTo<SILFunctionType>();
  auto linearMapTargetType = targetType->getResults().back().getSILStorageType()
      .castTo<SILFunctionType>();

  SILFunction *linearMapThunk;
  SubstitutionMap linearMapSubs;
  std::tie(linearMapThunk, linearMapSubs) =
      getOrCreateSubsetParametersThunkForLinearMap(
          thunk, linearMapType, linearMapTargetType, kind,
          desiredIndices, actualIndices);

  auto *linearMapThunkFRI = builder.createFunctionRef(loc, linearMapThunk);
  auto *thunkedLinearMap = builder.createPartialApply(
      loc, linearMapThunkFRI, linearMapSubs, {linearMap},
      ParameterConvention::Direct_Guaranteed);

  assert(origFnType->getResults().size() == 1);
  if (origFnType->getResults().front().isFormalDirect()) {
    auto result = joinElements(
        {originalDirectResult, thunkedLinearMap}, builder, loc);
    builder.createReturn(loc, result);
  } else {
    builder.createReturn(loc, thunkedLinearMap);
  }

  getGeneratedFunctions().push_back(thunk);
  return {thunk, interfaceSubs};
}

SILValue ADContext::promoteToDifferentiableFunction(
    DifferentiableFunctionInst *dfi, SILBuilder &builder, SILLocation loc,
    DifferentiationInvoker invoker) {
  auto origFnOperand = dfi->getOriginalFunction();
  auto origFnTy = origFnOperand->getType().castTo<SILFunctionType>();
  auto parameterIndices = dfi->getParameterIndices();
  unsigned resultIndex = resultIndices[dfi];

  // Handle curry thunk applications specially.
  if (auto *ai = dyn_cast<ApplyInst>(origFnOperand)) {
    if (auto *thunkRef = dyn_cast<FunctionRefInst>(ai->getCallee())) {
      // Create a new curry thunk.
      SILAutoDiffIndices desiredIndices(resultIndex, parameterIndices);
      auto *thunk = thunkRef->getReferencedFunctionOrNull();
      // TODO(TF-685): Use more principled mangling for thunks.
      auto newThunkName = "AD__" + thunk->getName().str() +
          "__differentiable_curry_thunk_" + desiredIndices.mangle();

      auto thunkTy = thunk->getLoweredFunctionType();
      auto thunkResult = thunkTy->getSingleResult();
      if (auto resultFnTy = thunkResult.getType()->getAs<SILFunctionType>()) {
        // Construct new curry thunk type with `@differentiable` result.
        auto diffableResultFnTy = resultFnTy->getWithExtInfo(
            resultFnTy->getExtInfo()
                .withDifferentiabilityKind(DifferentiabilityKind::Normal));
        auto newThunkResult = thunkResult.getWithType(diffableResultFnTy);
        auto thunkType = SILFunctionType::get(
            thunkTy->getGenericSignature(), thunkTy->getExtInfo(),
            thunkTy->getCoroutineKind(), thunkTy->getCalleeConvention(),
            thunkTy->getParameters(), {}, {newThunkResult}, {},
            thunkTy->getASTContext());

        // Construct new curry thunk, returning a `@differentiable` function.
        SILOptFunctionBuilder fb(transform);
        auto *newThunk = fb.getOrCreateFunction(
            loc, newThunkName,
            getSpecializedLinkage(thunk, thunk->getLinkage()), thunkType,
            thunk->isBare(), thunk->isTransparent(), thunk->isSerialized(),
            thunk->isDynamicallyReplaceable(), ProfileCounter(),
            thunk->isThunk());
        // If new thunk is newly created: clone the old thunk body, wrap the
        // returned function value with an `differentiable_function`
        // instruction, and process the `differentiable_function` instruction.
        if (newThunk->empty()) {
          if (auto newThunkGenSig = thunkType->getGenericSignature())
            newThunk->setGenericEnvironment(
                newThunkGenSig->getGenericEnvironment());
          newThunk->setOwnershipEliminated();
          BasicTypeSubstCloner cloner(thunk, newThunk);
          cloner.run();
          auto *retInst =
              cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
          SILBuilder thunkBuilder(retInst);
          auto *dfi = createDifferentiableFunction(thunkBuilder, loc,
                                                   parameterIndices,
                                                   retInst->getOperand());
          resultIndices[dfi] = resultIndex;
          thunkBuilder.createReturn(loc, dfi);
          retInst->eraseFromParent();

          getGeneratedFunctions().push_back(newThunk);
          getDifferentiableFunctionInsts().push_back(dfi);
          if (processDifferentiableFunctionInst(dfi))
            return nullptr;
        }

        // Apply the new curry thunk.
        auto *newThunkRef = builder.createFunctionRef(loc, newThunk);
        getGeneratedFunctionReferences().push_back(newThunkRef);
        SmallVector<SILValue, 8> newArgs;
        SmallVector<SILValue, 8> newArgsToDestroy;
        SmallVector<AllocStackInst *, 1> newBuffersToDealloc;
        copyParameterArgumentsForApply(ai, newArgs, newArgsToDestroy,
                                       newBuffersToDealloc);
        auto *newApply = builder.createApply(
            ai->getLoc(), newThunkRef, ai->getSubstitutionMap(), newArgs,
            ai->isNonThrowing());
        for (auto arg : newArgsToDestroy) {
          if (arg->getType().isObject())
            builder.emitDestroyValueOperation(loc, arg);
          else
            builder.emitDestroyAddr(loc, arg);
        }
        for (auto *alloc : newBuffersToDealloc)
          builder.createDeallocStack(loc, alloc);
        return newApply;
      }
    }
  }

  SILAutoDiffIndices desiredIndices(resultIndex, parameterIndices);
  SmallVector<SILValue, 2> derivativeFns;
  SmallVector<AllocStackInst *, 2> newBuffersToDealloc;
  for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
                           AutoDiffDerivativeFunctionKind::VJP}) {
    auto derivativeFnAndIndices = emitDerivativeFunctionReference(
        *this, builder, desiredIndices, derivativeFnKind, origFnOperand, invoker,
        newBuffersToDealloc);
    // Show an error at the operator, highlight the argument, and show a note
    // at the definition site of the argument.
    if (!derivativeFnAndIndices)
      return nullptr;

    auto derivativeFn = derivativeFnAndIndices->first;
    getGeneratedFunctionReferences().push_back(derivativeFn);

    // If desired indices are a subset of actual indices, create a "subset
    // indices thunk" and destroy the emitted derivative function reference.
    // - For JVPs: the thunked JVP returns a differential taking fewer
    //   parameters (using `.zero` for the dropped parameters).
    // - For VJPs: the thunked VJP returns a pullback that drops the unused
    //   tangent values.
    auto actualIndices = derivativeFnAndIndices->second;
    // NOTE: `desiredIndices` may come from a partially-applied function and
    // have smaller capacity than `actualIndices`. We expect this logic to go
    // away when we support `@differentiable` partial apply.
    // if (actualIndices != desiredIndices) { // TODO: Re-enable.
    auto extendedDesiredIndices = desiredIndices.parameters->extendingCapacity(
        getASTContext(), actualIndices.parameters->getCapacity());
    if (actualIndices.source != desiredIndices.source ||
        !actualIndices.parameters->equals(extendedDesiredIndices)) {
      // Destroy the already emitted derivative function reference because it
      // is no longer used.
      builder.emitDestroyValueOperation(loc, derivativeFn);
      // Check if underlying original function reference has been partially
      // applied with arguments. If so, produce an error: parameter subset
      // thunks do not yet support this case because partially applied arguments
      // cannot be propagated to parameter subset thunks.
      auto didPartiallyApplyArguments = [](SILValue original) {
        while (auto *pai =
                   peerThroughFunctionConversions<PartialApplyInst>(original)) {
          if (pai->getNumArguments() > 0)
            return true;
          original = pai->getCallee();
        }
        return false;
      };
      if (didPartiallyApplyArguments(origFnOperand)) {
        emitNondifferentiabilityError(
            origFnOperand, invoker,
            diag::autodiff_cannot_param_subset_thunk_partially_applied_orig_fn);
        return nullptr;
      }
      // Create the parameter subset thunk.
      assert(actualIndices.parameters->isSupersetOf(extendedDesiredIndices));
      SILFunction *thunk;
      SubstitutionMap interfaceSubs;
      std::tie(thunk, interfaceSubs) =
          getOrCreateSubsetParametersThunkForDerivativeFunction(
              origFnOperand, derivativeFn, derivativeFnKind, desiredIndices,
              actualIndices);
      auto *thunkFRI = builder.createFunctionRef(loc, thunk);
      if (auto genSig =
              thunk->getLoweredFunctionType()->getGenericSignature()) {
        derivativeFn = builder.createPartialApply(
            loc, thunkFRI, interfaceSubs, {},
            ParameterConvention::Direct_Guaranteed);
      } else {
        derivativeFn = thunkFRI;
      }
    }
    auto expectedDerivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
        parameterIndices, resultIndex, derivativeFnKind, getTypeConverter(),
        LookUpConformanceInModule(getModule().getSwiftModule()));
    // If `derivativeFn` is `@convention(thin)` but is expected to be
    // `@convention(thick)`, emit a `thin_to_thick` instruction.
    if (expectedDerivativeFnTy->getRepresentation()
            == SILFunctionTypeRepresentation::Thick &&
        derivativeFn->getType().castTo<SILFunctionType>()->getRepresentation()
            == SILFunctionTypeRepresentation::Thin) {
      derivativeFn = builder.createThinToThickFunction(
          loc, derivativeFn, SILType::getPrimitiveObjectType(expectedDerivativeFnTy));
    }

    derivativeFns.push_back(derivativeFn);
  }
  // Deallocate temporary buffers used for creating derivative functions.
  for (auto *buf : reversed(newBuffersToDealloc))
    builder.createDeallocStack(loc, buf);

  auto origFnCopy = builder.emitCopyValueOperation(loc, origFnOperand);
  auto *newDFI = createDifferentiableFunction(
      builder, loc, parameterIndices, origFnCopy,
      std::make_pair(derivativeFns[0], derivativeFns[1]));
  resultIndices[dfi] = resultIndex;
  getDifferentiableFunctionInsts().push_back(dfi);

  return newDFI;
}

/// Fold `differentiable_function_extract` users of the given
/// `differentiable_function` instruction, directly replacing them with
/// `differentiable_function` instruction operands. If the
/// `differentiable_function` instruction has no remaining uses, delete the
/// instruction itself after folding.
///
/// Folding can be disabled by the `SkipFoldingDifferentiableFunctionExtraction`
/// flag for SIL testing purposes.
// FIXME: This function is not correctly detecting the foldable pattern and
// needs to be rewritten.
void ADContext::foldDifferentiableFunctionExtraction(
    DifferentiableFunctionInst *source) {
  // Iterate through all `differentiable_function` instruction uses.
  for (auto use : source->getUses()) {
    auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(use->getUser());
    // If user is not an `differentiable_function_extract` instruction, set flag
    // to false.
    if (!dfei)
      continue;
    // Fold original function extractors.
    if (dfei->getExtractee() ==
            NormalDifferentiableFunctionTypeComponent::Original) {
      auto originalFnValue = source->getOriginalFunction();
      dfei->replaceAllUsesWith(originalFnValue);
      dfei->eraseFromParent();
      continue;
    }
    // Fold derivative function extractors.
    auto derivativeFnValue =
        source->getDerivativeFunction(dfei->getDerivativeFunctionKind());
    dfei->replaceAllUsesWith(derivativeFnValue);
    dfei->eraseFromParent();
  }
  // If the `differentiable_function` instruction has no remaining uses, erase
  // it.
  if (isInstructionTriviallyDead(source)) {
    SILBuilder builder(source);
    builder.emitDestroyAddrAndFold(source->getLoc(), source->getJVPFunction());
    builder.emitDestroyAddrAndFold(source->getLoc(), source->getVJPFunction());
    source->eraseFromParent();
  }
  // Mark `source` as processed so that it won't be reprocessed after deletion.
  processedDifferentiableFunctionInsts.insert(source);
}

bool ADContext::processDifferentiableFunctionInst(
    DifferentiableFunctionInst *dfi) {
  LLVM_DEBUG({
    auto &s = getADDebugStream() << "Processing DifferentiableFunctionInst:\n";
    dfi->printInContext(s);
  });
  if (dfi->hasDerivativeFunctions())
    return false;

  SILFunction *parent = dfi->getFunction();
  auto loc = dfi->getLoc();
  SILBuilder builder(dfi);

  auto differentiableFnValue =
      promoteToDifferentiableFunction(dfi, builder, loc, dfi);
  // Mark `dfi` as processed so that it won't be reprocessed after deletion.
  processedDifferentiableFunctionInsts.insert(dfi);
  if (!differentiableFnValue)
    return true;
  // Replace all uses of `dfi`.
  dfi->replaceAllUsesWith(differentiableFnValue);
  // Destroy the original operand.
  builder.emitDestroyValueOperation(loc, dfi->getOriginalFunction());
  dfi->eraseFromParent();
  // If the promoted `@differentiable` function-typed value is an
  // `differentiable_function` instruction, fold
  // `differentiable_function_extract` instructions. If
  // `differentiable_function_extract` folding is disabled, return.
  if (!SkipFoldingDifferentiableFunctionExtraction)
    if (auto *newDFI =
            dyn_cast<DifferentiableFunctionInst>(differentiableFnValue))
      foldDifferentiableFunctionExtraction(newDFI);
  transform.invalidateAnalysis(
      parent, SILAnalysis::InvalidationKind::FunctionBody);
  return false;
}

/// AD pass entry.
void Differentiation::run() {
  auto &module = *getModule();
  auto &astCtx = module.getASTContext();
  debugDump(module);

  // A global differentiation context.
  ADContext context(*this);

  bool errorOccurred = false;

  // Register all `@differentiable` attributes and `differentiable_function`
  // instructions in the module that trigger differentiation.
  for (SILFunction &f : module) {
    for (auto *diffAttr : f.getDifferentiableAttrs()) {
      DifferentiationInvoker invoker(diffAttr);
      assert(!context.getInvokers().count(diffAttr) &&
             "[differentiable] attribute already has an invoker");
      context.getInvokers().insert({diffAttr, invoker});
      continue;
    }
    for (SILBasicBlock &bb : f) {
      for (SILInstruction &i : bb) {
        if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i))
          context.getDifferentiableFunctionInsts().push_back(dfi);
        else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) {
          astCtx.Diags.diagnose(
              lfi->getLoc().getSourceLoc(),
              diag::autodiff_conversion_to_linear_function_not_supported);
          errorOccurred = true;
        }
      }
    }
  }

  // If nothing has triggered differentiation, there's nothing to do.
  if (context.getInvokers().empty() &&
      context.getDifferentiableFunctionInsts().empty())
    return;

  // AD relies on stdlib (the Swift module). If it's not imported, it's an
  // internal error.
  if (!astCtx.getStdlibModule()) {
    astCtx.Diags.diagnose(SourceLoc(),
                          diag::autodiff_internal_swift_not_imported);
    return;
  }

  // Process all `[differentiable]` attributes.
  for (auto invokerPair : context.getInvokers()) {
    auto *attr = invokerPair.first;
    auto *original = attr->getOriginal();
    auto invoker = invokerPair.second;
    errorOccurred |=
        context.processDifferentiableAttribute(original, attr, invoker);
  }

  // Iteratively process `differentiable_function` instruction worklist.
  while (!context.getDifferentiableFunctionInsts().empty()) {
    auto *dfi = context.getDifferentiableFunctionInsts().back();
    context.getDifferentiableFunctionInsts().pop_back();
    // Skip instructions that have been already been processed.
    if (context.getProcessedDifferentiableFunctionInsts().count(dfi)) continue;
    errorOccurred |= context.processDifferentiableFunctionInst(dfi);
  }

  // If any error occurred while processing `[differentiable]` attributes or
  // `differentiable_function` instructions, clean up.
  if (errorOccurred) {
    context.cleanUp();
    return;
  }

  LLVM_DEBUG(getADDebugStream() << "All differentiation finished\n");
}

//===----------------------------------------------------------------------===//
// Pass creation
//===----------------------------------------------------------------------===//

SILTransform *swift::createDifferentiation() {
  return new Differentiation;
}
