Merge pull request #270 from dan-zheng/holiday-patches
[AutoDiff] Disable `differentiable_function_extract` explicit type assertion.
diff --git a/include/swift/SIL/SILInstruction.h b/include/swift/SIL/SILInstruction.h
index 4b99d34..39b2edb 100644
--- a/include/swift/SIL/SILInstruction.h
+++ b/include/swift/SIL/SILInstruction.h
@@ -8920,7 +8920,12 @@
SILModule &module);
public:
- /// Note: explicit extractee type may be specified only in lowered SIL.
+ /// Note: explicit extractee type is used to avoid inconsistent typing in:
+ /// - Canonical SIL, due to generic specialization.
+ /// - Lowered SIL, due to LoadableByAddress.
+ /// - Raw SIL, due to deserialization of canonical/lowered SIL functions.
+ /// See `TypeSubstCloner::visitDifferentiableFunctionExtractInst` for an
+ /// explanation of how explicit extractee type is used.
explicit DifferentiableFunctionExtractInst(
SILModule &module, SILDebugLocation debugLoc,
NormalDifferentiableFunctionTypeComponent extractee, SILValue function,
diff --git a/lib/SIL/IR/SILInstructions.cpp b/lib/SIL/IR/SILInstructions.cpp
index 59993c2..0e085f7 100644
--- a/lib/SIL/IR/SILInstructions.cpp
+++ b/lib/SIL/IR/SILInstructions.cpp
@@ -727,18 +727,6 @@
: getExtracteeType(function, extractee, module),
function.getOwnershipKind()),
Extractee(extractee), HasExplicitExtracteeType(extracteeType.hasValue()) {
-#ifndef NDEBUG
- if (extracteeType.hasValue()) {
- // Note: explicit extractee type is used to avoid inconsistent typing in:
- // - Canonical SIL, due to generic specialization.
- // - Lowered SIL, due to LoadableByAddress.
- // See `TypeSubstCloner::visitDifferentiableFunctionExtractInst` for an
- // explanation of how explicit extractee type is used.
- assert((module.getStage() == SILStage::Canonical ||
- module.getStage() == SILStage::Lowered) &&
- "Explicit type is valid only in canonical or lowered SIL");
- }
-#endif
}
SILType LinearFunctionExtractInst::
diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp
index fa034b1..792caf8 100644
--- a/lib/Serialization/DeserializeSIL.cpp
+++ b/lib/Serialization/DeserializeSIL.cpp
@@ -1186,7 +1186,7 @@
case SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT:
SILInstDifferentiableFunctionExtractLayout::readRecord(
scratch, TyID, TyCategory, ValID, /*extractee*/ Attr,
- /*hasExplicitExtracteeType*/ Attr2);
+ /*hasExplicitExtracteeType*/ Attr2, /*explicitExtracteeType*/ TyID2);
RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionExtractInst;
break;
case SIL_INST_LINEAR_FUNCTION_EXTRACT:
@@ -2747,8 +2747,11 @@
auto val = getLocalValue(ValID, silTy);
NormalDifferentiableFunctionTypeComponent extractee(Attr);
Optional<SILType> explicitExtracteeType = None;
- if (Attr2)
- explicitExtracteeType = silTy;
+ if (Attr2) {
+ auto extracteeASTType = MF->getType(TyID2);
+ explicitExtracteeType =
+ getSILType(extracteeASTType, SILValueCategory::Object, Fn);
+ }
ResultInst = Builder.createDifferentiableFunctionExtract(
Loc, extractee, val, explicitExtracteeType);
break;
diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h
index cd9cdd0..b357d71 100644
--- a/lib/Serialization/ModuleFormat.h
+++ b/lib/Serialization/ModuleFormat.h
@@ -56,7 +56,7 @@
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
-const uint16_t SWIFTMODULE_VERSION_MINOR = 589; // cache prespecialization decls.
+const uint16_t SWIFTMODULE_VERSION_MINOR = 590; // differentiable_function_extract explicit extractee type
/// A standard hash seed used for all string hashes in a serialized module.
///
diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h
index 28b60b3..8d14762 100644
--- a/lib/Serialization/SILFormat.h
+++ b/lib/Serialization/SILFormat.h
@@ -478,8 +478,9 @@
TypeIDField,
SILTypeCategoryField,
ValueIDField,
- BCFixed<2>, // extractee
- BCFixed<1> // has explicit extractee type?
+ BCFixed<2>, // extractee
+ BCFixed<1>, // has explicit extractee type?
+ TypeIDField // explicit extractee type
>;
using SILInstLinearFunctionExtractLayout = BCRecordLayout<
diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp
index 4656ea4..f6d2ae4 100644
--- a/lib/Serialization/SerializeSIL.cpp
+++ b/lib/Serialization/SerializeSIL.cpp
@@ -2294,11 +2294,13 @@
auto operandType = dfei->getOperand()->getType();
auto operandTypeRef = S.addTypeRef(operandType.getASTType());
auto rawExtractee = (unsigned)dfei->getExtractee();
+ auto extracteeTypeRef = S.addTypeRef(dfei->getType().getASTType());
SILInstDifferentiableFunctionExtractLayout::emitRecord(
Out, ScratchRecord,
SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code],
operandTypeRef, (unsigned)operandType.getCategory(), operandRef,
- rawExtractee, (unsigned)dfei->hasExplicitExtracteeType());
+ rawExtractee, (unsigned)dfei->hasExplicitExtracteeType(),
+ extracteeTypeRef);
break;
}
case SILInstructionKind::LinearFunctionExtractInst: {
diff --git a/test/AutoDiff/compiler_crashers_fixed/sr14004-cross-module-differentiation-differentiable-function-extract-inlining.swift b/test/AutoDiff/compiler_crashers_fixed/sr14004-cross-module-differentiation-differentiable-function-extract-inlining.swift
new file mode 100644
index 0000000..0b4d6f5
--- /dev/null
+++ b/test/AutoDiff/compiler_crashers_fixed/sr14004-cross-module-differentiation-differentiable-function-extract-inlining.swift
@@ -0,0 +1,34 @@
+// RUN: %empty-directory(%t)
+// RUN: %target-build-swift-dylib(%t/%target-library-name(Library)) -emit-module -emit-module-path %t/Library.swiftmodule -module-name Library -DLIBRARY %s
+// RUN: %target-build-swift -I %t -O -emit-module %s
+
+// SR-14004: Assertion failure due to function with `differentiable_function_extract`
+// with explicit extractee type being deserialized into a raw SIL module.
+
+#if LIBRARY
+
+import _Differentiation
+
+public struct Struct<Scalar>: Differentiable {}
+
+@differentiable
+public func foo<Scalar>(_ x: Struct<Scalar>) -> Struct<Scalar> { x }
+
+@inlinable
+@differentiable
+public func bar<Scalar>(_ x: Struct<Scalar>) -> Struct<Scalar> {
+ foo(x)
+}
+
+#else
+
+import _Differentiation
+import Library
+
+public func foo(
+ body: @differentiable (Struct<Float>) -> Struct<Float> = bar
+) {
+ fatalError()
+}
+
+#endif