| //===- Inliner.cpp - Pass to inline function calls ------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements a basic inlining algorithm that operates bottom up over |
| // the Strongly Connect Components(SCCs) of the CallGraph. This enables a more |
| // incremental propagation of inlining decisions from the leafs to the roots of |
| // the callgraph. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Analysis/CallGraph.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Transforms/InliningUtils.h" |
| #include "mlir/Transforms/Passes.h" |
| #include "llvm/ADT/SCCIterator.h" |
| #include "llvm/Support/Debug.h" |
| #include "llvm/Support/Parallel.h" |
| |
| #define DEBUG_TYPE "inlining" |
| |
| using namespace mlir; |
| |
| static llvm::cl::opt<bool> disableCanonicalization( |
| "mlir-disable-inline-simplify", |
| llvm::cl::desc("Disable running simplifications during inlining"), |
| llvm::cl::ReallyHidden, llvm::cl::init(false)); |
| |
| static llvm::cl::opt<unsigned> maxInliningIterations( |
| "mlir-max-inline-iterations", |
| llvm::cl::desc("Maximum number of iterations when inlining within an SCC"), |
| llvm::cl::ReallyHidden, llvm::cl::init(4)); |
| |
| //===----------------------------------------------------------------------===// |
| // CallGraph traversal |
| //===----------------------------------------------------------------------===// |
| |
| /// Run a given transformation over the SCCs of the callgraph in a bottom up |
| /// traversal. |
| static void runTransformOnCGSCCs( |
| const CallGraph &cg, |
| function_ref<void(ArrayRef<CallGraphNode *>)> sccTransformer) { |
| std::vector<CallGraphNode *> currentSCCVec; |
| auto cgi = llvm::scc_begin(&cg); |
| while (!cgi.isAtEnd()) { |
| // Copy the current SCC and increment so that the transformer can modify the |
| // SCC without invalidating our iterator. |
| currentSCCVec = *cgi; |
| ++cgi; |
| sccTransformer(currentSCCVec); |
| } |
| } |
| |
| namespace { |
| /// This struct represents a resolved call to a given callgraph node. Given that |
| /// the call does not actually contain a direct reference to the |
| /// Region(CallGraphNode) that it is dispatching to, we need to resolve them |
| /// explicitly. |
| struct ResolvedCall { |
| ResolvedCall(CallOpInterface call, CallGraphNode *targetNode) |
| : call(call), targetNode(targetNode) {} |
| CallOpInterface call; |
| CallGraphNode *targetNode; |
| }; |
| } // end anonymous namespace |
| |
| /// Collect all of the callable operations within the given range of blocks. If |
| /// `traverseNestedCGNodes` is true, this will also collect call operations |
| /// inside of nested callgraph nodes. |
| static void collectCallOps(iterator_range<Region::iterator> blocks, |
| CallGraph &cg, SmallVectorImpl<ResolvedCall> &calls, |
| bool traverseNestedCGNodes) { |
| SmallVector<Block *, 8> worklist; |
| auto addToWorklist = [&](iterator_range<Region::iterator> blocks) { |
| for (Block &block : blocks) |
| worklist.push_back(&block); |
| }; |
| |
| addToWorklist(blocks); |
| while (!worklist.empty()) { |
| for (Operation &op : *worklist.pop_back_val()) { |
| if (auto call = dyn_cast<CallOpInterface>(op)) { |
| // TODO(riverriddle) Support inlining nested call references. |
| CallInterfaceCallable callable = call.getCallableForCallee(); |
| if (SymbolRefAttr symRef = callable.dyn_cast<SymbolRefAttr>()) { |
| if (!symRef.isa<FlatSymbolRefAttr>()) |
| continue; |
| } |
| |
| CallGraphNode *node = cg.resolveCallable(call); |
| if (!node->isExternal()) |
| calls.emplace_back(call, node); |
| continue; |
| } |
| |
| // If this is not a call, traverse the nested regions. If |
| // `traverseNestedCGNodes` is false, then don't traverse nested call graph |
| // regions. |
| for (auto &nestedRegion : op.getRegions()) |
| if (traverseNestedCGNodes || !cg.lookupNode(&nestedRegion)) |
| addToWorklist(nestedRegion); |
| } |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Inliner |
| //===----------------------------------------------------------------------===// |
| namespace { |
| /// This class provides a specialization of the main inlining interface. |
| struct Inliner : public InlinerInterface { |
| Inliner(MLIRContext *context, CallGraph &cg) |
| : InlinerInterface(context), cg(cg) {} |
| |
| /// Process a set of blocks that have been inlined. This callback is invoked |
| /// *before* inlined terminator operations have been processed. |
| void |
| processInlinedBlocks(iterator_range<Region::iterator> inlinedBlocks) final { |
| collectCallOps(inlinedBlocks, cg, calls, /*traverseNestedCGNodes=*/true); |
| } |
| |
| /// The current set of call instructions to consider for inlining. |
| SmallVector<ResolvedCall, 8> calls; |
| |
| /// The callgraph being operated on. |
| CallGraph &cg; |
| }; |
| } // namespace |
| |
| /// Returns true if the given call should be inlined. |
| static bool shouldInline(ResolvedCall &resolvedCall) { |
| // Don't allow inlining terminator calls. We currently don't support this |
| // case. |
| if (resolvedCall.call.getOperation()->isKnownTerminator()) |
| return false; |
| |
| // Don't allow inlining if the target is an ancestor of the call. This |
| // prevents inlining recursively. |
| if (resolvedCall.targetNode->getCallableRegion()->isAncestor( |
| resolvedCall.call.getParentRegion())) |
| return false; |
| |
| // Otherwise, inline. |
| return true; |
| } |
| |
| /// Attempt to inline calls within the given scc. This function returns |
| /// success if any calls were inlined, failure otherwise. |
| static LogicalResult inlineCallsInSCC(Inliner &inliner, |
| ArrayRef<CallGraphNode *> currentSCC) { |
| CallGraph &cg = inliner.cg; |
| auto &calls = inliner.calls; |
| |
| // Collect all of the direct calls within the nodes of the current SCC. We |
| // don't traverse nested callgraph nodes, because they are handled separately |
| // likely within a different SCC. |
| for (auto *node : currentSCC) { |
| if (!node->isExternal()) |
| collectCallOps(*node->getCallableRegion(), cg, calls, |
| /*traverseNestedCGNodes=*/false); |
| } |
| if (calls.empty()) |
| return failure(); |
| |
| // Try to inline each of the call operations. Don't cache the end iterator |
| // here as more calls may be added during inlining. |
| bool inlinedAnyCalls = false; |
| for (unsigned i = 0; i != calls.size(); ++i) { |
| ResolvedCall &it = calls[i]; |
| LLVM_DEBUG({ |
| llvm::dbgs() << "* Considering inlining call: "; |
| it.call.dump(); |
| }); |
| if (!shouldInline(it)) |
| continue; |
| |
| CallOpInterface call = it.call; |
| Region *targetRegion = it.targetNode->getCallableRegion(); |
| LogicalResult inlineResult = inlineCall( |
| inliner, call, cast<CallableOpInterface>(targetRegion->getParentOp()), |
| targetRegion); |
| if (failed(inlineResult)) |
| continue; |
| |
| // If the inlining was successful, then erase the call. |
| call.erase(); |
| inlinedAnyCalls = true; |
| } |
| calls.clear(); |
| return success(inlinedAnyCalls); |
| } |
| |
| /// Canonicalize the nodes within the given SCC with the given set of |
| /// canonicalization patterns. |
| static void canonicalizeSCC(CallGraph &cg, ArrayRef<CallGraphNode *> currentSCC, |
| MLIRContext *context, |
| const OwningRewritePatternList &canonPatterns) { |
| // Collect the sets of nodes to canonicalize. |
| SmallVector<CallGraphNode *, 4> nodesToCanonicalize; |
| for (auto *node : currentSCC) { |
| // Don't canonicalize the external node, it has no valid callable region. |
| if (node->isExternal()) |
| continue; |
| |
| // Don't canonicalize nodes with children. Nodes with children |
| // require special handling as we may remove the node during |
| // canonicalization. In the future, we should be able to handle this |
| // case with proper node deletion tracking. |
| if (node->hasChildren()) |
| continue; |
| |
| // We also won't apply canonicalizations for nodes that are not |
| // isolated. This avoids potentially mutating the regions of nodes defined |
| // above, this is also a stipulation of the 'applyPatternsGreedily' driver. |
| auto *region = node->getCallableRegion(); |
| if (!region->getParentOp()->isKnownIsolatedFromAbove()) |
| continue; |
| nodesToCanonicalize.push_back(node); |
| } |
| if (nodesToCanonicalize.empty()) |
| return; |
| |
| // Canonicalize each of the nodes within the SCC in parallel. |
| // NOTE: This is simple now, because we don't enable canonicalizing nodes |
| // within children. When we remove this restriction, this logic will need to |
| // be reworked. |
| ParallelDiagnosticHandler canonicalizationHandler(context); |
| llvm::parallel::for_each_n( |
| llvm::parallel::par, /*Begin=*/size_t(0), |
| /*End=*/nodesToCanonicalize.size(), [&](size_t index) { |
| // Set the order for this thread so that diagnostics will be properly |
| // ordered. |
| canonicalizationHandler.setOrderIDForThread(index); |
| |
| // Apply the canonicalization patterns to this region. |
| auto *node = nodesToCanonicalize[index]; |
| applyPatternsGreedily(*node->getCallableRegion(), canonPatterns); |
| |
| // Make sure to reset the order ID for the diagnostic handler, as this |
| // thread may be used in a different context. |
| canonicalizationHandler.eraseOrderIDForThread(); |
| }); |
| } |
| |
| /// Attempt to inline calls within the given scc, and run canonicalizations with |
| /// the given patterns, until a fixed point is reached. This allows for the |
| /// inlining of newly devirtualized calls. |
| static void inlineSCC(Inliner &inliner, ArrayRef<CallGraphNode *> currentSCC, |
| MLIRContext *context, |
| const OwningRewritePatternList &canonPatterns) { |
| // If we successfully inlined any calls, run some simplifications on the |
| // nodes of the scc. Continue attempting to inline until we reach a fixed |
| // point, or a maximum iteration count. We canonicalize here as it may |
| // devirtualize new calls, as well as give us a better cost model. |
| unsigned iterationCount = 0; |
| while (succeeded(inlineCallsInSCC(inliner, currentSCC))) { |
| // If we aren't allowing simplifications or the max iteration count was |
| // reached, then bail out early. |
| if (disableCanonicalization || ++iterationCount >= maxInliningIterations) |
| break; |
| canonicalizeSCC(inliner.cg, currentSCC, context, canonPatterns); |
| } |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // InlinerPass |
| //===----------------------------------------------------------------------===// |
| |
| // TODO(riverriddle) This pass should currently only be used for basic testing |
| // of inlining functionality. |
| namespace { |
| struct InlinerPass : public OperationPass<InlinerPass> { |
| void runOnOperation() override { |
| CallGraph &cg = getAnalysis<CallGraph>(); |
| auto *context = &getContext(); |
| |
| // The inliner should only be run on operations that define a symbol table, |
| // as the callgraph will need to resolve references. |
| Operation *op = getOperation(); |
| if (!op->hasTrait<OpTrait::SymbolTable>()) { |
| op->emitOpError() << " was scheduled to run under the inliner, but does " |
| "not define a symbol table"; |
| return signalPassFailure(); |
| } |
| |
| // Collect a set of canonicalization patterns to use when simplifying |
| // callable regions within an SCC. |
| OwningRewritePatternList canonPatterns; |
| for (auto *op : context->getRegisteredOperations()) |
| op->getCanonicalizationPatterns(canonPatterns, context); |
| |
| // Run the inline transform in post-order over the SCCs in the callgraph. |
| Inliner inliner(context, cg); |
| runTransformOnCGSCCs(cg, [&](ArrayRef<CallGraphNode *> scc) { |
| inlineSCC(inliner, scc, context, canonPatterns); |
| }); |
| } |
| }; |
| } // end anonymous namespace |
| |
| std::unique_ptr<Pass> mlir::createInlinerPass() { |
| return std::make_unique<InlinerPass>(); |
| } |
| |
| static PassRegistration<InlinerPass> pass("inline", "Inline function calls"); |