Merge pull request #27821 from apple/tensorflow-merge2
Merge swift-DEVELOPMENT-SNAPSHOT-2019-10-13-a into tensorflow
diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h
index d792b6e..ed94123 100644
--- a/include/swift/AST/AutoDiff.h
+++ b/include/swift/AST/AutoDiff.h
@@ -269,11 +269,11 @@
struct AutoDiffConfig {
IndexSubset *parameterIndices;
IndexSubset *resultIndices;
- GenericSignatureImpl* derivativeGenericSignature;
+ GenericSignature derivativeGenericSignature;
/*implicit*/ AutoDiffConfig(IndexSubset *parameterIndices,
IndexSubset *resultIndices,
- GenericSignatureImpl *derivativeGenericSignature)
+ GenericSignature derivativeGenericSignature)
: parameterIndices(parameterIndices), resultIndices(resultIndices),
derivativeGenericSignature(derivativeGenericSignature) {}
@@ -443,7 +443,7 @@
using swift::AutoDiffConfig;
using swift::AutoDiffDerivativeFunctionKind;
-using swift::GenericSignatureImpl;
+using swift::GenericSignature;
using swift::IndexSubset;
using swift::SILAutoDiffIndices;
@@ -453,27 +453,29 @@
static AutoDiffConfig getEmptyKey() {
auto *ptr = llvm::DenseMapInfo<void *>::getEmptyKey();
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
- static_cast<GenericSignatureImpl *>(ptr)};
+ DenseMapInfo<GenericSignature>::getEmptyKey()};
}
static AutoDiffConfig getTombstoneKey() {
auto *ptr = llvm::DenseMapInfo<void *>::getTombstoneKey();
return {static_cast<IndexSubset *>(ptr), static_cast<IndexSubset *>(ptr),
- static_cast<GenericSignatureImpl *>(ptr)};
+ DenseMapInfo<GenericSignature>::getTombstoneKey()};
}
static unsigned getHashValue(const AutoDiffConfig &Val) {
unsigned combinedHash = hash_combine(
~1U, DenseMapInfo<void *>::getHashValue(Val.parameterIndices),
DenseMapInfo<void *>::getHashValue(Val.resultIndices),
- DenseMapInfo<void *>::getHashValue(Val.derivativeGenericSignature));
+ DenseMapInfo<GenericSignature>::getHashValue(
+ Val.derivativeGenericSignature));
return combinedHash;
}
static bool isEqual(const AutoDiffConfig &LHS, const AutoDiffConfig &RHS) {
return LHS.parameterIndices == RHS.parameterIndices &&
LHS.resultIndices == RHS.resultIndices &&
- LHS.derivativeGenericSignature == RHS.derivativeGenericSignature;
+ DenseMapInfo<GenericSignature>::isEqual(LHS.derivativeGenericSignature,
+ RHS.derivativeGenericSignature);
}
};
diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def
index 57f88c4..98cd135 100644
--- a/include/swift/AST/DiagnosticsParse.def
+++ b/include/swift/AST/DiagnosticsParse.def
@@ -689,6 +689,8 @@
// SIL differentiability witnesses
ERROR(sil_diff_witness_expected_token,PointsToFirstBadToken,
"expected '%0' in differentiability witness", (StringRef))
+ERROR(sil_diff_witness_serialized_declaration,none,
+ "differentiability witness declaration should not be serialized", ())
// SIL Coverage Map
ERROR(sil_coverage_func_not_found, none,
diff --git a/include/swift/AST/IndexSubset.h b/include/swift/AST/IndexSubset.h
index 3bdb021..e2b8a98 100644
--- a/include/swift/AST/IndexSubset.h
+++ b/include/swift/AST/IndexSubset.h
@@ -57,6 +57,10 @@
/// The number of bit words in the index subset.
unsigned numBitWords;
+ static unsigned getNumBytesNeededForCapacity(unsigned capacity) {
+ return getNumBitWordsNeededForCapacity(capacity) * bitWordSize;
+ }
+
BitWord *getBitWordsData() {
return reinterpret_cast<BitWord *>(this + 1);
}
diff --git a/include/swift/Basic/LangOptions.h b/include/swift/Basic/LangOptions.h
index be1fa2e..1a893a5 100644
--- a/include/swift/Basic/LangOptions.h
+++ b/include/swift/Basic/LangOptions.h
@@ -331,6 +331,11 @@
/// `@differentiable` declaration attribute, etc.
bool EnableExperimentalDifferentiableProgramming = false;
+ // SWIFT_ENABLE_TENSORFLOW
+ /// Whether to enable forward mode differentiation.
+ bool EnableExperimentalForwardModeDifferentiation = false;
+ // SWIFT_ENABLE_TENSORFLOW END
+
/// Whether to enable #quote, #unquote and @quoted.
bool EnableExperimentalQuasiquotes = false;
diff --git a/include/swift/Option/Options.td b/include/swift/Option/Options.td
index 9b6cea3..a703542 100644
--- a/include/swift/Option/Options.td
+++ b/include/swift/Option/Options.td
@@ -428,6 +428,12 @@
def enable_experimental_differentiable_programming : Flag<["-"], "enable-experimental-differentiable-programming">,
Flags<[FrontendOption]>,
HelpText<"Enable experimental differentiable programming features">;
+// SWIFT_ENABLE_TENSORFLOW
+// NOTE: This flag will be removed when JVP/differential generation is robust.
+def enable_experimental_forward_mode_differentiation : Flag<["-"], "enable-experimental-forward-mode-differentiation">,
+ Flags<[FrontendOption]>,
+ HelpText<"Enable experimental forward mode differentiation">;
+// SWIFT_ENABLE_TENSORFLOW END
def enable_experimental_quasiquotes : Flag<["-"], "enable-experimental-quasiquotes">,
Flags<[FrontendOption]>,
diff --git a/include/swift/SIL/SILDifferentiabilityWitness.h b/include/swift/SIL/SILDifferentiabilityWitness.h
index 0a5ab1f..f37459f 100644
--- a/include/swift/SIL/SILDifferentiabilityWitness.h
+++ b/include/swift/SIL/SILDifferentiabilityWitness.h
@@ -43,26 +43,28 @@
{
private:
/// The module which contains the differentiability witness.
- SILModule &module;
+ SILModule &Module;
/// The linkage of the differentiability witness.
- SILLinkage linkage;
+ SILLinkage Linkage;
/// The original function.
- SILFunction *originalFunction;
+ SILFunction *OriginalFunction;
/// The autodiff configuration: parameter indices, result indices, derivative
/// generic signature (optional).
- AutoDiffConfig config;
+ AutoDiffConfig Config;
/// The JVP (Jacobian-vector products) derivative function.
- SILFunction *jvp;
+ SILFunction *JVP;
/// The VJP (vector-Jacobian products) derivative function.
- SILFunction *vjp;
+ SILFunction *VJP;
+ /// Whether or not this differentiability witness is a declaration.
+ bool IsDeclaration;
/// Whether or not this differentiability witness is serialized, which allows
/// devirtualization from another module.
- bool serialized;
+ bool IsSerialized;
/// The AST `@differentiable` or `@differentiating` attribute from which the
/// differentiability witness is generated. Used for diagnostics.
/// Null if the differentiability witness is parsed from SIL or if it is
/// deserialized.
- DeclAttribute *attribute = nullptr;
+ DeclAttribute *Attribute = nullptr;
SILDifferentiabilityWitness(SILModule &module, SILLinkage linkage,
SILFunction *originalFunction,
@@ -70,51 +72,60 @@
IndexSubset *resultIndices,
GenericSignature derivativeGenSig,
SILFunction *jvp, SILFunction *vjp,
- bool isSerialized, DeclAttribute *attribute)
- : module(module), linkage(linkage), originalFunction(originalFunction),
- config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
- jvp(jvp), vjp(vjp), serialized(isSerialized), attribute(attribute) {}
+ bool isDeclaration, bool isSerialized,
+ DeclAttribute *attribute)
+ : Module(module), Linkage(linkage), OriginalFunction(originalFunction),
+ Config(parameterIndices, resultIndices, derivativeGenSig.getPointer()),
+ JVP(jvp), VJP(vjp), IsDeclaration(isDeclaration),
+ IsSerialized(isSerialized), Attribute(attribute) {}
public:
- static SILDifferentiabilityWitness *create(
+ static SILDifferentiabilityWitness *createDeclaration(
+ SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
+ IndexSubset *parameterIndices, IndexSubset *resultIndices,
+ GenericSignature derivativeGenSig, DeclAttribute *attribute = nullptr);
+
+ static SILDifferentiabilityWitness *createDefinition(
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
bool isSerialized, DeclAttribute *attribute = nullptr);
SILDifferentiabilityWitnessKey getKey() const;
- SILModule &getModule() const { return module; }
- SILLinkage getLinkage() const { return linkage; }
- SILFunction *getOriginalFunction() const { return originalFunction; }
- const AutoDiffConfig &getConfig() const { return config; }
+ SILModule &getModule() const { return Module; }
+ SILLinkage getLinkage() const { return Linkage; }
+ SILFunction *getOriginalFunction() const { return OriginalFunction; }
+ const AutoDiffConfig &getConfig() const { return Config; }
IndexSubset *getParameterIndices() const {
- return config.parameterIndices;
+ return Config.parameterIndices;
}
IndexSubset *getResultIndices() const {
- return config.resultIndices;
+ return Config.resultIndices;
}
GenericSignature getDerivativeGenericSignature() const {
- return config.derivativeGenericSignature;
+ return Config.derivativeGenericSignature;
}
- SILFunction *getJVP() const { return jvp; }
- SILFunction *getVJP() const { return vjp; }
+ SILFunction *getJVP() const { return JVP; }
+ SILFunction *getVJP() const { return VJP; }
SILFunction *getDerivative(AutoDiffDerivativeFunctionKind kind) const {
switch (kind) {
- case AutoDiffDerivativeFunctionKind::JVP: return jvp;
- case AutoDiffDerivativeFunctionKind::VJP: return vjp;
+ case AutoDiffDerivativeFunctionKind::JVP: return JVP;
+ case AutoDiffDerivativeFunctionKind::VJP: return VJP;
}
}
- void setJVP(SILFunction *jvp) { this->jvp = jvp; }
- void setVJP(SILFunction *vjp) { this->vjp = vjp; }
+ void setJVP(SILFunction *jvp) { JVP = jvp; }
+ void setVJP(SILFunction *vjp) { VJP = vjp; }
void setDerivative(AutoDiffDerivativeFunctionKind kind,
SILFunction *derivative) {
switch (kind) {
- case AutoDiffDerivativeFunctionKind::JVP: jvp = derivative; break;
- case AutoDiffDerivativeFunctionKind::VJP: vjp = derivative; break;
+ case AutoDiffDerivativeFunctionKind::JVP: JVP = derivative; break;
+ case AutoDiffDerivativeFunctionKind::VJP: VJP = derivative; break;
}
}
- bool isSerialized() const { return serialized; }
- DeclAttribute *getAttribute() const { return attribute; }
+ bool isDeclaration() const { return IsDeclaration; }
+ bool isDefinition() const { return !IsDeclaration; }
+ bool isSerialized() const { return IsSerialized; }
+ DeclAttribute *getAttribute() const { return Attribute; }
/// Verify that the differentiability witness is well-formed.
void verify(const SILModule &module) const;
diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp
index fc9ba85..8905b96 100644
--- a/lib/AST/ASTContext.cpp
+++ b/lib/AST/ASTContext.cpp
@@ -4821,7 +4821,7 @@
if (existing)
return existing;
auto sizeToAlloc = sizeof(IndexSubset) +
- getNumBitWordsNeededForCapacity(capacity);
+ getNumBytesNeededForCapacity(capacity);
auto *buf = reinterpret_cast<IndexSubset *>(
ctx.Allocate(sizeToAlloc, alignof(IndexSubset)));
auto *newNode = new (buf) IndexSubset(indices);
diff --git a/lib/AST/ASTMangler.cpp b/lib/AST/ASTMangler.cpp
index 6871ce1..04d8b64 100644
--- a/lib/AST/ASTMangler.cpp
+++ b/lib/AST/ASTMangler.cpp
@@ -434,7 +434,7 @@
auto originalName = key.first;
auto *parameterIndices = key.second.parameterIndices;
auto *resultIndices = key.second.resultIndices;
- auto *derivativeGenericSignature = key.second.derivativeGenericSignature;
+ auto derivativeGenericSignature = key.second.derivativeGenericSignature;
Buffer << "AD__" << originalName << '_';
Buffer << "P" << parameterIndices->getString();
diff --git a/lib/Driver/ToolChains.cpp b/lib/Driver/ToolChains.cpp
index efd41ed..d8a3548 100644
--- a/lib/Driver/ToolChains.cpp
+++ b/lib/Driver/ToolChains.cpp
@@ -231,6 +231,10 @@
options::OPT_enable_experimental_dependencies);
inputArgs.AddLastArg(arguments,
options::OPT_experimental_dependency_include_intrafile);
+ // SWIFT_ENABLE_TENSORFLOW
+ inputArgs.AddLastArg(
+ arguments, options::OPT_enable_experimental_forward_mode_differentiation);
+ // SWIFT_ENABLE_TENSORFLOW END
inputArgs.AddLastArg(arguments,
options::OPT_enable_experimental_quasiquotes);
inputArgs.AddLastArg(arguments, options::OPT_package_description_version);
diff --git a/lib/Frontend/CompilerInvocation.cpp b/lib/Frontend/CompilerInvocation.cpp
index d2b9616..c07065a 100644
--- a/lib/Frontend/CompilerInvocation.cpp
+++ b/lib/Frontend/CompilerInvocation.cpp
@@ -358,6 +358,9 @@
if (Args.hasArg(OPT_experimental_dependency_include_intrafile))
Opts.ExperimentalDependenciesIncludeIntrafileOnes = true;
+ // TODO: Ignore if enable-experimental-differentiable-programming is false
+ Opts.EnableExperimentalForwardModeDifferentiation |=
+ Args.hasArg(OPT_enable_experimental_forward_mode_differentiation);
if (Args.hasArg(OPT_enable_experimental_quasiquotes))
Opts.EnableExperimentalQuasiquotes = true;
diff --git a/lib/IRGen/GenFunc.cpp b/lib/IRGen/GenFunc.cpp
index bdd375f..f25ffb2 100644
--- a/lib/IRGen/GenFunc.cpp
+++ b/lib/IRGen/GenFunc.cpp
@@ -735,6 +735,16 @@
// Create a new explosion for potentially reabstracted parameters.
Explosion args;
+ // SWIFT_ENABLE_TENSORFLOW
+ // The witness method self argument comes after polymorphic arguments (and is
+ // followed by the self type and the witness table). However, we may encounter
+ // the witness method self value before reaching the polymorphic arguments. So
+ // we create a special explosion for storing the witness method self value
+ // until it's time to add it to 'args'.
+ bool isWitnessMethodCallee = origType->getRepresentation() ==
+ SILFunctionTypeRepresentation::WitnessMethod;
+ Explosion witnessMethodSelfValue;
+
Address resultValueAddr;
{
@@ -775,6 +785,10 @@
// Reemit the parameters as unsubstituted.
for (unsigned i = 0; i < outType->getParameters().size(); ++i) {
+ // SWIFT_ENABLE_TENSORFLOW
+ bool isWitnessMethodCalleeSelf =
+ (isWitnessMethodCallee && i + 1 == origType->getParameters().size());
+
auto origParamInfo = origType->getParameters()[i];
auto &ti = IGM.getTypeInfoForLowered(origParamInfo.getType());
auto schema = ti.getSchema();
@@ -788,7 +802,8 @@
if (addr->getType() != ti.getStorageType()->getPointerTo())
addr = subIGF.Builder.CreateBitCast(addr,
ti.getStorageType()->getPointerTo());
- args.add(addr);
+ // SWIFT_ENABLE_TENSORFLOW
+ (isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args).add(addr);
continue;
}
@@ -796,8 +811,10 @@
// Indirect parameters need no mapping through the native calling
// convention.
if (isIndirectParam) {
- emitApplyArgument(subIGF, origParamInfo, outTypeParamInfo, origParams,
- args);
+ emitApplyArgument(
+ subIGF, origParamInfo, outTypeParamInfo, origParams,
+ // SWIFT_ENABLE_TENSORFLOW
+ (isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args));
continue;
}
@@ -824,7 +841,10 @@
Explosion nativeApplyArg = nativeSchemaOrigParam.mapIntoNative(
subIGF.IGM, subIGF, nonNativeApplyArg, origParamSILType, false);
assert(nonNativeApplyArg.empty());
- nativeApplyArg.transferInto(args, nativeApplyArg.size());
+ // SWIFT_ENABLE_TENSORFLOW
+ nativeApplyArg.transferInto(
+ (isWitnessMethodCalleeSelf ? witnessMethodSelfValue : args),
+ nativeApplyArg.size());
}
}
@@ -934,13 +954,6 @@
auto haveContextArgument =
calleeHasContext || hasSelfContextParameter(origType);
- // Witness method calls expect self, followed by the self type followed by,
- // the witness table at the end of the parameter list. But polymorphic
- // arguments come before this.
- bool isWitnessMethodCallee = origType->getRepresentation() ==
- SILFunctionTypeRepresentation::WitnessMethod;
- Explosion witnessMethodSelfValue;
-
// If there's a data pointer required, but it's a swift-retainable
// value being passed as the context, just forward it down.
if (!layout) {
diff --git a/lib/ParseSIL/ParseSIL.cpp b/lib/ParseSIL/ParseSIL.cpp
index f12aa7d..50ba2f3 100644
--- a/lib/ParseSIL/ParseSIL.cpp
+++ b/lib/ParseSIL/ParseSIL.cpp
@@ -6933,7 +6933,9 @@
/// '[' 'parameters' index-subset ']'
/// '[' 'results' index-subset ']'
/// ('[' 'where' derivatve-generic-signature-requirements ']')?
-/// sil-function-name ':' sil-type
+/// decl-sil-differentiability-witness-body?
+///
+/// decl-sil-differentiability-witness-body ::=
/// '{'
/// ('jvp' sil-function-name ':' sil-type)?
/// ('vjp' sil-function-name ':' sil-type)?
@@ -6949,9 +6951,6 @@
Optional<SILLinkage> linkage;
if (parseSILLinkage(linkage, P))
return true;
- // Default to public linkage.
- if (!linkage)
- linkage = SILLinkage::Public;
// Parse '[serialized]' flag (optional).
bool isSerialized = false;
@@ -6986,8 +6985,7 @@
P.diagnose(fnNameLoc, diag::expected_sil_function_type);
return true;
}
- fn = State.getGlobalNameForReference(name, fnType, fnNameLoc, true);
- State.TUState.PotentialZombieFns.insert(fn);
+ fn = State.getGlobalNameForReference(name, fnType, fnNameLoc);
return false;
};
@@ -7063,7 +7061,26 @@
nullptr);
}
- // Parse differentiability witness body.
+ auto origFnType = originalFn->getLoweredFunctionType();
+ auto *parameterIndexSet = IndexSubset::get(
+ P.Context, origFnType->getNumParameters(), parameterIndices);
+ auto *resultIndexSet = IndexSubset::get(
+ P.Context, origFnType->getNumResults(), resultIndices);
+
+ // If this is just a declaration, create the declaration now and return.
+ if (!P.Tok.is(tok::l_brace)) {
+ if (isSerialized) {
+ P.diagnose(lastLoc, diag::sil_diff_witness_serialized_declaration);
+ return true;
+ }
+
+ SILDifferentiabilityWitness::createDeclaration(
+ M, linkage ? *linkage : SILLinkage::DefaultForDeclaration, originalFn,
+ parameterIndexSet, resultIndexSet, derivativeGenSig);
+ return false;
+ }
+
+ // This is a definition, so parse differentiability witness body.
SILFunction *jvp = nullptr;
SILFunction *vjp = nullptr;
if (P.Tok.is(tok::l_brace)) {
@@ -7094,14 +7111,10 @@
return true;
}
- auto origFnType = originalFn->getLoweredFunctionType();
- auto *parameterIndexSet = IndexSubset::get(
- P.Context, origFnType->getNumParameters(), parameterIndices);
- auto *resultIndexSet = IndexSubset::get(
- P.Context, origFnType->getNumResults(), resultIndices);
- SILDifferentiabilityWitness::create(
- M, *linkage, originalFn, parameterIndexSet, resultIndexSet,
- derivativeGenSig, jvp, vjp, isSerialized);
+ SILDifferentiabilityWitness::createDefinition(
+ M, linkage ? *linkage : SILLinkage::DefaultForDefinition, originalFn,
+ parameterIndexSet, resultIndexSet, derivativeGenSig, jvp, vjp,
+ isSerialized);
return false;
}
// SWIFT_ENABLE_TENSORFLOW END
diff --git a/lib/SIL/SILDifferentiabilityWitness.cpp b/lib/SIL/SILDifferentiabilityWitness.cpp
index 0901692..2ad7bad 100644
--- a/lib/SIL/SILDifferentiabilityWitness.cpp
+++ b/lib/SIL/SILDifferentiabilityWitness.cpp
@@ -17,14 +17,14 @@
using namespace swift;
-SILDifferentiabilityWitness *SILDifferentiabilityWitness::create(
+SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDeclaration(
SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
- GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
- bool isSerialized, DeclAttribute *attribute) {
+ GenericSignature derivativeGenSig, DeclAttribute *attribute) {
auto *diffWitness = new (module) SILDifferentiabilityWitness(
module, linkage, originalFunction, parameterIndices, resultIndices,
- derivativeGenSig, jvp, vjp, isSerialized, attribute);
+ derivativeGenSig, /*jvp*/ nullptr, /*vjp*/ nullptr,
+ /*isDeclaration*/ true, /*isSerialized*/ false, attribute);
// Register the differentiability witness in the module.
assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
"Cannot create duplicate differentiability witness in a module");
@@ -33,6 +33,24 @@
return diffWitness;
}
+SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
+ SILModule &module, SILLinkage linkage, SILFunction *originalFunction,
+ IndexSubset *parameterIndices, IndexSubset *resultIndices,
+ GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
+ bool isSerialized, DeclAttribute *attribute) {
+ auto *diffWitness = new (module) SILDifferentiabilityWitness(
+ module, linkage, originalFunction, parameterIndices, resultIndices,
+ derivativeGenSig, jvp, vjp, /*isDeclaration*/ false, isSerialized,
+ attribute);
+ // Register the differentiability witness in the module.
+ assert(!module.DifferentiabilityWitnessMap.count(diffWitness->getKey()) &&
+ "Cannot create duplicate differentiability witness in a module");
+ module.DifferentiabilityWitnessMap[diffWitness->getKey()] = diffWitness;
+ module.getDifferentiabilityWitnessList().push_back(diffWitness);
+ return diffWitness;
+}
+
+
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
- return std::make_pair(originalFunction->getName(), getConfig());
+ return std::make_pair(getOriginalFunction()->getName(), getConfig());
}
diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp
index 653e127..2c2e73f 100644
--- a/lib/SIL/SILPrinter.cpp
+++ b/lib/SIL/SILPrinter.cpp
@@ -3164,11 +3164,11 @@
void SILDifferentiabilityWitness::print(
llvm::raw_ostream &OS, bool verbose) const {
OS << "// differentiability witness for "
- << demangleSymbol(originalFunction->getName()) << '\n';
+ << demangleSymbol(getOriginalFunction()->getName()) << '\n';
PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType();
// sil_differentiability_witness (linkage)?
OS << "sil_differentiability_witness ";
- printLinkage(OS, linkage, ForDefinition);
+ printLinkage(OS, getLinkage(), /*isDefinition*/ isDefinition());
// ([serialized])?
if (isSerialized())
OS << "[serialized] ";
@@ -3187,7 +3187,7 @@
if (auto derivativeGenSig = getDerivativeGenericSignature()) {
ArrayRef<Requirement> requirements;
SmallVector<Requirement, 4> requirementsScratch;
- auto *origGenEnv = originalFunction->getGenericEnvironment();
+ auto *origGenEnv = getOriginalFunction()->getGenericEnvironment();
if (derivativeGenSig) {
if (origGenEnv) {
requirementsScratch = derivativeGenSig->requirementsNotSatisfiedBy(
@@ -3210,18 +3210,22 @@
}
}
// @original-function-name : $original-sil-type
- printSILFunctionNameAndType(OS, originalFunction);
+ printSILFunctionNameAndType(OS, getOriginalFunction());
+
+ if (isDeclaration())
+ return;
+
// {
// jvp: @jvp-function-name : $jvp-sil-type
// vjp: @vjp-function-name : $vjp-sil-type
// }
OS << " {\n";
- if (jvp) {
+ if (auto *jvp = getJVP()) {
OS << " jvp: ";
printSILFunctionNameAndType(OS, jvp);
OS << '\n';
}
- if (vjp) {
+ if (auto *vjp = getVJP()) {
OS << " vjp: ";
printSILFunctionNameAndType(OS, vjp);
OS << '\n';
diff --git a/lib/SIL/SILVerifier.cpp b/lib/SIL/SILVerifier.cpp
index 5809f5b..deb21f5 100644
--- a/lib/SIL/SILVerifier.cpp
+++ b/lib/SIL/SILVerifier.cpp
@@ -5388,7 +5388,7 @@
if (!M.getOptions().VerifyAll)
return;
#endif
- auto origFnType = originalFunction->getLoweredFunctionType();
+ auto origFnType = getOriginalFunction()->getLoweredFunctionType();
CanGenericSignature derivativeCanGenSig;
if (auto derivativeGenSig = getDerivativeGenericSignature())
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
@@ -5408,7 +5408,7 @@
else
exit(1);
};
- if (jvp) {
+ if (auto *jvp = getJVP()) {
// TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
// to accept result indices.
auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType(
@@ -5418,7 +5418,7 @@
requireSameType(jvp->getLoweredFunctionType(), expectedJVPType,
"JVP type does not match expected JVP type");
}
- if (vjp) {
+ if (auto *vjp = getVJP()) {
// TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
// to result indices.
auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType(
diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp
index 15ec27f..3a62d72 100644
--- a/lib/SILGen/SILGen.cpp
+++ b/lib/SILGen/SILGen.cpp
@@ -810,7 +810,7 @@
bool reorderSelf = shouldReorderSelf();
CanGenericSignature derivativeCanGenSig;
- if (auto *derivativeGenSig = config.derivativeGenericSignature)
+ if (auto derivativeGenSig = config.derivativeGenericSignature)
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
// TODO(TF-835): Use simpler derivative generic signature logic below when
// type-checking no longer generates implicit `@differentiable` attributes.
@@ -832,7 +832,7 @@
// TODO(TF-919): Explore creating serialized differentiability witnesses.
// Currently, differentiability witnesses are never serialized to avoid
// deserialization issues where JVP/VJP functions cannot be found.
- auto *diffWitness = SILDifferentiabilityWitness::create(
+ auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
M, originalFunction->getLinkage(), originalFunction,
loweredParamIndices, config.resultIndices, derivativeCanGenSig,
/*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false);
diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp
index 575fbe7..166c21b 100644
--- a/lib/SILOptimizer/Mandatory/Differentiation.cpp
+++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp
@@ -65,12 +65,6 @@
"differentiation-skip-folding-differentiable-function-extraction",
llvm::cl::init(true));
-/// This flag is used to enable full JVP generation.
-/// It will be removed when JVP/differential generation is robust.
-static llvm::cl::opt<bool> RunJVPGeneration(
- "run-jvp-generation",
- llvm::cl::init(false));
-
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
@@ -1824,18 +1818,18 @@
class DifferentiableActivityCollection {
public:
- SmallDenseMap<GenericSignatureImpl *, DifferentiableActivityInfo> activityInfoMap;
+ SmallDenseMap<GenericSignature, DifferentiableActivityInfo> activityInfoMap;
SILFunction &function;
DominanceInfo *domInfo;
PostDominanceInfo *postDomInfo;
DifferentiableActivityInfo &getActivityInfo(
GenericSignature assocGenSig, AutoDiffDerivativeFunctionKind kind) {
- auto activityInfoLookup = activityInfoMap.find(assocGenSig.getPointer());
+ auto activityInfoLookup = activityInfoMap.find(assocGenSig);
if (activityInfoLookup != activityInfoMap.end())
return activityInfoLookup->getSecond();
auto insertion = activityInfoMap.insert(
- {assocGenSig.getPointer(), DifferentiableActivityInfo(*this, assocGenSig)});
+ {assocGenSig, DifferentiableActivityInfo(*this, assocGenSig)});
return insertion.first->getSecond();
}
@@ -8031,8 +8025,9 @@
// Diagnose:
// - Functions with no return.
// - Functions with unsupported control flow.
- if (RunJVPGeneration && (diagnoseNoReturn(*this, original, invoker) ||
- diagnoseUnsupportedControlFlow(*this, original, invoker)))
+ if (getASTContext().LangOpts.EnableExperimentalForwardModeDifferentiation &&
+ (diagnoseNoReturn(*this, original, invoker) ||
+ diagnoseUnsupportedControlFlow(*this, original, invoker)))
return true;
jvp = createEmptyJVP(*this, original, attr, isDerivativeFnExported);
@@ -8042,7 +8037,8 @@
// does not exist. If custom VJP exists but custom JVP does not, skip JVP
// generation because generated JVP may not match semantics of custom VJP.
// Instead, create an empty JVP.
- if (RunJVPGeneration && !vjp) {
+ if (getASTContext().LangOpts.EnableExperimentalForwardModeDifferentiation &&
+ !vjp) {
// JVP and differential generation do not currently support functions with
// multiple basic blocks.
if (original->getBlocks().size() > 1) {
diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp
index 5b94d78..c0b656b 100644
--- a/lib/Serialization/DeserializeSIL.cpp
+++ b/lib/Serialization/DeserializeSIL.cpp
@@ -3407,14 +3407,19 @@
(void)kind;
DeclID originalNameId, jvpNameId, vjpNameId;
- unsigned rawLinkage, isSerialized, numParameterIndices, numResultIndices;
+ unsigned rawLinkage, isDeclaration, isSerialized, numParameterIndices,
+ numResultIndices;
GenericSignatureID derivativeGenSigID;
ArrayRef<uint64_t> rawParameterAndResultIndices;
DifferentiabilityWitnessLayout::readRecord(
- scratch, originalNameId, rawLinkage, isSerialized, derivativeGenSigID,
- jvpNameId, vjpNameId, numParameterIndices, numResultIndices,
- rawParameterAndResultIndices);
+ scratch, originalNameId, rawLinkage, isDeclaration, isSerialized,
+ derivativeGenSigID, jvpNameId, vjpNameId, numParameterIndices,
+ numResultIndices, rawParameterAndResultIndices);
+
+ if (isDeclaration) {
+ assert(!isSerialized && "declaration must not be serialized");
+ }
auto linkage = fromStableSILLinkage(rawLinkage);
assert(linkage && "Expected value linkage for sil_differentiability_witness");
@@ -3424,11 +3429,15 @@
auto *original = getFuncForReference(originalName);
assert(original && "Original function must be found");
auto *jvp = getFuncForReference(jvpName);
- if (!jvpName.empty())
+ if (!jvpName.empty()) {
+ assert(!isDeclaration && "JVP must not be defined in declaration");
assert(jvp && "JVP function must be found if JVP name is not empty");
+ }
auto *vjp = getFuncForReference(vjpName);
- if (!vjpName.empty())
+ if (!vjpName.empty()) {
+ assert(!isDeclaration && "VJP must not be defined in declaration");
assert(vjp && "VJP function must be found if VJP name is not empty");
+ }
auto derivativeGenSig = MF->getGenericSignature(derivativeGenSigID);
SmallVector<unsigned, 8> parameterAndResultIndices(
@@ -3446,7 +3455,15 @@
ArrayRef<unsigned>(parameterAndResultIndices)
.take_back(numResultIndices));
- auto *diffWitness = SILDifferentiabilityWitness::create(
+ if (isDeclaration) {
+ auto *diffWitness = SILDifferentiabilityWitness::createDeclaration(
+ SILMod, *linkage, original, parameterIndices, resultIndices,
+ derivativeGenSig);
+ diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ false);
+ return diffWitness;
+ }
+
+ auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
SILMod, *linkage, original, parameterIndices, resultIndices,
derivativeGenSig, jvp, vjp, isSerialized);
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true);
diff --git a/lib/Serialization/SILFormat.h b/lib/Serialization/SILFormat.h
index 66fa50d..0f2f1c6 100644
--- a/lib/Serialization/SILFormat.h
+++ b/lib/Serialization/SILFormat.h
@@ -294,6 +294,7 @@
SIL_DIFFERENTIABILITY_WITNESS,
DeclIDField, // Original function name
SILLinkageField, // Linkage
+ BCFixed<1>, // Is declaration?
BCFixed<1>, // Is serialized?
GenericSignatureIDField, // Derivative function generic signature
DeclIDField, // JVP function name
diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp
index 6e7328f..8151f58 100644
--- a/lib/Serialization/SerializeSIL.cpp
+++ b/lib/Serialization/SerializeSIL.cpp
@@ -2586,6 +2586,7 @@
Out, ScratchRecord, SILAbbrCodes[DifferentiabilityWitnessLayout::Code],
addSILFunctionRef(original),
toStableSILLinkage(dw.getLinkage()),
+ dw.isDeclaration(),
dw.isSerialized(),
S.addGenericSignatureRef(dw.getDerivativeGenericSignature()),
jvpID, vjpID,
diff --git a/test/AutoDiff/control_flow_sil.swift b/test/AutoDiff/control_flow_sil.swift
index a3f9110..3f18db2 100644
--- a/test/AutoDiff/control_flow_sil.swift
+++ b/test/AutoDiff/control_flow_sil.swift
@@ -1,5 +1,6 @@
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
+// REQUIRES: asserts
// TODO: Add FileCheck tests.
diff --git a/test/AutoDiff/forward_mode_diagnostics.swift b/test/AutoDiff/forward_mode_diagnostics.swift
index f451e09..5d2f8fb 100644
--- a/test/AutoDiff/forward_mode_diagnostics.swift
+++ b/test/AutoDiff/forward_mode_diagnostics.swift
@@ -1,4 +1,4 @@
-// RUN: %target-swift-frontend -Xllvm -run-jvp-generation -emit-sil -verify %s
+// RUN: %target-swift-frontend -enable-experimental-forward-mode-differentiation -emit-sil -verify %s
// TODO: move these tests back into `autodiff_diagnostics.swift` once
// forward mode reaches feature parity with reverse mode.
diff --git a/test/AutoDiff/forward_mode_sil.swift b/test/AutoDiff/forward_mode_sil.swift
index ec40f38..8a490a0 100644
--- a/test/AutoDiff/forward_mode_sil.swift
+++ b/test/AutoDiff/forward_mode_sil.swift
@@ -1,5 +1,6 @@
-// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -run-jvp-generation -Xllvm -debug-only=differentiation %s 2>&1 | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
-// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -Xllvm -run-jvp-generation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
+// RUN: %target-swift-frontend -emit-sil -verify -enable-experimental-forward-mode-differentiation -Xllvm -debug-only=differentiation %s 2>&1 | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
+// RUN: %target-swift-frontend -emit-sil -verify -Xllvm -sil-print-after=differentiation -enable-experimental-forward-mode-differentiation -o /dev/null 2>&1 %s | %FileCheck %s -check-prefix=CHECK-SIL
+// REQUIRES: asserts
//===----------------------------------------------------------------------===//
diff --git a/test/AutoDiff/irgen_crashers.swift b/test/AutoDiff/irgen_crashers.swift
new file mode 100644
index 0000000..7860d89
--- /dev/null
+++ b/test/AutoDiff/irgen_crashers.swift
@@ -0,0 +1,12 @@
+// RUN: %target-swift-frontend -emit-ir %s
+
+// TF-917: `partial_apply` IRGen crash.
+public protocol TF_917: Differentiable {
+ @differentiable
+ func r<A>(_ a: A) -> Float
+}
+@differentiable
+public func tf_917<B: TF_917>(_ b: B) -> Float {
+ return b.r(0.0)
+}
+
diff --git a/test/AutoDiff/refcounting.swift b/test/AutoDiff/refcounting.swift
index ac29649..8fdb2c7 100644
--- a/test/AutoDiff/refcounting.swift
+++ b/test/AutoDiff/refcounting.swift
@@ -1,5 +1,6 @@
// RUN: %target-swift-frontend -emit-sil -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s -check-prefix=CHECK-DATA-STRUCTURES
// RUN: %target-swift-frontend -emit-sil -Xllvm -differentiation-skip-folding-differentiable-function-extraction %s | %FileCheck %s
+// REQUIRES: asserts
public class NonTrivialStuff : Equatable {
public init() {}
diff --git a/test/AutoDiff/sil_differentiability_witness.sil b/test/AutoDiff/sil_differentiability_witness.sil
index f3d41b0..3df2a0c 100644
--- a/test/AutoDiff/sil_differentiability_witness.sil
+++ b/test/AutoDiff/sil_differentiability_witness.sil
@@ -6,8 +6,6 @@
// RUN: %target-sil-opt %t/tmp.2.sib -module-name main | %FileCheck %s
// Round-trip parsing/printing and serialization/deserialization test.
-// NOTE: deserialization currently fails if public function bodies are removed
-// so that they are only declarations. This may require investigation.
sil_stage raw
@@ -78,3 +76,51 @@
// CHECK: jvp: @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
// CHECK: vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
// CHECK: }
+
+// Test SIL differentiability witness for bodiless original function, with defined jvp/vjp.
+
+sil @externalFn1 : $@convention(thin) (Float) -> Float
+
+sil @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
+bb0(%0 : $Float):
+ return undef : $(Float, @callee_guaranteed (Float) -> Float)
+}
+
+sil @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
+bb0(%0 : $Float):
+ return undef : $(Float, @callee_guaranteed (Float) -> Float)
+}
+
+sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
+ jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+ vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+}
+
+// CHECK-LABEL: // differentiability witness for externalFn1
+// CHECK: sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
+// CHECK: jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+// CHECK: }
+
+// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp.
+
+sil @externalFn2 : $@convention(thin) (Float) -> Float
+
+sil @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+
+sil @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+
+sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
+ jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+ vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
+}
+
+// Test SIL differentiability witness declaration.
+
+sil @externalFn3 : $@convention(thin) (Float) -> Float
+
+sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float
+
+// CHECK-LABEL: // differentiability witness for externalFn3
+// CHECK: sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float
+// CHECK-NOT: {
diff --git a/test/lit.cfg b/test/lit.cfg
index 1281490..fd1c59f 100644
--- a/test/lit.cfg
+++ b/test/lit.cfg
@@ -1538,7 +1538,7 @@
# TODO: Remove when forward mode AD support is robust.
config.target_run_simple_swift_forward_mode_differentiation = (
'%%empty-directory(%%t) && '
- '%s %s %%s -Xllvm -run-jvp-generation -o %%t/a.out %s -module-name main && '
+ '%s %s %%s -enable-experimental-forward-mode-differentiation -o %%t/a.out %s -module-name main && '
'%s %%t/a.out &&'
'%s %%t/a.out'
% (config.target_build_swift, mcp_opt, swift_tensorflow_extra_options, config.target_codesign, config.target_run))