blob: 2d329a1f3d88929a0ff18061ff873dc89c251b6e [file] [log] [blame]
//===- OneShotAnalysis.cpp - One-Shot (Single Pass) Analysis --------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// One-Shot Analysis analyzes function bodies. By default, function boundaries
// (FuncOp bbArgs, CallOps, ReturnOps) are treated as "unknown" ops.
// OneShotModuleBufferization.cpp is an extension of One-Shot Analysis for
// simple call graphs without loops.
//
// One-Shot Bufferize consists of three phases.
//
// 1. Analyze ops to decide which OpOperands can bufferize inplace, i.e.,
// without inserting buffer copies. The analysis queries op bufferization
// semantics via `BufferizableOpInterface`.
// 2. Insert copies for OpOperands that were decided to bufferize out-of-place
// in tensor land during `TensorCopyInsertion`.
// 3. Bufferize ops by calling `BufferizableOpInterface::bufferize`.
//
// This file contains only the analysis. For convenience, this file also
// contains a helper function `runOneShotBufferize` that analyzes an op (and its
// nested ops) and then bufferizes it.
//
// Inplace bufferization decisions are passed from the analysis to the
// `TensorCopyInsertion` phase via `AnalysisState`. They can be printed for
// debugging purposes with `testAnalysisOnly`.
//
// Ops that do not implement `BufferizableOpInterface` can be analyzed but are
// treated conservatively. E.g., the analysis has to assume that their tensor
// OpOperands bufferize to memory writes. While such ops can be analyzed, they
// are not bufferized and remain in the IR. to_tensor and to_memref ops are
// inserted at the bufferization boundary.
//
// This analysis caters to high-performance codegen where buffer reuse is deemed
// critical: the analysis should fail if the bufferized form of the function
// needs to return a buffer, unless `allowReturnAllocs` is enabled.
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include <optional>
#include <random>
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SetVector.h"
MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState)
// Run mlir-opt with `-debug-only="one-shot-analysis"` for detailed debug
// output.
#define DEBUG_TYPE "one-shot-analysis"
using namespace mlir;
using namespace mlir::bufferization;
static bool isaTensor(Type t) { return isa<TensorType>(t); }
//===----------------------------------------------------------------------===//
// Bufferization-specific attribute manipulation.
// These are for testing and debugging only. Bufferization information is stored
// in OneShotBufferizationState. When run with `testAnalysisOnly`, the IR is
// annotated with the results of the analysis, so that they can be checked in
// tests.
//===----------------------------------------------------------------------===//
/// Attribute marker to specify op operands that bufferize in-place.
constexpr StringLiteral kInPlaceOperandsAttrName = "__inplace_operands_attr__";
constexpr StringLiteral kOpResultAliasSetAttrName =
"__opresult_alias_set_attr__";
constexpr StringLiteral kBbArgAliasSetAttrName = "__bbarg_alias_set_attr__";
/// Mark whether OpOperand will be bufferized inplace.
static void setInPlaceOpOperand(OpOperand &opOperand, bool inPlace) {
Operation *op = opOperand.getOwner();
SmallVector<StringRef> inPlaceVector;
if (auto attr = op->getAttr(kInPlaceOperandsAttrName)) {
inPlaceVector = SmallVector<StringRef>(llvm::to_vector<4>(
cast<ArrayAttr>(attr).getAsValueRange<StringAttr>()));
} else {
inPlaceVector = SmallVector<StringRef>(op->getNumOperands(), "none");
for (OpOperand &opOperand : op->getOpOperands())
if (isa<TensorType>(opOperand.get().getType()))
inPlaceVector[opOperand.getOperandNumber()] = "false";
}
inPlaceVector[opOperand.getOperandNumber()] = inPlace ? "true" : "false";
op->setAttr(kInPlaceOperandsAttrName,
OpBuilder(op).getStrArrayAttr(inPlaceVector));
}
//===----------------------------------------------------------------------===//
// OneShotAnalysisState
//===----------------------------------------------------------------------===//
OneShotAnalysisState::OneShotAnalysisState(
Operation *op, const OneShotBufferizationOptions &options)
: AnalysisState(options, TypeID::get<OneShotAnalysisState>()) {
// Set up alias sets.
op->walk([&](Operation *op) {
for (Value v : op->getResults())
if (isa<TensorType>(v.getType()))
createAliasInfoEntry(v);
for (Region &r : op->getRegions())
for (Block &b : r.getBlocks())
for (auto bbArg : b.getArguments())
if (isa<TensorType>(bbArg.getType()))
createAliasInfoEntry(bbArg);
});
// Mark OpOperands in-place that must bufferize in-place.
op->walk([&](BufferizableOpInterface bufferizableOp) {
if (!options.isOpAllowed(bufferizableOp))
return WalkResult::skip();
for (OpOperand &opOperand : bufferizableOp->getOpOperands())
if (isa<TensorType>(opOperand.get().getType()))
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this))
bufferizeInPlace(opOperand);
return WalkResult::advance();
});
}
void OneShotAnalysisState::applyOnEquivalenceClass(
Value v, function_ref<void(Value)> fun) const {
auto leaderIt = equivalentInfo.findLeader(v);
for (auto mit = leaderIt, meit = equivalentInfo.member_end(); mit != meit;
++mit) {
fun(*mit);
}
}
void OneShotAnalysisState::applyOnAliases(Value v,
function_ref<void(Value)> fun) const {
auto leaderIt = aliasInfo.findLeader(v);
for (auto mit = leaderIt, meit = aliasInfo.member_end(); mit != meit; ++mit) {
fun(*mit);
}
}
bool OneShotAnalysisState::areEquivalentBufferizedValues(Value v1,
Value v2) const {
return equivalentInfo.isEquivalent(v1, v2);
}
bool OneShotAnalysisState::areAliasingBufferizedValues(Value v1,
Value v2) const {
return aliasInfo.isEquivalent(v1, v2);
}
void OneShotAnalysisState::bufferizeInPlace(OpOperand &operand) {
if (inplaceBufferized.contains(&operand))
return;
inplaceBufferized.insert(&operand);
for (AliasingValue alias : getAliasingValues(operand))
aliasInfo.unionSets(alias.value, operand.get());
++statNumTensorInPlace;
}
void OneShotAnalysisState::bufferizeOutOfPlace(OpOperand &operand) {
assert(!inplaceBufferized.contains(&operand) &&
"OpOperand was already decided to bufferize inplace");
++statNumTensorOutOfPlace;
}
void OneShotAnalysisState::createAliasInfoEntry(Value v) {
aliasInfo.insert(v);
equivalentInfo.insert(v);
}
void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
op->walk([&](Operation *op) {
// Skip unknown ops.
auto bufferizableOp = getOptions().dynCastBufferizableOp(op);
if (!bufferizableOp)
return WalkResult::skip();
// Check all tensor OpResults.
for (OpResult opResult : op->getOpResults()) {
if (!isa<TensorType>(opResult.getType()))
continue;
// If there is no preceding definition, the tensor contents are
// undefined.
if (findDefinitionsCached(opResult).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
return WalkResult::advance();
});
}
bool OneShotAnalysisState::hasUndefinedContents(OpOperand *opOperand) const {
return undefinedTensorUses.contains(opOperand);
}
bool OneShotAnalysisState::isInPlace(OpOperand &opOperand) const {
return inplaceBufferized.contains(&opOperand);
}
bool OneShotAnalysisState::isValueWritten(Value value) const {
bool isWritten = false;
applyOnAliases(value, [&](Value val) {
for (OpOperand &use : val.getUses())
if (isInPlace(use) && bufferizesToMemoryWrite(use))
isWritten = true;
});
return isWritten;
}
bool OneShotAnalysisState::isWritable(Value value) const {
// TODO: Out-of-place bufferized value could be considered writable.
// Query BufferizableOpInterface to see if the BlockArgument is writable.
if (auto bufferizableOp =
getOptions().dynCastBufferizableOp(getOwnerOfValue(value)))
return bufferizableOp.isWritable(value, *this);
// Not a bufferizable op: The conservative answer is "not writable".
return false;
}
void OneShotAnalysisState::unionAliasSets(Value v1, Value v2) {
aliasInfo.unionSets(v1, v2);
}
void OneShotAnalysisState::unionEquivalenceClasses(Value v1, Value v2) {
equivalentInfo.unionSets(v1, v2);
}
OneShotAnalysisState::Extension::~Extension() = default;
//===----------------------------------------------------------------------===//
// Bufferization-specific alias analysis.
//===----------------------------------------------------------------------===//
/// Return true if opOperand has been decided to bufferize in-place.
static bool isInplaceMemoryWrite(OpOperand &opOperand,
const OneShotAnalysisState &state) {
// OpOperands that do not bufferize to a memory write do not write in-place.
if (!state.bufferizesToMemoryWrite(opOperand))
return false;
// Check current bufferization decisions.
return state.isInPlace(opOperand);
}
/// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors
/// properly dominates `b` and `b` is not inside `a`.
static bool happensBefore(Operation *a, Operation *b,
const DominanceInfo &domInfo) {
do {
// TODO: Instead of isProperAncestor + properlyDominates, we should use
// properlyDominatesImpl(a, b, /*enclosingOpOk=*/false)
if (a->isProperAncestor(b))
return false;
if (domInfo.properlyDominates(a, b))
return true;
} while ((a = a->getParentOp()));
return false;
}
static bool isReachable(Block *from, Block *to, ArrayRef<Block *> except) {
DenseSet<Block *> visited;
SmallVector<Block *> worklist;
for (Block *succ : from->getSuccessors())
worklist.push_back(succ);
while (!worklist.empty()) {
Block *next = worklist.pop_back_val();
if (llvm::is_contained(except, next))
continue;
if (next == to)
return true;
if (visited.contains(next))
continue;
visited.insert(next);
for (Block *succ : next->getSuccessors())
worklist.push_back(succ);
}
return false;
}
/// Return `true` if op dominance can be used to rule out a read-after-write
/// conflicts based on the ordering of ops. Returns `false` if op dominance
/// cannot be used to due region-based loops.
///
/// Generalized op dominance can often be used to rule out potential conflicts
/// due to "read happens before write". E.g., the following IR is not a RaW
/// conflict because the read happens *before* the write.
///
/// Example 1:
/// %0 = ... : tensor<?xf32> // DEF
/// "reading_op"(%0) : tensor<?xf32> // READ
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
///
/// This is no longer true inside loops (or repetitive regions). In such cases,
/// there may not be a meaningful `happensBefore` relationship because ops
/// could be executed multiple times. E.g.:
///
/// Example 2:
/// %0 = ... : tensor<?xf32> // DEF
/// scf.for ... {
/// "reading_op"(%0) : tensor<?xf32> // READ
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32> // WRITE
/// ...
/// }
///
/// In the above example, reading_op happens before writing_op according to
/// op dominance. However, both ops may happen multiple times; in
/// particular, the second execution of reading_op happens after the first
/// execution of writing_op. This is problematic because the tensor %0 they
/// operate on (i.e., the "definition") is defined outside of the loop.
///
/// On a high-level, there is a potential RaW in a program if there exists a
/// possible program execution such that there is a sequence of DEF, followed
/// by WRITE, followed by READ. Each additional DEF resets the sequence.
///
/// E.g.:
/// No conflict: DEF, WRITE, DEF, READ
/// Potential conflict: DEF, READ, WRITE, READ, WRITE
///
/// Example 1 has no conflict: DEF, READ, WRITE
/// Example 2 has a potential conflict: DEF, (READ, WRITE)*
//
/// Example 3:
/// scf.for ... {
/// %0 = ... : tensor<?xf32>
/// "reading_op"(%0) : tensor<?xf32>
/// %1 = "writing_op"(%0) : tensor<?xf32> -> tensor<?xf32>
/// ...
/// }
/// This has no conflict: (DEF, READ, WRITE)*
///
/// Example 4:
/// %0 = ... : tensor<?xf32>
/// scf.for ... {
/// scf.for ... { "reading_op"(%0) }
/// %1 = "writing_op"(%0)
/// }
/// This has a potential conflict: DEF, ((READ)*, WRITE)*
///
/// Example 5:
/// %0 = ... : tensor<?xf32>
/// scf.for ... { %1 = "writing_op"(%0) }
/// scf.for ... { "reading_op"(%0) }
/// This has a potential conflict: DEF, WRITE*, READ*
///
/// The following rules are used to rule out RaW conflicts via ordering of ops:
///
/// 1. If the closest enclosing repetitive region of DEF is a proper ancestor of
/// a repetitive region that enclosing both READ and WRITE, we cannot rule
/// out RaW conflict due to the ordering of ops.
/// 2. Otherwise: There are no loops that interfere with our analysis; for
/// analysis purposes, we can assume that there are no loops/repetitive
/// regions. I.e., we can rule out a RaW conflict if READ happensBefore WRITE
/// or WRITE happensBefore DEF. (Checked in `hasReadAfterWriteInterference`.)
///
static bool canUseOpDominanceDueToRegions(OpOperand *uRead, OpOperand *uWrite,
const SetVector<Value> &definitions,
AnalysisState &state) {
const BufferizationOptions &options = state.getOptions();
for (Value def : definitions) {
Region *rRead =
state.getEnclosingRepetitiveRegion(uRead->getOwner(), options);
Region *rDef = state.getEnclosingRepetitiveRegion(def, options);
// READ and DEF are in the same repetitive region. `happensBefore` can be
// used to rule out RaW conflicts due to op ordering.
if (rRead == rDef)
continue;
// Find the enclosing repetitive region of READ that is closest to DEF but
// not the repetitive region of DEF itself.
while (true) {
Region *nextRegion = getNextEnclosingRepetitiveRegion(rRead, options);
if (nextRegion == rDef)
break;
assert(nextRegion && "expected to find another repetitive region");
rRead = nextRegion;
}
// We cannot use op dominance if WRITE is inside the same repetitive region.
if (rRead->getParentOp()->isAncestor(uWrite->getOwner()))
return false;
}
return true;
}
/// Return `true` if op dominance can be used to rule out a read-after-write
/// conflicts based on the ordering of ops. Returns `false` if op dominance
/// cannot be used to due block-based loops within a region.
///
/// Refer to the `canUseOpDominanceDueToRegions` documentation for details on
/// how op domiance is used during RaW conflict detection.
///
/// On a high-level, there is a potential RaW in a program if there exists a
/// possible program execution such that there is a sequence of DEF, followed
/// by WRITE, followed by READ. Each additional DEF resets the sequence.
///
/// Op dominance cannot be used if there is a path from block(READ) to
/// block(WRITE) and a path from block(WRITE) to block(READ). block(DEF) should
/// not appear on that path.
static bool canUseOpDominanceDueToBlocks(OpOperand *uRead, OpOperand *uWrite,
const SetVector<Value> &definitions,
AnalysisState &state) {
// Fast path: If READ and WRITE are in different regions, their block cannot
// be reachable just via unstructured control flow. (Loops due to regions are
// covered by `canUseOpDominanceDueToRegions`.)
if (uRead->getOwner()->getParentRegion() !=
uWrite->getOwner()->getParentRegion())
return true;
Block *readBlock = uRead->getOwner()->getBlock();
Block *writeBlock = uWrite->getOwner()->getBlock();
for (Value def : definitions) {
Block *defBlock = def.getParentBlock();
if (isReachable(readBlock, writeBlock, {defBlock}) &&
isReachable(writeBlock, readBlock, {defBlock}))
return false;
}
return true;
}
static bool canUseOpDominance(OpOperand *uRead, OpOperand *uWrite,
const SetVector<Value> &definitions,
AnalysisState &state) {
return canUseOpDominanceDueToRegions(uRead, uWrite, definitions, state) &&
canUseOpDominanceDueToBlocks(uRead, uWrite, definitions, state);
}
/// Annotate IR with details about the detected RaW conflict.
static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
Value definition) {
static uint64_t counter = 0;
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
OpBuilder b(conflictingWritingOp->getContext());
std::string id = "C_" + std::to_string(counter++);
std::string conflictingWriteAttr =
id +
"[CONFL-WRITE: " + std::to_string(uConflictingWrite->getOperandNumber()) +
"]";
conflictingWritingOp->setAttr(conflictingWriteAttr, b.getUnitAttr());
std::string readAttr =
id + "[READ: " + std::to_string(uRead->getOperandNumber()) + "]";
readingOp->setAttr(readAttr, b.getUnitAttr());
if (auto opResult = dyn_cast<OpResult>(definition)) {
std::string defAttr =
id + "[DEF: result " + std::to_string(opResult.getResultNumber()) + "]";
opResult.getDefiningOp()->setAttr(defAttr, b.getUnitAttr());
} else {
auto bbArg = cast<BlockArgument>(definition);
std::string defAttr =
id + "[DEF: bbArg " + std::to_string(bbArg.getArgNumber()) + "]";
bbArg.getOwner()->getParentOp()->setAttr(defAttr, b.getUnitAttr());
}
}
/// Return 'true' if a tensor that is equivalent to `other` can be found in the
/// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of
/// place along that use-def chain, the two tensors may not materialize as
/// equivalent buffers (but separate allocations).
///
/// Note: This function also requires that the two tensors have equivalent
/// indexing. I.e., the tensor types do not change along the use-def chain,
/// apart from static <-> dynamic dim casts.
static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
Value start, Value other) {
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
config.followSameTypeOrCastsOnly = true;
return !state
.findValueInReverseUseDefChain(
start, [&](Value v) { return v == other; }, config)
.empty();
}
/// Return "true" if `value` is originating from a subset that is equivalent to
/// the subset that `subsetOp` inserts into.
static bool matchesInsertDestination(const AnalysisState &state, Value value,
SubsetInsertionOpInterface subsetOp) {
auto matchingSubset = [&](Value val) {
if (auto opResult = dyn_cast<OpResult>(val))
if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
return state.areEquivalentBufferizedValues(v1, v2);
}))
return true;
return false;
};
// There may be multiple leaves at which the reverse SSA use-def chain lookup
// terminates. All of them must be equivalent subsets.
SetVector<Value> backwardSlice =
state.findValueInReverseUseDefChain(value, matchingSubset);
return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
}
/// Return "true" if the given "read" and potentially conflicting "write" are
/// not conflicting due to their subset relationship. The comments in this
/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
/// pairs, but apply to any subset ops that implement the
/// `SubsetInsertionOpInterface`.
static bool areNonConflictingSubsets(OpOperand *uRead,
OpOperand *uConflictingWrite,
const AnalysisState &state) {
Operation *readingOp = uRead->getOwner();
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
// uRead is an InsertSliceOp...
if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
if (uRead == &subsetOp.getDestinationOperand() &&
matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
// being read by uRead, this is not a conflict.
//
// In the above example:
// uRead = OpOperand 1 (%t) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
//
// The read of %t does not conflict with the write of the FillOp
// (same aliases!) because the area that the FillOp operates on is
// exactly the one that is *not* read via %t.
return true;
if (uRead == &subsetOp.getSourceOperand() &&
uConflictingWrite == &subsetOp.getDestinationOperand() &&
matchesInsertDestination(state, uRead->get(), subsetOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
// InsertSliceOp is writing.
//
// In the above example:
// uRead = OpOperand 0 (%1) of tensor.insert_slice
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
return true;
}
// If uConflictingWrite is an InsertSliceOp...
if (auto subsetOp =
dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
// As an example, consider the following IR.
//
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
// %1 = linalg.fill %cst, %0 {inplace= [true] }
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
// {inplace= [true] }
// %3 = vector.transfer_read %1, %cst
//
// In the above example:
// uRead = OpOperand 0 (%1) of vector.transfer_read
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
// definition = %1
//
// This is not a conflict because the InsertSliceOp overwrites the
// memory segment of %1 with the exact same data. (Effectively, there
// is no memory write here.)
if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
state.areEquivalentBufferizedValues(
uRead->get(), subsetOp.getSourceOperand().get()) &&
matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
subsetOp))
return true;
return false;
}
/// Given sets of uses and writes, return true if there is a RaW conflict under
/// the assumption that all given reads/writes alias the same buffer and that
/// all given writes bufferize inplace.
///
/// A conflict is: According to SSA use-def chains, a read R is supposed to read
/// the result of a definition W1. But because of bufferization decisions, R
/// actually reads another definition W2.
static bool
hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
const DenseSet<OpOperand *> &usesWrite,
const DominanceInfo &domInfo,
OneShotAnalysisState &state) {
const BufferizationOptions &options = state.getOptions();
// Before going through the main RaW analysis, find cases where a buffer must
// be privatized due to parallelism. If the result of a write is never read,
// privatization is not necessary (and large parts of the IR are likely dead).
if (!usesRead.empty()) {
for (OpOperand *uConflictingWrite : usesWrite) {
// Find the allocation point or last write (definition) of the buffer.
// Note: In contrast to `findDefinitions`, this also returns results of
// ops that do not bufferize to memory write when no other definition
// could be found. E.g., "bufferization.alloc_tensor" would be included,
// even though that op just bufferizes to an allocation but does define
// the contents of the buffer.
SetVector<Value> definitionsOrLeaves =
state.findValueInReverseUseDefChain(
uConflictingWrite->get(),
[&](Value v) { return state.bufferizesToMemoryWrite(v); });
assert(!definitionsOrLeaves.empty() &&
"expected at least one definition or leaf");
// The writing op must bufferize out-of-place if the definition is in a
// different parallel region than this write.
for (Value def : definitionsOrLeaves) {
if (getParallelRegion(def.getParentRegion(), options) !=
getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(),
options)) {
LLVM_DEBUG(
llvm::dbgs()
<< "\n- bufferizes out-of-place due to parallel region:\n");
LLVM_DEBUG(llvm::dbgs()
<< " unConflictingWrite = operand "
<< uConflictingWrite->getOperandNumber() << " of "
<< *uConflictingWrite->getOwner() << "\n");
return true;
}
}
}
}
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();
LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n");
LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber()
<< " of " << *readingOp << "\n");
// Find the definition of uRead by following the SSA use-def chain.
// E.g.:
//
// %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
// %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
// %2 = "reading_op"(%1) : : tensor<?x32> -> not_a_tensor_type
//
// In the above example, if uRead is the OpOperand of reading_op, the
// definition is %0. Note that operations that create an alias but do not
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
const SetVector<Value> &definitions =
state.findDefinitionsCached(uRead->get());
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
LLVM_DEBUG(llvm::dbgs()
<< " no conflict: read value has no definitions\n");
continue;
}
// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
for (OpOperand *uConflictingWrite : usesWrite) {
LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand "
<< uConflictingWrite->getOperandNumber() << " of "
<< *uConflictingWrite->getOwner() << "\n");
// Check if op dominance can be used to rule out read-after-write
// conflicts.
bool useDominance =
canUseOpDominance(uRead, uConflictingWrite, definitions, state);
LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n");
// Throughout this loop, check for multiple requirements that have to be
// met for uConflictingWrite to be an actual conflict.
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
// Inside of repetitive regions, ops may be executed multiple times and op
// dominance cannot be used to rule out conflicts.
if (useDominance) {
// No conflict if the readingOp dominates conflictingWritingOp, i.e.,
// the write is not visible when reading.
//
// Note: If ops are executed multiple times (e.g., because they are
// inside a loop), there may be no meaningful `happensBefore`
// relationship.
if (happensBefore(readingOp, conflictingWritingOp, domInfo)) {
LLVM_DEBUG(llvm::dbgs()
<< " no conflict: read happens before write\n");
continue;
}
// No conflict if the reading use equals the use of the conflicting
// write. A use cannot conflict with itself.
//
// Note: Just being the same op is not enough. It has to be the same
// use.
// Note: If the op is executed multiple times (e.g., because it is
// inside a loop), it may be conflicting with itself.
if (uConflictingWrite == uRead) {
LLVM_DEBUG(llvm::dbgs()
<< " no conflict: read and write are same use\n");
continue;
}
// Ops are not conflicting if they are in mutually exclusive regions.
//
// Note: If ops are executed multiple times (e.g., because they are
// inside a loop), mutually exclusive regions may be executed
// multiple times.
if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) {
LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in "
"mutually exclusive regions\n");
continue;
}
}
// Two equivalent operands of the same op are not conflicting if the op
// bufferizes to element-wise access. I.e., all loads at a position happen
// before all stores to the same position.
if (conflictingWritingOp == readingOp) {
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
if (bufferizableOp.bufferizesToElementwiseAccess(
state, {uRead, uConflictingWrite})) {
if (hasEquivalentValueInReverseUseDefChain(
state, uRead->get(), uConflictingWrite->get()) ||
hasEquivalentValueInReverseUseDefChain(
state, uConflictingWrite->get(), uRead->get())) {
LLVM_DEBUG(
llvm::dbgs()
<< " no conflict: op bufferizes to element-wise access\n");
continue;
}
}
}
}
// No conflict if the operands are non-conflicting subsets.
if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
continue;
}
// No conflict if the op interface says so.
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {
LLVM_DEBUG(llvm::dbgs()
<< " no conflict: op interace of reading op says 'no'\n");
continue;
}
}
if (conflictingWritingOp != readingOp) {
if (auto bufferizableOp =
options.dynCastBufferizableOp(conflictingWritingOp)) {
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite,
state)) {
LLVM_DEBUG(
llvm::dbgs()
<< " no conflict: op interace of writing op says 'no'\n");
continue;
}
}
}
// Check all possible definitions.
for (Value definition : definitions) {
LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n");
// No conflict if the conflicting write happens before the definition.
if (Operation *defOp = definition.getDefiningOp()) {
if (happensBefore(conflictingWritingOp, defOp, domInfo)) {
// conflictingWritingOp happens before defOp. No conflict.
LLVM_DEBUG(llvm::dbgs()
<< " no conflict: write happens before definition\n");
continue;
}
// No conflict if conflictingWritingOp is contained in defOp.
if (defOp->isProperAncestor(conflictingWritingOp)) {
LLVM_DEBUG(
llvm::dbgs()
<< " no conflict: write is contained in definition\n");
continue;
}
} else {
auto bbArg = cast<BlockArgument>(definition);
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp)) {
LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg "
"and write happens outside of block\n");
// conflictingWritingOp happens outside of the block. No
// conflict.
continue;
}
}
// No conflict if the conflicting write and the definition are the same
// use.
AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite);
if (aliases.getNumAliases() == 1 &&
aliases.getAliases()[0].value == definition) {
LLVM_DEBUG(llvm::dbgs()
<< " no conflict: definition and write are same\n");
continue;
}
// All requirements are met. Conflict found!
if (options.printConflicts)
annotateConflict(uRead, uConflictingWrite, definition);
LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n");
return true;
}
}
}
return false;
}
// Helper function to iterate on aliases of `root` and capture the writes.
static void getAliasingInplaceWrites(DenseSet<OpOperand *> &res, Value root,
const OneShotAnalysisState &state) {
state.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses())
// Inplace write to a value that aliases root.
if (isInplaceMemoryWrite(use, state))
res.insert(&use);
});
}
// Helper function to iterate on aliases of `root` and capture the reads.
static void getAliasingReads(DenseSet<OpOperand *> &res, Value root,
const OneShotAnalysisState &state) {
state.applyOnAliases(root, [&](Value alias) {
for (auto &use : alias.getUses()) {
// Read of a value that aliases root.
if (state.bufferizesToMemoryRead(use)) {
res.insert(&use);
continue;
}
// Read of a dependent value in the SSA use-def chain. E.g.:
//
// %0 = ...
// %1 = tensor.extract_slice %0 {not_analyzed_yet}
// "read"(%1)
//
// In the above example, getAliasingReads(%0) includes the first OpOperand
// of the tensor.extract_slice op. The extract_slice itself does not read
// but its aliasing result is eventually fed into an op that does.
//
// Note: This is considered a "read" only if the use does not bufferize to
// a memory write. (We already ruled out memory reads. In case of a memory
// write, the buffer would be entirely overwritten; in the above example
// there would then be no flow of data from the extract_slice operand to
// its result's uses.)
if (!state.bufferizesToMemoryWrite(use)) {
AliasingValueList aliases = state.getAliasingValues(use);
if (llvm::any_of(aliases, [&](AliasingValue a) {
return state.isValueRead(a.value);
}))
res.insert(&use);
}
}
});
}
/// Return true if bufferizing `operand` inplace would create a conflict. A read
/// R and a write W of the same alias set is a conflict if inplace bufferization
/// of W changes the value read by R to a value different from the one that
/// would be expected by tracing back R's origin through SSA use-def chains.
/// A conflict can only be introduced by a new alias and/or an inplace
/// bufferization decision.
///
/// Example:
/// %0 = tensor.extract_slice %t[...][...][1, 1] {inplace?}
/// %1 = vector.transfer_write %v1, %t {inplace} : vector<5xf32>, tensor<?xf32>
/// %e = tensor.extract_slice %1
/// %2 = vector.transfer_write %v2, %0 {inplace} : vector<6xf32>, tensor<?xf32>
/// %3 = vector.transfer_read %e, %cst : tensor<?xf32>, vector<7xf32>
///
/// In the above example, the two TransferWriteOps have already been decided to
/// bufferize inplace. Bufferizing the ExtractSliceOp inplace would create a
/// conflict because:
/// * According to SSA use-def chains, we expect to read the result of %1.
/// * However, adding an alias {%0, %t} would mean that the second
/// TransferWriteOp overwrites the result of the first one. Therefore, the
/// TransferReadOp would no longer be reading the result of %1.
///
/// If `checkConsistencyOnly` is true, this function checks if there is a
/// read-after-write conflict without bufferizing `operand` inplace. This would
/// indicate a problem with the current inplace bufferization decisions.
///
/// Note: If `checkConsistencyOnly`, this function may be called with a null
/// OpResult. In that case, only the consistency of bufferization decisions
/// involving aliases of the given OpOperand are checked.
static bool wouldCreateReadAfterWriteInterference(
OpOperand &operand, const DominanceInfo &domInfo,
OneShotAnalysisState &state, bool checkConsistencyOnly = false) {
// Collect reads and writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesRead, usesWrite;
getAliasingReads(usesRead, operand.get(), state);
getAliasingInplaceWrites(usesWrite, operand.get(), state);
for (AliasingValue alias : state.getAliasingValues(operand)) {
getAliasingReads(usesRead, alias.value, state);
getAliasingInplaceWrites(usesWrite, alias.value, state);
}
if (!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand))
usesWrite.insert(&operand);
return hasReadAfterWriteInterference(usesRead, usesWrite, domInfo, state);
}
/// Annotate IR with details about the detected non-writability conflict.
static void annotateNonWritableTensor(Value value) {
static int64_t counter = 0;
OpBuilder b(value.getContext());
std::string id = "W_" + std::to_string(counter++);
if (auto opResult = dyn_cast<OpResult>(value)) {
std::string attr = id + "[NOT-WRITABLE: result " +
std::to_string(opResult.getResultNumber()) + "]";
opResult.getDefiningOp()->setAttr(attr, b.getUnitAttr());
} else {
auto bbArg = cast<BlockArgument>(value);
std::string attr = id + "[NOT-WRITABLE: bbArg " +
std::to_string(bbArg.getArgNumber()) + "]";
bbArg.getOwner()->getParentOp()->setAttr(attr, b.getUnitAttr());
}
}
/// Return true if bufferizing `operand` inplace would create a write to a
/// non-writable buffer.
static bool
wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
OneShotAnalysisState &state,
bool checkConsistencyOnly = false) {
bool foundWrite =
!checkConsistencyOnly && state.bufferizesToMemoryWrite(operand);
if (!foundWrite) {
// Collect writes of all aliases of OpOperand and OpResult.
DenseSet<OpOperand *> usesWrite;
getAliasingInplaceWrites(usesWrite, operand.get(), state);
for (AliasingValue alias : state.getAliasingValues(operand))
getAliasingInplaceWrites(usesWrite, alias.value, state);
foundWrite = !usesWrite.empty();
}
if (!foundWrite)
return false;
// Look for a read-only tensor among all aliases.
bool foundReadOnly = false;
auto checkReadOnly = [&](Value v) {
if (!state.isWritable(v)) {
foundReadOnly = true;
if (state.getOptions().printConflicts)
annotateNonWritableTensor(v);
}
};
state.applyOnAliases(operand.get(), checkReadOnly);
for (AliasingValue alias : state.getAliasingValues(operand))
state.applyOnAliases(alias.value, checkReadOnly);
if (foundReadOnly) {
LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n");
return true;
}
return false;
}
//===----------------------------------------------------------------------===//
// Bufferization analyses.
//===----------------------------------------------------------------------===//
// Find the values that define the contents of the given value.
const llvm::SetVector<Value> &
OneShotAnalysisState::findDefinitionsCached(Value value) {
if (!cachedDefinitions.count(value))
cachedDefinitions[value] = findDefinitions(value);
return cachedDefinitions[value];
}
void OneShotAnalysisState::resetCache() {
AnalysisState::resetCache();
cachedDefinitions.clear();
}
/// Determine if `operand` can be bufferized in-place.
static LogicalResult
bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state,
const DominanceInfo &domInfo) {
LLVM_DEBUG(
llvm::dbgs() << "//===-------------------------------------------===//\n"
<< "Analyzing operand #" << operand.getOperandNumber()
<< " of " << *operand.getOwner() << "\n");
bool foundInterference =
wouldCreateWriteToNonWritableBuffer(operand, state) ||
wouldCreateReadAfterWriteInterference(operand, domInfo, state);
if (foundInterference)
state.bufferizeOutOfPlace(operand);
else
state.bufferizeInPlace(operand);
LLVM_DEBUG(llvm::dbgs()
<< "//===-------------------------------------------===//\n");
return success();
}
LogicalResult
OneShotAnalysisState::analyzeSingleOp(Operation *op,
const DominanceInfo &domInfo) {
for (OpOperand &opOperand : op->getOpOperands())
if (isa<TensorType>(opOperand.get().getType()))
if (failed(bufferizableInPlaceAnalysisImpl(opOperand, *this, domInfo)))
return failure();
return success();
}
/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
static void equivalenceAnalysis(SmallVector<Operation *> &ops,
OneShotAnalysisState &state) {
for (Operation *op : ops) {
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
for (OpResult opResult : op->getOpResults()) {
if (!isa<TensorType>(opResult.getType()))
continue;
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
if (aliases.getNumAliases() == 0)
// Nothing to do if there are no aliasing OpOperands.
continue;
Value firstOperand = aliases.begin()->opOperand->get();
bool allEquivalent = true;
for (AliasingOpOperand alias : aliases) {
bool isEquiv = alias.relation == BufferRelation::Equivalent;
bool isInPlace = state.isInPlace(*alias.opOperand);
Value operand = alias.opOperand->get();
if (isEquiv && isInPlace && alias.isDefinite) {
// Found a definite, equivalent alias. Merge equivalence sets.
// There can only be one definite alias, so we can stop here.
state.unionEquivalenceClasses(opResult, operand);
allEquivalent = false;
break;
}
if (!isEquiv || !isInPlace)
allEquivalent = false;
if (!state.areEquivalentBufferizedValues(operand, firstOperand))
allEquivalent = false;
}
// If all "maybe" aliases are equivalent and the OpResult is not a new
// allocation, it is a definite, equivalent alias. E.g.:
//
// aliasingOpOperands(%r) = {(%t0, EQUIV, MAYBE), (%t1, EQUIV, MAYBE)}
// aliasingValues(%t0) = {(%r, EQUIV, MAYBE)}
// aliasingValues(%t1) = {(%r, EQUIV, MAYBE)}
// %r = arith.select %c, %t0, %t1 : tensor<?xf32>
//
// If %t0 and %t1 are equivalent, it is safe to union the equivalence
// classes of %r, %t0 and %t1.
if (allEquivalent && !bufferizableOp.bufferizesToAllocation(opResult))
state.unionEquivalenceClasses(opResult, firstOperand);
}
}
}
}
/// Analyze equivalence of tied OpResult/OpOperand pairs of all ops contained
/// in `op`.
static void equivalenceAnalysis(Operation *op, OneShotAnalysisState &state) {
// Traverse ops in PostOrder: Nested ops first, then enclosing ops.
SmallVector<Operation *> ops;
op->walk<WalkOrder::PostOrder>([&](Operation *op) {
// No tensors => no buffers.
if (none_of(op->getResultTypes(), isaTensor))
return;
ops.push_back(op);
});
equivalenceAnalysis(ops, state);
}
/// "Bottom-up from terminators" heuristic.
static SmallVector<Operation *>
bottomUpFromTerminatorsHeuristic(Operation *op,
const OneShotAnalysisState &state) {
SetVector<Operation *> traversedOps;
// Find region terminators.
op->walk<WalkOrder::PostOrder>([&](RegionBranchTerminatorOpInterface term) {
if (!traversedOps.insert(term))
return;
// Follow the reverse SSA use-def chain from each yielded value as long as
// we stay within the same region.
SmallVector<OpResult> worklist;
for (Value v : term->getOperands()) {
if (!isa<TensorType>(v.getType()))
continue;
auto opResult = dyn_cast<OpResult>(v);
if (!opResult)
continue;
worklist.push_back(opResult);
}
while (!worklist.empty()) {
OpResult opResult = worklist.pop_back_val();
Operation *defOp = opResult.getDefiningOp();
if (!traversedOps.insert(defOp))
continue;
if (!term->getParentRegion()->findAncestorOpInRegion(*defOp))
continue;
AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult);
for (auto alias : aliases) {
Value v = alias.opOperand->get();
if (!isa<TensorType>(v.getType()))
continue;
auto opResult = dyn_cast<OpResult>(v);
if (!opResult)
continue;
worklist.push_back(opResult);
}
}
});
// Analyze traversed ops, then all remaining ops.
SmallVector<Operation *> result(traversedOps.begin(), traversedOps.end());
op->walk<WalkOrder::PostOrder, ReverseIterator>([&](Operation *op) {
if (!traversedOps.contains(op) && hasTensorSemantics(op))
result.push_back(op);
});
return result;
}
LogicalResult OneShotAnalysisState::analyzeOp(Operation *op,
const DominanceInfo &domInfo) {
OneShotBufferizationOptions::AnalysisHeuristic heuristic =
getOptions().analysisHeuristic;
SmallVector<Operation *> orderedOps;
if (heuristic ==
OneShotBufferizationOptions::AnalysisHeuristic::BottomUpFromTerminators) {
orderedOps = bottomUpFromTerminatorsHeuristic(op, *this);
} else {
op->walk([&](Operation *op) {
// No tensors => no buffers.
if (!hasTensorSemantics(op))
return;
orderedOps.push_back(op);
});
switch (heuristic) {
case OneShotBufferizationOptions::AnalysisHeuristic::BottomUp: {
// Default: Walk ops in reverse for better interference analysis.
std::reverse(orderedOps.begin(), orderedOps.end());
break;
}
case OneShotBufferizationOptions::AnalysisHeuristic::TopDown: {
// Ops are already sorted top-down in `orderedOps`.
break;
}
case OneShotBufferizationOptions::AnalysisHeuristic::Fuzzer: {
assert(getOptions().analysisFuzzerSeed &&
"expected that fuzzer seed it set");
// This is a fuzzer. For testing purposes only. Randomize the order in
// which operations are analyzed. The bufferization quality is likely
// worse, but we want to make sure that no assertions are triggered
// anywhere.
std::mt19937 g(getOptions().analysisFuzzerSeed);
llvm::shuffle(orderedOps.begin(), orderedOps.end(), g);
break;
}
default: {
llvm_unreachable("unsupported heuristic");
}
}
}
// Analyze ops in the computed order.
for (Operation *op : orderedOps)
if (failed(analyzeSingleOp(op, domInfo)))
return failure();
equivalenceAnalysis(op, *this);
return success();
}
/// Perform various checks on the input IR to see if it contains IR constructs
/// that are unsupported by One-Shot Bufferize.
static LogicalResult
checkPreBufferizationAssumptions(Operation *op, const DominanceInfo &domInfo,
OneShotAnalysisState &state) {
const BufferizationOptions &options = state.getOptions();
// Note: This walk cannot be combined with the one below because interface
// methods of invalid/unsupported ops may be called during the second walk.
// (On ops different from `op`.)
WalkResult walkResult = op->walk([&](BufferizableOpInterface op) {
// Skip ops that are not in the filter.
if (!options.isOpAllowed(op.getOperation()))
return WalkResult::advance();
// Check for unsupported unstructured control flow.
if (!op.supportsUnstructuredControlFlow()) {
for (Region &r : op->getRegions()) {
if (r.getBlocks().size() > 1) {
op->emitOpError("op or BufferizableOpInterface implementation does "
"not support unstructured control flow, but at least "
"one region has multiple blocks");
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return failure();
walkResult = op->walk([&](BufferizableOpInterface op) {
// Skip ops that are not in the filter.
if (!options.isOpAllowed(op.getOperation()))
return WalkResult::advance();
// Input IR may not contain any ToTensorOps without the "restrict"
// attribute. Such tensors may alias any other tensor, which is currently
// not handled in the analysis.
if (auto toTensorOp = dyn_cast<ToTensorOp>(op.getOperation())) {
if (!toTensorOp.getRestrict() && !toTensorOp->getUses().empty()) {
op->emitOpError("to_tensor ops without `restrict` are not supported by "
"One-Shot Analysis");
return WalkResult::interrupt();
}
}
for (OpOperand &opOperand : op->getOpOperands()) {
if (isa<TensorType>(opOperand.get().getType())) {
if (wouldCreateReadAfterWriteInterference(
opOperand, domInfo, state,
/*checkConsistencyOnly=*/true)) {
// This error can happen if certain "mustBufferizeInPlace" interface
// methods are implemented incorrectly, such that the IR already has
// a RaW conflict before making any bufferization decisions. It can
// also happen if the bufferization.materialize_in_destination is used
// in such a way that a RaW conflict is not avoidable.
op->emitOpError("not bufferizable under the given constraints: "
"cannot avoid RaW conflict");
return WalkResult::interrupt();
}
if (state.isInPlace(opOperand) &&
wouldCreateWriteToNonWritableBuffer(
opOperand, state, /*checkConsistencyOnly=*/true)) {
op->emitOpError("not bufferizable under the given constraints: would "
"write to read-only buffer");
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});
return success(!walkResult.wasInterrupted());
}
/// Annotate the IR with the result of the analysis. For testing/debugging only.
static void
annotateOpsWithBufferizationMarkers(Operation *op,
const OneShotAnalysisState &state) {
// Add __inplace_operands_attr__.
op->walk([&](Operation *op) {
for (OpOperand &opOperand : op->getOpOperands())
if (isa<TensorType>(opOperand.get().getType()))
setInPlaceOpOperand(opOperand, state.isInPlace(opOperand));
});
}
static void annotateOpsWithAliasSets(Operation *op,
const OneShotAnalysisState &state) {
AsmState asmState(op);
Builder b(op->getContext());
// Helper function to build an array attribute of aliasing SSA value strings.
auto buildAliasesArray = [&](Value v) {
SmallVector<Attribute> aliases;
state.applyOnAliases(v, [&](Value alias) {
std::string buffer;
llvm::raw_string_ostream stream(buffer);
alias.printAsOperand(stream, asmState);
aliases.push_back(b.getStringAttr(stream.str()));
});
return b.getArrayAttr(aliases);
};
op->walk([&](Operation *op) {
// Build alias set array for every OpResult.
SmallVector<Attribute> opResultAliasSets;
for (OpResult opResult : op->getOpResults()) {
if (llvm::isa<TensorType>(opResult.getType())) {
opResultAliasSets.push_back(buildAliasesArray(opResult));
}
}
if (!opResultAliasSets.empty())
op->setAttr(kOpResultAliasSetAttrName, b.getArrayAttr(opResultAliasSets));
// Build alias set array for every BlockArgument.
SmallVector<Attribute> regionAliasSets;
bool hasTensorBbArg = false;
for (Region &r : op->getRegions()) {
SmallVector<Attribute> blockAliasSets;
for (Block &block : r.getBlocks()) {
SmallVector<Attribute> bbArgAliasSets;
for (BlockArgument bbArg : block.getArguments()) {
if (llvm::isa<TensorType>(bbArg.getType())) {
bbArgAliasSets.push_back(buildAliasesArray(bbArg));
hasTensorBbArg = true;
}
}
blockAliasSets.push_back(b.getArrayAttr(bbArgAliasSets));
}
regionAliasSets.push_back(b.getArrayAttr(blockAliasSets));
}
if (hasTensorBbArg)
op->setAttr(kBbArgAliasSetAttrName, b.getArrayAttr(regionAliasSets));
});
}
LogicalResult bufferization::analyzeOp(Operation *op,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
DominanceInfo domInfo(op);
const OneShotBufferizationOptions &options = state.getOptions();
if (failed(checkPreBufferizationAssumptions(op, domInfo, state)))
return failure();
// If the analysis fails, just return.
if (failed(state.analyzeOp(op, domInfo)))
return failure();
if (statistics) {
statistics->numTensorInPlace = state.getStatNumTensorInPlace();
statistics->numTensorOutOfPlace = state.getStatNumTensorOutOfPlace();
}
bool failedAnalysis = false;
// Gather some extra analysis data.
state.gatherUndefinedTensorUses(op);
// Analysis verification: After setting up alias/equivalence sets, each op
// can check for expected invariants/limitations and fail the analysis if
// necessary.
op->walk([&](Operation *op) {
if (BufferizableOpInterface bufferizableOp =
options.dynCastBufferizableOp(op))
failedAnalysis |= failed(bufferizableOp.verifyAnalysis(state));
});
// Annotate operations if we only want to report the analysis.
if (options.testAnalysisOnly)
annotateOpsWithBufferizationMarkers(op, state);
if (options.dumpAliasSets)
annotateOpsWithAliasSets(op, state);
return success(!failedAnalysis);
}
LogicalResult
bufferization::runOneShotBufferize(Operation *op,
const OneShotBufferizationOptions &options,
BufferizationStatistics *statistics) {
// copy-before-write deactivates the analysis. It cannot be used together with
// test-analysis-only.
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
"invalid combination of bufferization flags");
if (options.copyBeforeWrite) {
// Copy buffer before each write. No analysis is needed.
} else {
// Run One-Shot Analysis and insert buffer copies (on the tensor level)
// only where needed. This is the default and much more efficient than
// copy-before-write.
if (failed(insertTensorCopies(op, options, statistics)))
return failure();
// If test-analysis-only is set, the IR was annotated with RaW conflict
// markers (attributes) during One-Shot Analysis.
if (options.testAnalysisOnly)
return success();
}
// Bufferize the op and its nested ops. If options.copyBeforeWrite is set,
// a new buffer copy is allocated every time a buffer is written to.
return bufferizeOp(op, options, statistics);
}