//===--- Differentiation.cpp - SIL Automatic Differentiation --*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2018 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file implements automatic differentiation.
//
//===----------------------------------------------------------------------===//

#define DEBUG_TYPE "differentiation"

#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/CommandLine.h"
#include "swift/AST/ASTMangler.h"
#include "swift/AST/ASTPrinter.h"
#include "swift/AST/AnyFunctionRef.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Builtins.h"
#include "swift/AST/DeclContext.h"
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/Expr.h"
#include "swift/AST/GenericEnvironment.h"
#include "swift/AST/GenericSignatureBuilder.h"
#include "swift/AST/LazyResolver.h"
#include "swift/AST/ParameterList.h"
#include "swift/AST/SourceFile.h"
#include "swift/AST/SubstitutionMap.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/SIL/FormalLinkage.h"
#include "swift/SIL/PrettyStackTrace.h"
#include "swift/SIL/SILBuilder.h"
#include "swift/SIL/TypeSubstCloner.h"
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
#include "swift/SILOptimizer/Analysis/DominanceAnalysis.h"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
#include "swift/SILOptimizer/Differentiation/JVPCloner.h"
#include "swift/SILOptimizer/Differentiation/Thunk.h"
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
#include "swift/SILOptimizer/PassManager/Passes.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/BreadthFirstIterator.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/Support/CommandLine.h"

using namespace swift;
using namespace swift::autodiff;
using llvm::DenseMap;
using llvm::SmallDenseMap;
using llvm::SmallDenseSet;
using llvm::SmallMapVector;
using llvm::SmallSet;

/// This flag enables experimental `@differentiable(linear)` function
/// transposition.
static llvm::cl::opt<bool> EnableExperimentalLinearMapTransposition(
    "enable-experimental-linear-map-transposition", llvm::cl::init(false));

/// This flag is used to disable `differentiable_function_extract` instruction
/// folding for SIL testing purposes.
static llvm::cl::opt<bool> SkipFoldingDifferentiableFunctionExtraction(
    "differentiation-skip-folding-differentiable-function-extraction",
    llvm::cl::init(true));

//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//

/// Given a dumpable value, dumps it to `llvm::dbgs()`.
template <typename T> static inline void debugDump(T &v) {
  LLVM_DEBUG(llvm::dbgs() << "\n==== BEGIN DEBUG DUMP ====\n"
                          << v << "\n==== END DEBUG DUMP ====\n");
}

namespace {

class DifferentiationTransformer {
private:
  /// Reference to the main transform.
  SILModuleTransform &transform;

  /// Context necessary for performing the transformations.
  ADContext context;

  /// Promotes the given `differentiable_function` instruction to a valid
  /// `@differentiable` function-typed value.
  SILValue promoteToDifferentiableFunction(DifferentiableFunctionInst *inst,
                                           SILBuilder &builder, SILLocation loc,
                                           DifferentiationInvoker invoker);

  /// Given a `linear_function` instruction that is missing a transpose operand,
  /// return a new `linear_function` instruction with the transpose filled in.
  SILValue promoteToLinearFunction(LinearFunctionInst *inst,
                                   SILBuilder &builder, SILLocation loc,
                                   DifferentiationInvoker invoker);

public:
  /// Construct an `DifferentiationTransformer` for the given module.
  explicit DifferentiationTransformer(SILModuleTransform &transform)
      : transform(transform), context(transform) {}

  SILModuleTransform &getTransform() { return transform; }

  ADContext &getContext() { return context; }

  /// Canonicalize the given witness, filling in derivative functions if
  /// missing.
  ///
  /// Generated derivative functions have the same linkage as the witness.
  ///
  /// \param serializeFunctions specifies whether generated functions should be
  ///        serialized.
  bool canonicalizeDifferentiabilityWitness(
      SILFunction *original, SILDifferentiabilityWitness *witness,
      DifferentiationInvoker invoker, IsSerialized_t serializeFunctions);

  /// Process the given `differentiable_function` instruction, filling in
  /// missing derivative functions if necessary.
  bool processDifferentiableFunctionInst(DifferentiableFunctionInst *dfi);

  /// Process the given `linear_function` instruction, filling in the missing
  /// transpose function if necessary.
  bool processLinearFunctionInst(LinearFunctionInst *lfi);

  /// Fold `differentiable_function_extract` users of the given
  /// `differentiable_function` instruction, directly replacing them with
  /// `differentiable_function` instruction operands. If the
  /// `differentiable_function` instruction has no remaining uses, delete the
  /// instruction itself after folding.
  ///
  /// Folding can be disabled by the
  /// `SkipFoldingDifferentiableFunctionExtraction` flag for SIL testing
  /// purposes.
  void foldDifferentiableFunctionExtraction(DifferentiableFunctionInst *source);
};

} // end anonymous namespace

/// If the original function doesn't have a return, it cannot be differentiated.
/// Returns true if error is emitted.
static bool diagnoseNoReturn(ADContext &context, SILFunction *original,
                             DifferentiationInvoker invoker) {
  if (original->findReturnBB() != original->end())
    return false;
  context.emitNondifferentiabilityError(
      original->getLocation().getEndSourceLoc(), invoker,
      diag::autodiff_missing_return);
  return true;
}

/// If the original function contains unsupported control flow, emit a "control
/// flow unsupported" error at appropriate source locations. Returns true if
/// error is emitted.
///
/// Update as control flow support is added.
static bool diagnoseUnsupportedControlFlow(ADContext &context,
                                           SILFunction *original,
                                           DifferentiationInvoker invoker) {
  if (original->getBlocks().size() <= 1)
    return false;
  // Diagnose unsupported branching terminators.
  for (auto &bb : *original) {
    auto *term = bb.getTerminator();
    // Check supported branching terminators.
    if (isa<BranchInst>(term) || isa<CondBranchInst>(term) ||
        isa<SwitchEnumInst>(term) || isa<SwitchEnumAddrInst>(term) ||
        isa<CheckedCastBranchInst>(term) ||
        isa<CheckedCastValueBranchInst>(term) ||
        isa<CheckedCastAddrBranchInst>(term) || isa<TryApplyInst>(term))
      continue;
    // If terminator is an unsupported branching terminator, emit an error.
    if (term->isBranch()) {
      context.emitNondifferentiabilityError(
          term, invoker, diag::autodiff_control_flow_not_supported);
      return true;
    }
  }
  return false;
}

