| //===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" |
| |
| #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| #include "mlir/Dialect/Bufferization/IR/Bufferization.h" |
| #include "mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h" |
| #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" |
| #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" |
| #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "mlir/Dialect/SCF/IR/SCF.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/IR/Dialect.h" |
| #include "mlir/IR/Operation.h" |
| #include "mlir/IR/PatternMatch.h" |
| |
| using namespace mlir; |
| using namespace mlir::bufferization; |
| using namespace mlir::scf; |
| |
| namespace mlir { |
| namespace scf { |
| namespace { |
| |
| /// Helper function for loop bufferization. Cast the given buffer to the given |
| /// memref type. |
| static Value castBuffer(OpBuilder &b, Value buffer, Type type) { |
| assert(isa<BaseMemRefType>(type) && "expected BaseMemRefType"); |
| assert(isa<BaseMemRefType>(buffer.getType()) && "expected BaseMemRefType"); |
| // If the buffer already has the correct type, no cast is needed. |
| if (buffer.getType() == type) |
| return buffer; |
| // TODO: In case `type` has a layout map that is not the fully dynamic |
| // one, we may not be able to cast the buffer. In that case, the loop |
| // iter_arg's layout map must be changed (see uses of `castBuffer`). |
| assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && |
| "scf.while op bufferization: cast incompatible"); |
| return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); |
| } |
| |
| /// Helper function for loop bufferization. Return "true" if the given value |
| /// is guaranteed to not alias with an external tensor apart from values in |
| /// `exceptions`. A value is external if it is defined outside of the given |
| /// region or if it is an entry block argument of the region. |
| static bool doesNotAliasExternalValue(Value value, Region *region, |
| ValueRange exceptions, |
| const OneShotAnalysisState &state) { |
| assert(region->getBlocks().size() == 1 && |
| "expected region with single block"); |
| bool result = true; |
| state.applyOnAliases(value, [&](Value alias) { |
| if (llvm::is_contained(exceptions, alias)) |
| return; |
| Region *aliasRegion = alias.getParentRegion(); |
| if (isa<BlockArgument>(alias) && !region->isProperAncestor(aliasRegion)) |
| result = false; |
| if (isa<OpResult>(alias) && !region->isAncestor(aliasRegion)) |
| result = false; |
| }); |
| return result; |
| } |
| |
| /// Bufferization of scf.condition. |
| struct ConditionOpInterface |
| : public BufferizableOpInterface::ExternalModel<ConditionOpInterface, |
| scf::ConditionOp> { |
| bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| return true; |
| } |
| |
| bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| return false; |
| } |
| |
| AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| return {}; |
| } |
| |
| bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| // Condition operands always bufferize inplace. Otherwise, an alloc + copy |
| // may be generated inside the block. We should not return/yield allocations |
| // when possible. |
| return true; |
| } |
| |
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| const BufferizationOptions &options) const { |
| auto conditionOp = cast<scf::ConditionOp>(op); |
| auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp()); |
| |
| SmallVector<Value> newArgs; |
| for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { |
| Value value = it.value(); |
| if (isa<TensorType>(value.getType())) { |
| FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options); |
| if (failed(maybeBuffer)) |
| return failure(); |
| FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
| whileOp.getAfterArguments()[it.index()], options); |
| if (failed(resultType)) |
| return failure(); |
| Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType); |
| newArgs.push_back(buffer); |
| } else { |
| newArgs.push_back(value); |
| } |
| } |
| |
| replaceOpWithNewBufferizedOp<scf::ConditionOp>( |
| rewriter, op, conditionOp.getCondition(), newArgs); |
| return success(); |
| } |
| }; |
| |
| /// Return the unique scf.yield op. If there are multiple or no scf.yield ops, |
| /// return an empty op. |
| static scf::YieldOp getUniqueYieldOp(scf::ExecuteRegionOp executeRegionOp) { |
| scf::YieldOp result; |
| for (Block &block : executeRegionOp.getRegion()) { |
| if (auto yieldOp = dyn_cast<scf::YieldOp>(block.getTerminator())) { |
| if (result) |
| return {}; |
| result = yieldOp; |
| } |
| } |
| return result; |
| } |
| |
| /// Bufferization of scf.execute_region. Can be analyzed, but bufferization not |
| /// fully implemented at the moment. |
| struct ExecuteRegionOpInterface |
| : public OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel< |
| ExecuteRegionOpInterface, scf::ExecuteRegionOp> { |
| |
| static bool supportsUnstructuredControlFlow() { return true; } |
| |
| bool isWritable(Operation *op, Value value, |
| const AnalysisState &state) const { |
| return true; |
| } |
| |
| LogicalResult verifyAnalysis(Operation *op, |
| const AnalysisState &state) const { |
| auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
| // TODO: scf.execute_region with multiple yields are not supported. |
| if (!getUniqueYieldOp(executeRegionOp)) |
| return op->emitOpError("op without unique scf.yield is not supported"); |
| return success(); |
| } |
| |
| AliasingOpOperandList |
| getAliasingOpOperands(Operation *op, Value value, |
| const AnalysisState &state) const { |
| if (auto bbArg = dyn_cast<BlockArgument>(value)) |
| return getAliasingBranchOpOperands(op, bbArg, state); |
| |
| // ExecuteRegionOps do not have tensor OpOperands. The yielded value can be |
| // any SSA value that is in scope. To allow for use-def chain traversal |
| // through ExecuteRegionOps in the analysis, the corresponding yield value |
| // is considered to be aliasing with the result. |
| auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
| auto it = llvm::find(op->getOpResults(), value); |
| assert(it != op->getOpResults().end() && "invalid value"); |
| size_t resultNum = std::distance(op->getOpResults().begin(), it); |
| auto yieldOp = getUniqueYieldOp(executeRegionOp); |
| // Note: If there is no unique scf.yield op, `verifyAnalysis` will fail. |
| if (!yieldOp) |
| return {}; |
| return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; |
| } |
| |
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| const BufferizationOptions &options) const { |
| auto executeRegionOp = cast<scf::ExecuteRegionOp>(op); |
| auto yieldOp = getUniqueYieldOp(executeRegionOp); |
| TypeRange newResultTypes(yieldOp.getResults()); |
| |
| // Create new op and move over region. |
| auto newOp = |
| rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); |
| newOp.getRegion().takeBody(executeRegionOp.getRegion()); |
| |
| // Bufferize every block. |
| for (Block &block : newOp.getRegion()) |
| if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, |
| options))) |
| return failure(); |
| |
| // Update all uses of the old op. |
| rewriter.setInsertionPointAfter(newOp); |
| SmallVector<Value> newResults; |
| for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { |
| if (isa<TensorType>(it.value())) { |
| newResults.push_back(rewriter.create<bufferization::ToTensorOp>( |
| executeRegionOp.getLoc(), newOp->getResult(it.index()))); |
| } else { |
| newResults.push_back(newOp->getResult(it.index())); |
| } |
| } |
| |
| // Replace old op. |
| rewriter.replaceOp(executeRegionOp, newResults); |
| |
| return success(); |
| } |
| }; |
| |
| /// Bufferization of scf.if. Replace with a new scf.if that yields memrefs. |
| struct IfOpInterface |
| : public BufferizableOpInterface::ExternalModel<IfOpInterface, scf::IfOp> { |
| AliasingOpOperandList |
| getAliasingOpOperands(Operation *op, Value value, |
| const AnalysisState &state) const { |
| // IfOps do not have tensor OpOperands. The yielded value can be any SSA |
| // value that is in scope. To allow for use-def chain traversal through |
| // IfOps in the analysis, both corresponding yield values from the then/else |
| // branches are considered to be aliasing with the result. |
| auto ifOp = cast<scf::IfOp>(op); |
| size_t resultNum = std::distance(op->getOpResults().begin(), |
| llvm::find(op->getOpResults(), value)); |
| OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum); |
| OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum); |
| return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false}, |
| {elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}}; |
| } |
| |
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| const BufferizationOptions &options) const { |
| OpBuilder::InsertionGuard g(rewriter); |
| auto ifOp = cast<scf::IfOp>(op); |
| |
| // Compute bufferized result types. |
| SmallVector<Type> newTypes; |
| for (Value result : ifOp.getResults()) { |
| if (!isa<TensorType>(result.getType())) { |
| newTypes.push_back(result.getType()); |
| continue; |
| } |
| auto bufferType = bufferization::getBufferType(result, options); |
| if (failed(bufferType)) |
| return failure(); |
| newTypes.push_back(*bufferType); |
| } |
| |
| // Create new op. |
| rewriter.setInsertionPoint(ifOp); |
| auto newIfOp = |
| rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), |
| /*withElseRegion=*/true); |
| |
| // Move over then/else blocks. |
| rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); |
| rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock()); |
| |
| // Replace op results. |
| replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults()); |
| |
| return success(); |
| } |
| |
| FailureOr<BaseMemRefType> |
| getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| SmallVector<Value> &invocationStack) const { |
| auto ifOp = cast<scf::IfOp>(op); |
| auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator()); |
| auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator()); |
| assert(value.getDefiningOp() == op && "invalid valid"); |
| |
| // Determine buffer types of the true/false branches. |
| auto opResult = cast<OpResult>(value); |
| auto thenValue = thenYieldOp.getOperand(opResult.getResultNumber()); |
| auto elseValue = elseYieldOp.getOperand(opResult.getResultNumber()); |
| BaseMemRefType thenBufferType, elseBufferType; |
| if (isa<BaseMemRefType>(thenValue.getType())) { |
| // True branch was already bufferized. |
| thenBufferType = cast<BaseMemRefType>(thenValue.getType()); |
| } else { |
| auto maybeBufferType = |
| bufferization::getBufferType(thenValue, options, invocationStack); |
| if (failed(maybeBufferType)) |
| return failure(); |
| thenBufferType = *maybeBufferType; |
| } |
| if (isa<BaseMemRefType>(elseValue.getType())) { |
| // False branch was already bufferized. |
| elseBufferType = cast<BaseMemRefType>(elseValue.getType()); |
| } else { |
| auto maybeBufferType = |
| bufferization::getBufferType(elseValue, options, invocationStack); |
| if (failed(maybeBufferType)) |
| return failure(); |
| elseBufferType = *maybeBufferType; |
| } |
| |
| // Best case: Both branches have the exact same buffer type. |
| if (thenBufferType == elseBufferType) |
| return thenBufferType; |
| |
| // Memory space mismatch. |
| if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace()) |
| return op->emitError("inconsistent memory space on then/else branches"); |
| |
| // Layout maps are different: Promote to fully dynamic layout map. |
| return getMemRefTypeWithFullyDynamicLayout( |
| cast<TensorType>(opResult.getType()), thenBufferType.getMemorySpace()); |
| } |
| }; |
| |
| /// Bufferization of scf.index_switch. Replace with a new scf.index_switch that |
| /// yields memrefs. |
| struct IndexSwitchOpInterface |
| : public BufferizableOpInterface::ExternalModel<IndexSwitchOpInterface, |
| scf::IndexSwitchOp> { |
| AliasingOpOperandList |
| getAliasingOpOperands(Operation *op, Value value, |
| const AnalysisState &state) const { |
| // IndexSwitchOps do not have tensor OpOperands. The yielded value can be |
| // any SSA. This is similar to IfOps. |
| auto switchOp = cast<scf::IndexSwitchOp>(op); |
| int64_t resultNum = cast<OpResult>(value).getResultNumber(); |
| AliasingOpOperandList result; |
| for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
| auto yieldOp = |
| cast<scf::YieldOp>(switchOp.getCaseBlock(i).getTerminator()); |
| result.addAlias(AliasingOpOperand(&yieldOp->getOpOperand(resultNum), |
| BufferRelation::Equivalent, |
| /*isDefinite=*/false)); |
| } |
| auto defaultYieldOp = |
| cast<scf::YieldOp>(switchOp.getDefaultBlock().getTerminator()); |
| result.addAlias(AliasingOpOperand(&defaultYieldOp->getOpOperand(resultNum), |
| BufferRelation::Equivalent, |
| /*isDefinite=*/false)); |
| return result; |
| } |
| |
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| const BufferizationOptions &options) const { |
| OpBuilder::InsertionGuard g(rewriter); |
| auto switchOp = cast<scf::IndexSwitchOp>(op); |
| |
| // Compute bufferized result types. |
| SmallVector<Type> newTypes; |
| for (Value result : switchOp.getResults()) { |
| if (!isa<TensorType>(result.getType())) { |
| newTypes.push_back(result.getType()); |
| continue; |
| } |
| auto bufferType = bufferization::getBufferType(result, options); |
| if (failed(bufferType)) |
| return failure(); |
| newTypes.push_back(*bufferType); |
| } |
| |
| // Create new op. |
| rewriter.setInsertionPoint(switchOp); |
| auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>( |
| switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(), |
| switchOp.getCases().size()); |
| |
| // Move over blocks. |
| for (auto [src, dest] : |
| llvm::zip(switchOp.getCaseRegions(), newSwitchOp.getCaseRegions())) |
| rewriter.inlineRegionBefore(src, dest, dest.begin()); |
| rewriter.inlineRegionBefore(switchOp.getDefaultRegion(), |
| newSwitchOp.getDefaultRegion(), |
| newSwitchOp.getDefaultRegion().begin()); |
| |
| // Replace op results. |
| replaceOpWithBufferizedValues(rewriter, op, newSwitchOp->getResults()); |
| |
| return success(); |
| } |
| |
| FailureOr<BaseMemRefType> |
| getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| SmallVector<Value> &invocationStack) const { |
| auto switchOp = cast<scf::IndexSwitchOp>(op); |
| assert(value.getDefiningOp() == op && "invalid value"); |
| int64_t resultNum = cast<OpResult>(value).getResultNumber(); |
| |
| // Helper function to get buffer type of a case. |
| SmallVector<BaseMemRefType> yieldedTypes; |
| auto getYieldedBufferType = [&](Block &b) -> FailureOr<BaseMemRefType> { |
| auto yieldOp = cast<scf::YieldOp>(b.getTerminator()); |
| Value yieldedValue = yieldOp->getOperand(resultNum); |
| if (auto bufferType = dyn_cast<BaseMemRefType>(yieldedValue.getType())) |
| return bufferType; |
| auto maybeBufferType = |
| bufferization::getBufferType(yieldedValue, options, invocationStack); |
| if (failed(maybeBufferType)) |
| return failure(); |
| return maybeBufferType; |
| }; |
| |
| // Compute buffer type of the default case. |
| auto maybeBufferType = getYieldedBufferType(switchOp.getDefaultBlock()); |
| if (failed(maybeBufferType)) |
| return failure(); |
| BaseMemRefType bufferType = *maybeBufferType; |
| |
| // Compute buffer types of all other cases. |
| for (int64_t i = 0, numCases = switchOp.getNumCases(); i < numCases; ++i) { |
| auto yieldedBufferType = getYieldedBufferType(switchOp.getCaseBlock(i)); |
| if (failed(yieldedBufferType)) |
| return failure(); |
| |
| // Best case: Both branches have the exact same buffer type. |
| if (bufferType == *yieldedBufferType) |
| continue; |
| |
| // Memory space mismatch. |
| if (bufferType.getMemorySpace() != yieldedBufferType->getMemorySpace()) |
| return op->emitError("inconsistent memory space on switch cases"); |
| |
| // Layout maps are different: Promote to fully dynamic layout map. |
| bufferType = getMemRefTypeWithFullyDynamicLayout( |
| cast<TensorType>(value.getType()), bufferType.getMemorySpace()); |
| } |
| |
| return bufferType; |
| } |
| }; |
| |
| /// Helper function for loop bufferization. Return the indices of all values |
| /// that have a tensor type. |
| static DenseSet<int64_t> getTensorIndices(ValueRange values) { |
| DenseSet<int64_t> result; |
| for (const auto &it : llvm::enumerate(values)) |
| if (isa<TensorType>(it.value().getType())) |
| result.insert(it.index()); |
| return result; |
| } |
| |
| /// Helper function for loop bufferization. Return the indices of all |
| /// bbArg/yielded value pairs who's buffer relation is "Equivalent". |
| DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs, |
| ValueRange yieldedValues, |
| const AnalysisState &state) { |
| unsigned int minSize = std::min(bbArgs.size(), yieldedValues.size()); |
| DenseSet<int64_t> result; |
| for (unsigned int i = 0; i < minSize; ++i) { |
| if (!isa<TensorType>(bbArgs[i].getType()) || |
| !isa<TensorType>(yieldedValues[i].getType())) |
| continue; |
| if (state.areEquivalentBufferizedValues(bbArgs[i], yieldedValues[i])) |
| result.insert(i); |
| } |
| return result; |
| } |
| |
| /// Helper function for loop bufferization. Return the bufferized values of the |
| /// given OpOperands. If an operand is not a tensor, return the original value. |
| static FailureOr<SmallVector<Value>> |
| getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, |
| const BufferizationOptions &options) { |
| SmallVector<Value> result; |
| for (OpOperand &opOperand : operands) { |
| if (isa<TensorType>(opOperand.get().getType())) { |
| FailureOr<Value> resultBuffer = |
| getBuffer(rewriter, opOperand.get(), options); |
| if (failed(resultBuffer)) |
| return failure(); |
| result.push_back(*resultBuffer); |
| } else { |
| result.push_back(opOperand.get()); |
| } |
| } |
| return result; |
| } |
| |
| /// Helper function for loop bufferization. Given a list of bbArgs of the new |
| /// (bufferized) loop op, wrap the bufferized tensor args (now memrefs) into |
| /// ToTensorOps, so that the block body can be moved over to the new op. |
| static SmallVector<Value> |
| getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, |
| const DenseSet<int64_t> &tensorIndices) { |
| SmallVector<Value> result; |
| for (const auto &it : llvm::enumerate(bbArgs)) { |
| size_t idx = it.index(); |
| Value val = it.value(); |
| if (tensorIndices.contains(idx)) { |
| result.push_back( |
| rewriter.create<bufferization::ToTensorOp>(val.getLoc(), val) |
| .getResult()); |
| } else { |
| result.push_back(val); |
| } |
| } |
| return result; |
| } |
| |
| /// Compute the bufferized type of a loop iter_arg. This type must be equal to |
| /// the bufferized type of the corresponding init_arg and the bufferized type |
| /// of the corresponding yielded value. |
| /// |
| /// This function uses bufferization::getBufferType to compute the bufferized |
| /// type of the init_arg and of the yielded value. (The computation of the |
| /// bufferized yielded value type usually requires computing the bufferized type |
| /// of the iter_arg again; the implementation of getBufferType traces back the |
| /// use-def chain of the given value and computes a buffer type along the way.) |
| /// If both buffer types are equal, no casts are needed the computed buffer type |
| /// can be used directly. Otherwise, the buffer types can only differ in their |
| /// layout map and a cast must be inserted. |
| static FailureOr<BaseMemRefType> computeLoopRegionIterArgBufferType( |
| Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, |
| const BufferizationOptions &options, SmallVector<Value> &invocationStack) { |
| // Determine the buffer type of the init_arg. |
| auto initArgBufferType = |
| bufferization::getBufferType(initArg, options, invocationStack); |
| if (failed(initArgBufferType)) |
| return failure(); |
| |
| if (llvm::count(invocationStack, iterArg) >= 2) { |
| // If the iter_arg is already twice on the invocation stack, just take the |
| // type of the init_arg. This is to avoid infinite loops when calculating |
| // the buffer type. This will most likely result in computing a memref type |
| // with a fully dynamic layout map. |
| |
| // Note: For more precise layout map computation, a fixpoint iteration could |
| // be done (i.e., re-computing the yielded buffer type until the bufferized |
| // iter_arg type no longer changes). This current implementation immediately |
| // switches to a fully dynamic layout map when a mismatch between bufferized |
| // init_arg type and bufferized yield value type is detected. |
| return *initArgBufferType; |
| } |
| |
| // Compute the buffer type of the yielded value. |
| BaseMemRefType yieldedValueBufferType; |
| if (isa<BaseMemRefType>(yieldedValue.getType())) { |
| // scf.yield was already bufferized. |
| yieldedValueBufferType = cast<BaseMemRefType>(yieldedValue.getType()); |
| } else { |
| // Note: This typically triggers a recursive call for the buffer type of |
| // the iter_arg. |
| auto maybeBufferType = |
| bufferization::getBufferType(yieldedValue, options, invocationStack); |
| if (failed(maybeBufferType)) |
| return failure(); |
| yieldedValueBufferType = *maybeBufferType; |
| } |
| |
| // If yielded type and init_arg type are the same, use that type directly. |
| if (*initArgBufferType == yieldedValueBufferType) |
| return yieldedValueBufferType; |
| |
| // If there is a mismatch between the yielded buffer type and the init_arg |
| // buffer type, the buffer type must be promoted to a fully dynamic layout |
| // map. |
| auto yieldedBufferType = cast<BaseMemRefType>(yieldedValueBufferType); |
| auto iterTensorType = cast<TensorType>(iterArg.getType()); |
| auto initBufferType = llvm::cast<BaseMemRefType>(*initArgBufferType); |
| if (initBufferType.getMemorySpace() != yieldedBufferType.getMemorySpace()) |
| return loopOp->emitOpError( |
| "init_arg and yielded value bufferize to inconsistent memory spaces"); |
| #ifndef NDEBUG |
| if (auto yieldedRankedBufferType = dyn_cast<MemRefType>(yieldedBufferType)) { |
| assert( |
| llvm::all_equal({yieldedRankedBufferType.getShape(), |
| cast<MemRefType>(initBufferType).getShape(), |
| cast<RankedTensorType>(iterTensorType).getShape()}) && |
| "expected same shape"); |
| } |
| #endif // NDEBUG |
| return getMemRefTypeWithFullyDynamicLayout( |
| iterTensorType, yieldedBufferType.getMemorySpace()); |
| } |
| |
| /// Return `true` if the given loop may have 0 iterations. |
| bool mayHaveZeroIterations(scf::ForOp forOp) { |
| std::optional<int64_t> lb = getConstantIntValue(forOp.getLowerBound()); |
| std::optional<int64_t> ub = getConstantIntValue(forOp.getUpperBound()); |
| if (!lb.has_value() || !ub.has_value()) |
| return true; |
| return *ub <= *lb; |
| } |
| |
| /// Bufferization of scf.for. Replace with a new scf.for that operates on |
| /// memrefs. |
| struct ForOpInterface |
| : public BufferizableOpInterface::ExternalModel<ForOpInterface, |
| scf::ForOp> { |
| bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| auto forOp = cast<scf::ForOp>(op); |
| |
| // If the loop has zero iterations, the results of the op are their |
| // corresponding init_args, meaning that the init_args bufferize to a read. |
| if (mayHaveZeroIterations(forOp)) |
| return true; |
| |
| // scf::ForOp alone doesn't bufferize to a memory read, one of the uses of |
| // its matching bbArg may. |
| return state.isValueRead(forOp.getTiedLoopRegionIterArg(&opOperand)); |
| } |
| |
| bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| // Tensor iter_args of scf::ForOps are always considered as a write. |
| return true; |
| } |
| |
| AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| auto forOp = cast<scf::ForOp>(op); |
| OpResult opResult = forOp.getTiedLoopResult(&opOperand); |
| BufferRelation relation = bufferRelation(op, opResult, state); |
| return {{opResult, relation, |
| /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
| } |
| |
| BufferRelation bufferRelation(Operation *op, OpResult opResult, |
| const AnalysisState &state) const { |
| // ForOp results are equivalent to their corresponding init_args if the |
| // corresponding iter_args and yield values are equivalent. |
| auto forOp = cast<scf::ForOp>(op); |
| BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
| bool equivalentYield = state.areEquivalentBufferizedValues( |
| bbArg, forOp.getTiedLoopYieldedValue(bbArg)->get()); |
| return equivalentYield ? BufferRelation::Equivalent |
| : BufferRelation::Unknown; |
| } |
| |
| bool isWritable(Operation *op, Value value, |
| const AnalysisState &state) const { |
| // Interestingly, scf::ForOp's bbArg can **always** be viewed |
| // inplace from the perspective of ops nested under: |
| // 1. Either the matching iter operand is not bufferized inplace and an |
| // alloc + optional copy makes the bbArg itself inplaceable. |
| // 2. Or the matching iter operand is bufferized inplace and bbArg just |
| // bufferizes to that too. |
| return true; |
| } |
| |
| LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, |
| const AnalysisState &state) const { |
| auto bufferizableOp = cast<BufferizableOpInterface>(op); |
| if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) |
| return failure(); |
| |
| if (!state.getOptions().enforceAliasingInvariants) |
| return success(); |
| |
| // According to the `getAliasing...` implementations, a bufferized OpResult |
| // may alias only with the corresponding bufferized init_arg (or with a |
| // newly allocated buffer) and not with other buffers defined outside of the |
| // loop. I.e., the i-th OpResult may alias with the i-th init_arg; |
| // but not with any other OpOperand. |
| auto forOp = cast<scf::ForOp>(op); |
| auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| OpBuilder::InsertionGuard g(rewriter); |
| rewriter.setInsertionPoint(yieldOp); |
| |
| // Indices of all iter_args that have tensor type. These are the ones that |
| // are bufferized. |
| DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); |
| // For every yielded value, does it alias with something defined outside of |
| // the loop? |
| SmallVector<Value> yieldValues; |
| for (const auto it : llvm::enumerate(yieldOp.getResults())) { |
| // Note: `state` is guaranteed to be a `OneShotAnalysisState`, but this |
| // type cannot be used in the signature of `resolveConflicts` because the |
| // op interface is in the "IR" build unit and the `OneShotAnalysisState` |
| // is defined in the "Transforms" build unit. |
| if (!indices.contains(it.index()) || |
| doesNotAliasExternalValue( |
| it.value(), &forOp.getRegion(), |
| /*exceptions=*/forOp.getRegionIterArg(it.index()), |
| static_cast<const OneShotAnalysisState &>(state))) { |
| yieldValues.push_back(it.value()); |
| continue; |
| } |
| FailureOr<Value> alloc = allocateTensorForShapedValue( |
| rewriter, yieldOp.getLoc(), it.value(), state.getOptions()); |
| if (failed(alloc)) |
| return failure(); |
| yieldValues.push_back(*alloc); |
| } |
| |
| rewriter.modifyOpInPlace( |
| yieldOp, [&]() { yieldOp.getResultsMutable().assign(yieldValues); }); |
| return success(); |
| } |
| |
| FailureOr<BaseMemRefType> |
| getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| SmallVector<Value> &invocationStack) const { |
| auto forOp = cast<scf::ForOp>(op); |
| assert(getOwnerOfValue(value) == op && "invalid value"); |
| assert(isa<TensorType>(value.getType()) && "expected tensor type"); |
| |
| if (auto opResult = dyn_cast<OpResult>(value)) { |
| // The type of an OpResult must match the corresponding iter_arg type. |
| BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); |
| return bufferization::getBufferType(bbArg, options, invocationStack); |
| } |
| |
| // Compute result/argument number. |
| BlockArgument bbArg = cast<BlockArgument>(value); |
| unsigned resultNum = forOp.getTiedLoopResult(bbArg).getResultNumber(); |
| |
| // Compute the bufferized type. |
| auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| Value yieldedValue = yieldOp.getOperand(resultNum); |
| BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; |
| Value initArg = forOp.getInitArgs()[resultNum]; |
| return computeLoopRegionIterArgBufferType( |
| op, iterArg, initArg, yieldedValue, options, invocationStack); |
| } |
| |
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| const BufferizationOptions &options) const { |
| auto forOp = cast<scf::ForOp>(op); |
| Block *oldLoopBody = forOp.getBody(); |
| |
| // Indices of all iter_args that have tensor type. These are the ones that |
| // are bufferized. |
| DenseSet<int64_t> indices = getTensorIndices(forOp.getInitArgs()); |
| |
| // The new memref init_args of the loop. |
| FailureOr<SmallVector<Value>> maybeInitArgs = |
| getBuffers(rewriter, forOp.getInitArgsMutable(), options); |
| if (failed(maybeInitArgs)) |
| return failure(); |
| SmallVector<Value> initArgs = *maybeInitArgs; |
| |
| // Cast init_args if necessary. |
| SmallVector<Value> castedInitArgs; |
| for (const auto &it : llvm::enumerate(initArgs)) { |
| Value initArg = it.value(); |
| Value result = forOp->getResult(it.index()); |
| // If the type is not a tensor, bufferization doesn't need to touch it. |
| if (!isa<TensorType>(result.getType())) { |
| castedInitArgs.push_back(initArg); |
| continue; |
| } |
| auto targetType = bufferization::getBufferType(result, options); |
| if (failed(targetType)) |
| return failure(); |
| castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); |
| } |
| |
| // Construct a new scf.for op with memref instead of tensor values. |
| auto newForOp = rewriter.create<scf::ForOp>( |
| forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), |
| forOp.getStep(), castedInitArgs); |
| newForOp->setAttrs(forOp->getAttrs()); |
| Block *loopBody = newForOp.getBody(); |
| |
| // Set up new iter_args. The loop body uses tensors, so wrap the (memref) |
| // iter_args of the new loop in ToTensorOps. |
| rewriter.setInsertionPointToStart(loopBody); |
| SmallVector<Value> iterArgs = |
| getBbArgReplacements(rewriter, newForOp.getRegionIterArgs(), indices); |
| iterArgs.insert(iterArgs.begin(), newForOp.getInductionVar()); |
| |
| // Move loop body to new loop. |
| rewriter.mergeBlocks(oldLoopBody, loopBody, iterArgs); |
| |
| // Replace loop results. |
| replaceOpWithBufferizedValues(rewriter, op, newForOp->getResults()); |
| |
| return success(); |
| } |
| |
| /// Assert that yielded values of an scf.for op are equivalent to their |
| /// corresponding bbArgs. In that case, the buffer relations of the |
| /// corresponding OpResults are "Equivalent". |
| /// |
| /// If this is not the case, an allocs+copies are inserted and yielded from |
| /// the loop. This could be a performance problem, so it must be explicitly |
| /// activated with `alloc-return-allocs`. |
| LogicalResult verifyAnalysis(Operation *op, |
| const AnalysisState &state) const { |
| const auto &options = |
| static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
| if (options.allowReturnAllocsFromLoops) |
| return success(); |
| |
| auto forOp = cast<scf::ForOp>(op); |
| auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator()); |
| for (OpResult opResult : op->getOpResults()) { |
| if (!isa<TensorType>(opResult.getType())) |
| continue; |
| |
| // Note: This is overly strict. We should check for aliasing bufferized |
| // values. But we don't have a "must-alias" analysis yet. |
| if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent) |
| return yieldOp->emitError() |
| << "Yield operand #" << opResult.getResultNumber() |
| << " is not equivalent to the corresponding iter bbArg"; |
| } |
| |
| return success(); |
| } |
| }; |
| |
| /// Bufferization of scf.while. Replace with a new scf.while that operates on |
| /// memrefs. |
| struct WhileOpInterface |
| : public BufferizableOpInterface::ExternalModel<WhileOpInterface, |
| scf::WhileOp> { |
| bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| // Tensor iter_args of scf::WhileOps are always considered as a read. |
| return true; |
| } |
| |
| bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| // Tensor iter_args of scf::WhileOps are always considered as a write. |
| return true; |
| } |
| |
| AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| auto whileOp = cast<scf::WhileOp>(op); |
| unsigned int idx = opOperand.getOperandNumber(); |
| |
| // The OpResults and OpOperands may not match. They may not even have the |
| // same type. The number of OpResults and OpOperands can also differ. |
| if (idx >= op->getNumResults() || |
| opOperand.get().getType() != op->getResult(idx).getType()) |
| return {}; |
| |
| // The only aliasing OpResult may be the one at the same index. |
| OpResult opResult = whileOp->getResult(idx); |
| BufferRelation relation = bufferRelation(op, opResult, state); |
| return {{opResult, relation, |
| /*isDefinite=*/relation == BufferRelation::Equivalent}}; |
| } |
| |
| BufferRelation bufferRelation(Operation *op, OpResult opResult, |
| const AnalysisState &state) const { |
| // WhileOp results are equivalent to their corresponding init_args if the |
| // corresponding iter_args and yield values are equivalent (for both the |
| // "before" and the "after" block). |
| unsigned int resultNumber = opResult.getResultNumber(); |
| auto whileOp = cast<scf::WhileOp>(op); |
| |
| // The "before" region bbArgs and the OpResults may not match. |
| if (resultNumber >= whileOp.getBeforeArguments().size()) |
| return BufferRelation::Unknown; |
| if (opResult.getType() != |
| whileOp.getBeforeArguments()[resultNumber].getType()) |
| return BufferRelation::Unknown; |
| |
| auto conditionOp = whileOp.getConditionOp(); |
| BlockArgument conditionBbArg = whileOp.getBeforeArguments()[resultNumber]; |
| Value conditionOperand = conditionOp.getArgs()[resultNumber]; |
| bool equivCondition = |
| state.areEquivalentBufferizedValues(conditionBbArg, conditionOperand); |
| |
| auto yieldOp = whileOp.getYieldOp(); |
| BlockArgument bodyBbArg = whileOp.getAfterArguments()[resultNumber]; |
| Value yieldOperand = yieldOp.getOperand(resultNumber); |
| bool equivYield = |
| state.areEquivalentBufferizedValues(bodyBbArg, yieldOperand); |
| |
| return equivCondition && equivYield ? BufferRelation::Equivalent |
| : BufferRelation::Unknown; |
| } |
| |
| bool isWritable(Operation *op, Value value, |
| const AnalysisState &state) const { |
| // Interestingly, scf::WhileOp's bbArg can **always** be viewed |
| // inplace from the perspective of ops nested under: |
| // 1. Either the matching iter operand is not bufferized inplace and an |
| // alloc + optional copy makes the bbArg itself inplaceable. |
| // 2. Or the matching iter operand is bufferized inplace and bbArg just |
| // bufferizes to that too. |
| return true; |
| } |
| |
| LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, |
| const AnalysisState &state) const { |
| auto bufferizableOp = cast<BufferizableOpInterface>(op); |
| if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) |
| return failure(); |
| |
| if (!state.getOptions().enforceAliasingInvariants) |
| return success(); |
| |
| // According to the `getAliasing...` implementations, a bufferized OpResult |
| // may alias only with the corresponding bufferized init_arg and with no |
| // other buffers. I.e., the i-th OpResult may alias with the i-th init_arg; |
| // but not with any other OpOperand. If a corresponding OpResult/init_arg |
| // pair bufferizes to equivalent buffers, this aliasing requirement is |
| // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy. |
| // (New buffer copies do not alias with any buffer.) |
| OpBuilder::InsertionGuard g(rewriter); |
| auto whileOp = cast<scf::WhileOp>(op); |
| auto conditionOp = whileOp.getConditionOp(); |
| |
| // For every yielded value, is the value equivalent to its corresponding |
| // bbArg? |
| DenseSet<int64_t> equivalentYieldsBefore = getEquivalentBuffers( |
| whileOp.getBeforeArguments(), conditionOp.getArgs(), state); |
| DenseSet<int64_t> equivalentYieldsAfter = getEquivalentBuffers( |
| whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); |
| |
| // Update "before" region. |
| rewriter.setInsertionPoint(conditionOp); |
| SmallVector<Value> beforeYieldValues; |
| for (int64_t idx = 0; |
| idx < static_cast<int64_t>(conditionOp.getArgs().size()); ++idx) { |
| Value value = conditionOp.getArgs()[idx]; |
| if (!isa<TensorType>(value.getType()) || |
| (equivalentYieldsAfter.contains(idx) && |
| equivalentYieldsBefore.contains(idx))) { |
| beforeYieldValues.push_back(value); |
| continue; |
| } |
| FailureOr<Value> alloc = allocateTensorForShapedValue( |
| rewriter, conditionOp.getLoc(), value, state.getOptions()); |
| if (failed(alloc)) |
| return failure(); |
| beforeYieldValues.push_back(*alloc); |
| } |
| rewriter.modifyOpInPlace(conditionOp, [&]() { |
| conditionOp.getArgsMutable().assign(beforeYieldValues); |
| }); |
| |
| return success(); |
| } |
| |
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| const BufferizationOptions &options) const { |
| auto whileOp = cast<scf::WhileOp>(op); |
| |
| // Indices of all bbArgs that have tensor type. These are the ones that |
| // are bufferized. The "before" and "after" regions may have different args. |
| DenseSet<int64_t> indicesBefore = getTensorIndices(whileOp.getInits()); |
| DenseSet<int64_t> indicesAfter = |
| getTensorIndices(whileOp.getAfterArguments()); |
| |
| // The new memref init_args of the loop. |
| FailureOr<SmallVector<Value>> maybeInitArgs = |
| getBuffers(rewriter, whileOp.getInitsMutable(), options); |
| if (failed(maybeInitArgs)) |
| return failure(); |
| SmallVector<Value> initArgs = *maybeInitArgs; |
| |
| // Cast init_args if necessary. |
| SmallVector<Value> castedInitArgs; |
| for (const auto &it : llvm::enumerate(initArgs)) { |
| Value initArg = it.value(); |
| Value beforeArg = whileOp.getBeforeArguments()[it.index()]; |
| // If the type is not a tensor, bufferization doesn't need to touch it. |
| if (!isa<TensorType>(beforeArg.getType())) { |
| castedInitArgs.push_back(initArg); |
| continue; |
| } |
| auto targetType = bufferization::getBufferType(beforeArg, options); |
| if (failed(targetType)) |
| return failure(); |
| castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); |
| } |
| |
| // The result types of a WhileOp are the same as the "after" bbArg types. |
| SmallVector<Type> argsTypesAfter = llvm::to_vector( |
| llvm::map_range(whileOp.getAfterArguments(), [&](BlockArgument bbArg) { |
| if (!isa<TensorType>(bbArg.getType())) |
| return bbArg.getType(); |
| // TODO: error handling |
| return llvm::cast<Type>( |
| *bufferization::getBufferType(bbArg, options)); |
| })); |
| |
| // Construct a new scf.while op with memref instead of tensor values. |
| ValueRange argsRangeBefore(castedInitArgs); |
| TypeRange argsTypesBefore(argsRangeBefore); |
| auto newWhileOp = rewriter.create<scf::WhileOp>( |
| whileOp.getLoc(), argsTypesAfter, castedInitArgs); |
| |
| // Add before/after regions to the new op. |
| SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(), |
| whileOp.getLoc()); |
| SmallVector<Location> bbArgLocsAfter(argsTypesAfter.size(), |
| whileOp.getLoc()); |
| Block *newBeforeBody = &newWhileOp.getBefore().emplaceBlock(); |
| newWhileOp.getBefore().addArguments(argsTypesBefore, bbArgLocsBefore); |
| Block *newAfterBody = &newWhileOp.getAfter().emplaceBlock(); |
| newWhileOp.getAfter().addArguments(argsTypesAfter, bbArgLocsAfter); |
| |
| // Set up new iter_args and move the loop condition block to the new op. |
| // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
| // in ToTensorOps. |
| rewriter.setInsertionPointToStart(newBeforeBody); |
| SmallVector<Value> newBeforeArgs = getBbArgReplacements( |
| rewriter, newWhileOp.getBeforeArguments(), indicesBefore); |
| rewriter.mergeBlocks(whileOp.getBeforeBody(), newBeforeBody, newBeforeArgs); |
| |
| // Set up new iter_args and move the loop body block to the new op. |
| // The old block uses tensors, so wrap the (memref) bbArgs of the new block |
| // in ToTensorOps. |
| rewriter.setInsertionPointToStart(newAfterBody); |
| SmallVector<Value> newAfterArgs = getBbArgReplacements( |
| rewriter, newWhileOp.getAfterArguments(), indicesAfter); |
| rewriter.mergeBlocks(whileOp.getAfterBody(), newAfterBody, newAfterArgs); |
| |
| // Replace loop results. |
| replaceOpWithBufferizedValues(rewriter, op, newWhileOp->getResults()); |
| |
| return success(); |
| } |
| |
| FailureOr<BaseMemRefType> |
| getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| SmallVector<Value> &invocationStack) const { |
| auto whileOp = cast<scf::WhileOp>(op); |
| assert(getOwnerOfValue(value) == op && "invalid value"); |
| assert(isa<TensorType>(value.getType()) && "expected tensor type"); |
| |
| // Case 1: Block argument of the "before" region. |
| if (auto bbArg = dyn_cast<BlockArgument>(value)) { |
| if (bbArg.getOwner()->getParent() == &whileOp.getBefore()) { |
| Value initArg = whileOp.getInits()[bbArg.getArgNumber()]; |
| auto yieldOp = whileOp.getYieldOp(); |
| Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber()); |
| return computeLoopRegionIterArgBufferType( |
| op, bbArg, initArg, yieldedValue, options, invocationStack); |
| } |
| } |
| |
| // Case 2: OpResult of the loop or block argument of the "after" region. |
| // The bufferized "after" bbArg type can be directly computed from the |
| // bufferized "before" bbArg type. |
| unsigned resultNum; |
| if (auto opResult = dyn_cast<OpResult>(value)) { |
| resultNum = opResult.getResultNumber(); |
| } else if (cast<BlockArgument>(value).getOwner()->getParent() == |
| &whileOp.getAfter()) { |
| resultNum = cast<BlockArgument>(value).getArgNumber(); |
| } else { |
| llvm_unreachable("invalid value"); |
| } |
| Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; |
| if (!isa<TensorType>(conditionYieldedVal.getType())) { |
| // scf.condition was already bufferized. |
| return cast<BaseMemRefType>(conditionYieldedVal.getType()); |
| } |
| return bufferization::getBufferType(conditionYieldedVal, options, |
| invocationStack); |
| } |
| |
| /// Assert that yielded values of an scf.while op are equivalent to their |
| /// corresponding bbArgs. In that case, the buffer relations of the |
| /// corresponding OpResults are "Equivalent". |
| /// |
| /// If this is not the case, allocs+copies are inserted and yielded from |
| /// the loop. This could be a performance problem, so it must be explicitly |
| /// activated with `allow-return-allocs`. |
| /// |
| /// Not: In contrast to scf::ForOp, scf::WhileOp has two regions and the |
| /// equivalence condition must be checked for both. |
| LogicalResult verifyAnalysis(Operation *op, |
| const AnalysisState &state) const { |
| auto whileOp = cast<scf::WhileOp>(op); |
| const auto &options = |
| static_cast<const OneShotBufferizationOptions &>(state.getOptions()); |
| if (options.allowReturnAllocsFromLoops) |
| return success(); |
| |
| auto conditionOp = whileOp.getConditionOp(); |
| for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { |
| Block *block = conditionOp->getBlock(); |
| if (!isa<TensorType>(it.value().getType())) |
| continue; |
| if (it.index() >= block->getNumArguments() || |
| !state.areEquivalentBufferizedValues(it.value(), |
| block->getArgument(it.index()))) |
| return conditionOp->emitError() |
| << "Condition arg #" << it.index() |
| << " is not equivalent to the corresponding iter bbArg"; |
| } |
| |
| auto yieldOp = whileOp.getYieldOp(); |
| for (const auto &it : llvm::enumerate(yieldOp.getResults())) { |
| Block *block = yieldOp->getBlock(); |
| if (!isa<TensorType>(it.value().getType())) |
| continue; |
| if (it.index() >= block->getNumArguments() || |
| !state.areEquivalentBufferizedValues(it.value(), |
| block->getArgument(it.index()))) |
| return yieldOp->emitError() |
| << "Yield operand #" << it.index() |
| << " is not equivalent to the corresponding iter bbArg"; |
| } |
| |
| return success(); |
| } |
| }; |
| |
| /// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so |
| /// this is for analysis only. |
| struct YieldOpInterface |
| : public BufferizableOpInterface::ExternalModel<YieldOpInterface, |
| scf::YieldOp> { |
| bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| return true; |
| } |
| |
| bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| return false; |
| } |
| |
| AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| if (auto ifOp = dyn_cast<scf::IfOp>(op->getParentOp())) { |
| return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), |
| BufferRelation::Equivalent, /*isDefinite=*/false}}; |
| } |
| if (isa<scf::ExecuteRegionOp>(op->getParentOp())) |
| return {{op->getParentOp()->getResult(opOperand.getOperandNumber()), |
| BufferRelation::Equivalent}}; |
| return {}; |
| } |
| |
| bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| // Yield operands always bufferize inplace. Otherwise, an alloc + copy |
| // may be generated inside the block. We should not return/yield allocations |
| // when possible. |
| return true; |
| } |
| |
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| const BufferizationOptions &options) const { |
| auto yieldOp = cast<scf::YieldOp>(op); |
| if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::IndexSwitchOp, scf::ForOp, |
| scf::WhileOp>(yieldOp->getParentOp())) |
| return yieldOp->emitError("unsupported scf::YieldOp parent"); |
| |
| SmallVector<Value> newResults; |
| for (const auto &it : llvm::enumerate(yieldOp.getResults())) { |
| Value value = it.value(); |
| if (isa<TensorType>(value.getType())) { |
| FailureOr<Value> maybeBuffer = getBuffer(rewriter, value, options); |
| if (failed(maybeBuffer)) |
| return failure(); |
| Value buffer = *maybeBuffer; |
| // We may have to cast the value before yielding it. |
| if (isa<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>( |
| yieldOp->getParentOp())) { |
| FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
| yieldOp->getParentOp()->getResult(it.index()), options); |
| if (failed(resultType)) |
| return failure(); |
| buffer = castBuffer(rewriter, buffer, *resultType); |
| } else if (auto whileOp = |
| dyn_cast<scf::WhileOp>(yieldOp->getParentOp())) { |
| FailureOr<BaseMemRefType> resultType = bufferization::getBufferType( |
| whileOp.getBeforeArguments()[it.index()], options); |
| if (failed(resultType)) |
| return failure(); |
| buffer = castBuffer(rewriter, buffer, *resultType); |
| } |
| newResults.push_back(buffer); |
| } else { |
| newResults.push_back(value); |
| } |
| } |
| |
| replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults); |
| return success(); |
| } |
| }; |
| |
| /// Return `true` if the given loop may have 0 iterations. |
| bool mayHaveZeroIterations(scf::ForallOp forallOp) { |
| for (auto [lb, ub] : llvm::zip(forallOp.getMixedLowerBound(), |
| forallOp.getMixedUpperBound())) { |
| std::optional<int64_t> lbConst = getConstantIntValue(lb); |
| std::optional<int64_t> ubConst = getConstantIntValue(ub); |
| if (!lbConst.has_value() || !ubConst.has_value() || *lbConst >= *ubConst) |
| return true; |
| } |
| return false; |
| } |
| |
| /// Bufferization of ForallOp. This also bufferizes the terminator of the |
| /// region. There are op interfaces for the terminators (InParallelOp |
| /// and ParallelInsertSliceOp), but these are only used during analysis. Not |
| /// for bufferization. |
| struct ForallOpInterface |
| : public BufferizableOpInterface::ExternalModel<ForallOpInterface, |
| ForallOp> { |
| bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| auto forallOp = cast<ForallOp>(op); |
| |
| // If the loop has zero iterations, the results of the op are their |
| // corresponding shared_outs, meaning that the shared_outs bufferize to a |
| // read. |
| if (mayHaveZeroIterations(forallOp)) |
| return true; |
| |
| // scf::ForallOp alone doesn't bufferize to a memory read, one of the |
| // uses of its matching bbArg may. |
| return state.isValueRead(forallOp.getTiedBlockArgument(&opOperand)); |
| } |
| |
| bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| // Outputs of scf::ForallOps are always considered as a write. |
| return true; |
| } |
| |
| AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand, |
| const AnalysisState &state) const { |
| auto forallOp = cast<ForallOp>(op); |
| return { |
| {{forallOp.getTiedOpResult(&opOperand), BufferRelation::Equivalent}}}; |
| } |
| |
| bool isWritable(Operation *op, Value value, |
| const AnalysisState &state) const { |
| return true; |
| } |
| |
| LogicalResult bufferize(Operation *op, RewriterBase &rewriter, |
| const BufferizationOptions &options) const { |
| OpBuilder::InsertionGuard guard(rewriter); |
| auto forallOp = cast<ForallOp>(op); |
| int64_t rank = forallOp.getRank(); |
| |
| // Get buffers for all output operands. |
| SmallVector<Value> buffers; |
| for (Value out : forallOp.getOutputs()) { |
| FailureOr<Value> buffer = getBuffer(rewriter, out, options); |
| if (failed(buffer)) |
| return failure(); |
| buffers.push_back(*buffer); |
| } |
| |
| // Use buffers instead of block arguments. |
| rewriter.setInsertionPointToStart(forallOp.getBody()); |
| for (const auto &it : llvm::zip( |
| forallOp.getBody()->getArguments().drop_front(rank), buffers)) { |
| BlockArgument bbArg = std::get<0>(it); |
| Value buffer = std::get<1>(it); |
| Value bufferAsTensor = |
| rewriter.create<ToTensorOp>(forallOp.getLoc(), buffer); |
| bbArg.replaceAllUsesWith(bufferAsTensor); |
| } |
| |
| // Create new ForallOp without any results and drop the automatically |
| // introduced terminator. |
| rewriter.setInsertionPoint(forallOp); |
| ForallOp newForallOp; |
| newForallOp = rewriter.create<ForallOp>( |
| forallOp.getLoc(), forallOp.getMixedLowerBound(), |
| forallOp.getMixedUpperBound(), forallOp.getMixedStep(), |
| /*outputs=*/ValueRange(), forallOp.getMapping()); |
| |
| rewriter.eraseOp(newForallOp.getBody()->getTerminator()); |
| |
| // Move over block contents of the old op. |
| SmallVector<Value> replacementBbArgs; |
| replacementBbArgs.append(newForallOp.getBody()->getArguments().begin(), |
| newForallOp.getBody()->getArguments().end()); |
| replacementBbArgs.append(forallOp.getOutputs().size(), Value()); |
| rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(), |
| replacementBbArgs); |
| |
| // Remove the old op and replace all of its uses. |
| replaceOpWithBufferizedValues(rewriter, op, buffers); |
| |
| return success(); |
| } |
| |
| FailureOr<BaseMemRefType> |
| getBufferType(Operation *op, Value value, const BufferizationOptions &options, |
| SmallVector<Value> &invocationStack) const { |
| auto forallOp = cast<ForallOp>(op); |
| |
| if (auto bbArg = dyn_cast<BlockArgument>(value)) |
| // A tensor block argument has the same bufferized type as the |
| // corresponding output operand. |
| return bufferization::getBufferType( |
| forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack); |
| |
| // The bufferized result type is the same as the bufferized type of the |
| // corresponding output operand. |
| return bufferization::getBufferType( |
| forallOp.getOutputs()[cast<OpResult>(value).getResultNumber()], options, |
| invocationStack); |
| } |
| |
| bool isRepetitiveRegion(Operation *op, unsigned index) const { |
| auto forallOp = cast<ForallOp>(op); |
| |
| // This op is repetitive if it has 1 or more steps. |
| // If the control variables are dynamic, it is also considered so. |
| for (auto [lb, ub, step] : |
| llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), |
| forallOp.getMixedStep())) { |
| std::optional<int64_t> lbConstant = getConstantIntValue(lb); |
| if (!lbConstant) |
| return true; |
| |
| std::optional<int64_t> ubConstant = getConstantIntValue(ub); |
| if (!ubConstant) |
| return true; |
| |
| std::optional<int64_t> stepConstant = getConstantIntValue(step); |
| if (!stepConstant) |
| return true; |
| |
| if (*lbConstant + *stepConstant < *ubConstant) |
| return true; |
| } |
| return false; |
| } |
| |
| bool isParallelRegion(Operation *op, unsigned index) const { |
| return isRepetitiveRegion(op, index); |
| } |
| }; |
| |
| /// Nothing to do for InParallelOp. |
| struct InParallelOpInterface |
| : public BufferizableOpInterface::ExternalModel<InParallelOpInterface, |
| InParallelOp> { |
| LogicalResult bufferize(Operation *op, RewriterBase &b, |
| const BufferizationOptions &options) const { |
| llvm_unreachable("op does not have any tensor OpOperands / OpResults"); |
| return failure(); |
| } |
| }; |
| |
| } // namespace |
| } // namespace scf |
| } // namespace mlir |
| |
| void mlir::scf::registerBufferizableOpInterfaceExternalModels( |
| DialectRegistry ®istry) { |
| registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { |
| ConditionOp::attachInterface<ConditionOpInterface>(*ctx); |
| ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx); |
| ForOp::attachInterface<ForOpInterface>(*ctx); |
| IfOp::attachInterface<IfOpInterface>(*ctx); |
| IndexSwitchOp::attachInterface<IndexSwitchOpInterface>(*ctx); |
| ForallOp::attachInterface<ForallOpInterface>(*ctx); |
| InParallelOp::attachInterface<InParallelOpInterface>(*ctx); |
| WhileOp::attachInterface<WhileOpInterface>(*ctx); |
| YieldOp::attachInterface<YieldOpInterface>(*ctx); |
| }); |
| } |