| //===- LLVMTypes.cpp - MLIR LLVM Dialect types ----------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements the types for the LLVM dialect in MLIR. These MLIR types |
| // correspond to the LLVM IR type system. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "TypeDetail.h" |
| |
| #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| #include "mlir/Dialect/LLVMIR/LLVMTypes.h" |
| #include "mlir/IR/DialectImplementation.h" |
| #include "mlir/IR/TypeSupport.h" |
| |
| #include "llvm/Support/TypeSize.h" |
| |
| using namespace mlir; |
| using namespace mlir::LLVM; |
| |
| //===----------------------------------------------------------------------===// |
| // LLVMType. |
| //===----------------------------------------------------------------------===// |
| |
| bool LLVMType::classof(Type type) { |
| return llvm::isa<LLVMDialect>(type.getDialect()); |
| } |
| |
| LLVMDialect &LLVMType::getDialect() { |
| return static_cast<LLVMDialect &>(Type::getDialect()); |
| } |
| |
| //----------------------------------------------------------------------------// |
| // Integer type utilities. |
| |
| bool LLVMType::isIntegerTy(unsigned bitwidth) { |
| if (auto intType = dyn_cast<LLVMIntegerType>()) |
| return intType.getBitWidth() == bitwidth; |
| return false; |
| } |
| unsigned LLVMType::getIntegerBitWidth() { |
| return cast<LLVMIntegerType>().getBitWidth(); |
| } |
| |
| LLVMType LLVMType::getArrayElementType() { |
| return cast<LLVMArrayType>().getElementType(); |
| } |
| |
| //----------------------------------------------------------------------------// |
| // Array type utilities. |
| |
| unsigned LLVMType::getArrayNumElements() { |
| return cast<LLVMArrayType>().getNumElements(); |
| } |
| |
| bool LLVMType::isArrayTy() { return isa<LLVMArrayType>(); } |
| |
| //----------------------------------------------------------------------------// |
| // Vector type utilities. |
| |
| LLVMType LLVMType::getVectorElementType() { |
| return cast<LLVMVectorType>().getElementType(); |
| } |
| |
| unsigned LLVMType::getVectorNumElements() { |
| return cast<LLVMFixedVectorType>().getNumElements(); |
| } |
| llvm::ElementCount LLVMType::getVectorElementCount() { |
| return cast<LLVMVectorType>().getElementCount(); |
| } |
| |
| bool LLVMType::isVectorTy() { return isa<LLVMVectorType>(); } |
| |
| //----------------------------------------------------------------------------// |
| // Function type utilities. |
| |
| LLVMType LLVMType::getFunctionParamType(unsigned argIdx) { |
| return cast<LLVMFunctionType>().getParamType(argIdx); |
| } |
| |
| unsigned LLVMType::getFunctionNumParams() { |
| return cast<LLVMFunctionType>().getNumParams(); |
| } |
| |
| LLVMType LLVMType::getFunctionResultType() { |
| return cast<LLVMFunctionType>().getReturnType(); |
| } |
| |
| bool LLVMType::isFunctionTy() { return isa<LLVMFunctionType>(); } |
| |
| bool LLVMType::isFunctionVarArg() { |
| return cast<LLVMFunctionType>().isVarArg(); |
| } |
| |
| //----------------------------------------------------------------------------// |
| // Pointer type utilities. |
| |
| LLVMType LLVMType::getPointerTo(unsigned addrSpace) { |
| return LLVMPointerType::get(*this, addrSpace); |
| } |
| |
| LLVMType LLVMType::getPointerElementTy() { |
| return cast<LLVMPointerType>().getElementType(); |
| } |
| |
| bool LLVMType::isPointerTy() { return isa<LLVMPointerType>(); } |
| |
| //----------------------------------------------------------------------------// |
| // Struct type utilities. |
| |
| LLVMType LLVMType::getStructElementType(unsigned i) { |
| return cast<LLVMStructType>().getBody()[i]; |
| } |
| |
| unsigned LLVMType::getStructNumElements() { |
| return cast<LLVMStructType>().getBody().size(); |
| } |
| |
| bool LLVMType::isStructTy() { return isa<LLVMStructType>(); } |
| |
| //----------------------------------------------------------------------------// |
| // Utilities used to generate floating point types. |
| |
| LLVMType LLVMType::getDoubleTy(MLIRContext *context) { |
| return LLVMDoubleType::get(context); |
| } |
| |
| LLVMType LLVMType::getFloatTy(MLIRContext *context) { |
| return LLVMFloatType::get(context); |
| } |
| |
| LLVMType LLVMType::getBFloatTy(MLIRContext *context) { |
| return LLVMBFloatType::get(context); |
| } |
| |
| LLVMType LLVMType::getHalfTy(MLIRContext *context) { |
| return LLVMHalfType::get(context); |
| } |
| |
| LLVMType LLVMType::getFP128Ty(MLIRContext *context) { |
| return LLVMFP128Type::get(context); |
| } |
| |
| LLVMType LLVMType::getX86_FP80Ty(MLIRContext *context) { |
| return LLVMX86FP80Type::get(context); |
| } |
| |
| //----------------------------------------------------------------------------// |
| // Utilities used to generate integer types. |
| |
| LLVMType LLVMType::getIntNTy(MLIRContext *context, unsigned numBits) { |
| return LLVMIntegerType::get(context, numBits); |
| } |
| |
| //----------------------------------------------------------------------------// |
| // Utilities used to generate other miscellaneous types. |
| |
| LLVMType LLVMType::getArrayTy(LLVMType elementType, uint64_t numElements) { |
| return LLVMArrayType::get(elementType, numElements); |
| } |
| |
| LLVMType LLVMType::getFunctionTy(LLVMType result, ArrayRef<LLVMType> params, |
| bool isVarArg) { |
| return LLVMFunctionType::get(result, params, isVarArg); |
| } |
| |
| LLVMType LLVMType::getStructTy(MLIRContext *context, |
| ArrayRef<LLVMType> elements, bool isPacked) { |
| return LLVMStructType::getLiteral(context, elements, isPacked); |
| } |
| |
| LLVMType LLVMType::getVectorTy(LLVMType elementType, unsigned numElements) { |
| return LLVMFixedVectorType::get(elementType, numElements); |
| } |
| |
| //----------------------------------------------------------------------------// |
| // Void type utilities. |
| |
| LLVMType LLVMType::getVoidTy(MLIRContext *context) { |
| return LLVMVoidType::get(context); |
| } |
| |
| bool LLVMType::isVoidTy() { return isa<LLVMVoidType>(); } |
| |
| //----------------------------------------------------------------------------// |
| // Creation and setting of LLVM's identified struct types |
| |
| LLVMType LLVMType::createStructTy(MLIRContext *context, |
| ArrayRef<LLVMType> elements, |
| Optional<StringRef> name, bool isPacked) { |
| assert(name.hasValue() && |
| "identified structs with no identifier not supported"); |
| StringRef stringNameBase = name.getValueOr(""); |
| std::string stringName = stringNameBase.str(); |
| unsigned counter = 0; |
| do { |
| auto type = LLVMStructType::getIdentified(context, stringName); |
| if (type.isInitialized() || failed(type.setBody(elements, isPacked))) { |
| counter += 1; |
| stringName = |
| (Twine(stringNameBase) + "." + std::to_string(counter)).str(); |
| continue; |
| } |
| return type; |
| } while (true); |
| } |
| |
| LLVMType LLVMType::setStructTyBody(LLVMType structType, |
| ArrayRef<LLVMType> elements, bool isPacked) { |
| LogicalResult couldSet = |
| structType.cast<LLVMStructType>().setBody(elements, isPacked); |
| assert(succeeded(couldSet) && "failed to set the body"); |
| (void)couldSet; |
| return structType; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Array type. |
| |
| bool LLVMArrayType::isValidElementType(LLVMType type) { |
| return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, |
| LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(); |
| } |
| |
| LLVMArrayType LLVMArrayType::get(LLVMType elementType, unsigned numElements) { |
| assert(elementType && "expected non-null subtype"); |
| return Base::get(elementType.getContext(), LLVMType::ArrayType, elementType, |
| numElements); |
| } |
| |
| LLVMArrayType LLVMArrayType::getChecked(Location loc, LLVMType elementType, |
| unsigned numElements) { |
| assert(elementType && "expected non-null subtype"); |
| return Base::getChecked(loc, LLVMType::ArrayType, elementType, numElements); |
| } |
| |
| LLVMType LLVMArrayType::getElementType() { return getImpl()->elementType; } |
| |
| unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; } |
| |
| LogicalResult |
| LLVMArrayType::verifyConstructionInvariants(Location loc, LLVMType elementType, |
| unsigned numElements) { |
| if (!isValidElementType(elementType)) |
| return emitError(loc, "invalid array element type: ") << elementType; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Function type. |
| |
| bool LLVMFunctionType::isValidArgumentType(LLVMType type) { |
| return !type.isa<LLVMVoidType, LLVMFunctionType>(); |
| } |
| |
| bool LLVMFunctionType::isValidResultType(LLVMType type) { |
| return !type.isa<LLVMFunctionType, LLVMMetadataType, LLVMLabelType>(); |
| } |
| |
| LLVMFunctionType LLVMFunctionType::get(LLVMType result, |
| ArrayRef<LLVMType> arguments, |
| bool isVarArg) { |
| assert(result && "expected non-null result"); |
| return Base::get(result.getContext(), LLVMType::FunctionType, result, |
| arguments, isVarArg); |
| } |
| |
| LLVMFunctionType LLVMFunctionType::getChecked(Location loc, LLVMType result, |
| ArrayRef<LLVMType> arguments, |
| bool isVarArg) { |
| assert(result && "expected non-null result"); |
| return Base::getChecked(loc, LLVMType::FunctionType, result, arguments, |
| isVarArg); |
| } |
| |
| LLVMType LLVMFunctionType::getReturnType() { |
| return getImpl()->getReturnType(); |
| } |
| |
| unsigned LLVMFunctionType::getNumParams() { |
| return getImpl()->getArgumentTypes().size(); |
| } |
| |
| LLVMType LLVMFunctionType::getParamType(unsigned i) { |
| return getImpl()->getArgumentTypes()[i]; |
| } |
| |
| bool LLVMFunctionType::isVarArg() { return getImpl()->isVariadic(); } |
| |
| ArrayRef<LLVMType> LLVMFunctionType::getParams() { |
| return getImpl()->getArgumentTypes(); |
| } |
| |
| LogicalResult LLVMFunctionType::verifyConstructionInvariants( |
| Location loc, LLVMType result, ArrayRef<LLVMType> arguments, bool) { |
| if (!isValidResultType(result)) |
| return emitError(loc, "invalid function result type: ") << result; |
| |
| for (LLVMType arg : arguments) |
| if (!isValidArgumentType(arg)) |
| return emitError(loc, "invalid function argument type: ") << arg; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Integer type. |
| |
| LLVMIntegerType LLVMIntegerType::get(MLIRContext *ctx, unsigned bitwidth) { |
| return Base::get(ctx, LLVMType::IntegerType, bitwidth); |
| } |
| |
| LLVMIntegerType LLVMIntegerType::getChecked(Location loc, unsigned bitwidth) { |
| return Base::getChecked(loc, LLVMType::IntegerType, bitwidth); |
| } |
| |
| unsigned LLVMIntegerType::getBitWidth() { return getImpl()->bitwidth; } |
| |
| LogicalResult LLVMIntegerType::verifyConstructionInvariants(Location loc, |
| unsigned bitwidth) { |
| constexpr int maxSupportedBitwidth = (1 << 24); |
| if (bitwidth >= maxSupportedBitwidth) |
| return emitError(loc, "integer type too wide"); |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Pointer type. |
| |
| bool LLVMPointerType::isValidElementType(LLVMType type) { |
| return !type.isa<LLVMVoidType, LLVMTokenType, LLVMMetadataType, |
| LLVMLabelType>(); |
| } |
| |
| LLVMPointerType LLVMPointerType::get(LLVMType pointee, unsigned addressSpace) { |
| assert(pointee && "expected non-null subtype"); |
| return Base::get(pointee.getContext(), LLVMType::PointerType, pointee, |
| addressSpace); |
| } |
| |
| LLVMPointerType LLVMPointerType::getChecked(Location loc, LLVMType pointee, |
| unsigned addressSpace) { |
| return Base::getChecked(loc, LLVMType::PointerType, pointee, addressSpace); |
| } |
| |
| LLVMType LLVMPointerType::getElementType() { return getImpl()->pointeeType; } |
| |
| unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; } |
| |
| LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc, |
| LLVMType pointee, |
| unsigned) { |
| if (!isValidElementType(pointee)) |
| return emitError(loc, "invalid pointer element type: ") << pointee; |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Struct type. |
| |
| bool LLVMStructType::isValidElementType(LLVMType type) { |
| return !type.isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, |
| LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>(); |
| } |
| |
| LLVMStructType LLVMStructType::getIdentified(MLIRContext *context, |
| StringRef name) { |
| return Base::get(context, LLVMType::StructType, name, /*opaque=*/false); |
| } |
| |
| LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc, |
| StringRef name) { |
| return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/false); |
| } |
| |
| LLVMStructType LLVMStructType::getLiteral(MLIRContext *context, |
| ArrayRef<LLVMType> types, |
| bool isPacked) { |
| return Base::get(context, LLVMType::StructType, types, isPacked); |
| } |
| |
| LLVMStructType LLVMStructType::getLiteralChecked(Location loc, |
| ArrayRef<LLVMType> types, |
| bool isPacked) { |
| return Base::getChecked(loc, LLVMType::StructType, types, isPacked); |
| } |
| |
| LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) { |
| return Base::get(context, LLVMType::StructType, name, /*opaque=*/true); |
| } |
| |
| LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) { |
| return Base::getChecked(loc, LLVMType::StructType, name, /*opaque=*/true); |
| } |
| |
| LogicalResult LLVMStructType::setBody(ArrayRef<LLVMType> types, bool isPacked) { |
| assert(isIdentified() && "can only set bodies of identified structs"); |
| assert(llvm::all_of(types, LLVMStructType::isValidElementType) && |
| "expected valid body types"); |
| return Base::mutate(types, isPacked); |
| } |
| |
| bool LLVMStructType::isPacked() { return getImpl()->isPacked(); } |
| bool LLVMStructType::isIdentified() { return getImpl()->isIdentified(); } |
| bool LLVMStructType::isOpaque() { |
| return getImpl()->isOpaque() || !getImpl()->isInitialized(); |
| } |
| bool LLVMStructType::isInitialized() { return getImpl()->isInitialized(); } |
| StringRef LLVMStructType::getName() { return getImpl()->getIdentifier(); } |
| ArrayRef<LLVMType> LLVMStructType::getBody() { |
| return isIdentified() ? getImpl()->getIdentifiedStructBody() |
| : getImpl()->getTypeList(); |
| } |
| |
| LogicalResult LLVMStructType::verifyConstructionInvariants(Location, StringRef, |
| bool) { |
| return success(); |
| } |
| |
| LogicalResult |
| LLVMStructType::verifyConstructionInvariants(Location loc, |
| ArrayRef<LLVMType> types, bool) { |
| for (LLVMType t : types) |
| if (!isValidElementType(t)) |
| return emitError(loc, "invalid LLVM structure element type: ") << t; |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Vector types. |
| |
| bool LLVMVectorType::isValidElementType(LLVMType type) { |
| return type.isa<LLVMIntegerType, LLVMPointerType>() || |
| type.isFloatingPointTy(); |
| } |
| |
| /// Support type casting functionality. |
| bool LLVMVectorType::classof(Type type) { |
| return type.isa<LLVMFixedVectorType, LLVMScalableVectorType>(); |
| } |
| |
| LLVMType LLVMVectorType::getElementType() { |
| // Both derived classes share the implementation type. |
| return static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->elementType; |
| } |
| |
| llvm::ElementCount LLVMVectorType::getElementCount() { |
| // Both derived classes share the implementation type. |
| return llvm::ElementCount( |
| static_cast<detail::LLVMTypeAndSizeStorage *>(impl)->numElements, |
| isa<LLVMScalableVectorType>()); |
| } |
| |
| /// Verifies that the type about to be constructed is well-formed. |
| LogicalResult |
| LLVMVectorType::verifyConstructionInvariants(Location loc, LLVMType elementType, |
| unsigned numElements) { |
| if (numElements == 0) |
| return emitError(loc, "the number of vector elements must be positive"); |
| |
| if (!isValidElementType(elementType)) |
| return emitError(loc, "invalid vector element type"); |
| |
| return success(); |
| } |
| |
| LLVMFixedVectorType LLVMFixedVectorType::get(LLVMType elementType, |
| unsigned numElements) { |
| assert(elementType && "expected non-null subtype"); |
| return Base::get(elementType.getContext(), LLVMType::FixedVectorType, |
| elementType, numElements); |
| } |
| |
| LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc, |
| LLVMType elementType, |
| unsigned numElements) { |
| assert(elementType && "expected non-null subtype"); |
| return Base::getChecked(loc, LLVMType::FixedVectorType, elementType, |
| numElements); |
| } |
| |
| unsigned LLVMFixedVectorType::getNumElements() { |
| return getImpl()->numElements; |
| } |
| |
| LLVMScalableVectorType LLVMScalableVectorType::get(LLVMType elementType, |
| unsigned minNumElements) { |
| assert(elementType && "expected non-null subtype"); |
| return Base::get(elementType.getContext(), LLVMType::ScalableVectorType, |
| elementType, minNumElements); |
| } |
| |
| LLVMScalableVectorType |
| LLVMScalableVectorType::getChecked(Location loc, LLVMType elementType, |
| unsigned minNumElements) { |
| assert(elementType && "expected non-null subtype"); |
| return Base::getChecked(loc, LLVMType::ScalableVectorType, elementType, |
| minNumElements); |
| } |
| |
| unsigned LLVMScalableVectorType::getMinNumElements() { |
| return getImpl()->numElements; |
| } |