blob: 0694ab596a048b515e798db9b407f438329a99ff [file] [log] [blame]
//===--- ADContext.h - Differentiation Context ----------------*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Per-module contextual information for the differentiation transform.
//
//===----------------------------------------------------------------------===//
#ifndef SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
#define SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H
#include "swift/SILOptimizer/Differentiation/Common.h"
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/AST/Expr.h"
#include "swift/AST/SynthesizedFileUnit.h"
#include "swift/SIL/SILBuilder.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
namespace swift {
class ASTContext;
class DifferentiableFunctionExpr;
class DifferentiableFunctionInst;
class FuncDecl;
class SILDifferentiabilityWitness;
class SILFunction;
class SILModuleTransform;
class SILModule;
class SILPassManager;
namespace autodiff {
/// Stores `apply` instruction information calculated by VJP generation.
struct NestedApplyInfo {
/// The differentiation config that is used to differentiate this `apply`
/// instruction.
AutoDiffConfig config;
/// The original pullback type before reabstraction. `None` if the pullback
/// type is not reabstracted.
Optional<CanSILFunctionType> originalPullbackType;
};
/// Per-module contextual information for the Differentiation pass.
class ADContext {
private:
/// Reference to the main transform.
SILModuleTransform &transform;
/// The module where Differentiation is performed on.
SILModule &module;
/// AST context.
ASTContext &astCtx = module.getASTContext();
/// Shared pass manager.
SILPassManager &passManager;
/// The worklist (stack) of `differentiable_function` instructions to be
/// processed.
llvm::SmallVector<DifferentiableFunctionInst *, 32>
differentiableFunctionInsts;
/// The worklist (stack) of `linear_function` instructions to be processed.
llvm::SmallVector<LinearFunctionInst *, 32> linearFunctionInsts;
/// The set of `differentiable_function` instructions that have been
/// processed. Used to avoid reprocessing invalidated instructions.
/// NOTE(TF-784): if we use `CanonicalizeInstruction` subclass to replace
/// `ADContext::processDifferentiableFunctionInst`, this field may be removed.
llvm::SmallPtrSet<DifferentiableFunctionInst *, 32>
processedDifferentiableFunctionInsts;
/// The set of `linear_function` instructions that have been processed. Used
/// to avoid reprocessing invalidated instructions.
/// NOTE(TF-784): if we use `CanonicalizeInstruction` subclass to replace
/// `ADContext::processLinearFunctionInst`, this field may be removed.
llvm::SmallPtrSet<LinearFunctionInst *, 32> processedLinearFunctionInsts;
/// Mapping from witnesses to invokers.
/// `SmallMapVector` is used for deterministic insertion order iteration.
llvm::SmallMapVector<SILDifferentiabilityWitness *, DifferentiationInvoker,
32>
invokers;
/// Mapping from original `apply` instructions to their corresponding
/// `NestedApplyInfo`s.
llvm::DenseMap<ApplyInst *, NestedApplyInfo> nestedApplyInfo;
/// List of generated functions (JVPs, VJPs, pullbacks, and thunks).
/// Saved for deletion during cleanup.
llvm::SmallVector<SILFunction *, 32> generatedFunctions;
/// List of references to generated functions.
/// Saved for deletion during cleanup.
llvm::SmallVector<SILValue, 32> generatedFunctionReferences;
/// The AdditiveArithmetic protocol in the standard library.
ProtocolDecl *additiveArithmeticProtocol =
astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
/// `AdditiveArithmetic.+` declaration.
mutable FuncDecl *cachedPlusFn = nullptr;
/// `AdditiveArithmetic.+=` declaration.
mutable FuncDecl *cachedPlusEqualFn = nullptr;
public:
/// Construct an ADContext for the given module.
explicit ADContext(SILModuleTransform &transform);
// No copying.
ADContext(const ADContext &) = delete;
//--------------------------------------------------------------------------//
// General utilities
//--------------------------------------------------------------------------//
SILModuleTransform &getTransform() const { return transform; }
SILModule &getModule() const { return module; }
ASTContext &getASTContext() const { return module.getASTContext(); }
SILPassManager &getPassManager() const { return passManager; }
Lowering::TypeConverter &getTypeConverter() { return module.Types; }
llvm::SmallVectorImpl<DifferentiableFunctionInst *> &
getDifferentiableFunctionInstWorklist() {
return differentiableFunctionInsts;
}
llvm::SmallVectorImpl<LinearFunctionInst *> &getLinearFunctionInstWorklist() {
return linearFunctionInsts;
}
/// Get or create the synthesized file for the given `SILFunction`.
/// Used by `LinearMapInfo` for adding generated linear map struct and
/// branching trace enum declarations.
SynthesizedFileUnit &getOrCreateSynthesizedFile(SILFunction *original);
/// Returns true if the given `differentiable_function` instruction has
/// already been processed.
bool
isDifferentiableFunctionInstProcessed(DifferentiableFunctionInst *dfi) const {
return processedDifferentiableFunctionInsts.count(dfi);
}
/// Adds the given `differentiable_function` instruction to the worklist.
void
markDifferentiableFunctionInstAsProcessed(DifferentiableFunctionInst *dfi) {
processedDifferentiableFunctionInsts.insert(dfi);
}
/// Returns true if the given `linear_function` instruction has already been
/// processed.
bool isLinearFunctionInstProcessed(LinearFunctionInst *lfi) const {
return processedLinearFunctionInsts.count(lfi);
}
/// Adds the given `linear_function` instruction to the worklist.
void markLinearFunctionInstAsProcessed(LinearFunctionInst *lfi) {
processedLinearFunctionInsts.insert(lfi);
}
const llvm::SmallMapVector<SILDifferentiabilityWitness *,
DifferentiationInvoker, 32> &
getInvokers() const {
return invokers;
}
void addInvoker(SILDifferentiabilityWitness *witness) {
assert(!invokers.count(witness) &&
"Differentiability witness already has an invoker");
invokers.insert({witness, DifferentiationInvoker(witness)});
}
llvm::DenseMap<ApplyInst *, NestedApplyInfo> &getNestedApplyInfo() {
return nestedApplyInfo;
}
void recordGeneratedFunction(SILFunction *function) {
generatedFunctions.push_back(function);
}
void recordGeneratedFunctionReference(SILValue functionRef) {
generatedFunctionReferences.push_back(functionRef);
}
ProtocolDecl *getAdditiveArithmeticProtocol() const {
return additiveArithmeticProtocol;
}
FuncDecl *getPlusDecl() const;
FuncDecl *getPlusEqualDecl() const;
/// Cleans up all the internal state.
void cleanUp();
/// Creates an `differentiable_function` instruction using the given builder
/// and arguments. Erase the newly created instruction from the processed set,
/// if it exists - it may exist in the processed set if it has the same
/// pointer value as a previously processed and deleted instruction.
/// TODO(TF-784): The pointer reuse is a real concern and the use of
/// `CanonicalizeInstruction` may get rid of the need for this workaround.
DifferentiableFunctionInst *createDifferentiableFunction(
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
IndexSubset *resultIndices, SILValue original,
Optional<std::pair<SILValue, SILValue>> derivativeFunctions = None);
/// Creates a `linear_function` instruction using the given builder
/// and arguments. Erase the newly created instruction from the processed set,
/// if it exists - it may exist in the processed set if it has the same
/// pointer value as a previously processed and deleted instruction.
LinearFunctionInst *
createLinearFunction(SILBuilder &builder, SILLocation loc,
IndexSubset *parameterIndices, SILValue original,
Optional<SILValue> transposeFunction = None);
// Given a `differentiable_function` instruction, finds the corresponding
// differential operator used in the AST. If no differential operator is
// found, return nullptr.
DifferentiableFunctionExpr *
findDifferentialOperator(DifferentiableFunctionInst *inst);
// Given a `linear_function` instruction, finds the corresponding differential
// operator used in the AST. If no differential operator is found, return
// nullptr.
LinearFunctionExpr *findDifferentialOperator(LinearFunctionInst *inst);
template <typename... T, typename... U>
InFlightDiagnostic diagnose(SourceLoc loc, Diag<T...> diag,
U &&... args) const {
return getASTContext().Diags.diagnose(loc, diag, std::forward<U>(args)...);
}
/// Given an instruction and a differentiation task associated with the
/// parent function, emits a "not differentiable" error based on the task. If
/// the task is indirect, emits notes all the way up to the outermost task,
/// and emits an error at the outer task. Otherwise, emits an error directly.
template <typename... T, typename... U>
InFlightDiagnostic
emitNondifferentiabilityError(SILInstruction *inst,
DifferentiationInvoker invoker, Diag<T...> diag,
U &&... args);
/// Given a value and a differentiation task associated with the parent
/// function, emits a "not differentiable" error based on the task. If the
/// task is indirect, emits notes all the way up to the outermost task, and
/// emits an error at the outer task. Otherwise, emits an error directly.
template <typename... T, typename... U>
InFlightDiagnostic
emitNondifferentiabilityError(SILValue value, DifferentiationInvoker invoker,
Diag<T...> diag, U &&... args);
/// Emit a "not differentiable" error based on the given differentiation task
/// and diagnostic.
template <typename... T, typename... U>
InFlightDiagnostic
emitNondifferentiabilityError(SourceLoc loc, DifferentiationInvoker invoker,
Diag<T...> diag, U &&... args);
};
template <typename... T, typename... U>
InFlightDiagnostic
ADContext::emitNondifferentiabilityError(SILValue value,
DifferentiationInvoker invoker,
Diag<T...> diag, U &&... args) {
LLVM_DEBUG({
getADDebugStream() << "Diagnosing non-differentiability.\n";
auto &s = getADDebugStream() << "For value:\n";
value->printInContext(s);
getADDebugStream() << "With invoker:\n" << invoker << '\n';
});
// If instruction does not have a valid location, use the function location
// as a fallback. Improves diagnostics in some cases.
auto valueLoc = getValidLocation(value).getSourceLoc();
return emitNondifferentiabilityError(valueLoc, invoker, diag,
std::forward<U>(args)...);
}
template <typename... T, typename... U>
InFlightDiagnostic
ADContext::emitNondifferentiabilityError(SILInstruction *inst,
DifferentiationInvoker invoker,
Diag<T...> diag, U &&... args) {
LLVM_DEBUG({
getADDebugStream() << "Diagnosing non-differentiability.\n";
auto &s = getADDebugStream() << "For instruction:\n";
inst->printInContext(s);
getADDebugStream() << "With invoker:\n" << invoker << '\n';
});
// If instruction does not have a valid location, use the function location
// as a fallback. Improves diagnostics for `ref_element_addr` generated in
// synthesized stored property getters.
auto instLoc = getValidLocation(inst).getSourceLoc();
return emitNondifferentiabilityError(instLoc, invoker, diag,
std::forward<U>(args)...);
}
template <typename... T, typename... U>
InFlightDiagnostic
ADContext::emitNondifferentiabilityError(SourceLoc loc,
DifferentiationInvoker invoker,
Diag<T...> diag, U &&... args) {
switch (invoker.getKind()) {
// For `differentiable_function` instructions: if the
// `differentiable_function` instruction comes from a differential operator,
// emit an error on the expression and a note on the non-differentiable
// operation. Otherwise, emit both an error and note on the
// non-differentiation operation.
case DifferentiationInvoker::Kind::DifferentiableFunctionInst: {
auto *inst = invoker.getDifferentiableFunctionInst();
if (auto *expr = findDifferentialOperator(inst)) {
diagnose(expr->getLoc(), diag::autodiff_function_not_differentiable_error)
.highlight(expr->getSubExpr()->getSourceRange());
return diagnose(loc, diag, std::forward<U>(args)...);
}
diagnose(loc, diag::autodiff_expression_not_differentiable_error);
return diagnose(loc, diag, std::forward<U>(args)...);
}
// For `linear_function` instructions: if the `linear_function` instruction
// comes from a differential operator, emit an error on the expression and a
// note on the non-differentiable operation. Otherwise, emit both an error and
// note on the non-differentiation operation.
case DifferentiationInvoker::Kind::LinearFunctionInst: {
auto *inst = invoker.getLinearFunctionInst();
if (auto *expr = findDifferentialOperator(inst)) {
diagnose(expr->getLoc(), diag::autodiff_function_not_differentiable_error)
.highlight(expr->getSubExpr()->getSourceRange());
return diagnose(loc, diag, std::forward<U>(args)...);
}
diagnose(loc, diag::autodiff_expression_not_differentiable_error);
return diagnose(loc, diag, std::forward<U>(args)...);
}
// For differentiability witnesses: try to find a `@differentiable` or
// `@derivative` attribute. If an attribute is found, emit an error on it;
// otherwise, emit an error on the original function.
case DifferentiationInvoker::Kind::SILDifferentiabilityWitnessInvoker: {
auto *witness = invoker.getSILDifferentiabilityWitnessInvoker();
auto *original = witness->getOriginalFunction();
// If the witness has an associated attribute, emit an error at its
// location.
if (auto *attr = witness->getAttribute()) {
diagnose(attr->getLocation(),
diag::autodiff_function_not_differentiable_error)
.highlight(attr->getRangeWithAt());
// Emit informative note.
bool emittedNote = false;
// If the witness comes from an implicit `@differentiable` attribute
// inherited from a protocol requirement's `@differentiable` attribute,
// emit a note on the inherited attribute.
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(attr)) {
auto inheritedAttrLoc =
diffAttr->getImplicitlyInheritedDifferentiableAttrLocation();
if (inheritedAttrLoc.isValid()) {
diagnose(inheritedAttrLoc,
diag::autodiff_implicitly_inherited_differentiable_attr_here)
.highlight(inheritedAttrLoc);
emittedNote = true;
}
}
// Otherwise, emit a note on the original function.
if (!emittedNote) {
diagnose(original->getLocation().getSourceLoc(),
diag::autodiff_when_differentiating_function_definition);
}
}
// Otherwise, emit an error on the original function.
else {
diagnose(original->getLocation().getSourceLoc(),
diag::autodiff_function_not_differentiable_error);
}
return diagnose(loc, diag, std::forward<U>(args)...);
}
// For indirect differentiation, emit a "not differentiable" note on the
// expression first. Then emit an error at the source invoker of
// differentiation, and a "when differentiating this" note at each indirect
// invoker.
case DifferentiationInvoker::Kind::IndirectDifferentiation: {
SILInstruction *inst;
SILDifferentiabilityWitness *witness;
std::tie(inst, witness) = invoker.getIndirectDifferentiation();
auto invokerLookup = invokers.find(witness);
assert(invokerLookup != invokers.end() && "Expected parent invoker");
emitNondifferentiabilityError(
inst, invokerLookup->second,
diag::autodiff_expression_not_differentiable_note);
return diagnose(loc, diag::autodiff_when_differentiating_function_call);
}
}
llvm_unreachable("Invalid invoker kind"); // silences MSVC C4715
}
} // end namespace autodiff
} // end namespace swift
#endif // SWIFT_SILOPTIMIZER_UTILS_DIFFERENTIATION_ADCONTEXT_H