Merge pull request #27821 from apple/tensorflow-merge2

Merge swift-DEVELOPMENT-SNAPSHOT-2019-10-13-a into tensorflow
diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h
index d792b6e..ed94123 100644
--- a/include/swift/AST/AutoDiff.h
+++ b/include/swift/AST/AutoDiff.h
@@ -269,11 +269,11 @@
 struct AutoDiffConfig {
   IndexSubset *parameterIndices;
   IndexSubset *resultIndices;
-  GenericSignatureImpl* derivativeGenericSignature;
+  GenericSignature derivativeGenericSignature;
 
   /*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
                               IndexSubset *resultIndices,
-                              GenericSignatureImpl *derivativeGenericSignature)
+                              GenericSignature derivativeGenericSignature)
       : parameterIndices(parameterIndices), resultIndices(resultIndices),
         derivativeGenericSignature(derivativeGenericSignature) {}
 
@@ -443,7 +443,7 @@
 
 using swift::AutoDiffConfig;
 using swift::AutoDiffDerivativeFunctionKind;
-using swift::GenericSignatureImpl;
+using swift::GenericSignature;
 using swift::IndexSubset;
 using swift::SILAutoDiffIndices;
 
@@ -453,27 +453,29 @@
   static AutoDiffConfig getEmptyKey() {
     auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
     return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
-            static_cast<GenericSignatureImpl *>(ptr)};
+            DenseMapInfo<GenericSignature>::getEmptyKey()};
   }
 
   static AutoDiffConfig getTombstoneKey() {
     auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
     return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
-            static_cast<GenericSignatureImpl *>(ptr)};
+            DenseMapInfo<GenericSignature>::getTombstoneKey()};
   }
 
   static unsigned getHashValue(const AutoDiffConfig &Val) {
     unsigned combinedHash = hash_combine(
         ~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
         DenseMapInfo<void *>::getHashValue(Val.resultIndices),
-        DenseMapInfo<void *>::getHashValue(Val.derivativeGenericSignature));
+        DenseMapInfo<GenericSignature>::getHashValue(
+            Val.derivativeGenericSignature));
     return combinedHash;
   }
 
   static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
     return LHS.parameterIndices == RHS.parameterIndices &&
         LHS.resultIndices == RHS.resultIndices &&
-        LHS.derivativeGenericSignature == RHS.derivativeGenericSignature;
+        DenseMapInfo<GenericSignature>::isEqual(LHS.derivativeGenericSignature,
+                                                RHS.derivativeGenericSignature);
   }
 };
 
diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def
index 57f88c4..98cd135 100644
--- a/include/swift/AST/DiagnosticsParse.def
+++ b/include/swift/AST/DiagnosticsParse.def
@@ -689,6 +689,8 @@
 // SIL differentiability witnesses
 ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
       "expected '%0' in differentiability witness", (StringRef))
+ERROR(sil_diff_witness_serialized_declaration,none,
+      "differentiability witness declaration should not be serialized", ())
 
 // SIL Coverage Map
 ERROR(sil_coverage_func_not_found, none,
diff --git a/include/swift/AST/IndexSubset.h b/include/swift/AST/IndexSubset.h
index 3bdb021..e2b8a98 100644
--- a/include/swift/AST/IndexSubset.h
+++ b/include/swift/AST/IndexSubset.h
@@ -57,6 +57,10 @@
   /// The number of bit words in the index subset.
   unsigned numBitWords;
 
+  static unsigned getNumBytesNeededForCapacity(unsigned capacity) {
+    return getNumBitWordsNeededForCapacity(capacity) * bitWordSize;
+  }
+
   BitWord *getBitWordsData() {
     return reinterpret_cast<BitWord *>(this + 1);
   }
diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h
index be1fa2e..1a893a5 100644
--- a/include/swift/Basic/LangOptions.h
+++ b/include/swift/Basic/LangOptions.h
@@ -331,6 +331,11 @@
     /// `@differentiable` declaration attribute, etc.
     bool EnableExperimentalDifferentiableProgramming = false;
 
+    // SWIFT_ENABLE_TENSORFLOW
+    /// Whether to enable forward mode differentiation.
+    bool EnableExperimentalForwardModeDifferentiation = false;
+    // SWIFT_ENABLE_TENSORFLOW END
+
     /// Whether to enable #quote, #unquote and @quoted.
     bool EnableExperimentalQuasiquotes = false;
 
diff --git a/include/swift/Option/Options.td b/include/swift/Option/Options.td
index 9b6cea3..a703542 100644
--- a/include/swift/Option/Options.td
+++ b/include/swift/Option/Options.td
@@ -428,6 +428,12 @@
 def enable_experimental_differentiable_programming : Flag<["-"], "enable-experimental-differentiable-programming">,
   Flags<[FrontendOption]>,
   HelpText<"Enable experimental differentiable programming features">;
+// SWIFT_ENABLE_TENSORFLOW
+// NOTE: This flag will be removed when JVP/differential generation is robust.
+def enable_experimental_forward_mode_differentiation : Flag<["-"], "enable-experimental-forward-mode-differentiation">,
+  Flags<[FrontendOption]>,
+  HelpText<"Enable experimental forward mode differentiation">;
+// SWIFT_ENABLE_TENSORFLOW END
 
 def enable_experimental_quasiquotes : Flag<["-"], "enable-experimental-quasiquotes">,
   Flags<[FrontendOption]>,
diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h
index 0a5ab1f..f37459f 100644
--- a/include/swift/SIL/SILDifferentiabilityWitness.h
+++ b/include/swift/SIL/SILDifferentiabilityWitness.h
@@ -43,26 +43,28 @@
 {
 private:
   /// The module which contains the differentiability witness.
-  SILModule &module;
+  SILModule &Module;
   /// The linkage of the differentiability witness.
-  SILLinkage linkage;
+  SILLinkage Linkage;
   /// The original function.
-  SILFunction *originalFunction;
+  SILFunction *OriginalFunction;
   /// The autodiff configuration: parameter indices, result indices, derivative
   /// generic signature (optional).
-  AutoDiffConfig config;
+  AutoDiffConfig Config;
   /// The JVP (Jacobian-vector products) derivative function.
-  SILFunction *jvp;
+  SILFunction *JVP;
   /// The VJP (vector-Jacobian products) derivative function.
-  SILFunction *vjp;
+  SILFunction *VJP;
+  /// Whether or not this differentiability witness is a declaration.
+  bool IsDeclaration;
   /// Whether or not this differentiability witness is serialized, which allows
   /// devirtualization from another module.
-  bool serialized;
+  bool IsSerialized;
   /// The AST `@differentiable` or `@differentiating` attribute from which the
   /// differentiability witness is generated. Used for diagnostics.
   /// Null if the differentiability witness is parsed from SIL or if it is
   /// deserialized.
-  DeclAttribute *attribute = nullptr;
+  DeclAttribute *Attribute = nullptr;
 
   SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
                               SILFunction *originalFunction,
@@ -70,51 +72,60 @@
                               IndexSubset *resultIndices,
                               GenericSignature derivativeGenSig,
                               SILFunction *jvp, SILFunction *vjp,
-                              bool isSerialized, DeclAttribute *attribute)
-    : module(module), linkage(linkage), originalFunction(originalFunction),
-      config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
-      jvp(jvp), vjp(vjp), serialized(isSerialized), attribute(attribute) {}
+                              bool isDeclaration, bool isSerialized,
+                              DeclAttribute *attribute)
+    : Module(module), Linkage(linkage), OriginalFunction(originalFunction),
+      Config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
+      JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration),
+      IsSerialized(isSerialized), Attribute(attribute) {}
 
 public:
-  static SILDifferentiabilityWitness *create(
+  static SILDifferentiabilityWitness *createDeclaration(
+      SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
+      IndexSubset *parameterIndices, IndexSubset *resultIndices,
+      GenericSignature derivativeGenSig, DeclAttribute *attribute = nullptr);
+
+  static SILDifferentiabilityWitness *createDefinition(
       SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
       IndexSubset *parameterIndices, IndexSubset *resultIndices,
       GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
       bool isSerialized, DeclAttribute *attribute = nullptr);
 
   SILDifferentiabilityWitnessKey getKey() const;
-  SILModule &getModule() const { return module; }
-  SILLinkage getLinkage() const { return linkage; }
-  SILFunction *getOriginalFunction() const { return originalFunction; }
-  const AutoDiffConfig &getConfig() const { return config; }
+  SILModule &getModule() const { return Module; }
+  SILLinkage getLinkage() const { return Linkage; }
+  SILFunction *getOriginalFunction() const { return OriginalFunction; }
+  const AutoDiffConfig &getConfig() const { return Config; }
   IndexSubset *getParameterIndices() const {
-    return config.parameterIndices;
+    return Config.parameterIndices;
   }
   IndexSubset *getResultIndices() const {
-    return config.resultIndices;
+    return Config.resultIndices;
   }
   GenericSignature getDerivativeGenericSignature() const {
-    return config.derivativeGenericSignature;
+    return Config.derivativeGenericSignature;
   }
-  SILFunction *getJVP() const { return jvp; }
-  SILFunction *getVJP() const { return vjp; }
+  SILFunction *getJVP() const { return JVP; }
+  SILFunction *getVJP() const { return VJP; }
   SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
     switch (kind) {
-    case AutoDiffDerivativeFunctionKind::JVP: return jvp;
-    case AutoDiffDerivativeFunctionKind::VJP: return vjp;
+    case AutoDiffDerivativeFunctionKind::JVP: return JVP;
+    case AutoDiffDerivativeFunctionKind::VJP: return VJP;
     }
   }
-  void setJVP(SILFunction *jvp) { this->jvp = jvp; }
-  void setVJP(SILFunction *vjp) { this->vjp = vjp; }
+  void setJVP(SILFunction *jvp) { JVP = jvp; }
+  void setVJP(SILFunction *vjp) { VJP = vjp; }
   void setDerivative(AutoDiffDerivativeFunctionKind kind,
                      SILFunction *derivative) {
     switch (kind) {
-    case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break;
-    case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break;
+    case AutoDiffDerivativeFunctionKind::JVP: JVP = derivative; break;
+    case AutoDiffDerivativeFunctionKind::VJP: VJP = derivative; break;
     }
   }
-  bool isSerialized() const { return serialized; }
-  DeclAttribute *getAttribute() const { return attribute; }
+  bool isDeclaration() const { return IsDeclaration; }
+  bool isDefinition() const { return !IsDeclaration; }
+  bool isSerialized() const { return IsSerialized; }
+  DeclAttribute *getAttribute() const { return Attribute; }
 
   /// Verify that the differentiability witness is well-formed.
   void verify(const SILModule &module) const;
diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp
index fc9ba85..8905b96 100644
--- a/lib/AST/ASTContext.cpp
+++ b/lib/AST/ASTContext.cpp
@@ -4821,7 +4821,7 @@
   if (existing)
     return existing;
   auto sizeToAlloc = sizeof(IndexSubset) +
-      getNumBitWordsNeededForCapacity(capacity);
+      getNumBytesNeededForCapacity(capacity);
   auto *buf = reinterpret_cast<IndexSubset *>(
       ctx.Allocate(sizeToAlloc, alignof(IndexSubset)));
   auto *newNode = new (buf) IndexSubset(indices);
diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp
index 6871ce1..04d8b64 100644
--- a/lib/AST/ASTMangler.cpp
+++ b/lib/AST/ASTMangler.cpp
@@ -434,7 +434,7 @@
   auto originalName = key.first;
   auto *parameterIndices = key.second.parameterIndices;
   auto *resultIndices = key.second.resultIndices;
-  auto *derivativeGenericSignature = key.second.derivativeGenericSignature;
+  auto derivativeGenericSignature = key.second.derivativeGenericSignature;
 
   Buffer << "AD__" << originalName << '_';
   Buffer << "P" << parameterIndices->getString();
diff --git a/lib/Driver/ToolChains.cpp b/lib/Driver/ToolChains.cpp
index efd41ed..d8a3548 100644
--- a/lib/Driver/ToolChains.cpp
+++ b/lib/Driver/ToolChains.cpp
@@ -231,6 +231,10 @@
                        options::OPT_enable_experimental_dependencies);
   inputArgs.AddLastArg(arguments,
                        options::OPT_experimental_dependency_include_intrafile);
+  // SWIFT_ENABLE_TENSORFLOW
+  inputArgs.AddLastArg(
+      arguments, options::OPT_enable_experimental_forward_mode_differentiation);
+  // SWIFT_ENABLE_TENSORFLOW END
   inputArgs.AddLastArg(arguments,
                        options::OPT_enable_experimental_quasiquotes);
   inputArgs.AddLastArg(arguments, options::OPT_package_description_version);
diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp
index d2b9616..c07065a 100644
--- a/lib/Frontend/CompilerInvocation.cpp
+++ b/lib/Frontend/CompilerInvocation.cpp
@@ -358,6 +358,9 @@
   if (Args.hasArg(OPT_experimental_dependency_include_intrafile))
     Opts.ExperimentalDependenciesIncludeIntrafileOnes = true;
 
+  // TODO: Ignore if enable-experimental-differentiable-programming is false
+  Opts.EnableExperimentalForwardModeDifferentiation |=
+      Args.hasArg(OPT_enable_experimental_forward_mode_differentiation);
   if (Args.hasArg(OPT_enable_experimental_quasiquotes))
     Opts.EnableExperimentalQuasiquotes = true;
 
diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp
index bdd375f..f25ffb2 100644
--- a/lib/IRGen/GenFunc.cpp
+++ b/lib/IRGen/GenFunc.cpp
@@ -735,6 +735,16 @@
   // Create a new explosion for potentially reabstracted parameters.
   Explosion args;
 
+  // SWIFT_ENABLE_TENSORFLOW
+  // The witness method self argument comes after polymorphic arguments (and is
+  // followed by the self type and the witness table). However, we may encounter
+  // the witness method self value before reaching the polymorphic arguments. So
+  // we create a special explosion for storing the witness method self value
+  // until it's time to add it to 'args'.
+  bool isWitnessMethodCallee = origType->getRepresentation() ==
+                               SILFunctionTypeRepresentation::WitnessMethod;
+  Explosion witnessMethodSelfValue;
+
   Address resultValueAddr;
 
   {
@@ -775,6 +785,10 @@
     
     // Reemit the parameters as unsubstituted.
     for (unsigned i = 0; i < outType->getParameters().size(); ++i) {
+      // SWIFT_ENABLE_TENSORFLOW
+      bool isWitnessMethodCalleeSelf =
+          (isWitnessMethodCallee && i + 1 == origType->getParameters().size());
+
       auto origParamInfo = origType->getParameters()[i];
       auto &ti = IGM.getTypeInfoForLowered(origParamInfo.getType());
       auto schema = ti.getSchema();
@@ -788,7 +802,8 @@
         if (addr->getType() != ti.getStorageType()->getPointerTo())
           addr = subIGF.Builder.CreateBitCast(addr,
                                            ti.getStorageType()->getPointerTo());
-        args.add(addr);
+        // SWIFT_ENABLE_TENSORFLOW
+        (isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args).add(addr);
         continue;
       }
 
@@ -796,8 +811,10 @@
       // Indirect parameters need no mapping through the native calling
       // convention.
       if (isIndirectParam) {
-        emitApplyArgument(subIGF, origParamInfo, outTypeParamInfo, origParams,
-                          args);
+        emitApplyArgument(
+            subIGF, origParamInfo, outTypeParamInfo, origParams,
+            // SWIFT_ENABLE_TENSORFLOW
+            (isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args));
         continue;
       }
 
@@ -824,7 +841,10 @@
       Explosion nativeApplyArg = nativeSchemaOrigParam.mapIntoNative(
           subIGF.IGM, subIGF, nonNativeApplyArg, origParamSILType, false);
       assert(nonNativeApplyArg.empty());
-      nativeApplyArg.transferInto(args, nativeApplyArg.size());
+      // SWIFT_ENABLE_TENSORFLOW
+      nativeApplyArg.transferInto(
+          (isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args),
+          nativeApplyArg.size());
     }
   }
 
@@ -934,13 +954,6 @@
   auto haveContextArgument =
       calleeHasContext || hasSelfContextParameter(origType);
 
-  // Witness method calls expect self, followed by the self type followed by,
-  // the witness table at the end of the parameter list. But polymorphic
-  // arguments come before this.
-  bool isWitnessMethodCallee = origType->getRepresentation() ==
-      SILFunctionTypeRepresentation::WitnessMethod;
-  Explosion witnessMethodSelfValue;
-
   // If there's a data pointer required, but it's a swift-retainable
   // value being passed as the context, just forward it down.
   if (!layout) {
diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp
index f12aa7d..50ba2f3 100644
--- a/lib/ParseSIL/ParseSIL.cpp
+++ b/lib/ParseSIL/ParseSIL.cpp
@@ -6933,7 +6933,9 @@
 ///   '[' 'parameters' index-subset ']'
 ///   '[' 'results' index-subset ']'
 ///   ('[' 'where' derivatve-generic-signature-requirements ']')?
-///   sil-function-name ':' sil-type
+///   decl-sil-differentiability-witness-body?
+///
+/// decl-sil-differentiability-witness-body ::=
 ///   '{'
 ///   ('jvp' sil-function-name ':' sil-type)?
 ///   ('vjp' sil-function-name ':' sil-type)?
@@ -6949,9 +6951,6 @@
   Optional<SILLinkage> linkage;
   if (parseSILLinkage(linkage, P))
     return true;
-  // Default to public linkage.
-  if (!linkage)
-    linkage = SILLinkage::Public;
 
   // Parse '[serialized]' flag (optional).
   bool isSerialized = false;
@@ -6986,8 +6985,7 @@
       P.diagnose(fnNameLoc, diag::expected_sil_function_type);
       return true;
     }
-    fn = State.getGlobalNameForReference(name, fnType, fnNameLoc, true);
-    State.TUState.PotentialZombieFns.insert(fn);
+    fn = State.getGlobalNameForReference(name, fnType, fnNameLoc);
     return false;
   };
 
@@ -7063,7 +7061,26 @@
             nullptr);
   }
 
-  // Parse differentiability witness body.
+  auto origFnType = originalFn->getLoweredFunctionType();
+  auto *parameterIndexSet = IndexSubset::get(
+      P.Context, origFnType->getNumParameters(), parameterIndices);
+  auto *resultIndexSet = IndexSubset::get(
+      P.Context, origFnType->getNumResults(), resultIndices);
+
+  // If this is just a declaration, create the declaration now and return.
+  if (!P.Tok.is(tok::l_brace)) {
+    if (isSerialized) {
+      P.diagnose(lastLoc, diag::sil_diff_witness_serialized_declaration);
+      return true;
+    }
+
+    SILDifferentiabilityWitness::createDeclaration(
+        M, linkage ? *linkage : SILLinkage::DefaultForDeclaration, originalFn,
+        parameterIndexSet, resultIndexSet, derivativeGenSig);
+    return false;
+  }
+
+  // This is a definition, so parse differentiability witness body.
   SILFunction *jvp = nullptr;
   SILFunction *vjp = nullptr;
   if (P.Tok.is(tok::l_brace)) {
@@ -7094,14 +7111,10 @@
       return true;
   }
 
-  auto origFnType = originalFn->getLoweredFunctionType();
-  auto *parameterIndexSet = IndexSubset::get(
-      P.Context, origFnType->getNumParameters(), parameterIndices);
-  auto *resultIndexSet = IndexSubset::get(
-      P.Context, origFnType->getNumResults(), resultIndices);
-  SILDifferentiabilityWitness::create(
-      M, *linkage, originalFn, parameterIndexSet, resultIndexSet,
-      derivativeGenSig, jvp, vjp, isSerialized);
+  SILDifferentiabilityWitness::createDefinition(
+      M, linkage ? *linkage : SILLinkage::DefaultForDefinition, originalFn,
+      parameterIndexSet, resultIndexSet, derivativeGenSig, jvp, vjp,
+      isSerialized);
   return false;
 }
 // SWIFT_ENABLE_TENSORFLOW END
diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp
index 0901692..2ad7bad 100644
--- a/lib/SIL/SILDifferentiabilityWitness.cpp
+++ b/lib/SIL/SILDifferentiabilityWitness.cpp
@@ -17,14 +17,14 @@
 
 using namespace swift;
 
-SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
+SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDeclaration(
     SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
     IndexSubset *parameterIndices, IndexSubset *resultIndices,
-    GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
-    bool isSerialized, DeclAttribute *attribute) {
+    GenericSignature derivativeGenSig, DeclAttribute *attribute) {
   auto *diffWitness = new (module) SILDifferentiabilityWitness(
       module, linkage, originalFunction, parameterIndices, resultIndices,
-      derivativeGenSig, jvp, vjp, isSerialized, attribute);
+      derivativeGenSig, /*jvp*/ nullptr, /*vjp*/ nullptr,
+      /*isDeclaration*/ true, /*isSerialized*/ false, attribute);
   // Register the differentiability witness in the module.
   assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
          "Cannot create duplicate differentiability witness in a module");