/// Check whether the given requirements are satisfied, with the given
/// derivative generic signature (containing requirements), and substitution
/// map. Returns true if error is emitted.
static bool diagnoseUnsatisfiedRequirements(ADContext &context,
                                            CanSILFunctionType origFnTy,
                                            GenericSignature derivativeGenSig,
                                            SubstitutionMap substMap,
                                            DifferentiationInvoker invoker,
                                            SourceLoc loc) {
  // If the original function is polymorphic and its generic signature is the
  // same as the derivative generic signature, then the requirements are
  // satisfied. This check is necessary because the subsequent logic does not
  // correctly handle polymorphic original functions.
  // TODO(TF-1055): Can be removed after we have a robust solution for TF-1055.
  if (origFnTy->getInvocationGenericSignature() && derivativeGenSig &&
      origFnTy->getInvocationGenericSignature()->isEqual(derivativeGenSig))
    return false;

  // If there are no derivative requirements, return false.
  if (!derivativeGenSig)
    return false;
  auto requirements = derivativeGenSig->getRequirements();
  if (requirements.empty())
    return false;
  // Iterate through all requirements and check whether they are satisfied.
  auto *swiftModule = context.getModule().getSwiftModule();
  SmallVector<Requirement, 2> unsatisfiedRequirements;
  for (auto req : requirements) {
    auto firstType = req.getFirstType();
    Type secondType;
    // Substitute first and second types using the given substitution map,
    // looking up conformances in the current module, if possible.
    if (auto substFirstType =
            firstType.subst(QuerySubstitutionMap{substMap},
                            LookUpConformanceInModule(swiftModule))) {
      firstType = substFirstType;
    }
    if (req.getKind() != RequirementKind::Layout) {
      secondType = req.getSecondType();
      if (auto substSecondType =
              secondType.subst(QuerySubstitutionMap{substMap},
                               LookUpConformanceInModule(swiftModule))) {
        secondType = substSecondType;
      }
    }
    switch (req.getKind()) {
    // Check layout requirements.
    case RequirementKind::Layout: {
      auto layout = req.getLayoutConstraint();
      switch (layout->getKind()) {
      case LayoutConstraintKind::Class:
        if (!firstType->satisfiesClassConstraint())
          unsatisfiedRequirements.push_back(req);
        continue;
      default:
        // TODO: Check other layout requirements. Note that `@differentiable`
        // attribute type-checking does not yet support layout requirements in
        // where clauses; layout requirements in derivative generic signatures
        // can be formed only from `differentiable_function` instructions whose
        // original function operand is generic with layout requirements.
        break;
      }
      continue;
    }
    // Check same type requirements.
    case RequirementKind::SameType:
      // If the first type does not equal the second type, then record the
      // unsatisfied requirement.
      if (!firstType->isEqual(secondType))
        unsatisfiedRequirements.push_back(req);
      continue;
    // Check superclass requirements.
    case RequirementKind::Superclass: {
      // If the second type is not an exact superclass of second type, then
      // record the unsatisfied requirement.
      if (!secondType->isExactSuperclassOf(firstType))
        unsatisfiedRequirements.push_back(req);
      continue;
    }
    // Check conformance requirements.
    case RequirementKind::Conformance: {
      auto protocolType = req.getSecondType()->castTo<ProtocolType>();
      auto protocol = protocolType->getDecl();
      assert(protocol && "Expected protocol in generic signature requirement");
      // If the first type does not conform to the second type in the current
      // module, then record the unsatisfied requirement.
      if (!swiftModule->lookupConformance(firstType, protocol))
        unsatisfiedRequirements.push_back(req);
      continue;
    }
    }
  }
  if (unsatisfiedRequirements.empty())
    return false;
  // Diagnose unsatisfied requirements.
  std::string reqText;
  llvm::raw_string_ostream stream(reqText);
  interleave(
      unsatisfiedRequirements,
      [&](Requirement req) { req.print(stream, PrintOptions()); },
      [&] { stream << ", "; });
  context.emitNondifferentiabilityError(
      loc, invoker, diag::autodiff_function_assoc_func_unmet_requirements,
      stream.str());
  return true;
}

//===----------------------------------------------------------------------===//
// Code emission utilities
//===----------------------------------------------------------------------===//

/// Given an apply site, emit copies of all parameters and place them in
/// `copiedArgs`. Any buffers that need to be destroyed will be added to
/// `newArgsToDestroy`. Any new buffers that need to be deallocated will be
/// added to `newBuffersToDealloc`. This helper is used for duplicating an
/// apply site.
static void copyParameterArgumentsForApply(
    ApplySite applySite, SmallVectorImpl<SILValue> &copiedArgs,
    SmallVectorImpl<SILValue> &newArgsToDestroy,
    SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
  LLVM_DEBUG({
    auto &s = getADDebugStream() << "Copying arguments from apply site: ";
    applySite.getInstruction()->print(s);
  });
  auto loc = applySite.getLoc();
  copiedArgs.reserve(applySite.getNumArguments());
  SILBuilderWithScope copyBuilder(applySite.getInstruction());
  for (auto &argOperand : applySite.getArgumentOperands()) {
    auto arg = argOperand.get();
    auto argConv = applySite.getArgumentConvention(argOperand);
    auto collectNewArg = [&](SILValue newArg) {
      copiedArgs.push_back(newArg);
      if (argConv.isGuaranteedConvention() &&
          argConv != SILArgumentConvention::Indirect_InoutAliasable)
        newArgsToDestroy.push_back(newArg);
    };
    // Copy the argument if it's to be owned by the newly created closure.
    // Objects are to be retained.
    if (arg->getType().isObject()) {
      auto newArg = arg;
      if (newArg.getOwnershipKind() != OwnershipKind::None)
        newArg = copyBuilder.emitCopyValueOperation(loc, arg);
      collectNewArg(newArg);
      continue;
    }
    // Addresses depend on argument conventions.
    // If the argument is an aliasable inout reference, do not copy the
    // argument since it's a `@noescape` capture.
    if (argConv == SILArgumentConvention::Indirect_InoutAliasable) {
      collectNewArg(arg);
      continue;
    }
    // Otherwise, it must be address-only. Create a new buffer and perform
    // `copy_addr`.
    auto *argCopy = copyBuilder.createAllocStack(loc, arg->getType());
    newBuffersToDealloc.push_back(argCopy);
    copyBuilder.createCopyAddr(loc, arg, argCopy, IsNotTake, IsInitialization);
    collectNewArg(argCopy);
  }
}

