| //===- Enums.h - Enums for the SparseTensor dialect -------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Typedefs and enums shared between MLIR code for manipulating the |
| // IR, and the lightweight runtime support library for sparse tensor |
| // manipulations. That is, all the enums are used to define the API |
| // of the runtime library and hence are also needed when generating |
| // calls into the runtime library. Moveover, the `LevelType` enum |
| // is also used as the internal IR encoding of dimension level types, |
| // to avoid code duplication (e.g., for the predicates). |
| // |
| // This file also defines x-macros <https://en.wikipedia.org/wiki/X_Macro> |
| // so that we can generate variations of the public functions for each |
| // supported primary- and/or overhead-type. |
| // |
| // Because this file defines a library which is a dependency of the |
| // runtime library itself, this file must not depend on any MLIR internals |
| // (e.g., operators, attributes, ArrayRefs, etc) lest the runtime library |
| // inherit those dependencies. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H |
| #define MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H |
| |
| // NOTE: Client code will need to include "mlir/ExecutionEngine/Float16bits.h" |
| // if they want to use the `MLIR_SPARSETENSOR_FOREVERY_V` macro. |
| |
| #include <cassert> |
| #include <cinttypes> |
| #include <complex> |
| #include <optional> |
| #include <vector> |
| |
| namespace mlir { |
| namespace sparse_tensor { |
| |
| /// This type is used in the public API at all places where MLIR expects |
| /// values with the built-in type "index". For now, we simply assume that |
| /// type is 64-bit, but targets with different "index" bitwidths should |
| /// link with an alternatively built runtime support library. |
| using index_type = uint64_t; |
| |
| /// Encoding of overhead types (both position overhead and coordinate |
| /// overhead), for "overloading" @newSparseTensor. |
| enum class OverheadType : uint32_t { |
| kIndex = 0, |
| kU64 = 1, |
| kU32 = 2, |
| kU16 = 3, |
| kU8 = 4 |
| }; |
| |
| // This x-macro calls its argument on every overhead type which has |
| // fixed-width. It excludes `index_type` because that type is often |
| // handled specially (e.g., by translating it into the architecture-dependent |
| // equivalent fixed-width overhead type). |
| #define MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \ |
| DO(64, uint64_t) \ |
| DO(32, uint32_t) \ |
| DO(16, uint16_t) \ |
| DO(8, uint8_t) |
| |
| // This x-macro calls its argument on every overhead type, including |
| // `index_type`. |
| #define MLIR_SPARSETENSOR_FOREVERY_O(DO) \ |
| MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DO) \ |
| DO(0, index_type) |
| |
| // These are not just shorthands but indicate the particular |
| // implementation used (e.g., as opposed to C99's `complex double`, |
| // or MLIR's `ComplexType`). |
| using complex64 = std::complex<double>; |
| using complex32 = std::complex<float>; |
| |
| /// Encoding of the elemental type, for "overloading" @newSparseTensor. |
| enum class PrimaryType : uint32_t { |
| kF64 = 1, |
| kF32 = 2, |
| kF16 = 3, |
| kBF16 = 4, |
| kI64 = 5, |
| kI32 = 6, |
| kI16 = 7, |
| kI8 = 8, |
| kC64 = 9, |
| kC32 = 10 |
| }; |
| |
| // This x-macro includes all `V` types. |
| #define MLIR_SPARSETENSOR_FOREVERY_V(DO) \ |
| DO(F64, double) \ |
| DO(F32, float) \ |
| DO(F16, f16) \ |
| DO(BF16, bf16) \ |
| DO(I64, int64_t) \ |
| DO(I32, int32_t) \ |
| DO(I16, int16_t) \ |
| DO(I8, int8_t) \ |
| DO(C64, complex64) \ |
| DO(C32, complex32) |
| |
| // This x-macro includes all `V` types and supports variadic arguments. |
| #define MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, ...) \ |
| DO(F64, double, __VA_ARGS__) \ |
| DO(F32, float, __VA_ARGS__) \ |
| DO(F16, f16, __VA_ARGS__) \ |
| DO(BF16, bf16, __VA_ARGS__) \ |
| DO(I64, int64_t, __VA_ARGS__) \ |
| DO(I32, int32_t, __VA_ARGS__) \ |
| DO(I16, int16_t, __VA_ARGS__) \ |
| DO(I8, int8_t, __VA_ARGS__) \ |
| DO(C64, complex64, __VA_ARGS__) \ |
| DO(C32, complex32, __VA_ARGS__) |
| |
| // This x-macro calls its argument on every pair of overhead and `V` types. |
| #define MLIR_SPARSETENSOR_FOREVERY_V_O(DO) \ |
| MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 64, uint64_t) \ |
| MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 32, uint32_t) \ |
| MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 16, uint16_t) \ |
| MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 8, uint8_t) \ |
| MLIR_SPARSETENSOR_FOREVERY_V_VAR(DO, 0, index_type) |
| |
| constexpr bool isFloatingPrimaryType(PrimaryType valTy) { |
| return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kBF16; |
| } |
| |
| constexpr bool isIntegralPrimaryType(PrimaryType valTy) { |
| return PrimaryType::kI64 <= valTy && valTy <= PrimaryType::kI8; |
| } |
| |
| constexpr bool isRealPrimaryType(PrimaryType valTy) { |
| return PrimaryType::kF64 <= valTy && valTy <= PrimaryType::kI8; |
| } |
| |
| constexpr bool isComplexPrimaryType(PrimaryType valTy) { |
| return PrimaryType::kC64 <= valTy && valTy <= PrimaryType::kC32; |
| } |
| |
| /// The actions performed by @newSparseTensor. |
| enum class Action : uint32_t { |
| kEmpty = 0, |
| kFromReader = 1, |
| kPack = 2, |
| kSortCOOInPlace = 3, |
| }; |
| |
| /// This enum defines all supported storage format without the level properties. |
| enum class LevelFormat : uint64_t { |
| Undef = 0x00000000, |
| Dense = 0x00010000, |
| Batch = 0x00020000, |
| Compressed = 0x00040000, |
| Singleton = 0x00080000, |
| LooseCompressed = 0x00100000, |
| NOutOfM = 0x00200000, |
| }; |
| |
| constexpr bool encPowOfTwo(LevelFormat fmt) { |
| auto enc = static_cast<std::underlying_type_t<LevelFormat>>(fmt); |
| return (enc & (enc - 1)) == 0; |
| } |
| |
| // All LevelFormats must have only one bit set (power of two). |
| static_assert(encPowOfTwo(LevelFormat::Dense) && |
| encPowOfTwo(LevelFormat::Batch) && |
| encPowOfTwo(LevelFormat::Compressed) && |
| encPowOfTwo(LevelFormat::Singleton) && |
| encPowOfTwo(LevelFormat::LooseCompressed) && |
| encPowOfTwo(LevelFormat::NOutOfM)); |
| |
| template <LevelFormat... targets> |
| constexpr bool isAnyOfFmt(LevelFormat fmt) { |
| return (... || (targets == fmt)); |
| } |
| |
| /// Returns string representation of the given level format. |
| constexpr const char *toFormatString(LevelFormat lvlFmt) { |
| switch (lvlFmt) { |
| case LevelFormat::Undef: |
| return "undef"; |
| case LevelFormat::Dense: |
| return "dense"; |
| case LevelFormat::Batch: |
| return "batch"; |
| case LevelFormat::Compressed: |
| return "compressed"; |
| case LevelFormat::Singleton: |
| return "singleton"; |
| case LevelFormat::LooseCompressed: |
| return "loose_compressed"; |
| case LevelFormat::NOutOfM: |
| return "structured"; |
| } |
| return ""; |
| } |
| |
| /// This enum defines all the nondefault properties for storage formats. |
| enum class LevelPropNonDefault : uint64_t { |
| Nonunique = 0x0001, // 0b001 |
| Nonordered = 0x0002, // 0b010 |
| SoA = 0x0004, // 0b100 |
| }; |
| |
| /// Returns string representation of the given level properties. |
| constexpr const char *toPropString(LevelPropNonDefault lvlProp) { |
| switch (lvlProp) { |
| case LevelPropNonDefault::Nonunique: |
| return "nonunique"; |
| case LevelPropNonDefault::Nonordered: |
| return "nonordered"; |
| case LevelPropNonDefault::SoA: |
| return "soa"; |
| } |
| return ""; |
| } |
| |
| /// This enum defines all the sparse representations supportable by |
| /// the SparseTensor dialect. We use a lightweight encoding to encode |
| /// the "format" per se (dense, compressed, singleton, loose_compressed, |
| /// n-out-of-m), the "properties" (ordered, unique) as well as n and m when |
| /// the format is NOutOfM. |
| /// The encoding is chosen for performance of the runtime library, and thus may |
| /// change in future versions; consequently, client code should use the |
| /// predicate functions defined below, rather than relying on knowledge |
| /// about the particular binary encoding. |
| /// |
| /// The `Undef` "format" is a special value used internally for cases |
| /// where we need to store an undefined or indeterminate `LevelType`. |
| /// It should not be used externally, since it does not indicate an |
| /// actual/representable format. |
| |
| struct LevelType { |
| public: |
| /// Check that the `LevelType` contains a valid (possibly undefined) value. |
| static constexpr bool isValidLvlBits(uint64_t lvlBits) { |
| auto fmt = static_cast<LevelFormat>(lvlBits & 0xffff0000); |
| const uint64_t propertyBits = lvlBits & 0xffff; |
| // If undefined/dense/batch/NOutOfM, then must be unique and ordered. |
| // Otherwise, the format must be one of the known ones. |
| return (isAnyOfFmt<LevelFormat::Undef, LevelFormat::Dense, |
| LevelFormat::Batch, LevelFormat::NOutOfM>(fmt)) |
| ? (propertyBits == 0) |
| : (isAnyOfFmt<LevelFormat::Compressed, LevelFormat::Singleton, |
| LevelFormat::LooseCompressed>(fmt)); |
| } |
| |
| /// Convert a LevelFormat to its corresponding LevelType with the given |
| /// properties. Returns std::nullopt when the properties are not applicable |
| /// for the input level format. |
| static std::optional<LevelType> |
| buildLvlType(LevelFormat lf, |
| const std::vector<LevelPropNonDefault> &properties, |
| uint64_t n = 0, uint64_t m = 0) { |
| assert((n & 0xff) == n && (m & 0xff) == m); |
| uint64_t newN = n << 32; |
| uint64_t newM = m << 40; |
| uint64_t ltBits = static_cast<uint64_t>(lf) | newN | newM; |
| for (auto p : properties) |
| ltBits |= static_cast<uint64_t>(p); |
| |
| return isValidLvlBits(ltBits) ? std::optional(LevelType(ltBits)) |
| : std::nullopt; |
| } |
| static std::optional<LevelType> buildLvlType(LevelFormat lf, bool ordered, |
| bool unique, uint64_t n = 0, |
| uint64_t m = 0) { |
| std::vector<LevelPropNonDefault> properties; |
| if (!ordered) |
| properties.push_back(LevelPropNonDefault::Nonordered); |
| if (!unique) |
| properties.push_back(LevelPropNonDefault::Nonunique); |
| return buildLvlType(lf, properties, n, m); |
| } |
| |
| /// Explicit conversion from uint64_t. |
| constexpr explicit LevelType(uint64_t bits) : lvlBits(bits) { |
| assert(isValidLvlBits(bits)); |
| }; |
| |
| /// Constructs a LevelType with the given format using all default properties. |
| /*implicit*/ LevelType(LevelFormat f) : lvlBits(static_cast<uint64_t>(f)) { |
| assert(isValidLvlBits(lvlBits) && !isa<LevelFormat::NOutOfM>()); |
| }; |
| |
| /// Converts to uint64_t |
| explicit operator uint64_t() const { return lvlBits; } |
| |
| bool operator==(const LevelType lhs) const { |
| return static_cast<uint64_t>(lhs) == lvlBits; |
| } |
| bool operator!=(const LevelType lhs) const { return !(*this == lhs); } |
| |
| LevelType stripStorageIrrelevantProperties() const { |
| // Properties other than `SoA` do not change the storage scheme of the |
| // sparse tensor. |
| constexpr uint64_t mask = |
| 0xffff & ~static_cast<uint64_t>(LevelPropNonDefault::SoA); |
| return LevelType(lvlBits & ~mask); |
| } |
| |
| /// Get N of NOutOfM level type. |
| constexpr uint64_t getN() const { |
| assert(isa<LevelFormat::NOutOfM>()); |
| return (lvlBits >> 32) & 0xff; |
| } |
| |
| /// Get M of NOutOfM level type. |
| constexpr uint64_t getM() const { |
| assert(isa<LevelFormat::NOutOfM>()); |
| return (lvlBits >> 40) & 0xff; |
| } |
| |
| /// Get the `LevelFormat` of the `LevelType`. |
| constexpr LevelFormat getLvlFmt() const { |
| return static_cast<LevelFormat>(lvlBits & 0xffff0000); |
| } |
| |
| /// Check if the `LevelType` is in the `LevelFormat`. |
| template <LevelFormat... fmt> |
| constexpr bool isa() const { |
| return (... || (getLvlFmt() == fmt)) || false; |
| } |
| |
| /// Check if the `LevelType` has the properties |
| template <LevelPropNonDefault p> |
| constexpr bool isa() const { |
| return lvlBits & static_cast<uint64_t>(p); |
| } |
| |
| /// Check if the `LevelType` is considered to be sparse. |
| constexpr bool hasSparseSemantic() const { |
| return isa<LevelFormat::Compressed, LevelFormat::Singleton, |
| LevelFormat::LooseCompressed, LevelFormat::NOutOfM>(); |
| } |
| |
| /// Check if the `LevelType` is considered to be dense-like. |
| constexpr bool hasDenseSemantic() const { |
| return isa<LevelFormat::Dense, LevelFormat::Batch>(); |
| } |
| |
| /// Check if the `LevelType` needs positions array. |
| constexpr bool isWithPosLT() const { |
| assert(!isa<LevelFormat::Undef>()); |
| return isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>(); |
| } |
| |
| /// Check if the `LevelType` needs coordinates array. |
| constexpr bool isWithCrdLT() const { |
| assert(!isa<LevelFormat::Undef>()); |
| // All sparse levels has coordinate array. |
| return hasSparseSemantic(); |
| } |
| |
| std::string toMLIRString() const { |
| std::string lvlStr = toFormatString(getLvlFmt()); |
| std::string propStr = ""; |
| if (isa<LevelFormat::NOutOfM>()) { |
| lvlStr += |
| "[" + std::to_string(getN()) + ", " + std::to_string(getM()) + "]"; |
| } |
| if (isa<LevelPropNonDefault::Nonunique>()) |
| propStr += toPropString(LevelPropNonDefault::Nonunique); |
| |
| if (isa<LevelPropNonDefault::Nonordered>()) { |
| if (!propStr.empty()) |
| propStr += ", "; |
| propStr += toPropString(LevelPropNonDefault::Nonordered); |
| } |
| if (isa<LevelPropNonDefault::SoA>()) { |
| if (!propStr.empty()) |
| propStr += ", "; |
| propStr += toPropString(LevelPropNonDefault::SoA); |
| } |
| if (!propStr.empty()) |
| lvlStr += ("(" + propStr + ")"); |
| return lvlStr; |
| } |
| |
| private: |
| /// Bit manipulations for LevelType: |
| /// |
| /// | 8-bit n | 8-bit m | 16-bit LevelFormat | 16-bit LevelProperty | |
| /// |
| uint64_t lvlBits; |
| }; |
| |
| // For backward-compatibility. TODO: remove below after fully migration. |
| constexpr uint64_t nToBits(uint64_t n) { return n << 32; } |
| constexpr uint64_t mToBits(uint64_t m) { return m << 40; } |
| |
| inline std::optional<LevelType> |
| buildLevelType(LevelFormat lf, |
| const std::vector<LevelPropNonDefault> &properties, |
| uint64_t n = 0, uint64_t m = 0) { |
| return LevelType::buildLvlType(lf, properties, n, m); |
| } |
| inline std::optional<LevelType> buildLevelType(LevelFormat lf, bool ordered, |
| bool unique, uint64_t n = 0, |
| uint64_t m = 0) { |
| return LevelType::buildLvlType(lf, ordered, unique, n, m); |
| } |
| inline bool isUndefLT(LevelType lt) { return lt.isa<LevelFormat::Undef>(); } |
| inline bool isDenseLT(LevelType lt) { return lt.isa<LevelFormat::Dense>(); } |
| inline bool isBatchLT(LevelType lt) { return lt.isa<LevelFormat::Batch>(); } |
| inline bool isCompressedLT(LevelType lt) { |
| return lt.isa<LevelFormat::Compressed>(); |
| } |
| inline bool isLooseCompressedLT(LevelType lt) { |
| return lt.isa<LevelFormat::LooseCompressed>(); |
| } |
| inline bool isSingletonLT(LevelType lt) { |
| return lt.isa<LevelFormat::Singleton>(); |
| } |
| inline bool isNOutOfMLT(LevelType lt) { return lt.isa<LevelFormat::NOutOfM>(); } |
| inline bool isOrderedLT(LevelType lt) { |
| return !lt.isa<LevelPropNonDefault::Nonordered>(); |
| } |
| inline bool isUniqueLT(LevelType lt) { |
| return !lt.isa<LevelPropNonDefault::Nonunique>(); |
| } |
| inline bool isWithCrdLT(LevelType lt) { return lt.isWithCrdLT(); } |
| inline bool isWithPosLT(LevelType lt) { return lt.isWithPosLT(); } |
| inline bool isValidLT(LevelType lt) { |
| return LevelType::isValidLvlBits(static_cast<uint64_t>(lt)); |
| } |
| inline std::optional<LevelFormat> getLevelFormat(LevelType lt) { |
| LevelFormat fmt = lt.getLvlFmt(); |
| if (fmt == LevelFormat::Undef) |
| return std::nullopt; |
| return fmt; |
| } |
| inline uint64_t getN(LevelType lt) { return lt.getN(); } |
| inline uint64_t getM(LevelType lt) { return lt.getM(); } |
| inline bool isValidNOutOfMLT(LevelType lt, uint64_t n, uint64_t m) { |
| return isNOutOfMLT(lt) && lt.getN() == n && lt.getM() == m; |
| } |
| inline std::string toMLIRString(LevelType lt) { return lt.toMLIRString(); } |
| |
| /// Bit manipulations for affine encoding. |
| /// |
| /// Note that because the indices in the mappings refer to dimensions |
| /// and levels (and *not* the sizes of these dimensions and levels), the |
| /// 64-bit encoding gives ample room for a compact encoding of affine |
| /// operations in the higher bits. Pure permutations still allow for |
| /// 60-bit indices. But non-permutations reserve 20-bits for the |
| /// potential three components (index i, constant, index ii). |
| /// |
| /// The compact encoding is as follows: |
| /// |
| /// 0xffffffffffffffff |
| /// |0000 | 60-bit idx| e.g. i |
| /// |0001 floor| 20-bit const|20-bit idx| e.g. i floor c |
| /// |0010 mod | 20-bit const|20-bit idx| e.g. i mod c |
| /// |0011 mul |20-bit idx|20-bit const|20-bit idx| e.g. i + c * ii |
| /// |
| /// This encoding provides sufficient generality for currently supported |
| /// sparse tensor types. To generalize this more, we will need to provide |
| /// a broader encoding scheme for affine functions. Also, the library |
| /// encoding may be replaced with pure "direct-IR" code in the future. |
| /// |
| constexpr uint64_t encodeDim(uint64_t i, uint64_t cf, uint64_t cm) { |
| if (cf != 0) { |
| assert(cf <= 0xfffffu && cm == 0 && i <= 0xfffffu); |
| return (static_cast<uint64_t>(0x01u) << 60) | (cf << 20) | i; |
| } |
| if (cm != 0) { |
| assert(cm <= 0xfffffu && i <= 0xfffffu); |
| return (static_cast<uint64_t>(0x02u) << 60) | (cm << 20) | i; |
| } |
| assert(i <= 0x0fffffffffffffffu); |
| return i; |
| } |
| constexpr uint64_t encodeLvl(uint64_t i, uint64_t c, uint64_t ii) { |
| if (c != 0) { |
| assert(c <= 0xfffffu && ii <= 0xfffffu && i <= 0xfffffu); |
| return (static_cast<uint64_t>(0x03u) << 60) | (c << 20) | (ii << 40) | i; |
| } |
| assert(i <= 0x0fffffffffffffffu); |
| return i; |
| } |
| constexpr bool isEncodedFloor(uint64_t v) { return (v >> 60) == 0x01u; } |
| constexpr bool isEncodedMod(uint64_t v) { return (v >> 60) == 0x02u; } |
| constexpr bool isEncodedMul(uint64_t v) { return (v >> 60) == 0x03u; } |
| constexpr uint64_t decodeIndex(uint64_t v) { return v & 0xfffffu; } |
| constexpr uint64_t decodeConst(uint64_t v) { return (v >> 20) & 0xfffffu; } |
| constexpr uint64_t decodeMulc(uint64_t v) { return (v >> 20) & 0xfffffu; } |
| constexpr uint64_t decodeMuli(uint64_t v) { return (v >> 40) & 0xfffffu; } |
| |
| } // namespace sparse_tensor |
| } // namespace mlir |
| |
| #endif // MLIR_DIALECT_SPARSETENSOR_IR_ENUMS_H |