blob: 34d99913fbd51beef8f6de3b11df4eedaf59a93d [file] [log] [blame]
//===- SparseTensorType.h - Wrapper around RankedTensorType -----*- C++ -*-===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// This header defines the `SparseTensorType` wrapper class.
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
namespace mlir {
namespace sparse_tensor {
/// A simple structure that encodes a range of levels in the sparse tensors that
/// forms a COO segment.
struct COOSegment {
std::pair<Level, Level> lvlRange; // [low, high)
bool isSoA;
bool isAoS() const { return !isSoA; }
bool isSegmentStart(Level l) const { return l == lvlRange.first; }
bool inSegment(Level l) const {
return l >= lvlRange.first && l < lvlRange.second;
/// A wrapper around `RankedTensorType`, which has three goals:
/// (1) To provide a uniform API for querying aspects of sparse-tensor
/// types; in particular, to make the "dimension" vs "level" distinction
/// overt (i.e., explicit everywhere). Thus, throughout the sparsifier
/// this class should be preferred over using `RankedTensorType` or
/// `ShapedType` directly, since the methods of the latter do not make
/// the "dimension" vs "level" distinction overt.
/// (2) To provide a uniform abstraction over both sparse-tensor
/// types (i.e., `RankedTensorType` with `SparseTensorEncodingAttr`)
/// and dense-tensor types (i.e., `RankedTensorType` without an encoding).
/// That is, we want to manipulate dense-tensor types using the same API
/// that we use for manipulating sparse-tensor types; both to keep the
/// "dimension" vs "level" distinction overt, and to avoid needing to
/// handle certain cases specially in the sparsifier.
/// (3) To provide uniform handling of "defaults". In particular
/// this means that dense-tensors should always return the same answers
/// as sparse-tensors with a default encoding. But it additionally means
/// that the answers should be normalized, so that there's no way to
/// distinguish between non-provided data (which is filled in by default)
/// vs explicitly-provided data which equals the defaults.
class SparseTensorType {
// We memoize `lvlRank`, `dimToLvl`, and `lvlToDim` to avoid repeating
// the conditionals throughout the rest of the class.
SparseTensorType(RankedTensorType rtp)
: rtp(rtp), enc(getSparseTensorEncoding(rtp)),
lvlRank(enc ? enc.getLvlRank() : getDimRank()),
dimToLvl(enc.isIdentity() ? AffineMap() : enc.getDimToLvl()),
lvlToDim(enc.isIdentity() ? AffineMap() : enc.getLvlToDim()) {
assert(rtp && "got null RankedTensorType");
assert((!isIdentity() || getDimRank() == lvlRank) && "Rank mismatch");
SparseTensorType(ShapedType stp, SparseTensorEncodingAttr enc)
: SparseTensorType(
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
// TODO: remove?
SparseTensorType(SparseTensorEncodingAttr enc)
: SparseTensorType(RankedTensorType::get(
SmallVector<Size>(enc.getDimRank(), ShapedType::kDynamic),
Float32Type::get(enc.getContext()), enc)) {}
SparseTensorType &operator=(const SparseTensorType &) = delete;
SparseTensorType(const SparseTensorType &) = default;
// Factory methods to construct a new `SparseTensorType`
// with the same dimension-shape and element type.
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
return SparseTensorType(rtp, newEnc);
SparseTensorType withDimToLvl(AffineMap dimToLvl) const {
return withEncoding(enc.withDimToLvl(dimToLvl));
SparseTensorType withDimToLvl(SparseTensorEncodingAttr dimToLvlEnc) const {
return withEncoding(enc.withDimToLvl(dimToLvlEnc));
SparseTensorType withDimToLvl(const SparseTensorType &dimToLvlSTT) const {
return withDimToLvl(dimToLvlSTT.getEncoding());
SparseTensorType withoutDimToLvl() const {
return withEncoding(enc.withoutDimToLvl());
SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
return withEncoding(enc.withBitWidths(posWidth, crdWidth));
SparseTensorType withoutBitWidths() const {
return withEncoding(enc.withoutBitWidths());
SparseTensorType withExplicitVal(Attribute explicitVal) const {
return withEncoding(enc.withExplicitVal(explicitVal));
SparseTensorType withoutExplicitVal() const {
return withEncoding(enc.withoutExplicitVal());
SparseTensorType withImplicitVal(Attribute implicitVal) const {
return withEncoding(enc.withImplicitVal(implicitVal));
SparseTensorType withoutImplicitVal() const {
return withEncoding(enc.withoutImplicitVal());
withDimSlices(ArrayRef<SparseTensorDimSliceAttr> dimSlices) const {
return withEncoding(enc.withDimSlices(dimSlices));
SparseTensorType withoutDimSlices() const {
return withEncoding(enc.withoutDimSlices());
/// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
/// and `Type`. These are implicit to help alleviate the impedance
/// mismatch for code that has not been converted to use `SparseTensorType`
/// directly. Once more uses have been converted to `SparseTensorType`,
/// we may want to make these explicit instead.
/// WARNING: This user-defined-conversion method causes overload
/// ambiguity whenever passing a `SparseTensorType` directly to a
/// function which is overloaded to accept either `Type` or `TypeRange`.
/// In particular, this includes `RewriterBase::replaceOpWithNewOp<OpTy>`
/// and `OpBuilder::create<OpTy>` whenever the `OpTy::build` is overloaded
/// thus. This happens because the `TypeRange<T>(T&&)` ctor is implicit
/// as well, and there's no SFINAE we can add to this method that would
/// block subsequent application of that ctor. The only way to fix the
/// overload ambiguity is to avoid *implicit* conversion at the callsite:
/// e.g., by using `static_cast` to make the conversion explicit, by
/// assigning the `SparseTensorType` to a temporary variable of the
/// desired type, etc.
// NOTE: We implement this as a single templated user-defined-conversion
// function to avoid ambiguity problems when the desired result is `Type`
// (since both `RankedTensorType` and `ShapedType` can be implicitly
// converted to `Type`).
template <typename T, typename = std::enable_if_t<
std::is_convertible_v<RankedTensorType, T>>>
/*implicit*/ operator T() const {
return rtp;
/// Explicitly convert to `RankedTensorType`. This method is
/// a convenience for resolving overload-ambiguity issues with
/// implicit conversion.
RankedTensorType getRankedTensorType() const { return rtp; }
bool operator==(const SparseTensorType &other) const {
// All other fields are derived from `rtp` and therefore don't need
// to be checked.
return rtp == other.rtp;
bool operator!=(const SparseTensorType &other) const {
return !(*this == other);
MLIRContext *getContext() const { return rtp.getContext(); }
Type getElementType() const { return rtp.getElementType(); }
SparseTensorEncodingAttr getEncoding() const { return enc; }
// SparseTensorEncodingAttr delegators
/// Returns true for tensors which have an encoding, and false for
/// those which do not. Therefore tensors with an all-dense encoding
/// return true.
bool hasEncoding() const { return static_cast<bool>(enc); }
/// Returns true for tensors where every level is dense.
/// (This is always true for dense-tensors.)
bool isAllDense() const { return enc.isAllDense(); }
/// Returns true for tensors where every level is ordered.
/// (This is always true for dense-tensors.)
bool isAllOrdered() const { return enc.isAllOrdered(); }
/// Translates between level / dimension coordinate space.
ValueRange translateCrds(OpBuilder &builder, Location loc, ValueRange crds,
CrdTransDirectionKind dir) const {
return enc.translateCrds(builder, loc, crds, dir);
/// Returns true if the dimToLvl mapping is a permutation.
/// (This is always true for dense-tensors.)
bool isPermutation() const { return enc.isPermutation(); }
/// Returns true if the dimToLvl mapping is the identity.
/// (This is always true for dense-tensors.)
bool isIdentity() const { return enc.isIdentity(); }
// Other methods.
/// Returns the dimToLvl mapping (or the null-map for the identity).
/// If you intend to compare the results of this method for equality,
/// see `hasSameDimToLvl` instead.
AffineMap getDimToLvl() const { return dimToLvl; }
/// Returns the lvlToDiml mapping (or the null-map for the identity).
AffineMap getLvlToDim() const { return lvlToDim; }
/// Returns the dimToLvl mapping, where the identity map is expanded out
/// into a full `AffineMap`. This method is provided as a convenience,
/// but for most purposes other methods (`isIdentity`, `getDimToLvl`,
/// etc) will be more helpful.
AffineMap getExpandedDimToLvl() const {
return dimToLvl
? dimToLvl
: AffineMap::getMultiDimIdentityMap(getDimRank(), getContext());
/// Returns true iff the two types have the same mapping. This method
/// takes care to handle identity maps properly, so it should be preferred
/// over using `getDimToLvl` followed by `AffineMap::operator==`.
bool hasSameDimToLvl(const SparseTensorType &other) const {
// If the maps are the identity, then we need to check the rank
// to be sure they're the same size identity. (And since identity
// means dimRank==lvlRank, we use lvlRank as a minor optimization.)
return isIdentity() ? (other.isIdentity() && lvlRank == other.lvlRank)
: (dimToLvl == other.dimToLvl);
/// Returns the dimension-rank.
Dimension getDimRank() const { return rtp.getRank(); }
/// Returns the level-rank.
Level getLvlRank() const { return lvlRank; }
/// Returns the dimension-shape.
ArrayRef<Size> getDimShape() const { return rtp.getShape(); }
/// Returns the level-shape.
SmallVector<Size> getLvlShape() const {
return getEncoding().translateShape(getDimShape(),
/// Returns the batched level-rank.
unsigned getBatchLvlRank() const { return getEncoding().getBatchLvlRank(); }
/// Returns the batched level-shape.
SmallVector<Size> getBatchLvlShape() const {
auto lvlShape = getEncoding().translateShape(
getDimShape(), CrdTransDirectionKind::dim2lvl);
return lvlShape;
/// Returns the type with an identity mapping.
RankedTensorType getDemappedType() const {
return RankedTensorType::get(getLvlShape(), getElementType(),
/// Safely looks up the requested dimension-DynSize. If you intend
/// to check the result with `ShapedType::isDynamic`, then see the
/// `getStaticDimSize` method instead.
Size getDynamicDimSize(Dimension d) const {
assert(d < getDimRank() && "Dimension is out of bounds");
return getDimShape()[d];
/// Returns true if no dimension has dynamic size.
bool hasStaticDimShape() const { return rtp.hasStaticShape(); }
/// Returns true if any dimension has dynamic size.
bool hasDynamicDimShape() const { return !hasStaticDimShape(); }
/// Returns true if the given dimension has dynamic size. If you
/// intend to call `getDynamicDimSize` based on the result, then see
/// the `getStaticDimSize` method instead.
bool isDynamicDim(Dimension d) const {
// We don't use `rtp.isDynamicDim(d)` because we want the
// OOB error message to be consistent with `getDynamicDimSize`.
return ShapedType::isDynamic(getDynamicDimSize(d));
/// Returns the number of dimensions which have dynamic sizes.
/// The return type is `int64_t` to maintain consistency with
/// `ShapedType::Trait<T>::getNumDynamicDims`.
int64_t getNumDynamicDims() const { return rtp.getNumDynamicDims(); }
ArrayRef<LevelType> getLvlTypes() const { return enc.getLvlTypes(); }
LevelType getLvlType(Level l) const {
// This OOB check is for dense-tensors, since this class knows
// their lvlRank (whereas STEA::getLvlType will/can only check
// OOB for sparse-tensors).
assert(l < lvlRank && "Level out of bounds");
return enc.getLvlType(l);
// We can't just delegate these, since we want to use this class's
// `getLvlType` method instead of STEA's.
bool isDenseLvl(Level l) const { return isDenseLT(getLvlType(l)); }
bool isCompressedLvl(Level l) const { return isCompressedLT(getLvlType(l)); }
bool isLooseCompressedLvl(Level l) const {
return isLooseCompressedLT(getLvlType(l));
bool isSingletonLvl(Level l) const { return isSingletonLT(getLvlType(l)); }
bool isNOutOfMLvl(Level l) const { return isNOutOfMLT(getLvlType(l)); }
bool isOrderedLvl(Level l) const { return isOrderedLT(getLvlType(l)); }
bool isUniqueLvl(Level l) const { return isUniqueLT(getLvlType(l)); }
bool isWithPos(Level l) const { return isWithPosLT(getLvlType(l)); }
bool isWithCrd(Level l) const { return isWithCrdLT(getLvlType(l)); }
/// Returns the coordinate-overhead bitwidth, defaulting to zero.
unsigned getCrdWidth() const { return enc ? enc.getCrdWidth() : 0; }
/// Returns the position-overhead bitwidth, defaulting to zero.
unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
/// Returns the explicit value, defaulting to null Attribute for unset.
Attribute getExplicitVal() const { return enc.getExplicitVal(); }
/// Returns the implicit value, defaulting to null Attribute for 0.
Attribute getImplicitVal() const { return enc.getImplicitVal(); }
/// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
Type getCrdType() const { return enc.getCrdElemType(); }
/// Returns the position-overhead MLIR type, defaulting to `IndexType`.
Type getPosType() const { return enc.getPosElemType(); }
/// Returns true iff this sparse tensor type has a trailing
/// COO region starting at the given level. By default, it
/// tests for a unique COO type at top level.
bool isCOOType(Level startLvl = 0, bool isUnique = true) const;
/// Returns the starting level of this sparse tensor type for a
/// trailing COO region that spans **at least** two levels. If
/// no such COO region is found, then returns the level-rank.
/// DEPRECATED: use getCOOSegment instead;
Level getAoSCOOStart() const;
/// Returns [un]ordered COO type for this sparse tensor type.
RankedTensorType getCOOType(bool ordered) const;
/// Returns a list of COO segments in the sparse tensor types.
SmallVector<COOSegment> getCOOSegments() const;
// These two must be const, to ensure coherence of the memoized fields.
const RankedTensorType rtp;
const SparseTensorEncodingAttr enc;
// Memoized to avoid frequent redundant conditionals.
const Level lvlRank;
const AffineMap dimToLvl;
const AffineMap lvlToDim;
/// Convenience methods to obtain a SparseTensorType from a Value.
inline SparseTensorType getSparseTensorType(Value val) {
return SparseTensorType(cast<RankedTensorType>(val.getType()));
inline std::optional<SparseTensorType> tryGetSparseTensorType(Value val) {
if (auto rtp = dyn_cast<RankedTensorType>(val.getType()))
return SparseTensorType(rtp);
return std::nullopt;
} // namespace sparse_tensor
} // namespace mlir