/// When a function value is used in an instruction (usually `apply`), there may
/// be conversion instructions in between, e.g. `thin_to_thick_function`. Given
/// a new function value and an old function value, this helper function
/// recursively converts the new function just like how the old function is
/// converted.
///
/// If the new function's generic signature is specified, it is used
/// to create substitution maps for reapplied `partial_apply` instructions.
static SILValue reapplyFunctionConversion(
    ADContext &context, SILValue newFunc, SILValue oldFunc,
    SILValue oldConvertedFunc, SILBuilder &builder, SILLocation loc,
    SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc,
    IndexSubset *parameterIndices, IndexSubset *resultIndices,
    GenericSignature newFuncGenSig = GenericSignature()) {
  // If the old func is the new func, then there's no conversion.
  if (oldFunc == oldConvertedFunc)
    return newFunc;
  // Handle a few instruction cases.
  // copy_value
  if (auto *cvi = dyn_cast<CopyValueInst>(oldConvertedFunc)) {
    // Note: no `copy_value` is needed for the re-converted function because the
    // caller of `reapplyFunctionConversion` should consume the re-converted
    // function.
    return reapplyFunctionConversion(
        context, newFunc, oldFunc, cvi->getOperand(), builder, loc,
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
  }
  // begin_borrow
  if (auto *bbi = dyn_cast<BeginBorrowInst>(oldConvertedFunc)) {
    // Note: no `begin_borrow` is needed for the re-converted function because
    // the caller of `reapplyFunctionConversion` should consume the re-converted
    // function.
    return reapplyFunctionConversion(
        context, newFunc, oldFunc, bbi->getOperand(), builder, loc,
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
  }
  // convert_function
  if (auto *cfi = dyn_cast<ConvertFunctionInst>(oldConvertedFunc)) {
    return reapplyFunctionConversion(
        context, newFunc, oldFunc, cfi->getOperand(), builder, loc,
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
  }
  // thin_to_thick_function
  if (auto *tttfi = dyn_cast<ThinToThickFunctionInst>(oldConvertedFunc)) {
    auto innerNewFunc = reapplyFunctionConversion(
        context, newFunc, oldFunc, tttfi->getOperand(), builder, loc,
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
    auto operandFnTy = innerNewFunc->getType().castTo<SILFunctionType>();
    auto thickTy = operandFnTy->getWithRepresentation(
        SILFunctionTypeRepresentation::Thick);
    auto silTy = SILType::getPrimitiveObjectType(thickTy);
    return builder.createThinToThickFunction(loc, innerNewFunc, silTy);
  }
  // partial_apply
  if (auto *pai = dyn_cast<PartialApplyInst>(oldConvertedFunc)) {
    SmallVector<SILValue, 8> newArgs;
    newArgs.reserve(pai->getNumArguments());
    SmallVector<SILValue, 1> newArgsToDestroy;
    copyParameterArgumentsForApply(pai, newArgs, newArgsToDestroy,
                                   newBuffersToDealloc);
    auto innerNewFunc = reapplyFunctionConversion(
        context, newFunc, oldFunc, pai->getCallee(), builder, loc,
        newBuffersToDealloc, parameterIndices, resultIndices, newFuncGenSig);
    // Reabstraction thunk `partial_apply` reapplications require special
    // support. Reabstraction thunk JVP/VJP expects a `@differentiable`
    // function-typed argument to avoid opaque function non-differentiability
    // errors. Thus, `partial_apply` reapplications must first form a
    // `differentiable_function` of the function-typed thunk argument.
    auto isReabstractionThunkCallee = [&]() -> bool {
      auto *fri = dyn_cast<FunctionRefInst>(oldFunc);
      return fri && fri->getReferencedFunctionOrNull()->isThunk() ==
                        IsReabstractionThunk;
    };
    if (isReabstractionThunkCallee()) {
      assert(newArgs.size() == 1 &&
             "Expected reabstraction thunk to be partially applied with only "
             "one argument");
      auto *dfi = context.createDifferentiableFunction(
          builder, loc, parameterIndices, resultIndices, newArgs.back());
      context.getDifferentiableFunctionInstWorklist().push_back(dfi);
      newArgs.back() = dfi;
    }
    // Compute substitution map for reapplying `partial_apply`.
    // - If reapplied functoin is not polymorphic, use empty substitution map
    //   regardless of the original `partial_apply`'s substitution map.
    //   - This case is triggered for reapplying `partial_apply` where `newFunc`
    //     is a `differentiability_witness_function` where the witness generic
    //     signature has all concrete parameters while the original function's
    //     generic signature does not. In this case, the original function type
    //     is polymorphic while derivative function types are not (specialized
    //     with concrete types from same-type requirements).
    // - Otherwise, if `newFuncGenSig` is not specified, use the original
    //   `partial_apply`'s substitution map.
    // - Otherwise, if `newFuncGenSig` is specified, combine it with the
    //   original `partial_apply`'s substitution map.
    SubstitutionMap substMap;
    if (innerNewFunc->getType().castTo<SILFunctionType>()->isPolymorphic()) {
      if (!newFuncGenSig) {
        substMap = pai->getSubstitutionMap();
      } else {
        substMap = SubstitutionMap::get(
            newFuncGenSig, QuerySubstitutionMap{pai->getSubstitutionMap()},
            LookUpConformanceInModule(builder.getModule().getSwiftModule()));
      }
    }
    return builder.createPartialApply(loc, innerNewFunc, substMap, newArgs,
                                      ParameterConvention::Direct_Guaranteed);
  }
  llvm_unreachable("Unhandled function conversion instruction");
}

/// Emits a reference to a derivative function of `original`, differentiated
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
/// the derivative function and the actual indices that the derivative function
/// is with respect to.
///
/// Returns `None` on failure, signifying that a diagnostic has been emitted
/// using `invoker`.
static Optional<std::pair<SILValue, AutoDiffConfig>>
emitDerivativeFunctionReference(
    DifferentiationTransformer &transformer, SILBuilder &builder,
    AutoDiffConfig desiredConfig, AutoDiffDerivativeFunctionKind kind,
    SILValue original, DifferentiationInvoker invoker,
    SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
  ADContext &context = transformer.getContext();

  // If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
  // matches the given kind and desired differentiation parameter indices,
  // simply extract the derivative function of its function operand, retain the
  // derivative function, and return it.
  if (auto *inst = original->getDefiningInstruction())
    if (auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(inst))
      if (dfei->getExtractee() ==
          NormalDifferentiableFunctionTypeComponent::Original)
        original = dfei->getOperand();

  // If `original` is a `@differentiable` function, just extract the
  // derivative function.
  if (auto diffableFnType = original->getType().castTo<SILFunctionType>()) {
    if (diffableFnType->isDifferentiable()) {
      auto paramIndices =
          diffableFnType->getDifferentiabilityParameterIndices();
      for (auto i : desiredConfig.parameterIndices->getIndices()) {
        if (!paramIndices->contains(i)) {
          context.emitNondifferentiabilityError(
              original, invoker,
              diag::
                  autodiff_function_noderivative_parameter_not_differentiable);
          return None;
        }
      }
      auto borrowedDiffFunc =
          builder.emitBeginBorrowOperation(original.getLoc(), original);
      SILValue derivativeFn = builder.createDifferentiableFunctionExtract(
          borrowedDiffFunc.getLoc(), kind, borrowedDiffFunc);
      if (derivativeFn.getOwnershipKind() != OwnershipKind::None)
        derivativeFn =
            builder.emitCopyValueOperation(original.getLoc(), derivativeFn);
      builder.emitEndBorrowOperation(original.getLoc(), borrowedDiffFunc);
      return std::make_pair(derivativeFn, desiredConfig);
    }
  }

  // Handle `function_ref` original function.
  if (auto *originalFRI =
          peerThroughFunctionConversions<FunctionRefInst>(original)) {
    auto loc = originalFRI->getLoc();
    auto *originalFn = originalFRI->getReferencedFunctionOrNull();
    assert(originalFn);
    auto originalFnTy = originalFn->getLoweredFunctionType();
    auto *desiredParameterIndices = desiredConfig.parameterIndices;
    auto *desiredResultIndices = desiredConfig.resultIndices;
    // NOTE(TF-893): Extending capacity is necessary when `originalFnTy` has
    // parameters corresponding to captured variables.
    // TODO: If posssible, change `autodiff::getLoweredParameterIndices` to
    // take `CaptureInfo` into account.
    if (originalFnTy->getNumParameters() >
        desiredParameterIndices->getCapacity()) {
      desiredParameterIndices = desiredParameterIndices->extendingCapacity(
          context.getASTContext(), originalFnTy->getNumParameters());
    }
    // Look up a differentiability witness with the exact configuration.
    auto *minimalWitness = getExactDifferentiabilityWitness(
        context.getModule(), originalFn, desiredParameterIndices,
        desiredResultIndices);
    // Otherwise, look up a differentiability witness with a minimal superset
    // configuration.
    if (!minimalWitness)
      minimalWitness = getOrCreateMinimalASTDifferentiabilityWitness(
          context.getModule(), originalFn, desiredParameterIndices,
          desiredResultIndices);
    // If no minimal witness exists, check non-differentiable cases before
    // creating a new private differentiability witness.
    if (!minimalWitness) {
      // If the function is intentionally marked as being opaque to
      // differentiation, then we should not create a task for it.
      if (originalFn->hasSemanticsAttr("autodiff.opaque")) {
        context.emitNondifferentiabilityError(
            original, invoker,
            diag::autodiff_opaque_function_not_differentiable);
        return None;
      }
      // Check and diagnose non-differentiable arguments.
      auto originalFnTy = originalFn->getLoweredFunctionType();
      for (unsigned paramIndex : range(originalFnTy->getNumParameters())) {
        if (desiredConfig.isWrtParameter(paramIndex) &&
            !originalFnTy->getParameters()[paramIndex]
                 .getSILStorageInterfaceType()
                 .isDifferentiable(context.getModule())) {
          auto diag = context.emitNondifferentiabilityError(
              original, invoker, diag::autodiff_nondifferentiable_argument);
          return None;
        }
      }
      // Check and diagnose non-differentiable results.
      for (auto resultIndex : desiredResultIndices->getIndices()) {
        SILType resultType;
        if (resultIndex >= originalFnTy->getNumResults()) {
          auto inoutParamIdx = resultIndex - originalFnTy->getNumResults();
          auto inoutParam =
              *std::next(originalFnTy->getIndirectMutatingParameters().begin(),
                         inoutParamIdx);
          resultType = inoutParam.getSILStorageInterfaceType();
        } else {
          resultType = originalFnTy->getResults()[resultIndex]
                           .getSILStorageInterfaceType();
        }
        if (!resultType.isDifferentiable(context.getModule())) {
          context.emitNondifferentiabilityError(
              original, invoker, diag::autodiff_nondifferentiable_result);
          return None;
        }
      }
      // Check and diagnose external declarations.
      if (originalFn->isExternalDeclaration()) {
        context.emitNondifferentiabilityError(
            original, invoker,
            diag::autodiff_external_nondifferentiable_function);
        return None;
      }
      // Sanity check passed. Create a new differentiability witness and
      // canonicalize it.
      GenericSignature contextualDerivativeGenSig = GenericSignature();
      if (invoker.getKind() ==
          DifferentiationInvoker::Kind::IndirectDifferentiation)
        contextualDerivativeGenSig =
            invoker.getIndirectDifferentiation()
                .second->getDerivativeGenericSignature();
      auto derivativeConstrainedGenSig =
          autodiff::getConstrainedDerivativeGenericSignature(
              originalFn->getLoweredFunctionType(), desiredParameterIndices,
              contextualDerivativeGenSig,
              LookUpConformanceInModule(context.getModule().getSwiftModule()));
      minimalWitness = SILDifferentiabilityWitness::createDefinition(
          context.getModule(), SILLinkage::Private, originalFn,
          desiredParameterIndices, desiredResultIndices,
          derivativeConstrainedGenSig, /*jvp*/ nullptr,
          /*vjp*/ nullptr, /*isSerialized*/ false);
      if (transformer.canonicalizeDifferentiabilityWitness(
              originalFn, minimalWitness, invoker, IsNotSerialized))
        return None;
    }
    assert(minimalWitness);
    if (original->getFunction()->isSerialized() &&
        !hasPublicVisibility(minimalWitness->getLinkage())) {
      enum { Inlinable = 0, DefaultArgument = 1 };
      unsigned fragileKind = Inlinable;
      // FIXME: This is not a very robust way of determining if the function is
      // a default argument. Also, we have not exhaustively listed all the kinds
      // of fragility.
      if (original->getFunction()->getLinkage() == SILLinkage::PublicNonABI)
        fragileKind = DefaultArgument;
      context.emitNondifferentiabilityError(
          original, invoker, diag::autodiff_private_derivative_from_fragile,
          fragileKind,
          llvm::isa_and_nonnull<AbstractClosureExpr>(
              originalFRI->getLoc().getAsASTNode<Expr>()));
      return None;
    }
    // TODO(TF-482): Move generic requirement checking logic to
    // `getExactDifferentiabilityWitness` and
    // `getOrCreateMinimalASTDifferentiabilityWitness`.
    // Get the substitution map for checking unmet generic requirements.
    // By default, use the forwarding substitution map of the original function.
    // If the original callee is a `partial_apply` or `apply` instruction, use
    // its substitution map instead.
    auto substMap = original->getFunction()->getForwardingSubstitutionMap();
    if (auto *pai =
            peerThroughFunctionConversions<PartialApplyInst>(original)) {
      substMap = pai->getSubstitutionMap();
    } else if (auto *ai = peerThroughFunctionConversions<ApplyInst>(original)) {
      substMap = ai->getSubstitutionMap();
    }
    if (diagnoseUnsatisfiedRequirements(
            context, original->getType().castTo<SILFunctionType>(),
            minimalWitness->getDerivativeGenericSignature(), substMap, invoker,
            original.getLoc().getSourceLoc()))
      return None;
    DifferentiabilityWitnessFunctionKind witnessKind;
    switch (kind) {
    case AutoDiffDerivativeFunctionKind::JVP:
      witnessKind = DifferentiabilityWitnessFunctionKind::JVP;
      break;
    case AutoDiffDerivativeFunctionKind::VJP:
      witnessKind = DifferentiabilityWitnessFunctionKind::VJP;
      break;
    }
    auto *derivativeFnRef = builder.createDifferentiabilityWitnessFunction(
        loc, witnessKind, minimalWitness);
    auto convertedRef = reapplyFunctionConversion(
        context, derivativeFnRef, originalFRI, original, builder, loc,
        newBuffersToDealloc, desiredConfig.parameterIndices,
        desiredConfig.resultIndices,
        derivativeFnRef->getType()
            .getASTType()
            ->castTo<SILFunctionType>()
            ->getSubstGenericSignature());
    return std::make_pair(convertedRef, minimalWitness->getConfig());
  }

  // Handle `witness_method`.
  if (auto *witnessMethod =
          peerThroughFunctionConversions<WitnessMethodInst>(original)) {
    auto loc = witnessMethod->getLoc();
    auto requirementDeclRef = witnessMethod->getMember();
    auto *requirementDecl = requirementDeclRef.getAbstractFunctionDecl();
    // If requirement declaration does not have any derivative function
    // configurations, produce an error.
    if (requirementDecl->getDerivativeFunctionConfigurations().empty()) {
      context.emitNondifferentiabilityError(
          original, invoker, diag::autodiff_protocol_member_not_differentiable);
      return None;
    }
    // Find the minimal derivative configuration: minimal parameter indices and
    // corresponding derivative generic signature. If it does not exist, produce
    // an error.
    IndexSubset *minimalASTParamIndices = nullptr;
    auto minimalConfig = findMinimalDerivativeConfiguration(
        requirementDecl, desiredConfig.parameterIndices,
        minimalASTParamIndices);
    if (!minimalConfig) {
      context.emitNondifferentiabilityError(
          original, invoker,
          diag::autodiff_member_subset_indices_not_differentiable);
      return None;
    }
    // Emit a `witness_method` instruction for the derivative function.
    auto originalType = witnessMethod->getType().castTo<SILFunctionType>();
    auto assocType = originalType->getAutoDiffDerivativeFunctionType(
        minimalConfig->parameterIndices, minimalConfig->resultIndices, kind,
        context.getTypeConverter(),
        LookUpConformanceInModule(builder.getModule().getSwiftModule()));
    auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
        kind, minimalASTParamIndices, minimalConfig->derivativeGenericSignature,
        context.getASTContext());
    auto *ref = builder.createWitnessMethod(
        loc, witnessMethod->getLookupType(), witnessMethod->getConformance(),
        requirementDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
        SILType::getPrimitiveObjectType(assocType));
    auto convertedRef = reapplyFunctionConversion(
        context, ref, witnessMethod, original, builder, loc,
        newBuffersToDealloc, desiredConfig.parameterIndices,
        desiredConfig.resultIndices);
    return std::make_pair(convertedRef, *minimalConfig);
  }

  // Handle `class_method`.
  if (auto *classMethod =
          peerThroughFunctionConversions<ClassMethodInst>(original)) {
    auto loc = classMethod->getLoc();
    auto methodDeclRef = classMethod->getMember();
    auto *methodDecl = methodDeclRef.getAbstractFunctionDecl();
    // If method declaration does not have any derivative function
    // configurations, produce an error.
    if (methodDecl->getDerivativeFunctionConfigurations().empty()) {
      context.emitNondifferentiabilityError(
          original, invoker, diag::autodiff_class_member_not_differentiable);
      return None;
    }
    // Find the minimal derivative configuration: minimal parameter indices and
    // corresponding derivative generic signature. If it does not exist, produce
    // an error.
    IndexSubset *minimalASTParamIndices = nullptr;
    auto minimalConfig = findMinimalDerivativeConfiguration(
        methodDecl, desiredConfig.parameterIndices, minimalASTParamIndices);
    if (!minimalConfig) {
      context.emitNondifferentiabilityError(
          original, invoker,
          diag::autodiff_member_subset_indices_not_differentiable);
      return None;
    }
    // Emit a `class_method` instruction for the derivative function.
    auto originalType = classMethod->getType().castTo<SILFunctionType>();
    auto assocType = originalType->getAutoDiffDerivativeFunctionType(
        minimalConfig->parameterIndices, minimalConfig->resultIndices, kind,
        context.getTypeConverter(),
        LookUpConformanceInModule(builder.getModule().getSwiftModule()));
    auto *autoDiffFuncId = AutoDiffDerivativeFunctionIdentifier::get(
        kind, minimalASTParamIndices, minimalConfig->derivativeGenericSignature,
        context.getASTContext());
    auto *ref = builder.createClassMethod(
        loc, classMethod->getOperand(),
        methodDeclRef.asAutoDiffDerivativeFunction(autoDiffFuncId),
        SILType::getPrimitiveObjectType(assocType));
    auto convertedRef = reapplyFunctionConversion(
        context, ref, classMethod, original, builder, loc, newBuffersToDealloc,
        desiredConfig.parameterIndices, desiredConfig.resultIndices);
    return std::make_pair(convertedRef, *minimalConfig);
  }

  // Emit the general opaque function error.
  context.emitNondifferentiabilityError(
      original, invoker, diag::autodiff_opaque_function_not_differentiable);
  return None;
}

//===----------------------------------------------------------------------===//
// `SILDifferentiabilityWitness` processing
//===----------------------------------------------------------------------===//

static SILFunction *createEmptyVJP(ADContext &context, SILFunction *original,
                                   SILDifferentiabilityWitness *witness,
                                   IsSerialized_t isSerialized) {
  LLVM_DEBUG({
    auto &s = getADDebugStream();
    s << "Creating VJP:\n\t";
    s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
  });

  auto &module = context.getModule();
  auto originalTy = original->getLoweredFunctionType();
  auto config = witness->getConfig();

  // === Create an empty VJP. ===
  Mangle::ASTMangler mangler;
  auto vjpName =
      original->getASTContext()
          .getIdentifier(mangler.mangleAutoDiffDerivativeFunctionHelper(
              original->getName(), AutoDiffDerivativeFunctionKind::VJP,
              witness->getConfig()))
          .str();
  CanGenericSignature vjpCanGenSig;
  if (auto vjpGenSig = witness->getDerivativeGenericSignature())
    vjpCanGenSig = vjpGenSig->getCanonicalSignature();
  GenericEnvironment *vjpGenericEnv = nullptr;
  if (vjpCanGenSig && !vjpCanGenSig->areAllParamsConcrete())
    vjpGenericEnv = vjpCanGenSig->getGenericEnvironment();
  auto vjpType = originalTy->getAutoDiffDerivativeFunctionType(
      config.parameterIndices, config.resultIndices,
      AutoDiffDerivativeFunctionKind::VJP,
      module.Types, LookUpConformanceInModule(module.getSwiftModule()),
      vjpCanGenSig,
      /*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);

  SILOptFunctionBuilder fb(context.getTransform());
  auto *vjp = fb.createFunction(
      witness->getLinkage(), vjpName, vjpType, vjpGenericEnv,
      original->getLocation(), original->isBare(), IsNotTransparent,
      isSerialized, original->isDynamicallyReplaceable());
  vjp->setDebugScope(new (module) SILDebugScope(original->getLocation(), vjp));

  LLVM_DEBUG(llvm::dbgs() << "VJP type: " << vjp->getLoweredFunctionType()
                          << "\n");
  return vjp;
}

static SILFunction *createEmptyJVP(ADContext &context, SILFunction *original,
                                   SILDifferentiabilityWitness *witness,
                                   IsSerialized_t isSerialized) {
  LLVM_DEBUG({
    auto &s = getADDebugStream();
    s << "Creating JVP:\n\t";
    s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
  });

  auto &module = context.getModule();
  auto originalTy = original->getLoweredFunctionType();
  auto config = witness->getConfig();

  // === Create an empty JVP. ===
  Mangle::ASTMangler mangler;
  auto jvpName =
      original->getASTContext()
          .getIdentifier(mangler.mangleAutoDiffDerivativeFunctionHelper(
              original->getName(), AutoDiffDerivativeFunctionKind::JVP,
              witness->getConfig()))
          .str();
  CanGenericSignature jvpCanGenSig;
  if (auto jvpGenSig = witness->getDerivativeGenericSignature())
    jvpCanGenSig = jvpGenSig->getCanonicalSignature();
  GenericEnvironment *jvpGenericEnv = nullptr;
  if (jvpCanGenSig && !jvpCanGenSig->areAllParamsConcrete())
    jvpGenericEnv = jvpCanGenSig->getGenericEnvironment();
  auto jvpType = originalTy->getAutoDiffDerivativeFunctionType(
      config.parameterIndices, config.resultIndices,
      AutoDiffDerivativeFunctionKind::JVP,
      module.Types, LookUpConformanceInModule(module.getSwiftModule()),
      jvpCanGenSig,
      /*isReabstractionThunk*/ original->isThunk() == IsReabstractionThunk);

  SILOptFunctionBuilder fb(context.getTransform());
  auto *jvp = fb.createFunction(
      witness->getLinkage(), jvpName, jvpType, jvpGenericEnv,
      original->getLocation(), original->isBare(), IsNotTransparent,
      isSerialized, original->isDynamicallyReplaceable());
  jvp->setDebugScope(new (module) SILDebugScope(original->getLocation(), jvp));

  LLVM_DEBUG(llvm::dbgs() << "JVP type: " << jvp->getLoweredFunctionType()
                          << "\n");
  return jvp;
}

/// Apply the fatal error function with the given name of type
/// `@convention(thin) () -> Never` in `f`.
static void emitFatalError(ADContext &context, SILFunction *f,
                           StringRef fatalErrorFuncName) {
  auto *entry = f->createBasicBlock();
  createEntryArguments(f);
  SILBuilder builder(entry);
  auto loc = f->getLocation();
  // Destroy all owned arguments to pass ownership verification.
  for (auto *arg : entry->getArguments())
    if (arg->getOwnershipKind() == OwnershipKind::Owned)
      builder.emitDestroyOperation(loc, arg);
  // Fatal error with a nice message.
  auto neverResultInfo =
      SILResultInfo(context.getModule().getASTContext().getNeverType(),
                    ResultConvention::Unowned);
  // Fatal error function must have type `@convention(thin) () -> Never`.
  auto fatalErrorFnType = SILFunctionType::get(
      /*genericSig*/ nullptr, SILFunctionType::ExtInfo::getThin(),
      SILCoroutineKind::None, ParameterConvention::Direct_Unowned, {},
      /*interfaceYields*/ {}, neverResultInfo,
      /*interfaceErrorResults*/ None, {}, {}, context.getASTContext());
  auto fnBuilder = SILOptFunctionBuilder(context.getTransform());
  auto *fatalErrorFn = fnBuilder.getOrCreateFunction(
      loc, fatalErrorFuncName, SILLinkage::PublicExternal, fatalErrorFnType,
      IsNotBare, IsNotTransparent, IsNotSerialized, IsNotDynamic,
      ProfileCounter(), IsNotThunk);
  auto *fatalErrorFnRef = builder.createFunctionRef(loc, fatalErrorFn);
  builder.createApply(loc, fatalErrorFnRef, SubstitutionMap(), {});
  builder.createUnreachable(loc);
}

/// Returns true on error.
bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
    SILFunction *original, SILDifferentiabilityWitness *witness,
    DifferentiationInvoker invoker, IsSerialized_t serializeFunctions) {
  std::string traceMessage;
  llvm::raw_string_ostream OS(traceMessage);
  OS << "processing ";
  witness->print(OS);
  OS << " on";
  OS.flush();
  PrettyStackTraceSILFunction trace(traceMessage.c_str(), original);

  assert(witness->isDefinition());

  // If the JVP doesn't exist, need to synthesize it.
  if (!witness->getJVP()) {
    // Diagnose:
    // - Functions with no return.
    // - Functions with unsupported control flow.
    if (context.getASTContext()
            .LangOpts.EnableExperimentalForwardModeDifferentiation &&
        (diagnoseNoReturn(context, original, invoker) ||
         diagnoseUnsupportedControlFlow(context, original, invoker)))
      return true;

    // Create empty JVP.
    auto *jvp = createEmptyJVP(context, original, witness, serializeFunctions);
    witness->setJVP(jvp);
    context.recordGeneratedFunction(jvp);

    // For now, only do JVP generation if the flag is enabled and if custom VJP
    // does not exist. If custom VJP exists but custom JVP does not, skip JVP
    // generation because generated JVP may not match semantics of custom VJP.
    // Instead, create an empty JVP.
    if (context.getASTContext()
            .LangOpts.EnableExperimentalForwardModeDifferentiation &&
        !witness->getVJP()) {
      // JVP and differential generation do not currently support functions with
      // multiple basic blocks.
      if (original->getBlocks().size() > 1) {
        context.emitNondifferentiabilityError(
            original->getLocation().getSourceLoc(), invoker,
            diag::autodiff_jvp_control_flow_not_supported);
        return true;
      }
      // Emit JVP function.
      JVPCloner cloner(context, original, witness, jvp, invoker);
      if (cloner.run())
        return true;
    } else {
      // If JVP generation is disabled or a user-defined custom VJP function
      // exists, fatal error with a nice message.
      emitFatalError(context, jvp,
                     "_fatalErrorForwardModeDifferentiationDisabled");
      LLVM_DEBUG(getADDebugStream()
                 << "Generated empty JVP for " << original->getName() << ":\n"
                 << *jvp);
    }
  }

  // If the VJP doesn't exist, need to synthesize it.
  if (!witness->getVJP()) {
    // Diagnose:
    // - Functions with no return.
    // - Functions with unsupported control flow.
    if (diagnoseNoReturn(context, original, invoker) ||
        diagnoseUnsupportedControlFlow(context, original, invoker))
      return true;

    // Create empty VJP.
    auto *vjp = createEmptyVJP(context, original, witness, serializeFunctions);
    witness->setVJP(vjp);
    context.recordGeneratedFunction(vjp);
    // Emit VJP function.
    VJPCloner cloner(context, original, witness, vjp, invoker);
    return cloner.run();
  }
  return false;
}

//===----------------------------------------------------------------------===//
// Differentiation pass implementation
//===----------------------------------------------------------------------===//

/// The automatic differentiation pass.
namespace {
class Differentiation : public SILModuleTransform {
public:
  Differentiation() : SILModuleTransform() {}
  void run() override;
};
} // end anonymous namespace

/// Given a curry thunk application, clone the thunk to return a
/// `@differentiable` function-typed value and apply the cloned thunk.
///
/// Curry thunk type: `(Self) -> (T, ...) -> U`.
/// Cloned thunk type: `(Self) -> @differentiable (T, ...) -> U`.
static SILValue promoteCurryThunkApplicationToDifferentiableFunction(
    DifferentiationTransformer &dt, DifferentiableFunctionInst *dfi,
    SILBuilder &builder, SILLocation loc, DifferentiationInvoker invoker) {
  auto origFnOperand = dfi->getOriginalFunction();
  auto *parameterIndices = dfi->getParameterIndices();
  auto *resultIndices = dfi->getResultIndices();
  auto &context = dt.getContext();

  // Check for curry thunk application:
  // - The original function operand must be an `apply` instruction.
  // - The `apply` callee must be a `function_ref` instruction.
  // - The callee must return a function-typed value.
  auto *ai = dyn_cast<ApplyInst>(origFnOperand);
  if (!ai)
    return nullptr;
  auto *thunkRef = dyn_cast<FunctionRefInst>(ai->getCallee());
  if (!thunkRef)
    return nullptr;
  auto *thunk = thunkRef->getReferencedFunctionOrNull();
  auto thunkTy = thunk->getLoweredFunctionType();
  auto thunkResult = thunkTy->getSingleResult();
  auto resultFnTy = thunkResult.getInterfaceType()->getAs<SILFunctionType>();
  if (!resultFnTy)
    return nullptr;

  // Create a new curry thunk.
  AutoDiffConfig desiredConfig(parameterIndices, resultIndices);
  // TODO(TF-685): Use more principled mangling for thunks.
  auto newThunkName = "AD__" + thunk->getName().str() +
                      "__differentiable_curry_thunk_" + desiredConfig.mangle();

  // Construct new curry thunk type with `@differentiable` function
  // result.
  auto diffResultFnTy = resultFnTy->getWithExtInfo(
      resultFnTy->getExtInfo()
          .intoBuilder()
          .withDifferentiabilityKind(DifferentiabilityKind::Normal)
          .build());
  auto newThunkResult = thunkResult.getWithInterfaceType(diffResultFnTy);
  auto thunkType = SILFunctionType::get(
      thunkTy->getSubstGenericSignature(), thunkTy->getExtInfo(),
      thunkTy->getCoroutineKind(), thunkTy->getCalleeConvention(),
      thunkTy->getParameters(), {}, {newThunkResult}, {},
      thunkTy->getPatternSubstitutions(), thunkTy->getInvocationSubstitutions(),
      thunkTy->getASTContext());

  // Construct new curry thunk, returning a `@differentiable` function.
  SILOptFunctionBuilder fb(dt.getTransform());
  auto *newThunk = fb.getOrCreateFunction(
      loc, newThunkName, getSpecializedLinkage(thunk, thunk->getLinkage()),
      thunkType, thunk->isBare(), thunk->isTransparent(), thunk->isSerialized(),
      thunk->isDynamicallyReplaceable(), ProfileCounter(), thunk->isThunk());
  // If new thunk is newly created: clone the old thunk body, wrap the
  // returned function value with an `differentiable_function`
  // instruction, and process the `differentiable_function` instruction.
  if (newThunk->empty()) {
    if (auto newThunkGenSig = thunkType->getSubstGenericSignature())
      newThunk->setGenericEnvironment(newThunkGenSig->getGenericEnvironment());
    // TODO(TF-1206): Enable ownership in all differentiation thunks.
    newThunk->setOwnershipEliminated();
    BasicTypeSubstCloner cloner(thunk, newThunk);
    cloner.cloneFunction();
    auto *retInst = cast<ReturnInst>(newThunk->findReturnBB()->getTerminator());
    auto returnValue = retInst->getOperand();
    // Create `differentiable_function` instruction directly after the
    // defining instruction (e.g. `partial_apply`) of the returned value.
    // Note: `differentiable_function` is not created at the end of the
    // new thunk to avoid `alloc_stack`/`dealloc_stack` ordering issues.
    SILBuilderWithScope dfiBuilder(
        std::next(returnValue->getDefiningInstruction()->getIterator()));
    auto *dfi = context.createDifferentiableFunction(
        dfiBuilder, loc, parameterIndices, resultIndices, returnValue);
    dfiBuilder.setInsertionPoint(newThunk->findReturnBB());
    dfiBuilder.createReturn(loc, dfi);
    retInst->eraseFromParent();

    context.recordGeneratedFunction(newThunk);
    context.getDifferentiableFunctionInstWorklist().push_back(dfi);
    if (dt.processDifferentiableFunctionInst(dfi))
      return nullptr;
  }

  // Apply the new curry thunk.
  auto *newThunkRef = builder.createFunctionRef(loc, newThunk);
  context.recordGeneratedFunctionReference(newThunkRef);
  SmallVector<SILValue, 8> newArgs;
  SmallVector<SILValue, 8> newArgsToDestroy;
  SmallVector<AllocStackInst *, 1> newBuffersToDealloc;
  copyParameterArgumentsForApply(ai, newArgs, newArgsToDestroy,
                                 newBuffersToDealloc);
  auto *newApply = builder.createApply(
      loc, newThunkRef, ai->getSubstitutionMap(), newArgs, ai->isNonThrowing());
  for (auto arg : newArgsToDestroy)
    builder.emitDestroyOperation(loc, arg);
  for (auto *alloc : newBuffersToDealloc)
    builder.createDeallocStack(loc, alloc);
  return newApply;
}

SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
    DifferentiableFunctionInst *dfi, SILBuilder &builder, SILLocation loc,
    DifferentiationInvoker invoker) {
  auto &astCtx = context.getASTContext();
  auto origFnOperand = dfi->getOriginalFunction();
  auto origFnTy = origFnOperand->getType().castTo<SILFunctionType>();
  auto *parameterIndices = dfi->getParameterIndices();
  auto *resultIndices = dfi->getResultIndices();

  if (auto diffFn = promoteCurryThunkApplicationToDifferentiableFunction(
          *this, dfi, builder, loc, invoker))
    return diffFn;

  AutoDiffConfig desiredConfig(parameterIndices, resultIndices);
  SmallVector<SILValue, 2> derivativeFns;
  SmallVector<AllocStackInst *, 2> newBuffersToDealloc;
  for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
                                AutoDiffDerivativeFunctionKind::VJP}) {
    auto derivativeFnAndIndices = emitDerivativeFunctionReference(
        *this, builder, desiredConfig, derivativeFnKind, origFnOperand,
        invoker, newBuffersToDealloc);
    // Show an error at the operator, highlight the argument, and show a note
    // at the definition site of the argument.
    if (!derivativeFnAndIndices)
      return nullptr;

    auto derivativeFn = derivativeFnAndIndices->first;
    context.recordGeneratedFunctionReference(derivativeFn);

    // If desired indices are a subset of actual indices, create a "subset
    // indices thunk" and destroy the emitted derivative function reference.
    // - For JVPs: the thunked JVP returns a differential taking fewer
    //   parameters (using `.zero` for the dropped parameters).
    // - For VJPs: the thunked VJP returns a pullback that drops the unused
    //   tangent values.
    auto actualConfig = derivativeFnAndIndices->second;
    // NOTE: `desiredIndices` may come from a partially-applied function and
    // have smaller capacity than `actualIndices`. We expect this logic to go
    // away when we support `@differentiable` partial apply.
    // if (actualIndices != desiredIndices) { // TODO: Re-enable.
    auto extendedDesiredParameterIndices =
        desiredConfig.parameterIndices->extendingCapacity(
            astCtx, actualConfig.parameterIndices->getCapacity());
    if (!actualConfig.parameterIndices->equals(extendedDesiredParameterIndices)
        || !actualConfig.resultIndices->equals(desiredConfig.resultIndices)) {
      // Destroy the already emitted derivative function reference because it
      // is no longer used.
      builder.emitDestroyValueOperation(loc, derivativeFn);
      // Check if underlying original function reference has been partially
      // applied with arguments. If so, produce an error: parameter subset
      // thunks do not yet support this case because partially applied arguments
      // cannot be propagated to parameter subset thunks.
      auto didPartiallyApplyArguments = [](SILValue original) {
        while (auto *pai =
                   peerThroughFunctionConversions<PartialApplyInst>(original)) {
          if (pai->getNumArguments() > 0)
            return true;
          original = pai->getCallee();
        }
        return false;
      };
      if (didPartiallyApplyArguments(origFnOperand)) {
        context.emitNondifferentiabilityError(
            origFnOperand, invoker,
            diag::autodiff_cannot_param_subset_thunk_partially_applied_orig_fn);
        return nullptr;
      }
      // Create the parameter subset thunk.
      assert(actualConfig.parameterIndices->isSupersetOf(
          extendedDesiredParameterIndices));
      SILFunction *thunk;
      SubstitutionMap interfaceSubs;
      SILOptFunctionBuilder fb(transform);
      std::tie(thunk, interfaceSubs) =
          getOrCreateSubsetParametersThunkForDerivativeFunction(
              fb, origFnOperand, derivativeFn, derivativeFnKind, desiredConfig,
              actualConfig);
      auto *thunkFRI = builder.createFunctionRef(loc, thunk);
      if (auto genSig =
              thunk->getLoweredFunctionType()->getSubstGenericSignature()) {
        derivativeFn =
            builder.createPartialApply(loc, thunkFRI, interfaceSubs, {},
                                       ParameterConvention::Direct_Guaranteed);
      } else {
        derivativeFn = thunkFRI;
      }
    }
    auto expectedDerivativeFnTy = origFnTy->getAutoDiffDerivativeFunctionType(
        parameterIndices, resultIndices, derivativeFnKind,
        context.getTypeConverter(),
        LookUpConformanceInModule(context.getModule().getSwiftModule()));
    // If `derivativeFn` is `@convention(thin)` but is expected to be
    // `@convention(thick)`, emit a `thin_to_thick` instruction.
    if (expectedDerivativeFnTy->getRepresentation() ==
            SILFunctionTypeRepresentation::Thick &&
        derivativeFn->getType()
                .castTo<SILFunctionType>()
                ->getRepresentation() == SILFunctionTypeRepresentation::Thin) {
      derivativeFn = builder.createThinToThickFunction(
          loc, derivativeFn,
          SILType::getPrimitiveObjectType(expectedDerivativeFnTy));
    }
    // If derivative function value's type is not ABI-compatible with the
    // expected derivative function type (i.e. parameter and result conventions
    // do not match), perform reabstraction.
    auto abiCompatibility = expectedDerivativeFnTy->isABICompatibleWith(
        derivativeFn->getType().castTo<SILFunctionType>(), *dfi->getFunction());
    if (!abiCompatibility.isCompatible()) {
      SILOptFunctionBuilder fb(context.getTransform());
      auto newDerivativeFn = reabstractFunction(
          builder, fb, loc, derivativeFn, expectedDerivativeFnTy,
          [](SubstitutionMap substMap) { return substMap; });
      derivativeFn = newDerivativeFn;
      assert(expectedDerivativeFnTy
                 ->isABICompatibleWith(
                     derivativeFn->getType().castTo<SILFunctionType>(),
                     *dfi->getFunction())
                 .isCompatible());
    }

    derivativeFns.push_back(derivativeFn);
  }
  // Deallocate temporary buffers used for creating derivative functions.
  for (auto *buf : llvm::reverse(newBuffersToDealloc))
    builder.createDeallocStack(loc, buf);

  // If our original copy does not have none ownership, copy it.
  if (origFnOperand.getOwnershipKind() != OwnershipKind::None)
    origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand);
  auto *newDiffFn = context.createDifferentiableFunction(
      builder, loc, parameterIndices, resultIndices, origFnOperand,
      std::make_pair(derivativeFns[0], derivativeFns[1]));
  context.getDifferentiableFunctionInstWorklist().push_back(dfi);
  return newDiffFn;
}

