blob: c94c070144a7e9ee522ee08e87ece7ab189e72ec [file] [log] [blame]
//===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/CAPI/AffineMap.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LogicalResult.h"
#include <algorithm>
using namespace mlir;
//===----------------------------------------------------------------------===//
// Integer types.
//===----------------------------------------------------------------------===//
MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); }
bool mlirTypeIsAInteger(MlirType type) {
return llvm::isa<IntegerType>(unwrap(type));
}
MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(unwrap(ctx), bitwidth));
}
MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
}
MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
}
unsigned mlirIntegerTypeGetWidth(MlirType type) {
return llvm::cast<IntegerType>(unwrap(type)).getWidth();
}
bool mlirIntegerTypeIsSignless(MlirType type) {
return llvm::cast<IntegerType>(unwrap(type)).isSignless();
}
bool mlirIntegerTypeIsSigned(MlirType type) {
return llvm::cast<IntegerType>(unwrap(type)).isSigned();
}
bool mlirIntegerTypeIsUnsigned(MlirType type) {
return llvm::cast<IntegerType>(unwrap(type)).isUnsigned();
}
//===----------------------------------------------------------------------===//
// Index type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); }
bool mlirTypeIsAIndex(MlirType type) {
return llvm::isa<IndexType>(unwrap(type));
}
MlirType mlirIndexTypeGet(MlirContext ctx) {
return wrap(IndexType::get(unwrap(ctx)));
}
//===----------------------------------------------------------------------===//
// Floating-point types.
//===----------------------------------------------------------------------===//
bool mlirTypeIsAFloat(MlirType type) {
return llvm::isa<FloatType>(unwrap(type));
}
unsigned mlirFloatTypeGetWidth(MlirType type) {
return llvm::cast<FloatType>(unwrap(type)).getWidth();
}
MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
return wrap(Float8E5M2Type::getTypeID());
}
bool mlirTypeIsAFloat8E5M2(MlirType type) {
return unwrap(type).isFloat8E5M2();
}
MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2(unwrap(ctx)));
}
MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
return wrap(Float8E4M3FNType::getTypeID());
}
bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
return unwrap(type).isFloat8E4M3FN();
}
MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx)));
}
MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
return wrap(Float8E5M2FNUZType::getTypeID());
}
bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
return unwrap(type).isFloat8E5M2FNUZ();
}
MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx)));
}
MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
return wrap(Float8E4M3FNUZType::getTypeID());
}
bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3FNUZ();
}
MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx)));
}
MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
return wrap(Float8E4M3B11FNUZType::getTypeID());
}
bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
return unwrap(type).isFloat8E4M3B11FNUZ();
}
MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx)));
}
MlirTypeID mlirBFloat16TypeGetTypeID() {
return wrap(BFloat16Type::getTypeID());
}
bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); }
MlirType mlirBF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getBF16(unwrap(ctx)));
}
MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); }
MlirType mlirF16TypeGet(MlirContext ctx) {
return wrap(FloatType::getF16(unwrap(ctx)));
}
MlirTypeID mlirFloatTF32TypeGetTypeID() {
return wrap(FloatTF32Type::getTypeID());
}
bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); }
MlirType mlirTF32TypeGet(MlirContext ctx) {
return wrap(FloatType::getTF32(unwrap(ctx)));
}
MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); }
MlirType mlirF32TypeGet(MlirContext ctx) {
return wrap(FloatType::getF32(unwrap(ctx)));
}
MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); }
MlirType mlirF64TypeGet(MlirContext ctx) {
return wrap(FloatType::getF64(unwrap(ctx)));
}
//===----------------------------------------------------------------------===//
// None type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); }
bool mlirTypeIsANone(MlirType type) {
return llvm::isa<NoneType>(unwrap(type));
}
MlirType mlirNoneTypeGet(MlirContext ctx) {
return wrap(NoneType::get(unwrap(ctx)));
}
//===----------------------------------------------------------------------===//
// Complex type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); }
bool mlirTypeIsAComplex(MlirType type) {
return llvm::isa<ComplexType>(unwrap(type));
}
MlirType mlirComplexTypeGet(MlirType elementType) {
return wrap(ComplexType::get(unwrap(elementType)));
}
MlirType mlirComplexTypeGetElementType(MlirType type) {
return wrap(llvm::cast<ComplexType>(unwrap(type)).getElementType());
}
//===----------------------------------------------------------------------===//
// Shaped type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsAShaped(MlirType type) {
return llvm::isa<ShapedType>(unwrap(type));
}
MlirType mlirShapedTypeGetElementType(MlirType type) {
return wrap(llvm::cast<ShapedType>(unwrap(type)).getElementType());
}
bool mlirShapedTypeHasRank(MlirType type) {
return llvm::cast<ShapedType>(unwrap(type)).hasRank();
}
int64_t mlirShapedTypeGetRank(MlirType type) {
return llvm::cast<ShapedType>(unwrap(type)).getRank();
}
bool mlirShapedTypeHasStaticShape(MlirType type) {
return llvm::cast<ShapedType>(unwrap(type)).hasStaticShape();
}
bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
return llvm::cast<ShapedType>(unwrap(type))
.isDynamicDim(static_cast<unsigned>(dim));
}
int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
return llvm::cast<ShapedType>(unwrap(type))
.getDimSize(static_cast<unsigned>(dim));
}
int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; }
bool mlirShapedTypeIsDynamicSize(int64_t size) {
return ShapedType::isDynamic(size);
}
bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
return ShapedType::isDynamic(val);
}
int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
return ShapedType::kDynamic;
}
//===----------------------------------------------------------------------===//
// Vector type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); }
bool mlirTypeIsAVector(MlirType type) {
return llvm::isa<VectorType>(unwrap(type));
}
MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType) {
return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType)));
}
MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape, MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType)));
}
MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
const bool *scalable, MlirType elementType) {
return wrap(VectorType::get(
llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape,
const bool *scalable,
MlirType elementType) {
return wrap(VectorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType),
llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
}
bool mlirVectorTypeIsScalable(MlirType type) {
return cast<VectorType>(unwrap(type)).isScalable();
}
bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
return cast<VectorType>(unwrap(type)).getScalableDims()[dim];
}
//===----------------------------------------------------------------------===//
// Ranked / Unranked tensor type.
//===----------------------------------------------------------------------===//
bool mlirTypeIsATensor(MlirType type) {
return llvm::isa<TensorType>(unwrap(type));
}
MlirTypeID mlirRankedTensorTypeGetTypeID() {
return wrap(RankedTensorType::getTypeID());
}
bool mlirTypeIsARankedTensor(MlirType type) {
return llvm::isa<RankedTensorType>(unwrap(type));
}
MlirTypeID mlirUnrankedTensorTypeGetTypeID() {
return wrap(UnrankedTensorType::getTypeID());
}
bool mlirTypeIsAUnrankedTensor(MlirType type) {
return llvm::isa<UnrankedTensorType>(unwrap(type));
}
MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
MlirType elementType, MlirAttribute encoding) {
return wrap(
RankedTensorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), unwrap(encoding)));
}
MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
const int64_t *shape,
MlirType elementType,
MlirAttribute encoding) {
return wrap(RankedTensorType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), unwrap(encoding)));
}
MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
return wrap(llvm::cast<RankedTensorType>(unwrap(type)).getEncoding());
}
MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
return wrap(UnrankedTensorType::get(unwrap(elementType)));
}
MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
MlirType elementType) {
return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
}
MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
}
//===----------------------------------------------------------------------===//
// Ranked / Unranked MemRef type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); }
bool mlirTypeIsAMemRef(MlirType type) {
return llvm::isa<MemRefType>(unwrap(type));
}
MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
const int64_t *shape, MlirAttribute layout,
MlirAttribute memorySpace) {
return wrap(MemRefType::get(
llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
mlirAttributeIsNull(layout)
? MemRefLayoutAttrInterface()
: llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
unwrap(memorySpace)));
}
MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
intptr_t rank, const int64_t *shape,
MlirAttribute layout,
MlirAttribute memorySpace) {
return wrap(MemRefType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType),
mlirAttributeIsNull(layout)
? MemRefLayoutAttrInterface()
: llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
unwrap(memorySpace)));
}
MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
const int64_t *shape,
MlirAttribute memorySpace) {
return wrap(MemRefType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), MemRefLayoutAttrInterface(),
unwrap(memorySpace)));
}
MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
MlirType elementType, intptr_t rank,
const int64_t *shape,
MlirAttribute memorySpace) {
return wrap(MemRefType::getChecked(
unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace)));
}
MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout());
}
MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getLayout().getAffineMap());
}
MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
return wrap(llvm::cast<MemRefType>(unwrap(type)).getMemorySpace());
}
MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
int64_t *strides,
int64_t *offset) {
MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
SmallVector<int64_t> strides_;
if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
return mlirLogicalResultFailure();
(void)std::copy(strides_.begin(), strides_.end(), strides);
return mlirLogicalResultSuccess();
}
MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
return wrap(UnrankedMemRefType::getTypeID());
}
bool mlirTypeIsAUnrankedMemRef(MlirType type) {
return llvm::isa<UnrankedMemRefType>(unwrap(type));
}
MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
MlirAttribute memorySpace) {
return wrap(
UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace)));
}
MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
MlirType elementType,
MlirAttribute memorySpace) {
return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
unwrap(memorySpace)));
}
MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
return wrap(llvm::cast<UnrankedMemRefType>(unwrap(type)).getMemorySpace());
}
//===----------------------------------------------------------------------===//
// Tuple type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); }
bool mlirTypeIsATuple(MlirType type) {
return llvm::isa<TupleType>(unwrap(type));
}
MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
MlirType const *elements) {
SmallVector<Type, 4> types;
ArrayRef<Type> typeRef = unwrapList(numElements, elements, types);
return wrap(TupleType::get(unwrap(ctx), typeRef));
}
intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
return llvm::cast<TupleType>(unwrap(type)).size();
}
MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
return wrap(
llvm::cast<TupleType>(unwrap(type)).getType(static_cast<size_t>(pos)));
}
//===----------------------------------------------------------------------===//
// Function type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirFunctionTypeGetTypeID() {
return wrap(FunctionType::getTypeID());
}
bool mlirTypeIsAFunction(MlirType type) {
return llvm::isa<FunctionType>(unwrap(type));
}
MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
MlirType const *inputs, intptr_t numResults,
MlirType const *results) {
SmallVector<Type, 4> inputsList;
SmallVector<Type, 4> resultsList;
(void)unwrapList(numInputs, inputs, inputsList);
(void)unwrapList(numResults, results, resultsList);
return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
}
intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
return llvm::cast<FunctionType>(unwrap(type)).getNumInputs();
}
intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
return llvm::cast<FunctionType>(unwrap(type)).getNumResults();
}
MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
assert(pos >= 0 && "pos in array must be positive");
return wrap(llvm::cast<FunctionType>(unwrap(type))
.getInput(static_cast<unsigned>(pos)));
}
MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
assert(pos >= 0 && "pos in array must be positive");
return wrap(llvm::cast<FunctionType>(unwrap(type))
.getResult(static_cast<unsigned>(pos)));
}
//===----------------------------------------------------------------------===//
// Opaque type.
//===----------------------------------------------------------------------===//
MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); }
bool mlirTypeIsAOpaque(MlirType type) {
return llvm::isa<OpaqueType>(unwrap(type));
}
MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
MlirStringRef typeData) {
return wrap(
OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
unwrap(typeData)));
}
MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) {
return wrap(
llvm::cast<OpaqueType>(unwrap(type)).getDialectNamespace().strref());
}
MlirStringRef mlirOpaqueTypeGetData(MlirType type) {
return wrap(llvm::cast<OpaqueType>(unwrap(type)).getTypeData());
}