blob: 575f6afc968b008c02f648316472aba5ac505854 [file] [log] [blame]
//===--- ADContext.cpp - Differentiation Context --------------*- C++ -*---===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// Per-module contextual information for the differentiation transform.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "differentiation"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
#include "swift/AST/DiagnosticsSIL.h"
#include "swift/SILOptimizer/PassManager/Transforms.h"
using llvm::DenseMap;
using llvm::SmallPtrSet;
using llvm::SmallVector;
namespace swift {
namespace autodiff {
//===----------------------------------------------------------------------===//
// Local helpers
//===----------------------------------------------------------------------===//
/// Given an operator name, such as '+', and a protocol, returns the '+'
/// operator. If the operator does not exist in the protocol, returns null.
static FuncDecl *findOperatorDeclInProtocol(DeclName operatorName,
ProtocolDecl *protocol) {
assert(operatorName.isOperator());
// Find the operator requirement in the given protocol declaration.
auto opLookup = protocol->lookupDirect(operatorName);
for (auto *decl : opLookup) {
if (!decl->isProtocolRequirement())
continue;
auto *fd = dyn_cast<FuncDecl>(decl);
if (!fd || !fd->isStatic() || !fd->isOperator())
continue;
return fd;
}
// Not found.
return nullptr;
}
//===----------------------------------------------------------------------===//
// ADContext methods
//===----------------------------------------------------------------------===//
ADContext::ADContext(SILModuleTransform &transform)
: transform(transform), module(*transform.getModule()),
passManager(*transform.getPassManager()) {}
/// Get the source file for the given `SILFunction`.
static SourceFile &getSourceFile(SILFunction *f) {
if (f->hasLocation())
if (auto *declContext = f->getLocation().getAsDeclContext())
if (auto *parentSourceFile = declContext->getParentSourceFile())
return *parentSourceFile;
for (auto *file : f->getModule().getSwiftModule()->getFiles())
if (auto *sourceFile = dyn_cast<SourceFile>(file))
return *sourceFile;
llvm_unreachable("Could not resolve SourceFile from SILFunction");
}
SynthesizedFileUnit &
ADContext::getOrCreateSynthesizedFile(SILFunction *original) {
auto &SF = getSourceFile(original);
return SF.getOrCreateSynthesizedFile();
}
FuncDecl *ADContext::getPlusDecl() const {
if (!cachedPlusFn) {
cachedPlusFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+"),
additiveArithmeticProtocol);
assert(cachedPlusFn && "AdditiveArithmetic.+ not found");
}
return cachedPlusFn;
}
FuncDecl *ADContext::getPlusEqualDecl() const {
if (!cachedPlusEqualFn) {
cachedPlusEqualFn = findOperatorDeclInProtocol(astCtx.getIdentifier("+="),
additiveArithmeticProtocol);
assert(cachedPlusEqualFn && "AdditiveArithmetic.+= not found");
}
return cachedPlusEqualFn;
}
void ADContext::cleanUp() {
// Delete all references to generated functions.
for (auto fnRef : generatedFunctionReferences) {
if (auto *fnRefInst =
peerThroughFunctionConversions<FunctionRefInst>(fnRef)) {
fnRefInst->replaceAllUsesWithUndef();
fnRefInst->eraseFromParent();
}
}
// Delete all generated functions.
for (auto *generatedFunction : generatedFunctions) {
LLVM_DEBUG(getADDebugStream() << "Deleting generated function "
<< generatedFunction->getName() << '\n');
generatedFunction->dropAllReferences();
transform.notifyWillDeleteFunction(generatedFunction);
module.eraseFunction(generatedFunction);
}
}
DifferentiableFunctionInst *ADContext::createDifferentiableFunction(
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
IndexSubset *resultIndices, SILValue original,
Optional<std::pair<SILValue, SILValue>> derivativeFunctions) {
auto *dfi = builder.createDifferentiableFunction(
loc, parameterIndices, resultIndices, original, derivativeFunctions);
processedDifferentiableFunctionInsts.erase(dfi);
return dfi;
}
LinearFunctionInst *ADContext::createLinearFunction(
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
SILValue original, Optional<SILValue> transposeFunction) {
auto *lfi = builder.createLinearFunction(loc, parameterIndices, original,
transposeFunction);
processedLinearFunctionInsts.erase(lfi);
return lfi;
}
DifferentiableFunctionExpr *
ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) {
return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>();
}
LinearFunctionExpr *
ADContext::findDifferentialOperator(LinearFunctionInst *inst) {
return inst->getLoc().getAsASTNode<LinearFunctionExpr>();
}
} // end namespace autodiff
} // end namespace swift