SILValue DifferentiationTransformer::promoteToLinearFunction(
    LinearFunctionInst *lfi, SILBuilder &builder, SILLocation loc,
    DifferentiationInvoker invoker) {
  // Note: for now, this function creates a new `linear_function` instruction
  // with an undef transpose function operand. Eventually, a legitimate
  // transpose function operand should be created and used.
  auto origFnOperand = lfi->getOriginalFunction();
  if (origFnOperand.getOwnershipKind() != OwnershipKind::None)
    origFnOperand = builder.emitCopyValueOperation(loc, origFnOperand);
  auto *parameterIndices = lfi->getParameterIndices();
  auto originalType = origFnOperand->getType().castTo<SILFunctionType>();
  auto transposeFnType = originalType->getAutoDiffTransposeFunctionType(
      parameterIndices, context.getTypeConverter(),
      LookUpConformanceInModule(builder.getModule().getSwiftModule()));
  auto transposeType = SILType::getPrimitiveObjectType(transposeFnType);
  auto transposeFn = SILUndef::get(transposeType, builder.getFunction());
  auto *newLinearFn = context.createLinearFunction(
      builder, loc, parameterIndices, origFnOperand, SILValue(transposeFn));
  context.getLinearFunctionInstWorklist().push_back(lfi);
  return newLinearFn;
}

