| //===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Test pass for backward dense dataflow analysis. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "TestDenseDataFlowAnalysis.h" |
| #include "TestDialect.h" |
| #include "TestOps.h" |
| #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
| #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
| #include "mlir/Analysis/DataFlow/DenseAnalysis.h" |
| #include "mlir/Analysis/DataFlowFramework.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/SymbolTable.h" |
| #include "mlir/Interfaces/CallInterfaces.h" |
| #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| #include "mlir/Interfaces/SideEffectInterfaces.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/TypeID.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| using namespace mlir; |
| using namespace mlir::dataflow; |
| using namespace mlir::dataflow::test; |
| |
| namespace { |
| |
| class NextAccess : public AbstractDenseLattice, public AccessLatticeBase { |
| public: |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess) |
| |
| using dataflow::AbstractDenseLattice::AbstractDenseLattice; |
| |
| ChangeResult meet(const AbstractDenseLattice &lattice) override { |
| return AccessLatticeBase::merge(static_cast<AccessLatticeBase>( |
| static_cast<const NextAccess &>(lattice))); |
| } |
| |
| void print(raw_ostream &os) const override { |
| return AccessLatticeBase::print(os); |
| } |
| }; |
| |
| class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis<NextAccess> { |
| public: |
| NextAccessAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, |
| bool assumeFuncReads = false) |
| : DenseBackwardDataFlowAnalysis(solver, symbolTable), |
| assumeFuncReads(assumeFuncReads) {} |
| |
| void visitOperation(Operation *op, const NextAccess &after, |
| NextAccess *before) override; |
| |
| void visitCallControlFlowTransfer(CallOpInterface call, |
| CallControlFlowAction action, |
| const NextAccess &after, |
| NextAccess *before) override; |
| |
| void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, |
| RegionBranchPoint regionFrom, |
| RegionBranchPoint regionTo, |
| const NextAccess &after, |
| NextAccess *before) override; |
| |
| // TODO: this isn't ideal for the analysis. When there is no next access, it |
| // means "we don't know what the next access is" rather than "there is no next |
| // access". But it's unclear how to differentiate the two cases... |
| void setToExitState(NextAccess *lattice) override { |
| propagateIfChanged(lattice, lattice->setKnownToUnknown()); |
| } |
| |
| const bool assumeFuncReads; |
| }; |
| } // namespace |
| |
| void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after, |
| NextAccess *before) { |
| auto memory = dyn_cast<MemoryEffectOpInterface>(op); |
| // If we can't reason about the memory effects, conservatively assume we can't |
| // say anything about the next access. |
| if (!memory) |
| return setToExitState(before); |
| |
| SmallVector<MemoryEffects::EffectInstance> effects; |
| memory.getEffects(effects); |
| |
| // First, check if all underlying values are already known. Otherwise, avoid |
| // propagating and stay in the "undefined" state to avoid incorrectly |
| // propagating values that may be overwritten later on as that could be |
| // problematic for convergence based on monotonicity of lattice updates. |
| SmallVector<Value> underlyingValues; |
| underlyingValues.reserve(effects.size()); |
| for (const MemoryEffects::EffectInstance &effect : effects) { |
| Value value = effect.getValue(); |
| |
| // Effects with unspecified value are treated conservatively and we cannot |
| // assume anything about the next access. |
| if (!value) |
| return setToExitState(before); |
| |
| // If cannot find the most underlying value, we cannot assume anything about |
| // the next accesses. |
| std::optional<Value> underlyingValue = |
| UnderlyingValueAnalysis::getMostUnderlyingValue( |
| value, [&](Value value) { |
| return getOrCreateFor<UnderlyingValueLattice>(op, value); |
| }); |
| |
| // If the underlying value is not known yet, don't propagate. |
| if (!underlyingValue) |
| return; |
| |
| underlyingValues.push_back(*underlyingValue); |
| } |
| |
| // Update the state if all underlying values are known. |
| ChangeResult result = before->meet(after); |
| for (const auto &[effect, value] : llvm::zip(effects, underlyingValues)) { |
| // If the underlying value is known to be unknown, set to fixpoint. |
| if (!value) |
| return setToExitState(before); |
| |
| result |= before->set(value, op); |
| } |
| propagateIfChanged(before, result); |
| } |
| |
| void NextAccessAnalysis::visitCallControlFlowTransfer( |
| CallOpInterface call, CallControlFlowAction action, const NextAccess &after, |
| NextAccess *before) { |
| if (action == CallControlFlowAction::ExternalCallee && assumeFuncReads) { |
| SmallVector<Value> underlyingValues; |
| underlyingValues.reserve(call->getNumOperands()); |
| for (Value operand : call.getArgOperands()) { |
| std::optional<Value> underlyingValue = |
| UnderlyingValueAnalysis::getMostUnderlyingValue( |
| operand, [&](Value value) { |
| return getOrCreateFor<UnderlyingValueLattice>( |
| call.getOperation(), value); |
| }); |
| if (!underlyingValue) |
| return; |
| underlyingValues.push_back(*underlyingValue); |
| } |
| |
| ChangeResult result = before->meet(after); |
| for (Value operand : underlyingValues) { |
| result |= before->set(operand, call); |
| } |
| return propagateIfChanged(before, result); |
| } |
| auto testCallAndStore = |
| dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); |
| if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && |
| testCallAndStore.getStoreBeforeCall()) || |
| (action == CallControlFlowAction::ExitCallee && |
| !testCallAndStore.getStoreBeforeCall()))) { |
| visitOperation(call, after, before); |
| } else { |
| AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer( |
| call, action, after, before); |
| } |
| } |
| |
| void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( |
| RegionBranchOpInterface branch, RegionBranchPoint regionFrom, |
| RegionBranchPoint regionTo, const NextAccess &after, NextAccess *before) { |
| auto testStoreWithARegion = |
| dyn_cast<::test::TestStoreWithARegion>(branch.getOperation()); |
| |
| if (testStoreWithARegion && |
| ((regionTo.isParent() && !testStoreWithARegion.getStoreBeforeRegion()) || |
| (regionFrom.isParent() && |
| testStoreWithARegion.getStoreBeforeRegion()))) { |
| visitOperation(branch, static_cast<const NextAccess &>(after), |
| static_cast<NextAccess *>(before)); |
| } else { |
| propagateIfChanged(before, before->meet(after)); |
| } |
| } |
| |
| namespace { |
| struct TestNextAccessPass |
| : public PassWrapper<TestNextAccessPass, OperationPass<>> { |
| TestNextAccessPass() = default; |
| TestNextAccessPass(const TestNextAccessPass &other) : PassWrapper(other) { |
| interprocedural = other.interprocedural; |
| assumeFuncReads = other.assumeFuncReads; |
| } |
| |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass) |
| |
| StringRef getArgument() const override { return "test-next-access"; } |
| |
| Option<bool> interprocedural{ |
| *this, "interprocedural", llvm::cl::init(true), |
| llvm::cl::desc("perform interprocedural analysis")}; |
| Option<bool> assumeFuncReads{ |
| *this, "assume-func-reads", llvm::cl::init(false), |
| llvm::cl::desc( |
| "assume external functions have read effect on all arguments")}; |
| |
| static constexpr llvm::StringLiteral kTagAttrName = "name"; |
| static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access"; |
| static constexpr llvm::StringLiteral kAtEntryPointAttrName = |
| "next_at_entry_point"; |
| |
| static Attribute makeNextAccessAttribute(Operation *op, |
| const DataFlowSolver &solver, |
| const NextAccess *nextAccess) { |
| if (!nextAccess) |
| return StringAttr::get(op->getContext(), "not computed"); |
| |
| // Note that if the underlying value could not be computed or is unknown, we |
| // conservatively treat the result also unknown. |
| SmallVector<Attribute> attrs; |
| for (Value operand : op->getOperands()) { |
| std::optional<Value> underlyingValue = |
| UnderlyingValueAnalysis::getMostUnderlyingValue( |
| operand, [&](Value value) { |
| return solver.lookupState<UnderlyingValueLattice>(value); |
| }); |
| if (!underlyingValue) { |
| attrs.push_back(StringAttr::get(op->getContext(), "unknown")); |
| continue; |
| } |
| Value value = *underlyingValue; |
| const AdjacentAccess *nextAcc = nextAccess->getAdjacentAccess(value); |
| if (!nextAcc || !nextAcc->isKnown()) { |
| attrs.push_back(StringAttr::get(op->getContext(), "unknown")); |
| continue; |
| } |
| |
| SmallVector<Attribute> innerAttrs; |
| innerAttrs.reserve(nextAcc->get().size()); |
| for (Operation *nextAccOp : nextAcc->get()) { |
| if (auto nextAccTag = |
| nextAccOp->getAttrOfType<StringAttr>(kTagAttrName)) { |
| innerAttrs.push_back(nextAccTag); |
| continue; |
| } |
| std::string repr; |
| llvm::raw_string_ostream os(repr); |
| nextAccOp->print(os); |
| innerAttrs.push_back(StringAttr::get(op->getContext(), os.str())); |
| } |
| attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs)); |
| } |
| return ArrayAttr::get(op->getContext(), attrs); |
| } |
| |
| void runOnOperation() override { |
| Operation *op = getOperation(); |
| SymbolTableCollection symbolTable; |
| |
| auto config = DataFlowConfig().setInterprocedural(interprocedural); |
| DataFlowSolver solver(config); |
| solver.load<DeadCodeAnalysis>(); |
| solver.load<NextAccessAnalysis>(symbolTable, assumeFuncReads); |
| solver.load<SparseConstantPropagation>(); |
| solver.load<UnderlyingValueAnalysis>(); |
| if (failed(solver.initializeAndRun(op))) { |
| emitError(op->getLoc(), "dataflow solver failed"); |
| return signalPassFailure(); |
| } |
| op->walk([&](Operation *op) { |
| auto tag = op->getAttrOfType<StringAttr>(kTagAttrName); |
| if (!tag) |
| return; |
| |
| const NextAccess *nextAccess = solver.lookupState<NextAccess>( |
| op->getNextNode() == nullptr ? ProgramPoint(op->getBlock()) |
| : op->getNextNode()); |
| op->setAttr(kNextAccessAttrName, |
| makeNextAccessAttribute(op, solver, nextAccess)); |
| |
| auto iface = dyn_cast<RegionBranchOpInterface>(op); |
| if (!iface) |
| return; |
| |
| SmallVector<Attribute> entryPointNextAccess; |
| SmallVector<RegionSuccessor> regionSuccessors; |
| iface.getSuccessorRegions(RegionBranchPoint::parent(), regionSuccessors); |
| for (const RegionSuccessor &successor : regionSuccessors) { |
| if (!successor.getSuccessor() || successor.getSuccessor()->empty()) |
| continue; |
| Block &successorBlock = successor.getSuccessor()->front(); |
| ProgramPoint successorPoint = successorBlock.empty() |
| ? ProgramPoint(&successorBlock) |
| : &successorBlock.front(); |
| entryPointNextAccess.push_back(makeNextAccessAttribute( |
| op, solver, solver.lookupState<NextAccess>(successorPoint))); |
| } |
| op->setAttr(kAtEntryPointAttrName, |
| ArrayAttr::get(op->getContext(), entryPointNextAccess)); |
| }); |
| } |
| }; |
| } // namespace |
| |
| namespace mlir::test { |
| void registerTestNextAccessPass() { PassRegistration<TestNextAccessPass>(); } |
| } // namespace mlir::test |