Merge pull request #270 from dan-zheng/holiday-patches

[AutoDiff] Disable `differentiable_function_extract` explicit type assertion.
diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h
index 4b99d34..39b2edb 100644
--- a/include/swift/SIL/SILInstruction.h
+++ b/include/swift/SIL/SILInstruction.h
@@ -8920,7 +8920,12 @@
                    SILModule &module);
 
 public:
-  /// Note: explicit extractee type may be specified only in lowered SIL.
+  /// Note: explicit extractee type is used to avoid inconsistent typing in:
+  /// - Canonical SIL, due to generic specialization.
+  /// - Lowered SIL, due to LoadableByAddress.
+  /// - Raw SIL, due to deserialization of canonical/lowered SIL functions.
+  /// See `TypeSubstCloner::visitDifferentiableFunctionExtractInst` for an
+  /// explanation of how explicit extractee type is used.
   explicit DifferentiableFunctionExtractInst(
       SILModule &module, SILDebugLocation debugLoc,
       NormalDifferentiableFunctionTypeComponent extractee, SILValue function,
diff --git a/lib/SIL/IR/SILInstructions.cpp b/lib/SIL/IR/SILInstructions.cpp
index 59993c2..0e085f7 100644
--- a/lib/SIL/IR/SILInstructions.cpp
+++ b/lib/SIL/IR/SILInstructions.cpp
@@ -727,18 +727,6 @@
                                : getExtracteeType(function, extractee, module),
                            function.getOwnershipKind()),
       Extractee(extractee), HasExplicitExtracteeType(extracteeType.hasValue()) {
-#ifndef NDEBUG
-  if (extracteeType.hasValue()) {
-    // Note: explicit extractee type is used to avoid inconsistent typing in:
-    // - Canonical SIL, due to generic specialization.
-    // - Lowered SIL, due to LoadableByAddress.
-    // See `TypeSubstCloner::visitDifferentiableFunctionExtractInst` for an
-    // explanation of how explicit extractee type is used.
-    assert((module.getStage() == SILStage::Canonical ||
-            module.getStage() == SILStage::Lowered) &&
-           "Explicit type is valid only in canonical or lowered SIL");
-  }
-#endif
 }
 
 SILType LinearFunctionExtractInst::
diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp
index fa034b1..792caf8 100644
--- a/lib/Serialization/DeserializeSIL.cpp
+++ b/lib/Serialization/DeserializeSIL.cpp
@@ -1186,7 +1186,7 @@
   case SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT:
     SILInstDifferentiableFunctionExtractLayout::readRecord(
         scratch, TyID, TyCategory, ValID, /*extractee*/ Attr,
-        /*hasExplicitExtracteeType*/ Attr2);
+        /*hasExplicitExtracteeType*/ Attr2, /*explicitExtracteeType*/ TyID2);
     RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionExtractInst;
     break;
   case SIL_INST_LINEAR_FUNCTION_EXTRACT:
@@ -2747,8 +2747,11 @@
     auto val = getLocalValue(ValID, silTy);
     NormalDifferentiableFunctionTypeComponent extractee(Attr);
     Optional<SILType> explicitExtracteeType = None;
-    if (Attr2)
-      explicitExtracteeType = silTy;
+    if (Attr2) {
+      auto extracteeASTType = MF->getType(TyID2);
+      explicitExtracteeType =
+          getSILType(extracteeASTType, SILValueCategory::Object, Fn);
+    }
     ResultInst = Builder.createDifferentiableFunctionExtract(
         Loc, extractee, val, explicitExtracteeType);
     break;
diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h
index cd9cdd0..b357d71 100644
--- a/lib/Serialization/ModuleFormat.h
+++ b/lib/Serialization/ModuleFormat.h
@@ -56,7 +56,7 @@
 /// describe what change you made. The content of this comment isn't important;
 /// it just ensures a conflict if two people change the module format.
 /// Don't worry about adhering to the 80-column limit for this line.
-const uint16_t SWIFTMODULE_VERSION_MINOR = 589; // cache prespecialization decls.
+const uint16_t SWIFTMODULE_VERSION_MINOR = 590; // differentiable_function_extract explicit extractee type
 
 /// A standard hash seed used for all string hashes in a serialized module.
 ///
diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h
index 28b60b3..8d14762 100644
--- a/lib/Serialization/SILFormat.h
+++ b/lib/Serialization/SILFormat.h
@@ -478,8 +478,9 @@
     TypeIDField,
     SILTypeCategoryField,
     ValueIDField,
-    BCFixed<2>, // extractee
-    BCFixed<1>  // has explicit extractee type?
+    BCFixed<2>,  // extractee
+    BCFixed<1>,  // has explicit extractee type?
+    TypeIDField  // explicit extractee type
   >;
 
   using SILInstLinearFunctionExtractLayout = BCRecordLayout<
diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp
index 4656ea4..f6d2ae4 100644
--- a/lib/Serialization/SerializeSIL.cpp
+++ b/lib/Serialization/SerializeSIL.cpp
@@ -2294,11 +2294,13 @@
     auto operandType = dfei->getOperand()->getType();
     auto operandTypeRef = S.addTypeRef(operandType.getASTType());
     auto rawExtractee = (unsigned)dfei->getExtractee();
+    auto extracteeTypeRef = S.addTypeRef(dfei->getType().getASTType());
     SILInstDifferentiableFunctionExtractLayout::emitRecord(
         Out, ScratchRecord,
         SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code],
         operandTypeRef, (unsigned)operandType.getCategory(), operandRef,
-        rawExtractee, (unsigned)dfei->hasExplicitExtracteeType());
+        rawExtractee, (unsigned)dfei->hasExplicitExtracteeType(),
+        extracteeTypeRef);
     break;
   }
   case SILInstructionKind::LinearFunctionExtractInst: {
diff --git a/test/AutoDiff/compiler_crashers_fixed/sr14004-cross-module-differentiation-differentiable-function-extract-inlining.swift b/test/AutoDiff/compiler_crashers_fixed/sr14004-cross-module-differentiation-differentiable-function-extract-inlining.swift
new file mode 100644
index 0000000..0b4d6f5
--- /dev/null
+++ b/test/AutoDiff/compiler_crashers_fixed/sr14004-cross-module-differentiation-differentiable-function-extract-inlining.swift
@@ -0,0 +1,34 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift-dylib(%t/%target-library-name(Library)) -emit-module -emit-module-path %t/Library.swiftmodule -module-name Library -DLIBRARY %s
+// RUN: %target-build-swift -I %t -O -emit-module %s
+
+// SR-14004: Assertion failure due to function with `differentiable_function_extract`
+// with explicit extractee type being deserialized into a raw SIL module.
+
+#if LIBRARY
+
+import _Differentiation
+
+public struct Struct<Scalar>: Differentiable {}
+
+@differentiable
+public func foo<Scalar>(_ x: Struct<Scalar>) -> Struct<Scalar> { x }
+
+@inlinable
+@differentiable
+public func bar<Scalar>(_ x: Struct<Scalar>) -> Struct<Scalar> {
+  foo(x)
+}
+
+#else
+
+import _Differentiation
+import Library
+
+public func foo(
+  body: @differentiable (Struct<Float>) -> Struct<Float> = bar
+) {
+  fatalError()
+}
+
+#endif