/// Fold `differentiable_function_extract` users of the given
/// `differentiable_function` instruction, directly replacing them with
/// `differentiable_function` instruction operands. If the
/// `differentiable_function` instruction has no remaining uses, delete the
/// instruction itself after folding.
///
/// Folding can be disabled by the `SkipFoldingDifferentiableFunctionExtraction`
/// flag for SIL testing purposes.
// FIXME: This function is not correctly detecting the foldable pattern and
// needs to be rewritten.
void DifferentiationTransformer::foldDifferentiableFunctionExtraction(
    DifferentiableFunctionInst *source) {
  // Iterate through all `differentiable_function` instruction uses.
  for (auto use : source->getUses()) {
    auto *dfei = dyn_cast<DifferentiableFunctionExtractInst>(use->getUser());
    // If user is not an `differentiable_function_extract` instruction, set flag
    // to false.
    if (!dfei)
      continue;
    // Fold original function extractors.
    if (dfei->getExtractee() ==
        NormalDifferentiableFunctionTypeComponent::Original) {
      auto originalFnValue = source->getOriginalFunction();
      dfei->replaceAllUsesWith(originalFnValue);
      dfei->eraseFromParent();
      continue;
    }
    // Fold derivative function extractors.
    auto derivativeFnValue =
        source->getDerivativeFunction(dfei->getDerivativeFunctionKind());
    dfei->replaceAllUsesWith(derivativeFnValue);
    dfei->eraseFromParent();
  }
  // If the `differentiable_function` instruction has no remaining uses, erase
  // it.
  if (isInstructionTriviallyDead(source)) {
    SILBuilder builder(source);
    builder.emitDestroyAddrAndFold(source->getLoc(), source->getJVPFunction());
    builder.emitDestroyAddrAndFold(source->getLoc(), source->getVJPFunction());
    source->eraseFromParent();
  }
  // Mark `source` as processed so that it won't be reprocessed after deletion.
  context.markDifferentiableFunctionInstAsProcessed(source);
}

