blob: 33d449aac5a35505f14bd437219ed83a45b5d780 [file] [log] [blame]
//===- CodegenUtils.cpp - Utilities for generating MLIR -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "CodegenUtils.h"
#include "SparseTensorDescriptor.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include <optional>
using namespace mlir;
using namespace mlir::sparse_tensor;
//===----------------------------------------------------------------------===//
// ExecutionEngine/SparseTensorUtils helper functions.
//===----------------------------------------------------------------------===//
OverheadType mlir::sparse_tensor::overheadTypeEncoding(unsigned width) {
switch (width) {
case 64:
return OverheadType::kU64;
case 32:
return OverheadType::kU32;
case 16:
return OverheadType::kU16;
case 8:
return OverheadType::kU8;
case 0:
return OverheadType::kIndex;
}
llvm_unreachable("Unsupported overhead bitwidth");
}
OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
if (tp.isIndex())
return OverheadType::kIndex;
if (auto intTp = dyn_cast<IntegerType>(tp))
return overheadTypeEncoding(intTp.getWidth());
llvm_unreachable("Unknown overhead type");
}
Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
switch (ot) {
case OverheadType::kIndex:
return builder.getIndexType();
case OverheadType::kU64:
return builder.getIntegerType(64);
case OverheadType::kU32:
return builder.getIntegerType(32);
case OverheadType::kU16:
return builder.getIntegerType(16);
case OverheadType::kU8:
return builder.getIntegerType(8);
}
llvm_unreachable("Unknown OverheadType");
}
OverheadType
mlir::sparse_tensor::posTypeEncoding(SparseTensorEncodingAttr enc) {
return overheadTypeEncoding(enc.getPosWidth());
}
OverheadType
mlir::sparse_tensor::crdTypeEncoding(SparseTensorEncodingAttr enc) {
return overheadTypeEncoding(enc.getCrdWidth());
}
// TODO: we ought to add some `static_assert` tests to ensure that the
// `STEA::get{Pos,Crd}Type` methods agree with `getOverheadType(builder,
// {pos,crd}OverheadTypeEncoding(enc))`
// TODO: Adjust the naming convention for the constructors of
// `OverheadType` so we can use the `MLIR_SPARSETENSOR_FOREVERY_O` x-macro
// here instead of `MLIR_SPARSETENSOR_FOREVERY_FIXED_O`; to further reduce
// the possibility of typo bugs or things getting out of sync.
StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(OverheadType ot) {
switch (ot) {
case OverheadType::kIndex:
return "0";
#define CASE(ONAME, O) \
case OverheadType::kU##ONAME: \
return #ONAME;
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(CASE)
#undef CASE
}
llvm_unreachable("Unknown OverheadType");
}
StringRef mlir::sparse_tensor::overheadTypeFunctionSuffix(Type tp) {
return overheadTypeFunctionSuffix(overheadTypeEncoding(tp));
}
PrimaryType mlir::sparse_tensor::primaryTypeEncoding(Type elemTp) {
if (elemTp.isF64())
return PrimaryType::kF64;
if (elemTp.isF32())
return PrimaryType::kF32;
if (elemTp.isF16())
return PrimaryType::kF16;
if (elemTp.isBF16())
return PrimaryType::kBF16;
if (elemTp.isInteger(64))
return PrimaryType::kI64;
if (elemTp.isInteger(32))
return PrimaryType::kI32;
if (elemTp.isInteger(16))
return PrimaryType::kI16;
if (elemTp.isInteger(8))
return PrimaryType::kI8;
if (auto complexTp = dyn_cast<ComplexType>(elemTp)) {
auto complexEltTp = complexTp.getElementType();
if (complexEltTp.isF64())
return PrimaryType::kC64;
if (complexEltTp.isF32())
return PrimaryType::kC32;
}
llvm_unreachable("Unknown primary type");
}
StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(PrimaryType pt) {
switch (pt) {
#define CASE(VNAME, V) \
case PrimaryType::k##VNAME: \
return #VNAME;
MLIR_SPARSETENSOR_FOREVERY_V(CASE)
#undef CASE
}
llvm_unreachable("Unknown PrimaryType");
}
StringRef mlir::sparse_tensor::primaryTypeFunctionSuffix(Type elemTp) {
return primaryTypeFunctionSuffix(primaryTypeEncoding(elemTp));
}
//===----------------------------------------------------------------------===//
// Misc code generators.
//===----------------------------------------------------------------------===//
Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value,
Type dstTp) {
const Type srcTp = value.getType();
if (srcTp == dstTp)
return value;
// int <=> index
if (isa<IndexType>(srcTp) || isa<IndexType>(dstTp))
return builder.create<arith::IndexCastOp>(loc, dstTp, value);
const auto srcIntTp = dyn_cast_or_null<IntegerType>(srcTp);
const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false;
return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast);
}
Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
Value elem, Type dstTp) {
if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) {
// Scalars can only be converted to 0-ranked tensors.
assert(rtp.getRank() == 0);
elem = sparse_tensor::genCast(builder, loc, elem, rtp.getElementType());
return builder.create<tensor::FromElementsOp>(loc, rtp, elem);
}
return sparse_tensor::genCast(builder, loc, elem, dstTp);
}
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
Value s) {
Value load = builder.create<memref::LoadOp>(loc, mem, s);
if (!isa<IndexType>(load.getType())) {
if (load.getType().getIntOrFloatBitWidth() < 64)
load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load);
load =
builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load);
}
return load;
}
mlir::TypedAttr mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {
if (isa<FloatType>(tp))
return builder.getFloatAttr(tp, 1.0);
if (isa<IndexType>(tp))
return builder.getIndexAttr(1);
if (auto intTp = dyn_cast<IntegerType>(tp))
return builder.getIntegerAttr(tp, APInt(intTp.getWidth(), 1));
if (isa<RankedTensorType, VectorType>(tp)) {
auto shapedTp = cast<ShapedType>(tp);
if (auto one = getOneAttr(builder, shapedTp.getElementType()))
return DenseElementsAttr::get(shapedTp, one);
}
llvm_unreachable("Unsupported attribute type");
}
Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
Value v) {
Type tp = v.getType();
Value zero = constantZero(builder, loc, tp);
if (isa<FloatType>(tp))
return builder.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNE, v,
zero);
if (tp.isIntOrIndex())
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, v,
zero);
if (dyn_cast<ComplexType>(tp))
return builder.create<complex::NotEqualOp>(loc, v, zero);
llvm_unreachable("Non-numeric type");
}
void mlir::sparse_tensor::genReshapeDstShape(
OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
ArrayRef<ReassociationIndices> reassociation) {
// Collapse shape.
if (reassociation.size() < srcShape.size()) {
unsigned start = 0;
for (const auto &map : llvm::enumerate(reassociation)) {
auto dstDim = constantIndex(builder, loc, 1);
for (unsigned i = start; i < start + map.value().size(); i++) {
dstDim = builder.create<arith::MulIOp>(loc, dstDim, srcShape[i]);
}
dstShape.push_back(dstDim);
start = start + map.value().size();
}
assert(start == srcShape.size());
return;
}
// Expand shape.
assert(reassociation.size() == srcShape.size());
unsigned start = 0;
// Expand the i-th dimension in srcShape.
for (unsigned i = 0, size = srcShape.size(); i < size; i++) {
const auto &map = reassociation[i];
auto srcDim = srcShape[i];
// Iterate through dimensions expanded from the i-th dimension.
for (unsigned j = start; j < start + map.size(); j++) {
// There can be only one dynamic sized dimension among dimensions
// expanded from the i-th dimension in srcShape.
// For example, if srcDim = 8, then the expanded shape could be <2x?x2>,
// but not <2x?x?>.
if (staticDstShape[j] == ShapedType::kDynamic) {
// The expanded dimension has dynamic size. We compute the dimension
// by dividing srcDim by the product of the static dimensions.
Size product = 1;
for (unsigned k = start; k < start + map.size(); k++) {
if (staticDstShape[k] != ShapedType::kDynamic) {
product *= staticDstShape[k];
}
}
// Compute the dynamic dimension size.
Value productVal = constantIndex(builder, loc, product);
Value dynamicSize =
builder.create<arith::DivUIOp>(loc, srcDim, productVal);
dstShape.push_back(dynamicSize);
} else {
// The expanded dimension is statically known.
dstShape.push_back(constantIndex(builder, loc, staticDstShape[j]));
}
}
start = start + map.size();
}
assert(start == staticDstShape.size());
}
void mlir::sparse_tensor::reshapeCvs(
OpBuilder &builder, Location loc,
ArrayRef<ReassociationIndices> reassociation, // NOLINT
ValueRange srcSizes, ValueRange srcCvs, // NOLINT
ValueRange dstSizes, SmallVectorImpl<Value> &dstCvs) {
const unsigned srcRank = srcSizes.size();
const unsigned dstRank = dstSizes.size();
assert(srcRank == srcCvs.size() && "Source rank mismatch");
const bool isCollapse = srcRank > dstRank;
const ValueRange sizes = isCollapse ? srcSizes : dstSizes;
// Iterate over reassociation map.
unsigned i = 0;
unsigned start = 0;
for (const auto &map : llvm::enumerate(reassociation)) {
// Prepare strides information in dimension slice.
Value linear = constantIndex(builder, loc, 1);
for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
linear = builder.create<arith::MulIOp>(loc, linear, sizes[j]);
}
// Start expansion.
Value val;
if (!isCollapse)
val = srcCvs[i];
// Iterate over dimension slice.
for (unsigned j = start, end = start + map.value().size(); j < end; j++) {
linear = builder.create<arith::DivUIOp>(loc, linear, sizes[j]);
if (isCollapse) {
const Value mul = builder.create<arith::MulIOp>(loc, srcCvs[j], linear);
val = val ? builder.create<arith::AddIOp>(loc, val, mul) : mul;
} else {
const Value old = val;
val = builder.create<arith::DivUIOp>(loc, val, linear);
assert(dstCvs.size() == j);
dstCvs.push_back(val);
val = builder.create<arith::RemUIOp>(loc, old, linear);
}
}
// Finalize collapse.
if (isCollapse) {
assert(dstCvs.size() == i);
dstCvs.push_back(val);
}
start += map.value().size();
i++;
}
assert(dstCvs.size() == dstRank);
}
FlatSymbolRefAttr mlir::sparse_tensor::getFunc(ModuleOp module, StringRef name,
TypeRange resultType,
ValueRange operands,
EmitCInterface emitCInterface) {
MLIRContext *context = module.getContext();
auto result = SymbolRefAttr::get(context, name);
auto func = module.lookupSymbol<func::FuncOp>(result.getAttr());
if (!func) {
OpBuilder moduleBuilder(module.getBodyRegion());
func = moduleBuilder.create<func::FuncOp>(
module.getLoc(), name,
FunctionType::get(context, operands.getTypes(), resultType));
func.setPrivate();
if (static_cast<bool>(emitCInterface))
func->setAttr(LLVM::LLVMDialect::getEmitCWrapperAttrName(),
UnitAttr::get(context));
}
return result;
}
func::CallOp mlir::sparse_tensor::createFuncCall(
OpBuilder &builder, Location loc, StringRef name, TypeRange resultType,
ValueRange operands, EmitCInterface emitCInterface) {
auto module = builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
FlatSymbolRefAttr fn =
getFunc(module, name, resultType, operands, emitCInterface);
return builder.create<func::CallOp>(loc, resultType, fn, operands);
}
Type mlir::sparse_tensor::getOpaquePointerType(MLIRContext *ctx) {
return LLVM::LLVMPointerType::get(ctx);
}
Type mlir::sparse_tensor::getOpaquePointerType(Builder &builder) {
return getOpaquePointerType(builder.getContext());
}
Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc,
unsigned sz, Type tp, bool staticShape) {
if (staticShape) {
auto memTp = MemRefType::get({sz}, tp);
return builder.create<memref::AllocaOp>(loc, memTp);
}
return genAlloca(builder, loc, constantIndex(builder, loc, sz), tp);
}
Value mlir::sparse_tensor::genAlloca(OpBuilder &builder, Location loc, Value sz,
Type tp) {
auto memTp = MemRefType::get({ShapedType::kDynamic}, tp);
return builder.create<memref::AllocaOp>(loc, memTp, ValueRange{sz});
}
Value mlir::sparse_tensor::genAllocaScalar(OpBuilder &builder, Location loc,
Type tp) {
return builder.create<memref::AllocaOp>(loc, MemRefType::get({}, tp));
}
Value mlir::sparse_tensor::allocaBuffer(OpBuilder &builder, Location loc,
ValueRange values) {
const unsigned sz = values.size();
assert(sz >= 1);
Value buffer = genAlloca(builder, loc, sz, values[0].getType());
for (unsigned i = 0; i < sz; i++) {
Value idx = constantIndex(builder, loc, i);
builder.create<memref::StoreOp>(loc, values[i], buffer, idx);
}
return buffer;
}
Value mlir::sparse_tensor::allocDenseTensor(OpBuilder &builder, Location loc,
RankedTensorType tensorTp,
ValueRange sizes) {
Type elemTp = tensorTp.getElementType();
auto shape = tensorTp.getShape();
auto memTp = MemRefType::get(shape, elemTp);
SmallVector<Value> dynamicSizes;
for (unsigned i = 0, rank = tensorTp.getRank(); i < rank; i++) {
if (shape[i] == ShapedType::kDynamic)
dynamicSizes.push_back(sizes[i]);
}
Value mem = builder.create<memref::AllocOp>(loc, memTp, dynamicSizes);
Value zero = constantZero(builder, loc, elemTp);
builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{mem});
return mem;
}
void mlir::sparse_tensor::deallocDenseTensor(OpBuilder &builder, Location loc,
Value buffer) {
builder.create<memref::DeallocOp>(loc, buffer);
}
void mlir::sparse_tensor::sizesFromSrc(OpBuilder &builder,
SmallVectorImpl<Value> &sizes,
Location loc, Value src) {
const Dimension dimRank = getSparseTensorType(src).getDimRank();
for (Dimension d = 0; d < dimRank; d++)
sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d));
}
Operation *mlir::sparse_tensor::getTop(Operation *op) {
for (; isa<scf::ForOp>(op->getParentOp()) ||
isa<scf::WhileOp>(op->getParentOp()) ||
isa<scf::ParallelOp>(op->getParentOp()) ||
isa<scf::IfOp>(op->getParentOp());
op = op->getParentOp())
;
return op;
}
void sparse_tensor::foreachInSparseConstant(
OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
function_ref<void(ArrayRef<Value>, Value)> callback) {
if (!order)
order = builder.getMultiDimIdentityMap(attr.getType().getRank());
auto stt = SparseTensorType(getRankedTensorType(attr));
const Dimension dimRank = stt.getDimRank();
const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
const auto values = attr.getValues().getValues<Attribute>();
// This is like the `Element<V>` class in the runtime library, but for
// MLIR attributes. In the future we may want to move this out into
// a proper class definition to help improve code legibility (e.g.,
// `first` -> `coords`, `second` -> `value`) as well as being able
// to factor out analogues of `ElementLT<V>` for the sort below, etc.
using ElementAttr = std::pair<SmallVector<IntegerAttr>, Attribute>;
// Construct the COO from the SparseElementsAttr.
SmallVector<ElementAttr> elems;
for (size_t i = 0, nse = values.size(); i < nse; i++) {
elems.emplace_back();
elems.back().second = values[i];
auto &coords = elems.back().first;
coords.reserve(dimRank);
for (Dimension d = 0; d < dimRank; d++)
coords.push_back(coordinates[i * dimRank + d]);
}
// Sorts the sparse element attribute based on coordinates.
std::sort(elems.begin(), elems.end(),
[order](const ElementAttr &lhs, const ElementAttr &rhs) {
if (std::addressof(lhs) == std::addressof(rhs))
return false;
auto lhsCoords = llvm::map_to_vector(
lhs.first, [](IntegerAttr i) { return i.getInt(); });
auto rhsCoords = llvm::map_to_vector(
rhs.first, [](IntegerAttr i) { return i.getInt(); });
SmallVector<int64_t, 4> lhsLvlCrds = order.compose(lhsCoords);
SmallVector<int64_t, 4> rhsLvlCrds = order.compose(rhsCoords);
// Sort the element based on the lvl coordinates.
for (Level l = 0; l < order.getNumResults(); l++) {
if (lhsLvlCrds[l] == rhsLvlCrds[l])
continue;
return lhsLvlCrds[l] < rhsLvlCrds[l];
}
llvm_unreachable("no equal coordinate in sparse element attr");
});
SmallVector<Value> cvs;
cvs.reserve(dimRank);
for (size_t i = 0, nse = values.size(); i < nse; i++) {
// Remap coordinates.
cvs.clear();
for (Dimension d = 0; d < dimRank; d++) {
auto crd = elems[i].first[d].getInt();
cvs.push_back(builder.create<arith::ConstantIndexOp>(loc, crd));
}
// Remap value.
Value val;
if (isa<ComplexType>(attr.getElementType())) {
auto valAttr = cast<ArrayAttr>(elems[i].second);
val = builder.create<complex::ConstantOp>(loc, attr.getElementType(),
valAttr);
} else {
auto valAttr = cast<TypedAttr>(elems[i].second);
val = builder.create<arith::ConstantOp>(loc, valAttr);
}
assert(val);
callback(cvs, val);
}
}
SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
size_t size, Value mem,
size_t offsetIdx, Value offsetVal) {
#ifndef NDEBUG
const auto memTp = cast<MemRefType>(mem.getType());
assert(memTp.getRank() == 1);
const Size memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
assert(offsetIdx == 0 || offsetIdx < size);
#endif // NDEBUG
SmallVector<Value> vs;
vs.reserve(size);
for (unsigned i = 0; i < size; i++) {
Value v = builder.create<memref::LoadOp>(loc, mem,
constantIndex(builder, loc, i));
if (i == offsetIdx && offsetVal)
v = builder.create<arith::AddIOp>(loc, v, offsetVal);
vs.push_back(v);
}
return vs;
}
void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
ValueRange vs, size_t offsetIdx, Value offsetVal) {
#ifndef NDEBUG
const size_t vsize = vs.size();
const auto memTp = cast<MemRefType>(mem.getType());
assert(memTp.getRank() == 1);
const Size memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
assert(offsetIdx == 0 || offsetIdx < vsize);
#endif // NDEBUG
for (const auto &v : llvm::enumerate(vs)) {
const Value w =
(offsetIdx == v.index() && offsetVal)
? builder.create<arith::AddIOp>(loc, v.value(), offsetVal)
: v.value();
builder.create<memref::StoreOp>(loc, w, mem,
constantIndex(builder, loc, v.index()));
}
}
TypedValue<BaseMemRefType>
sparse_tensor::genToMemref(OpBuilder &builder, Location loc, Value tensor) {
auto tTp = llvm::cast<TensorType>(tensor.getType());
auto mTp = MemRefType::get(tTp.getShape(), tTp.getElementType());
return builder.create<bufferization::ToMemrefOp>(loc, mTp, tensor)
.getResult();
}
Value sparse_tensor::genToPositions(OpBuilder &builder, Location loc,
Value tensor, Level lvl) {
const auto srcTp = getSparseTensorType(tensor);
const Type posTp = srcTp.getPosType();
const Type memTp = get1DMemRefType(posTp, /*withLayout=*/false);
return builder.create<ToPositionsOp>(loc, memTp, tensor,
builder.getIndexAttr(lvl));
}
Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
Value tensor, Level lvl, Level cooStart) {
const auto srcTp = getSparseTensorType(tensor);
const Type crdTp = srcTp.getCrdType();
const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/lvl >= cooStart);
return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
builder.getIndexAttr(lvl));
}
Value sparse_tensor::genToCoordinatesBuffer(OpBuilder &builder, Location loc,
Value tensor) {
const auto srcTp = getSparseTensorType(tensor);
const Type crdTp = srcTp.getCrdType();
const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/false);
return builder.create<ToCoordinatesBufferOp>(loc, memTp, tensor);
}
Value sparse_tensor::genToValues(OpBuilder &builder, Location loc,
Value tensor) {
RankedTensorType srcTp = getRankedTensorType(tensor);
Type valTp = get1DMemRefType(srcTp.getElementType(),
/*withLayout=*/false);
return builder.create<ToValuesOp>(loc, valTp, tensor);
}
Value sparse_tensor::genValMemSize(OpBuilder &builder, Location loc,
Value tensor) {
return getDescriptorFromTensorTuple(tensor).getValMemSize(builder, loc);
}
Value sparse_tensor::createOrFoldSliceOffsetOp(OpBuilder &builder, Location loc,
Value tensor, Dimension dim) {
auto enc = getSparseTensorEncoding(tensor.getType());
assert(enc && enc.isSlice());
std::optional<unsigned> offset = enc.getStaticDimSliceOffset(dim);
if (offset.has_value())
return constantIndex(builder, loc, *offset);
return builder.create<ToSliceOffsetOp>(loc, tensor, APInt(64, dim));
}
Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
Value tensor, Dimension dim) {
auto enc = getSparseTensorEncoding(tensor.getType());
assert(enc && enc.isSlice());
std::optional<unsigned> stride = enc.getStaticDimSliceStride(dim);
if (stride.has_value())
return constantIndex(builder, loc, *stride);
return builder.create<ToSliceStrideOp>(loc, tensor, APInt(64, dim));
}
Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
SparseTensorType stt, Value tensor,
/*out*/ SmallVectorImpl<Value> &dimSizesValues,
/*out*/ Value &dimSizesBuffer) {
// Construct the dimension **shapes** buffer. The buffer contains the static
// size per dimension, or otherwise a zero for a dynamic size.
Dimension dimRank = stt.getDimRank();
dimSizesValues.clear();
dimSizesValues.reserve(dimRank);
for (const Size sz : stt.getDimShape()) {
const auto s = ShapedType::isDynamic(sz) ? 0 : sz;
dimSizesValues.push_back(constantIndex(builder, loc, s));
}
Value dimShapesBuffer = allocaBuffer(builder, loc, dimSizesValues);
// Create the `CheckedSparseTensorReader`. This reader performs a
// consistency check on the static sizes, but accepts any size
// of each dimension with a dynamic size.
Type opaqueTp = getOpaquePointerType(builder);
Type eltTp = stt.getElementType();
Value valTp = constantPrimaryTypeEncoding(builder, loc, eltTp);
Value reader =
createFuncCall(builder, loc, "createCheckedSparseTensorReader", opaqueTp,
{tensor, dimShapesBuffer, valTp}, EmitCInterface::On)
.getResult(0);
// For static shapes, the shape buffer can be used right away. For dynamic
// shapes, use the information from the reader to construct a buffer that
// supplies the actual size for each dynamic dimension.
dimSizesBuffer = dimShapesBuffer;
if (stt.hasDynamicDimShape()) {
Type indexTp = builder.getIndexType();
auto memTp = MemRefType::get({ShapedType::kDynamic}, indexTp);
dimSizesBuffer =
createFuncCall(builder, loc, "getSparseTensorReaderDimSizes", memTp,
reader, EmitCInterface::On)
.getResult(0);
// Also convert the dim shapes values into dim sizes values, just in case
// subsequent clients need the values (DCE will remove unused).
for (Dimension d = 0; d < dimRank; d++) {
if (stt.isDynamicDim(d))
dimSizesValues[d] = builder.create<memref::LoadOp>(
loc, dimSizesBuffer, constantIndex(builder, loc, d));
}
}
return reader;
}
Value sparse_tensor::genMapBuffers(
OpBuilder &builder, Location loc, SparseTensorType stt,
ArrayRef<Value> dimSizesValues, Value dimSizesBuffer,
/*out*/ SmallVectorImpl<Value> &lvlSizesValues,
/*out*/ Value &dim2lvlBuffer,
/*out*/ Value &lvl2dimBuffer) {
const Dimension dimRank = stt.getDimRank();
const Level lvlRank = stt.getLvlRank();
lvlSizesValues.clear();
lvlSizesValues.reserve(lvlRank);
// For an identity mapping, the dim2lvl and lvl2dim mappings are
// identical as are dimSizes and lvlSizes, so buffers are reused
// as much as possible.
if (stt.isIdentity()) {
assert(dimRank == lvlRank);
SmallVector<Value> iotaValues;
iotaValues.reserve(lvlRank);
for (Level l = 0; l < lvlRank; l++) {
iotaValues.push_back(constantIndex(builder, loc, l));
lvlSizesValues.push_back(dimSizesValues[l]);
}
dim2lvlBuffer = lvl2dimBuffer = allocaBuffer(builder, loc, iotaValues);
return dimSizesBuffer; // now lvlSizesBuffer
}
// Otherwise, some code needs to be generated to set up the buffers.
// This code deals with permutations as well as non-permutations that
// arise from rank changing blocking.
const auto dimToLvl = stt.getDimToLvl();
const auto lvlToDim = stt.getLvlToDim();
SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
// Generate dim2lvl.
assert(lvlRank == dimToLvl.getNumResults());
for (Level l = 0; l < lvlRank; l++) {
AffineExpr exp = dimToLvl.getResult(l);
// We expect:
// (1) l = d
// (2) l = d / c
// (3) l = d % c
Dimension d = 0;
uint64_t cf = 0, cm = 0;
switch (exp.getKind()) {
case AffineExprKind::DimId: {
d = cast<AffineDimExpr>(exp).getPosition();
break;
}
case AffineExprKind::FloorDiv: {
auto floor = cast<AffineBinaryOpExpr>(exp);
d = cast<AffineDimExpr>(floor.getLHS()).getPosition();
cf = cast<AffineConstantExpr>(floor.getRHS()).getValue();
break;
}
case AffineExprKind::Mod: {
auto mod = cast<AffineBinaryOpExpr>(exp);
d = cast<AffineDimExpr>(mod.getLHS()).getPosition();
cm = cast<AffineConstantExpr>(mod.getRHS()).getValue();
break;
}
default:
llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
}
dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
// Compute the level sizes.
// (1) l = d : size(d)
// (2) l = d / c : size(d) / c
// (3) l = d % c : c
Value lvlSz;
if (cm == 0) {
lvlSz = dimSizesValues[d];
if (cf != 0)
lvlSz = builder.create<arith::DivUIOp>(loc, lvlSz,
constantIndex(builder, loc, cf));
} else {
lvlSz = constantIndex(builder, loc, cm);
}
lvlSizesValues.push_back(lvlSz);
}
// Generate lvl2dim.
assert(dimRank == lvlToDim.getNumResults());
for (Dimension d = 0; d < dimRank; d++) {
AffineExpr exp = lvlToDim.getResult(d);
// We expect:
// (1) d = l
// (2) d = l' * c + l
Level l = 0, ll = 0;
uint64_t c = 0;
switch (exp.getKind()) {
case AffineExprKind::DimId: {
l = cast<AffineDimExpr>(exp).getPosition();
break;
}
case AffineExprKind::Add: {
// Always mul on lhs, symbol/constant on rhs.
auto add = cast<AffineBinaryOpExpr>(exp);
assert(add.getLHS().getKind() == AffineExprKind::Mul);
auto mul = cast<AffineBinaryOpExpr>(add.getLHS());
ll = cast<AffineDimExpr>(mul.getLHS()).getPosition();
c = cast<AffineConstantExpr>(mul.getRHS()).getValue();
l = cast<AffineDimExpr>(add.getRHS()).getPosition();
break;
}
default:
llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
}
lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
}
// Return buffers.
dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);
return allocaBuffer(builder, loc, lvlSizesValues); // lvlSizesBuffer
}