| //===- SparseTensorDialect.cpp - Sparse tensor dialect implementation -----===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include <utility> |
| |
| #include "Detail/DimLvlMapParser.h" |
| |
| #include "mlir/Dialect/SparseTensor/IR/Enums.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" |
| |
| #include "mlir/Dialect/Arith/IR/Arith.h" |
| #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" |
| #include "mlir/Dialect/Utils/StaticValueUtils.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| #include "llvm/Support/FormatVariadic.h" |
| |
| #define GET_ATTRDEF_CLASSES |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc" |
| |
| // Forward declarations, following custom print/parsing methods are referenced |
| // by the generated code for SparseTensorTypes.td. |
| static mlir::ParseResult parseLevelRange(mlir::AsmParser &, |
| mlir::sparse_tensor::Level &, |
| mlir::sparse_tensor::Level &); |
| static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level, |
| mlir::sparse_tensor::Level); |
| |
| #define GET_TYPEDEF_CLASSES |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" |
| |
| using namespace mlir; |
| using namespace mlir::sparse_tensor; |
| |
| // Support hashing LevelType such that SparseTensorEncodingAttr can be hashed as |
| // well. |
| namespace mlir::sparse_tensor { |
| llvm::hash_code hash_value(LevelType lt) { |
| return llvm::hash_value(static_cast<uint64_t>(lt)); |
| } |
| } // namespace mlir::sparse_tensor |
| |
| //===----------------------------------------------------------------------===// |
| // Local Convenience Methods. |
| //===----------------------------------------------------------------------===// |
| |
| static constexpr bool acceptBitWidth(unsigned bitWidth) { |
| switch (bitWidth) { |
| case 0: |
| case 8: |
| case 16: |
| case 32: |
| case 64: |
| return true; |
| default: |
| return false; |
| } |
| } |
| |
| static SmallVector<Size> |
| getSparseFieldShape(const SparseTensorEncodingAttr enc, |
| std::optional<ArrayRef<int64_t>> dimShape) { |
| assert(enc); |
| // With only encoding, we can not determine the static shape for leading |
| // batch levels, we therefore return a dynamic shape memref instead. |
| SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic); |
| if (dimShape.has_value()) { |
| // If the actual tensor shape is provided, we can then refine the leading |
| // batch dimension. |
| SmallVector<int64_t> lvlShape = |
| enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl); |
| memrefShape.assign(lvlShape.begin(), |
| lvlShape.begin() + enc.getBatchLvlRank()); |
| } |
| // Another dynamic dimension to store the sparse level. |
| memrefShape.push_back(ShapedType::kDynamic); |
| return memrefShape; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SparseTensorDialect StorageLayout. |
| //===----------------------------------------------------------------------===// |
| |
| static constexpr Level kInvalidLevel = -1u; |
| static constexpr Level kInvalidFieldIndex = -1u; |
| static constexpr FieldIndex kDataFieldStartingIdx = 0; |
| |
| void StorageLayout::foreachField( |
| llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level, |
| LevelType)> |
| callback) const { |
| const auto lvlTypes = enc.getLvlTypes(); |
| const Level lvlRank = enc.getLvlRank(); |
| SmallVector<COOSegment> cooSegs = SparseTensorType(enc).getCOOSegments(); |
| FieldIndex fieldIdx = kDataFieldStartingIdx; |
| |
| ArrayRef cooSegsRef = cooSegs; |
| // Per-level storage. |
| for (Level l = 0; l < lvlRank; /*l += 1 or l += AoSCooLen*/) { |
| const auto lt = lvlTypes[l]; |
| if (isWithPosLT(lt)) { |
| if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, lt))) |
| return; |
| } |
| if (isWithCrdLT(lt)) { |
| if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, lt))) |
| return; |
| } |
| if (!cooSegsRef.empty() && cooSegsRef.front().isSegmentStart(l)) { |
| if (!cooSegsRef.front().isSoA) { |
| // AoS COO, all singletons are fused into one memrefs. Skips the entire |
| // COO segement. |
| l = cooSegsRef.front().lvlRange.second; |
| } else { |
| // SoA COO, each singleton level has one memref. |
| l++; |
| } |
| // Expire handled COO segment. |
| cooSegsRef = cooSegsRef.drop_front(); |
| } else { |
| // Non COO levels. |
| l++; |
| } |
| } |
| // The values array. |
| if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel, |
| LevelFormat::Undef))) |
| return; |
| // Put metadata at the end. |
| if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel, |
| LevelFormat::Undef))) |
| return; |
| } |
| |
| void sparse_tensor::foreachFieldAndTypeInSparseTensor( |
| SparseTensorType stt, |
| llvm::function_ref<bool(Type, FieldIndex, SparseTensorFieldKind, Level, |
| LevelType)> |
| callback) { |
| assert(stt.hasEncoding()); |
| |
| SmallVector<int64_t> memrefShape = |
| getSparseFieldShape(stt.getEncoding(), stt.getDimShape()); |
| |
| const Type specType = StorageSpecifierType::get(stt.getEncoding()); |
| // memref<[batch] x ? x pos> positions |
| const Type posMemType = MemRefType::get(memrefShape, stt.getPosType()); |
| // memref<[batch] x ? x crd> coordinates |
| const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType()); |
| // memref<[batch] x ? x eltType> values |
| const Type valMemType = MemRefType::get(memrefShape, stt.getElementType()); |
| |
| StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType, |
| callback](FieldIndex fieldIdx, |
| SparseTensorFieldKind fieldKind, |
| Level lvl, LevelType lt) -> bool { |
| switch (fieldKind) { |
| case SparseTensorFieldKind::StorageSpec: |
| return callback(specType, fieldIdx, fieldKind, lvl, lt); |
| case SparseTensorFieldKind::PosMemRef: |
| return callback(posMemType, fieldIdx, fieldKind, lvl, lt); |
| case SparseTensorFieldKind::CrdMemRef: |
| return callback(crdMemType, fieldIdx, fieldKind, lvl, lt); |
| case SparseTensorFieldKind::ValMemRef: |
| return callback(valMemType, fieldIdx, fieldKind, lvl, lt); |
| }; |
| llvm_unreachable("unrecognized field kind"); |
| }); |
| } |
| |
| unsigned StorageLayout::getNumFields() const { |
| unsigned numFields = 0; |
| foreachField([&numFields](FieldIndex, SparseTensorFieldKind, Level, |
| LevelType) -> bool { |
| numFields++; |
| return true; |
| }); |
| return numFields; |
| } |
| |
| unsigned StorageLayout::getNumDataFields() const { |
| unsigned numFields = 0; // one value memref |
| foreachField([&numFields](FieldIndex fidx, SparseTensorFieldKind, Level, |
| LevelType) -> bool { |
| if (fidx >= kDataFieldStartingIdx) |
| numFields++; |
| return true; |
| }); |
| numFields -= 1; // the last field is StorageSpecifier |
| assert(numFields == getNumFields() - kDataFieldStartingIdx - 1); |
| return numFields; |
| } |
| |
| std::pair<FieldIndex, unsigned> |
| StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind, |
| std::optional<Level> lvl) const { |
| FieldIndex fieldIdx = kInvalidFieldIndex; |
| unsigned stride = 1; |
| if (kind == SparseTensorFieldKind::CrdMemRef) { |
| assert(lvl.has_value()); |
| const Level cooStart = SparseTensorType(enc).getAoSCOOStart(); |
| const Level lvlRank = enc.getLvlRank(); |
| if (lvl.value() >= cooStart && lvl.value() < lvlRank) { |
| lvl = cooStart; |
| stride = lvlRank - cooStart; |
| } |
| } |
| foreachField([lvl, kind, &fieldIdx](FieldIndex fIdx, |
| SparseTensorFieldKind fKind, Level fLvl, |
| LevelType lt) -> bool { |
| if ((lvl && fLvl == lvl.value() && kind == fKind) || |
| (kind == fKind && fKind == SparseTensorFieldKind::ValMemRef)) { |
| fieldIdx = fIdx; |
| // Returns false to break the iteration. |
| return false; |
| } |
| return true; |
| }); |
| assert(fieldIdx != kInvalidFieldIndex); |
| return std::pair<FieldIndex, unsigned>(fieldIdx, stride); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SparseTensorDialect Attribute Methods. |
| //===----------------------------------------------------------------------===// |
| |
| std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) { |
| return isDynamic(v) ? std::nullopt |
| : std::make_optional(static_cast<uint64_t>(v)); |
| } |
| |
| std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticOffset() const { |
| return getStatic(getOffset()); |
| } |
| |
| std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticStride() const { |
| return getStatic(getStride()); |
| } |
| |
| std::optional<uint64_t> SparseTensorDimSliceAttr::getStaticSize() const { |
| return getStatic(getSize()); |
| } |
| |
| bool SparseTensorDimSliceAttr::isCompletelyDynamic() const { |
| return isDynamic(getOffset()) && isDynamic(getStride()) && |
| isDynamic(getSize()); |
| } |
| |
| std::string SparseTensorDimSliceAttr::getStaticString(int64_t v) { |
| return isDynamic(v) ? "?" : std::to_string(v); |
| } |
| |
| void SparseTensorDimSliceAttr::print(llvm::raw_ostream &os) const { |
| assert(getImpl() && "Uninitialized SparseTensorDimSliceAttr"); |
| os << '('; |
| os << getStaticString(getOffset()); |
| os << ", "; |
| os << getStaticString(getSize()); |
| os << ", "; |
| os << getStaticString(getStride()); |
| os << ')'; |
| } |
| |
| void SparseTensorDimSliceAttr::print(AsmPrinter &printer) const { |
| print(printer.getStream()); |
| } |
| |
| static ParseResult parseOptionalStaticSlice(int64_t &result, |
| AsmParser &parser) { |
| auto parseResult = parser.parseOptionalInteger(result); |
| if (parseResult.has_value()) { |
| if (parseResult.value().succeeded() && result < 0) { |
| parser.emitError( |
| parser.getCurrentLocation(), |
| "expect positive value or ? for slice offset/size/stride"); |
| return failure(); |
| } |
| return parseResult.value(); |
| } |
| |
| // Else, and '?' which represented dynamic slice |
| result = SparseTensorDimSliceAttr::kDynamic; |
| return parser.parseQuestion(); |
| } |
| |
| Attribute SparseTensorDimSliceAttr::parse(AsmParser &parser, Type type) { |
| int64_t offset = kDynamic, size = kDynamic, stride = kDynamic; |
| |
| if (failed(parser.parseLParen()) || |
| failed(parseOptionalStaticSlice(offset, parser)) || |
| failed(parser.parseComma()) || |
| failed(parseOptionalStaticSlice(size, parser)) || |
| failed(parser.parseComma()) || |
| failed(parseOptionalStaticSlice(stride, parser)) || |
| failed(parser.parseRParen())) |
| return {}; |
| |
| return parser.getChecked<SparseTensorDimSliceAttr>(parser.getContext(), |
| offset, size, stride); |
| } |
| |
| LogicalResult |
| SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError, |
| int64_t offset, int64_t size, int64_t stride) { |
| if (!isDynamic(offset) && offset < 0) |
| return emitError() << "expect non-negative value or ? for slice offset"; |
| if (!isDynamic(size) && size <= 0) |
| return emitError() << "expect positive value or ? for slice size"; |
| if (!isDynamic(stride) && stride <= 0) |
| return emitError() << "expect positive value or ? for slice stride"; |
| return success(); |
| } |
| |
| SparseTensorEncodingAttr |
| SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const { |
| assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
| return SparseTensorEncodingAttr::get( |
| getContext(), getLvlTypes(), dimToLvl, AffineMap(), getPosWidth(), |
| getCrdWidth(), getExplicitVal(), getImplicitVal()); |
| } |
| |
| SparseTensorEncodingAttr |
| SparseTensorEncodingAttr::withDimToLvl(SparseTensorEncodingAttr enc) const { |
| return withDimToLvl(enc ? enc.getDimToLvl() : AffineMap()); |
| } |
| |
| SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimToLvl() const { |
| return withDimToLvl(AffineMap()); |
| } |
| |
| SparseTensorEncodingAttr |
| SparseTensorEncodingAttr::withBitWidths(unsigned posWidth, |
| unsigned crdWidth) const { |
| assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
| return SparseTensorEncodingAttr::get( |
| getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), posWidth, |
| crdWidth, getExplicitVal(), getImplicitVal()); |
| } |
| |
| SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutBitWidths() const { |
| return withBitWidths(0, 0); |
| } |
| |
| SparseTensorEncodingAttr |
| SparseTensorEncodingAttr::withExplicitVal(Attribute explicitVal) const { |
| assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
| return SparseTensorEncodingAttr::get( |
| getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), |
| getCrdWidth(), explicitVal, getImplicitVal()); |
| } |
| |
| SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutExplicitVal() const { |
| return withExplicitVal(Attribute()); |
| } |
| |
| SparseTensorEncodingAttr |
| SparseTensorEncodingAttr::withImplicitVal(Attribute implicitVal) const { |
| assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
| return SparseTensorEncodingAttr::get( |
| getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), |
| getCrdWidth(), getExplicitVal(), implicitVal); |
| } |
| |
| SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutImplicitVal() const { |
| return withImplicitVal(Attribute()); |
| } |
| |
| SparseTensorEncodingAttr SparseTensorEncodingAttr::withDimSlices( |
| ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { |
| return SparseTensorEncodingAttr::get( |
| getContext(), getLvlTypes(), getDimToLvl(), getLvlToDim(), getPosWidth(), |
| getCrdWidth(), getExplicitVal(), getImplicitVal(), dimSlices); |
| } |
| |
| SparseTensorEncodingAttr SparseTensorEncodingAttr::withoutDimSlices() const { |
| return withDimSlices(ArrayRef<SparseTensorDimSliceAttr>{}); |
| } |
| |
| uint64_t SparseTensorEncodingAttr::getBatchLvlRank() const { |
| ArrayRef<LevelType> lvlTypes = getLvlTypes(); |
| auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); |
| return std::distance(lastBatch, lvlTypes.rend()); |
| } |
| |
| bool SparseTensorEncodingAttr::isAllDense() const { |
| return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT); |
| } |
| |
| bool SparseTensorEncodingAttr::isAllOrdered() const { |
| return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT); |
| } |
| |
| Type SparseTensorEncodingAttr::getCrdElemType() const { |
| if (!getImpl()) |
| return nullptr; |
| if (getCrdWidth()) |
| return IntegerType::get(getContext(), getCrdWidth()); |
| return IndexType::get(getContext()); |
| } |
| |
| Type SparseTensorEncodingAttr::getPosElemType() const { |
| if (!getImpl()) |
| return nullptr; |
| if (getPosWidth()) |
| return IntegerType::get(getContext(), getPosWidth()); |
| return IndexType::get(getContext()); |
| } |
| |
| MemRefType SparseTensorEncodingAttr::getCrdMemRefType( |
| std::optional<ArrayRef<int64_t>> dimShape) const { |
| SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); |
| return MemRefType::get(shape, getCrdElemType()); |
| } |
| |
| MemRefType SparseTensorEncodingAttr::getPosMemRefType( |
| std::optional<ArrayRef<int64_t>> dimShape) const { |
| SmallVector<Size> shape = getSparseFieldShape(*this, dimShape); |
| return MemRefType::get(shape, getPosElemType()); |
| } |
| |
| bool SparseTensorEncodingAttr::isIdentity() const { |
| return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity(); |
| } |
| |
| bool SparseTensorEncodingAttr::isPermutation() const { |
| return !getImpl() || !getDimToLvl() || getDimToLvl().isPermutation(); |
| } |
| |
| Dimension SparseTensorEncodingAttr::getDimRank() const { |
| assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
| const auto dimToLvl = getDimToLvl(); |
| return dimToLvl ? dimToLvl.getNumDims() : getLvlRank(); |
| } |
| |
| Level SparseTensorEncodingAttr::getLvlRank() const { |
| assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
| return getLvlTypes().size(); |
| } |
| |
| LevelType SparseTensorEncodingAttr::getLvlType(Level l) const { |
| if (!getImpl()) |
| return LevelFormat::Batch; |
| assert(l < getLvlRank() && "Level is out of bounds"); |
| return getLvlTypes()[l]; |
| } |
| |
| bool SparseTensorEncodingAttr::isSlice() const { |
| assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); |
| return !getDimSlices().empty(); |
| } |
| |
| SparseTensorDimSliceAttr |
| SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { |
| assert(isSlice() && "Is not a slice"); |
| const auto dimSlices = getDimSlices(); |
| assert(dim < dimSlices.size() && "Dimension is out of bounds"); |
| return dimSlices[dim]; |
| } |
| |
| std::optional<uint64_t> |
| SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { |
| return getDimSlice(dim).getStaticOffset(); |
| } |
| |
| std::optional<uint64_t> |
| SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { |
| return getDimSlice(dim).getStaticStride(); |
| } |
| |
| std::optional<uint64_t> |
| SparseTensorEncodingAttr::getStaticLvlSliceOffset(Level lvl) const { |
| return getStaticDimSliceOffset(toDim(*this, lvl)); |
| } |
| |
| std::optional<uint64_t> |
| SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const { |
| return getStaticDimSliceStride(toDim(*this, lvl)); |
| } |
| |
| SmallVector<int64_t> |
| SparseTensorEncodingAttr::translateShape(ArrayRef<int64_t> srcShape, |
| CrdTransDirectionKind dir) const { |
| if (isIdentity()) |
| return SmallVector<int64_t>(srcShape); |
| |
| SmallVector<int64_t> ret; |
| unsigned rank = |
| dir == CrdTransDirectionKind::dim2lvl ? getLvlRank() : getDimRank(); |
| ret.reserve(rank); |
| |
| if (isPermutation()) { |
| for (unsigned r = 0; r < rank; r++) { |
| unsigned trans = dir == CrdTransDirectionKind::dim2lvl ? toDim(*this, r) |
| : toLvl(*this, r); |
| ret.push_back(srcShape[trans]); |
| } |
| return ret; |
| } |
| |
| // Handle non-permutation maps. |
| AffineMap transMap = |
| dir == CrdTransDirectionKind::dim2lvl ? getDimToLvl() : getLvlToDim(); |
| |
| SmallVector<AffineExpr> dimRep; |
| dimRep.reserve(srcShape.size()); |
| for (int64_t sz : srcShape) { |
| if (!ShapedType::isDynamic(sz)) { |
| // Push back the max coordinate for the given dimension/level size. |
| dimRep.push_back(getAffineConstantExpr(sz - 1, getContext())); |
| } else { |
| // A dynamic size, use a AffineDimExpr to symbolize the value. |
| dimRep.push_back(getAffineDimExpr(dimRep.size(), getContext())); |
| } |
| }; |
| |
| for (AffineExpr exp : transMap.getResults()) { |
| // Do constant propagation on the affine map. |
| AffineExpr evalExp = |
| simplifyAffineExpr(exp.replaceDims(dimRep), srcShape.size(), 0); |
| // use llvm namespace here to avoid ambiguity |
| if (auto c = llvm::dyn_cast<AffineConstantExpr>(evalExp)) { |
| ret.push_back(c.getValue() + 1); |
| } else { |
| if (auto mod = llvm::dyn_cast<AffineBinaryOpExpr>(evalExp); |
| mod && mod.getKind() == AffineExprKind::Mod) { |
| // We can still infer a static bound for expressions in form |
| // "d % constant" since d % constant \in [0, constant). |
| if (auto bound = llvm::dyn_cast<AffineConstantExpr>(mod.getRHS())) { |
| ret.push_back(bound.getValue()); |
| continue; |
| } |
| } |
| ret.push_back(ShapedType::kDynamic); |
| } |
| } |
| assert(ret.size() == rank); |
| return ret; |
| } |
| |
| ValueRange |
| SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc, |
| ValueRange crds, |
| CrdTransDirectionKind dir) const { |
| if (!getImpl()) |
| return crds; |
| |
| SmallVector<Type> retType( |
| dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(), |
| builder.getIndexType()); |
| auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this); |
| return transOp.getOutCrds(); |
| } |
| |
| Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) { |
| // Open "<{" part. |
| if (failed(parser.parseLess())) |
| return {}; |
| if (failed(parser.parseLBrace())) |
| return {}; |
| |
| // Process the data from the parsed dictionary value into struct-like data. |
| SmallVector<LevelType> lvlTypes; |
| SmallVector<SparseTensorDimSliceAttr> dimSlices; |
| AffineMap dimToLvl = {}; |
| AffineMap lvlToDim = {}; |
| unsigned posWidth = 0; |
| unsigned crdWidth = 0; |
| Attribute explicitVal; |
| Attribute implicitVal; |
| StringRef attrName; |
| SmallVector<StringRef, 5> keys = {"map", "posWidth", "crdWidth", |
| "explicitVal", "implicitVal"}; |
| while (succeeded(parser.parseOptionalKeyword(&attrName))) { |
| // Detect admissible keyword. |
| auto *it = find(keys, attrName); |
| if (it == keys.end()) { |
| parser.emitError(parser.getNameLoc(), "unexpected key: ") << attrName; |
| return {}; |
| } |
| unsigned keyWordIndex = it - keys.begin(); |
| // Consume the `=` after keys |
| if (failed(parser.parseEqual())) |
| return {}; |
| // Dispatch on keyword. |
| switch (keyWordIndex) { |
| case 0: { // map |
| ir_detail::DimLvlMapParser cParser(parser); |
| auto res = cParser.parseDimLvlMap(); |
| if (failed(res)) |
| return {}; |
| const auto &dlm = *res; |
| |
| const Level lvlRank = dlm.getLvlRank(); |
| for (Level lvl = 0; lvl < lvlRank; lvl++) |
| lvlTypes.push_back(dlm.getLvlType(lvl)); |
| |
| const Dimension dimRank = dlm.getDimRank(); |
| for (Dimension dim = 0; dim < dimRank; dim++) |
| dimSlices.push_back(dlm.getDimSlice(dim)); |
| // NOTE: the old syntax requires an all-or-nothing approach to |
| // `dimSlices`; therefore, if any slice actually exists then we need |
| // to convert null-DSA into default/nop DSA. |
| const auto isDefined = [](SparseTensorDimSliceAttr slice) { |
| return static_cast<bool>(slice.getImpl()); |
| }; |
| if (llvm::any_of(dimSlices, isDefined)) { |
| const auto defaultSlice = |
| SparseTensorDimSliceAttr::get(parser.getContext()); |
| for (Dimension dim = 0; dim < dimRank; dim++) |
| if (!isDefined(dimSlices[dim])) |
| dimSlices[dim] = defaultSlice; |
| } else { |
| dimSlices.clear(); |
| } |
| |
| dimToLvl = dlm.getDimToLvlMap(parser.getContext()); |
| lvlToDim = dlm.getLvlToDimMap(parser.getContext()); |
| break; |
| } |
| case 1: { // posWidth |
| Attribute attr; |
| if (failed(parser.parseAttribute(attr))) |
| return {}; |
| auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); |
| if (!intAttr) { |
| parser.emitError(parser.getNameLoc(), |
| "expected an integral position bitwidth"); |
| return {}; |
| } |
| posWidth = intAttr.getInt(); |
| break; |
| } |
| case 2: { // crdWidth |
| Attribute attr; |
| if (failed(parser.parseAttribute(attr))) |
| return {}; |
| auto intAttr = llvm::dyn_cast<IntegerAttr>(attr); |
| if (!intAttr) { |
| parser.emitError(parser.getNameLoc(), |
| "expected an integral index bitwidth"); |
| return {}; |
| } |
| crdWidth = intAttr.getInt(); |
| break; |
| } |
| case 3: { // explicitVal |
| Attribute attr; |
| if (failed(parser.parseAttribute(attr))) |
| return {}; |
| if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { |
| explicitVal = result; |
| } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { |
| explicitVal = result; |
| } else { |
| parser.emitError(parser.getNameLoc(), |
| "expected a numeric value for explicitVal"); |
| return {}; |
| } |
| break; |
| } |
| case 4: { // implicitVal |
| Attribute attr; |
| if (failed(parser.parseAttribute(attr))) |
| return {}; |
| if (auto result = llvm::dyn_cast<FloatAttr>(attr)) { |
| implicitVal = result; |
| } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) { |
| implicitVal = result; |
| } else { |
| parser.emitError(parser.getNameLoc(), |
| "expected a numeric value for implicitVal"); |
| return {}; |
| } |
| break; |
| } |
| } // switch |
| // Only last item can omit the comma. |
| if (parser.parseOptionalComma().failed()) |
| break; |
| } |
| |
| // Close "}>" part. |
| if (failed(parser.parseRBrace())) |
| return {}; |
| if (failed(parser.parseGreater())) |
| return {}; |
| |
| // Construct struct-like storage for attribute. |
| if (!lvlToDim || lvlToDim.isEmpty()) { |
| lvlToDim = inferLvlToDim(dimToLvl, parser.getContext()); |
| } |
| return parser.getChecked<SparseTensorEncodingAttr>( |
| parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth, |
| explicitVal, implicitVal, dimSlices); |
| } |
| |
| void SparseTensorEncodingAttr::print(AsmPrinter &printer) const { |
| auto map = static_cast<AffineMap>(getDimToLvl()); |
| // Empty affine map indicates identity map |
| if (!map) |
| map = AffineMap::getMultiDimIdentityMap(getLvlTypes().size(), getContext()); |
| printer << "<{ map = "; |
| printSymbols(map, printer); |
| printer << '('; |
| printDimensions(map, printer, getDimSlices()); |
| printer << ") -> ("; |
| printLevels(map, printer, getLvlTypes()); |
| printer << ')'; |
| // Print remaining members only for non-default values. |
| if (getPosWidth()) |
| printer << ", posWidth = " << getPosWidth(); |
| if (getCrdWidth()) |
| printer << ", crdWidth = " << getCrdWidth(); |
| if (getExplicitVal()) { |
| printer << ", explicitVal = " << getExplicitVal(); |
| } |
| if (getImplicitVal()) |
| printer << ", implicitVal = " << getImplicitVal(); |
| printer << " }>"; |
| } |
| |
| void SparseTensorEncodingAttr::printSymbols(AffineMap &map, |
| AsmPrinter &printer) const { |
| if (map.getNumSymbols() == 0) |
| return; |
| printer << '['; |
| for (unsigned i = 0, n = map.getNumSymbols() - 1; i < n; i++) |
| printer << 's' << i << ", "; |
| if (map.getNumSymbols() >= 1) |
| printer << 's' << map.getNumSymbols() - 1; |
| printer << ']'; |
| } |
| |
| void SparseTensorEncodingAttr::printDimensions( |
| AffineMap &map, AsmPrinter &printer, |
| ArrayRef<SparseTensorDimSliceAttr> dimSlices) const { |
| if (!dimSlices.empty()) { |
| for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) |
| printer << 'd' << i << " : " << dimSlices[i] << ", "; |
| if (map.getNumDims() >= 1) { |
| printer << 'd' << map.getNumDims() - 1 << " : " |
| << dimSlices[map.getNumDims() - 1]; |
| } |
| } else { |
| for (unsigned i = 0, n = map.getNumDims() - 1; i < n; i++) |
| printer << 'd' << i << ", "; |
| if (map.getNumDims() >= 1) |
| printer << 'd' << map.getNumDims() - 1; |
| } |
| } |
| |
| void SparseTensorEncodingAttr::printLevels(AffineMap &map, AsmPrinter &printer, |
| ArrayRef<LevelType> lvlTypes) const { |
| for (unsigned i = 0, n = map.getNumResults() - 1; i < n; i++) { |
| map.getResult(i).print(printer.getStream()); |
| printer << " : " << toMLIRString(lvlTypes[i]) << ", "; |
| } |
| if (map.getNumResults() >= 1) { |
| auto lastIndex = map.getNumResults() - 1; |
| map.getResult(lastIndex).print(printer.getStream()); |
| printer << " : " << toMLIRString(lvlTypes[lastIndex]); |
| } |
| } |
| |
| LogicalResult SparseTensorEncodingAttr::verify( |
| function_ref<InFlightDiagnostic()> emitError, ArrayRef<LevelType> lvlTypes, |
| AffineMap dimToLvl, AffineMap lvlToDim, unsigned posWidth, |
| unsigned crdWidth, Attribute explicitVal, Attribute implicitVal, |
| ArrayRef<SparseTensorDimSliceAttr> dimSlices) { |
| if (!acceptBitWidth(posWidth)) |
| return emitError() << "unexpected position bitwidth: " << posWidth; |
| if (!acceptBitWidth(crdWidth)) |
| return emitError() << "unexpected coordinate bitwidth: " << crdWidth; |
| if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT); |
| it != std::end(lvlTypes)) { |
| if (it == lvlTypes.begin() || |
| (!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1)))) |
| return emitError() << "expected compressed or loose_compressed level " |
| "before singleton level"; |
| if (!std::all_of(it, lvlTypes.end(), |
| [](LevelType i) { return isSingletonLT(i); })) |
| return emitError() << "expected all singleton lvlTypes " |
| "following a singleton level"; |
| // We can potentially support mixed SoA/AoS singleton levels. |
| if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) { |
| return it->isa<LevelPropNonDefault::SoA>() == |
| i.isa<LevelPropNonDefault::SoA>(); |
| })) { |
| return emitError() << "expected all singleton lvlTypes stored in the " |
| "same memory layout (SoA vs AoS)."; |
| } |
| } |
| |
| auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT); |
| if (!std::all_of(lastBatch, lvlTypes.rend(), isBatchLT)) |
| return emitError() << "Batch lvlType can only be leading levels."; |
| |
| // SoA property can only be applied on singleton level. |
| auto soaLvls = llvm::make_filter_range(lvlTypes, [](LevelType lt) { |
| return lt.isa<LevelPropNonDefault::SoA>(); |
| }); |
| if (llvm::any_of(soaLvls, [](LevelType lt) { |
| return !lt.isa<LevelFormat::Singleton>(); |
| })) { |
| return emitError() << "SoA is only applicable to singleton lvlTypes."; |
| } |
| |
| // TODO: audit formats that actually are supported by backend. |
| if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isNOutOfMLT); |
| it != std::end(lvlTypes)) { |
| if (it != lvlTypes.end() - 1) |
| return emitError() << "expected n_out_of_m to be the last level type"; |
| if (!std::all_of(lvlTypes.begin(), it, |
| [](LevelType i) { return isDenseLT(i); })) |
| return emitError() << "expected all dense lvlTypes " |
| "before a n_out_of_m level"; |
| if (dimToLvl && (dimToLvl.getNumDims() != dimToLvl.getNumResults())) { |
| if (!isBlockSparsity(dimToLvl)) { |
| return emitError() |
| << "expected 1xm block structure for n_out_of_m level"; |
| } |
| auto sizes = getBlockSize(dimToLvl); |
| unsigned coefficient = 0; |
| for (const auto &elem : sizes) { |
| if (elem != 0) { |
| if (elem != coefficient && coefficient != 0) { |
| return emitError() << "expected only one blocked level " |
| "with the same coefficients"; |
| } |
| coefficient = elem; |
| } |
| } |
| if (coefficient != getM(*it)) { |
| return emitError() << "expected coeffiencts of Affine expressions " |
| "to be equal to m of n_out_of_m level"; |
| } |
| } |
| } |
| // Before we can check that the level-rank is consistent/coherent |
| // across all fields, we need to define it. The source-of-truth for |
| // the `getLvlRank` method is the length of the level-types array, |
| // since it must always be provided and have full rank; therefore we |
| // use that same source-of-truth here. |
| const Level lvlRank = lvlTypes.size(); |
| if (lvlRank == 0) |
| return emitError() << "expected a non-empty array for lvlTypes"; |
| // We save `dimRank` here because we'll also need it to verify `dimSlices`. |
| const Dimension dimRank = dimToLvl ? dimToLvl.getNumDims() : lvlRank; |
| if (dimToLvl) { |
| if (dimToLvl.getNumResults() != lvlRank) |
| return emitError() |
| << "level-rank mismatch between dimToLvl and lvlTypes: " |
| << dimToLvl.getNumResults() << " != " << lvlRank; |
| auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext()); |
| // Symbols can't be inferred but are acceptable. |
| if (!inferRes && dimToLvl.getNumSymbols() == 0) |
| return emitError() << "failed to infer lvlToDim from dimToLvl"; |
| if (lvlToDim && (inferRes != lvlToDim)) |
| return emitError() << "expected lvlToDim to be an inverse of dimToLvl"; |
| if (dimRank > lvlRank) |
| return emitError() << "unexpected dimToLvl mapping from " << dimRank |
| << " to " << lvlRank; |
| } |
| if (!dimSlices.empty()) { |
| if (dimSlices.size() != dimRank) |
| return emitError() |
| << "dimension-rank mismatch between dimSlices and dimToLvl: " |
| << dimSlices.size() << " != " << dimRank; |
| // Compiler support for `dimSlices` currently requires that the two |
| // ranks agree. (However, it does allow `dimToLvl` to be a permutation.) |
| if (dimRank != lvlRank) |
| return emitError() |
| << "dimSlices expected dimension-rank to match level-rank: " |
| << dimRank << " != " << lvlRank; |
| } |
| return success(); |
| } |
| |
| LogicalResult SparseTensorEncodingAttr::verifyEncoding( |
| ArrayRef<Size> dimShape, Type elementType, |
| function_ref<InFlightDiagnostic()> emitError) const { |
| // Check structural integrity. In particular, this ensures that the |
| // level-rank is coherent across all the fields. |
| if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(), |
| getPosWidth(), getCrdWidth(), getExplicitVal(), |
| getImplicitVal(), getDimSlices()))) |
| return failure(); |
| // Check integrity with tensor type specifics. In particular, we |
| // need only check that the dimension-rank of the tensor agrees with |
| // the dimension-rank of the encoding. |
| const Dimension dimRank = dimShape.size(); |
| if (dimRank == 0) |
| return emitError() << "expected non-scalar sparse tensor"; |
| if (getDimRank() != dimRank) |
| return emitError() |
| << "dimension-rank mismatch between encoding and tensor shape: " |
| << getDimRank() << " != " << dimRank; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SparseTensorType Methods. |
| //===----------------------------------------------------------------------===// |
| |
| bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl, |
| bool isUnique) const { |
| if (!hasEncoding()) |
| return false; |
| if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl)) |
| return false; |
| for (Level l = startLvl + 1; l < lvlRank; ++l) |
| if (!isSingletonLvl(l)) |
| return false; |
| // If isUnique is true, then make sure that the last level is unique, |
| // that is, when lvlRank == 1, the only compressed level is unique, |
| // and when lvlRank > 1, the last singleton is unique. |
| return !isUnique || isUniqueLvl(lvlRank - 1); |
| } |
| |
| Level mlir::sparse_tensor::SparseTensorType::getAoSCOOStart() const { |
| SmallVector<COOSegment> coo = getCOOSegments(); |
| assert(coo.size() == 1 || coo.empty()); |
| if (!coo.empty() && coo.front().isAoS()) { |
| return coo.front().lvlRange.first; |
| } |
| return lvlRank; |
| } |
| |
| SmallVector<COOSegment> |
| mlir::sparse_tensor::SparseTensorType::getCOOSegments() const { |
| SmallVector<COOSegment> ret; |
| if (!hasEncoding() || lvlRank <= 1) |
| return ret; |
| |
| ArrayRef<LevelType> lts = getLvlTypes(); |
| Level l = 0; |
| while (l < lvlRank) { |
| auto lt = lts[l]; |
| if (lt.isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>()) { |
| auto cur = lts.begin() + l; |
| auto end = std::find_if(cur + 1, lts.end(), [](LevelType lt) { |
| return !lt.isa<LevelFormat::Singleton>(); |
| }); |
| unsigned cooLen = std::distance(cur, end); |
| if (cooLen > 1) { |
| // To support mixed SoA/AoS COO, we should break the segment when the |
| // storage scheme changes, for now we faithfully assume that all |
| // consecutive singleton levels have the same storage format as verified |
| // STEA. |
| ret.push_back(COOSegment{std::make_pair(l, l + cooLen), |
| lts[l + 1].isa<LevelPropNonDefault::SoA>()}); |
| } |
| l += cooLen; |
| } else { |
| l++; |
| } |
| } |
| return ret; |
| } |
| |
| RankedTensorType |
| mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const { |
| SmallVector<LevelType> lvlTypes; |
| lvlTypes.reserve(lvlRank); |
| // A non-unique compressed level at beginning (unless this is |
| // also the last level, then it is unique). |
| lvlTypes.push_back( |
| *buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1)); |
| if (lvlRank > 1) { |
| // Followed by n-2 non-unique singleton levels. |
| std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2, |
| *buildLevelType(LevelFormat::Singleton, ordered, false)); |
| // Ends by a unique singleton level. |
| lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true)); |
| } |
| auto enc = SparseTensorEncodingAttr::get( |
| getContext(), lvlTypes, getDimToLvl(), getLvlToDim(), getPosWidth(), |
| getCrdWidth(), getExplicitVal(), getImplicitVal()); |
| return RankedTensorType::get(getDimShape(), getElementType(), enc); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Convenience Methods. |
| //===----------------------------------------------------------------------===// |
| |
| SparseTensorEncodingAttr |
| mlir::sparse_tensor::getSparseTensorEncoding(Type type) { |
| if (auto ttp = llvm::dyn_cast<RankedTensorType>(type)) |
| return llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(ttp.getEncoding()); |
| if (auto mdtp = llvm::dyn_cast<StorageSpecifierType>(type)) |
| return mdtp.getEncoding(); |
| return nullptr; |
| } |
| |
| AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl, |
| MLIRContext *context) { |
| auto map = static_cast<AffineMap>(dimToLvl); |
| AffineMap lvlToDim; |
| // Return an empty lvlToDim when inference is not successful. |
| if (!map || map.getNumSymbols() != 0) { |
| lvlToDim = AffineMap(); |
| } else if (map.isPermutation()) { |
| lvlToDim = inversePermutation(map); |
| } else if (isBlockSparsity(map)) { |
| lvlToDim = inverseBlockSparsity(map, context); |
| } |
| return lvlToDim; |
| } |
| |
| AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl, |
| MLIRContext *context) { |
| SmallVector<AffineExpr> lvlExprs; |
| auto numLvls = dimToLvl.getNumResults(); |
| lvlExprs.reserve(numLvls); |
| // lvlExprComponents stores information of the floordiv and mod operations |
| // applied to the same dimension, so as to build the lvlToDim map. |
| std::map<unsigned, SmallVector<AffineExpr, 3>> lvlExprComponents; |
| for (unsigned i = 0, n = numLvls; i < n; i++) { |
| auto result = dimToLvl.getResult(i); |
| if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { |
| if (result.getKind() == AffineExprKind::FloorDiv) { |
| // Position of the dimension in dimToLvl. |
| auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); |
| assert(lvlExprComponents.find(pos) == lvlExprComponents.end() && |
| "expected only one floordiv for each dimension"); |
| SmallVector<AffineExpr, 3> components; |
| // Level variable for floordiv. |
| components.push_back(getAffineDimExpr(i, context)); |
| // Multiplier. |
| components.push_back(binOp.getRHS()); |
| // Map key is the position of the dimension. |
| lvlExprComponents[pos] = components; |
| } else if (result.getKind() == AffineExprKind::Mod) { |
| auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition(); |
| assert(lvlExprComponents.find(pos) != lvlExprComponents.end() && |
| "expected floordiv before mod"); |
| // Add level variable for mod to the same vector |
| // of the corresponding floordiv. |
| lvlExprComponents[pos].push_back(getAffineDimExpr(i, context)); |
| } else { |
| assert(false && "expected floordiv or mod"); |
| } |
| } else { |
| lvlExprs.push_back(getAffineDimExpr(i, context)); |
| } |
| } |
| // Build lvlExprs from lvlExprComponents. |
| // For example, for il = i floordiv 2 and ii = i mod 2, the components |
| // would be [il, 2, ii]. It could be used to build the AffineExpr |
| // i = il * 2 + ii in lvlToDim. |
| for (auto &components : lvlExprComponents) { |
| assert(components.second.size() == 3 && |
| "expected 3 components to build lvlExprs"); |
| auto mulOp = getAffineBinaryOpExpr( |
| AffineExprKind::Mul, components.second[0], components.second[1]); |
| auto addOp = |
| getAffineBinaryOpExpr(AffineExprKind::Add, mulOp, components.second[2]); |
| lvlExprs.push_back(addOp); |
| } |
| return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context); |
| } |
| |
| SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) { |
| assert(isBlockSparsity(dimToLvl) && |
| "expected dimToLvl to be block sparsity for calling getBlockSize"); |
| SmallVector<unsigned> blockSize; |
| for (auto result : dimToLvl.getResults()) { |
| if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { |
| if (result.getKind() == AffineExprKind::Mod) { |
| blockSize.push_back( |
| dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue()); |
| } |
| } else { |
| blockSize.push_back(0); |
| } |
| } |
| return blockSize; |
| } |
| |
| bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) { |
| if (!dimToLvl) |
| return false; |
| std::map<unsigned, int64_t> coeffientMap; |
| bool hasBlock = false; |
| for (auto result : dimToLvl.getResults()) { |
| if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) { |
| // Check for "dim op const". |
| auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS()); |
| auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS()); |
| if (!dimOp || !conOp || conOp.getValue() <= 0) |
| return false; |
| // Inspect "dim / const" or "dim % const". |
| auto pos = dimOp.getPosition(); |
| if (binOp.getKind() == AffineExprKind::FloorDiv) { |
| // Expect only one floordiv for each dimension. |
| if (coeffientMap.find(pos) != coeffientMap.end()) |
| return false; |
| // Record coefficient of the floordiv. |
| coeffientMap[pos] = conOp.getValue(); |
| } else if (binOp.getKind() == AffineExprKind::Mod) { |
| // Expect floordiv before mod. |
| if (coeffientMap.find(pos) == coeffientMap.end()) |
| return false; |
| // Expect mod to have the same coefficient as floordiv. |
| if (conOp.getValue() != coeffientMap[pos]) |
| return false; |
| hasBlock = true; |
| } else { |
| return false; |
| } |
| } else if (auto dimOp = dyn_cast<AffineDimExpr>(result)) { |
| auto pos = dimOp.getPosition(); |
| // Expect dim to be unset. |
| if (coeffientMap.find(pos) != coeffientMap.end()) |
| return false; |
| coeffientMap[pos] = 0; |
| } else { |
| return false; |
| } |
| } |
| return hasBlock; |
| } |
| |
| bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) { |
| auto hasNonIdentityMap = [](Value v) { |
| auto stt = tryGetSparseTensorType(v); |
| return stt && !stt->isIdentity(); |
| }; |
| |
| return llvm::any_of(op->getOperands(), hasNonIdentityMap) || |
| llvm::any_of(op->getResults(), hasNonIdentityMap); |
| } |
| |
| Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) { |
| if (enc) { |
| assert(enc.isPermutation() && "Non permutation map not supported"); |
| if (const auto dimToLvl = enc.getDimToLvl()) |
| return dimToLvl.getDimPosition(l); |
| } |
| return l; |
| } |
| |
| Level mlir::sparse_tensor::toLvl(SparseTensorEncodingAttr enc, Dimension d) { |
| if (enc) { |
| assert(enc.isPermutation() && "Non permutation map not supported"); |
| if (const auto lvlToDim = enc.getLvlToDim()) |
| return lvlToDim.getDimPosition(d); |
| } |
| return d; |
| } |
| |
| /// We normalized sparse tensor encoding attribute by always using |
| /// ordered/unique LT such that "compressed_nu_no" and "compressed_nu" (as well |
| /// as other variants) lead to the same storage specifier type, and stripping |
| /// irrelevant fields that do not alter the sparse tensor memory layout. |
| static SparseTensorEncodingAttr |
| getNormalizedEncodingForSpecifier(SparseTensorEncodingAttr enc) { |
| SmallVector<LevelType> lts; |
| for (auto lt : enc.getLvlTypes()) |
| lts.push_back(lt.stripStorageIrrelevantProperties()); |
| |
| return SparseTensorEncodingAttr::get( |
| enc.getContext(), lts, |
| AffineMap(), // dimToLvl (irrelevant to storage specifier) |
| AffineMap(), // lvlToDim (irrelevant to storage specifier) |
| // Always use `index` for memSize and lvlSize instead of reusing |
| // `getPosWidth` and `getCrdWidth`. It allows us to reuse the same SSA |
| // value for different bitwidth, it also avoids casting between index and |
| // integer (returned by DimOp) |
| 0, 0, |
| Attribute(), // explicitVal (irrelevant to storage specifier) |
| Attribute(), // implicitVal (irrelevant to storage specifier) |
| enc.getDimSlices()); |
| } |
| |
| StorageSpecifierType |
| StorageSpecifierType::get(MLIRContext *ctx, SparseTensorEncodingAttr encoding) { |
| return Base::get(ctx, getNormalizedEncodingForSpecifier(encoding)); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SparseTensorDialect Operations. |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult lvlIsInBounds(Level lvl, Value tensor) { |
| return success(lvl < getSparseTensorType(tensor).getLvlRank()); |
| } |
| |
| static LogicalResult isMatchingWidth(Value mem, unsigned width) { |
| const Type etp = getMemRefType(mem).getElementType(); |
| return success(width == 0 ? etp.isIndex() : etp.isInteger(width)); |
| } |
| |
| static LogicalResult verifySparsifierGetterSetter( |
| StorageSpecifierKind mdKind, std::optional<Level> lvl, |
| TypedValue<StorageSpecifierType> md, Operation *op) { |
| if (mdKind == StorageSpecifierKind::ValMemSize && lvl) { |
| return op->emitError( |
| "redundant level argument for querying value memory size"); |
| } |
| |
| const auto enc = md.getType().getEncoding(); |
| const Level lvlRank = enc.getLvlRank(); |
| |
| if (mdKind == StorageSpecifierKind::DimOffset || |
| mdKind == StorageSpecifierKind::DimStride) |
| if (!enc.isSlice()) |
| return op->emitError("requested slice data on non-slice tensor"); |
| |
| if (mdKind != StorageSpecifierKind::ValMemSize) { |
| if (!lvl) |
| return op->emitError("missing level argument"); |
| |
| const Level l = lvl.value(); |
| if (l >= lvlRank) |
| return op->emitError("requested level is out of bounds"); |
| |
| if (mdKind == StorageSpecifierKind::PosMemSize && enc.isSingletonLvl(l)) |
| return op->emitError( |
| "requested position memory size on a singleton level"); |
| } |
| return success(); |
| } |
| |
| static Type getFieldElemType(SparseTensorType stt, SparseTensorFieldKind kind) { |
| switch (kind) { |
| case SparseTensorFieldKind::CrdMemRef: |
| return stt.getCrdType(); |
| case SparseTensorFieldKind::PosMemRef: |
| return stt.getPosType(); |
| case SparseTensorFieldKind::ValMemRef: |
| return stt.getElementType(); |
| case SparseTensorFieldKind::StorageSpec: |
| return nullptr; |
| } |
| llvm_unreachable("Unrecognizable FieldKind"); |
| } |
| |
| static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape, |
| SparseTensorType stt, |
| RankedTensorType valTp, |
| TypeRange lvlTps) { |
| if (requiresStaticShape && !stt.hasStaticDimShape()) |
| return op->emitError("the sparse-tensor must have static shape"); |
| if (!stt.hasEncoding()) |
| return op->emitError("the sparse-tensor must have an encoding attribute"); |
| |
| // Verifies the trailing COO. |
| Level cooStartLvl = stt.getAoSCOOStart(); |
| if (cooStartLvl < stt.getLvlRank()) { |
| // We only supports trailing COO for now, must be the last input. |
| auto cooTp = llvm::cast<ShapedType>(lvlTps.back()); |
| // The coordinates should be in shape of <? x rank> |
| unsigned expCOORank = stt.getLvlRank() - cooStartLvl; |
| if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) { |
| op->emitError("input/output trailing COO level-ranks don't match"); |
| } |
| } |
| |
| // Verifies that all types match. |
| StorageLayout layout(stt.getEncoding()); |
| if (layout.getNumDataFields() != lvlTps.size() + 1) // plus one value memref |
| return op->emitError("inconsistent number of fields between input/output"); |
| |
| unsigned idx = 0; |
| bool misMatch = false; |
| layout.foreachField([&idx, &misMatch, stt, valTp, |
| lvlTps](FieldIndex fid, SparseTensorFieldKind fKind, |
| Level lvl, LevelType lt) -> bool { |
| if (fKind == SparseTensorFieldKind::StorageSpec) |
| return true; |
| |
| Type inputTp = nullptr; |
| if (fKind == SparseTensorFieldKind::ValMemRef) { |
| inputTp = valTp; |
| } else { |
| assert(fid == idx && stt.getLvlType(lvl) == lt); |
| inputTp = lvlTps[idx++]; |
| } |
| // The input element type and expected element type should match. |
| Type inpElemTp = llvm::cast<TensorType>(inputTp).getElementType(); |
| Type expElemTp = getFieldElemType(stt, fKind); |
| if (inpElemTp != expElemTp) { |
| misMatch = true; |
| return false; // to terminate the iteration |
| } |
| return true; |
| }); |
| |
| if (misMatch) |
| return op->emitError("input/output element-types don't match"); |
| return success(); |
| } |
| |
| LogicalResult AssembleOp::verify() { |
| const auto valuesTp = getRankedTensorType(getValues()); |
| const auto lvlsTp = getLevels().getTypes(); |
| const auto resTp = getSparseTensorType(getResult()); |
| return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp); |
| } |
| |
| LogicalResult DisassembleOp::verify() { |
| if (getOutValues().getType() != getRetValues().getType()) |
| return emitError("output values and return value type mismatch"); |
| |
| for (auto [ot, rt] : llvm::zip_equal(getOutLevels(), getRetLevels())) |
| if (ot.getType() != rt.getType()) |
| return emitError("output levels and return levels type mismatch"); |
| |
| const auto valuesTp = getRankedTensorType(getRetValues()); |
| const auto lvlsTp = getRetLevels().getTypes(); |
| const auto srcTp = getSparseTensorType(getTensor()); |
| return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp); |
| } |
| |
| LogicalResult ConvertOp::verify() { |
| if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) { |
| if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) { |
| if (tp1.getRank() != tp2.getRank()) |
| return emitError("unexpected conversion mismatch in rank"); |
| auto dstEnc = |
| llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding()); |
| if (dstEnc && dstEnc.isSlice()) |
| return emitError("cannot convert to a sparse tensor slice"); |
| |
| auto shape1 = tp1.getShape(); |
| auto shape2 = tp2.getShape(); |
| // Accept size matches between the source and the destination type |
| // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or |
| // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). |
| for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++) |
| if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic) |
| return emitError("unexpected conversion mismatch in dimension ") << d; |
| return success(); |
| } |
| } |
| return emitError("unexpected type in convert"); |
| } |
| |
| OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { |
| if (getType() == getSource().getType()) |
| return getSource(); |
| return {}; |
| } |
| |
| bool ConvertOp::needsExtraSort() { |
| SparseTensorType srcStt = getSparseTensorType(getSource()); |
| SparseTensorType dstStt = getSparseTensorType(getDest()); |
| |
| // We do not need an extra sort when returning unordered sparse tensors or |
| // dense tensor since dense tensor support random access. |
| if (dstStt.isAllDense() || !dstStt.isAllOrdered()) |
| return false; |
| |
| if (srcStt.isAllOrdered() && dstStt.isAllOrdered() && |
| srcStt.hasSameDimToLvl(dstStt)) { |
| return false; |
| } |
| |
| // Source and dest tensors are ordered in different ways. We only do direct |
| // dense to sparse conversion when the dense input is defined by a sparse |
| // constant. Note that we can theoretically always directly convert from dense |
| // inputs by rotating dense loops but it leads to bad cache locality and hurt |
| // performance. |
| if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>()) |
| if (isa<SparseElementsAttr>(constOp.getValue())) |
| return false; |
| |
| return true; |
| } |
| |
| LogicalResult CrdTranslateOp::verify() { |
| uint64_t inRank = getEncoder().getLvlRank(); |
| uint64_t outRank = getEncoder().getDimRank(); |
| |
| if (getDirection() == CrdTransDirectionKind::dim2lvl) |
| std::swap(inRank, outRank); |
| |
| if (inRank != getInCrds().size() || outRank != getOutCrds().size()) |
| return emitError("Coordinate rank mismatch with encoding"); |
| |
| return success(); |
| } |
| |
| LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor, |
| SmallVectorImpl<OpFoldResult> &results) { |
| if (getEncoder().isIdentity()) { |
| results.assign(getInCrds().begin(), getInCrds().end()); |
| return success(); |
| } |
| if (getEncoder().isPermutation()) { |
| AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl |
| ? getEncoder().getDimToLvl() |
| : getEncoder().getLvlToDim(); |
| for (AffineExpr exp : perm.getResults()) |
| results.push_back(getInCrds()[cast<AffineDimExpr>(exp).getPosition()]); |
| return success(); |
| } |
| |
| // Fuse dim2lvl/lvl2dim pairs. |
| auto def = getInCrds()[0].getDefiningOp<CrdTranslateOp>(); |
| bool sameDef = def && llvm::all_of(getInCrds(), [def](Value v) { |
| return v.getDefiningOp() == def; |
| }); |
| if (!sameDef) |
| return failure(); |
| |
| bool oppositeDir = def.getDirection() != getDirection(); |
| bool sameOracle = |
| def.getEncoder().getDimToLvl() == getEncoder().getDimToLvl(); |
| bool sameCount = def.getNumResults() == getInCrds().size(); |
| if (!oppositeDir || !sameOracle || !sameCount) |
| return failure(); |
| |
| // The definition produces the coordinates in the same order as the input |
| // coordinates. |
| bool sameOrder = llvm::all_of(llvm::zip_equal(def.getOutCrds(), getInCrds()), |
| [](auto valuePair) { |
| auto [lhs, rhs] = valuePair; |
| return lhs == rhs; |
| }); |
| |
| if (!sameOrder) |
| return failure(); |
| // l1 = dim2lvl (lvl2dim l0) |
| // ==> l0 |
| results.append(def.getInCrds().begin(), def.getInCrds().end()); |
| return success(); |
| } |
| |
| void LvlOp::build(OpBuilder &builder, OperationState &state, Value source, |
| int64_t index) { |
| Value val = builder.create<arith::ConstantIndexOp>(state.location, index); |
| return build(builder, state, source, val); |
| } |
| |
| LogicalResult LvlOp::verify() { |
| if (std::optional<uint64_t> lvl = getConstantLvlIndex()) { |
| auto stt = getSparseTensorType(getSource()); |
| if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank()) |
| emitError("Level index exceeds the rank of the input sparse tensor"); |
| } |
| return success(); |
| } |
| |
| std::optional<uint64_t> LvlOp::getConstantLvlIndex() { |
| return getConstantIntValue(getIndex()); |
| } |
| |
| Speculation::Speculatability LvlOp::getSpeculatability() { |
| auto constantIndex = getConstantLvlIndex(); |
| if (!constantIndex) |
| return Speculation::NotSpeculatable; |
| |
| assert(constantIndex < |
| cast<RankedTensorType>(getSource().getType()).getRank()); |
| return Speculation::Speculatable; |
| } |
| |
| OpFoldResult LvlOp::fold(FoldAdaptor adaptor) { |
| auto lvlIndex = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex()); |
| if (!lvlIndex) |
| return {}; |
| |
| Level lvl = lvlIndex.getAPSInt().getZExtValue(); |
| auto stt = getSparseTensorType(getSource()); |
| if (lvl >= stt.getLvlRank()) { |
| // Follows the same convention used by tensor.dim operation. Out of bound |
| // indices produce undefined behavior but are still valid IR. Don't choke on |
| // them. |
| return {}; |
| } |
| |
| // Helper lambda to build an IndexAttr. |
| auto getIndexAttr = [this](int64_t lvlSz) { |
| return IntegerAttr::get(IndexType::get(getContext()), APInt(64, lvlSz)); |
| }; |
| |
| SmallVector<Size> lvlShape = stt.getLvlShape(); |
| if (!ShapedType::isDynamic(lvlShape[lvl])) |
| return getIndexAttr(lvlShape[lvl]); |
| |
| return {}; |
| } |
| |
| void ReinterpretMapOp::build(OpBuilder &odsBuilder, OperationState &odsState, |
| SparseTensorEncodingAttr dstEnc, Value source) { |
| auto srcStt = getSparseTensorType(source); |
| SmallVector<int64_t> srcLvlShape = srcStt.getLvlShape(); |
| SmallVector<int64_t> dstDimShape = |
| dstEnc.translateShape(srcLvlShape, CrdTransDirectionKind::lvl2dim); |
| auto dstTp = |
| RankedTensorType::get(dstDimShape, srcStt.getElementType(), dstEnc); |
| return build(odsBuilder, odsState, dstTp, source); |
| } |
| |
| LogicalResult ReinterpretMapOp::verify() { |
| auto srcStt = getSparseTensorType(getSource()); |
| auto dstStt = getSparseTensorType(getDest()); |
| ArrayRef<LevelType> srcLvlTps = srcStt.getLvlTypes(); |
| ArrayRef<LevelType> dstLvlTps = dstStt.getLvlTypes(); |
| |
| if (srcLvlTps.size() != dstLvlTps.size()) |
| return emitError("Level rank mismatch between source/dest tensors"); |
| |
| for (auto [srcLvlTp, dstLvlTp] : llvm::zip(srcLvlTps, dstLvlTps)) |
| if (srcLvlTp != dstLvlTp) |
| return emitError("Level type mismatch between source/dest tensors"); |
| |
| if (srcStt.getPosWidth() != dstStt.getPosWidth() || |
| srcStt.getCrdWidth() != dstStt.getCrdWidth()) { |
| return emitError("Crd/Pos width mismatch between source/dest tensors"); |
| } |
| |
| if (srcStt.getElementType() != dstStt.getElementType()) |
| return emitError("Element type mismatch between source/dest tensors"); |
| |
| SmallVector<Size> srcLvlShape = srcStt.getLvlShape(); |
| SmallVector<Size> dstLvlShape = dstStt.getLvlShape(); |
| for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) { |
| if (srcLvlSz != dstLvlSz) { |
| // Should we allow one side to be dynamic size, e.g., <?x?> should be |
| // compatible to <3x4>? For now, we require all the level sizes to be |
| // *exactly* matched for simplicity. |
| return emitError("Level size mismatch between source/dest tensors"); |
| } |
| } |
| |
| return success(); |
| } |
| |
| OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) { |
| if (getSource().getType() == getDest().getType()) |
| return getSource(); |
| |
| if (auto def = getSource().getDefiningOp<ReinterpretMapOp>()) { |
| // A -> B, B -> A ==> A |
| if (def.getSource().getType() == getDest().getType()) |
| return def.getSource(); |
| } |
| return {}; |
| } |
| |
| template <typename ToBufferOp> |
| static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr, |
| OpaqueProperties prop, |
| RegionRange region, |
| SmallVectorImpl<mlir::Type> &ret) { |
| typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region); |
| SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); |
| Type elemTp = nullptr; |
| bool withStride = false; |
| if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) { |
| elemTp = stt.getPosType(); |
| } else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> || |
| std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) { |
| elemTp = stt.getCrdType(); |
| if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>) |
| withStride = stt.getAoSCOOStart() <= adaptor.getLevel(); |
| } else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) { |
| elemTp = stt.getElementType(); |
| } |
| |
| assert(elemTp && "unhandled operation."); |
| SmallVector<int64_t> bufShape = stt.getBatchLvlShape(); |
| bufShape.push_back(ShapedType::kDynamic); |
| |
| auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get( |
| stt.getContext(), ShapedType::kDynamic, |
| {ShapedType::kDynamic}) |
| : StridedLayoutAttr(); |
| ret.emplace_back(MemRefType::get(bufShape, elemTp, layout)); |
| return success(); |
| } |
| |
| LogicalResult ToPositionsOp::verify() { |
| auto stt = getSparseTensorType(getTensor()); |
| if (failed(lvlIsInBounds(getLevel(), getTensor()))) |
| return emitError("requested level is out of bounds"); |
| if (failed(isMatchingWidth(getResult(), stt.getPosWidth()))) |
| return emitError("unexpected type for positions"); |
| return success(); |
| } |
| |
| LogicalResult |
| ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, |
| ValueRange ops, DictionaryAttr attr, |
| OpaqueProperties prop, RegionRange region, |
| SmallVectorImpl<mlir::Type> &ret) { |
| return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret); |
| } |
| |
| LogicalResult ToCoordinatesOp::verify() { |
| auto stt = getSparseTensorType(getTensor()); |
| if (failed(lvlIsInBounds(getLevel(), getTensor()))) |
| return emitError("requested level is out of bounds"); |
| if (failed(isMatchingWidth(getResult(), stt.getCrdWidth()))) |
| return emitError("unexpected type for coordinates"); |
| return success(); |
| } |
| |
| LogicalResult |
| ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, |
| ValueRange ops, DictionaryAttr attr, |
| OpaqueProperties prop, RegionRange region, |
| SmallVectorImpl<mlir::Type> &ret) { |
| return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret); |
| } |
| |
| LogicalResult ToCoordinatesBufferOp::verify() { |
| auto stt = getSparseTensorType(getTensor()); |
| if (stt.getAoSCOOStart() >= stt.getLvlRank()) |
| return emitError("expected sparse tensor with a COO region"); |
| return success(); |
| } |
| |
| LogicalResult ToCoordinatesBufferOp::inferReturnTypes( |
| MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, |
| DictionaryAttr attr, OpaqueProperties prop, RegionRange region, |
| SmallVectorImpl<mlir::Type> &ret) { |
| return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region, |
| ret); |
| } |
| |
| LogicalResult ToValuesOp::verify() { |
| auto stt = getSparseTensorType(getTensor()); |
| auto mtp = getMemRefType(getResult()); |
| if (stt.getElementType() != mtp.getElementType()) |
| return emitError("unexpected mismatch in element types"); |
| return success(); |
| } |
| |
| LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx, |
| std::optional<Location> loc, |
| ValueRange ops, DictionaryAttr attr, |
| OpaqueProperties prop, |
| RegionRange region, |
| SmallVectorImpl<mlir::Type> &ret) { |
| return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret); |
| } |
| |
| LogicalResult ToSliceOffsetOp::verify() { |
| auto rank = getRankedTensorType(getSlice()).getRank(); |
| if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) |
| return emitError("requested dimension out of bound"); |
| return success(); |
| } |
| |
| LogicalResult ToSliceStrideOp::verify() { |
| auto rank = getRankedTensorType(getSlice()).getRank(); |
| if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) |
| return emitError("requested dimension out of bound"); |
| return success(); |
| } |
| |
| LogicalResult GetStorageSpecifierOp::verify() { |
| return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), |
| getSpecifier(), getOperation()); |
| } |
| |
| template <typename SpecifierOp> |
| static SetStorageSpecifierOp getSpecifierSetDef(SpecifierOp op) { |
| return op.getSpecifier().template getDefiningOp<SetStorageSpecifierOp>(); |
| } |
| |
| OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { |
| const StorageSpecifierKind kind = getSpecifierKind(); |
| const auto lvl = getLevel(); |
| for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) |
| if (kind == op.getSpecifierKind() && lvl == op.getLevel()) |
| return op.getValue(); |
| return {}; |
| } |
| |
| LogicalResult SetStorageSpecifierOp::verify() { |
| return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(), |
| getSpecifier(), getOperation()); |
| } |
| |
| template <class T> |
| static LogicalResult verifyNumBlockArgs(T *op, Region ®ion, |
| const char *regionName, |
| TypeRange inputTypes, Type outputType) { |
| unsigned numArgs = region.getNumArguments(); |
| unsigned expectedNum = inputTypes.size(); |
| if (numArgs != expectedNum) |
| return op->emitError() << regionName << " region must have exactly " |
| << expectedNum << " arguments"; |
| |
| for (unsigned i = 0; i < numArgs; i++) { |
| Type typ = region.getArgument(i).getType(); |
| if (typ != inputTypes[i]) |
| return op->emitError() << regionName << " region argument " << (i + 1) |
| << " type mismatch"; |
| } |
| Operation *term = region.front().getTerminator(); |
| YieldOp yield = dyn_cast<YieldOp>(term); |
| if (!yield) |
| return op->emitError() << regionName |
| << " region must end with sparse_tensor.yield"; |
| if (!yield.hasSingleResult() || |
| yield.getSingleResult().getType() != outputType) |
| return op->emitError() << regionName << " region yield type mismatch"; |
| |
| return success(); |
| } |
| |
| LogicalResult BinaryOp::verify() { |
| NamedAttrList attrs = (*this)->getAttrs(); |
| Type leftType = getX().getType(); |
| Type rightType = getY().getType(); |
| Type outputType = getOutput().getType(); |
| Region &overlap = getOverlapRegion(); |
| Region &left = getLeftRegion(); |
| Region &right = getRightRegion(); |
| |
| // Check correct number of block arguments and return type for each |
| // non-empty region. |
| if (!overlap.empty()) { |
| if (failed(verifyNumBlockArgs(this, overlap, "overlap", |
| TypeRange{leftType, rightType}, outputType))) |
| return failure(); |
| } |
| if (!left.empty()) { |
| if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, |
| outputType))) |
| return failure(); |
| } else if (getLeftIdentity()) { |
| if (leftType != outputType) |
| return emitError("left=identity requires first argument to have the same " |
| "type as the output"); |
| } |
| if (!right.empty()) { |
| if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType}, |
| outputType))) |
| return failure(); |
| } else if (getRightIdentity()) { |
| if (rightType != outputType) |
| return emitError("right=identity requires second argument to have the " |
| "same type as the output"); |
| } |
| return success(); |
| } |
| |
| LogicalResult UnaryOp::verify() { |
| Type inputType = getX().getType(); |
| Type outputType = getOutput().getType(); |
| |
| // Check correct number of block arguments and return type for each |
| // non-empty region. |
| Region &present = getPresentRegion(); |
| if (!present.empty()) { |
| if (failed(verifyNumBlockArgs(this, present, "present", |
| TypeRange{inputType}, outputType))) |
| return failure(); |
| } |
| Region &absent = getAbsentRegion(); |
| if (!absent.empty()) { |
| if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{}, |
| outputType))) |
| return failure(); |
| // Absent branch can only yield invariant values. |
| Block *absentBlock = &absent.front(); |
| Block *parent = getOperation()->getBlock(); |
| Value absentVal = |
| cast<YieldOp>(absentBlock->getTerminator()).getSingleResult(); |
| if (auto arg = dyn_cast<BlockArgument>(absentVal)) { |
| if (arg.getOwner() == parent) |
| return emitError("absent region cannot yield linalg argument"); |
| } else if (Operation *def = absentVal.getDefiningOp()) { |
| if (!isa<arith::ConstantOp>(def) && |
| (def->getBlock() == absentBlock || def->getBlock() == parent)) |
| return emitError("absent region cannot yield locally computed value"); |
| } |
| } |
| return success(); |
| } |
| |
| bool ConcatenateOp::needsExtraSort() { |
| SparseTensorType dstStt = getSparseTensorType(*this); |
| if (dstStt.isAllDense() || !dstStt.isAllOrdered()) |
| return false; |
| |
| bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) { |
| return getSparseTensorType(op).hasSameDimToLvl(dstStt); |
| }); |
| // TODO: When conDim != 0, as long as conDim corresponding to the first level |
| // in all input/output buffers, and all input/output buffers have the same |
| // dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate |
| // CSC matrices along column). |
| bool directLowerable = |
| allSameOrdered && getDimension() == 0 && dstStt.isIdentity(); |
| return !directLowerable; |
| } |
| |
| LogicalResult ConcatenateOp::verify() { |
| const auto dstTp = getSparseTensorType(*this); |
| const Dimension concatDim = getDimension(); |
| const Dimension dimRank = dstTp.getDimRank(); |
| |
| if (getInputs().size() <= 1) |
| return emitError("Need at least two tensors to concatenate."); |
| |
| if (concatDim >= dimRank) |
| return emitError(llvm::formatv( |
| "Concat-dimension is out of bounds for dimension-rank ({0} >= {1})", |
| concatDim, dimRank)); |
| |
| for (const auto &it : llvm::enumerate(getInputs())) { |
| const auto i = it.index(); |
| const auto srcTp = getSparseTensorType(it.value()); |
| if (srcTp.hasDynamicDimShape()) |
| return emitError(llvm::formatv("Input tensor ${0} has dynamic shape", i)); |
| const Dimension srcDimRank = srcTp.getDimRank(); |
| if (srcDimRank != dimRank) |
| return emitError( |
| llvm::formatv("Input tensor ${0} has a different rank (rank={1}) " |
| "from the output tensor (rank={2}).", |
| i, srcDimRank, dimRank)); |
| } |
| |
| for (Dimension d = 0; d < dimRank; d++) { |
| const Size dstSh = dstTp.getDimShape()[d]; |
| if (d == concatDim) { |
| if (!ShapedType::isDynamic(dstSh)) { |
| // If we reach here, then all inputs have static shapes. So we |
| // can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)` |
| // to avoid redundant assertions in the loop. |
| Size sumSz = 0; |
| for (const auto src : getInputs()) |
| sumSz += getSparseTensorType(src).getDimShape()[d]; |
| // If all dimension are statically known, the sum of all the input |
| // dimensions should be equal to the output dimension. |
| if (sumSz != dstSh) |
| return emitError( |
| "The concatenation dimension of the output tensor should be the " |
| "sum of all the concatenation dimensions of the input tensors."); |
| } |
| } else { |
| Size prev = dstSh; |
| for (const auto src : getInputs()) { |
| const auto sh = getSparseTensorType(src).getDimShape()[d]; |
| if (!ShapedType::isDynamic(prev) && sh != prev) |
| return emitError("All dimensions (expect for the concatenating one) " |
| "should be equal."); |
| prev = sh; |
| } |
| } |
| } |
| |
| return success(); |
| } |
| |
| void PushBackOp::build(OpBuilder &builder, OperationState &result, |
| Value curSize, Value inBuffer, Value value) { |
| build(builder, result, curSize, inBuffer, value, Value()); |
| } |
| |
| LogicalResult PushBackOp::verify() { |
| if (Value n = getN()) { |
| std::optional<int64_t> nValue = getConstantIntValue(n); |
| if (nValue && nValue.value() < 1) |
| return emitOpError("n must be not less than 1"); |
| } |
| return success(); |
| } |
| |
| LogicalResult CompressOp::verify() { |
| const auto stt = getSparseTensorType(getTensor()); |
| if (stt.getLvlRank() != 1 + static_cast<Level>(getLvlCoords().size())) |
| return emitOpError("incorrect number of coordinates"); |
| return success(); |
| } |
| |
| void ForeachOp::build( |
| OpBuilder &builder, OperationState &result, Value tensor, |
| ValueRange initArgs, AffineMapAttr order, |
| function_ref<void(OpBuilder &, Location, ValueRange, Value, ValueRange)> |
| bodyBuilder) { |
| build(builder, result, initArgs.getTypes(), tensor, initArgs, order); |
| // Builds foreach body. |
| if (!bodyBuilder) |
| return; |
| const auto stt = getSparseTensorType(tensor); |
| const Dimension dimRank = stt.getDimRank(); |
| |
| // Starts with `dimRank`-many coordinates. |
| SmallVector<Type> blockArgTypes(dimRank, builder.getIndexType()); |
| // Followed by one value. |
| blockArgTypes.push_back(stt.getElementType()); |
| // Followed by the reduction variables. |
| blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end()); |
| |
| SmallVector<Location> blockArgLocs(blockArgTypes.size(), tensor.getLoc()); |
| |
| OpBuilder::InsertionGuard guard(builder); |
| auto ®ion = *result.regions.front(); |
| Block *bodyBlock = |
| builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs); |
| bodyBuilder(builder, result.location, |
| bodyBlock->getArguments().slice(0, dimRank), |
| bodyBlock->getArguments()[dimRank], |
| bodyBlock->getArguments().drop_front(dimRank + 1)); |
| } |
| |
| LogicalResult ForeachOp::verify() { |
| const auto t = getSparseTensorType(getTensor()); |
| const Dimension dimRank = t.getDimRank(); |
| const auto args = getBody()->getArguments(); |
| |
| if (getOrder().has_value() && getOrder()->getNumDims() != t.getLvlRank()) |
| return emitError("Level traverse order does not match tensor's level rank"); |
| |
| if (dimRank + 1 + getInitArgs().size() != args.size()) |
| return emitError("Unmatched number of arguments in the block"); |
| |
| if (getNumResults() != getInitArgs().size()) |
| return emitError("Mismatch in number of init arguments and results"); |
| |
| if (getResultTypes() != getInitArgs().getTypes()) |
| return emitError("Mismatch in types of init arguments and results"); |
| |
| // Cannot mark this const, because the getters aren't. |
| auto yield = cast<YieldOp>(getBody()->getTerminator()); |
| if (yield.getNumOperands() != getNumResults() || |
| yield.getOperands().getTypes() != getResultTypes()) |
| return emitError("Mismatch in types of yield values and results"); |
| |
| const auto iTp = IndexType::get(getContext()); |
| for (Dimension d = 0; d < dimRank; d++) |
| if (args[d].getType() != iTp) |
| emitError( |
| llvm::formatv("Expecting Index type for argument at index {0}", d)); |
| |
| const auto elemTp = t.getElementType(); |
| const auto valueTp = args[dimRank].getType(); |
| if (elemTp != valueTp) |
| emitError(llvm::formatv("Unmatched element type between input tensor and " |
| "block argument, expected:{0}, got: {1}", |
| elemTp, valueTp)); |
| return success(); |
| } |
| |
| OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) { |
| if (getSparseTensorEncoding(getInputCoo().getType()) == |
| getSparseTensorEncoding(getResultCoo().getType())) |
| return getInputCoo(); |
| |
| return {}; |
| } |
| |
| LogicalResult ReorderCOOOp::verify() { |
| SparseTensorType srcStt = getSparseTensorType(getInputCoo()); |
| SparseTensorType dstStt = getSparseTensorType(getResultCoo()); |
| |
| if (!srcStt.isCOOType() || !dstStt.isCOOType()) |
| emitError("Expected COO sparse tensors only"); |
| |
| if (!srcStt.hasSameDimToLvl(dstStt)) |
| emitError("Unmatched dim2lvl map between input and result COO"); |
| |
| if (srcStt.getPosType() != dstStt.getPosType() || |
| srcStt.getCrdType() != dstStt.getCrdType() || |
| srcStt.getElementType() != dstStt.getElementType()) |
| emitError("Unmatched storage format between input and result COO"); |
| |
| return success(); |
| } |
| |
| LogicalResult ReduceOp::verify() { |
| Type inputType = getX().getType(); |
| Region &formula = getRegion(); |
| return verifyNumBlockArgs(this, formula, "reduce", |
| TypeRange{inputType, inputType}, inputType); |
| } |
| |
| LogicalResult SelectOp::verify() { |
| Builder b(getContext()); |
| Type inputType = getX().getType(); |
| Type boolType = b.getI1Type(); |
| Region &formula = getRegion(); |
| return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType}, |
| boolType); |
| } |
| |
| LogicalResult SortOp::verify() { |
| AffineMap xPerm = getPermMap(); |
| uint64_t nx = xPerm.getNumDims(); |
| if (nx < 1) |
| emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx)); |
| |
| if (!xPerm.isPermutation()) |
| emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm)); |
| |
| // We can't check the size of the buffers when n or buffer dimensions aren't |
| // compile-time constants. |
| std::optional<int64_t> cn = getConstantIntValue(getN()); |
| if (!cn) |
| return success(); |
| |
| // Verify dimensions. |
| const auto checkDim = [&](Value v, Size minSize, const char *message) { |
| const Size sh = getMemRefType(v).getShape()[0]; |
| if (!ShapedType::isDynamic(sh) && sh < minSize) |
| emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize)); |
| }; |
| uint64_t n = cn.value(); |
| uint64_t ny = 0; |
| if (auto nyAttr = getNyAttr()) |
| ny = nyAttr.getInt(); |
| checkDim(getXy(), n * (nx + ny), |
| "Expected dimension(xy) >= n * (rank(perm_map) + ny)"); |
| for (Value opnd : getYs()) |
| checkDim(opnd, n, "Expected dimension(y) >= n"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Sparse Tensor Iteration Operations. |
| //===----------------------------------------------------------------------===// |
| |
| IterSpaceType IteratorType::getIterSpaceType() const { |
| return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(), |
| getHiLvl()); |
| } |
| |
| IteratorType IterSpaceType::getIteratorType() const { |
| return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl()); |
| } |
| |
| /// Parses a level range in the form "$lo `to` $hi" |
| /// or simply "$lo" if $hi - $lo = 1 |
| static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo, |
| Level &lvlHi) { |
| if (parser.parseInteger(lvlLo)) |
| return failure(); |
| |
| if (succeeded(parser.parseOptionalKeyword("to"))) { |
| if (parser.parseInteger(lvlHi)) |
| return failure(); |
| } else { |
| lvlHi = lvlLo + 1; |
| } |
| |
| if (lvlHi <= lvlLo) |
| parser.emitError(parser.getNameLoc(), |
| "expect larger level upper bound than lower bound"); |
| |
| return success(); |
| } |
| |
| /// Parses a level range in the form "$lo `to` $hi" |
| /// or simply "$lo" if $hi - $lo = 1 |
| static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr, |
| IntegerAttr &lvlHiAttr) { |
| Level lvlLo, lvlHi; |
| if (parseLevelRange(parser, lvlLo, lvlHi)) |
| return failure(); |
| |
| lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo); |
| lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi); |
| return success(); |
| } |
| |
| /// Prints a level range in the form "$lo `to` $hi" |
| /// or simply "$lo" if $hi - $lo = 1 |
| static void printLevelRange(AsmPrinter &p, Level lo, Level hi) { |
| |
| if (lo + 1 == hi) |
| p << lo; |
| else |
| p << lo << " to " << hi; |
| } |
| |
| /// Prints a level range in the form "$lo `to` $hi" |
| /// or simply "$lo" if $hi - $lo = 1 |
| static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo, |
| IntegerAttr lvlHi) { |
| unsigned lo = lvlLo.getValue().getZExtValue(); |
| unsigned hi = lvlHi.getValue().getZExtValue(); |
| printLevelRange(p, lo, hi); |
| } |
| |
| LogicalResult ExtractIterSpaceOp::inferReturnTypes( |
| MLIRContext *ctx, std::optional<Location> loc, ValueRange ops, |
| DictionaryAttr attr, OpaqueProperties prop, RegionRange region, |
| SmallVectorImpl<mlir::Type> &ret) { |
| |
| ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region); |
| SparseTensorType stt = getSparseTensorType(adaptor.getTensor()); |
| ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(), |
| adaptor.getHiLvl())); |
| return success(); |
| } |
| |
| LogicalResult ExtractIterSpaceOp::verify() { |
| if (getLoLvl() >= getHiLvl()) |
| return emitOpError("expected smaller level low than level high"); |
| |
| TypedValue<IteratorType> pIter = getParentIter(); |
| if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) { |
| return emitOpError( |
| "parent iterator should be specified iff level lower bound equals 0"); |
| } |
| |
| if (pIter) { |
| IterSpaceType spaceTp = getResultSpace().getType(); |
| if (pIter.getType().getEncoding() != spaceTp.getEncoding()) |
| return emitOpError( |
| "mismatch in parent iterator encoding and iteration space encoding."); |
| |
| if (spaceTp.getLoLvl() != pIter.getType().getHiLvl()) |
| return emitOpError("parent iterator should be used to extract an " |
| "iteration space from a consecutive level."); |
| } |
| |
| return success(); |
| } |
| |
| /// Materialize a single constant operation from a given attribute value with |
| /// the desired resultant type. |
| Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder, |
| Attribute value, Type type, |
| Location loc) { |
| if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) |
| return op; |
| return nullptr; |
| } |
| |
| namespace { |
| struct SparseTensorAsmDialectInterface : public OpAsmDialectInterface { |
| using OpAsmDialectInterface::OpAsmDialectInterface; |
| |
| AliasResult getAlias(Attribute attr, raw_ostream &os) const override { |
| if (isa<SparseTensorEncodingAttr>(attr)) { |
| os << "sparse"; |
| return AliasResult::OverridableAlias; |
| } |
| return AliasResult::NoAlias; |
| } |
| }; |
| } // namespace |
| |
| void SparseTensorDialect::initialize() { |
| addInterface<SparseTensorAsmDialectInterface>(); |
| addAttributes< |
| #define GET_ATTRDEF_LIST |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc" |
| >(); |
| addTypes< |
| #define GET_TYPEDEF_LIST |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc" |
| >(); |
| addOperations< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" |
| >(); |
| declarePromisedInterfaces< |
| bufferization::BufferizableOpInterface, ConcatenateOp, ConvertOp, LoadOp, |
| NewOp, NumberOfEntriesOp, AssembleOp, DisassembleOp, |
| ToCoordinatesBufferOp, ToCoordinatesOp, ToPositionsOp, ToValuesOp>(); |
| } |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.cpp.inc" |
| |
| #include "mlir/Dialect/SparseTensor/IR/SparseTensorOpsDialect.cpp.inc" |