bool DifferentiationTransformer::processDifferentiableFunctionInst(
    DifferentiableFunctionInst *dfi) {
  PrettyStackTraceSILNode dfiTrace("canonicalizing `differentiable_function`",
                                   cast<SILInstruction>(dfi));
  PrettyStackTraceSILFunction fnTrace("...in", dfi->getFunction());
  LLVM_DEBUG({
    auto &s = getADDebugStream() << "Processing DifferentiableFunctionInst:\n";
    dfi->printInContext(s);
  });

  // If `dfi` already has derivative functions, do not process.
  if (dfi->hasDerivativeFunctions())
    return false;

  SILFunction *parent = dfi->getFunction();
  auto loc = dfi->getLoc();
  SILBuilderWithScope builder(dfi);
  auto differentiableFnValue =
      promoteToDifferentiableFunction(dfi, builder, loc, dfi);
  // Mark `dfi` as processed so that it won't be reprocessed after deletion.
  context.markDifferentiableFunctionInstAsProcessed(dfi);
  if (!differentiableFnValue)
    return true;
  // Replace all uses of `dfi`.
  dfi->replaceAllUsesWith(differentiableFnValue);
  // Destroy the original operand.
  builder.emitDestroyValueOperation(loc, dfi->getOriginalFunction());
  dfi->eraseFromParent();
  // If the promoted `@differentiable` function-typed value is an
  // `differentiable_function` instruction, fold
  // `differentiable_function_extract` instructions. If
  // `differentiable_function_extract` folding is disabled, return.
  if (!SkipFoldingDifferentiableFunctionExtraction)
    if (auto *newDFI =
            dyn_cast<DifferentiableFunctionInst>(differentiableFnValue))
      foldDifferentiableFunctionExtraction(newDFI);
  transform.invalidateAnalysis(parent,
                               SILAnalysis::InvalidationKind::FunctionBody);
  return false;
}

