blob: d1f91090a62d404457f7049fe0b5a91fb7506c39 [file] [log] [blame]
//===- Bufferize.cpp - Bufferization utilities ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "mlir/IR/Operation.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// BufferizeTypeConverter
//===----------------------------------------------------------------------===//
/// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter::BufferizeTypeConverter() {
// Keep all types unchanged.
addConversion([](Type type) { return type; });
// Convert RankedTensorType to MemRefType.
addConversion([](RankedTensorType type) -> Type {
return MemRefType::get(type.getShape(), type.getElementType());
});
// Convert UnrankedTensorType to UnrankedMemRefType.
addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
addSourceMaterialization([](OpBuilder &builder, RankedTensorType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseMemRefType>());
return builder.create<TensorLoadOp>(loc, type, inputs[0]);
});
addTargetMaterialization([](OpBuilder &builder, MemRefType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<TensorType>());
return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
});
}
/// This method tries to decompose a value of a certain type using provided
/// decompose callback functions. If it is unable to do so, the original value
/// is returned.
void BufferizeTypeConverter::tryDecomposeValue(
OpBuilder &builder, Location loc, Type type, Value value,
SmallVectorImpl<Value> &results) {
for (auto &conversion : decomposeValueConversions)
if (conversion(builder, loc, type, value, results))
return;
results.push_back(value);
}
/// This method tries to decompose a type using provided decompose callback
/// functions. If it is unable to do so, the original type is returned.
void BufferizeTypeConverter::tryDecomposeType(Type type,
SmallVectorImpl<Type> &types) {
for (auto &conversion : decomposeTypeConversions)
if (conversion(type, types))
return;
types.push_back(type);
}
/// This method returns ResultConversionKind for the input type.
BufferizeTypeConverter::ResultConversionKind
BufferizeTypeConverter::getResultConversionKind(Type origin, Type converted) {
for (auto &conversion : resultTypeConversions)
if (auto res = conversion(origin, converted))
return res.getValue();
return KeepAsFunctionResult;
}
void mlir::populateBufferizeMaterializationLegality(ConversionTarget &target) {
target.addLegalOp<TensorLoadOp, TensorToMemrefOp>();
}
//===----------------------------------------------------------------------===//
// BufferizeFuncOpConverter
//===----------------------------------------------------------------------===//
/// Performs the actual function signature rewriting step.
LogicalResult BufferizeFuncOpConverter::matchAndRewrite(
mlir::FuncOp funcOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
auto funcType = funcOp.getType();
// Convert function arguments using the provided TypeConverter.
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
for (auto argType : llvm::enumerate(funcType.getInputs())) {
SmallVector<Type, 2> decomposedTypes, convertedTypes;
converter.tryDecomposeType(argType.value(), decomposedTypes);
converter.convertTypes(decomposedTypes, convertedTypes);
conversion.addInputs(argType.index(), convertedTypes);
}
// Convert the result types of the function.
SmallVector<Type, 2> newResultTypes;
newResultTypes.reserve(funcOp.getNumResults());
for (Type resultType : funcType.getResults()) {
SmallVector<Type, 2> originTypes;
converter.tryDecomposeType(resultType, originTypes);
for (auto origin : originTypes) {
Type converted = converter.convertType(origin);
auto kind = converter.getResultConversionKind(origin, converted);
if (kind == BufferizeTypeConverter::AppendToArgumentsList) {
conversion.addInputs(converted);
} else {
assert(kind == BufferizeTypeConverter::KeepAsFunctionResult);
newResultTypes.push_back(converted);
}
}
}
if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter,
&conversion)))
return failure();
// Update the signature of the function.
rewriter.updateRootInPlace(funcOp, [&] {
funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
newResultTypes));
});
return success();
}
//===----------------------------------------------------------------------===//
// BufferizeCallOpConverter
//===----------------------------------------------------------------------===//
namespace {
// This class represents a mapping from a result to a list of values and some
// results that have not yet constructed. Instead, the indices of these
// results in the operation that will be constructed are known. They will be
// replaced with the actual values when they are available. The order of
// adding to this mapping is important.
class CallOpResultMapping {
public:
CallOpResultMapping() { order = 0; };
/// Add an available value to the mapping.
void addMapping(Value value) { toValuesMapping.push_back({order++, value}); }
/// Add the index of unavailble result value to the mapping.
void addMapping(unsigned index) {
toIndicesMapping.push_back({order++, index});
}
/// This method returns the mapping values list. The unknown result values
/// that only their indices are available are replaced with their values.
void getMappingValues(ValueRange valuesToReplaceIndices,
SmallVectorImpl<Value> &values) {
// Append available values to the list.
SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
toValuesMapping.end());
// Replace the indices with the actual values.
for (const std::pair<unsigned, unsigned> &entry : toIndicesMapping) {
assert(entry.second < valuesToReplaceIndices.size() &&
"The value index is out of range.");
res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
}
// Sort the values based on their adding orders.
llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
const std::pair<unsigned, Value> &v2) {
return v1.first < v2.first;
});
// Fill the values.
for (const std::pair<unsigned, Value> &entry : res)
values.push_back(entry.second);
}
private:
/// Keeping the inserting order of mapping values.
int order;
/// Containing the mapping values with their inserting orders.
SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
/// Containing the indices of result values with their inserting orders.
SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
};
} // namespace
/// Performs the actual rewriting step.
LogicalResult BufferizeCallOpConverter::matchAndRewrite(
CallOp callOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = callOp.getLoc();
OpBuilder builder(callOp);
SmallVector<Value, 2> newOperands;
// TODO: if the CallOp references a FuncOp that only has a declaration (e.g.
// to an externally defined symbol like an external library calls), only
// convert if some special attribute is set.
// This will allow more control of interop across ABI boundaries.
// Create the operands list of the new `CallOp`. It unpacks the decomposable
// values if a decompose callback function has been provided by the user.
for (auto operand : operands) {
SmallVector<Value, 2> values;
converter.tryDecomposeValue(builder, loc, operand.getType(), operand,
values);
newOperands.append(values.begin(), values.end());
}
// Create the new result types for the new `CallOp` and a mapping from the old
// result to new value(s).
SmallVector<Type, 2> newResultTypes;
SmallVector<CallOpResultMapping, 4> mappings;
mappings.resize(callOp.getNumResults());
for (auto result : llvm::enumerate(callOp.getResults())) {
SmallVector<Type, 2> originTypes;
converter.tryDecomposeType(result.value().getType(), originTypes);
auto &resultMapping = mappings[result.index()];
for (Type origin : originTypes) {
Type converted = converter.convertType(origin);
auto kind = converter.getResultConversionKind(origin, converted);
if (kind == BufferizeTypeConverter::KeepAsFunctionResult) {
newResultTypes.push_back(converted);
// The result value is not yet available. Its index is kept and it is
// replaced with the actual value of the new `CallOp` later.
resultMapping.addMapping(newResultTypes.size() - 1);
} else {
// kind = BufferizeTypeConverter::AppendToArgumentsList
MemRefType memref = converted.dyn_cast<MemRefType>();
if (!memref)
return callOp.emitError("Cannot allocate for a non-Memref type");
Value alloc = rewriter.create<AllocOp>(loc, memref);
newOperands.push_back(alloc);
resultMapping.addMapping(alloc);
}
}
}
CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
newResultTypes, newOperands);
// Build a replacing value for each result to replace its uses. If a result
// has multiple mapping values, it needs to be packed to a single value.
OpBuilder nextBuilder(callOp.getOperation()->getNextNode());
SmallVector<Value, 2> replacedValues;
replacedValues.reserve(callOp.getNumResults());
for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
SmallVector<Value, 2> valuesToPack;
mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack);
if (valuesToPack.empty()) {
// No replacement is required.
replacedValues.push_back(nullptr);
} else if (valuesToPack.size() == 1) {
replacedValues.push_back(valuesToPack.front());
} else {
// Values need to be packed using callback function. The same callback
// that is used for materializeArgumentConversion is used for packing.
Value packed = converter.materializeArgumentConversion(
nextBuilder, loc, callOp.getType(i), valuesToPack);
replacedValues.push_back(packed);
}
}
rewriter.replaceOp(callOp, replacedValues);
return success();
}