blob: 950d97d93cf4d206f71dbb552e9e86422cd2c700 [file] [log] [blame]
//===--- CallerAnalysis.cpp - Determine callsites to a function ----------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "sil-caller-analysis"
#include "swift/SILOptimizer/Analysis/CallerAnalysis.h"
#include "swift/SIL/InstructionUtils.h"
#include "swift/SIL/SILModule.h"
#include "swift/SILOptimizer/Utils/Local.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/YAMLTraits.h"
using namespace swift;
namespace {
using FunctionInfo = CallerAnalysis::FunctionInfo;
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// CallerAnalysis::FunctionInfo
//===----------------------------------------------------------------------===//
CallerAnalysis::FunctionInfo::FunctionInfo(SILFunction *f)
: callerStates(),
// TODO: Make this more aggressive by considering
// final/visibility/etc.
mayHaveIndirectCallers(canBeCalledIndirectly(f->getRepresentation())) {}
//===----------------------------------------------------------------------===//
// CallerAnalysis::ApplySiteFinderVisitor
//===----------------------------------------------------------------------===//
struct CallerAnalysis::ApplySiteFinderVisitor
: SILInstructionVisitor<ApplySiteFinderVisitor, bool> {
CallerAnalysis *analysis;
SILFunction *callerFn;
FunctionInfo &callerInfo;
#ifndef NDEBUG
SmallPtrSet<SILInstruction *, 8> visitedCallSites;
SmallSetVector<SILInstruction *, 8> callSitesThatMustBeVisited;
#endif
ApplySiteFinderVisitor(CallerAnalysis *analysis, SILFunction *callerFn)
: analysis(analysis), callerFn(callerFn),
callerInfo(analysis->unsafeGetFunctionInfo(callerFn)) {}
~ApplySiteFinderVisitor();
bool visitSILInstruction(SILInstruction *) { return false; }
bool visitFunctionRefInst(FunctionRefInst *fri) {
return visitFunctionRefBaseInst(fri);
}
bool visitDynamicFunctionRefInst(DynamicFunctionRefInst *fri) {
return visitFunctionRefBaseInst(fri);
}
bool
visitPreviousDynamicFunctionRefInst(PreviousDynamicFunctionRefInst *fri) {
return visitFunctionRefBaseInst(fri);
}
bool visitFunctionRefBaseInst(FunctionRefBaseInst *fri);
void process();
void processApplySites(ArrayRef<ApplySite> applySites);
void processApplySites(ArrayRef<FullApplySite> applySites);
void checkCallSiteInvariants(SILInstruction &i);
};
void CallerAnalysis::ApplySiteFinderVisitor::processApplySites(
ArrayRef<ApplySite> applySites) {
// For now we just verify our invariants. If we need to update other
// non-NDEBUG state related to apply sites, this should be updated.
#ifndef NDEBUG
for (auto applySite : applySites) {
visitedCallSites.insert(applySite.getInstruction());
callSitesThatMustBeVisited.remove(applySite.getInstruction());
}
#endif
}
void CallerAnalysis::ApplySiteFinderVisitor::processApplySites(
ArrayRef<FullApplySite> applySites) {
// For now we just verify our invariants. If we need to update other
// non-NDEBUG state related to apply sites, this should be updated.
#ifndef NDEBUG
for (auto applySite : applySites) {
visitedCallSites.insert(applySite.getInstruction());
callSitesThatMustBeVisited.remove(applySite.getInstruction());
}
#endif
}
CallerAnalysis::ApplySiteFinderVisitor::~ApplySiteFinderVisitor() {
#ifndef NDEBUG
if (callSitesThatMustBeVisited.empty())
return;
llvm::errs() << "Found unhandled call sites!\n";
while (callSitesThatMustBeVisited.size()) {
auto *i = callSitesThatMustBeVisited.pop_back_val();
llvm::errs() << "Inst: " << *i;
}
assert(false && "Unhandled call site?!");
#endif
}
bool CallerAnalysis::ApplySiteFinderVisitor::visitFunctionRefBaseInst(
FunctionRefBaseInst *fri) {
auto optResult = findLocalApplySites(fri);
auto *calleeFn = fri->getReferencedFunction();
FunctionInfo &calleeInfo = analysis->unsafeGetFunctionInfo(calleeFn);
// First make an edge from our callerInfo to our calleeState for invalidation
// purposes.
callerInfo.calleeStates.insert(calleeFn);
// Then grab our callee state and update it with state for this caller.
auto iter = calleeInfo.callerStates.insert({callerFn, {}});
// If we succeeded in inserting a new value, put in an optimistic
// value for escaping.
if (iter.second) {
iter.first->second.isDirectCallerSetComplete = true;
}
// Then check if we found any information at all from our result. If we
// didn't, then mark this as escaping and bail.
if (!optResult.hasValue()) {
iter.first->second.isDirectCallerSetComplete = false;
return true;
}
auto &result = optResult.getValue();
// Ok. We know that we have some sort of information. Merge that information
// into our information.
iter.first->second.isDirectCallerSetComplete &= !result.isEscaping();
if (result.fullApplySites.size()) {
iter.first->second.hasFullApply = true;
processApplySites(llvm::makeArrayRef(result.fullApplySites));
}
if (result.partialApplySites.size()) {
auto optMin = iter.first->second.getNumPartiallyAppliedArguments();
unsigned min = optMin.getValueOr(UINT_MAX);
for (ApplySite partialSite : result.partialApplySites) {
min = std::min(min, partialSite.getNumArguments());
}
iter.first->second.setNumPartiallyAppliedArguments(min);
processApplySites(result.partialApplySites);
}
return true;
}
void CallerAnalysis::ApplySiteFinderVisitor::checkCallSiteInvariants(
SILInstruction &i) {
#ifndef NDEBUG
if (auto apply = FullApplySite::isa(&i)) {
if (apply.getCalleeFunction() && !visitedCallSites.count(&i)) {
callSitesThatMustBeVisited.insert(&i);
}
return;
}
// Make sure that we are in sync with looking for partial apply callees.
if (auto *pai = dyn_cast<PartialApplyInst>(&i)) {
if (pai->getCalleeFunction() && !visitedCallSites.count(&i)) {
callSitesThatMustBeVisited.insert(pai);
}
return;
}
#endif
}
void CallerAnalysis::ApplySiteFinderVisitor::process() {
for (auto &block : *callerFn) {
for (auto &i : block) {
#ifndef NDEBUG
// If this is a call site that we visited as part of seeing a different
// function_ref, skip it. We know that it has been processed correctly.
//
// NOTE: This is only used in NDEBUG builds since we only use this as part
// of the verification that we can find all callees going forward along
// def-use edges that FullApplySite is able to track backwards along
// def-use edges.
if (visitedCallSites.count(&i))
continue;
#endif
// Try to find the apply sites for this specific FRI.
if (visit(&i))
continue;
#ifndef NDEBUG
checkCallSiteInvariants(i);
#endif
}
}
}
//===----------------------------------------------------------------------===//
// CallerAnalysis
//===----------------------------------------------------------------------===//
// NOTE: This is only meant to be used by external users of CallerAnalysis since
// it recomputes our invalidated results. For internal uses, please instead use
// getOrInsertFunctionInfo or unsafeGetFunctionInfo.
const FunctionInfo &CallerAnalysis::getFunctionInfo(SILFunction *f) const {
// Recompute every function in the invalidated function list and empty the
// list.
auto &self = const_cast<CallerAnalysis &>(*this);
self.processRecomputeFunctionList();
return self.unsafeGetFunctionInfo(f);
}
// Private only version of this function for mutable callers that tries to
// initialize a new f.
FunctionInfo &CallerAnalysis::getOrInsertFunctionInfo(SILFunction *f) {
LLVM_DEBUG(llvm::dbgs() << "CallerAnalysis: Creating caller info for: "
<< f->getName() << "\n");
return funcInfos.try_emplace(f, f).first->second;
}
FunctionInfo &CallerAnalysis::unsafeGetFunctionInfo(SILFunction *f) {
auto r = funcInfos.find(f);
assert(r != funcInfos.end() && "Function does not have functionInfo!");
return r->second;
}
const FunctionInfo &
CallerAnalysis::unsafeGetFunctionInfo(SILFunction *f) const {
auto r = funcInfos.find(f);
assert(r != funcInfos.end() && "Function does not have functionInfo!");
return r->second;
}
CallerAnalysis::CallerAnalysis(SILModule *m)
: SILAnalysis(SILAnalysisKind::Caller), mod(*m) {
// When we start we create a function info for each f and add all f to the
// recompute function list.
for (auto &f : mod) {
getOrInsertFunctionInfo(&f);
recomputeFunctionList.insert(&f);
}
}
void CallerAnalysis::processFunctionCallSites(SILFunction *callerFn) {
ApplySiteFinderVisitor visitor(this, callerFn);
visitor.process();
}
void CallerAnalysis::invalidateAllInfo(SILFunction *f) {
// Look up the callees that our caller refers to and invalidate any
// values that point back at the caller.
FunctionInfo &fInfo = unsafeGetFunctionInfo(f);
// Then we first eliminate any callees that we point at.
invalidateKnownCallees(f, fInfo);
// And then eliminate any caller edges that we need.
while (fInfo.callerStates.size()) {
auto back = fInfo.callerStates.back();
SILFunction *caller = back.first;
auto &callerInfo = unsafeGetFunctionInfo(caller);
LLVM_DEBUG(llvm::dbgs()
<< " caller-backedge: " << caller->getName() << "\n");
bool foundF = callerInfo.calleeStates.remove(f);
(void)foundF;
assert(foundF && "Bad caller edge pointing at f?");
fInfo.callerStates.pop_back();
}
}
void CallerAnalysis::invalidateKnownCallees(SILFunction *caller,
FunctionInfo &callerInfo) {
LLVM_DEBUG(llvm::dbgs() << "Invalidating caller: " << caller->getName()
<< "\n");
while (callerInfo.calleeStates.size()) {
auto *callee = callerInfo.calleeStates.pop_back_val();
FunctionInfo &calleeInfo = unsafeGetFunctionInfo(callee);
LLVM_DEBUG(llvm::dbgs() << " callee: " << callee->getName() << "\n");
assert(calleeInfo.callerStates.count(caller) &&
"Referenced callee is not fully/partially applied in the caller?!");
// Then remove the caller from this specific callee's info struct
// and to be conservative mark the callee as potentially having an
// escaping use that we do not understand.
calleeInfo.callerStates.erase(caller);
}
}
void CallerAnalysis::invalidateKnownCallees(SILFunction *caller) {
// Look up the callees that our caller refers to and invalidate any
// values that point back at the caller.
invalidateKnownCallees(caller, unsafeGetFunctionInfo(caller));
}
void CallerAnalysis::verify(SILFunction *caller) const {
#ifndef NDEBUG
const FunctionInfo &callerInfo = unsafeGetFunctionInfo(caller);
verify(caller, callerInfo);
#endif
}
void CallerAnalysis::verify(SILFunction *function,
const FunctionInfo &functionInfo) const {
#ifndef NDEBUG
LLVM_DEBUG(llvm::dbgs() << "Validating function: " << function->getName()
<< "\n");
for (auto *callee : functionInfo.calleeStates) {
LLVM_DEBUG(llvm::dbgs() << " callee: " << callee->getName() << "\n");
const FunctionInfo &calleeInfo = unsafeGetFunctionInfo(callee);
assert(calleeInfo.callerStates.count(function) &&
"Referenced callee is not fully/partially applied in the caller");
}
// Make sure all caller edges are valid.
for (auto callerPair : functionInfo.callerStates) {
auto *caller = callerPair.first;
LLVM_DEBUG(llvm::dbgs() << " caller: " << caller->getName() << "\n");
const FunctionInfo &callerInfo = unsafeGetFunctionInfo(caller);
assert(callerInfo.calleeStates.count(function) &&
"Referencing caller does not have a callee edge for function");
}
#endif
}
void CallerAnalysis::verify() const {
#ifndef NDEBUG
std::vector<SILFunction *> seenFunctions;
for (auto &fn : mod) {
bool found = funcInfos.count(&fn);
if (!found) {
llvm::errs() << "Missing notification for added function: '"
<< fn.getName() << "'\n";
llvm_unreachable("standard error assertion");
}
seenFunctions.push_back(&fn);
}
sortUnique(seenFunctions);
for (auto &pair : funcInfos) {
bool found = std::binary_search(seenFunctions.begin(), seenFunctions.end(),
pair.first);
if (!found) {
llvm::errs() << "Notification not sent for deleted function: '"
<< pair.first->getName() << "'.";
llvm_unreachable("standard error assertion");
}
verify(pair.first, pair.second);
}
#endif
}
void CallerAnalysis::invalidate() {
for (auto &f : mod) {
// Since we are going over all functions in the module
// invalidateKnownCallees should be sufficient.
invalidateKnownCallees(&f);
// We do not need to clear recompute function list since we know that it can
// at most contain a subset of the functions in the module so the SetVector
// will unique for us.
recomputeFunctionList.insert(&f);
}
}
//===----------------------------------------------------------------------===//
// CallerAnalysis YAML Dumper
//===----------------------------------------------------------------------===//
namespace {
using llvm::yaml::IO;
using llvm::yaml::MappingTraits;
using llvm::yaml::Output;
using llvm::yaml::ScalarEnumerationTraits;
using llvm::yaml::SequenceTraits;
/// A special struct that marshals call graph state into a form that
/// is easy for llvm's yaml i/o to dump. Its structure is meant to
/// correspond to how the data should be shown by the printer, so
/// naturally it is slightly redundant.
struct YAMLCallGraphNode {
StringRef calleeName;
bool hasCaller;
unsigned minPartialAppliedArgs;
bool hasOnlyCompleteDirectCallerSets;
bool hasAllCallers;
std::vector<StringRef> partialAppliers;
std::vector<StringRef> fullAppliers;
YAMLCallGraphNode() = delete;
~YAMLCallGraphNode() = default;
YAMLCallGraphNode(const YAMLCallGraphNode &) = delete;
YAMLCallGraphNode(YAMLCallGraphNode &&) = delete;
YAMLCallGraphNode &operator=(const YAMLCallGraphNode &) = delete;
YAMLCallGraphNode &operator=(YAMLCallGraphNode &&) = delete;
YAMLCallGraphNode(StringRef calleeName, bool hasCaller,
unsigned minPartialAppliedArgs,
bool hasOnlyCompleteDirectCallerSets, bool hasAllCallers,
std::vector<StringRef> &&partialAppliers,
std::vector<StringRef> &&fullAppliers)
: calleeName(calleeName), hasCaller(hasCaller),
minPartialAppliedArgs(minPartialAppliedArgs),
hasOnlyCompleteDirectCallerSets(hasOnlyCompleteDirectCallerSets),
hasAllCallers(hasAllCallers),
partialAppliers(std::move(partialAppliers)),
fullAppliers(std::move(fullAppliers)) {}
};
} // end anonymous namespace
namespace llvm {
namespace yaml {
template <> struct MappingTraits<YAMLCallGraphNode> {
static void mapping(IO &io, YAMLCallGraphNode &func) {
io.mapRequired("calleeName", func.calleeName);
io.mapRequired("hasCaller", func.hasCaller);
io.mapRequired("minPartialAppliedArgs", func.minPartialAppliedArgs);
io.mapRequired("hasOnlyCompleteDirectCallerSets",
func.hasOnlyCompleteDirectCallerSets);
io.mapRequired("hasAllCallers", func.hasAllCallers);
io.mapRequired("partialAppliers", func.partialAppliers);
io.mapRequired("fullAppliers", func.fullAppliers);
}
};
} // namespace yaml
} // namespace llvm
void CallerAnalysis::dump() const { print(llvm::errs()); }
void CallerAnalysis::print(const char *filePath) const {
using namespace llvm::sys;
std::error_code error;
llvm::raw_fd_ostream fileOutputStream(filePath, error, fs::F_Text);
if (error) {
llvm::errs() << "Failed to open path \"" << filePath << "\" for writing.!";
llvm_unreachable("default error handler");
}
print(fileOutputStream);
}
void CallerAnalysis::print(llvm::raw_ostream &os) const {
llvm::yaml::Output yout(os);
// NOTE: We purposely do not iterate over our internal state here to ensure
// that we dump for all functions and that we dump the state we have stored
// with the functions in module order.
for (auto &f : mod) {
const auto &fi = getFunctionInfo(&f);
std::vector<StringRef> partialAppliers;
std::vector<StringRef> fullAppliers;
for (auto &apply : fi.getAllReferencingCallers()) {
if (apply.second.hasFullApply) {
fullAppliers.push_back(apply.first->getName());
}
if (apply.second.getNumPartiallyAppliedArguments().hasValue()) {
partialAppliers.push_back(apply.first->getName());
}
}
YAMLCallGraphNode node(
f.getName(), fi.hasDirectCaller(), fi.getMinPartialAppliedArgs(),
fi.hasOnlyCompleteDirectCallerSets(), fi.foundAllCallers(),
std::move(partialAppliers), std::move(fullAppliers));
yout << node;
}
}
//===----------------------------------------------------------------------===//
// Main Entry Point
//===----------------------------------------------------------------------===//
SILAnalysis *swift::createCallerAnalysis(SILModule *mod) {
return new CallerAnalysis(mod);
}