bool DifferentiationTransformer::processLinearFunctionInst(
    LinearFunctionInst *lfi) {
  PrettyStackTraceSILNode dfiTrace("canonicalizing `linear_function`",
                                   cast<SILInstruction>(lfi));
  PrettyStackTraceSILFunction fnTrace("...in", lfi->getFunction());
  LLVM_DEBUG({
    auto &s = getADDebugStream() << "Processing LinearFunctionInst:\n";
    lfi->printInContext(s);
  });

  // If `lfi` already has a transpose function, do not process.
  if (lfi->hasTransposeFunction())
    return false;

  SILFunction *parent = lfi->getFunction();
  auto loc = lfi->getLoc();
  SILBuilderWithScope builder(lfi);
  auto linearFnValue = promoteToLinearFunction(lfi, builder, loc, lfi);
  // Mark `lfi` as processed so that it won't be reprocessed after deletion.
  context.markLinearFunctionInstAsProcessed(lfi);
  if (!linearFnValue)
    return true;
  // Replace all uses of `lfi`.
  lfi->replaceAllUsesWith(linearFnValue);
  // Destroy the original operand.
  builder.emitDestroyValueOperation(loc, lfi->getOriginalFunction());
  lfi->eraseFromParent();

  transform.invalidateAnalysis(parent,
                               SILAnalysis::InvalidationKind::FunctionBody);
  return false;
}

