| //===- SPIRVTypes.h - MLIR SPIR-V Types -------------------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file declares the types in the SPIR-V dialect. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ |
| #define MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ |
| |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/StandardTypes.h" |
| #include "mlir/IR/TypeSupport.h" |
| #include "mlir/IR/Types.h" |
| |
| #include <tuple> |
| |
| // Forward declare enum classes related to op availability. Their definitions |
| // are in the TableGen'erated SPIRVEnums.h.inc and can be referenced by other |
| // declarations in SPIRVEnums.h.inc. |
| namespace mlir { |
| namespace spirv { |
| enum class Version : uint32_t; |
| enum class Extension; |
| enum class Capability : uint32_t; |
| } // namespace spirv |
| } // namespace mlir |
| |
| // Pull in all enum type definitions and utility function declarations |
| #include "mlir/Dialect/SPIRV/SPIRVEnums.h.inc" |
| // Pull in all enum type availability query function declarations |
| #include "mlir/Dialect/SPIRV/SPIRVEnumAvailability.h.inc" |
| |
| namespace mlir { |
| namespace spirv { |
| /// Returns the implied extensions for the given version. These extensions are |
| /// incorporated into the current version so they are implicitly declared when |
| /// targeting the given version. |
| ArrayRef<Extension> getImpliedExtensions(Version version); |
| |
| /// Returns the directly implied capabilities for the given capability. These |
| /// capabilities are implicitly declared by the given capability. |
| ArrayRef<Capability> getDirectImpliedCapabilities(Capability cap); |
| /// Returns the recursively implied capabilities for the given capability. These |
| /// capabilities are implicitly declared by the given capability. Compared to |
| /// the above function, this function collects implied capabilities recursively: |
| /// if an implicitly declared capability implicitly declares a third one, the |
| /// third one will also be returned. |
| SmallVector<Capability, 0> getRecursiveImpliedCapabilities(Capability cap); |
| |
| namespace detail { |
| struct ArrayTypeStorage; |
| struct CooperativeMatrixTypeStorage; |
| struct ImageTypeStorage; |
| struct MatrixTypeStorage; |
| struct PointerTypeStorage; |
| struct RuntimeArrayTypeStorage; |
| struct StructTypeStorage; |
| |
| } // namespace detail |
| |
| namespace TypeKind { |
| enum Kind { |
| Array = Type::FIRST_SPIRV_TYPE, |
| CooperativeMatrix, |
| Image, |
| Matrix, |
| Pointer, |
| RuntimeArray, |
| Struct, |
| LAST_SPIRV_TYPE = Struct, |
| }; |
| } |
| |
| // Base SPIR-V type for providing availability queries. |
| class SPIRVType : public Type { |
| public: |
| using Type::Type; |
| |
| static bool classof(Type type); |
| |
| bool isScalarOrVector(); |
| |
| /// The extension requirements for each type are following the |
| /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) |
| /// convention. |
| using ExtensionArrayRefVector = SmallVectorImpl<ArrayRef<spirv::Extension>>; |
| |
| /// Appends to `extensions` the extensions needed for this type to appear in |
| /// the given `storage` class. This method does not guarantee the uniqueness |
| /// of extensions; the same extension may be appended multiple times. |
| void getExtensions(ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| |
| /// The capability requirements for each type are following the |
| /// ((Capability::A OR Extension::B) AND (Capability::C OR Capability::D)) |
| /// convention. |
| using CapabilityArrayRefVector = SmallVectorImpl<ArrayRef<spirv::Capability>>; |
| |
| /// Appends to `capabilities` the capabilities needed for this type to appear |
| /// in the given `storage` class. This method does not guarantee the |
| /// uniqueness of capabilities; the same capability may be appended multiple |
| /// times. |
| void getCapabilities(CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| // SPIR-V scalar type: bool type, integer type, floating point type. |
| class ScalarType : public SPIRVType { |
| public: |
| using SPIRVType::SPIRVType; |
| |
| static bool classof(Type type); |
| |
| /// Returns true if the given integer type is valid for the SPIR-V dialect. |
| static bool isValid(FloatType); |
| /// Returns true if the given float type is valid for the SPIR-V dialect. |
| static bool isValid(IntegerType); |
| |
| void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| // SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. |
| class CompositeType : public SPIRVType { |
| public: |
| using SPIRVType::SPIRVType; |
| |
| static bool classof(Type type); |
| |
| /// Returns true if the given vector type is valid for the SPIR-V dialect. |
| static bool isValid(VectorType); |
| |
| /// Return the number of elements of the type. This should only be called if |
| /// hasCompileTimeKnownNumElements is true. |
| unsigned getNumElements() const; |
| |
| Type getElementType(unsigned) const; |
| |
| /// Return true if the number of elements is known at compile time and is not |
| /// implementation dependent. |
| bool hasCompileTimeKnownNumElements() const; |
| |
| void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| // SPIR-V array type |
| class ArrayType : public Type::TypeBase<ArrayType, CompositeType, |
| detail::ArrayTypeStorage> { |
| public: |
| using Base::Base; |
| |
| static bool kindof(unsigned kind) { return kind == TypeKind::Array; } |
| |
| static ArrayType get(Type elementType, unsigned elementCount); |
| |
| /// Returns an array type with the given stride in bytes. |
| static ArrayType get(Type elementType, unsigned elementCount, |
| unsigned stride); |
| |
| unsigned getNumElements() const; |
| |
| Type getElementType() const; |
| |
| /// Returns the array stride in bytes. 0 means no stride decorated on this |
| /// type. |
| unsigned getArrayStride() const; |
| |
| void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| // SPIR-V image type |
| class ImageType |
| : public Type::TypeBase<ImageType, SPIRVType, detail::ImageTypeStorage> { |
| public: |
| using Base::Base; |
| |
| static bool kindof(unsigned kind) { return kind == TypeKind::Image; } |
| |
| static ImageType |
| get(Type elementType, Dim dim, |
| ImageDepthInfo depth = ImageDepthInfo::DepthUnknown, |
| ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed, |
| ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled, |
| ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown, |
| ImageFormat format = ImageFormat::Unknown) { |
| return ImageType::get( |
| std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, |
| ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>( |
| elementType, dim, depth, arrayed, samplingInfo, samplerUse, |
| format)); |
| } |
| |
| static ImageType |
| get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo, |
| ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>); |
| |
| Type getElementType() const; |
| Dim getDim() const; |
| ImageDepthInfo getDepthInfo() const; |
| ImageArrayedInfo getArrayedInfo() const; |
| ImageSamplingInfo getSamplingInfo() const; |
| ImageSamplerUseInfo getSamplerUseInfo() const; |
| ImageFormat getImageFormat() const; |
| // TODO: Add support for Access qualifier |
| |
| void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| // SPIR-V pointer type |
| class PointerType : public Type::TypeBase<PointerType, SPIRVType, |
| detail::PointerTypeStorage> { |
| public: |
| using Base::Base; |
| |
| static bool kindof(unsigned kind) { return kind == TypeKind::Pointer; } |
| |
| static PointerType get(Type pointeeType, StorageClass storageClass); |
| |
| Type getPointeeType() const; |
| |
| StorageClass getStorageClass() const; |
| |
| void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| // SPIR-V run-time array type |
| class RuntimeArrayType |
| : public Type::TypeBase<RuntimeArrayType, SPIRVType, |
| detail::RuntimeArrayTypeStorage> { |
| public: |
| using Base::Base; |
| |
| static bool kindof(unsigned kind) { return kind == TypeKind::RuntimeArray; } |
| |
| static RuntimeArrayType get(Type elementType); |
| |
| /// Returns a runtime array type with the given stride in bytes. |
| static RuntimeArrayType get(Type elementType, unsigned stride); |
| |
| Type getElementType() const; |
| |
| /// Returns the array stride in bytes. 0 means no stride decorated on this |
| /// type. |
| unsigned getArrayStride() const; |
| |
| void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| // SPIR-V struct type |
| class StructType : public Type::TypeBase<StructType, CompositeType, |
| detail::StructTypeStorage> { |
| public: |
| using Base::Base; |
| |
| // Type for specifying the offset of the struct members |
| using OffsetInfo = uint32_t; |
| |
| // Type for specifying the decoration(s) on struct members |
| struct MemberDecorationInfo { |
| uint32_t memberIndex : 31; |
| uint32_t hasValue : 1; |
| Decoration decoration; |
| uint32_t decorationValue; |
| |
| MemberDecorationInfo(uint32_t index, uint32_t hasValue, |
| Decoration decoration, uint32_t decorationValue) |
| : memberIndex(index), hasValue(hasValue), decoration(decoration), |
| decorationValue(decorationValue) {} |
| |
| bool operator==(const MemberDecorationInfo &other) const { |
| return (this->memberIndex == other.memberIndex) && |
| (this->decoration == other.decoration) && |
| (this->decorationValue == other.decorationValue); |
| } |
| |
| bool operator<(const MemberDecorationInfo &other) const { |
| return this->memberIndex < other.memberIndex || |
| (this->memberIndex == other.memberIndex && |
| static_cast<uint32_t>(this->decoration) < |
| static_cast<uint32_t>(other.decoration)); |
| } |
| }; |
| |
| static bool kindof(unsigned kind) { return kind == TypeKind::Struct; } |
| |
| /// Construct a StructType with at least one member. |
| static StructType get(ArrayRef<Type> memberTypes, |
| ArrayRef<OffsetInfo> offsetInfo = {}, |
| ArrayRef<MemberDecorationInfo> memberDecorations = {}); |
| |
| /// Construct a struct with no members. |
| static StructType getEmpty(MLIRContext *context); |
| |
| unsigned getNumElements() const; |
| |
| Type getElementType(unsigned) const; |
| |
| /// Range class for element types. |
| class ElementTypeRange |
| : public ::llvm::detail::indexed_accessor_range_base< |
| ElementTypeRange, const Type *, Type, Type, Type> { |
| private: |
| using RangeBaseT::RangeBaseT; |
| |
| /// See `llvm::detail::indexed_accessor_range_base` for details. |
| static const Type *offset_base(const Type *object, ptrdiff_t index) { |
| return object + index; |
| } |
| /// See `llvm::detail::indexed_accessor_range_base` for details. |
| static Type dereference_iterator(const Type *object, ptrdiff_t index) { |
| return object[index]; |
| } |
| |
| /// Allow base class access to `offset_base` and `dereference_iterator`. |
| friend RangeBaseT; |
| }; |
| |
| ElementTypeRange getElementTypes() const; |
| |
| bool hasOffset() const; |
| |
| uint64_t getMemberOffset(unsigned) const; |
| |
| // Returns in `memberDecorations` the spirv::Decorations (apart from |
| // Offset) associated with all members of the StructType. |
| void getMemberDecorations(SmallVectorImpl<StructType::MemberDecorationInfo> |
| &memberDecorations) const; |
| |
| // Returns in `decorationsInfo` all the spirv::Decorations (apart from |
| // Offset) associated with the `i`-th member of the StructType. |
| void getMemberDecorations(unsigned i, |
| SmallVectorImpl<StructType::MemberDecorationInfo> |
| &decorationsInfo) const; |
| |
| void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| llvm::hash_code |
| hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); |
| |
| // SPIR-V cooperative matrix type |
| class CooperativeMatrixNVType |
| : public Type::TypeBase<CooperativeMatrixNVType, CompositeType, |
| detail::CooperativeMatrixTypeStorage> { |
| public: |
| using Base::Base; |
| |
| static bool kindof(unsigned kind) { |
| return kind == TypeKind::CooperativeMatrix; |
| } |
| |
| static CooperativeMatrixNVType get(Type elementType, spirv::Scope scope, |
| unsigned rows, unsigned columns); |
| Type getElementType() const; |
| |
| /// Return the scope of the cooperative matrix. |
| spirv::Scope getScope() const; |
| /// return the number of rows of the matrix. |
| unsigned getRows() const; |
| /// return the number of columns of the matrix. |
| unsigned getColumns() const; |
| |
| void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| // SPIR-V matrix type |
| class MatrixType : public Type::TypeBase<MatrixType, CompositeType, |
| detail::MatrixTypeStorage> { |
| public: |
| using Base::Base; |
| |
| static bool kindof(unsigned kind) { return kind == TypeKind::Matrix; } |
| |
| static MatrixType get(Type columnType, uint32_t columnCount); |
| |
| static MatrixType getChecked(Type columnType, uint32_t columnCount, |
| Location location); |
| |
| static LogicalResult verifyConstructionInvariants(Location loc, |
| Type columnType, |
| uint32_t columnCount); |
| |
| /// Returns true if the matrix elements are vectors of float elements. |
| static bool isValidColumnType(Type columnType); |
| |
| Type getColumnType() const; |
| |
| /// Returns the number of rows. |
| unsigned getNumRows() const; |
| |
| /// Returns the number of columns. |
| unsigned getNumColumns() const; |
| |
| /// Returns total number of elements (rows*columns). |
| unsigned getNumElements() const; |
| |
| /// Returns the elements' type (i.e, single element type). |
| Type getElementType() const; |
| |
| void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, |
| Optional<spirv::StorageClass> storage = llvm::None); |
| }; |
| |
| } // end namespace spirv |
| } // end namespace mlir |
| |
| #endif // MLIR_DIALECT_SPIRV_SPIRVTYPES_H_ |