[AutoDiff] Fix `differentiable_function_extract` extractee type serialization.
The explicit extractee type was previously not serialized.
The operand's type was incorrectly used instead.
diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp
index fa034b1..792caf8 100644
--- a/lib/Serialization/DeserializeSIL.cpp
+++ b/lib/Serialization/DeserializeSIL.cpp
@@ -1186,7 +1186,7 @@
case SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT:
SILInstDifferentiableFunctionExtractLayout::readRecord(
scratch, TyID, TyCategory, ValID, /*extractee*/ Attr,
- /*hasExplicitExtracteeType*/ Attr2);
+ /*hasExplicitExtracteeType*/ Attr2, /*explicitExtracteeType*/ TyID2);
RawOpCode = (unsigned)SILInstructionKind::DifferentiableFunctionExtractInst;
break;
case SIL_INST_LINEAR_FUNCTION_EXTRACT:
@@ -2747,8 +2747,11 @@
auto val = getLocalValue(ValID, silTy);
NormalDifferentiableFunctionTypeComponent extractee(Attr);
Optional<SILType> explicitExtracteeType = None;
- if (Attr2)
- explicitExtracteeType = silTy;
+ if (Attr2) {
+ auto extracteeASTType = MF->getType(TyID2);
+ explicitExtracteeType =
+ getSILType(extracteeASTType, SILValueCategory::Object, Fn);
+ }
ResultInst = Builder.createDifferentiableFunctionExtract(
Loc, extractee, val, explicitExtracteeType);
break;
diff --git a/lib/Serialization/ModuleFormat.h b/lib/Serialization/ModuleFormat.h
index cd9cdd0..b357d71 100644
--- a/lib/Serialization/ModuleFormat.h
+++ b/lib/Serialization/ModuleFormat.h
@@ -56,7 +56,7 @@
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
-const uint16_t SWIFTMODULE_VERSION_MINOR = 589; // cache prespecialization decls.
+const uint16_t SWIFTMODULE_VERSION_MINOR = 590; // differentiable_function_extract explicit extractee type
/// A standard hash seed used for all string hashes in a serialized module.
///
diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h
index 28b60b3..8d14762 100644
--- a/lib/Serialization/SILFormat.h
+++ b/lib/Serialization/SILFormat.h
@@ -478,8 +478,9 @@
TypeIDField,
SILTypeCategoryField,
ValueIDField,
- BCFixed<2>, // extractee
- BCFixed<1> // has explicit extractee type?
+ BCFixed<2>, // extractee
+ BCFixed<1>, // has explicit extractee type?
+ TypeIDField // explicit extractee type
>;
using SILInstLinearFunctionExtractLayout = BCRecordLayout<
diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp
index 4656ea4..f6d2ae4 100644
--- a/lib/Serialization/SerializeSIL.cpp
+++ b/lib/Serialization/SerializeSIL.cpp
@@ -2294,11 +2294,13 @@
auto operandType = dfei->getOperand()->getType();
auto operandTypeRef = S.addTypeRef(operandType.getASTType());
auto rawExtractee = (unsigned)dfei->getExtractee();
+ auto extracteeTypeRef = S.addTypeRef(dfei->getType().getASTType());
SILInstDifferentiableFunctionExtractLayout::emitRecord(
Out, ScratchRecord,
SILAbbrCodes[SILInstDifferentiableFunctionExtractLayout::Code],
operandTypeRef, (unsigned)operandType.getCategory(), operandRef,
- rawExtractee, (unsigned)dfei->hasExplicitExtracteeType());
+ rawExtractee, (unsigned)dfei->hasExplicitExtracteeType(),
+ extracteeTypeRef);
break;
}
case SILInstructionKind::LinearFunctionExtractInst: {