/// Automatic differentiation transform entry.
void Differentiation::run() {
  auto &module = *getModule();
  auto &astCtx = module.getASTContext();
  debugDump(module);

  // A transformation helper.
  DifferentiationTransformer transformer(*this);
  ADContext &context = transformer.getContext();

  bool errorOccurred = false;

  // Register all the SIL differentiability witnesses in the module that trigger
  // differentiation.
  for (auto &witness : module.getDifferentiabilityWitnesses()) {
    if (witness.isDeclaration())
      continue;
    context.addInvoker(&witness);
  }

  // Register all the `differentiable_function` and `linear_function`
  // instructions in the module that trigger differentiation.
  for (SILFunction &f : module) {
    for (SILBasicBlock &bb : f) {
      for (SILInstruction &i : bb) {
        if (auto *dfi = dyn_cast<DifferentiableFunctionInst>(&i)) {
          context.getDifferentiableFunctionInstWorklist().push_back(dfi);
        } else if (auto *lfi = dyn_cast<LinearFunctionInst>(&i)) {
          // If linear map transposition is not enabled and an uncanonical
          // `linear_function` instruction is encountered, emit a diagnostic.
          // FIXME(SR-11850): Finish support for linear map transposition.
          if (!EnableExperimentalLinearMapTransposition) {
            if (!lfi->hasTransposeFunction()) {
              astCtx.Diags.diagnose(
                lfi->getLoc().getSourceLoc(),
                diag::autodiff_conversion_to_linear_function_not_supported);
              errorOccurred = true;
            }
          }
          context.getLinearFunctionInstWorklist().push_back(lfi);
        }
      }
    }
  }

  // If nothing has triggered differentiation, there's nothing to do.
  if (context.getInvokers().empty() &&
      context.getDifferentiableFunctionInstWorklist().empty() &&
      context.getLinearFunctionInstWorklist().empty())
    return;

  // Differentiation relies on the stdlib (the Swift module).
  // If it's not imported, it's an internal error.
  if (!astCtx.getStdlibModule()) {
    astCtx.Diags.diagnose(SourceLoc(),
                          diag::autodiff_internal_swift_not_imported);
    return;
  }
  if (!astCtx.getProtocol(KnownProtocolKind::Differentiable)) {
    SourceLoc loc;
    if (!context.getInvokers().empty()) {
      loc = context.getInvokers().front().second.getLocation();
    } else {
      assert(!context.getDifferentiableFunctionInstWorklist().empty());
      loc = context.getDifferentiableFunctionInstWorklist()
                .pop_back_val()
                ->getLoc()
                .getSourceLoc();
    }
    astCtx.Diags.diagnose(loc,
                          diag::autodiff_differentiation_module_not_imported);
    return;
  }

  // Process all invokers.
  for (auto invokerPair : context.getInvokers()) {
    auto *witness = invokerPair.first;
    auto *original = witness->getOriginalFunction();
    auto invoker = invokerPair.second;

    if (transformer.canonicalizeDifferentiabilityWitness(
            original, witness, invoker, original->isSerialized()))
      errorOccurred = true;
  }

  // Iteratively process `differentiable_function` instruction worklist.
  while (!context.getDifferentiableFunctionInstWorklist().empty()) {
    auto *dfi = context.getDifferentiableFunctionInstWorklist().pop_back_val();
    // Skip instructions that have been already been processed.
    if (context.isDifferentiableFunctionInstProcessed(dfi))
      continue;
    errorOccurred |= transformer.processDifferentiableFunctionInst(dfi);
  }

  // Iteratively process `linear_function` instruction worklist.
  while (!context.getLinearFunctionInstWorklist().empty()) {
    auto *lfi = context.getLinearFunctionInstWorklist().pop_back_val();
    // Skip instructions that have been already been processed.
    if (context.isLinearFunctionInstProcessed(lfi))
      continue;
    errorOccurred |= transformer.processLinearFunctionInst(lfi);
  }

  // If any error occurred while processing witnesses or
  // `differentiable_function` instructions, clean up.
  if (errorOccurred) {
    context.cleanUp();
    return;
  }

  LLVM_DEBUG(getADDebugStream() << "All differentiation finished\n");
}

//===----------------------------------------------------------------------===//
// Pass creation
//===----------------------------------------------------------------------===//

SILTransform *swift::createDifferentiation() { return new Differentiation; }
