| //===- TestBackwardDataFlowAnalysis.cpp - Test dead code analysis ---------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
| #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
| #include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Interfaces/SideEffectInterfaces.h" |
| #include "mlir/Pass/Pass.h" |
| |
| using namespace mlir; |
| using namespace mlir::dataflow; |
| |
| namespace { |
| |
| /// This lattice represents, for a given value, the set of memory resources that |
| /// this value, or anything derived from this value, is potentially written to. |
| struct WrittenTo : public AbstractSparseLattice { |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WrittenTo) |
| using AbstractSparseLattice::AbstractSparseLattice; |
| |
| void print(raw_ostream &os) const override { |
| os << "["; |
| llvm::interleave( |
| writes, os, [&](const StringAttr &a) { os << a.str(); }, " "); |
| os << "]"; |
| } |
| ChangeResult addWrites(const SetVector<StringAttr> &writes) { |
| int sizeBefore = this->writes.size(); |
| this->writes.insert(writes.begin(), writes.end()); |
| int sizeAfter = this->writes.size(); |
| return sizeBefore == sizeAfter ? ChangeResult::NoChange |
| : ChangeResult::Change; |
| } |
| ChangeResult meet(const AbstractSparseLattice &other) override { |
| const auto *rhs = reinterpret_cast<const WrittenTo *>(&other); |
| return addWrites(rhs->writes); |
| } |
| |
| SetVector<StringAttr> writes; |
| }; |
| |
| /// An analysis that, by going backwards along the dataflow graph, annotates |
| /// each value with all the memory resources it (or anything derived from it) |
| /// is eventually written to. |
| class WrittenToAnalysis : public SparseBackwardDataFlowAnalysis<WrittenTo> { |
| public: |
| WrittenToAnalysis(DataFlowSolver &solver, SymbolTableCollection &symbolTable, |
| bool assumeFuncWrites) |
| : SparseBackwardDataFlowAnalysis(solver, symbolTable), |
| assumeFuncWrites(assumeFuncWrites) {} |
| |
| void visitOperation(Operation *op, ArrayRef<WrittenTo *> operands, |
| ArrayRef<const WrittenTo *> results) override; |
| |
| void visitBranchOperand(OpOperand &operand) override; |
| |
| void visitCallOperand(OpOperand &operand) override; |
| |
| void visitExternalCall(CallOpInterface call, ArrayRef<WrittenTo *> operands, |
| ArrayRef<const WrittenTo *> results) override; |
| |
| void setToExitState(WrittenTo *lattice) override { lattice->writes.clear(); } |
| |
| private: |
| bool assumeFuncWrites; |
| }; |
| |
| void WrittenToAnalysis::visitOperation(Operation *op, |
| ArrayRef<WrittenTo *> operands, |
| ArrayRef<const WrittenTo *> results) { |
| if (auto store = dyn_cast<memref::StoreOp>(op)) { |
| SetVector<StringAttr> newWrites; |
| newWrites.insert(op->getAttrOfType<StringAttr>("tag_name")); |
| propagateIfChanged(operands[0], operands[0]->addWrites(newWrites)); |
| return; |
| } // By default, every result of an op depends on every operand. |
| for (const WrittenTo *r : results) { |
| for (WrittenTo *operand : operands) { |
| meet(operand, *r); |
| } |
| addDependency(const_cast<WrittenTo *>(r), op); |
| } |
| } |
| |
| void WrittenToAnalysis::visitBranchOperand(OpOperand &operand) { |
| // Mark branch operands as "brancharg%d", with %d the operand number. |
| WrittenTo *lattice = getLatticeElement(operand.get()); |
| SetVector<StringAttr> newWrites; |
| newWrites.insert( |
| StringAttr::get(operand.getOwner()->getContext(), |
| "brancharg" + Twine(operand.getOperandNumber()))); |
| propagateIfChanged(lattice, lattice->addWrites(newWrites)); |
| } |
| |
| void WrittenToAnalysis::visitCallOperand(OpOperand &operand) { |
| // Mark call operands as "callarg%d", with %d the operand number. |
| WrittenTo *lattice = getLatticeElement(operand.get()); |
| SetVector<StringAttr> newWrites; |
| newWrites.insert( |
| StringAttr::get(operand.getOwner()->getContext(), |
| "callarg" + Twine(operand.getOperandNumber()))); |
| propagateIfChanged(lattice, lattice->addWrites(newWrites)); |
| } |
| |
| void WrittenToAnalysis::visitExternalCall(CallOpInterface call, |
| ArrayRef<WrittenTo *> operands, |
| ArrayRef<const WrittenTo *> results) { |
| if (!assumeFuncWrites) { |
| return SparseBackwardDataFlowAnalysis::visitExternalCall(call, operands, |
| results); |
| } |
| |
| for (WrittenTo *lattice : operands) { |
| SetVector<StringAttr> newWrites; |
| StringAttr name = call->getAttrOfType<StringAttr>("tag_name"); |
| if (!name) { |
| name = StringAttr::get(call->getContext(), |
| call.getOperation()->getName().getStringRef()); |
| } |
| newWrites.insert(name); |
| propagateIfChanged(lattice, lattice->addWrites(newWrites)); |
| } |
| } |
| |
| } // end anonymous namespace |
| |
| namespace { |
| struct TestWrittenToPass |
| : public PassWrapper<TestWrittenToPass, OperationPass<>> { |
| MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWrittenToPass) |
| |
| TestWrittenToPass() = default; |
| TestWrittenToPass(const TestWrittenToPass &other) : PassWrapper(other) { |
| interprocedural = other.interprocedural; |
| assumeFuncWrites = other.assumeFuncWrites; |
| } |
| |
| StringRef getArgument() const override { return "test-written-to"; } |
| |
| Option<bool> interprocedural{ |
| *this, "interprocedural", llvm::cl::init(true), |
| llvm::cl::desc("perform interprocedural analysis")}; |
| Option<bool> assumeFuncWrites{ |
| *this, "assume-func-writes", llvm::cl::init(false), |
| llvm::cl::desc( |
| "assume external functions have write effect on all arguments")}; |
| |
| void runOnOperation() override { |
| Operation *op = getOperation(); |
| |
| SymbolTableCollection symbolTable; |
| |
| DataFlowSolver solver(DataFlowConfig().setInterprocedural(interprocedural)); |
| solver.load<DeadCodeAnalysis>(); |
| solver.load<SparseConstantPropagation>(); |
| solver.load<WrittenToAnalysis>(symbolTable, assumeFuncWrites); |
| if (failed(solver.initializeAndRun(op))) |
| return signalPassFailure(); |
| |
| raw_ostream &os = llvm::outs(); |
| op->walk([&](Operation *op) { |
| auto tag = op->getAttrOfType<StringAttr>("tag"); |
| if (!tag) |
| return; |
| os << "test_tag: " << tag.getValue() << ":\n"; |
| for (auto [index, operand] : llvm::enumerate(op->getOperands())) { |
| const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand); |
| assert(writtenTo && "expected a sparse lattice"); |
| os << " operand #" << index << ": "; |
| writtenTo->print(os); |
| os << "\n"; |
| } |
| for (auto [index, operand] : llvm::enumerate(op->getResults())) { |
| const WrittenTo *writtenTo = solver.lookupState<WrittenTo>(operand); |
| assert(writtenTo && "expected a sparse lattice"); |
| os << " result #" << index << ": "; |
| writtenTo->print(os); |
| os << "\n"; |
| } |
| }); |
| } |
| }; |
| } // end anonymous namespace |
| |
| namespace mlir { |
| namespace test { |
| void registerTestWrittenToPass() { PassRegistration<TestWrittenToPass>(); } |
| } // end namespace test |
| } // end namespace mlir |