blob: 91dad050b28852c2df08873d881d3607bf5c513f [file] [log] [blame]
//===- SparseStorageSpecifierToLLVM.cpp - convert specifier to llvm -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Utils/CodegenUtils.h"
#include "mlir/Conversion/LLVMCommon/StructBuilder.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include <optional>
using namespace mlir;
using namespace sparse_tensor;
namespace {
//===----------------------------------------------------------------------===//
// Helper methods.
//===----------------------------------------------------------------------===//
static SmallVector<Type, 4> getSpecifierFields(StorageSpecifierType tp) {
MLIRContext *ctx = tp.getContext();
auto enc = tp.getEncoding();
const Level lvlRank = enc.getLvlRank();
SmallVector<Type, 4> result;
// TODO: how can we get the lowering type for index type in the later pipeline
// to be consistent? LLVM::StructureType does not allow index fields.
auto sizeType = IntegerType::get(tp.getContext(), 64);
auto lvlSizes = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
auto memSizes = LLVM::LLVMArrayType::get(ctx, sizeType,
getNumDataFieldsFromEncoding(enc));
result.push_back(lvlSizes);
result.push_back(memSizes);
if (enc.isSlice()) {
// Extra fields are required for the slice information.
auto dimOffset = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
auto dimStride = LLVM::LLVMArrayType::get(ctx, sizeType, lvlRank);
result.push_back(dimOffset);
result.push_back(dimStride);
}
return result;
}
static Type convertSpecifier(StorageSpecifierType tp) {
return LLVM::LLVMStructType::getLiteral(tp.getContext(),
getSpecifierFields(tp));
}
//===----------------------------------------------------------------------===//
// Specifier struct builder.
//===----------------------------------------------------------------------===//
constexpr uint64_t kLvlSizePosInSpecifier = 0;
constexpr uint64_t kMemSizePosInSpecifier = 1;
constexpr uint64_t kDimOffsetPosInSpecifier = 2;
constexpr uint64_t kDimStridePosInSpecifier = 3;
class SpecifierStructBuilder : public StructBuilder {
private:
Value extractField(OpBuilder &builder, Location loc,
ArrayRef<int64_t> indices) const {
return genCast(builder, loc,
builder.create<LLVM::ExtractValueOp>(loc, value, indices),
builder.getIndexType());
}
void insertField(OpBuilder &builder, Location loc, ArrayRef<int64_t> indices,
Value v) {
value = builder.create<LLVM::InsertValueOp>(
loc, value, genCast(builder, loc, v, builder.getIntegerType(64)),
indices);
}
public:
explicit SpecifierStructBuilder(Value specifier) : StructBuilder(specifier) {
assert(value);
}
// Undef value for dimension sizes, all zero value for memory sizes.
static Value getInitValue(OpBuilder &builder, Location loc, Type structType,
Value source);
Value lvlSize(OpBuilder &builder, Location loc, Level lvl) const;
void setLvlSize(OpBuilder &builder, Location loc, Level lvl, Value size);
Value dimOffset(OpBuilder &builder, Location loc, Dimension dim) const;
void setDimOffset(OpBuilder &builder, Location loc, Dimension dim,
Value size);
Value dimStride(OpBuilder &builder, Location loc, Dimension dim) const;
void setDimStride(OpBuilder &builder, Location loc, Dimension dim,
Value size);
Value memSize(OpBuilder &builder, Location loc, FieldIndex fidx) const;
void setMemSize(OpBuilder &builder, Location loc, FieldIndex fidx,
Value size);
Value memSizeArray(OpBuilder &builder, Location loc) const;
void setMemSizeArray(OpBuilder &builder, Location loc, Value array);
};
Value SpecifierStructBuilder::getInitValue(OpBuilder &builder, Location loc,
Type structType, Value source) {
Value metaData = builder.create<LLVM::UndefOp>(loc, structType);
SpecifierStructBuilder md(metaData);
if (!source) {
auto memSizeArrayType =
cast<LLVM::LLVMArrayType>(cast<LLVM::LLVMStructType>(structType)
.getBody()[kMemSizePosInSpecifier]);
Value zero = constantZero(builder, loc, memSizeArrayType.getElementType());
// Fill memSizes array with zero.
for (int i = 0, e = memSizeArrayType.getNumElements(); i < e; i++)
md.setMemSize(builder, loc, i, zero);
} else {
// We copy non-slice information (memory sizes array) from source
SpecifierStructBuilder sourceMd(source);
md.setMemSizeArray(builder, loc, sourceMd.memSizeArray(builder, loc));
}
return md;
}
/// Builds IR extracting the pos-th offset from the descriptor.
Value SpecifierStructBuilder::dimOffset(OpBuilder &builder, Location loc,
Dimension dim) const {
return extractField(
builder, loc,
ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)});
}
/// Builds IR inserting the pos-th offset into the descriptor.
void SpecifierStructBuilder::setDimOffset(OpBuilder &builder, Location loc,
Dimension dim, Value size) {
insertField(
builder, loc,
ArrayRef<int64_t>{kDimOffsetPosInSpecifier, static_cast<int64_t>(dim)},
size);
}
/// Builds IR extracting the `lvl`-th level-size from the descriptor.
Value SpecifierStructBuilder::lvlSize(OpBuilder &builder, Location loc,
Level lvl) const {
// This static_cast makes the narrowing of `lvl` explicit, as required
// by the braces notation for the ctor.
return extractField(
builder, loc,
ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)});
}
/// Builds IR inserting the `lvl`-th level-size into the descriptor.
void SpecifierStructBuilder::setLvlSize(OpBuilder &builder, Location loc,
Level lvl, Value size) {
// This static_cast makes the narrowing of `lvl` explicit, as required
// by the braces notation for the ctor.
insertField(
builder, loc,
ArrayRef<int64_t>{kLvlSizePosInSpecifier, static_cast<int64_t>(lvl)},
size);
}
/// Builds IR extracting the pos-th stride from the descriptor.
Value SpecifierStructBuilder::dimStride(OpBuilder &builder, Location loc,
Dimension dim) const {
return extractField(
builder, loc,
ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)});
}
/// Builds IR inserting the pos-th stride into the descriptor.
void SpecifierStructBuilder::setDimStride(OpBuilder &builder, Location loc,
Dimension dim, Value size) {
insertField(
builder, loc,
ArrayRef<int64_t>{kDimStridePosInSpecifier, static_cast<int64_t>(dim)},
size);
}
/// Builds IR extracting the pos-th memory size into the descriptor.
Value SpecifierStructBuilder::memSize(OpBuilder &builder, Location loc,
FieldIndex fidx) const {
return extractField(
builder, loc,
ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)});
}
/// Builds IR inserting the `fidx`-th memory-size into the descriptor.
void SpecifierStructBuilder::setMemSize(OpBuilder &builder, Location loc,
FieldIndex fidx, Value size) {
insertField(
builder, loc,
ArrayRef<int64_t>{kMemSizePosInSpecifier, static_cast<int64_t>(fidx)},
size);
}
/// Builds IR extracting the memory size array from the descriptor.
Value SpecifierStructBuilder::memSizeArray(OpBuilder &builder,
Location loc) const {
return builder.create<LLVM::ExtractValueOp>(loc, value,
kMemSizePosInSpecifier);
}
/// Builds IR inserting the memory size array into the descriptor.
void SpecifierStructBuilder::setMemSizeArray(OpBuilder &builder, Location loc,
Value array) {
value = builder.create<LLVM::InsertValueOp>(loc, value, array,
kMemSizePosInSpecifier);
}
} // namespace
//===----------------------------------------------------------------------===//
// The sparse storage specifier type converter (defined in Passes.h).
//===----------------------------------------------------------------------===//
StorageSpecifierToLLVMTypeConverter::StorageSpecifierToLLVMTypeConverter() {
addConversion([](Type type) { return type; });
addConversion(convertSpecifier);
}
//===----------------------------------------------------------------------===//
// Storage specifier conversion rules.
//===----------------------------------------------------------------------===//
template <typename Base, typename SourceOp>
class SpecifierGetterSetterOpConverter : public OpConversionPattern<SourceOp> {
public:
using OpAdaptor = typename SourceOp::Adaptor;
using OpConversionPattern<SourceOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SpecifierStructBuilder spec(adaptor.getSpecifier());
switch (op.getSpecifierKind()) {
case StorageSpecifierKind::LvlSize: {
Value v = Base::onLvlSize(rewriter, op, spec, (*op.getLevel()));
rewriter.replaceOp(op, v);
return success();
}
case StorageSpecifierKind::DimOffset: {
Value v = Base::onDimOffset(rewriter, op, spec, (*op.getLevel()));
rewriter.replaceOp(op, v);
return success();
}
case StorageSpecifierKind::DimStride: {
Value v = Base::onDimStride(rewriter, op, spec, (*op.getLevel()));
rewriter.replaceOp(op, v);
return success();
}
case StorageSpecifierKind::CrdMemSize:
case StorageSpecifierKind::PosMemSize:
case StorageSpecifierKind::ValMemSize: {
auto enc = op.getSpecifier().getType().getEncoding();
StorageLayout layout(enc);
std::optional<unsigned> lvl;
if (op.getLevel())
lvl = (*op.getLevel());
unsigned idx =
layout.getMemRefFieldIndex(toFieldKind(op.getSpecifierKind()), lvl);
Value v = Base::onMemSize(rewriter, op, spec, idx);
rewriter.replaceOp(op, v);
return success();
}
}
llvm_unreachable("unrecognized specifer kind");
}
};
struct StorageSpecifierSetOpConverter
: public SpecifierGetterSetterOpConverter<StorageSpecifierSetOpConverter,
SetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
static Value onLvlSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, Level lvl) {
spec.setLvlSize(builder, op.getLoc(), lvl, op.getValue());
return spec;
}
static Value onDimOffset(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, Dimension d) {
spec.setDimOffset(builder, op.getLoc(), d, op.getValue());
return spec;
}
static Value onDimStride(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, Dimension d) {
spec.setDimStride(builder, op.getLoc(), d, op.getValue());
return spec;
}
static Value onMemSize(OpBuilder &builder, SetStorageSpecifierOp op,
SpecifierStructBuilder &spec, FieldIndex fidx) {
spec.setMemSize(builder, op.getLoc(), fidx, op.getValue());
return spec;
}
};
struct StorageSpecifierGetOpConverter
: public SpecifierGetterSetterOpConverter<StorageSpecifierGetOpConverter,
GetStorageSpecifierOp> {
using SpecifierGetterSetterOpConverter::SpecifierGetterSetterOpConverter;
static Value onLvlSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, Level lvl) {
return spec.lvlSize(builder, op.getLoc(), lvl);
}
static Value onDimOffset(OpBuilder &builder, GetStorageSpecifierOp op,
const SpecifierStructBuilder &spec, Dimension d) {
return spec.dimOffset(builder, op.getLoc(), d);
}
static Value onDimStride(OpBuilder &builder, GetStorageSpecifierOp op,
const SpecifierStructBuilder &spec, Dimension d) {
return spec.dimStride(builder, op.getLoc(), d);
}
static Value onMemSize(OpBuilder &builder, GetStorageSpecifierOp op,
SpecifierStructBuilder &spec, FieldIndex fidx) {
return spec.memSize(builder, op.getLoc(), fidx);
}
};
struct StorageSpecifierInitOpConverter
: public OpConversionPattern<StorageSpecifierInitOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(StorageSpecifierInitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type llvmType = getTypeConverter()->convertType(op.getResult().getType());
rewriter.replaceOp(
op, SpecifierStructBuilder::getInitValue(
rewriter, op.getLoc(), llvmType, adaptor.getSource()));
return success();
}
};
//===----------------------------------------------------------------------===//
// Public method for populating conversion rules.
//===----------------------------------------------------------------------===//
void mlir::populateStorageSpecifierToLLVMPatterns(TypeConverter &converter,
RewritePatternSet &patterns) {
patterns.add<StorageSpecifierGetOpConverter, StorageSpecifierSetOpConverter,
StorageSpecifierInitOpConverter>(converter,
patterns.getContext());
}