| //===- BufferDeallocationOpInterface.cpp ----------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/IR/AsmState.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/TypeUtilities.h" |
| #include "mlir/IR/Value.h" |
| #include "llvm/ADT/SetOperations.h" |
| |
| //===----------------------------------------------------------------------===// |
| // BufferDeallocationOpInterface |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| namespace bufferization { |
| |
| #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp.inc" |
| |
| } // namespace bufferization |
| } // namespace mlir |
| |
| using namespace mlir; |
| using namespace bufferization; |
| |
| //===----------------------------------------------------------------------===// |
| // Helpers |
| //===----------------------------------------------------------------------===// |
| |
| static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { |
| return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value)); |
| } |
| |
| static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); } |
| |
| //===----------------------------------------------------------------------===// |
| // Ownership |
| //===----------------------------------------------------------------------===// |
| |
| Ownership::Ownership(Value indicator) |
| : indicator(indicator), state(State::Unique) {} |
| |
| Ownership Ownership::getUnknown() { |
| Ownership unknown; |
| unknown.indicator = Value(); |
| unknown.state = State::Unknown; |
| return unknown; |
| } |
| Ownership Ownership::getUnique(Value indicator) { return Ownership(indicator); } |
| Ownership Ownership::getUninitialized() { return Ownership(); } |
| |
| bool Ownership::isUninitialized() const { |
| return state == State::Uninitialized; |
| } |
| bool Ownership::isUnique() const { return state == State::Unique; } |
| bool Ownership::isUnknown() const { return state == State::Unknown; } |
| |
| Value Ownership::getIndicator() const { |
| assert(isUnique() && "must have unique ownership to get the indicator"); |
| return indicator; |
| } |
| |
| Ownership Ownership::getCombined(Ownership other) const { |
| if (other.isUninitialized()) |
| return *this; |
| if (isUninitialized()) |
| return other; |
| |
| if (!isUnique() || !other.isUnique()) |
| return getUnknown(); |
| |
| // Since we create a new constant i1 value for (almost) each use-site, we |
| // should compare the actual value rather than just the SSA Value to avoid |
| // unnecessary invalidations. |
| if (isEqualConstantIntOrValue(indicator, other.indicator)) |
| return *this; |
| |
| // Return the join of the lattice if the indicator of both ownerships cannot |
| // be merged. |
| return getUnknown(); |
| } |
| |
| void Ownership::combine(Ownership other) { *this = getCombined(other); } |
| |
| //===----------------------------------------------------------------------===// |
| // DeallocationState |
| //===----------------------------------------------------------------------===// |
| |
| DeallocationState::DeallocationState(Operation *op) : liveness(op) {} |
| |
| void DeallocationState::updateOwnership(Value memref, Ownership ownership, |
| Block *block) { |
| // In most cases we care about the block where the value is defined. |
| if (block == nullptr) |
| block = memref.getParentBlock(); |
| |
| // Update ownership of current memref itself. |
| ownershipMap[{memref, block}].combine(ownership); |
| } |
| |
| void DeallocationState::resetOwnerships(ValueRange memrefs, Block *block) { |
| for (Value val : memrefs) |
| ownershipMap[{val, block}] = Ownership::getUninitialized(); |
| } |
| |
| Ownership DeallocationState::getOwnership(Value memref, Block *block) const { |
| return ownershipMap.lookup({memref, block}); |
| } |
| |
| void DeallocationState::addMemrefToDeallocate(Value memref, Block *block) { |
| memrefsToDeallocatePerBlock[block].push_back(memref); |
| } |
| |
| void DeallocationState::dropMemrefToDeallocate(Value memref, Block *block) { |
| llvm::erase(memrefsToDeallocatePerBlock[block], memref); |
| } |
| |
| void DeallocationState::getLiveMemrefsIn(Block *block, |
| SmallVectorImpl<Value> &memrefs) { |
| SmallVector<Value> liveMemrefs( |
| llvm::make_filter_range(liveness.getLiveIn(block), isMemref)); |
| llvm::sort(liveMemrefs, ValueComparator()); |
| memrefs.append(liveMemrefs); |
| } |
| |
| std::pair<Value, Value> |
| DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder, |
| Value memref, Block *block) { |
| auto iter = ownershipMap.find({memref, block}); |
| assert(iter != ownershipMap.end() && |
| "Value must already have been registered in the ownership map"); |
| |
| Ownership ownership = iter->second; |
| if (ownership.isUnique()) |
| return {memref, ownership.getIndicator()}; |
| |
| // Instead of inserting a clone operation we could also insert a dealloc |
| // operation earlier in the block and use the updated ownerships returned by |
| // the op for the retained values. Alternatively, we could insert code to |
| // check aliasing at runtime and use this information to combine two unique |
| // ownerships more intelligently to not end up with an 'Unknown' ownership in |
| // the first place. |
| auto cloneOp = |
| builder.create<bufferization::CloneOp>(memref.getLoc(), memref); |
| Value condition = buildBoolValue(builder, memref.getLoc(), true); |
| Value newMemref = cloneOp.getResult(); |
| updateOwnership(newMemref, condition); |
| memrefsToDeallocatePerBlock[newMemref.getParentBlock()].push_back(newMemref); |
| return {newMemref, condition}; |
| } |
| |
| void DeallocationState::getMemrefsToRetain( |
| Block *fromBlock, Block *toBlock, ValueRange destOperands, |
| SmallVectorImpl<Value> &toRetain) const { |
| for (Value operand : destOperands) { |
| if (!isMemref(operand)) |
| continue; |
| toRetain.push_back(operand); |
| } |
| |
| SmallPtrSet<Value, 16> liveOut; |
| for (auto val : liveness.getLiveOut(fromBlock)) |
| if (isMemref(val)) |
| liveOut.insert(val); |
| |
| if (toBlock) |
| llvm::set_intersect(liveOut, liveness.getLiveIn(toBlock)); |
| |
| // liveOut has non-deterministic order because it was constructed by iterating |
| // over a hash-set. |
| SmallVector<Value> retainedByLiveness(liveOut.begin(), liveOut.end()); |
| std::sort(retainedByLiveness.begin(), retainedByLiveness.end(), |
| ValueComparator()); |
| toRetain.append(retainedByLiveness); |
| } |
| |
| LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate( |
| OpBuilder &builder, Location loc, Block *block, |
| SmallVectorImpl<Value> &memrefs, SmallVectorImpl<Value> &conditions) const { |
| |
| for (auto [i, memref] : |
| llvm::enumerate(memrefsToDeallocatePerBlock.lookup(block))) { |
| Ownership ownership = ownershipMap.lookup({memref, block}); |
| if (!ownership.isUnique()) |
| return emitError(memref.getLoc(), |
| "MemRef value does not have valid ownership"); |
| |
| // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such |
| // that we can call extract_strided_metadata on it. |
| if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType())) |
| memref = builder.create<memref::ReinterpretCastOp>( |
| loc, MemRefType::get({}, unrankedMemRefTy.getElementType()), memref, |
| 0, SmallVector<int64_t>{}, SmallVector<int64_t>{}); |
| |
| // Use the `memref.extract_strided_metadata` operation to get the base |
| // memref. This is needed because the same MemRef that was produced by the |
| // alloc operation has to be passed to the dealloc operation. Passing |
| // subviews, etc. to a dealloc operation is not allowed. |
| memrefs.push_back( |
| builder.create<memref::ExtractStridedMetadataOp>(loc, memref) |
| .getResult(0)); |
| conditions.push_back(ownership.getIndicator()); |
| } |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ValueComparator |
| //===----------------------------------------------------------------------===// |
| |
| bool ValueComparator::operator()(const Value &lhs, const Value &rhs) const { |
| if (lhs == rhs) |
| return false; |
| |
| // Block arguments are less than results. |
| bool lhsIsBBArg = isa<BlockArgument>(lhs); |
| if (lhsIsBBArg != isa<BlockArgument>(rhs)) { |
| return lhsIsBBArg; |
| } |
| |
| Region *lhsRegion; |
| Region *rhsRegion; |
| if (lhsIsBBArg) { |
| auto lhsBBArg = llvm::cast<BlockArgument>(lhs); |
| auto rhsBBArg = llvm::cast<BlockArgument>(rhs); |
| if (lhsBBArg.getArgNumber() != rhsBBArg.getArgNumber()) { |
| return lhsBBArg.getArgNumber() < rhsBBArg.getArgNumber(); |
| } |
| lhsRegion = lhsBBArg.getParentRegion(); |
| rhsRegion = rhsBBArg.getParentRegion(); |
| assert(lhsRegion != rhsRegion && |
| "lhsRegion == rhsRegion implies lhs == rhs"); |
| } else if (lhs.getDefiningOp() == rhs.getDefiningOp()) { |
| return llvm::cast<OpResult>(lhs).getResultNumber() < |
| llvm::cast<OpResult>(rhs).getResultNumber(); |
| } else { |
| lhsRegion = lhs.getDefiningOp()->getParentRegion(); |
| rhsRegion = rhs.getDefiningOp()->getParentRegion(); |
| if (lhsRegion == rhsRegion) { |
| return lhs.getDefiningOp()->isBeforeInBlock(rhs.getDefiningOp()); |
| } |
| } |
| |
| // lhsRegion != rhsRegion, so if we look at their ancestor chain, they |
| // - have different heights |
| // - or there's a spot where their region numbers differ |
| // - or their parent regions are the same and their parent ops are |
| // different. |
| while (lhsRegion && rhsRegion) { |
| if (lhsRegion->getRegionNumber() != rhsRegion->getRegionNumber()) { |
| return lhsRegion->getRegionNumber() < rhsRegion->getRegionNumber(); |
| } |
| if (lhsRegion->getParentRegion() == rhsRegion->getParentRegion()) { |
| return lhsRegion->getParentOp()->isBeforeInBlock( |
| rhsRegion->getParentOp()); |
| } |
| lhsRegion = lhsRegion->getParentRegion(); |
| rhsRegion = rhsRegion->getParentRegion(); |
| } |
| if (rhsRegion) |
| return true; |
| assert(lhsRegion && "this should only happen if lhs == rhs"); |
| return false; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Implementation utilities |
| //===----------------------------------------------------------------------===// |
| |
| FailureOr<Operation *> deallocation_impl::insertDeallocOpForReturnLike( |
| DeallocationState &state, Operation *op, ValueRange operands, |
| SmallVectorImpl<Value> &updatedOperandOwnerships) { |
| assert(op->hasTrait<OpTrait::IsTerminator>() && "must be a terminator"); |
| assert(!op->hasSuccessors() && "must not have any successors"); |
| // Collect the values to deallocate and retain and use them to create the |
| // dealloc operation. |
| OpBuilder builder(op); |
| Block *block = op->getBlock(); |
| SmallVector<Value> memrefs, conditions, toRetain; |
| if (failed(state.getMemrefsAndConditionsToDeallocate( |
| builder, op->getLoc(), block, memrefs, conditions))) |
| return failure(); |
| |
| state.getMemrefsToRetain(block, /*toBlock=*/nullptr, operands, toRetain); |
| if (memrefs.empty() && toRetain.empty()) |
| return op; |
| |
| auto deallocOp = builder.create<bufferization::DeallocOp>( |
| op->getLoc(), memrefs, conditions, toRetain); |
| |
| // We want to replace the current ownership of the retained values with the |
| // result values of the dealloc operation as they are always unique. |
| state.resetOwnerships(deallocOp.getRetained(), block); |
| for (auto [retained, ownership] : |
| llvm::zip(deallocOp.getRetained(), deallocOp.getUpdatedConditions())) |
| state.updateOwnership(retained, ownership, block); |
| |
| unsigned numMemrefOperands = llvm::count_if(operands, isMemref); |
| auto newOperandOwnerships = |
| deallocOp.getUpdatedConditions().take_front(numMemrefOperands); |
| updatedOperandOwnerships.append(newOperandOwnerships.begin(), |
| newOperandOwnerships.end()); |
| |
| return op; |
| } |