@@ -33,6 +33,24 @@
   return diffWitness;
 }
 
+SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
+    SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
+    IndexSubset *parameterIndices, IndexSubset *resultIndices,
+    GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
+    bool isSerialized, DeclAttribute *attribute) {
+  auto *diffWitness = new (module) SILDifferentiabilityWitness(
+      module, linkage, originalFunction, parameterIndices, resultIndices,
+      derivativeGenSig, jvp, vjp, /*isDeclaration*/ false, isSerialized,
+      attribute);
+  // Register the differentiability witness in the module.
+  assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
+         "Cannot create duplicate differentiability witness in a module");
+  module.DifferentiabilityWitnessMap[diffWitness->getKey()] = diffWitness;
+  module.getDifferentiabilityWitnessList().push_back(diffWitness);
+  return diffWitness;
+}
+
+
 SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
-  return std::make_pair(originalFunction->getName(), getConfig());
+  return std::make_pair(getOriginalFunction()->getName(), getConfig());
 }
diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp
index 653e127..2c2e73f 100644
--- a/lib/SIL/SILPrinter.cpp
+++ b/lib/SIL/SILPrinter.cpp
@@ -3164,11 +3164,11 @@
 void SILDifferentiabilityWitness::print(
     llvm::raw_ostream &OS, bool verbose) const {
   OS << "// differentiability witness for "
-     << demangleSymbol(originalFunction->getName()) << '\n';
+     << demangleSymbol(getOriginalFunction()->getName()) << '\n';
   PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType();
   // sil_differentiability_witness (linkage)?
   OS << "sil_differentiability_witness ";
-  printLinkage(OS, linkage, ForDefinition);
+  printLinkage(OS, getLinkage(), /*isDefinition*/ isDefinition());
   // ([serialized])?
   if (isSerialized())
     OS << "[serialized] ";
@@ -3187,7 +3187,7 @@
   if (auto derivativeGenSig = getDerivativeGenericSignature()) {
     ArrayRef<Requirement> requirements;
     SmallVector<Requirement, 4> requirementsScratch;
-    auto *origGenEnv = originalFunction->getGenericEnvironment();
+    auto *origGenEnv = getOriginalFunction()->getGenericEnvironment();
     if (derivativeGenSig) {
       if (origGenEnv) {
         requirementsScratch = derivativeGenSig->requirementsNotSatisfiedBy(
@@ -3210,18 +3210,22 @@
     }
   }
   // @original-function-name : $original-sil-type
-  printSILFunctionNameAndType(OS, originalFunction);
+  printSILFunctionNameAndType(OS, getOriginalFunction());
+
+  if (isDeclaration())
+    return;
+
   // {
   //   jvp: @jvp-function-name : $jvp-sil-type
   //   vjp: @vjp-function-name : $vjp-sil-type
   // }
   OS << " {\n";
-  if (jvp) {
+  if (auto *jvp = getJVP()) {
     OS << "  jvp: ";
     printSILFunctionNameAndType(OS, jvp);
     OS << '\n';
   }
-  if (vjp) {
+  if (auto *vjp = getVJP()) {
     OS << "  vjp: ";
     printSILFunctionNameAndType(OS, vjp);
     OS << '\n';
diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp
index 5809f5b..deb21f5 100644
--- a/lib/SIL/SILVerifier.cpp
+++ b/lib/SIL/SILVerifier.cpp
@@ -5388,7 +5388,7 @@
   if (!M.getOptions().VerifyAll)
     return;
 #endif
-  auto origFnType = originalFunction->getLoweredFunctionType();
+  auto origFnType = getOriginalFunction()->getLoweredFunctionType();
   CanGenericSignature derivativeCanGenSig;
   if (auto derivativeGenSig = getDerivativeGenericSignature())
     derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
@@ -5408,7 +5408,7 @@
     else
       exit(1);
   };
-  if (jvp) {
+  if (auto *jvp = getJVP()) {
     // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
     // to accept result indices.
     auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType(
@@ -5418,7 +5418,7 @@
     requireSameType(jvp->getLoweredFunctionType(), expectedJVPType,
                     "JVP type does not match expected JVP type");
   }
-  if (vjp) {
+  if (auto *vjp = getVJP()) {
     // TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
     // to result indices.
     auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType(
diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp
index 15ec27f..3a62d72 100644
--- a/lib/SILGen/SILGen.cpp
+++ b/lib/SILGen/SILGen.cpp
@@ -810,7 +810,7 @@
   bool reorderSelf = shouldReorderSelf();
 
   CanGenericSignature derivativeCanGenSig;
-  if (auto *derivativeGenSig = config.derivativeGenericSignature)
+  if (auto derivativeGenSig = config.derivativeGenericSignature)
     derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
   // TODO(TF-835): Use simpler derivative generic signature logic below when
   // type-checking no longer generates implicit `@differentiable` attributes.
@@ -832,7 +832,7 @@
   // TODO(TF-919): Explore creating serialized differentiability witnesses.
   // Currently, differentiability witnesses are never serialized to avoid
   // deserialization issues where JVP/VJP functions cannot be found.
-  auto *diffWitness = SILDifferentiabilityWitness::create(
+  auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
       M, originalFunction->getLinkage(), originalFunction,
       loweredParamIndices, config.resultIndices, derivativeCanGenSig,
       /*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false);
diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp
index 575fbe7..166c21b 100644
--- a/lib/SILOptimizer/Mandatory/Differentiation.cpp
+++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp
@@ -65,12 +65,6 @@
     "differentiation-skip-folding-differentiable-function-extraction",
     llvm::cl::init(true));
 
-/// This flag is used to enable full JVP generation.
-/// It will be removed when JVP/differential generation is robust.
-static llvm::cl::opt<bool> RunJVPGeneration(
-    "run-jvp-generation",
-    llvm::cl::init(false));
-
 //===----------------------------------------------------------------------===//
 // Helpers
 //===----------------------------------------------------------------------===//
@@ -1824,18 +1818,18 @@
 
 class DifferentiableActivityCollection {
 public:
-  SmallDenseMap<GenericSignatureImpl *, DifferentiableActivityInfo> activityInfoMap;
+  SmallDenseMap<GenericSignature, DifferentiableActivityInfo> activityInfoMap;
   SILFunction &function;
   DominanceInfo *domInfo;
   PostDominanceInfo *postDomInfo;
 
   DifferentiableActivityInfo &getActivityInfo(
       GenericSignature assocGenSig, AutoDiffDerivativeFunctionKind kind) {
-    auto activityInfoLookup = activityInfoMap.find(assocGenSig.getPointer());
+    auto activityInfoLookup = activityInfoMap.find(assocGenSig);
     if (activityInfoLookup != activityInfoMap.end())
       return activityInfoLookup->getSecond();
     auto insertion = activityInfoMap.insert(
-        {assocGenSig.getPointer(), DifferentiableActivityInfo(*this, assocGenSig)});
+        {assocGenSig, DifferentiableActivityInfo(*this, assocGenSig)});
     return insertion.first->getSecond();
   }
 
@@ -8031,8 +8025,9 @@
     // Diagnose:
     // - Functions with no return.
     // - Functions with unsupported control flow.
-    if (RunJVPGeneration && (diagnoseNoReturn(*this, original, invoker) ||
-        diagnoseUnsupportedControlFlow(*this, original, invoker)))
+    if (getASTContext().LangOpts.EnableExperimentalForwardModeDifferentiation &&
+        (diagnoseNoReturn(*this, original, invoker) ||
+         diagnoseUnsupportedControlFlow(*this, original, invoker)))
       return true;
 
     jvp = createEmptyJVP(*this, original, attr, isDerivativeFnExported);
@@ -8042,7 +8037,8 @@
     // does not exist. If custom VJP exists but custom JVP does not, skip JVP
     // generation because generated JVP may not match semantics of custom VJP.
     // Instead, create an empty JVP.
-    if (RunJVPGeneration && !vjp) {
+    if (getASTContext().LangOpts.EnableExperimentalForwardModeDifferentiation &&
+        !vjp) {
       // JVP and differential generation do not currently support functions with
       // multiple basic blocks.
       if (original->getBlocks().size() > 1) {
diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp
index 5b94d78..c0b656b 100644
--- a/lib/Serialization/DeserializeSIL.cpp
+++ b/lib/Serialization/DeserializeSIL.cpp
@@ -3407,14 +3407,19 @@
   (void)kind;
 
   DeclID originalNameId, jvpNameId, vjpNameId;
-  unsigned rawLinkage, isSerialized, numParameterIndices, numResultIndices;
+  unsigned rawLinkage, isDeclaration, isSerialized, numParameterIndices,
+           numResultIndices;
   GenericSignatureID derivativeGenSigID;
   ArrayRef<uint64_t> rawParameterAndResultIndices;
 
   DifferentiabilityWitnessLayout::readRecord(
-      scratch, originalNameId, rawLinkage, isSerialized, derivativeGenSigID,
-      jvpNameId, vjpNameId, numParameterIndices, numResultIndices,
-      rawParameterAndResultIndices);
+      scratch, originalNameId, rawLinkage, isDeclaration, isSerialized,
+      derivativeGenSigID, jvpNameId, vjpNameId, numParameterIndices,
+      numResultIndices, rawParameterAndResultIndices);
+
+  if (isDeclaration) {
+    assert(!isSerialized && "declaration must not be serialized");
+  }
 
   auto linkage = fromStableSILLinkage(rawLinkage);
   assert(linkage && "Expected value linkage for sil_differentiability_witness");
@@ -3424,11 +3429,15 @@
   auto *original = getFuncForReference(originalName);
   assert(original && "Original function must be found");
   auto *jvp = getFuncForReference(jvpName);
-  if (!jvpName.empty())
+  if (!jvpName.empty()) {
+    assert(!isDeclaration && "JVP must not be defined in declaration");
     assert(jvp && "JVP function must be found if JVP name is not empty");
+  }
   auto *vjp = getFuncForReference(vjpName);
-  if (!vjpName.empty())
+  if (!vjpName.empty()) {
+    assert(!isDeclaration && "VJP must not be defined in declaration");
     assert(vjp && "VJP function must be found if VJP name is not empty");
+  }
   auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID);
 
   SmallVector<unsigned, 8> parameterAndResultIndices(
@@ -3446,7 +3455,15 @@
       ArrayRef<unsigned>(parameterAndResultIndices)
           .take_back(numResultIndices));
 
-  auto *diffWitness = SILDifferentiabilityWitness::create(
+  if (isDeclaration) {
+    auto *diffWitness = SILDifferentiabilityWitness::createDeclaration(
+        SILMod, *linkage, original, parameterIndices, resultIndices,
+        derivativeGenSig);
+    diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ false);
+    return diffWitness;
+  }
+
+  auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
       SILMod, *linkage, original, parameterIndices, resultIndices,
       derivativeGenSig, jvp, vjp, isSerialized);
   diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true);
diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h
index 66fa50d..0f2f1c6 100644
--- a/lib/Serialization/SILFormat.h
+++ b/lib/Serialization/SILFormat.h
@@ -294,6 +294,7 @@
     SIL_DIFFERENTIABILITY_WITNESS,
     DeclIDField,             // Original function name
     SILLinkageField,         // Linkage
+    BCFixed<1>,              // Is declaration?
     BCFixed<1>,              // Is serialized?
     GenericSignatureIDField, // Derivative function generic signature
     DeclIDField,             // JVP function name
diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp
index 6e7328f..8151f58 100644
--- a/lib/Serialization/SerializeSIL.cpp
+++ b/lib/Serialization/SerializeSIL.cpp
@@ -2586,6 +2586,7 @@
       Out, ScratchRecord, SILAbbrCodes[DifferentiabilityWitnessLayout::Code],
       addSILFunctionRef(original),
       toStableSILLinkage(dw.getLinkage()),
+      dw.isDeclaration(),
       dw.isSerialized(),
       S.addGenericSignatureRef(dw.getDerivativeGenericSignature()),
       jvpID, vjpID,
diff --git a/test/AutoDiff/control_flow_sil.swift b/test/AutoDiff/control_flow_sil.swift
index a3f9110..3f18db2 100644
--- a/test/AutoDiff/control_flow_sil.swift
+++ b/test/AutoDiff/control_flow_sil.swift
@@ -1,5 +1,6 @@
 // RUN: %target-swift-frontend -emit-sil -verify -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
 // RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
+// REQUIRES: asserts
 
 // TODO: Add FileCheck tests.
 
diff --git a/test/AutoDiff/forward_mode_diagnostics.swift b/test/AutoDiff/forward_mode_diagnostics.swift
index f451e09..5d2f8fb 100644
--- a/test/AutoDiff/forward_mode_diagnostics.swift
+++ b/test/AutoDiff/forward_mode_diagnostics.swift
@@ -1,4 +1,4 @@
-// RUN: %target-swift-frontend -Xllvm -run-jvp-generation -emit-sil -verify %s
+// RUN: %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil -verify %s
 
 // TODO: move these tests back into `autodiff_diagnostics.swift` once
 // forward mode reaches feature parity with reverse mode.
diff --git a/test/AutoDiff/forward_mode_sil.swift b/test/AutoDiff/forward_mode_sil.swift
index ec40f38..8a490a0 100644
--- a/test/AutoDiff/forward_mode_sil.swift
+++ b/test/AutoDiff/forward_mode_sil.swift
@@ -1,5 +1,6 @@
-// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -run-jvp-generation -Xllvm -debug-only=differentiation %s 2>&1 | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
-// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -Xllvm -run-jvp-generation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
+// RUN: %target-swift-frontend -emit-sil -verify -enable-experimental-forward-mode-differentiation -Xllvm -debug-only=differentiation %s 2>&1 | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
+// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -enable-experimental-forward-mode-differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
+// REQUIRES: asserts
 
 
 //===----------------------------------------------------------------------===//
diff --git a/test/AutoDiff/irgen_crashers.swift b/test/AutoDiff/irgen_crashers.swift
new file mode 100644
index 0000000..7860d89
--- /dev/null
+++ b/test/AutoDiff/irgen_crashers.swift
@@ -0,0 +1,12 @@
+// RUN: %target-swift-frontend -emit-ir %s
+
+// TF-917: `partial_apply` IRGen crash.
+public protocol TF_917: Differentiable {
+  @differentiable
+  func r<A>(_ a: A) -> Float
+}
+@differentiable
+public func tf_917<B: TF_917>(_ b: B) -> Float {
+  return b.r(0.0)
+}
+
diff --git a/test/AutoDiff/refcounting.swift b/test/AutoDiff/refcounting.swift
index ac29649..8fdb2c7 100644
--- a/test/AutoDiff/refcounting.swift
+++ b/test/AutoDiff/refcounting.swift
@@ -1,5 +1,6 @@
 // RUN: %target-swift-frontend -emit-sil -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
 // RUN: %target-swift-frontend -emit-sil -Xllvm -differentiation-skip-folding-differentiable-function-extraction %s | %FileCheck %s
+// REQUIRES: asserts
 
 public class NonTrivialStuff : Equatable {
   public init() {}
diff --git a/test/AutoDiff/sil_differentiability_witness.sil b/test/AutoDiff/sil_differentiability_witness.sil
index f3d41b0..3df2a0c 100644
--- a/test/AutoDiff/sil_differentiability_witness.sil
+++ b/test/AutoDiff/sil_differentiability_witness.sil
@@ -6,8 +6,6 @@
 // RUN: %target-sil-opt %t/tmp.2.sib -module-name main | %FileCheck %s
 
 // Round-trip parsing/printing and serialization/deserialization test.
-// NOTE: deserialization currently fails if public function bodies are removed
-// so that they are only declarations. This may require investigation.
 
 sil_stage raw
 
@@ -78,3 +76,51 @@
 // CHECK:   jvp: @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
 // CHECK:   vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
 // CHECK: }
+
+// Test SIL differentiability witness for bodiless original function, with defined jvp/vjp.
+
+sil @externalFn1 : $@convention(thin) (Float) -> Float
+
+sil @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
+bb0(%0 : $Float):
+  return undef : $(Float, @callee_guaranteed (Float) -> Float)
+}
+
+sil @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
+bb0(%0 : $Float):
+  return undef : $(Float, @callee_guaranteed (Float) -> Float)
+}
+
+sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
+  jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+  vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+}
+
+// CHECK-LABEL: // differentiability witness for externalFn1
+// CHECK: sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
+// CHECK:   jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK:   vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK: }
+
+// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp.
+
+sil @externalFn2 : $@convention(thin) (Float) -> Float
+
+sil @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+
+sil @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+
+sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
+  jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+  vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+}
+
+// Test SIL differentiability witness declaration.
+
+sil @externalFn3 : $@convention(thin) (Float) -> Float
+
+sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float
+
+// CHECK-LABEL: // differentiability witness for externalFn3
+// CHECK: sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float
+// CHECK-NOT: {
diff --git a/test/lit.cfg b/test/lit.cfg
index 1281490..fd1c59f 100644
--- a/test/lit.cfg
+++ b/test/lit.cfg
@@ -1538,7 +1538,7 @@
     # TODO: Remove when forward mode AD support is robust.	
     config.target_run_simple_swift_forward_mode_differentiation = (	
         '%%empty-directory(%%t) && '	
-        '%s %s %%s -Xllvm -run-jvp-generation -o %%t/a.out %s -module-name main  && '	
+        '%s %s %%s -enable-experimental-forward-mode-differentiation -o %%t/a.out %s -module-name main  && '
         '%s %%t/a.out &&'	
         '%s %%t/a.out'	
         % (config.target_build_swift, mcp_opt, swift_tensorflow_extra_options, config.target_codesign, config.target_run))