blob: a53bce16dad8603460a08719f5b4c3744acd32fe [file] [log] [blame]
//===- SparseAssembler.cpp - adds wrapper method around sparse types ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Utils/CodegenUtils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
using namespace sparse_tensor;
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
// Convert type range to new types range, with sparse tensors externalized.
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
SmallVectorImpl<Type> *extraTypes, bool directOut) {
for (auto type : types) {
// All "dense" data passes through unmodified.
if (!getSparseTensorEncoding(type)) {
convTypes.push_back(type);
continue;
}
// Convert the external representations of the pos/crd/val arrays.
const SparseTensorType stt(cast<RankedTensorType>(type));
foreachFieldAndTypeInSparseTensor(
stt, [&convTypes, extraTypes, directOut](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
auto rtp = cast<ShapedType>(t);
if (!directOut) {
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
if (extraTypes)
extraTypes->push_back(rtp);
}
convTypes.push_back(rtp);
}
return true;
});
}
}
// Convert input and output values to [dis]assemble ops for sparse tensors.
static void convVals(OpBuilder &builder, Location loc, TypeRange types,
ValueRange fromVals, ValueRange extraVals,
SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn,
bool directOut) {
unsigned idx = 0;
for (auto type : types) {
// All "dense" data passes through unmodified.
if (!getSparseTensorEncoding(type)) {
toVals.push_back(fromVals[idx++]);
continue;
}
// Handle sparse data.
auto rtp = cast<RankedTensorType>(type);
const SparseTensorType stt(rtp);
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
if (!isIn)
inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
// Collect the external representations of the pos/crd/val arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level lv, LevelType) {
if (kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else if (directOut) {
Value mem;
if (kind == SparseTensorFieldKind::PosMemRef)
mem = builder.create<sparse_tensor::ToPositionsOp>(loc, inputs[0],
lv);
else if (kind == SparseTensorFieldKind::CrdMemRef)
mem = builder.create<sparse_tensor::ToCoordinatesOp>(loc, inputs[0],
lv);
else
mem = builder.create<sparse_tensor::ToValuesOp>(loc, inputs[0]);
toVals.push_back(mem);
} else {
ShapedType rtp = cast<ShapedType>(t);
rtp = RankedTensorType::get(rtp.getShape(), rtp.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
cntTypes.push_back(builder.getIndexType());
}
}
return true;
});
if (isIn) {
// Assemble multiple inputs into a single sparse tensor.
auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);
toVals.push_back(a.getResult());
} else if (!directOut) {
// Disassemble a single sparse input into multiple outputs.
// Note that this includes the counters, which are dropped.
unsigned len = retTypes.size();
retTypes.append(cntTypes);
auto d =
builder.create<sparse_tensor::DisassembleOp>(loc, retTypes, inputs);
for (unsigned i = 0; i < len; i++)
toVals.push_back(d.getResult(i));
}
}
}
//===----------------------------------------------------------------------===//
// Rewriting rules.
//===----------------------------------------------------------------------===//
namespace {
// A rewriting rules that converts public entry methods that use sparse tensors
// as input parameters and/or output return values into wrapper methods that
// [dis]assemble the individual tensors that constitute the actual storage used
// externally into MLIR sparse tensors before calling the original method.
//
// In particular, each sparse tensor input
//
// void foo(..., t, ...) { }
//
// makes the original foo() internal and adds the following wrapper method
//
// void foo(..., t1..tn, ...) {
// t = assemble t1..tn
// _internal_foo(..., t, ...)
// }
//
// and likewise, each output tensor
//
// ... T ... bar(...) { return ..., t, ...; }
//
// makes the original bar() internal and adds the following wrapper method
//
// ... T1..TN ... bar(..., t1'..tn') {
// ..., t, ... = _internal_bar(...)
// t1..tn = disassemble t, t1'..tn'
// return ..., t1..tn, ...
// }
//
// (with a direct-out variant without the disassemble).
//
struct SparseFuncAssembler : public OpRewritePattern<func::FuncOp> {
using OpRewritePattern::OpRewritePattern;
SparseFuncAssembler(MLIRContext *context, bool dO)
: OpRewritePattern(context), directOut(dO) {}
LogicalResult matchAndRewrite(func::FuncOp funcOp,
PatternRewriter &rewriter) const override {
// Only rewrite public entry methods.
if (funcOp.isPrivate())
return failure();
// Translate sparse tensor types to external types.
SmallVector<Type> inputTypes;
SmallVector<Type> outputTypes;
SmallVector<Type> extraTypes;
convTypes(funcOp.getArgumentTypes(), inputTypes, nullptr, false);
convTypes(funcOp.getResultTypes(), outputTypes, &extraTypes, directOut);
// Only sparse inputs or outputs need a wrapper method.
if (inputTypes.size() == funcOp.getArgumentTypes().size() &&
outputTypes.size() == funcOp.getResultTypes().size())
return failure();
// Modify the original method into an internal, private method.
auto orgName = funcOp.getName();
std::string wrapper = llvm::formatv("_internal_{0}", orgName).str();
funcOp.setName(wrapper);
funcOp.setPrivate();
// Start the new public wrapper method with original name.
Location loc = funcOp.getLoc();
ModuleOp modOp = funcOp->getParentOfType<ModuleOp>();
MLIRContext *context = modOp.getContext();
OpBuilder moduleBuilder(modOp.getBodyRegion());
unsigned extra = inputTypes.size();
inputTypes.append(extraTypes);
auto func = moduleBuilder.create<func::FuncOp>(
loc, orgName, FunctionType::get(context, inputTypes, outputTypes));
func.setPublic();
// Construct new wrapper method body.
OpBuilder::InsertionGuard insertionGuard(rewriter);
Block *body = func.addEntryBlock();
rewriter.setInsertionPointToStart(body);
// Convert inputs.
SmallVector<Value> inputs;
convVals(rewriter, loc, funcOp.getArgumentTypes(), body->getArguments(),
ValueRange(), inputs, /*extra=*/0, /*isIn=*/true, directOut);
// Call the original, now private method. A subsequent inlining pass can
// determine whether cloning the method body in place is worthwhile.
auto org = SymbolRefAttr::get(context, wrapper);
auto call = rewriter.create<func::CallOp>(loc, funcOp.getResultTypes(), org,
inputs);
// Convert outputs and return.
SmallVector<Value> outputs;
convVals(rewriter, loc, funcOp.getResultTypes(), call.getResults(),
body->getArguments(), outputs, extra, /*isIn=*/false, directOut);
rewriter.create<func::ReturnOp>(loc, outputs);
// Finally, migrate a potential c-interface property.
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));
funcOp->removeAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName());
}
return success();
}
private:
const bool directOut;
};
} // namespace
//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//
void mlir::populateSparseAssembler(RewritePatternSet &patterns,
bool directOut) {
patterns.add<SparseFuncAssembler>(patterns.getContext(), directOut);
}